import os import io import logging from typing import Tuple from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from PIL import Image # Roboflow inference from inference import get_model logging.basicConfig(level=logging.INFO) logger = logging.getLogger("vehicle-predictor") # FastAPI setup app = FastAPI(title="Vehicle Type Predictor") app.add_middleware( CORSMiddleware, allow_origins=["*"], # you can tighten this later if needed allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load Roboflow model at startup ROBOFLOW_API_KEY = os.environ.get("ROBOFLOW_API_KEY") MODEL_ID = "vehicle-classification-eapcd/19" if ROBOFLOW_API_KEY is None: logger.error("❌ ROBOFLOW_API_KEY not found in environment variables") model = None else: try: logger.info("🚀 Loading Roboflow model...") model = get_model(model_id=MODEL_ID, api_key=ROBOFLOW_API_KEY) logger.info("✅ Roboflow model loaded successfully") except Exception as e: logger.exception("❌ Failed to load Roboflow model") model = None # Response model class PredictionResponse(BaseModel): label: str confidence: float @app.post("/predict", response_model=PredictionResponse) async def predict(file: UploadFile = File(...)): if model is None: raise HTTPException(status_code=503, detail="Model not loaded") if not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="File must be an image") try: contents = await file.read() # Roboflow accepts PIL Image directly img = Image.open(io.BytesIO(contents)).convert("RGB") # Run inference result = model.infer(img) if not result.get("predictions"): raise HTTPException(status_code=500, detail="No predictions returned") # Take top prediction pred = result["predictions"][0] label = pred.get("class", "Unknown") confidence = float(pred.get("confidence", 0.0)) logger.info(f"Predicted {label} ({confidence:.4f}) for {file.filename}") return PredictionResponse(label=label, confidence=confidence) except Exception as e: logger.exception("Prediction failed") raise HTTPException(status_code=500, detail="Prediction failed") @app.get("/health") def health(): return {"status": "ok", "model_loaded": model is not None}