File size: 4,767 Bytes
27abab4
 
 
 
 
 
0a6516c
27abab4
 
 
 
 
 
 
 
 
 
0a6516c
 
27abab4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a6516c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27abab4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from fastapi import FastAPI, Request, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import requests
import logging
from contextlib import asynccontextmanager

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global variables for model and tokenizer
model = None
tokenizer = None
label_mapping = None

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Load the model and tokenizer on startup"""
    global model, tokenizer, label_mapping
    
    try:
        model_name = "ityndall/james-river-classifier"
        logger.info(f"Loading model: {model_name}")
        
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSequenceClassification.from_pretrained(model_name)
        
        # Load label mapping
        label_mapping_url = f"https://huggingface.co/{model_name}/resolve/main/label_mapping.json"
        response = requests.get(label_mapping_url)
        label_mapping = response.json()
        
        logger.info("Model loaded successfully")
        logger.info(f"Available labels: {list(label_mapping['id2label'].values())}")
        
    except Exception as e:
        logger.error(f"Error loading model: {str(e)}")
        raise e
    
    yield
    
    # Cleanup (if needed)
    logger.info("Shutting down...")

app = FastAPI(
    title="James River Survey Classification API",
    description="API for classifying survey-related text messages into job types",
    version="1.0.0",
    lifespan=lifespan
)

# Request model
class PredictionRequest(BaseModel):
    message: str

# Response model
class PredictionResponse(BaseModel):
    label: str
    confidence: float

@app.get("/")
async def root():
    """Root endpoint with API information"""
    return {
        "message": "James River Survey Classification API",
        "version": "1.0.0",
        "model": "ityndall/james-river-classifier",
        "available_labels": list(label_mapping["id2label"].values()) if label_mapping else [],
        "endpoints": {
            "predict": "/predict - POST endpoint for text classification",
            "health": "/health - GET endpoint for health check"
        }
    }

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    if model is None or tokenizer is None or label_mapping is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    return {"status": "healthy", "model_loaded": True}

@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
    """Predict the survey job type for the given message"""
    
    if model is None or tokenizer is None or label_mapping is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    try:
        text = request.message.strip()
        if not text:
            raise HTTPException(status_code=400, detail="Message cannot be empty")
        
        # Tokenize and predict
        inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
        
        with torch.no_grad():
            logits = model(**inputs).logits
            probs = torch.nn.functional.softmax(logits, dim=-1)
            predicted_class_id = probs.argmax().item()
            confidence = probs[0][predicted_class_id].item()
        
        # Get label
        label = label_mapping["id2label"][str(predicted_class_id)]
        
        logger.info(f"Prediction: '{text}' -> {label} (confidence: {confidence:.3f})")
        
        return PredictionResponse(label=label, confidence=confidence)
        
    except Exception as e:
        logger.error(f"Error during prediction: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")

# Legacy endpoint for backward compatibility
@app.post("/predict_legacy")
async def predict_legacy(request: Request):
    """Legacy endpoint that accepts raw JSON (for backward compatibility)"""
    try:
        data = await request.json()
        message = data.get("message", "")
        
        if not message:
            raise HTTPException(status_code=400, detail="Message field is required")
        
        # Use the main predict function
        prediction_request = PredictionRequest(message=message)
        result = await predict(prediction_request)
        
        return {"label": result.label, "confidence": result.confidence}
        
    except Exception as e:
        logger.error(f"Error in legacy endpoint: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)