Hemang1915's picture
Update app.py
4e9b7ce verified
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
from fastapi import FastAPI, Request
from pydantic import BaseModel
import pickle
import logging
import os
import re
# Set Hugging Face cache directory
os.environ["TRANSFORMERS_CACHE"] = "/path/to/writable/cache" # Replace with a writable path
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Text cleaning function
def clean_description(text):
"""
Clean transaction description by removing prefixes, numeric codes, and separators.
Examples:
'UPI/DR/12345678/netflix subscription' -> 'netflix subscription'
'UPI-DR-12345678-netflix subscription' -> 'netflix subscription'
'VISA/123456/uber ride to office' -> 'uber ride to office'
"""
# Convert to lowercase (optional, depending on model training)
text = text.lower()
# Remove common transaction prefixes and codes
patterns = [
r'^upi/dr/[0-9]+/', # Matches 'UPI/DR/12345678/'
r'^upi-dr-[0-9]+-', # Matches 'UPI-DR-12345678-'
r'^visa/[0-9]+/', # Matches 'VISA/123456/'
r'^[a-zA-Z]+/[0-9]+/', # Matches other prefixes like 'POS/123456/'
r'^[a-zA-Z]+-[0-9]+-', # Matches other prefixes like 'POS-123456-'
r'\b[0-9]{6,}\b', # Matches standalone numeric codes (6+ digits)
]
for pattern in patterns:
text = re.sub(pattern, '', text, flags=re.IGNORECASE)
# Replace multiple separators with a single space
text = re.sub(r'[-_/]+', ' ', text)
# Remove extra whitespace
text = ' '.join(text.split())
# Return cleaned text, or original if cleaning results in empty string
return text if text else "unknown transaction"
# Load label encoders
try:
# Note: Ensure these were pickled with scikit-learn 1.6.1 to avoid version mismatch
with open("main_category_encoder_5k.pkl", "rb") as f:
main_category_encoder = pickle.load(f)
with open("sub_category_encoder_5k.pkl", "rb") as f:
sub_category_encoder = pickle.load(f)
except FileNotFoundError as e:
logger.error(f"Failed to load label encoders: {e}")
raise
# Load tokenizer from local directory
try:
tokenizer = BertTokenizer.from_pretrained("./tokenizer")
logger.info("Tokenizer loaded successfully")
except Exception as e:
logger.error(f"Failed to load tokenizer: {e}")
raise
# Define the model
class BERTFNN(nn.Module):
def __init__(self, num_main_classes, num_sub_classes):
super(BERTFNN, self).__init__()
self.bert = BertModel.from_pretrained("./bert-model")
self.fc_main = nn.Linear(self.bert.config.hidden_size, num_main_classes)
self.fc_sub = nn.Linear(self.bert.config.hidden_size + num_main_classes, num_sub_classes)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
cls_embedding = outputs.last_hidden_state[:, 0, :]
main_logits = self.fc_main(cls_embedding)
main_pred = torch.softmax(main_logits, dim=1)
combined_input = torch.cat((cls_embedding, main_pred), dim=1)
sub_logits = self.fc_sub(combined_input)
return main_logits, sub_logits
# Load trained model
num_main_classes = len(main_category_encoder.classes_)
num_sub_classes = len(sub_category_encoder.classes_)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
model = BERTFNN(num_main_classes, num_sub_classes).to(device)
model.load_state_dict(torch.load("expense_categorization_5k.pth", map_location=device))
model.eval()
logger.info("Model loaded successfully")
except FileNotFoundError as e:
logger.error(f"Failed to load model weights: {e}")
raise
except Exception as e:
logger.error(f"Failed to initialize model: {e}")
raise
# Initialize FastAPI
app = FastAPI()
# Define request body
class TransactionInput(BaseModel):
description: str
# Define root endpoint for debugging
@app.get("/")
async def root():
return {"message": "Welcome to the Expense Categorization API. Use POST /predict to categorize expenses."}
# Define predict endpoint with text cleaning and confidence scores
@app.post("/predict")
async def predict_category(transaction: TransactionInput, request: Request):
logger.info("Starting prediction for request")
try:
logger.info(f"Received request: {transaction.dict()}")
# Clean the input description
cleaned_description = clean_description(transaction.description)
logger.info(f"Cleaned description: {cleaned_description}")
# Tokenize cleaned description
tokens = tokenizer(cleaned_description, return_tensors="pt", truncation=True, padding="max_length", max_length=64)
input_ids = tokens["input_ids"].to(device)
attention_mask = tokens["attention_mask"].to(device)
logger.info("Tokenization completed")
# Get model predictions
with torch.no_grad():
main_logits, sub_logits = model(input_ids, attention_mask)
logger.info("Model inference completed")
# Compute softmax probabilities for main category
main_probs = torch.softmax(main_logits, dim=1)
main_category_idx = torch.argmax(main_probs, dim=1).cpu().item()
main_confidence = main_probs[0, main_category_idx].cpu().item()
# Compute softmax probabilities for subcategory
sub_probs = torch.softmax(sub_logits, dim=1)
sub_category_idx = torch.argmax(sub_probs, dim=1).cpu().item()
sub_confidence = sub_probs[0, sub_category_idx].cpu().item()
# Decode category labels
main_category = main_category_encoder.inverse_transform([main_category_idx])[0]
sub_category = sub_category_encoder.inverse_transform([sub_category_idx])[0]
logger.info("Category decoding completed")
# Prepare response
response = {
"category": main_category,
"subcategory": sub_category,
"category_confidence": round(main_confidence, 4),
"subcategory_confidence": round(sub_confidence, 4)
}
logger.info(f"Response: {response}")
return response
except Exception as e:
logger.error(f"Error in prediction: {str(e)}", exc_info=True)
return {"error": str(e)}, 500