import os import requests MODEL_PATH = "/tmp/bert_model.pth" FILE_ID = "1qqmBxbxM0CmxPGC4sqO6vLJAe-Kikiv4" def download_from_google_drive(file_id, dest_path): URL = "https://docs.google.com/uc?export=download" session = requests.Session() response = session.get(URL, params={'id': file_id}, stream=True) def get_confirm_token(response): for key, value in response.cookies.items(): if key.startswith('download_warning'): return value return None token = get_confirm_token(response) if token: params = {'id': file_id, 'confirm': token} response = session.get(URL, params=params, stream=True) with open(dest_path, "wb") as f: for chunk in response.iter_content(32768): if chunk: f.write(chunk) if not os.path.exists(MODEL_PATH): print("Downloading model from Google Drive...") download_from_google_drive(FILE_ID, MODEL_PATH) import torch import torch.nn as nn from transformers import BertTokenizer, BertModel from fastapi import FastAPI from pydantic import BaseModel LABEL_COLUMNS = [ 'Red_Flag_Reason', 'Maker_Action', 'Escalation_Level', 'Risk_Category', 'Risk_Drivers', 'Investigation_Outcome' ] PRETRAINED_MODEL_NAME = 'bert-base-uncased' MAX_LEN = 128 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") class BertMultiOutput(nn.Module): def __init__(self, num_labels_per_output): super().__init__() self.bert = BertModel.from_pretrained(PRETRAINED_MODEL_NAME) self.dropout = nn.Dropout(0.3) self.classifiers = nn.ModuleList([ nn.Linear(self.bert.config.hidden_size, n_labels) for n_labels in num_labels_per_output ]) def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) pooled_output = self.dropout(outputs.pooler_output) logits = [classifier(pooled_output) for classifier in self.classifiers] return logits checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False) label_encoders = checkpoint['label_encoders'] num_labels_list = [len(le.classes_) for le in label_encoders.values()] model = BertMultiOutput(num_labels_list).to(DEVICE) model.load_state_dict(checkpoint['model_state_dict']) model.eval() tokenizer = BertTokenizer.from_pretrained("bert_tokenizer/") app = FastAPI() class PredictRequest(BaseModel): text: str @app.get("/") def root(): return {"message": "Multi-output BERT is ready!"} @app.post("/predict") def predict(request: PredictRequest): inputs = tokenizer( request.text, truncation=True, padding='max_length', max_length=MAX_LEN, return_tensors="pt" ).to(DEVICE) with torch.no_grad(): outputs = model(**inputs) preds = [torch.argmax(output, dim=1).item() for output in outputs] decoded = { label: label_encoders[label].inverse_transform([pred])[0] for label, pred in zip(LABEL_COLUMNS, preds) } return {"predictions": decoded}