Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from ultralytics import YOLO | |
| from PIL import Image, UnidentifiedImageError | |
| import io | |
| import os | |
| import time | |
| import torch | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| # initializing FastAPI | |
| app = FastAPI() | |
| # enabling Cross-Origin Resource Sharing | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # use GPU if available, else use CPU | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_PATH = os.path.join(os.path.dirname(__file__), "detection.pt") | |
| try: | |
| model = YOLO(MODEL_PATH).to(device) # loading model to specified device | |
| class_names = model.names # storing class names from model to class_names | |
| logging.info(f"Model loaded successfully with {len(class_names)} classes.") | |
| except Exception as e: | |
| model = None | |
| class_names = [] | |
| logging.error(f"Error loading model: {e}") | |
| def home(): | |
| return { | |
| "message": "Traffic Sign Detection API", | |
| "available_route": "/detection/" | |
| } | |
| async def predict(file: UploadFile = File(...)): | |
| if model is None: | |
| raise HTTPException(status_code=500, detail="Model not loaded. Check logs for details.") | |
| try: | |
| contents = await file.read() # reading uploaded image | |
| image = Image.open(io.BytesIO(contents)) | |
| if image.mode != "RGB": # check if image is in RGB | |
| image = image.convert("RGB") | |
| start_time = time.time() | |
| results = model.predict(image, save=False, imgsz=640, device=device) | |
| end_time = time.time() | |
| inference_time = round((end_time - start_time) * 1000, 2) # calculating inference time | |
| predictions = [] # storing all detection results in this array | |
| for box in results[0].boxes: | |
| class_id = int(box.cls) | |
| class_name = class_names[class_id] if class_id < len(class_names) else "Unknown" | |
| predictions.append({ | |
| "class_id": class_id, | |
| "class_name": class_name | |
| }) | |
| if not predictions: | |
| logging.info("No objects detected.") | |
| # returns both predictions and inference time | |
| return { | |
| "predictions": predictions, | |
| "inference_time_ms": inference_time | |
| } | |
| except UnidentifiedImageError: | |
| raise HTTPException(status_code=400, detail="Invalid image file") | |
| except Exception as e: | |
| logging.error(f"Prediction error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) |