Hemang1915 commited on
Commit
3f09e7b
·
verified ·
1 Parent(s): dd7fc5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -5,6 +5,7 @@ from fastapi import FastAPI, Request
5
  from pydantic import BaseModel
6
  import pickle
7
  import logging
 
8
 
9
  # Set up logging
10
  logging.basicConfig(level=logging.INFO)
@@ -20,9 +21,10 @@ except FileNotFoundError as e:
20
  logger.error(f"Failed to load label encoders: {e}")
21
  raise
22
 
23
- # Load tokenizer
24
  try:
25
- tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
 
26
  except Exception as e:
27
  logger.error(f"Failed to load tokenizer: {e}")
28
  raise
@@ -31,7 +33,7 @@ except Exception as e:
31
  class BERTFNN(nn.Module):
32
  def __init__(self, num_main_classes, num_sub_classes):
33
  super(BERTFNN, self).__init__()
34
- self.bert = BertModel.from_pretrained("bert-base-uncased")
35
  self.fc_main = nn.Linear(self.bert.config.hidden_size, num_main_classes)
36
  self.fc_sub = nn.Linear(self.bert.config.hidden_size + num_main_classes, num_sub_classes)
37
 
@@ -52,6 +54,7 @@ try:
52
  model = BERTFNN(num_main_classes, num_sub_classes).to(device)
53
  model.load_state_dict(torch.load("expense_categorization_5k.pth", map_location=device))
54
  model.eval()
 
55
  except FileNotFoundError as e:
56
  logger.error(f"Failed to load model weights: {e}")
57
  raise
 
5
  from pydantic import BaseModel
6
  import pickle
7
  import logging
8
+ import os
9
 
10
  # Set up logging
11
  logging.basicConfig(level=logging.INFO)
 
21
  logger.error(f"Failed to load label encoders: {e}")
22
  raise
23
 
24
+ # Load tokenizer from local directory
25
  try:
26
+ tokenizer = BertTokenizer.from_pretrained("./tokenizer")
27
+ logger.info("Tokenizer loaded successfully")
28
  except Exception as e:
29
  logger.error(f"Failed to load tokenizer: {e}")
30
  raise
 
33
  class BERTFNN(nn.Module):
34
  def __init__(self, num_main_classes, num_sub_classes):
35
  super(BERTFNN, self).__init__()
36
+ self.bert = BertModel.from_pretrained("bert-base-uncased", cache_dir="./cache")
37
  self.fc_main = nn.Linear(self.bert.config.hidden_size, num_main_classes)
38
  self.fc_sub = nn.Linear(self.bert.config.hidden_size + num_main_classes, num_sub_classes)
39
 
 
54
  model = BERTFNN(num_main_classes, num_sub_classes).to(device)
55
  model.load_state_dict(torch.load("expense_categorization_5k.pth", map_location=device))
56
  model.eval()
57
+ logger.info("Model loaded successfully")
58
  except FileNotFoundError as e:
59
  logger.error(f"Failed to load model weights: {e}")
60
  raise