MuayThaiLegz's picture
Upload handler.py with huggingface_hub
b11aebf verified
"""Custom inference handler for RF-DETR Threat Detection on HuggingFace Inference Endpoints."""
from typing import Any, Dict, List
import io
import numpy as np
from PIL import Image
class EndpointHandler:
def __init__(self, path: str = ""):
from rfdetr import RFDETRNano
import os
weights = os.path.join(path, "checkpoint_best_total.pth")
self.model = RFDETRNano(resolution=960, pretrain_weights=weights)
self.model.optimize_for_inference()
self.class_map = {1: "Gun", 2: "Explosive", 3: "Grenade", 4: "Knife"}
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data.get("inputs")
if isinstance(inputs, bytes):
image = Image.open(io.BytesIO(inputs)).convert("RGB")
elif isinstance(inputs, str):
import base64
image = Image.open(io.BytesIO(base64.b64decode(inputs))).convert("RGB")
else:
image = inputs
threshold = data.get("parameters", {}).get("threshold", 0.25)
detections = self.model.predict(image, threshold=threshold)
results = []
if hasattr(detections, "class_id") and len(detections.class_id) > 0:
for idx in range(len(detections.class_id)):
cid = int(detections.class_id[idx])
conf = float(detections.confidence[idx])
bbox = detections.xyxy[idx].tolist()
results.append({
"label": self.class_map.get(cid, f"threat_{cid}"),
"score": round(conf, 4),
"box": {
"xmin": int(bbox[0]),
"ymin": int(bbox[1]),
"xmax": int(bbox[2]),
"ymax": int(bbox[3]),
},
})
return results