octavian7's picture
Deploy YOLO model for traffic sign detection
fdc1b64
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from ultralytics import YOLO
from PIL import Image, UnidentifiedImageError
import io
import os
import time
import torch
import logging
logging.basicConfig(level=logging.INFO)
# initializing FastAPI
app = FastAPI()
# enabling Cross-Origin Resource Sharing
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# use GPU if available, else use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH = os.path.join(os.path.dirname(__file__), "detection.pt")
try:
model = YOLO(MODEL_PATH).to(device) # loading model to specified device
class_names = model.names # storing class names from model to class_names
logging.info(f"Model loaded successfully with {len(class_names)} classes.")
except Exception as e:
model = None
class_names = []
logging.error(f"Error loading model: {e}")
@app.get("/")
def home():
return {
"message": "Traffic Sign Detection API",
"available_route": "/detection/"
}
@app.post("/detection/")
async def predict(file: UploadFile = File(...)):
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded. Check logs for details.")
try:
contents = await file.read() # reading uploaded image
image = Image.open(io.BytesIO(contents))
if image.mode != "RGB": # check if image is in RGB
image = image.convert("RGB")
start_time = time.time()
results = model.predict(image, save=False, imgsz=640, device=device)
end_time = time.time()
inference_time = round((end_time - start_time) * 1000, 2) # calculating inference time
predictions = [] # storing all detection results in this array
for box in results[0].boxes:
class_id = int(box.cls)
class_name = class_names[class_id] if class_id < len(class_names) else "Unknown"
predictions.append({
"class_id": class_id,
"class_name": class_name
})
if not predictions:
logging.info("No objects detected.")
# returns both predictions and inference time
return {
"predictions": predictions,
"inference_time_ms": inference_time
}
except UnidentifiedImageError:
raise HTTPException(status_code=400, detail="Invalid image file")
except Exception as e:
logging.error(f"Prediction error: {e}")
raise HTTPException(status_code=500, detail=str(e))