ganeshkonapalli commited on
Commit
3212eba
·
verified ·
1 Parent(s): 7faa9a5

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +12 -0
  2. deberta_model.pkl +3 -0
  3. main.py +63 -0
  4. requirements.txt +5 -0
Dockerfile ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./app /code/app
6
+
7
+ RUN pip install --upgrade pip
8
+ RUN pip install -r app/requirements.txt
9
+
10
+ EXPOSE 7860
11
+
12
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
deberta_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cdafeca39a5bdbb2b3b1dd7d628de614be1ad34d1064a6dfdec8884e999414bc
3
+ size 556709919
main.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ import torch
4
+ import torch.nn as nn
5
+ import pickle
6
+ from transformers import DebertaModel, DebertaTokenizer
7
+ import uvicorn
8
+
9
+ LABEL_COLUMNS = ['Red_Flag_Reason', 'Maker_Action', 'Escalation_Level',
10
+ 'Risk_Category', 'Risk_Drivers', 'Investigation_Outcome']
11
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ class InputText(BaseModel):
14
+ text: str
15
+
16
+ with open("app/deberta_model.pkl", "rb") as f:
17
+ checkpoint = pickle.load(f)
18
+
19
+ tokenizer = checkpoint['tokenizer']
20
+ label_encoders = checkpoint['label_encoders']
21
+
22
+ class DebertaMultiOutput(nn.Module):
23
+ def __init__(self, num_labels_per_output):
24
+ super().__init__()
25
+ self.deberta = DebertaModel.from_pretrained("microsoft/deberta-base")
26
+ self.dropout = nn.Dropout(0.3)
27
+ self.classifiers = nn.ModuleList([
28
+ nn.Linear(self.deberta.config.hidden_size, n_labels) for n_labels in num_labels_per_output
29
+ ])
30
+
31
+ def forward(self, input_ids, attention_mask):
32
+ outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
33
+ pooled = self.dropout(outputs.last_hidden_state[:, 0])
34
+ return [classifier(pooled) for classifier in self.classifiers]
35
+
36
+ num_labels = [len(le.classes_) for le in label_encoders.values()]
37
+ model = DebertaMultiOutput(num_labels)
38
+ model.load_state_dict(checkpoint['model_state_dict'])
39
+ model.to(DEVICE)
40
+ model.eval()
41
+
42
+ app = FastAPI()
43
+
44
+ @app.get("/")
45
+ def root():
46
+ return {"message": "🟢 DeBERTa multi-output classifier ready."}
47
+
48
+ @app.post("/predict")
49
+ def predict(input: InputText):
50
+ inputs = tokenizer(input.text, return_tensors="pt", truncation=True, padding=True, max_length=128)
51
+ input_ids = inputs['input_ids'].to(DEVICE)
52
+ attention_mask = inputs['attention_mask'].to(DEVICE)
53
+
54
+ with torch.no_grad():
55
+ outputs = model(input_ids, attention_mask)
56
+
57
+ preds = {}
58
+ for output, col, le in zip(outputs, LABEL_COLUMNS, label_encoders.values()):
59
+ pred_idx = torch.argmax(output, dim=1).item()
60
+ pred_label = le.inverse_transform([pred_idx])[0]
61
+ preds[col] = pred_label
62
+
63
+ return preds
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ torch
4
+ transformers
5
+ scikit-learn