File size: 2,757 Bytes
5baf551
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import RobertaTokenizer
from models.roberta_model import RobertaMultiOutputModel
from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, DEVICE
from dataset_utils import load_label_encoders
import numpy as np
import os

app = FastAPI()

# Load the model and tokenizer
model_path = "saved_models/ROBERTA_model.pth"  # Adjust if different
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

# Load label encoders
label_encoders = load_label_encoders()
num_classes = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]

# Initialize model and load weights
model = RobertaMultiOutputModel(num_classes).to(DEVICE)
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.eval()

# Request format
class PredictionRequest(BaseModel):
    sanction_context: str

# Root health check
@app.get("/")
async def root():
    return {"status": "healthy", "message": "RoBERTa API is running"}

@app.get("/health")
async def health_check():
    return {"status": "healthy"}

# Prediction endpoint
@app.post("/predict")
async def predict(request: PredictionRequest):
    try:
        # Tokenize the input text
        inputs = tokenizer(
            request.sanction_context,
            padding='max_length',
            truncation=True,
            max_length=MAX_LEN,
            return_tensors="pt"
        )
        
        # Move inputs to device
        input_ids = inputs['input_ids'].to(DEVICE)
        attention_mask = inputs['attention_mask'].to(DEVICE)

        # Get model predictions
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            probabilities = [torch.softmax(output, dim=1).cpu().numpy() for output in outputs]
            predictions = [np.argmax(prob, axis=1) for prob in probabilities]

        # Format the response
        response = {}
        for i, (col, pred, prob) in enumerate(zip(LABEL_COLUMNS, predictions, probabilities)):
            decoded_pred = label_encoders[col].inverse_transform(pred)[0]
            response[col] = {
                "prediction": decoded_pred,
                "probabilities": {
                    label: float(prob[0][j]) 
                    for j, label in enumerate(label_encoders[col].classes_)
                }
            }

        return response

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# For local or Spaces deployment
if __name__ == "__main__":
    import uvicorn
    port = int(os.environ.get("PORT", 7860))
    uvicorn.run(app, host="0.0.0.0", port=port)