james-river-api / app.py
ityndall's picture
Fix FastAPI lifespan handler for modern FastAPI compatibility
0a6516c
from fastapi import FastAPI, Request, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import requests
import logging
from contextlib import asynccontextmanager
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables for model and tokenizer
model = None
tokenizer = None
label_mapping = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load the model and tokenizer on startup"""
global model, tokenizer, label_mapping
try:
model_name = "ityndall/james-river-classifier"
logger.info(f"Loading model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Load label mapping
label_mapping_url = f"https://huggingface.co/{model_name}/resolve/main/label_mapping.json"
response = requests.get(label_mapping_url)
label_mapping = response.json()
logger.info("Model loaded successfully")
logger.info(f"Available labels: {list(label_mapping['id2label'].values())}")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise e
yield
# Cleanup (if needed)
logger.info("Shutting down...")
app = FastAPI(
title="James River Survey Classification API",
description="API for classifying survey-related text messages into job types",
version="1.0.0",
lifespan=lifespan
)
# Request model
class PredictionRequest(BaseModel):
message: str
# Response model
class PredictionResponse(BaseModel):
label: str
confidence: float
@app.get("/")
async def root():
"""Root endpoint with API information"""
return {
"message": "James River Survey Classification API",
"version": "1.0.0",
"model": "ityndall/james-river-classifier",
"available_labels": list(label_mapping["id2label"].values()) if label_mapping else [],
"endpoints": {
"predict": "/predict - POST endpoint for text classification",
"health": "/health - GET endpoint for health check"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
if model is None or tokenizer is None or label_mapping is None:
raise HTTPException(status_code=503, detail="Model not loaded")
return {"status": "healthy", "model_loaded": True}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
"""Predict the survey job type for the given message"""
if model is None or tokenizer is None or label_mapping is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
text = request.message.strip()
if not text:
raise HTTPException(status_code=400, detail="Message cannot be empty")
# Tokenize and predict
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.nn.functional.softmax(logits, dim=-1)
predicted_class_id = probs.argmax().item()
confidence = probs[0][predicted_class_id].item()
# Get label
label = label_mapping["id2label"][str(predicted_class_id)]
logger.info(f"Prediction: '{text}' -> {label} (confidence: {confidence:.3f})")
return PredictionResponse(label=label, confidence=confidence)
except Exception as e:
logger.error(f"Error during prediction: {str(e)}")
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
# Legacy endpoint for backward compatibility
@app.post("/predict_legacy")
async def predict_legacy(request: Request):
"""Legacy endpoint that accepts raw JSON (for backward compatibility)"""
try:
data = await request.json()
message = data.get("message", "")
if not message:
raise HTTPException(status_code=400, detail="Message field is required")
# Use the main predict function
prediction_request = PredictionRequest(message=message)
result = await predict(prediction_request)
return {"label": result.label, "confidence": result.confidence}
except Exception as e:
logger.error(f"Error in legacy endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)