Hemang1915 commited on
Commit
50109cc
·
1 Parent(s): 8666b71

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ from tensorflow.keras.layers import Dense
5
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
6
+ import pickle
7
+ from stable_baselines3 import PPO
8
+ from pydantic import BaseModel
9
+ from typing import Optional
10
+
11
+ # Define the custom HierarchicalPrediction layer
12
+ class HierarchicalPrediction(tf.keras.layers.Layer):
13
+ def __init__(self, num_subcategories, cat_to_subcat_tensor, max_subcats_per_cat, **kwargs):
14
+ super(HierarchicalPrediction, self).__init__(**kwargs)
15
+ self.num_subcategories = num_subcategories
16
+ self.cat_to_subcat_tensor = cat_to_subcat_tensor
17
+ self.max_subcats_per_cat = max_subcats_per_cat
18
+ self.subcategory_dense = Dense(num_subcategories, activation=None)
19
+
20
+ def build(self, input_shape):
21
+ super(HierarchicalPrediction, self).build(input_shape)
22
+
23
+ def call(self, inputs):
24
+ lstm_output, category_probs = inputs
25
+ subcat_logits = self.subcategory_dense(lstm_output)
26
+ batch_size = tf.shape(category_probs)[0]
27
+ predicted_categories = tf.argmax(category_probs, axis=1)
28
+ valid_subcat_indices = tf.gather(self.cat_to_subcat_tensor, predicted_categories)
29
+ batch_indices = tf.range(batch_size)
30
+ batch_indices_expanded = tf.tile(batch_indices[:, tf.newaxis], [1, self.max_subcats_per_cat])
31
+ update_indices = tf.stack([batch_indices_expanded, valid_subcat_indices], axis=-1)
32
+ update_indices = tf.reshape(update_indices, [-1, 2])
33
+ valid_mask = tf.not_equal(valid_subcat_indices, -1)
34
+ valid_indices = tf.boolean_mask(update_indices, tf.reshape(valid_mask, [-1]))
35
+ updates = tf.ones(tf.shape(valid_indices)[0], dtype=tf.float32)
36
+ mask = tf.scatter_nd(valid_indices, updates, [batch_size, self.num_subcategories])
37
+ masked_logits = subcat_logits * mask + (1 - mask) * tf.float32.min
38
+ return tf.nn.softmax(masked_logits)
39
+
40
+ def get_config(self):
41
+ config = super(HierarchicalPrediction, self).get_config()
42
+ config.update({
43
+ 'num_subcategories': self.num_subcategories,
44
+ 'cat_to_subcat_tensor': self.cat_to_subcat_tensor.numpy(),
45
+ 'max_subcats_per_cat': self.max_subcats_per_cat
46
+ })
47
+ return config
48
+
49
+ @classmethod
50
+ def from_config(cls, config):
51
+ config['cat_to_subcat_tensor'] = tf.constant(config['cat_to_subcat_tensor'], dtype=tf.int32)
52
+ return cls(**config)
53
+
54
+ # Register custom layer
55
+ tf.keras.utils.get_custom_objects()['HierarchicalPrediction'] = HierarchicalPrediction
56
+
57
+ # Load model and preprocessing objects
58
+ try:
59
+ model = tf.keras.models.load_model('model.h5', custom_objects={'HierarchicalPrediction': HierarchicalPrediction})
60
+ except Exception as e:
61
+ raise RuntimeError(f"Failed to load model: {e}")
62
+
63
+ with open('tokenizer.pkl', 'rb') as f:
64
+ tokenizer = pickle.load(f)
65
+ with open('le_category.pkl', 'rb') as f:
66
+ le_category = pickle.load(f)
67
+ with open('le_subcategory.pkl', 'rb') as f:
68
+ le_subcategory = pickle.load(f)
69
+
70
+ max_len = 20
71
+ num_subcategories = len(le_subcategory.classes_)
72
+
73
+ # Load PPO model (optional)
74
+ try:
75
+ ppo_model = PPO.load('ppo_finetuned_model')
76
+ rl_available = True
77
+ except FileNotFoundError:
78
+ ppo_model = None
79
+ rl_available = False
80
+
81
+ # Define request model
82
+ class TransactionRequest(BaseModel):
83
+ description: str
84
+ use_rl: Optional[bool] = False
85
+
86
+ # Define response model
87
+ class PredictionResponse(BaseModel):
88
+ category: str
89
+ subcategory: str
90
+ category_confidence: float
91
+ subcategory_confidence: float
92
+
93
+ # Initialize FastAPI app
94
+ app = FastAPI(title="Transaction Classifier API", description="API for classifying banking transactions.")
95
+
96
+ # Prediction functions
97
+ def predict_category_subcategory(description):
98
+ seq = tokenizer.texts_to_sequences([description])
99
+ pad = pad_sequences(seq, maxlen=max_len, padding='post')
100
+ pred = model.predict(pad, verbose=0)
101
+ cat_probs = pred[0][0]
102
+ subcat_probs = pred[1][0]
103
+ cat_idx = np.argmax(cat_probs)
104
+ subcat_idx = np.argmax(subcat_probs)
105
+ cat_pred = le_category.inverse_transform([cat_idx])[0]
106
+ subcat_pred = le_subcategory.inverse_transform([subcat_idx])[0]
107
+ cat_conf = float(cat_probs[cat_idx] * 100)
108
+ subcat_conf = float(subcat_probs[subcat_idx] * 100)
109
+ return cat_pred, subcat_pred, cat_conf, subcat_conf
110
+
111
+ def predict_with_rl(description):
112
+ if not rl_available:
113
+ raise HTTPException(status_code=400, detail="RL model not available")
114
+ seq = tokenizer.texts_to_sequences([description])
115
+ pad = pad_sequences(seq, maxlen=max_len, padding='post')
116
+ pred = model.predict(pad, verbose=0)
117
+ obs = pad[0]
118
+ action, _ = ppo_model.predict(obs)
119
+ cat_idx = action // num_subcategories
120
+ subcat_idx = action % num_subcategories
121
+ cat_pred = le_category.inverse_transform([cat_idx])[0]
122
+ subcat_pred = le_subcategory.inverse_transform([subcat_idx])[0]
123
+ cat_conf = float(pred[0][0][cat_idx] * 100)
124
+ subcat_conf = float(pred[1][0][subcat_idx] * 100)
125
+ return cat_pred, subcat_pred, cat_conf, subcat_conf
126
+
127
+ # Define API routes
128
+ @app.get("/")
129
+ async def root():
130
+ return {"message": "Welcome to the Transaction Classifier API. Use POST /predict to classify transactions."}
131
+
132
+ @app.post("/predict", response_model=PredictionResponse)
133
+ async def predict(request: TransactionRequest):
134
+ try:
135
+ if request.use_rl:
136
+ cat_pred, subcat_pred, cat_conf, subcat_conf = predict_with_rl(request.description)
137
+ else:
138
+ cat_pred, subcat_pred, cat_conf, subcat_conf = predict_category_subcategory(request.description)
139
+ return {
140
+ "category": cat_pred,
141
+ "subcategory": subcat_pred,
142
+ "category_confidence": cat_conf,
143
+ "subcategory_confidence": subcat_conf
144
+ }
145
+ except Exception as e:
146
+ raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
147
+
148
+ # Optional: Health check endpoint
149
+ @app.get("/health")
150
+ async def health_check():
151
+ return {"status": "healthy"}