Hemang1915 commited on
Commit
ca4fd98
·
verified ·
1 Parent(s): be6f038

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -6
app.py CHANGED
@@ -5,13 +5,53 @@ from fastapi import FastAPI, Request
5
  from pydantic import BaseModel
6
  import pickle
7
  import logging
 
 
 
 
 
8
 
9
  # Set up logging
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # Load label encoders
14
  try:
 
15
  with open("main_category_encoder_5k.pkl", "rb") as f:
16
  main_category_encoder = pickle.load(f)
17
  with open("sub_category_encoder_5k.pkl", "rb") as f:
@@ -32,7 +72,7 @@ except Exception as e:
32
  class BERTFNN(nn.Module):
33
  def __init__(self, num_main_classes, num_sub_classes):
34
  super(BERTFNN, self).__init__()
35
- self.bert = BertModel.from_pretrained("./bert-model") # Load locally
36
  self.fc_main = nn.Linear(self.bert.config.hidden_size, num_main_classes)
37
  self.fc_sub = nn.Linear(self.bert.config.hidden_size + num_main_classes, num_sub_classes)
38
 
@@ -73,19 +113,27 @@ class TransactionInput(BaseModel):
73
  async def root():
74
  return {"message": "Welcome to the Expense Categorization API. Use POST /predict to categorize expenses."}
75
 
76
- # Define predict endpoint with confidence scores
77
  @app.post("/predict")
78
  async def predict_category(transaction: TransactionInput, request: Request):
 
79
  try:
80
  logger.info(f"Received request: {transaction.dict()}")
81
- # Tokenize input
82
- tokens = tokenizer(transaction.description, return_tensors="pt", truncation=True, padding="max_length", max_length=64)
 
 
 
 
 
83
  input_ids = tokens["input_ids"].to(device)
84
  attention_mask = tokens["attention_mask"].to(device)
 
85
 
86
  # Get model predictions
87
  with torch.no_grad():
88
  main_logits, sub_logits = model(input_ids, attention_mask)
 
89
 
90
  # Compute softmax probabilities for main category
91
  main_probs = torch.softmax(main_logits, dim=1)
@@ -100,10 +148,12 @@ async def predict_category(transaction: TransactionInput, request: Request):
100
  # Decode category labels
101
  main_category = main_category_encoder.inverse_transform([main_category_idx])[0]
102
  sub_category = sub_category_encoder.inverse_transform([sub_category_idx])[0]
 
103
 
104
  # Prepare response
105
  response = {
106
- "description": transaction.description,
 
107
  "main_category": main_category,
108
  "main_confidence": round(main_confidence, 4),
109
  "sub_category": sub_category,
@@ -112,5 +162,5 @@ async def predict_category(transaction: TransactionInput, request: Request):
112
  logger.info(f"Response: {response}")
113
  return response
114
  except Exception as e:
115
- logger.error(f"Error in prediction: {e}")
116
  return {"error": str(e)}, 500
 
5
  from pydantic import BaseModel
6
  import pickle
7
  import logging
8
+ import os
9
+ import re
10
+
11
+ # Set Hugging Face cache directory
12
+ os.environ["TRANSFORMERS_CACHE"] = "/path/to/writable/cache" # Replace with a writable path
13
 
14
  # Set up logging
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
+ # Text cleaning function
19
+ def clean_description(text):
20
+ """
21
+ Clean transaction description by removing prefixes, numeric codes, and separators.
22
+ Examples:
23
+ 'UPI/DR/12345678/netflix subscription' -> 'netflix subscription'
24
+ 'UPI-DR-12345678-netflix subscription' -> 'netflix subscription'
25
+ 'VISA/123456/uber ride to office' -> 'uber ride to office'
26
+ """
27
+ # Convert to lowercase (optional, depending on model training)
28
+ text = text.lower()
29
+
30
+ # Remove common transaction prefixes and codes
31
+ patterns = [
32
+ r'^upi/dr/[0-9]+/', # Matches 'UPI/DR/12345678/'
33
+ r'^upi-dr-[0-9]+-', # Matches 'UPI-DR-12345678-'
34
+ r'^visa/[0-9]+/', # Matches 'VISA/123456/'
35
+ r'^[a-zA-Z]+/[0-9]+/', # Matches other prefixes like 'POS/123456/'
36
+ r'^[a-zA-Z]+-[0-9]+-', # Matches other prefixes like 'POS-123456-'
37
+ r'\b[0-9]{6,}\b', # Matches standalone numeric codes (6+ digits)
38
+ ]
39
+
40
+ for pattern in patterns:
41
+ text = re.sub(pattern, '', text, flags=re.IGNORECASE)
42
+
43
+ # Replace multiple separators with a single space
44
+ text = re.sub(r'[-_/]+', ' ', text)
45
+
46
+ # Remove extra whitespace
47
+ text = ' '.join(text.split())
48
+
49
+ # Return cleaned text, or original if cleaning results in empty string
50
+ return text if text else "unknown transaction"
51
+
52
  # Load label encoders
53
  try:
54
+ # Note: Ensure these were pickled with scikit-learn 1.6.1 to avoid version mismatch
55
  with open("main_category_encoder_5k.pkl", "rb") as f:
56
  main_category_encoder = pickle.load(f)
57
  with open("sub_category_encoder_5k.pkl", "rb") as f:
 
72
  class BERTFNN(nn.Module):
73
  def __init__(self, num_main_classes, num_sub_classes):
74
  super(BERTFNN, self).__init__()
75
+ self.bert = BertModel.from_pretrained("./bert-model")
76
  self.fc_main = nn.Linear(self.bert.config.hidden_size, num_main_classes)
77
  self.fc_sub = nn.Linear(self.bert.config.hidden_size + num_main_classes, num_sub_classes)
78
 
 
113
  async def root():
114
  return {"message": "Welcome to the Expense Categorization API. Use POST /predict to categorize expenses."}
115
 
116
+ # Define predict endpoint with text cleaning and confidence scores
117
  @app.post("/predict")
118
  async def predict_category(transaction: TransactionInput, request: Request):
119
+ logger.info("Starting prediction for request")
120
  try:
121
  logger.info(f"Received request: {transaction.dict()}")
122
+
123
+ # Clean the input description
124
+ cleaned_description = clean_description(transaction.description)
125
+ logger.info(f"Cleaned description: {cleaned_description}")
126
+
127
+ # Tokenize cleaned description
128
+ tokens = tokenizer(cleaned_description, return_tensors="pt", truncation=True, padding="max_length", max_length=64)
129
  input_ids = tokens["input_ids"].to(device)
130
  attention_mask = tokens["attention_mask"].to(device)
131
+ logger.info("Tokenization completed")
132
 
133
  # Get model predictions
134
  with torch.no_grad():
135
  main_logits, sub_logits = model(input_ids, attention_mask)
136
+ logger.info("Model inference completed")
137
 
138
  # Compute softmax probabilities for main category
139
  main_probs = torch.softmax(main_logits, dim=1)
 
148
  # Decode category labels
149
  main_category = main_category_encoder.inverse_transform([main_category_idx])[0]
150
  sub_category = sub_category_encoder.inverse_transform([sub_category_idx])[0]
151
+ logger.info("Category decoding completed")
152
 
153
  # Prepare response
154
  response = {
155
+ "original_description": transaction.description,
156
+ "cleaned_description": cleaned_description,
157
  "main_category": main_category,
158
  "main_confidence": round(main_confidence, 4),
159
  "sub_category": sub_category,
 
162
  logger.info(f"Response: {response}")
163
  return response
164
  except Exception as e:
165
+ logger.error(f"Error in prediction: {str(e)}", exc_info=True)
166
  return {"error": str(e)}, 500