from fastapi import FastAPI, Request, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import requests import logging from contextlib import asynccontextmanager # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global variables for model and tokenizer model = None tokenizer = None label_mapping = None @asynccontextmanager async def lifespan(app: FastAPI): """Load the model and tokenizer on startup""" global model, tokenizer, label_mapping try: model_name = "ityndall/james-river-classifier" logger.info(f"Loading model: {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) # Load label mapping label_mapping_url = f"https://huggingface.co/{model_name}/resolve/main/label_mapping.json" response = requests.get(label_mapping_url) label_mapping = response.json() logger.info("Model loaded successfully") logger.info(f"Available labels: {list(label_mapping['id2label'].values())}") except Exception as e: logger.error(f"Error loading model: {str(e)}") raise e yield # Cleanup (if needed) logger.info("Shutting down...") app = FastAPI( title="James River Survey Classification API", description="API for classifying survey-related text messages into job types", version="1.0.0", lifespan=lifespan ) # Request model class PredictionRequest(BaseModel): message: str # Response model class PredictionResponse(BaseModel): label: str confidence: float @app.get("/") async def root(): """Root endpoint with API information""" return { "message": "James River Survey Classification API", "version": "1.0.0", "model": "ityndall/james-river-classifier", "available_labels": list(label_mapping["id2label"].values()) if label_mapping else [], "endpoints": { "predict": "/predict - POST endpoint for text classification", "health": "/health - GET endpoint for health check" } } @app.get("/health") async def health_check(): """Health check endpoint""" if model is None or tokenizer is None or label_mapping is None: raise HTTPException(status_code=503, detail="Model not loaded") return {"status": "healthy", "model_loaded": True} @app.post("/predict", response_model=PredictionResponse) async def predict(request: PredictionRequest): """Predict the survey job type for the given message""" if model is None or tokenizer is None or label_mapping is None: raise HTTPException(status_code=503, detail="Model not loaded") try: text = request.message.strip() if not text: raise HTTPException(status_code=400, detail="Message cannot be empty") # Tokenize and predict inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) with torch.no_grad(): logits = model(**inputs).logits probs = torch.nn.functional.softmax(logits, dim=-1) predicted_class_id = probs.argmax().item() confidence = probs[0][predicted_class_id].item() # Get label label = label_mapping["id2label"][str(predicted_class_id)] logger.info(f"Prediction: '{text}' -> {label} (confidence: {confidence:.3f})") return PredictionResponse(label=label, confidence=confidence) except Exception as e: logger.error(f"Error during prediction: {str(e)}") raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") # Legacy endpoint for backward compatibility @app.post("/predict_legacy") async def predict_legacy(request: Request): """Legacy endpoint that accepts raw JSON (for backward compatibility)""" try: data = await request.json() message = data.get("message", "") if not message: raise HTTPException(status_code=400, detail="Message field is required") # Use the main predict function prediction_request = PredictionRequest(message=message) result = await predict(prediction_request) return {"label": result.label, "confidence": result.confidence} except Exception as e: logger.error(f"Error in legacy endpoint: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)