MTerryJack's picture
Upload 10 files
148aad1 verified
import argparse
import json
import sys
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, List
import numpy as np
from PIL import Image
from ultralytics import YOLO
def load_keypoint_names(base_dir: Path) -> dict[int, str]:
metadata_path = base_dir / "keypoints_metadata.json"
if not metadata_path.exists():
return {}
try:
data = json.loads(metadata_path.read_text())
except Exception:
return {}
if not data:
return {}
keypoints = data[0].get("keypoints", {})
names: dict[int, str] = {}
for idx_str, label in keypoints.items():
try:
idx = int(idx_str)
except ValueError:
continue
names[idx] = str(label)
return names
def load_image(frame: Any, base_dir: Path) -> Image.Image:
if isinstance(frame, (bytes, bytearray, memoryview)):
return Image.open(BytesIO(frame)).convert("RGB")
path = Path(str(frame))
if not path.is_absolute():
path = (Path.cwd() / path).resolve()
if not path.exists():
candidate = (base_dir / str(frame)).resolve()
if candidate.exists():
path = candidate
return Image.open(path).convert("RGB")
def load_model(*_args: Any, **_kwargs: Any):
base_dir = Path(__file__).resolve().parent
model_path = base_dir / "weights.onnx"
if not model_path.exists():
return None
return YOLO(str(model_path), task="pose")
def run_model(model, frame: "np.ndarray") -> List[Dict[str, Any]]:
base_dir = Path(__file__).resolve().parent
keypoint_names = load_keypoint_names(base_dir)
image = Image.fromarray(frame)
results = model(image)
outputs: List[Dict[str, Any]] = []
result = results[0]
if result.keypoints is None:
return outputs
keypoints_xy = result.keypoints.xy
for det_idx, points in enumerate(keypoints_xy):
point_items = []
for idx, xy in enumerate(points.tolist()):
label = keypoint_names.get(idx, str(idx))
point_items.append({"id": label, "x": float(xy[0]), "y": float(xy[1])})
outputs.append({"frame_idx": 0, "points": point_items})
return outputs
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run basketball court keypoint detection (YOLO pose ONNX).")
parser.add_argument(
"--stdin-raw",
action="store_true",
default=True,
help="Read raw image bytes from stdin.",
)
return parser
if __name__ == "__main__":
build_parser().parse_args()
base_dir = Path(__file__).resolve().parent
model = load_model()
if model is None:
print("[]")
sys.exit(0)
try:
image = load_image(sys.stdin.buffer.read(), base_dir)
except Exception:
print("[]")
sys.exit(0)
frame = np.array(image)
output = run_model(model, frame)
print(json.dumps(output))