Spaces:
Running
Running
File size: 2,521 Bytes
93fc243 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import os
import io
import logging
from typing import Tuple
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from PIL import Image
# Roboflow inference
from inference import get_model
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("vehicle-predictor")
# FastAPI setup
app = FastAPI(title="Vehicle Type Predictor")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # you can tighten this later if needed
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load Roboflow model at startup
ROBOFLOW_API_KEY = os.environ.get("ROBOFLOW_API_KEY")
MODEL_ID = "vehicle-classification-eapcd/19"
if ROBOFLOW_API_KEY is None:
logger.error("β ROBOFLOW_API_KEY not found in environment variables")
model = None
else:
try:
logger.info("π Loading Roboflow model...")
model = get_model(model_id=MODEL_ID, api_key=ROBOFLOW_API_KEY)
logger.info("β
Roboflow model loaded successfully")
except Exception as e:
logger.exception("β Failed to load Roboflow model")
model = None
# Response model
class PredictionResponse(BaseModel):
label: str
confidence: float
@app.post("/predict", response_model=PredictionResponse)
async def predict(file: UploadFile = File(...)):
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
try:
contents = await file.read()
# Roboflow accepts PIL Image directly
img = Image.open(io.BytesIO(contents)).convert("RGB")
# Run inference
result = model.infer(img)
if not result.get("predictions"):
raise HTTPException(status_code=500, detail="No predictions returned")
# Take top prediction
pred = result["predictions"][0]
label = pred.get("class", "Unknown")
confidence = float(pred.get("confidence", 0.0))
logger.info(f"Predicted {label} ({confidence:.4f}) for {file.filename}")
return PredictionResponse(label=label, confidence=confidence)
except Exception as e:
logger.exception("Prediction failed")
raise HTTPException(status_code=500, detail="Prediction failed")
@app.get("/health")
def health():
return {"status": "ok", "model_loaded": model is not None}
|