Hemang1915 commited on
Commit
6fc4c16
·
1 Parent(s): 7a4699d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -29
app.py CHANGED
@@ -1,16 +1,14 @@
1
- from fastapi import FastAPI, HTTPException
2
  import tensorflow as tf
3
- import numpy as np
4
  from tensorflow.keras.layers import Dense
5
  from tensorflow.keras.preprocessing.sequence import pad_sequences
 
6
  import pickle
7
- from stable_baselines3 import PPO
8
  from pydantic import BaseModel
9
  from typing import Optional
 
10
 
11
- print("Starting app.py...")
12
-
13
- # Define the custom HierarchicalPrediction layer
14
  class HierarchicalPrediction(tf.keras.layers.Layer):
15
  def __init__(self, num_subcategories, cat_to_subcat_tensor, max_subcats_per_cat, **kwargs):
16
  super(HierarchicalPrediction, self).__init__(**kwargs)
@@ -53,55 +51,65 @@ class HierarchicalPrediction(tf.keras.layers.Layer):
53
  config['cat_to_subcat_tensor'] = tf.constant(config['cat_to_subcat_tensor'], dtype=tf.int32)
54
  return cls(**config)
55
 
56
- print("Defined HierarchicalPrediction layer...")
57
-
58
- # Register custom layer
59
  tf.keras.utils.get_custom_objects()['HierarchicalPrediction'] = HierarchicalPrediction
60
 
61
- print("Loading BiLSTM model...")
 
 
 
62
  try:
 
63
  model = tf.keras.models.load_model('model.h5', custom_objects={'HierarchicalPrediction': HierarchicalPrediction})
 
64
  except Exception as e:
65
- raise RuntimeError(f"Failed to load model: {e}")
66
 
67
- print("Loading tokenizer...")
68
- with open('tokenizer.pkl', 'rb') as f:
69
- tokenizer = pickle.load(f)
 
 
 
 
70
 
71
- print("Loading category label encoder...")
72
- with open('le_category.pkl', 'rb') as f:
73
- le_category = pickle.load(f)
 
 
 
 
74
 
75
- print("Loading subcategory label encoder...")
76
- with open('le_subcategory.pkl', 'rb') as f:
77
- le_subcategory = pickle.load(f)
 
 
 
 
78
 
79
  max_len = 20
80
  num_subcategories = len(le_subcategory.classes_)
81
  print(f"max_len: {max_len}, num_subcategories: {num_subcategories}")
82
 
83
- print("Loading PPO model...")
84
  try:
 
85
  ppo_model = PPO.load('ppo_finetuned_model')
 
86
  except Exception as e:
87
  raise RuntimeError(f"Failed to load PPO model: {e}")
88
 
89
- # Define request model
90
  class TransactionRequest(BaseModel):
91
  description: str
92
  use_rl: Optional[bool] = False
93
 
94
- # Define response model
95
  class PredictionResponse(BaseModel):
96
  category: str
97
  subcategory: str
98
  category_confidence: float
99
  subcategory_confidence: float
100
 
101
- # Initialize FastAPI app
102
- app = FastAPI(title="Transaction Classifier API", description="API for classifying banking transactions.")
103
- print("FastAPI app initialized...")
104
-
105
  # Prediction functions
106
  def predict_category_subcategory(description):
107
  seq = tokenizer.texts_to_sequences([description])
@@ -131,7 +139,7 @@ def predict_with_rl(description):
131
  subcat_conf = float(pred[1][0][subcat_idx] * 100)
132
  return cat_pred, subcat_pred, cat_conf, subcat_conf
133
 
134
- # Define API routes
135
  @app.get("/")
136
  async def root():
137
  return {"message": "Welcome to the Transaction Classifier API. Use POST /predict to classify transactions."}
@@ -156,4 +164,4 @@ async def predict(request: TransactionRequest):
156
  async def health_check():
157
  return {"status": "healthy"}
158
 
159
- print("API routes defined...")
 
 
1
  import tensorflow as tf
 
2
  from tensorflow.keras.layers import Dense
3
  from tensorflow.keras.preprocessing.sequence import pad_sequences
4
+ import numpy as np
5
  import pickle
6
+ from fastapi import FastAPI, HTTPException
7
  from pydantic import BaseModel
8
  from typing import Optional
9
+ from stable_baselines3 import PPO
10
 
11
+ # Define custom layer
 
 
12
  class HierarchicalPrediction(tf.keras.layers.Layer):
13
  def __init__(self, num_subcategories, cat_to_subcat_tensor, max_subcats_per_cat, **kwargs):
14
  super(HierarchicalPrediction, self).__init__(**kwargs)
 
51
  config['cat_to_subcat_tensor'] = tf.constant(config['cat_to_subcat_tensor'], dtype=tf.int32)
52
  return cls(**config)
53
 
 
 
 
54
  tf.keras.utils.get_custom_objects()['HierarchicalPrediction'] = HierarchicalPrediction
55
 
56
+ # Initialize FastAPI app at the top
57
+ app = FastAPI(title="Transaction Classifier API", description="API for classifying banking transactions.")
58
+
59
+ # Load models and resources
60
  try:
61
+ print("Loading BiLSTM model...")
62
  model = tf.keras.models.load_model('model.h5', custom_objects={'HierarchicalPrediction': HierarchicalPrediction})
63
+ print("BiLSTM model loaded.")
64
  except Exception as e:
65
+ raise RuntimeError(f"Failed to load BiLSTM model: {e}")
66
 
67
+ try:
68
+ print("Loading tokenizer...")
69
+ with open('tokenizer.pkl', 'rb') as f:
70
+ tokenizer = pickle.load(f)
71
+ print("Tokenizer loaded.")
72
+ except Exception as e:
73
+ raise RuntimeError(f"Failed to load tokenizer: {e}")
74
 
75
+ try:
76
+ print("Loading category label encoder...")
77
+ with open('le_category.pkl', 'rb') as f:
78
+ le_category = pickle.load(f)
79
+ print("Category label encoder loaded.")
80
+ except Exception as e:
81
+ raise RuntimeError(f"Failed to load category label encoder: {e}")
82
 
83
+ try:
84
+ print("Loading subcategory label encoder...")
85
+ with open('le_subcategory.pkl', 'rb') as f:
86
+ le_subcategory = pickle.load(f)
87
+ print("Subcategory label encoder loaded.")
88
+ except Exception as e:
89
+ raise RuntimeError(f"Failed to load subcategory label encoder: {e}")
90
 
91
  max_len = 20
92
  num_subcategories = len(le_subcategory.classes_)
93
  print(f"max_len: {max_len}, num_subcategories: {num_subcategories}")
94
 
 
95
  try:
96
+ print("Loading PPO model...")
97
  ppo_model = PPO.load('ppo_finetuned_model')
98
+ print("PPO model loaded.")
99
  except Exception as e:
100
  raise RuntimeError(f"Failed to load PPO model: {e}")
101
 
102
+ # Define models
103
  class TransactionRequest(BaseModel):
104
  description: str
105
  use_rl: Optional[bool] = False
106
 
 
107
  class PredictionResponse(BaseModel):
108
  category: str
109
  subcategory: str
110
  category_confidence: float
111
  subcategory_confidence: float
112
 
 
 
 
 
113
  # Prediction functions
114
  def predict_category_subcategory(description):
115
  seq = tokenizer.texts_to_sequences([description])
 
139
  subcat_conf = float(pred[1][0][subcat_idx] * 100)
140
  return cat_pred, subcat_pred, cat_conf, subcat_conf
141
 
142
+ # Define routes
143
  @app.get("/")
144
  async def root():
145
  return {"message": "Welcome to the Transaction Classifier API. Use POST /predict to classify transactions."}
 
164
  async def health_check():
165
  return {"status": "healthy"}
166
 
167
+ print("API fully loaded and ready.")