Hemang1915 commited on
Commit
a791c56
·
1 Parent(s): f6aaa53

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import Optional
4
+ import tensorflow as tf
5
+ from tensorflow.keras.layers import Dense
6
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
7
+ import numpy as np
8
+ import pickle
9
+ from stable_baselines3 import PPO
10
+
11
+ app = FastAPI(title="Transaction Classifier API", description="API for classifying banking transactions.")
12
+ print("FastAPI app initialized...")
13
+
14
+ model = None
15
+ tokenizer = None
16
+ le_category = None
17
+ le_subcategory = None
18
+ ppo_model = None
19
+ max_len = 20
20
+
21
+ class HierarchicalPrediction(tf.keras.layers.Layer):
22
+ def __init__(self, num_subcategories, cat_to_subcat_tensor, max_subcats_per_cat, **kwargs):
23
+ super(HierarchicalPrediction, self).__init__(**kwargs)
24
+ self.num_subcategories = num_subcategories
25
+ self.cat_to_subcat_tensor = cat_to_subcat_tensor
26
+ self.max_subcats_per_cat = max_subcats_per_cat
27
+ self.subcategory_dense = Dense(num_subcategories, activation=None)
28
+
29
+ def build(self, input_shape):
30
+ super(HierarchicalPrediction, self).build(input_shape)
31
+
32
+ def call(self, inputs):
33
+ lstm_output, category_probs = inputs
34
+ subcat_logits = self.subcategory_dense(lstm_output)
35
+ batch_size = tf.shape(category_probs)[0]
36
+ predicted_categories = tf.argmax(category_probs, axis=1)
37
+ valid_subcat_indices = tf.gather(self.cat_to_subcat_tensor, predicted_categories)
38
+ batch_indices = tf.range(batch_size)
39
+ batch_indices_expanded = tf.tile(batch_indices[:, tf.newaxis], [1, self.max_subcats_per_cat])
40
+ update_indices = tf.stack([batch_indices_expanded, valid_subcat_indices], axis=-1)
41
+ update_indices = tf.reshape(update_indices, [-1, 2])
42
+ valid_mask = tf.not_equal(valid_subcat_indices, -1)
43
+ valid_indices = tf.boolean_mask(update_indices, tf.reshape(valid_mask, [-1]))
44
+ updates = tf.ones(tf.shape(valid_indices)[0], dtype=tf.float32)
45
+ mask = tf.scatter_nd(valid_indices, updates, [batch_size, self.num_subcategories])
46
+ masked_logits = subcat_logits * mask + (1 - mask) * tf.float32.min
47
+ return tf.nn.softmax(masked_logits)
48
+
49
+ def get_config(self):
50
+ config = super(HierarchicalPrediction, self).get_config()
51
+ config.update({
52
+ 'num_subcategories': self.num_subcategories,
53
+ 'cat_to_subcat_tensor': self.cat_to_subcat_tensor.numpy(),
54
+ 'max_subcats_per_cat': self.max_subcats_per_cat
55
+ })
56
+ return config
57
+
58
+ @classmethod
59
+ def from_config(cls, config):
60
+ config['cat_to_subcat_tensor'] = tf.constant(config['cat_to_subcat_tensor'], dtype=tf.int32)
61
+ return cls(**config)
62
+
63
+ tf.keras.utils.get_custom_objects()['HierarchicalPrediction'] = HierarchicalPrediction
64
+
65
+ def load_resources():
66
+ global model, tokenizer, le_category, le_subcategory
67
+ if model is None:
68
+ print("Loading BiLSTM model...")
69
+ model = tf.keras.models.load_model('model.h5', custom_objects={'HierarchicalPrediction': HierarchicalPrediction})
70
+ print("BiLSTM model loaded.")
71
+ if tokenizer is None:
72
+ print("Loading tokenizer...")
73
+ with open('tokenizer.pkl', 'rb') as f:
74
+ tokenizer = pickle.load(f)
75
+ print("Tokenizer loaded.")
76
+ if le_category is None:
77
+ print("Loading category label encoder...")
78
+ with open('le_category.pkl', 'rb') as f:
79
+ le_category = pickle.load(f)
80
+ print("Category label encoder loaded.")
81
+ if le_subcategory is None:
82
+ print("Loading subcategory label encoder...")
83
+ with open('le_subcategory.pkl', 'rb') as f:
84
+ le_subcategory = pickle.load(f)
85
+ print("Subcategory label encoder loaded.")
86
+ return model, tokenizer, le_category, le_subcategory
87
+
88
+ def load_ppo_model():
89
+ global ppo_model
90
+ if ppo_model is None:
91
+ print("Loading PPO model...")
92
+ ppo_model = PPO.load('ppo_finetuned_model')
93
+ print("PPO model loaded.")
94
+ return ppo_model
95
+
96
+ class TransactionRequest(BaseModel):
97
+ description: str
98
+ use_rl: Optional[bool] = False
99
+
100
+ class PredictionResponse(BaseModel):
101
+ category: str
102
+ subcategory: str
103
+ category_confidence: float
104
+ subcategory_confidence: float
105
+
106
+ @app.get("/")
107
+ async def root():
108
+ return {"message": "Welcome to the Transaction Classifier API. Use POST /predict to classify transactions."}
109
+
110
+ @app.post("/predict", response_model=PredictionResponse)
111
+ async def predict(request: TransactionRequest):
112
+ try:
113
+ model, tokenizer, le_category, le_subcategory = load_resources()
114
+ num_subcategories = len(le_subcategory.classes_)
115
+
116
+ seq = tokenizer.texts_to_sequences([request.description])
117
+ pad = pad_sequences(seq, maxlen=max_len, padding='post')
118
+ pred = model.predict(pad, verbose=0)
119
+
120
+ if request.use_rl:
121
+ ppo = load_ppo_model()
122
+ obs = pad[0]
123
+ action, _ = ppo.predict(obs)
124
+ print(f"RL Action: {action}, Observation: {obs}")
125
+ cat_idx = action // num_subcategories
126
+ subcat_idx = action % num_subcategories
127
+ print(f"RL cat_idx: {cat_idx}, subcat_idx: {subcat_idx}")
128
+ else:
129
+ cat_probs = pred[0][0]
130
+ subcat_probs = pred[1][0]
131
+ cat_idx = np.argmax(cat_probs)
132
+ subcat_idx = np.argmax(subcat_probs)
133
+
134
+ cat_pred = le_category.inverse_transform([cat_idx])[0]
135
+ subcat_pred = le_subcategory.inverse_transform([subcat_idx])[0]
136
+ cat_conf = float(pred[0][0][cat_idx] * 100)
137
+ subcat_conf = float(pred[1][0][subcat_idx] * 100)
138
+
139
+ print(f"Predicted: category={cat_pred}, subcategory={subcat_pred}")
140
+ return {
141
+ "category": cat_pred,
142
+ "subcategory": subcat_pred,
143
+ "category_confidence": cat_conf,
144
+ "subcategory_confidence": subcat_conf
145
+ }
146
+ except Exception as e:
147
+ raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
148
+
149
+ @app.get("/health")
150
+ async def health_check():
151
+ return {"status": "healthy"}
152
+
153
+ print("API ready with lazy-loaded resources.")