|
|
from typing import Dict, Any |
|
|
import base64 |
|
|
import io |
|
|
from PIL import Image |
|
|
from ultralytics import YOLO |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
""" |
|
|
Hugging Face Inference Endpoint handler for LocustGuard YOLO model. |
|
|
|
|
|
Accepts: |
|
|
1) Direct HTTP / Spaces: |
|
|
{ |
|
|
"image": "<base64>", |
|
|
"conf": 0.25, |
|
|
"iou": 0.45 |
|
|
} |
|
|
|
|
|
2) Playground / Hosted API: |
|
|
{ |
|
|
"inputs": { |
|
|
"image": "<base64>", |
|
|
"conf": 0.25, |
|
|
"iou": 0.45 |
|
|
} |
|
|
} |
|
|
|
|
|
3) HF standard: |
|
|
{ |
|
|
"inputs": "<base64>" |
|
|
} |
|
|
|
|
|
Returns: |
|
|
{ |
|
|
"detections": [ |
|
|
{ |
|
|
"label": str, |
|
|
"confidence": float, |
|
|
"coordinates": [xmin, ymin, xmax, ymax] |
|
|
} |
|
|
] |
|
|
} |
|
|
""" |
|
|
|
|
|
def __init__(self, path: str = "."): |
|
|
|
|
|
self.model = YOLO(f"{path}/best.pt") |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
|
|
payload = data.get("inputs", data) |
|
|
|
|
|
|
|
|
if isinstance(payload, str): |
|
|
image_b64 = payload |
|
|
conf = 0.25 |
|
|
iou = 0.45 |
|
|
|
|
|
|
|
|
elif isinstance(payload, dict) and "image" in payload: |
|
|
image_b64 = payload["image"] |
|
|
conf = float(payload.get("conf", 0.25)) |
|
|
iou = float(payload.get("iou", 0.45)) |
|
|
|
|
|
else: |
|
|
return { |
|
|
"error": "Invalid input. Expected base64 image under key 'image' or 'inputs'." |
|
|
} |
|
|
|
|
|
|
|
|
try: |
|
|
image_bytes = base64.b64decode(image_b64) |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
except Exception as e: |
|
|
return {"error": f"Failed to decode image: {str(e)}"} |
|
|
|
|
|
|
|
|
results = self.model( |
|
|
image, |
|
|
conf=conf, |
|
|
iou=iou, |
|
|
imgsz=640, |
|
|
verbose=False |
|
|
) |
|
|
|
|
|
r = results[0] |
|
|
detections = [] |
|
|
|
|
|
if r.boxes is not None: |
|
|
for box in r.boxes: |
|
|
x1, y1, x2, y2 = box.xyxy[0].tolist() |
|
|
cls_id = int(box.cls[0]) |
|
|
conf_score = float(box.conf[0]) |
|
|
label = self.model.names[cls_id] |
|
|
|
|
|
detections.append({ |
|
|
"label": label, |
|
|
"confidence": round(conf_score, 3), |
|
|
"coordinates": [ |
|
|
round(float(x1), 2), |
|
|
round(float(y1), 2), |
|
|
round(float(x2), 2), |
|
|
round(float(y2), 2), |
|
|
] |
|
|
}) |
|
|
|
|
|
return {"detections": detections} |
|
|
|