| from huggingface_hub import hf_hub_download |
| from typing import Dict, List, Any |
| from ultralytics import YOLO |
| import json |
| import urllib.request |
| import cv2 |
| from io import BytesIO |
| import numpy as np |
| class EndpointHandler(): |
| def __init__(self, path=""): |
| hf_hub_download(repo_id="Drazcat-AI/flejes", filename="flejes-16/runs/detect/train/weights/best.pt") |
| self.model = YOLO(hf_hub_download(repo_id="Drazcat-AI/flejes", filename="flejes-16/runs/detect/train/weights/best.pt", local_files_only=True)) |
| |
| def predict_objects(self, image_path, image_size_m): |
| results = self.model(image_path, imgsz=[1280, 960]) |
| predictions = [] |
| for box in results[0].boxes: |
| class_id = results[0].names[box.cls[0].item()] |
| cords = box.xywh[0].tolist() |
| |
| conf = box.conf[0].item() |
| prediction = { |
| "x": round(cords[0]*image_size_m[0]), |
| "y": round(cords[1]*image_size_m[1]), |
| "width": round(cords[2]*image_size_m[0]), |
| "height": round(cords[3]*image_size_m[1]), |
| "confidence": conf, |
| "class": class_id |
| } |
| predictions.append(prediction) |
| predictions_array = {"predictions": predictions} |
|
|
| return predictions_array |
|
|
| def __call__(self, event): |
| if "inputs" not in event: |
| return { |
| "statusCode": 400, |
| "body": json.dumps("Error: Please provide an 'inputs' parameter."), |
| } |
|
|
| image_path = event["inputs"] |
|
|
| try: |
| with urllib.request.urlopen(image_path) as response: |
| image_content = np.asarray(bytearray(response.read()), dtype=np.uint8) |
| image = cv2.imdecode(image_content, cv2.IMREAD_COLOR) |
| """ |
| image_size = image.shape |
| if image.shape[0]>image.shape[0]: |
| x, y = 1280, 960 |
| else: |
| y, x = 1280, 960 |
| image = cv2.resize(image, (x, y)) |
| |
| predictions = self.predict_objects(image, [image_size[0]/x,image_size[1]/y]) |
| """ |
| predictions = self.predict_objects(image, (1,1)) |
| return { |
| "statusCode": 200, |
| "body": json.dumps(predictions), |
| } |
| except Exception as e: |
| return { |
| "statusCode": 500, |
| "body": json.dumps(f"Error: {str(e)}"), |
| } |