botInfinity's picture
Create main.py
93fc243 verified
raw
history blame
2.52 kB
import os
import io
import logging
from typing import Tuple
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from PIL import Image
# Roboflow inference
from inference import get_model
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("vehicle-predictor")
# FastAPI setup
app = FastAPI(title="Vehicle Type Predictor")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # you can tighten this later if needed
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load Roboflow model at startup
ROBOFLOW_API_KEY = os.environ.get("ROBOFLOW_API_KEY")
MODEL_ID = "vehicle-classification-eapcd/19"
if ROBOFLOW_API_KEY is None:
logger.error("❌ ROBOFLOW_API_KEY not found in environment variables")
model = None
else:
try:
logger.info("πŸš€ Loading Roboflow model...")
model = get_model(model_id=MODEL_ID, api_key=ROBOFLOW_API_KEY)
logger.info("βœ… Roboflow model loaded successfully")
except Exception as e:
logger.exception("❌ Failed to load Roboflow model")
model = None
# Response model
class PredictionResponse(BaseModel):
label: str
confidence: float
@app.post("/predict", response_model=PredictionResponse)
async def predict(file: UploadFile = File(...)):
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
try:
contents = await file.read()
# Roboflow accepts PIL Image directly
img = Image.open(io.BytesIO(contents)).convert("RGB")
# Run inference
result = model.infer(img)
if not result.get("predictions"):
raise HTTPException(status_code=500, detail="No predictions returned")
# Take top prediction
pred = result["predictions"][0]
label = pred.get("class", "Unknown")
confidence = float(pred.get("confidence", 0.0))
logger.info(f"Predicted {label} ({confidence:.4f}) for {file.filename}")
return PredictionResponse(label=label, confidence=confidence)
except Exception as e:
logger.exception("Prediction failed")
raise HTTPException(status_code=500, detail="Prediction failed")
@app.get("/health")
def health():
return {"status": "ok", "model_loaded": model is not None}