| from fastapi import FastAPI, UploadFile, File, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from PIL import Image, UnidentifiedImageError |
| import onnxruntime as ort |
| import numpy as np |
| import io |
| import os |
| import time |
| import logging |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
|
|
| |
| app = FastAPI(title="Traffic Sign Detection API", version="1.0") |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| MODEL_PATH = os.path.join(os.path.dirname(__file__), "detection.onnx") |
|
|
| |
| CLASS_NAMES = { |
| 0: "Crossroad", 1: "Cycle Prohibited", 2: "Gap in the Median", 3: "Give Way", |
| 4: "Go Slow", 5: "Horn Prohibited", 6: "Hospital", 7: "Keep Left", |
| 8: "Left Turn", 9: "Men at Work", 10: "No Entry", 11: "No Left Turn", |
| 12: "No Overtaking", 13: "No Parking", 14: "No Right Turn", 15: "No Stopping", |
| 16: "Parking", 17: "Pedestrian Crossing", 18: "Right Turn", 19: "Roundabout", |
| 20: "School Ahead", 21: "Side Road Left", 22: "Side Road Right", |
| 23: "Speed Breaker", 24: "Speed Limit 20", 25: "Speed Limit 30", |
| 26: "Speed Limit 40", 27: "Speed Limit 50", 28: "Speed Limit 60", |
| 29: "Speed Limit 80", 30: "Stop", 31: "T Intersection", |
| 32: "Traffic Signal Ahead", 33: "U-Turn Prohibited", 34: "U-Turn", |
| 35: "Y Intersection", 36: "Zigzag Road" |
| } |
|
|
| |
| try: |
| session = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"]) |
| input_name = session.get_inputs()[0].name |
| logging.info("ONNX model loaded successfully.") |
| except Exception as e: |
| session = None |
| logging.error(f"Error loading ONNX model: {e}") |
|
|
| |
| @app.get("/") |
| def home(): |
| return {"message": "Welcome to the Traffic Sign Detection API. Visit /docs for API documentation."} |
|
|
| |
| @app.get("/health/") |
| def health_check(): |
| return {"status": "ok", "model_loaded": session is not None} |
|
|
| |
| @app.post("/detection/") |
| async def predict(file: UploadFile = File(...)): |
| if session is None: |
| raise HTTPException(status_code=500, detail="ONNX model not loaded.") |
|
|
| try: |
| |
| contents = await file.read() |
| image = Image.open(io.BytesIO(contents)) |
| |
| |
| original_width, original_height = image.size |
|
|
| |
| if image.mode != "RGB": |
| image = image.convert("RGB") |
|
|
| |
| img_resized = image.resize((640, 640)) |
| img_array = np.array(img_resized, dtype=np.float32) / 255.0 |
| img_array = np.transpose(img_array, (2, 0, 1)) |
| img_array = np.expand_dims(img_array, axis=0) |
|
|
| |
| start_time = time.time() |
| outputs = session.run(None, {input_name: img_array}) |
| end_time = time.time() |
| inference_time = round((end_time - start_time) * 1000, 2) |
|
|
| |
| output = outputs[0] |
| |
| |
| |
| box_predictions = output[0, :4, :] |
| class_predictions = output[0, 4:, :] |
| |
| num_classes = class_predictions.shape[0] |
| num_anchors = class_predictions.shape[1] |
| |
| logging.info(f"Processing output: box_predictions shape: {box_predictions.shape}, class_predictions shape: {class_predictions.shape}") |
| |
| |
| detections = [] |
| CONFIDENCE_THRESHOLD = 0.3 |
| |
| |
| max_class_indices = np.argmax(class_predictions, axis=0) |
| max_class_values = np.max(class_predictions, axis=0) |
| |
| |
| for anchor_idx in range(num_anchors): |
| class_id = int(max_class_indices[anchor_idx]) |
| confidence = float(max_class_values[anchor_idx]) |
| |
| |
| if confidence > CONFIDENCE_THRESHOLD: |
| |
| x, y, w, h = [float(box_predictions[i, anchor_idx]) for i in range(4)] |
| |
| |
| normalized_confidence = min(confidence / 100.0, 1.0) |
| |
| |
| x_scaled = (x / 640) * original_width |
| y_scaled = (y / 640) * original_height |
| w_scaled = (w / 640) * original_width |
| h_scaled = (h / 640) * original_height |
| |
| |
| detections.append({ |
| "class_id": class_id, |
| "class_name": CLASS_NAMES.get(class_id % len(CLASS_NAMES), f"Unknown-{class_id}"), |
| "confidence": round(normalized_confidence * 100, 2), |
| "bbox": { |
| "x": x_scaled, |
| "y": y_scaled, |
| "width": w_scaled, |
| "height": h_scaled |
| } |
| }) |
| |
| |
| detections = sorted(detections, key=lambda x: x["confidence"], reverse=True) |
| |
| |
| detections = detections[:10] |
| |
| |
| logging.info(f"Found {len(detections)} detections") |
| for det in detections[:5]: |
| logging.info(f"Detection: {det}") |
|
|
| if not detections: |
| return {"message": "No traffic signs detected", "inference_time_ms": inference_time} |
|
|
| return { |
| "detections": detections, |
| "inference_time_ms": inference_time |
| } |
|
|
| except UnidentifiedImageError: |
| raise HTTPException(status_code=400, detail="Invalid image file") |
| |
| except Exception as e: |
| logging.error(f"Prediction error: {e}", exc_info=True) |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/debug-detection/") |
| async def debug_prediction(file: UploadFile = File(...)): |
| """Special endpoint for debugging classification issues""" |
| if session is None: |
| raise HTTPException(status_code=500, detail="ONNX model not loaded.") |
|
|
| try: |
| |
| contents = await file.read() |
| image = Image.open(io.BytesIO(contents)) |
| |
| |
| original_width, original_height = image.size |
|
|
| |
| if image.mode != "RGB": |
| image = image.convert("RGB") |
|
|
| |
| img_resized = image.resize((640, 640)) |
| img_array = np.array(img_resized, dtype=np.float32) / 255.0 |
| img_array = np.transpose(img_array, (2, 0, 1)) |
| img_array = np.expand_dims(img_array, axis=0) |
|
|
| |
| outputs = session.run(None, {input_name: img_array}) |
| output = outputs[0] |
| |
| |
| box_predictions = output[0, :4, :] |
| class_predictions = output[0, 4:, :] |
| |
| |
| max_confidence_per_anchor = np.max(class_predictions, axis=0) |
| sorted_anchor_indices = np.argsort(-max_confidence_per_anchor) |
| |
| |
| top_anchors = sorted_anchor_indices[:5] |
| |
| |
| debug_info = { |
| "top_detections": [], |
| "speed_limit_comparison": [] |
| } |
| |
| |
| speed_limit_classes = [24, 25, 26, 27, 28, 29] |
| |
| |
| for anchor_idx in top_anchors: |
| class_id = int(np.argmax(class_predictions[:, anchor_idx])) |
| confidence = float(np.max(class_predictions[:, anchor_idx])) |
| |
| |
| x, y, w, h = [float(box_predictions[i, anchor_idx]) for i in range(4)] |
| |
| |
| debug_info["top_detections"].append({ |
| "anchor_idx": int(anchor_idx), |
| "class_id": class_id, |
| "class_name": CLASS_NAMES.get(class_id, f"Unknown-{class_id}"), |
| "confidence": float(confidence), |
| "bbox": [float(x), float(y), float(w), float(h)] |
| }) |
| |
| |
| speed_limit_probs = {} |
| for sl_class in speed_limit_classes: |
| prob = float(class_predictions[sl_class, anchor_idx]) |
| speed_limit_probs[f"{CLASS_NAMES.get(sl_class)}"] = prob |
| |
| debug_info["speed_limit_comparison"].append({ |
| "anchor_idx": int(anchor_idx), |
| "highest_class": CLASS_NAMES.get(class_id, f"Unknown-{class_id}"), |
| "speed_limit_probabilities": speed_limit_probs |
| }) |
| |
| return debug_info |
|
|
| except UnidentifiedImageError: |
| raise HTTPException(status_code=400, detail="Invalid image file") |
| |
| except Exception as e: |
| logging.error(f"Debug error: {e}", exc_info=True) |
| raise HTTPException(status_code=500, detail=str(e)) |