MTerryJack's picture
Upload 8 files
cbce3da verified
from __future__ import annotations
import argparse
import json
import sys
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, List, Tuple
import numpy as np
from PIL import Image
import torch
from transformers import AutoImageProcessor, DetrForObjectDetection
ModelBundle = Tuple[DetrForObjectDetection, AutoImageProcessor]
def load_image(frame: Any, base_dir: Path) -> Image.Image:
if isinstance(frame, (bytes, bytearray, memoryview)):
return Image.open(BytesIO(frame)).convert("RGB")
path = Path(str(frame))
if not path.is_absolute():
path = (Path.cwd() / path).resolve()
if not path.exists():
candidate = (base_dir / str(frame)).resolve()
if candidate.exists():
path = candidate
return Image.open(path).convert("RGB")
def load_model(*_args: Any, **_kwargs: Any) -> ModelBundle | None:
base_dir = Path(__file__).resolve().parent
if not (base_dir / "config.json").exists():
return None
processor = AutoImageProcessor.from_pretrained(str(base_dir))
model = DetrForObjectDetection.from_pretrained(str(base_dir))
model.eval()
return model, processor
def run_model(model_bundle: ModelBundle, frame: "np.ndarray") -> List[Dict[str, Any]]:
image = Image.fromarray(frame)
model, processor = model_bundle
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(
outputs,
threshold=0.5,
target_sizes=target_sizes,
)[0]
detections: List[Dict[str, Any]] = []
names = model.config.id2label or {}
name_overrides = {"LABEL_0": "qr_code"}
for det_idx, (score, label, box) in enumerate(
zip(results["scores"], results["labels"], results["boxes"])
):
class_id = int(label.item())
class_name = names.get(class_id, str(class_id))
class_name = name_overrides.get(class_name, class_name)
xyxy = [float(x) for x in box.tolist()]
detections.append(
{
"frame_idx": 0,
"class": class_name,
"bbox": xyxy,
"score": float(score.item()),
"track_id": f"f0-d{det_idx}",
}
)
return detections
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run QR detect element.")
parser.add_argument(
"--stdin-raw",
action="store_true",
default=True,
help="Read raw image bytes from stdin.",
)
return parser
if __name__ == "__main__":
build_parser().parse_args()
base_dir = Path(__file__).resolve().parent
model_bundle = load_model()
if model_bundle is None:
print("[]")
sys.exit(0)
try:
image = load_image(sys.stdin.buffer.read(), base_dir)
except Exception:
print("[]")
sys.exit(0)
frame = np.array(image)
output = run_model(model_bundle, frame)
print(json.dumps(output))