Spaces:
Running
Running
| import os | |
| import io | |
| import logging | |
| from typing import Tuple | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from PIL import Image | |
| # Roboflow inference | |
| from inference import get_model | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("vehicle-predictor") | |
| # FastAPI setup | |
| app = FastAPI(title="Vehicle Type Predictor") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # you can tighten this later if needed | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Load Roboflow model at startup | |
| ROBOFLOW_API_KEY = os.environ.get("ROBOFLOW_API_KEY") | |
| MODEL_ID = "vehicle-classification-eapcd/19" | |
| if ROBOFLOW_API_KEY is None: | |
| logger.error("β ROBOFLOW_API_KEY not found in environment variables") | |
| model = None | |
| else: | |
| try: | |
| logger.info("π Loading Roboflow model...") | |
| model = get_model(model_id=MODEL_ID, api_key=ROBOFLOW_API_KEY) | |
| logger.info("β Roboflow model loaded successfully") | |
| except Exception as e: | |
| logger.exception("β Failed to load Roboflow model") | |
| model = None | |
| # Response model | |
| class PredictionResponse(BaseModel): | |
| label: str | |
| confidence: float | |
| async def predict(file: UploadFile = File(...)): | |
| if model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| try: | |
| contents = await file.read() | |
| # Roboflow accepts PIL Image directly | |
| img = Image.open(io.BytesIO(contents)).convert("RGB") | |
| # Run inference | |
| result = model.infer(img) | |
| if not result.get("predictions"): | |
| raise HTTPException(status_code=500, detail="No predictions returned") | |
| # Take top prediction | |
| pred = result["predictions"][0] | |
| label = pred.get("class", "Unknown") | |
| confidence = float(pred.get("confidence", 0.0)) | |
| logger.info(f"Predicted {label} ({confidence:.4f}) for {file.filename}") | |
| return PredictionResponse(label=label, confidence=confidence) | |
| except Exception as e: | |
| logger.exception("Prediction failed") | |
| raise HTTPException(status_code=500, detail="Prediction failed") | |
| def health(): | |
| return {"status": "ok", "model_loaded": model is not None} | |