| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| from io import BytesIO | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Tuple | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from transformers import AutoImageProcessor, DetrForObjectDetection | |
| ModelBundle = Tuple[DetrForObjectDetection, AutoImageProcessor] | |
| def load_image(frame: Any, base_dir: Path) -> Image.Image: | |
| if isinstance(frame, (bytes, bytearray, memoryview)): | |
| return Image.open(BytesIO(frame)).convert("RGB") | |
| path = Path(str(frame)) | |
| if not path.is_absolute(): | |
| path = (Path.cwd() / path).resolve() | |
| if not path.exists(): | |
| candidate = (base_dir / str(frame)).resolve() | |
| if candidate.exists(): | |
| path = candidate | |
| return Image.open(path).convert("RGB") | |
| def load_model(*_args: Any, **_kwargs: Any) -> ModelBundle | None: | |
| base_dir = Path(__file__).resolve().parent | |
| if not (base_dir / "config.json").exists(): | |
| return None | |
| processor = AutoImageProcessor.from_pretrained(str(base_dir)) | |
| model = DetrForObjectDetection.from_pretrained(str(base_dir)) | |
| model.eval() | |
| return model, processor | |
| def run_model(model_bundle: ModelBundle, frame: "np.ndarray") -> List[Dict[str, Any]]: | |
| image = Image.fromarray(frame) | |
| model, processor = model_bundle | |
| inputs = processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| target_sizes = torch.tensor([image.size[::-1]]) | |
| results = processor.post_process_object_detection( | |
| outputs, | |
| threshold=0.5, | |
| target_sizes=target_sizes, | |
| )[0] | |
| detections: List[Dict[str, Any]] = [] | |
| names = model.config.id2label or {} | |
| name_overrides = {"LABEL_0": "qr_code"} | |
| for det_idx, (score, label, box) in enumerate( | |
| zip(results["scores"], results["labels"], results["boxes"]) | |
| ): | |
| class_id = int(label.item()) | |
| class_name = names.get(class_id, str(class_id)) | |
| class_name = name_overrides.get(class_name, class_name) | |
| xyxy = [float(x) for x in box.tolist()] | |
| detections.append( | |
| { | |
| "frame_idx": 0, | |
| "class": class_name, | |
| "bbox": xyxy, | |
| "score": float(score.item()), | |
| "track_id": f"f0-d{det_idx}", | |
| } | |
| ) | |
| return detections | |
| def build_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser(description="Run QR detect element.") | |
| parser.add_argument( | |
| "--stdin-raw", | |
| action="store_true", | |
| default=True, | |
| help="Read raw image bytes from stdin.", | |
| ) | |
| return parser | |
| if __name__ == "__main__": | |
| build_parser().parse_args() | |
| base_dir = Path(__file__).resolve().parent | |
| model_bundle = load_model() | |
| if model_bundle is None: | |
| print("[]") | |
| sys.exit(0) | |
| try: | |
| image = load_image(sys.stdin.buffer.read(), base_dir) | |
| except Exception: | |
| print("[]") | |
| sys.exit(0) | |
| frame = np.array(image) | |
| output = run_model(model_bundle, frame) | |
| print(json.dumps(output)) | |