Spaces:
Runtime error
Runtime error
| import os | |
| import requests | |
| MODEL_PATH = "/tmp/bert_model.pth" | |
| FILE_ID = "1qqmBxbxM0CmxPGC4sqO6vLJAe-Kikiv4" | |
| def download_from_google_drive(file_id, dest_path): | |
| URL = "https://docs.google.com/uc?export=download" | |
| session = requests.Session() | |
| response = session.get(URL, params={'id': file_id}, stream=True) | |
| def get_confirm_token(response): | |
| for key, value in response.cookies.items(): | |
| if key.startswith('download_warning'): | |
| return value | |
| return None | |
| token = get_confirm_token(response) | |
| if token: | |
| params = {'id': file_id, 'confirm': token} | |
| response = session.get(URL, params=params, stream=True) | |
| with open(dest_path, "wb") as f: | |
| for chunk in response.iter_content(32768): | |
| if chunk: | |
| f.write(chunk) | |
| if not os.path.exists(MODEL_PATH): | |
| print("Downloading model from Google Drive...") | |
| download_from_google_drive(FILE_ID, MODEL_PATH) | |
| import torch | |
| import torch.nn as nn | |
| from transformers import BertTokenizer, BertModel | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| LABEL_COLUMNS = [ | |
| 'Red_Flag_Reason', 'Maker_Action', 'Escalation_Level', | |
| 'Risk_Category', 'Risk_Drivers', 'Investigation_Outcome' | |
| ] | |
| PRETRAINED_MODEL_NAME = 'bert-base-uncased' | |
| MAX_LEN = 128 | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| class BertMultiOutput(nn.Module): | |
| def __init__(self, num_labels_per_output): | |
| super().__init__() | |
| self.bert = BertModel.from_pretrained(PRETRAINED_MODEL_NAME) | |
| self.dropout = nn.Dropout(0.3) | |
| self.classifiers = nn.ModuleList([ | |
| nn.Linear(self.bert.config.hidden_size, n_labels) | |
| for n_labels in num_labels_per_output | |
| ]) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled_output = self.dropout(outputs.pooler_output) | |
| logits = [classifier(pooled_output) for classifier in self.classifiers] | |
| return logits | |
| checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False) | |
| label_encoders = checkpoint['label_encoders'] | |
| num_labels_list = [len(le.classes_) for le in label_encoders.values()] | |
| model = BertMultiOutput(num_labels_list).to(DEVICE) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| tokenizer = BertTokenizer.from_pretrained("bert_tokenizer/") | |
| app = FastAPI() | |
| class PredictRequest(BaseModel): | |
| text: str | |
| def root(): | |
| return {"message": "Multi-output BERT is ready!"} | |
| def predict(request: PredictRequest): | |
| inputs = tokenizer( | |
| request.text, | |
| truncation=True, | |
| padding='max_length', | |
| max_length=MAX_LEN, | |
| return_tensors="pt" | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| preds = [torch.argmax(output, dim=1).item() for output in outputs] | |
| decoded = { | |
| label: label_encoders[label].inverse_transform([pred])[0] | |
| for label, pred in zip(LABEL_COLUMNS, preds) | |
| } | |
| return {"predictions": decoded} | |