Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, Request, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import requests | |
| import logging | |
| from contextlib import asynccontextmanager | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Global variables for model and tokenizer | |
| model = None | |
| tokenizer = None | |
| label_mapping = None | |
| async def lifespan(app: FastAPI): | |
| """Load the model and tokenizer on startup""" | |
| global model, tokenizer, label_mapping | |
| try: | |
| model_name = "ityndall/james-river-classifier" | |
| logger.info(f"Loading model: {model_name}") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| # Load label mapping | |
| label_mapping_url = f"https://huggingface.co/{model_name}/resolve/main/label_mapping.json" | |
| response = requests.get(label_mapping_url) | |
| label_mapping = response.json() | |
| logger.info("Model loaded successfully") | |
| logger.info(f"Available labels: {list(label_mapping['id2label'].values())}") | |
| except Exception as e: | |
| logger.error(f"Error loading model: {str(e)}") | |
| raise e | |
| yield | |
| # Cleanup (if needed) | |
| logger.info("Shutting down...") | |
| app = FastAPI( | |
| title="James River Survey Classification API", | |
| description="API for classifying survey-related text messages into job types", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # Request model | |
| class PredictionRequest(BaseModel): | |
| message: str | |
| # Response model | |
| class PredictionResponse(BaseModel): | |
| label: str | |
| confidence: float | |
| async def root(): | |
| """Root endpoint with API information""" | |
| return { | |
| "message": "James River Survey Classification API", | |
| "version": "1.0.0", | |
| "model": "ityndall/james-river-classifier", | |
| "available_labels": list(label_mapping["id2label"].values()) if label_mapping else [], | |
| "endpoints": { | |
| "predict": "/predict - POST endpoint for text classification", | |
| "health": "/health - GET endpoint for health check" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| if model is None or tokenizer is None or label_mapping is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| return {"status": "healthy", "model_loaded": True} | |
| async def predict(request: PredictionRequest): | |
| """Predict the survey job type for the given message""" | |
| if model is None or tokenizer is None or label_mapping is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| text = request.message.strip() | |
| if not text: | |
| raise HTTPException(status_code=400, detail="Message cannot be empty") | |
| # Tokenize and predict | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| predicted_class_id = probs.argmax().item() | |
| confidence = probs[0][predicted_class_id].item() | |
| # Get label | |
| label = label_mapping["id2label"][str(predicted_class_id)] | |
| logger.info(f"Prediction: '{text}' -> {label} (confidence: {confidence:.3f})") | |
| return PredictionResponse(label=label, confidence=confidence) | |
| except Exception as e: | |
| logger.error(f"Error during prediction: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") | |
| # Legacy endpoint for backward compatibility | |
| async def predict_legacy(request: Request): | |
| """Legacy endpoint that accepts raw JSON (for backward compatibility)""" | |
| try: | |
| data = await request.json() | |
| message = data.get("message", "") | |
| if not message: | |
| raise HTTPException(status_code=400, detail="Message field is required") | |
| # Use the main predict function | |
| prediction_request = PredictionRequest(message=message) | |
| result = await predict(prediction_request) | |
| return {"label": result.label, "confidence": result.confidence} | |
| except Exception as e: | |
| logger.error(f"Error in legacy endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |