MTerryJack's picture
Upload 9 files
4f72f5f verified
import argparse
import json
import sys
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, List
import numpy as np
from PIL import Image
from ultralytics import YOLO
def load_class_names(base_dir: Path) -> dict[int, str]:
labels_path = base_dir / "class_names.txt"
if not labels_path.exists():
return {}
names: dict[int, str] = {}
for idx, raw in enumerate(labels_path.read_text().splitlines()):
label = raw.strip()
if label:
names[idx] = label
return names
def load_image(frame: Any, base_dir: Path) -> Image.Image:
if isinstance(frame, (bytes, bytearray, memoryview)):
return Image.open(BytesIO(frame)).convert("RGB")
path = Path(str(frame))
if not path.is_absolute():
path = (Path.cwd() / path).resolve()
if not path.exists():
candidate = (base_dir / str(frame)).resolve()
if candidate.exists():
path = candidate
return Image.open(path).convert("RGB")
def load_model(*_args: Any, **_kwargs: Any):
base_dir = Path(__file__).resolve().parent
model_path = base_dir / "yolov5s_weights.onnx"
if not model_path.exists():
return None
model = YOLO(str(model_path), task="detect")
names = load_class_names(base_dir)
if names and hasattr(model, "model") and hasattr(model.model, "names"):
model.model.names = names
return model
def run_model(model, frame: "np.ndarray") -> List[Dict[str, Any]]:
base_dir = Path(__file__).resolve().parent
fallback_names = load_class_names(base_dir)
image = Image.fromarray(frame)
results = model(image)
detections: List[Dict[str, Any]] = []
result = results[0]
names = result.names or model.names or fallback_names or {}
for det_idx, box in enumerate(result.boxes):
xyxy = box.xyxy[0].tolist()
class_id = int(box.cls[0].item())
label = fallback_names.get(class_id, names.get(class_id, str(class_id)))
detections.append(
{
"frame_idx": 0,
"class": label,
"bbox": [float(x) for x in xyxy],
"score": float(box.conf[0].item()),
"track_id": f"f0-d{det_idx}",
}
)
return detections
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run crowd counting detection (YOLOv5 ONNX).")
parser.add_argument(
"--stdin-raw",
action="store_true",
default=True,
help="Read raw image bytes from stdin.",
)
return parser
if __name__ == "__main__":
build_parser().parse_args()
base_dir = Path(__file__).resolve().parent
model = load_model()
if model is None:
print("[]")
sys.exit(0)
try:
image = load_image(sys.stdin.buffer.read(), base_dir)
except Exception:
print("[]")
sys.exit(0)
frame = np.array(image)
output = run_model(model, frame)
print(json.dumps(output))