Hemang1915 commited on
Commit
f6aaa53
·
1 Parent(s): 693e9be

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -153
app.py DELETED
@@ -1,153 +0,0 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
- from typing import Optional
4
- import tensorflow as tf
5
- from tensorflow.keras.layers import Dense
6
- from tensorflow.keras.preprocessing.sequence import pad_sequences
7
- import numpy as np
8
- import pickle
9
- from stable_baselines3 import PPO
10
-
11
- app = FastAPI(title="Transaction Classifier API", description="API for classifying banking transactions.")
12
- print("FastAPI app initialized...")
13
-
14
- model = None
15
- tokenizer = None
16
- le_category = None
17
- le_subcategory = None
18
- ppo_model = None
19
- max_len = 20
20
-
21
- class HierarchicalPrediction(tf.keras.layers.Layer):
22
- def __init__(self, num_subcategories, cat_to_subcat_tensor, max_subcats_per_cat, **kwargs):
23
- super(HierarchicalPrediction, self).__init__(**kwargs)
24
- self.num_subcategories = num_subcategories
25
- self.cat_to_subcat_tensor = cat_to_subcat_tensor
26
- self.max_subcats_per_cat = max_subcats_per_cat
27
- self.subcategory_dense = Dense(num_subcategories, activation=None)
28
-
29
- def build(self, input_shape):
30
- super(HierarchicalPrediction, self).build(input_shape)
31
-
32
- def call(self, inputs):
33
- lstm_output, category_probs = inputs
34
- subcat_logits = self.subcategory_dense(lstm_output)
35
- batch_size = tf.shape(category_probs)[0]
36
- predicted_categories = tf.argmax(category_probs, axis=1)
37
- valid_subcat_indices = tf.gather(self.cat_to_subcat_tensor, predicted_categories)
38
- batch_indices = tf.range(batch_size)
39
- batch_indices_expanded = tf.tile(batch_indices[:, tf.newaxis], [1, self.max_subcats_per_cat])
40
- update_indices = tf.stack([batch_indices_expanded, valid_subcat_indices], axis=-1)
41
- update_indices = tf.reshape(update_indices, [-1, 2])
42
- valid_mask = tf.not_equal(valid_subcat_indices, -1)
43
- valid_indices = tf.boolean_mask(update_indices, tf.reshape(valid_mask, [-1]))
44
- updates = tf.ones(tf.shape(valid_indices)[0], dtype=tf.float32)
45
- mask = tf.scatter_nd(valid_indices, updates, [batch_size, self.num_subcategories])
46
- masked_logits = subcat_logits * mask + (1 - mask) * tf.float32.min
47
- return tf.nn.softmax(masked_logits)
48
-
49
- def get_config(self):
50
- config = super(HierarchicalPrediction, self).get_config()
51
- config.update({
52
- 'num_subcategories': self.num_subcategories,
53
- 'cat_to_subcat_tensor': self.cat_to_subcat_tensor.numpy(),
54
- 'max_subcats_per_cat': self.max_subcats_per_cat
55
- })
56
- return config
57
-
58
- @classmethod
59
- def from_config(cls, config):
60
- config['cat_to_subcat_tensor'] = tf.constant(config['cat_to_subcat_tensor'], dtype=tf.int32)
61
- return cls(**config)
62
-
63
- tf.keras.utils.get_custom_objects()['HierarchicalPrediction'] = HierarchicalPrediction
64
-
65
- def load_resources():
66
- global model, tokenizer, le_category, le_subcategory
67
- if model is None:
68
- print("Loading BiLSTM model...")
69
- model = tf.keras.models.load_model('model.h5', custom_objects={'HierarchicalPrediction': HierarchicalPrediction})
70
- print("BiLSTM model loaded.")
71
- if tokenizer is None:
72
- print("Loading tokenizer...")
73
- with open('tokenizer.pkl', 'rb') as f:
74
- tokenizer = pickle.load(f)
75
- print("Tokenizer loaded.")
76
- if le_category is None:
77
- print("Loading category label encoder...")
78
- with open('le_category.pkl', 'rb') as f:
79
- le_category = pickle.load(f)
80
- print("Category label encoder loaded.")
81
- if le_subcategory is None:
82
- print("Loading subcategory label encoder...")
83
- with open('le_subcategory.pkl', 'rb') as f:
84
- le_subcategory = pickle.load(f)
85
- print("Subcategory label encoder loaded.")
86
- return model, tokenizer, le_category, le_subcategory
87
-
88
- def load_ppo_model():
89
- global ppo_model
90
- if ppo_model is None:
91
- print("Loading PPO model...")
92
- ppo_model = PPO.load('ppo_finetuned_model')
93
- print("PPO model loaded.")
94
- return ppo_model
95
-
96
- class TransactionRequest(BaseModel):
97
- description: str
98
- use_rl: Optional[bool] = False
99
-
100
- class PredictionResponse(BaseModel):
101
- category: str
102
- subcategory: str
103
- category_confidence: float
104
- subcategory_confidence: float
105
-
106
- @app.get("/")
107
- async def root():
108
- return {"message": "Welcome to the Transaction Classifier API. Use POST /predict to classify transactions."}
109
-
110
- @app.post("/predict", response_model=PredictionResponse)
111
- async def predict(request: TransactionRequest):
112
- try:
113
- model, tokenizer, le_category, le_subcategory = load_resources()
114
- num_subcategories = len(le_subcategory.classes_)
115
-
116
- seq = tokenizer.texts_to_sequences([request.description])
117
- pad = pad_sequences(seq, maxlen=max_len, padding='post')
118
- pred = model.predict(pad, verbose=0)
119
-
120
- if request.use_rl:
121
- ppo = load_ppo_model()
122
- obs = pad[0]
123
- action, _ = ppo.predict(obs)
124
- print(f"RL Action: {action}, Observation: {obs}")
125
- cat_idx = action // num_subcategories
126
- subcat_idx = action % num_subcategories
127
- print(f"RL cat_idx: {cat_idx}, subcat_idx: {subcat_idx}")
128
- else:
129
- cat_probs = pred[0][0]
130
- subcat_probs = pred[1][0]
131
- cat_idx = np.argmax(cat_probs)
132
- subcat_idx = np.argmax(subcat_probs)
133
-
134
- cat_pred = le_category.inverse_transform([cat_idx])[0]
135
- subcat_pred = le_subcategory.inverse_transform([subcat_idx])[0]
136
- cat_conf = float(pred[0][0][cat_idx] * 100)
137
- subcat_conf = float(pred[1][0][subcat_idx] * 100)
138
-
139
- print(f"Predicted: category={cat_pred}, subcategory={subcat_pred}")
140
- return {
141
- "category": cat_pred,
142
- "subcategory": subcat_pred,
143
- "category_confidence": cat_conf,
144
- "subcategory_confidence": subcat_conf
145
- }
146
- except Exception as e:
147
- raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
148
-
149
- @app.get("/health")
150
- async def health_check():
151
- return {"status": "healthy"}
152
-
153
- print("API ready with lazy-loaded resources.")