ganeshkonapalli commited on
Commit
4d84411
·
verified ·
1 Parent(s): e94e533

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -10
app.py CHANGED
@@ -1,14 +1,53 @@
1
- from fastapi import FastAPI, Request
2
- from pydantic import BaseModel
3
- from app.model_utils import load_model, predict_label
 
4
 
5
- app = FastAPI()
6
- tokenizer, model, label_encoders = load_model()
 
 
 
 
 
 
 
7
 
8
- class InputText(BaseModel):
9
- text: str
 
 
 
 
 
 
 
10
 
11
- @app.post("/predict")
12
- def predict(input: InputText):
13
- predictions = predict_label(input.text, tokenizer, model, label_encoders)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  return predictions
 
1
+ import torch
2
+ import pickle
3
+ import torch.nn as nn
4
+ from transformers import BertTokenizer, BertModel
5
 
6
+ PRETRAINED_MODEL_NAME = 'bert-base-uncased'
7
+ LABEL_COLUMNS = [
8
+ 'Red_Flag_Reason',
9
+ 'Maker_Action',
10
+ 'Escalation_Level',
11
+ 'Risk_Category',
12
+ 'Risk_Drivers',
13
+ 'Investigation_Outcome'
14
+ ]
15
 
16
+ class BertMultiOutput(nn.Module):
17
+ def __init__(self, num_labels_per_output):
18
+ super().__init__()
19
+ self.bert = BertModel.from_pretrained(PRETRAINED_MODEL_NAME)
20
+ self.dropout = nn.Dropout(0.3)
21
+ self.classifiers = nn.ModuleList([
22
+ nn.Linear(self.bert.config.hidden_size, n_labels)
23
+ for n_labels in num_labels_per_output
24
+ ])
25
 
26
+ def forward(self, input_ids, attention_mask):
27
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
28
+ pooled_output = self.dropout(outputs.pooler_output)
29
+ logits = [classifier(pooled_output) for classifier in self.classifiers]
30
+ return logits
31
+
32
+ def load_model():
33
+ with open("bert_model.pkl", "rb") as f:
34
+ bundle = pickle.load(f)
35
+ tokenizer = bundle['tokenizer']
36
+ label_encoders = bundle['label_encoders']
37
+ num_labels = [len(le.classes_) for le in label_encoders.values()]
38
+ model = BertMultiOutput(num_labels)
39
+ model.load_state_dict(bundle['model_state_dict'])
40
+ model.eval()
41
+ return tokenizer, model, label_encoders
42
+
43
+ def predict_label(text, tokenizer, model, label_encoders):
44
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
45
+ with torch.no_grad():
46
+ outputs = model(**inputs)
47
+ predictions = {}
48
+ for i, logits in enumerate(outputs):
49
+ pred_class = torch.argmax(logits, dim=1).item()
50
+ label_name = LABEL_COLUMNS[i]
51
+ pred_label = label_encoders[label_name].inverse_transform([pred_class])[0]
52
+ predictions[label_name] = pred_label
53
  return predictions