File size: 6,440 Bytes
ac2c610
 
 
d92700f
ac2c610
 
d92700f
ca4fd98
 
 
 
 
d92700f
 
 
 
ac2c610
ca4fd98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac2c610
d92700f
ca4fd98
d92700f
 
 
 
 
 
 
ac2c610
3f09e7b
d92700f
3f09e7b
 
d92700f
 
 
ac2c610
 
 
 
 
ca4fd98
ac2c610
 
 
 
 
 
 
3e3eb39
 
ac2c610
 
 
 
 
 
3e3eb39
d92700f
 
 
 
3f09e7b
d92700f
 
 
 
 
 
ac2c610
 
 
 
 
 
 
 
d92700f
 
 
 
ac2c610
ca4fd98
d92700f
 
ca4fd98
d92700f
 
ca4fd98
 
 
 
 
 
 
d92700f
 
ca4fd98
d2bb04d
 
d92700f
 
ca4fd98
d2bb04d
 
 
 
 
 
 
 
 
 
 
 
 
 
ca4fd98
d2bb04d
 
d92700f
4e9b7ce
 
 
 
d92700f
 
 
 
ca4fd98
d92700f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
from fastapi import FastAPI, Request
from pydantic import BaseModel
import pickle
import logging
import os
import re

# Set Hugging Face cache directory
os.environ["TRANSFORMERS_CACHE"] = "/path/to/writable/cache"  # Replace with a writable path

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Text cleaning function
def clean_description(text):
    """
    Clean transaction description by removing prefixes, numeric codes, and separators.
    Examples:
        'UPI/DR/12345678/netflix subscription' -> 'netflix subscription'
        'UPI-DR-12345678-netflix subscription' -> 'netflix subscription'
        'VISA/123456/uber ride to office' -> 'uber ride to office'
    """
    # Convert to lowercase (optional, depending on model training)
    text = text.lower()
    
    # Remove common transaction prefixes and codes
    patterns = [
        r'^upi/dr/[0-9]+/',  # Matches 'UPI/DR/12345678/'
        r'^upi-dr-[0-9]+-',  # Matches 'UPI-DR-12345678-'
        r'^visa/[0-9]+/',    # Matches 'VISA/123456/'
        r'^[a-zA-Z]+/[0-9]+/',  # Matches other prefixes like 'POS/123456/'
        r'^[a-zA-Z]+-[0-9]+-',   # Matches other prefixes like 'POS-123456-'
        r'\b[0-9]{6,}\b',    # Matches standalone numeric codes (6+ digits)
    ]
    
    for pattern in patterns:
        text = re.sub(pattern, '', text, flags=re.IGNORECASE)
    
    # Replace multiple separators with a single space
    text = re.sub(r'[-_/]+', ' ', text)
    
    # Remove extra whitespace
    text = ' '.join(text.split())
    
    # Return cleaned text, or original if cleaning results in empty string
    return text if text else "unknown transaction"

# Load label encoders
try:
    # Note: Ensure these were pickled with scikit-learn 1.6.1 to avoid version mismatch
    with open("main_category_encoder_5k.pkl", "rb") as f:
        main_category_encoder = pickle.load(f)
    with open("sub_category_encoder_5k.pkl", "rb") as f:
        sub_category_encoder = pickle.load(f)
except FileNotFoundError as e:
    logger.error(f"Failed to load label encoders: {e}")
    raise

# Load tokenizer from local directory
try:
    tokenizer = BertTokenizer.from_pretrained("./tokenizer")
    logger.info("Tokenizer loaded successfully")
except Exception as e:
    logger.error(f"Failed to load tokenizer: {e}")
    raise

# Define the model
class BERTFNN(nn.Module):
    def __init__(self, num_main_classes, num_sub_classes):
        super(BERTFNN, self).__init__()
        self.bert = BertModel.from_pretrained("./bert-model")
        self.fc_main = nn.Linear(self.bert.config.hidden_size, num_main_classes)
        self.fc_sub = nn.Linear(self.bert.config.hidden_size + num_main_classes, num_sub_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_embedding = outputs.last_hidden_state[:, 0, :]
        main_logits = self.fc_main(cls_embedding)
        main_pred = torch.softmax(main_logits, dim=1)
        combined_input = torch.cat((cls_embedding, main_pred), dim=1)
        sub_logits = self.fc_sub(combined_input)
        return main_logits, sub_logits

# Load trained model
num_main_classes = len(main_category_encoder.classes_)
num_sub_classes = len(sub_category_encoder.classes_)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    model = BERTFNN(num_main_classes, num_sub_classes).to(device)
    model.load_state_dict(torch.load("expense_categorization_5k.pth", map_location=device))
    model.eval()
    logger.info("Model loaded successfully")
except FileNotFoundError as e:
    logger.error(f"Failed to load model weights: {e}")
    raise
except Exception as e:
    logger.error(f"Failed to initialize model: {e}")
    raise

# Initialize FastAPI
app = FastAPI()

# Define request body
class TransactionInput(BaseModel):
    description: str

# Define root endpoint for debugging
@app.get("/")
async def root():
    return {"message": "Welcome to the Expense Categorization API. Use POST /predict to categorize expenses."}

# Define predict endpoint with text cleaning and confidence scores
@app.post("/predict")
async def predict_category(transaction: TransactionInput, request: Request):
    logger.info("Starting prediction for request")
    try:
        logger.info(f"Received request: {transaction.dict()}")
        
        # Clean the input description
        cleaned_description = clean_description(transaction.description)
        logger.info(f"Cleaned description: {cleaned_description}")
        
        # Tokenize cleaned description
        tokens = tokenizer(cleaned_description, return_tensors="pt", truncation=True, padding="max_length", max_length=64)
        input_ids = tokens["input_ids"].to(device)
        attention_mask = tokens["attention_mask"].to(device)
        logger.info("Tokenization completed")
        
        # Get model predictions
        with torch.no_grad():
            main_logits, sub_logits = model(input_ids, attention_mask)
        logger.info("Model inference completed")
        
        # Compute softmax probabilities for main category
        main_probs = torch.softmax(main_logits, dim=1)
        main_category_idx = torch.argmax(main_probs, dim=1).cpu().item()
        main_confidence = main_probs[0, main_category_idx].cpu().item()
        
        # Compute softmax probabilities for subcategory
        sub_probs = torch.softmax(sub_logits, dim=1)
        sub_category_idx = torch.argmax(sub_probs, dim=1).cpu().item()
        sub_confidence = sub_probs[0, sub_category_idx].cpu().item()
        
        # Decode category labels
        main_category = main_category_encoder.inverse_transform([main_category_idx])[0]
        sub_category = sub_category_encoder.inverse_transform([sub_category_idx])[0]
        logger.info("Category decoding completed")
        
        # Prepare response
        response = {
            "category": main_category,
            "subcategory": sub_category,
            "category_confidence": round(main_confidence, 4),
            "subcategory_confidence": round(sub_confidence, 4)
        }
        logger.info(f"Response: {response}")
        return response
    except Exception as e:
        logger.error(f"Error in prediction: {str(e)}", exc_info=True)
        return {"error": str(e)}, 500