roberta-model / app.py
subbunanepalli's picture
Upload app.py
5baf551 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import RobertaTokenizer
from models.roberta_model import RobertaMultiOutputModel
from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, DEVICE
from dataset_utils import load_label_encoders
import numpy as np
import os
app = FastAPI()
# Load the model and tokenizer
model_path = "saved_models/ROBERTA_model.pth" # Adjust if different
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
# Load label encoders
label_encoders = load_label_encoders()
num_classes = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
# Initialize model and load weights
model = RobertaMultiOutputModel(num_classes).to(DEVICE)
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.eval()
# Request format
class PredictionRequest(BaseModel):
sanction_context: str
# Root health check
@app.get("/")
async def root():
return {"status": "healthy", "message": "RoBERTa API is running"}
@app.get("/health")
async def health_check():
return {"status": "healthy"}
# Prediction endpoint
@app.post("/predict")
async def predict(request: PredictionRequest):
try:
# Tokenize the input text
inputs = tokenizer(
request.sanction_context,
padding='max_length',
truncation=True,
max_length=MAX_LEN,
return_tensors="pt"
)
# Move inputs to device
input_ids = inputs['input_ids'].to(DEVICE)
attention_mask = inputs['attention_mask'].to(DEVICE)
# Get model predictions
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
probabilities = [torch.softmax(output, dim=1).cpu().numpy() for output in outputs]
predictions = [np.argmax(prob, axis=1) for prob in probabilities]
# Format the response
response = {}
for i, (col, pred, prob) in enumerate(zip(LABEL_COLUMNS, predictions, probabilities)):
decoded_pred = label_encoders[col].inverse_transform(pred)[0]
response[col] = {
"prediction": decoded_pred,
"probabilities": {
label: float(prob[0][j])
for j, label in enumerate(label_encoders[col].classes_)
}
}
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# For local or Spaces deployment
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)