rttsd / app.py
octavian7's picture
Update app.py
be51db7 verified
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image, UnidentifiedImageError
import onnxruntime as ort
import numpy as np
import io
import os
import time
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
# Initialize FastAPI app
app = FastAPI(title="Traffic Sign Detection API", version="1.0")
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define model path
MODEL_PATH = os.path.join(os.path.dirname(__file__), "detection.onnx")
# Define class names
CLASS_NAMES = {
0: "Crossroad", 1: "Cycle Prohibited", 2: "Gap in the Median", 3: "Give Way",
4: "Go Slow", 5: "Horn Prohibited", 6: "Hospital", 7: "Keep Left",
8: "Left Turn", 9: "Men at Work", 10: "No Entry", 11: "No Left Turn",
12: "No Overtaking", 13: "No Parking", 14: "No Right Turn", 15: "No Stopping",
16: "Parking", 17: "Pedestrian Crossing", 18: "Right Turn", 19: "Roundabout",
20: "School Ahead", 21: "Side Road Left", 22: "Side Road Right",
23: "Speed Breaker", 24: "Speed Limit 20", 25: "Speed Limit 30",
26: "Speed Limit 40", 27: "Speed Limit 50", 28: "Speed Limit 60",
29: "Speed Limit 80", 30: "Stop", 31: "T Intersection",
32: "Traffic Signal Ahead", 33: "U-Turn Prohibited", 34: "U-Turn",
35: "Y Intersection", 36: "Zigzag Road"
}
# Load ONNX model
try:
session = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"])
input_name = session.get_inputs()[0].name
logging.info("ONNX model loaded successfully.")
except Exception as e:
session = None
logging.error(f"Error loading ONNX model: {e}")
# Root route (API welcome message)
@app.get("/")
def home():
return {"message": "Welcome to the Traffic Sign Detection API. Visit /docs for API documentation."}
# Health check route
@app.get("/health/")
def health_check():
return {"status": "ok", "model_loaded": session is not None}
# Traffic sign detection route
@app.post("/detection/")
async def predict(file: UploadFile = File(...)):
if session is None:
raise HTTPException(status_code=500, detail="ONNX model not loaded.")
try:
# Read and open the image
contents = await file.read()
image = Image.open(io.BytesIO(contents))
# Keep original image dimensions for scaling coordinates later
original_width, original_height = image.size
# Ensure image is in RGB format
if image.mode != "RGB":
image = image.convert("RGB")
# Preprocess image
img_resized = image.resize((640, 640))
img_array = np.array(img_resized, dtype=np.float32) / 255.0 # Normalize to [0,1]
img_array = np.transpose(img_array, (2, 0, 1)) # Change shape to (C, H, W)
img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
# Run inference
start_time = time.time()
outputs = session.run(None, {input_name: img_array})
end_time = time.time()
inference_time = round((end_time - start_time) * 1000, 2)
# Extract output - YOLOv8 format (1, num_classes + box_params, num_anchors)
output = outputs[0]
# The first 4 rows contain the bounding box coordinates (x, y, w, h)
# The remaining rows contain class predictions
box_predictions = output[0, :4, :] # Shape: (4, num_anchors)
class_predictions = output[0, 4:, :] # Shape: (num_classes, num_anchors)
num_classes = class_predictions.shape[0]
num_anchors = class_predictions.shape[1]
logging.info(f"Processing output: box_predictions shape: {box_predictions.shape}, class_predictions shape: {class_predictions.shape}")
# Process detections
detections = []
CONFIDENCE_THRESHOLD = 0.3
# Find the class with the highest confidence for each anchor
max_class_indices = np.argmax(class_predictions, axis=0) # Shape: (num_anchors,)
max_class_values = np.max(class_predictions, axis=0) # Shape: (num_anchors,)
# Create detection objects
for anchor_idx in range(num_anchors):
class_id = int(max_class_indices[anchor_idx])
confidence = float(max_class_values[anchor_idx])
# Apply confidence threshold
if confidence > CONFIDENCE_THRESHOLD:
# Extract bounding box
x, y, w, h = [float(box_predictions[i, anchor_idx]) for i in range(4)]
# Normalize confidence (assuming the model outputs unnormalized values)
normalized_confidence = min(confidence / 100.0, 1.0)
# Convert to original image coordinates
x_scaled = (x / 640) * original_width
y_scaled = (y / 640) * original_height
w_scaled = (w / 640) * original_width
h_scaled = (h / 640) * original_height
# Create detection object
detections.append({
"class_id": class_id,
"class_name": CLASS_NAMES.get(class_id % len(CLASS_NAMES), f"Unknown-{class_id}"),
"confidence": round(normalized_confidence * 100, 2),
"bbox": {
"x": x_scaled,
"y": y_scaled,
"width": w_scaled,
"height": h_scaled
}
})
# Sort detections by confidence (descending order)
detections = sorted(detections, key=lambda x: x["confidence"], reverse=True)
# Limit to the top detections to avoid overwhelming the response
detections = detections[:10]
# Log detected classes
logging.info(f"Found {len(detections)} detections")
for det in detections[:5]: # Log first 5 detections
logging.info(f"Detection: {det}")
if not detections:
return {"message": "No traffic signs detected", "inference_time_ms": inference_time}
return {
"detections": detections,
"inference_time_ms": inference_time
}
except UnidentifiedImageError:
raise HTTPException(status_code=400, detail="Invalid image file")
except Exception as e:
logging.error(f"Prediction error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/debug-detection/")
async def debug_prediction(file: UploadFile = File(...)):
"""Special endpoint for debugging classification issues"""
if session is None:
raise HTTPException(status_code=500, detail="ONNX model not loaded.")
try:
# Read and open the image
contents = await file.read()
image = Image.open(io.BytesIO(contents))
# Keep original image dimensions
original_width, original_height = image.size
# Ensure image is in RGB format
if image.mode != "RGB":
image = image.convert("RGB")
# Preprocess image
img_resized = image.resize((640, 640))
img_array = np.array(img_resized, dtype=np.float32) / 255.0
img_array = np.transpose(img_array, (2, 0, 1))
img_array = np.expand_dims(img_array, axis=0)
# Run inference
outputs = session.run(None, {input_name: img_array})
output = outputs[0]
# Extract outputs
box_predictions = output[0, :4, :]
class_predictions = output[0, 4:, :]
# Find anchors with highest confidence
max_confidence_per_anchor = np.max(class_predictions, axis=0)
sorted_anchor_indices = np.argsort(-max_confidence_per_anchor) # Sort in descending order
# Get top 5 anchors
top_anchors = sorted_anchor_indices[:5]
# Debug information
debug_info = {
"top_detections": [],
"speed_limit_comparison": []
}
# Speed limit class IDs
speed_limit_classes = [24, 25, 26, 27, 28, 29] # 30, 40, 50, 60, 80
# Extract info for top detections
for anchor_idx in top_anchors:
class_id = int(np.argmax(class_predictions[:, anchor_idx]))
confidence = float(np.max(class_predictions[:, anchor_idx]))
# Extract bounding box
x, y, w, h = [float(box_predictions[i, anchor_idx]) for i in range(4)]
# Add to debug info
debug_info["top_detections"].append({
"anchor_idx": int(anchor_idx),
"class_id": class_id,
"class_name": CLASS_NAMES.get(class_id, f"Unknown-{class_id}"),
"confidence": float(confidence),
"bbox": [float(x), float(y), float(w), float(h)]
})
# For each top anchor, compare all speed limit classes
speed_limit_probs = {}
for sl_class in speed_limit_classes:
prob = float(class_predictions[sl_class, anchor_idx])
speed_limit_probs[f"{CLASS_NAMES.get(sl_class)}"] = prob
debug_info["speed_limit_comparison"].append({
"anchor_idx": int(anchor_idx),
"highest_class": CLASS_NAMES.get(class_id, f"Unknown-{class_id}"),
"speed_limit_probabilities": speed_limit_probs
})
return debug_info
except UnidentifiedImageError:
raise HTTPException(status_code=400, detail="Invalid image file")
except Exception as e:
logging.error(f"Debug error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))