| """Custom inference handler for RF-DETR Threat Detection on HuggingFace Inference Endpoints.""" |
|
|
| from typing import Any, Dict, List |
| import io |
| import numpy as np |
| from PIL import Image |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| from rfdetr import RFDETRNano |
| import os |
|
|
| weights = os.path.join(path, "checkpoint_best_total.pth") |
| self.model = RFDETRNano(resolution=960, pretrain_weights=weights) |
| self.model.optimize_for_inference() |
|
|
| self.class_map = {1: "Gun", 2: "Explosive", 3: "Grenade", 4: "Knife"} |
|
|
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| inputs = data.get("inputs") |
| if isinstance(inputs, bytes): |
| image = Image.open(io.BytesIO(inputs)).convert("RGB") |
| elif isinstance(inputs, str): |
| import base64 |
| image = Image.open(io.BytesIO(base64.b64decode(inputs))).convert("RGB") |
| else: |
| image = inputs |
|
|
| threshold = data.get("parameters", {}).get("threshold", 0.25) |
| detections = self.model.predict(image, threshold=threshold) |
|
|
| results = [] |
| if hasattr(detections, "class_id") and len(detections.class_id) > 0: |
| for idx in range(len(detections.class_id)): |
| cid = int(detections.class_id[idx]) |
| conf = float(detections.confidence[idx]) |
| bbox = detections.xyxy[idx].tolist() |
| results.append({ |
| "label": self.class_map.get(cid, f"threat_{cid}"), |
| "score": round(conf, 4), |
| "box": { |
| "xmin": int(bbox[0]), |
| "ymin": int(bbox[1]), |
| "xmax": int(bbox[2]), |
| "ymax": int(bbox[3]), |
| }, |
| }) |
|
|
| return results |
|
|