File size: 5,021 Bytes
a791c56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75289c7
 
 
 
 
 
 
 
 
 
 
 
a791c56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ba4070
 
 
 
a791c56
 
 
9ba4070
 
a791c56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as np
import pickle

app = FastAPI(title="Transaction Classifier API", description="API for classifying banking transactions.")
print("FastAPI app initialized...")

max_len = 20

class HierarchicalPrediction(tf.keras.layers.Layer):
    def __init__(self, num_subcategories, cat_to_subcat_tensor, max_subcats_per_cat, **kwargs):
        super(HierarchicalPrediction, self).__init__(**kwargs)
        self.num_subcategories = num_subcategories
        self.cat_to_subcat_tensor = cat_to_subcat_tensor
        self.max_subcats_per_cat = max_subcats_per_cat
        self.subcategory_dense = Dense(num_subcategories, activation=None)

    def build(self, input_shape):
        super(HierarchicalPrediction, self).build(input_shape)

    def call(self, inputs):
        lstm_output, category_probs = inputs
        subcat_logits = self.subcategory_dense(lstm_output)
        batch_size = tf.shape(category_probs)[0]
        predicted_categories = tf.argmax(category_probs, axis=1)
        valid_subcat_indices = tf.gather(self.cat_to_subcat_tensor, predicted_categories)
        batch_indices = tf.range(batch_size)
        batch_indices_expanded = tf.tile(batch_indices[:, tf.newaxis], [1, self.max_subcats_per_cat])
        update_indices = tf.stack([batch_indices_expanded, valid_subcat_indices], axis=-1)
        update_indices = tf.reshape(update_indices, [-1, 2])
        valid_mask = tf.not_equal(valid_subcat_indices, -1)
        valid_indices = tf.boolean_mask(update_indices, tf.reshape(valid_mask, [-1]))
        updates = tf.ones(tf.shape(valid_indices)[0], dtype=tf.float32)
        mask = tf.scatter_nd(valid_indices, updates, [batch_size, self.num_subcategories])
        masked_logits = subcat_logits * mask + (1 - mask) * tf.float32.min
        return tf.nn.softmax(masked_logits)

    def get_config(self):
        config = super(HierarchicalPrediction, self).get_config()
        config.update({
            'num_subcategories': self.num_subcategories,
            'cat_to_subcat_tensor': self.cat_to_subcat_tensor.numpy(),
            'max_subcats_per_cat': self.max_subcats_per_cat
        })
        return config

    @classmethod
    def from_config(cls, config):
        config['cat_to_subcat_tensor'] = tf.constant(config['cat_to_subcat_tensor'], dtype=tf.int32)
        return cls(**config)

tf.keras.utils.get_custom_objects()['HierarchicalPrediction'] = HierarchicalPrediction

def load_resources():
    print("Loading BiLSTM model...")
    model = tf.keras.models.load_model('model.h5', custom_objects={'HierarchicalPrediction': HierarchicalPrediction})
    print("Loading tokenizer...")
    with open('tokenizer.pkl', 'rb') as f:
        tokenizer = pickle.load(f)
    print("Loading category label encoder...")
    with open('le_category.pkl', 'rb') as f:
        le_category = pickle.load(f)
    print("Loading subcategory label encoder...")
    with open('le_subcategory.pkl', 'rb') as f:
        le_subcategory = pickle.load(f)
    print(f"Num categories: {len(le_category.classes_)}, Num subcategories: {len(le_subcategory.classes_)}")
    return model, tokenizer, le_category, le_subcategory

class TransactionRequest(BaseModel):
    description: str

class PredictionResponse(BaseModel):
    category: str
    subcategory: str
    category_confidence: float
    subcategory_confidence: float

@app.get("/")
async def root():
    return {"message": "Welcome to the Transaction Classifier API. Use POST /predict to classify transactions."}

@app.post("/predict", response_model=PredictionResponse)
async def predict(request: TransactionRequest):
    try:
        model, tokenizer, le_category, le_subcategory = load_resources()
        
        seq = tokenizer.texts_to_sequences([request.description])
        pad = pad_sequences(seq, maxlen=max_len, padding='post')
        pred = model.predict(pad, verbose=0)
        
        cat_probs = pred[0][0]
        subcat_probs = pred[1][0]
        cat_idx = np.argmax(cat_probs)
        subcat_idx = np.argmax(subcat_probs)
        
        cat_pred = le_category.inverse_transform([cat_idx])[0]
        subcat_pred = le_subcategory.inverse_transform([subcat_idx])[0]
        cat_conf = float(cat_probs[cat_idx] * 100)
        subcat_conf = float(subcat_probs[subcat_idx] * 100)
        
        print(f"Predicted: category={cat_pred}, subcategory={subcat_pred}")
        return {
            "category": cat_pred,
            "subcategory": subcat_pred,
            "category_confidence": cat_conf,
            "subcategory_confidence": subcat_conf
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")

@app.get("/health")
async def health_check():
    return {"status": "healthy"}

print("API ready with lazy-loaded resources.")