File size: 1,852 Bytes
b11aebf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | """Custom inference handler for RF-DETR Threat Detection on HuggingFace Inference Endpoints."""
from typing import Any, Dict, List
import io
import numpy as np
from PIL import Image
class EndpointHandler:
def __init__(self, path: str = ""):
from rfdetr import RFDETRNano
import os
weights = os.path.join(path, "checkpoint_best_total.pth")
self.model = RFDETRNano(resolution=960, pretrain_weights=weights)
self.model.optimize_for_inference()
self.class_map = {1: "Gun", 2: "Explosive", 3: "Grenade", 4: "Knife"}
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data.get("inputs")
if isinstance(inputs, bytes):
image = Image.open(io.BytesIO(inputs)).convert("RGB")
elif isinstance(inputs, str):
import base64
image = Image.open(io.BytesIO(base64.b64decode(inputs))).convert("RGB")
else:
image = inputs
threshold = data.get("parameters", {}).get("threshold", 0.25)
detections = self.model.predict(image, threshold=threshold)
results = []
if hasattr(detections, "class_id") and len(detections.class_id) > 0:
for idx in range(len(detections.class_id)):
cid = int(detections.class_id[idx])
conf = float(detections.confidence[idx])
bbox = detections.xyxy[idx].tolist()
results.append({
"label": self.class_map.get(cid, f"threat_{cid}"),
"score": round(conf, 4),
"box": {
"xmin": int(bbox[0]),
"ymin": int(bbox[1]),
"xmax": int(bbox[2]),
"ymax": int(bbox[3]),
},
})
return results
|