File size: 2,703 Bytes
fdc1b64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from ultralytics import YOLO
from PIL import Image, UnidentifiedImageError
import io
import os
import time
import torch
import logging

logging.basicConfig(level=logging.INFO)

# initializing FastAPI 
app = FastAPI()

# enabling Cross-Origin Resource Sharing 
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  
    allow_credentials=True,
    allow_methods=["*"],  
    allow_headers=["*"],  
)

# use GPU if available, else use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"

MODEL_PATH = os.path.join(os.path.dirname(__file__), "detection.pt")

try:
    model = YOLO(MODEL_PATH).to(device) # loading model to specified device
    class_names = model.names # storing class names from model to class_names
    logging.info(f"Model loaded successfully with {len(class_names)} classes.")
except Exception as e:
    model = None
    class_names = []
    logging.error(f"Error loading model: {e}")

@app.get("/")
def home():
    return {
        "message": "Traffic Sign Detection API",
        "available_route": "/detection/"
    }

@app.post("/detection/")
async def predict(file: UploadFile = File(...)):
    if model is None:
        raise HTTPException(status_code=500, detail="Model not loaded. Check logs for details.")

    try:
        contents = await file.read() # reading uploaded image
        image = Image.open(io.BytesIO(contents))

        if image.mode != "RGB": # check if image is in RGB
            image = image.convert("RGB") 
        
        start_time = time.time()
        results = model.predict(image, save=False, imgsz=640, device=device)
        end_time = time.time()
        inference_time = round((end_time - start_time) * 1000, 2) # calculating inference time

        predictions = [] # storing all detection results in this array

        for box in results[0].boxes:
            class_id = int(box.cls)  
            class_name = class_names[class_id] if class_id < len(class_names) else "Unknown"

            predictions.append({
                "class_id": class_id,
                "class_name": class_name
            })

        if not predictions:
            logging.info("No objects detected.")    

        # returns both predictions and inference time
        return {
            "predictions": predictions,
            "inference_time_ms": inference_time
        }

    except UnidentifiedImageError:
        raise HTTPException(status_code=400, detail="Invalid image file")
    
    except Exception as e:
        logging.error(f"Prediction error: {e}")
        raise HTTPException(status_code=500, detail=str(e))