from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.middleware.cors import CORSMiddleware from ultralytics import YOLO from PIL import Image, UnidentifiedImageError import io import os import time import torch import logging logging.basicConfig(level=logging.INFO) # initializing FastAPI app = FastAPI() # enabling Cross-Origin Resource Sharing app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # use GPU if available, else use CPU device = "cuda" if torch.cuda.is_available() else "cpu" MODEL_PATH = os.path.join(os.path.dirname(__file__), "detection.pt") try: model = YOLO(MODEL_PATH).to(device) # loading model to specified device class_names = model.names # storing class names from model to class_names logging.info(f"Model loaded successfully with {len(class_names)} classes.") except Exception as e: model = None class_names = [] logging.error(f"Error loading model: {e}") @app.get("/") def home(): return { "message": "Traffic Sign Detection API", "available_route": "/detection/" } @app.post("/detection/") async def predict(file: UploadFile = File(...)): if model is None: raise HTTPException(status_code=500, detail="Model not loaded. Check logs for details.") try: contents = await file.read() # reading uploaded image image = Image.open(io.BytesIO(contents)) if image.mode != "RGB": # check if image is in RGB image = image.convert("RGB") start_time = time.time() results = model.predict(image, save=False, imgsz=640, device=device) end_time = time.time() inference_time = round((end_time - start_time) * 1000, 2) # calculating inference time predictions = [] # storing all detection results in this array for box in results[0].boxes: class_id = int(box.cls) class_name = class_names[class_id] if class_id < len(class_names) else "Unknown" predictions.append({ "class_id": class_id, "class_name": class_name }) if not predictions: logging.info("No objects detected.") # returns both predictions and inference time return { "predictions": predictions, "inference_time_ms": inference_time } except UnidentifiedImageError: raise HTTPException(status_code=400, detail="Invalid image file") except Exception as e: logging.error(f"Prediction error: {e}") raise HTTPException(status_code=500, detail=str(e))