ganeshkonapalli commited on
Commit
3b69eed
·
verified ·
1 Parent(s): a615712

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +6 -0
  2. bert_multioutput_model.pth +3 -0
  3. main.py +59 -0
  4. requirements.txt +6 -0
Dockerfile ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+ WORKDIR /app
3
+ COPY requirements.txt .
4
+ RUN pip install --no-cache-dir -r requirements.txt
5
+ COPY . .
6
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
bert_multioutput_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69f4ba092be2d47ddd20ce865d13ff92795af25cb29301f657b442c68ddad6fa
3
+ size 35
main.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import BertTokenizer, BertModel
4
+ from fastapi import FastAPI
5
+ from pydantic import BaseModel
6
+ import pandas as pd
7
+ from sklearn.preprocessing import LabelEncoder
8
+
9
+ MODEL_PATH = "bert_multioutput_model.pth"
10
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ LABEL_COLUMNS = ["Red_Flag_Reason", "Maker_Action", "Escalation_Level", "Risk_Category", "Risk_Drivers", "Investigation_Outcome"]
12
+
13
+ class InputText(BaseModel):
14
+ text: str
15
+
16
+ class MultiOutputBERT(nn.Module):
17
+ def __init__(self, num_classes_per_label):
18
+ super(MultiOutputBERT, self).__init__()
19
+ self.bert = BertModel.from_pretrained('bert-base-uncased')
20
+ self.dropout = nn.Dropout(0.3)
21
+ self.classifiers = nn.ModuleList([
22
+ nn.Linear(self.bert.config.hidden_size, num_classes)
23
+ for num_classes in num_classes_per_label
24
+ ])
25
+
26
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None):
27
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
28
+ pooled_output = self.dropout(outputs.pooler_output)
29
+ logits = [classifier(pooled_output) for classifier in self.classifiers]
30
+ return logits
31
+
32
+ checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
33
+ num_classes_per_label = checkpoint["num_classes_per_label"]
34
+ label_encoders = checkpoint["label_encoders"]
35
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
36
+
37
+ model = MultiOutputBERT(num_classes_per_label)
38
+ model.load_state_dict(checkpoint["model_state_dict"])
39
+ model.to(DEVICE)
40
+ model.eval()
41
+
42
+ app = FastAPI()
43
+
44
+ @app.get("/")
45
+ def home():
46
+ return {"message": "✅ Multi-output BERT API is live."}
47
+
48
+ @app.post("/predict")
49
+ def predict(request: InputText):
50
+ inputs = tokenizer(request.text, return_tensors="pt", truncation=True, padding=True, max_length=128)
51
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
52
+ with torch.no_grad():
53
+ logits = model(**inputs)
54
+ predictions = {}
55
+ for i, logit in enumerate(logits):
56
+ pred_idx = torch.argmax(logit, dim=1).item()
57
+ label = label_encoders[LABEL_COLUMNS[i]].inverse_transform([pred_idx])[0]
58
+ predictions[LABEL_COLUMNS[i]] = label
59
+ return predictions
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi==0.110.0
2
+ uvicorn==0.29.0
3
+ torch
4
+ transformers
5
+ scikit-learn
6
+ pandas