Spaces:
Runtime error
Runtime error
File size: 4,767 Bytes
27abab4 0a6516c 27abab4 0a6516c 27abab4 0a6516c 27abab4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from fastapi import FastAPI, Request, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import requests
import logging
from contextlib import asynccontextmanager
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables for model and tokenizer
model = None
tokenizer = None
label_mapping = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load the model and tokenizer on startup"""
global model, tokenizer, label_mapping
try:
model_name = "ityndall/james-river-classifier"
logger.info(f"Loading model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Load label mapping
label_mapping_url = f"https://huggingface.co/{model_name}/resolve/main/label_mapping.json"
response = requests.get(label_mapping_url)
label_mapping = response.json()
logger.info("Model loaded successfully")
logger.info(f"Available labels: {list(label_mapping['id2label'].values())}")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise e
yield
# Cleanup (if needed)
logger.info("Shutting down...")
app = FastAPI(
title="James River Survey Classification API",
description="API for classifying survey-related text messages into job types",
version="1.0.0",
lifespan=lifespan
)
# Request model
class PredictionRequest(BaseModel):
message: str
# Response model
class PredictionResponse(BaseModel):
label: str
confidence: float
@app.get("/")
async def root():
"""Root endpoint with API information"""
return {
"message": "James River Survey Classification API",
"version": "1.0.0",
"model": "ityndall/james-river-classifier",
"available_labels": list(label_mapping["id2label"].values()) if label_mapping else [],
"endpoints": {
"predict": "/predict - POST endpoint for text classification",
"health": "/health - GET endpoint for health check"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
if model is None or tokenizer is None or label_mapping is None:
raise HTTPException(status_code=503, detail="Model not loaded")
return {"status": "healthy", "model_loaded": True}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
"""Predict the survey job type for the given message"""
if model is None or tokenizer is None or label_mapping is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
text = request.message.strip()
if not text:
raise HTTPException(status_code=400, detail="Message cannot be empty")
# Tokenize and predict
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.nn.functional.softmax(logits, dim=-1)
predicted_class_id = probs.argmax().item()
confidence = probs[0][predicted_class_id].item()
# Get label
label = label_mapping["id2label"][str(predicted_class_id)]
logger.info(f"Prediction: '{text}' -> {label} (confidence: {confidence:.3f})")
return PredictionResponse(label=label, confidence=confidence)
except Exception as e:
logger.error(f"Error during prediction: {str(e)}")
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
# Legacy endpoint for backward compatibility
@app.post("/predict_legacy")
async def predict_legacy(request: Request):
"""Legacy endpoint that accepts raw JSON (for backward compatibility)"""
try:
data = await request.json()
message = data.get("message", "")
if not message:
raise HTTPException(status_code=400, detail="Message field is required")
# Use the main predict function
prediction_request = PredictionRequest(message=message)
result = await predict(prediction_request)
return {"label": result.label, "confidence": result.confidence}
except Exception as e:
logger.error(f"Error in legacy endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|