Hemang1915 commited on
Commit
d2bb04d
·
verified ·
1 Parent(s): cb557cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -5
app.py CHANGED
@@ -73,22 +73,41 @@ class TransactionInput(BaseModel):
73
  async def root():
74
  return {"message": "Welcome to the Expense Categorization API. Use POST /predict to categorize expenses."}
75
 
76
- # Define predict endpoint
77
  @app.post("/predict")
78
  async def predict_category(transaction: TransactionInput, request: Request):
79
  try:
80
  logger.info(f"Received request: {transaction.dict()}")
 
81
  tokens = tokenizer(transaction.description, return_tensors="pt", truncation=True, padding="max_length", max_length=64)
82
  input_ids = tokens["input_ids"].to(device)
83
  attention_mask = tokens["attention_mask"].to(device)
 
 
84
  with torch.no_grad():
85
  main_logits, sub_logits = model(input_ids, attention_mask)
86
- main_category = torch.argmax(main_logits, dim=1).cpu().item()
87
- sub_category = torch.argmax(sub_logits, dim=1).cpu().item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  response = {
89
  "description": transaction.description,
90
- "main_category": main_category_encoder.inverse_transform([main_category])[0],
91
- "sub_category": sub_category_encoder.inverse_transform([sub_category])[0]
 
 
92
  }
93
  logger.info(f"Response: {response}")
94
  return response
 
73
  async def root():
74
  return {"message": "Welcome to the Expense Categorization API. Use POST /predict to categorize expenses."}
75
 
76
+ # Define predict endpoint with confidence scores
77
  @app.post("/predict")
78
  async def predict_category(transaction: TransactionInput, request: Request):
79
  try:
80
  logger.info(f"Received request: {transaction.dict()}")
81
+ # Tokenize input
82
  tokens = tokenizer(transaction.description, return_tensors="pt", truncation=True, padding="max_length", max_length=64)
83
  input_ids = tokens["input_ids"].to(device)
84
  attention_mask = tokens["attention_mask"].to(device)
85
+
86
+ # Get model predictions
87
  with torch.no_grad():
88
  main_logits, sub_logits = model(input_ids, attention_mask)
89
+
90
+ # Compute softmax probabilities for main category
91
+ main_probs = torch.softmax(main_logits, dim=1)
92
+ main_category_idx = torch.argmax(main_probs, dim=1).cpu().item()
93
+ main_confidence = main_probs[0, main_category_idx].cpu().item()
94
+
95
+ # Compute softmax probabilities for subcategory
96
+ sub_probs = torch.softmax(sub_logits, dim=1)
97
+ sub_category_idx = torch.argmax(sub_probs, dim=1).cpu().item()
98
+ sub_confidence = sub_probs[0, sub_category_idx].cpu().item()
99
+
100
+ # Decode category labels
101
+ main_category = main_category_encoder.inverse_transform([main_category_idx])[0]
102
+ sub_category = sub_category_encoder.inverse_transform([sub_category_idx])[0]
103
+
104
+ # Prepare response
105
  response = {
106
  "description": transaction.description,
107
+ "main_category": main_category,
108
+ "main_confidence": round(main_confidence, 4),
109
+ "sub_category": sub_category,
110
+ "sub_confidence": round(sub_confidence, 4)
111
  }
112
  logger.info(f"Response: {response}")
113
  return response