Spaces:
Sleeping
Sleeping
| from fastapi import APIRouter, HTTPException | |
| from pydantic import BaseModel | |
| from typing import Dict, Union, List | |
| from models.text_classification import TextClassificationModel | |
| router = APIRouter() | |
| model = TextClassificationModel() | |
| class TextInput(BaseModel): | |
| text: str | |
| class BatchTextInput(BaseModel): | |
| texts: List[str] | |
| class PredictionResponse(BaseModel): | |
| label: str | |
| confidence: float | |
| class BatchPredictionResponse(BaseModel): | |
| predictions: List[PredictionResponse] | |
| async def predict(input_data: TextInput) -> Dict[str, Union[str, float]]: | |
| """Make a prediction for a single text.""" | |
| try: | |
| result = await model.predict(input_data.text) | |
| return result | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Prediction failed: {str(e)}" | |
| ) | |
| async def predict_batch(input_data: BatchTextInput) -> Dict[str, List[Dict[str, Union[str, float]]]]: | |
| """Make predictions for multiple texts.""" | |
| try: | |
| predictions = [] | |
| for text in input_data.texts: | |
| result = await model.predict(text) | |
| predictions.append(result) | |
| return {"predictions": predictions} | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Batch prediction failed: {str(e)}" | |
| ) | |
| async def get_model_info(): | |
| """Get information about the text classification model.""" | |
| return model.get_info() |