Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import BertTokenizer | |
| from models.bert_model import BertMultiOutputModel | |
| from config import TEXT_COLUMN, LABEL_COLUMNS, MAX_LEN, DEVICE | |
| from dataset_utils import load_label_encoders | |
| import numpy as np | |
| import os | |
| app = FastAPI() | |
| # Load the model and tokenizer | |
| model_path = "BERT_model.pth" | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| model = BertMultiOutputModel([len(load_label_encoders()[col].classes_) for col in LABEL_COLUMNS]).to(DEVICE) | |
| model.load_state_dict(torch.load(model_path, map_location=DEVICE)) | |
| model.eval() | |
| class PredictionRequest(BaseModel): | |
| sanction_context: str | |
| async def root(): | |
| return {"status": "healthy", "message": "BERT API is running"} | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| async def predict(request: PredictionRequest): | |
| try: | |
| # Tokenize the input text | |
| inputs = tokenizer( | |
| request.sanction_context, | |
| padding='max_length', | |
| truncation=True, | |
| max_length=MAX_LEN, | |
| return_tensors="pt" | |
| ) | |
| # Move inputs to device | |
| input_ids = inputs['input_ids'].to(DEVICE) | |
| attention_mask = inputs['attention_mask'].to(DEVICE) | |
| # Get predictions | |
| with torch.no_grad(): | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
| probabilities = [torch.softmax(output, dim=1).cpu().numpy() for output in outputs] | |
| predictions = [np.argmax(prob, axis=1) for prob in probabilities] | |
| # Load label encoders to decode predictions | |
| label_encoders = load_label_encoders() | |
| # Format response | |
| response = {} | |
| for i, (col, pred, prob) in enumerate(zip(LABEL_COLUMNS, predictions, probabilities)): | |
| decoded_pred = label_encoders[col].inverse_transform(pred)[0] | |
| response[col] = { | |
| "prediction": decoded_pred, | |
| "probabilities": { | |
| label: float(prob[0][j]) | |
| for j, label in enumerate(label_encoders[col].classes_) | |
| } | |
| } | |
| return response | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # For Hugging Face Spaces, we need to use port 7860 | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |