File size: 1,852 Bytes
b11aebf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""Custom inference handler for RF-DETR Threat Detection on HuggingFace Inference Endpoints."""

from typing import Any, Dict, List
import io
import numpy as np
from PIL import Image


class EndpointHandler:
    def __init__(self, path: str = ""):
        from rfdetr import RFDETRNano
        import os

        weights = os.path.join(path, "checkpoint_best_total.pth")
        self.model = RFDETRNano(resolution=960, pretrain_weights=weights)
        self.model.optimize_for_inference()

        self.class_map = {1: "Gun", 2: "Explosive", 3: "Grenade", 4: "Knife"}

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        inputs = data.get("inputs")
        if isinstance(inputs, bytes):
            image = Image.open(io.BytesIO(inputs)).convert("RGB")
        elif isinstance(inputs, str):
            import base64
            image = Image.open(io.BytesIO(base64.b64decode(inputs))).convert("RGB")
        else:
            image = inputs

        threshold = data.get("parameters", {}).get("threshold", 0.25)
        detections = self.model.predict(image, threshold=threshold)

        results = []
        if hasattr(detections, "class_id") and len(detections.class_id) > 0:
            for idx in range(len(detections.class_id)):
                cid = int(detections.class_id[idx])
                conf = float(detections.confidence[idx])
                bbox = detections.xyxy[idx].tolist()
                results.append({
                    "label": self.class_map.get(cid, f"threat_{cid}"),
                    "score": round(conf, 4),
                    "box": {
                        "xmin": int(bbox[0]),
                        "ymin": int(bbox[1]),
                        "xmax": int(bbox[2]),
                        "ymax": int(bbox[3]),
                    },
                })

        return results