File size: 6,440 Bytes
ac2c610 d92700f ac2c610 d92700f ca4fd98 d92700f ac2c610 ca4fd98 ac2c610 d92700f ca4fd98 d92700f ac2c610 3f09e7b d92700f 3f09e7b d92700f ac2c610 ca4fd98 ac2c610 3e3eb39 ac2c610 3e3eb39 d92700f 3f09e7b d92700f ac2c610 d92700f ac2c610 ca4fd98 d92700f ca4fd98 d92700f ca4fd98 d92700f ca4fd98 d2bb04d d92700f ca4fd98 d2bb04d ca4fd98 d2bb04d d92700f 4e9b7ce d92700f ca4fd98 d92700f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
from fastapi import FastAPI, Request
from pydantic import BaseModel
import pickle
import logging
import os
import re
# Set Hugging Face cache directory
os.environ["TRANSFORMERS_CACHE"] = "/path/to/writable/cache" # Replace with a writable path
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Text cleaning function
def clean_description(text):
"""
Clean transaction description by removing prefixes, numeric codes, and separators.
Examples:
'UPI/DR/12345678/netflix subscription' -> 'netflix subscription'
'UPI-DR-12345678-netflix subscription' -> 'netflix subscription'
'VISA/123456/uber ride to office' -> 'uber ride to office'
"""
# Convert to lowercase (optional, depending on model training)
text = text.lower()
# Remove common transaction prefixes and codes
patterns = [
r'^upi/dr/[0-9]+/', # Matches 'UPI/DR/12345678/'
r'^upi-dr-[0-9]+-', # Matches 'UPI-DR-12345678-'
r'^visa/[0-9]+/', # Matches 'VISA/123456/'
r'^[a-zA-Z]+/[0-9]+/', # Matches other prefixes like 'POS/123456/'
r'^[a-zA-Z]+-[0-9]+-', # Matches other prefixes like 'POS-123456-'
r'\b[0-9]{6,}\b', # Matches standalone numeric codes (6+ digits)
]
for pattern in patterns:
text = re.sub(pattern, '', text, flags=re.IGNORECASE)
# Replace multiple separators with a single space
text = re.sub(r'[-_/]+', ' ', text)
# Remove extra whitespace
text = ' '.join(text.split())
# Return cleaned text, or original if cleaning results in empty string
return text if text else "unknown transaction"
# Load label encoders
try:
# Note: Ensure these were pickled with scikit-learn 1.6.1 to avoid version mismatch
with open("main_category_encoder_5k.pkl", "rb") as f:
main_category_encoder = pickle.load(f)
with open("sub_category_encoder_5k.pkl", "rb") as f:
sub_category_encoder = pickle.load(f)
except FileNotFoundError as e:
logger.error(f"Failed to load label encoders: {e}")
raise
# Load tokenizer from local directory
try:
tokenizer = BertTokenizer.from_pretrained("./tokenizer")
logger.info("Tokenizer loaded successfully")
except Exception as e:
logger.error(f"Failed to load tokenizer: {e}")
raise
# Define the model
class BERTFNN(nn.Module):
def __init__(self, num_main_classes, num_sub_classes):
super(BERTFNN, self).__init__()
self.bert = BertModel.from_pretrained("./bert-model")
self.fc_main = nn.Linear(self.bert.config.hidden_size, num_main_classes)
self.fc_sub = nn.Linear(self.bert.config.hidden_size + num_main_classes, num_sub_classes)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
cls_embedding = outputs.last_hidden_state[:, 0, :]
main_logits = self.fc_main(cls_embedding)
main_pred = torch.softmax(main_logits, dim=1)
combined_input = torch.cat((cls_embedding, main_pred), dim=1)
sub_logits = self.fc_sub(combined_input)
return main_logits, sub_logits
# Load trained model
num_main_classes = len(main_category_encoder.classes_)
num_sub_classes = len(sub_category_encoder.classes_)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
model = BERTFNN(num_main_classes, num_sub_classes).to(device)
model.load_state_dict(torch.load("expense_categorization_5k.pth", map_location=device))
model.eval()
logger.info("Model loaded successfully")
except FileNotFoundError as e:
logger.error(f"Failed to load model weights: {e}")
raise
except Exception as e:
logger.error(f"Failed to initialize model: {e}")
raise
# Initialize FastAPI
app = FastAPI()
# Define request body
class TransactionInput(BaseModel):
description: str
# Define root endpoint for debugging
@app.get("/")
async def root():
return {"message": "Welcome to the Expense Categorization API. Use POST /predict to categorize expenses."}
# Define predict endpoint with text cleaning and confidence scores
@app.post("/predict")
async def predict_category(transaction: TransactionInput, request: Request):
logger.info("Starting prediction for request")
try:
logger.info(f"Received request: {transaction.dict()}")
# Clean the input description
cleaned_description = clean_description(transaction.description)
logger.info(f"Cleaned description: {cleaned_description}")
# Tokenize cleaned description
tokens = tokenizer(cleaned_description, return_tensors="pt", truncation=True, padding="max_length", max_length=64)
input_ids = tokens["input_ids"].to(device)
attention_mask = tokens["attention_mask"].to(device)
logger.info("Tokenization completed")
# Get model predictions
with torch.no_grad():
main_logits, sub_logits = model(input_ids, attention_mask)
logger.info("Model inference completed")
# Compute softmax probabilities for main category
main_probs = torch.softmax(main_logits, dim=1)
main_category_idx = torch.argmax(main_probs, dim=1).cpu().item()
main_confidence = main_probs[0, main_category_idx].cpu().item()
# Compute softmax probabilities for subcategory
sub_probs = torch.softmax(sub_logits, dim=1)
sub_category_idx = torch.argmax(sub_probs, dim=1).cpu().item()
sub_confidence = sub_probs[0, sub_category_idx].cpu().item()
# Decode category labels
main_category = main_category_encoder.inverse_transform([main_category_idx])[0]
sub_category = sub_category_encoder.inverse_transform([sub_category_idx])[0]
logger.info("Category decoding completed")
# Prepare response
response = {
"category": main_category,
"subcategory": sub_category,
"category_confidence": round(main_confidence, 4),
"subcategory_confidence": round(sub_confidence, 4)
}
logger.info(f"Response: {response}")
return response
except Exception as e:
logger.error(f"Error in prediction: {str(e)}", exc_info=True)
return {"error": str(e)}, 500 |