|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import BertTokenizer, BertModel |
|
|
from fastapi import FastAPI, Request |
|
|
from pydantic import BaseModel |
|
|
import pickle |
|
|
import logging |
|
|
import os |
|
|
import re |
|
|
|
|
|
|
|
|
os.environ["TRANSFORMERS_CACHE"] = "/path/to/writable/cache" |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def clean_description(text): |
|
|
""" |
|
|
Clean transaction description by removing prefixes, numeric codes, and separators. |
|
|
Examples: |
|
|
'UPI/DR/12345678/netflix subscription' -> 'netflix subscription' |
|
|
'UPI-DR-12345678-netflix subscription' -> 'netflix subscription' |
|
|
'VISA/123456/uber ride to office' -> 'uber ride to office' |
|
|
""" |
|
|
|
|
|
text = text.lower() |
|
|
|
|
|
|
|
|
patterns = [ |
|
|
r'^upi/dr/[0-9]+/', |
|
|
r'^upi-dr-[0-9]+-', |
|
|
r'^visa/[0-9]+/', |
|
|
r'^[a-zA-Z]+/[0-9]+/', |
|
|
r'^[a-zA-Z]+-[0-9]+-', |
|
|
r'\b[0-9]{6,}\b', |
|
|
] |
|
|
|
|
|
for pattern in patterns: |
|
|
text = re.sub(pattern, '', text, flags=re.IGNORECASE) |
|
|
|
|
|
|
|
|
text = re.sub(r'[-_/]+', ' ', text) |
|
|
|
|
|
|
|
|
text = ' '.join(text.split()) |
|
|
|
|
|
|
|
|
return text if text else "unknown transaction" |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
with open("main_category_encoder_5k.pkl", "rb") as f: |
|
|
main_category_encoder = pickle.load(f) |
|
|
with open("sub_category_encoder_5k.pkl", "rb") as f: |
|
|
sub_category_encoder = pickle.load(f) |
|
|
except FileNotFoundError as e: |
|
|
logger.error(f"Failed to load label encoders: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
try: |
|
|
tokenizer = BertTokenizer.from_pretrained("./tokenizer") |
|
|
logger.info("Tokenizer loaded successfully") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load tokenizer: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
class BERTFNN(nn.Module): |
|
|
def __init__(self, num_main_classes, num_sub_classes): |
|
|
super(BERTFNN, self).__init__() |
|
|
self.bert = BertModel.from_pretrained("./bert-model") |
|
|
self.fc_main = nn.Linear(self.bert.config.hidden_size, num_main_classes) |
|
|
self.fc_sub = nn.Linear(self.bert.config.hidden_size + num_main_classes, num_sub_classes) |
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
|
|
cls_embedding = outputs.last_hidden_state[:, 0, :] |
|
|
main_logits = self.fc_main(cls_embedding) |
|
|
main_pred = torch.softmax(main_logits, dim=1) |
|
|
combined_input = torch.cat((cls_embedding, main_pred), dim=1) |
|
|
sub_logits = self.fc_sub(combined_input) |
|
|
return main_logits, sub_logits |
|
|
|
|
|
|
|
|
num_main_classes = len(main_category_encoder.classes_) |
|
|
num_sub_classes = len(sub_category_encoder.classes_) |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
try: |
|
|
model = BERTFNN(num_main_classes, num_sub_classes).to(device) |
|
|
model.load_state_dict(torch.load("expense_categorization_5k.pth", map_location=device)) |
|
|
model.eval() |
|
|
logger.info("Model loaded successfully") |
|
|
except FileNotFoundError as e: |
|
|
logger.error(f"Failed to load model weights: {e}") |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize model: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
class TransactionInput(BaseModel): |
|
|
description: str |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return {"message": "Welcome to the Expense Categorization API. Use POST /predict to categorize expenses."} |
|
|
|
|
|
|
|
|
@app.post("/predict") |
|
|
async def predict_category(transaction: TransactionInput, request: Request): |
|
|
logger.info("Starting prediction for request") |
|
|
try: |
|
|
logger.info(f"Received request: {transaction.dict()}") |
|
|
|
|
|
|
|
|
cleaned_description = clean_description(transaction.description) |
|
|
logger.info(f"Cleaned description: {cleaned_description}") |
|
|
|
|
|
|
|
|
tokens = tokenizer(cleaned_description, return_tensors="pt", truncation=True, padding="max_length", max_length=64) |
|
|
input_ids = tokens["input_ids"].to(device) |
|
|
attention_mask = tokens["attention_mask"].to(device) |
|
|
logger.info("Tokenization completed") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
main_logits, sub_logits = model(input_ids, attention_mask) |
|
|
logger.info("Model inference completed") |
|
|
|
|
|
|
|
|
main_probs = torch.softmax(main_logits, dim=1) |
|
|
main_category_idx = torch.argmax(main_probs, dim=1).cpu().item() |
|
|
main_confidence = main_probs[0, main_category_idx].cpu().item() |
|
|
|
|
|
|
|
|
sub_probs = torch.softmax(sub_logits, dim=1) |
|
|
sub_category_idx = torch.argmax(sub_probs, dim=1).cpu().item() |
|
|
sub_confidence = sub_probs[0, sub_category_idx].cpu().item() |
|
|
|
|
|
|
|
|
main_category = main_category_encoder.inverse_transform([main_category_idx])[0] |
|
|
sub_category = sub_category_encoder.inverse_transform([sub_category_idx])[0] |
|
|
logger.info("Category decoding completed") |
|
|
|
|
|
|
|
|
response = { |
|
|
"category": main_category, |
|
|
"subcategory": sub_category, |
|
|
"category_confidence": round(main_confidence, 4), |
|
|
"subcategory_confidence": round(sub_confidence, 4) |
|
|
} |
|
|
logger.info(f"Response: {response}") |
|
|
return response |
|
|
except Exception as e: |
|
|
logger.error(f"Error in prediction: {str(e)}", exc_info=True) |
|
|
return {"error": str(e)}, 500 |