from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch from transformers import RobertaTokenizer from models.roberta_model import RobertaMultiOutputModel from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, DEVICE from dataset_utils import load_label_encoders import numpy as np import os app = FastAPI() # Load the model and tokenizer model_path = "saved_models/ROBERTA_model.pth" # Adjust if different tokenizer = RobertaTokenizer.from_pretrained('roberta-base') # Load label encoders label_encoders = load_label_encoders() num_classes = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS] # Initialize model and load weights model = RobertaMultiOutputModel(num_classes).to(DEVICE) model.load_state_dict(torch.load(model_path, map_location=DEVICE)) model.eval() # Request format class PredictionRequest(BaseModel): sanction_context: str # Root health check @app.get("/") async def root(): return {"status": "healthy", "message": "RoBERTa API is running"} @app.get("/health") async def health_check(): return {"status": "healthy"} # Prediction endpoint @app.post("/predict") async def predict(request: PredictionRequest): try: # Tokenize the input text inputs = tokenizer( request.sanction_context, padding='max_length', truncation=True, max_length=MAX_LEN, return_tensors="pt" ) # Move inputs to device input_ids = inputs['input_ids'].to(DEVICE) attention_mask = inputs['attention_mask'].to(DEVICE) # Get model predictions with torch.no_grad(): outputs = model(input_ids=input_ids, attention_mask=attention_mask) probabilities = [torch.softmax(output, dim=1).cpu().numpy() for output in outputs] predictions = [np.argmax(prob, axis=1) for prob in probabilities] # Format the response response = {} for i, (col, pred, prob) in enumerate(zip(LABEL_COLUMNS, predictions, probabilities)): decoded_pred = label_encoders[col].inverse_transform(pred)[0] response[col] = { "prediction": decoded_pred, "probabilities": { label: float(prob[0][j]) for j, label in enumerate(label_encoders[col].classes_) } } return response except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # For local or Spaces deployment if __name__ == "__main__": import uvicorn port = int(os.environ.get("PORT", 7860)) uvicorn.run(app, host="0.0.0.0", port=port)