Hemang1915 commited on
Commit
aea9c98
·
1 Parent(s): 5bdbedd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -83
app.py CHANGED
@@ -1,14 +1,24 @@
 
 
 
1
  import tensorflow as tf
2
  from tensorflow.keras.layers import Dense
3
  from tensorflow.keras.preprocessing.sequence import pad_sequences
4
  import numpy as np
5
  import pickle
6
- from fastapi import FastAPI, HTTPException
7
- from pydantic import BaseModel
8
- from typing import Optional
9
  from stable_baselines3 import PPO
10
 
11
- # Define custom layer
 
 
 
 
 
 
 
 
 
 
12
  class HierarchicalPrediction(tf.keras.layers.Layer):
13
  def __init__(self, num_subcategories, cat_to_subcat_tensor, max_subcats_per_cat, **kwargs):
14
  super(HierarchicalPrediction, self).__init__(**kwargs)
@@ -53,53 +63,37 @@ class HierarchicalPrediction(tf.keras.layers.Layer):
53
 
54
  tf.keras.utils.get_custom_objects()['HierarchicalPrediction'] = HierarchicalPrediction
55
 
56
- # Initialize FastAPI app at the top
57
- app = FastAPI(title="Transaction Classifier API", description="API for classifying banking transactions.")
58
-
59
- # Load models and resources
60
- try:
61
- print("Loading BiLSTM model...")
62
- model = tf.keras.models.load_model('model.h5', custom_objects={'HierarchicalPrediction': HierarchicalPrediction})
63
- print("BiLSTM model loaded.")
64
- except Exception as e:
65
- raise RuntimeError(f"Failed to load BiLSTM model: {e}")
66
-
67
- try:
68
- print("Loading tokenizer...")
69
- with open('tokenizer.pkl', 'rb') as f:
70
- tokenizer = pickle.load(f)
71
- print("Tokenizer loaded.")
72
- except Exception as e:
73
- raise RuntimeError(f"Failed to load tokenizer: {e}")
74
-
75
- try:
76
- print("Loading category label encoder...")
77
- with open('le_category.pkl', 'rb') as f:
78
- le_category = pickle.load(f)
79
- print("Category label encoder loaded.")
80
- except Exception as e:
81
- raise RuntimeError(f"Failed to load category label encoder: {e}")
82
-
83
- try:
84
- print("Loading subcategory label encoder...")
85
- with open('le_subcategory.pkl', 'rb') as f:
86
- le_subcategory = pickle.load(f)
87
- print("Subcategory label encoder loaded.")
88
- except Exception as e:
89
- raise RuntimeError(f"Failed to load subcategory label encoder: {e}")
90
-
91
- max_len = 20
92
- num_subcategories = len(le_subcategory.classes_)
93
- print(f"max_len: {max_len}, num_subcategories: {num_subcategories}")
94
-
95
- try:
96
- print("Loading PPO model...")
97
- ppo_model = PPO.load('ppo_finetuned_model')
98
- print("PPO model loaded.")
99
- except Exception as e:
100
- raise RuntimeError(f"Failed to load PPO model: {e}")
101
 
102
- # Define models
103
  class TransactionRequest(BaseModel):
104
  description: str
105
  use_rl: Optional[bool] = False
@@ -110,36 +104,6 @@ class PredictionResponse(BaseModel):
110
  category_confidence: float
111
  subcategory_confidence: float
112
 
113
- # Prediction functions
114
- def predict_category_subcategory(description):
115
- seq = tokenizer.texts_to_sequences([description])
116
- pad = pad_sequences(seq, maxlen=max_len, padding='post')
117
- pred = model.predict(pad, verbose=0)
118
- cat_probs = pred[0][0]
119
- subcat_probs = pred[1][0]
120
- cat_idx = np.argmax(cat_probs)
121
- subcat_idx = np.argmax(subcat_probs)
122
- cat_pred = le_category.inverse_transform([cat_idx])[0]
123
- subcat_pred = le_subcategory.inverse_transform([subcat_idx])[0]
124
- cat_conf = float(cat_probs[cat_idx] * 100)
125
- subcat_conf = float(subcat_probs[subcat_idx] * 100)
126
- return cat_pred, subcat_pred, cat_conf, subcat_conf
127
-
128
- def predict_with_rl(description):
129
- seq = tokenizer.texts_to_sequences([description])
130
- pad = pad_sequences(seq, maxlen=max_len, padding='post')
131
- pred = model.predict(pad, verbose=0)
132
- obs = pad[0]
133
- action, _ = ppo_model.predict(obs)
134
- cat_idx = action // num_subcategories
135
- subcat_idx = action % num_subcategories
136
- cat_pred = le_category.inverse_transform([cat_idx])[0]
137
- subcat_pred = le_subcategory.inverse_transform([subcat_idx])[0]
138
- cat_conf = float(pred[0][0][cat_idx] * 100)
139
- subcat_conf = float(pred[1][0][subcat_idx] * 100)
140
- return cat_pred, subcat_pred, cat_conf, subcat_conf
141
-
142
- # Define routes
143
  @app.get("/")
144
  async def root():
145
  return {"message": "Welcome to the Transaction Classifier API. Use POST /predict to classify transactions."}
@@ -147,10 +111,30 @@ async def root():
147
  @app.post("/predict", response_model=PredictionResponse)
148
  async def predict(request: TransactionRequest):
149
  try:
 
 
 
 
 
 
 
150
  if request.use_rl:
151
- cat_pred, subcat_pred, cat_conf, subcat_conf = predict_with_rl(request.description)
 
 
 
 
152
  else:
153
- cat_pred, subcat_pred, cat_conf, subcat_conf = predict_category_subcategory(request.description)
 
 
 
 
 
 
 
 
 
154
  return {
155
  "category": cat_pred,
156
  "subcategory": subcat_pred,
@@ -164,4 +148,4 @@ async def predict(request: TransactionRequest):
164
  async def health_check():
165
  return {"status": "healthy"}
166
 
167
- print("API fully loaded and ready.")
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import Optional
4
  import tensorflow as tf
5
  from tensorflow.keras.layers import Dense
6
  from tensorflow.keras.preprocessing.sequence import pad_sequences
7
  import numpy as np
8
  import pickle
 
 
 
9
  from stable_baselines3 import PPO
10
 
11
+ app = FastAPI(title="Transaction Classifier API", description="API for classifying banking transactions.")
12
+ print("FastAPI app initialized...")
13
+
14
+ # Lazy-loaded resources
15
+ model = None
16
+ tokenizer = None
17
+ le_category = None
18
+ le_subcategory = None
19
+ ppo_model = None
20
+ max_len = 20
21
+
22
  class HierarchicalPrediction(tf.keras.layers.Layer):
23
  def __init__(self, num_subcategories, cat_to_subcat_tensor, max_subcats_per_cat, **kwargs):
24
  super(HierarchicalPrediction, self).__init__(**kwargs)
 
63
 
64
  tf.keras.utils.get_custom_objects()['HierarchicalPrediction'] = HierarchicalPrediction
65
 
66
+ def load_resources():
67
+ global model, tokenizer, le_category, le_subcategory
68
+ if model is None:
69
+ print("Loading BiLSTM model...")
70
+ model = tf.keras.models.load_model('model.h5', custom_objects={'HierarchicalPrediction': HierarchicalPrediction})
71
+ print("BiLSTM model loaded.")
72
+ if tokenizer is None:
73
+ print("Loading tokenizer...")
74
+ with open('tokenizer.pkl', 'rb') as f:
75
+ tokenizer = pickle.load(f)
76
+ print("Tokenizer loaded.")
77
+ if le_category is None:
78
+ print("Loading category label encoder...")
79
+ with open('le_category.pkl', 'rb') as f:
80
+ le_category = pickle.load(f)
81
+ print("Category label encoder loaded.")
82
+ if le_subcategory is None:
83
+ print("Loading subcategory label encoder...")
84
+ with open('le_subcategory.pkl', 'rb') as f:
85
+ le_subcategory = pickle.load(f)
86
+ print("Subcategory label encoder loaded.")
87
+ return model, tokenizer, le_category, le_subcategory
88
+
89
+ def load_ppo_model():
90
+ global ppo_model
91
+ if ppo_model is None:
92
+ print("Loading PPO model...")
93
+ ppo_model = PPO.load('ppo_finetuned_model')
94
+ print("PPO model loaded.")
95
+ return ppo_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
 
97
  class TransactionRequest(BaseModel):
98
  description: str
99
  use_rl: Optional[bool] = False
 
104
  category_confidence: float
105
  subcategory_confidence: float
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  @app.get("/")
108
  async def root():
109
  return {"message": "Welcome to the Transaction Classifier API. Use POST /predict to classify transactions."}
 
111
  @app.post("/predict", response_model=PredictionResponse)
112
  async def predict(request: TransactionRequest):
113
  try:
114
+ model, tokenizer, le_category, le_subcategory = load_resources()
115
+ num_subcategories = len(le_subcategory.classes_)
116
+
117
+ seq = tokenizer.texts_to_sequences([description])
118
+ pad = pad_sequences(seq, maxlen=max_len, padding='post')
119
+ pred = model.predict(pad, verbose=0)
120
+
121
  if request.use_rl:
122
+ ppo = load_ppo_model()
123
+ obs = pad[0]
124
+ action, _ = ppo.predict(obs)
125
+ cat_idx = action // num_subcategories
126
+ subcat_idx = action % num_subcategories
127
  else:
128
+ cat_probs = pred[0][0]
129
+ subcat_probs = pred[1][0]
130
+ cat_idx = np.argmax(cat_probs)
131
+ subcat_idx = np.argmax(subcat_probs)
132
+
133
+ cat_pred = le_category.inverse_transform([cat_idx])[0]
134
+ subcat_pred = le_subcategory.inverse_transform([subcat_idx])[0]
135
+ cat_conf = float(pred[0][0][cat_idx] * 100)
136
+ subcat_conf = float(pred[1][0][subcat_idx] * 100)
137
+
138
  return {
139
  "category": cat_pred,
140
  "subcategory": subcat_pred,
 
148
  async def health_check():
149
  return {"status": "healthy"}
150
 
151
+ print("API ready with lazy-loaded resources.")