Hemang1915 commited on
Commit
ac2c610
·
verified ·
1 Parent(s): b635cad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -73
app.py CHANGED
@@ -1,73 +1,73 @@
1
- import torch
2
- import torch.nn as nn
3
- from transformers import BertTokenizer, BertModel
4
- from fastapi import FastAPI
5
- from pydantic import BaseModel
6
- import pickle
7
-
8
- # Load label encoders
9
- with open("main_category_encoder_5k.pkl", "rb") as f:
10
- main_category_encoder = pickle.load(f)
11
-
12
- with open("sub_category_encoder_5k.pkl", "rb") as f:
13
- sub_category_encoder = pickle.load(f)
14
-
15
- # Load tokenizer
16
- tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
17
-
18
- # Define the model
19
- class BERTFNN(nn.Module):
20
- def __init__(self, num_main_classes, num_sub_classes):
21
- super(BERTFNN, self).__init__()
22
- self.bert = BertModel.from_pretrained("bert-base-uncased")
23
- self.fc_main = nn.Linear(self.bert.config.hidden_size, num_main_classes)
24
- self.fc_sub = nn.Linear(self.bert.config.hidden_size + num_main_classes, num_sub_classes)
25
-
26
- def forward(self, input_ids, attention_mask):
27
- outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
28
- cls_embedding = outputs.last_hidden_state[:, 0, :]
29
- main_logits = self.fc_main(cls_embedding)
30
- main_pred = torch.softmax(main_logits, dim=1)
31
- combined_input = torch.cat((cls_embedding, main_pred), dim=1)
32
- sub_logits = self.fc_sub(combined_input)
33
- return main_logits, sub_logits
34
-
35
- # Load trained model
36
- num_main_classes = len(main_category_encoder.classes_)
37
- num_sub_classes = len(sub_category_encoder.classes_)
38
-
39
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
- model = BERTFNN(num_main_classes, num_sub_classes).to(device)
41
- model.load_state_dict(torch.load("expense_categorization_5k.pth", map_location=device))
42
- model.eval()
43
-
44
- # Initialize FastAPI
45
- app = FastAPI()
46
-
47
- # Define request body
48
- class TransactionInput(BaseModel):
49
- description: str
50
-
51
- # Define predict function
52
- @app.post("/predict")
53
- def predict_category(transaction: TransactionInput):
54
- tokens = tokenizer(transaction.description, return_tensors="pt", truncation=True, padding="max_length", max_length=64)
55
- input_ids = tokens["input_ids"].to(device)
56
- attention_mask = tokens["attention_mask"].to(device)
57
-
58
- with torch.no_grad():
59
- main_logits, sub_logits = model(input_ids, attention_mask)
60
-
61
- main_category = torch.argmax(main_logits, dim=1).cpu().item()
62
- sub_category = torch.argmax(sub_logits, dim=1).cpu().item()
63
-
64
- return {
65
- "description": transaction.description,
66
- "main_category": main_category_encoder.inverse_transform([main_category])[0],
67
- "sub_category": sub_category_encoder.inverse_transform([sub_category])[0]
68
- }
69
-
70
- # Run the API (for local testing)
71
- if __name__ == "__main__":
72
- import uvicorn
73
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import BertTokenizer, BertModel
4
+ from fastapi import FastAPI
5
+ from pydantic import BaseModel
6
+ import pickle
7
+
8
+ # Load label encoders
9
+ with open("main_category_encoder_5k.pkl", "rb") as f:
10
+ main_category_encoder = pickle.load(f)
11
+
12
+ with open("sub_category_encoder_5k.pkl", "rb") as f:
13
+ sub_category_encoder = pickle.load(f)
14
+
15
+ # Load tokenizer
16
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
17
+
18
+ # Define the model
19
+ class BERTFNN(nn.Module):
20
+ def __init__(self, num_main_classes, num_sub_classes):
21
+ super(BERTFNN, self).__init__()
22
+ self.bert = BertModel.from_pretrained("bert-base-uncased")
23
+ self.fc_main = nn.Linear(self.bert.config.hidden_size, num_main_classes)
24
+ self.fc_sub = nn.Linear(self.bert.config.hidden_size + num_main_classes, num_sub_classes)
25
+
26
+ def forward(self, input_ids, attention_mask):
27
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
28
+ cls_embedding = outputs.last_hidden_state[:, 0, :]
29
+ main_logits = self.fc_main(cls_embedding)
30
+ main_pred = torch.softmax(main_logits, dim=1)
31
+ combined_input = torch.cat((cls_embedding, main_pred), dim=1)
32
+ sub_logits = self.fc_sub(combined_input)
33
+ return main_logits, sub_logits
34
+
35
+ # Load trained model
36
+ num_main_classes = len(main_category_encoder.classes_)
37
+ num_sub_classes = len(sub_category_encoder.classes_)
38
+
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ model = BERTFNN(num_main_classes, num_sub_classes).to(device)
41
+ model.load_state_dict(torch.load("expense_categorization_5k.pth", map_location=device))
42
+ model.eval()
43
+
44
+ # Initialize FastAPI
45
+ app = FastAPI()
46
+
47
+ # Define request body
48
+ class TransactionInput(BaseModel):
49
+ description: str
50
+
51
+ # Define predict function
52
+ @app.post("/predict")
53
+ def predict_category(transaction: TransactionInput):
54
+ tokens = tokenizer(transaction.description, return_tensors="pt", truncation=True, padding="max_length", max_length=64)
55
+ input_ids = tokens["input_ids"].to(device)
56
+ attention_mask = tokens["attention_mask"].to(device)
57
+
58
+ with torch.no_grad():
59
+ main_logits, sub_logits = model(input_ids, attention_mask)
60
+
61
+ main_category = torch.argmax(main_logits, dim=1).cpu().item()
62
+ sub_category = torch.argmax(sub_logits, dim=1).cpu().item()
63
+
64
+ return {
65
+ "description": transaction.description,
66
+ "main_category": main_category_encoder.inverse_transform([main_category])[0],
67
+ "sub_category": sub_category_encoder.inverse_transform([sub_category])[0]
68
+ }
69
+
70
+ # Run the API (for local testing)
71
+ if __name__ == "__main__":
72
+ import uvicorn
73
+ uvicorn.run(app, host="0.0.0.0", port=8000)