Spaces:
Paused
Paused
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from ml_service import get_ml_service | |
| from schemas import PredictionRequest, PredictionResponse, PredictionItem | |
| import asyncio | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="ML Text Classification API", | |
| description="API for multi-label text classification using DistilBERT", | |
| version="1.0.0" | |
| ) | |
| # Configure CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def root(): | |
| """Root endpoint.""" | |
| return { | |
| "message": "ML Text Classification API", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "health": "/health", | |
| "predict": "/predict" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| return {"status": "healthy"} | |
| async def predict(prediction_request: PredictionRequest): | |
| """ | |
| Predict labels for the given text. | |
| Args: | |
| prediction_request: Request containing the text to classify | |
| Returns: | |
| PredictionResponse with classification results | |
| """ | |
| if not prediction_request.text: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Text field is required and cannot be empty" | |
| ) | |
| try: | |
| # Get ML service and predict in executor to avoid blocking | |
| ml_service = get_ml_service() | |
| # Run blocking ML inference in thread pool | |
| loop = asyncio.get_event_loop() | |
| predictions = await loop.run_in_executor( | |
| None, # Use default executor | |
| ml_service.predict, | |
| prediction_request.text | |
| ) | |
| # Convert to Pydantic models | |
| results = [ | |
| PredictionItem(label=item['label'], score=item['score']) | |
| for item in predictions | |
| ] | |
| return PredictionResponse(results=results) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Prediction error: {str(e)}" | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |