Spaces:
Configuration error
Configuration error
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import RobertaTokenizer | |
| from models.roberta_model import RobertaMultiOutputModel | |
| from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, DEVICE | |
| from dataset_utils import load_label_encoders | |
| import numpy as np | |
| import os | |
| app = FastAPI() | |
| # Load the model and tokenizer | |
| model_path = "saved_models/ROBERTA_model.pth" # Adjust if different | |
| tokenizer = RobertaTokenizer.from_pretrained('roberta-base') | |
| # Load label encoders | |
| label_encoders = load_label_encoders() | |
| num_classes = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS] | |
| # Initialize model and load weights | |
| model = RobertaMultiOutputModel(num_classes).to(DEVICE) | |
| model.load_state_dict(torch.load(model_path, map_location=DEVICE)) | |
| model.eval() | |
| # Request format | |
| class PredictionRequest(BaseModel): | |
| sanction_context: str | |
| # Root health check | |
| async def root(): | |
| return {"status": "healthy", "message": "RoBERTa API is running"} | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| # Prediction endpoint | |
| async def predict(request: PredictionRequest): | |
| try: | |
| # Tokenize the input text | |
| inputs = tokenizer( | |
| request.sanction_context, | |
| padding='max_length', | |
| truncation=True, | |
| max_length=MAX_LEN, | |
| return_tensors="pt" | |
| ) | |
| # Move inputs to device | |
| input_ids = inputs['input_ids'].to(DEVICE) | |
| attention_mask = inputs['attention_mask'].to(DEVICE) | |
| # Get model predictions | |
| with torch.no_grad(): | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
| probabilities = [torch.softmax(output, dim=1).cpu().numpy() for output in outputs] | |
| predictions = [np.argmax(prob, axis=1) for prob in probabilities] | |
| # Format the response | |
| response = {} | |
| for i, (col, pred, prob) in enumerate(zip(LABEL_COLUMNS, predictions, probabilities)): | |
| decoded_pred = label_encoders[col].inverse_transform(pred)[0] | |
| response[col] = { | |
| "prediction": decoded_pred, | |
| "probabilities": { | |
| label: float(prob[0][j]) | |
| for j, label in enumerate(label_encoders[col].classes_) | |
| } | |
| } | |
| return response | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # For local or Spaces deployment | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |