octavian7 commited on
Commit
fdc1b64
·
1 Parent(s): c57e453

Deploy YOLO model for traffic sign detection

Browse files
Files changed (3) hide show
  1. app.py +88 -0
  2. detection.pt +3 -0
  3. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from ultralytics import YOLO
4
+ from PIL import Image, UnidentifiedImageError
5
+ import io
6
+ import os
7
+ import time
8
+ import torch
9
+ import logging
10
+
11
+ logging.basicConfig(level=logging.INFO)
12
+
13
+ # initializing FastAPI
14
+ app = FastAPI()
15
+
16
+ # enabling Cross-Origin Resource Sharing
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"],
20
+ allow_credentials=True,
21
+ allow_methods=["*"],
22
+ allow_headers=["*"],
23
+ )
24
+
25
+ # use GPU if available, else use CPU
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+
28
+ MODEL_PATH = os.path.join(os.path.dirname(__file__), "detection.pt")
29
+
30
+ try:
31
+ model = YOLO(MODEL_PATH).to(device) # loading model to specified device
32
+ class_names = model.names # storing class names from model to class_names
33
+ logging.info(f"Model loaded successfully with {len(class_names)} classes.")
34
+ except Exception as e:
35
+ model = None
36
+ class_names = []
37
+ logging.error(f"Error loading model: {e}")
38
+
39
+ @app.get("/")
40
+ def home():
41
+ return {
42
+ "message": "Traffic Sign Detection API",
43
+ "available_route": "/detection/"
44
+ }
45
+
46
+ @app.post("/detection/")
47
+ async def predict(file: UploadFile = File(...)):
48
+ if model is None:
49
+ raise HTTPException(status_code=500, detail="Model not loaded. Check logs for details.")
50
+
51
+ try:
52
+ contents = await file.read() # reading uploaded image
53
+ image = Image.open(io.BytesIO(contents))
54
+
55
+ if image.mode != "RGB": # check if image is in RGB
56
+ image = image.convert("RGB")
57
+
58
+ start_time = time.time()
59
+ results = model.predict(image, save=False, imgsz=640, device=device)
60
+ end_time = time.time()
61
+ inference_time = round((end_time - start_time) * 1000, 2) # calculating inference time
62
+
63
+ predictions = [] # storing all detection results in this array
64
+
65
+ for box in results[0].boxes:
66
+ class_id = int(box.cls)
67
+ class_name = class_names[class_id] if class_id < len(class_names) else "Unknown"
68
+
69
+ predictions.append({
70
+ "class_id": class_id,
71
+ "class_name": class_name
72
+ })
73
+
74
+ if not predictions:
75
+ logging.info("No objects detected.")
76
+
77
+ # returns both predictions and inference time
78
+ return {
79
+ "predictions": predictions,
80
+ "inference_time_ms": inference_time
81
+ }
82
+
83
+ except UnidentifiedImageError:
84
+ raise HTTPException(status_code=400, detail="Invalid image file")
85
+
86
+ except Exception as e:
87
+ logging.error(f"Prediction error: {e}")
88
+ raise HTTPException(status_code=500, detail=str(e))
detection.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8774ed0becad01a934857be831f1c3243d4af9c37e9b12c297c88340b59d3553
3
+ size 22546019
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ pillow
4
+ torch
5
+ ultralytics
6
+ python-multipart