File size: 4,973 Bytes
d66fee6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import argparse
import json
import sys
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, List
import numpy as np
from PIL import Image
import onnxruntime as ort
def load_class_names(base_dir: Path) -> dict[int, str]:
labels_path = base_dir / "class_names.txt"
if not labels_path.exists():
return {}
names: dict[int, str] = {}
for idx, raw in enumerate(labels_path.read_text().splitlines()):
label = raw.strip()
if label:
names[idx] = label
return names
def load_image(frame: Any, base_dir: Path) -> Image.Image:
if isinstance(frame, (bytes, bytearray, memoryview)):
return Image.open(BytesIO(frame)).convert("RGB")
path = Path(str(frame))
if not path.is_absolute():
path = (Path.cwd() / path).resolve()
if not path.exists():
candidate = (base_dir / str(frame)).resolve()
if candidate.exists():
path = candidate
return Image.open(path).convert("RGB")
def _nms(boxes: np.ndarray, scores: np.ndarray, iou_thresh: float) -> List[int]:
if boxes.size == 0:
return []
x1, y1, x2, y2 = boxes.T
areas = (x2 - x1) * (y2 - y1)
order = scores.argsort()[::-1]
keep: List[int] = []
while order.size > 0:
i = int(order[0])
keep.append(i)
if order.size == 1:
break
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.clip(xx2 - xx1, 0, None)
h = np.clip(yy2 - yy1, 0, None)
inter = w * h
iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-6)
inds = np.where(iou <= iou_thresh)[0]
order = order[inds + 1]
return keep
def load_model(*_args: Any, **_kwargs: Any):
base_dir = Path(__file__).resolve().parent
model_path = base_dir / "yolov5s_weights.onnx"
if not model_path.exists():
return None
session = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"])
return {
"session": session,
"input_name": session.get_inputs()[0].name,
"names": load_class_names(base_dir),
"size": 640,
}
def run_model(model, frame: "np.ndarray") -> List[Dict[str, Any]]:
if not isinstance(model, dict):
return []
session: ort.InferenceSession = model["session"]
input_name = model["input_name"]
names: dict[int, str] = model["names"]
size = int(model["size"])
image = Image.fromarray(frame).convert("RGB")
orig_w, orig_h = image.size
resized = image.resize((size, size))
inp = np.array(resized).astype("float32") / 255.0
inp = np.transpose(inp, (2, 0, 1))[None, ...]
outputs = session.run(None, {input_name: inp})
preds = outputs[0][0] # (25200, 5+nc)
if preds.shape[1] < 6:
return []
boxes = preds[:, :4]
objectness = preds[:, 4]
class_scores = preds[:, 5:]
class_ids = np.argmax(class_scores, axis=1)
class_conf = class_scores[np.arange(class_scores.shape[0]), class_ids]
scores = objectness * class_conf
conf_thresh = 0.25
keep = scores > conf_thresh
boxes = boxes[keep]
scores = scores[keep]
class_ids = class_ids[keep]
if boxes.size == 0:
return []
x, y, w, h = boxes.T
x1 = x - w / 2
y1 = y - h / 2
x2 = x + w / 2
y2 = y + h / 2
boxes_xyxy = np.stack([x1, y1, x2, y2], axis=1)
keep_idx = _nms(boxes_xyxy, scores, 0.45)
detections: List[Dict[str, Any]] = []
for det_idx, i in enumerate(keep_idx):
xyxy = boxes_xyxy[i]
scale_x = orig_w / size
scale_y = orig_h / size
xyxy = np.array([xyxy[0] * scale_x, xyxy[1] * scale_y, xyxy[2] * scale_x, xyxy[3] * scale_y])
class_id = int(class_ids[i])
label = names.get(class_id, str(class_id))
detections.append(
{
"frame_idx": 0,
"class": label,
"bbox": [float(v) for v in xyxy],
"score": float(scores[i]),
"track_id": f"f0-d{det_idx}",
}
)
return detections
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run rotten fruit detection (YOLOv5 ONNX).")
parser.add_argument(
"--stdin-raw",
action="store_true",
default=True,
help="Read raw image bytes from stdin.",
)
return parser
if __name__ == "__main__":
build_parser().parse_args()
base_dir = Path(__file__).resolve().parent
model = load_model()
if model is None:
print("[]")
sys.exit(0)
try:
image = load_image(sys.stdin.buffer.read(), base_dir)
except Exception:
print("[]")
sys.exit(0)
frame = np.array(image)
output = run_model(model, frame)
print(json.dumps(output))
|