lovebird25 / main.py
Paul
Initial commit
75146bf
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from ml_service import get_ml_service
from schemas import PredictionRequest, PredictionResponse, PredictionItem
import asyncio
# Initialize FastAPI app
app = FastAPI(
title="ML Text Classification API",
description="API for multi-label text classification using DistilBERT",
version="1.0.0"
)
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
"""Root endpoint."""
return {
"message": "ML Text Classification API",
"version": "1.0.0",
"endpoints": {
"health": "/health",
"predict": "/predict"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy"}
@app.post("/predict", response_model=PredictionResponse)
async def predict(prediction_request: PredictionRequest):
"""
Predict labels for the given text.
Args:
prediction_request: Request containing the text to classify
Returns:
PredictionResponse with classification results
"""
if not prediction_request.text:
raise HTTPException(
status_code=400,
detail="Text field is required and cannot be empty"
)
try:
# Get ML service and predict in executor to avoid blocking
ml_service = get_ml_service()
# Run blocking ML inference in thread pool
loop = asyncio.get_event_loop()
predictions = await loop.run_in_executor(
None, # Use default executor
ml_service.predict,
prediction_request.text
)
# Convert to Pydantic models
results = [
PredictionItem(label=item['label'], score=item['score'])
for item in predictions
]
return PredictionResponse(results=results)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Prediction error: {str(e)}"
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)