Hemang1915 commited on
Commit
75289c7
·
1 Parent(s): 0cda1e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -32
app.py CHANGED
@@ -11,11 +11,6 @@ from stable_baselines3 import PPO
11
  app = FastAPI(title="Transaction Classifier API", description="API for classifying banking transactions.")
12
  print("FastAPI app initialized...")
13
 
14
- model = None
15
- tokenizer = None
16
- le_category = None
17
- le_subcategory = None
18
- ppo_model = None
19
  max_len = 20
20
 
21
  class HierarchicalPrediction(tf.keras.layers.Layer):
@@ -63,35 +58,25 @@ class HierarchicalPrediction(tf.keras.layers.Layer):
63
  tf.keras.utils.get_custom_objects()['HierarchicalPrediction'] = HierarchicalPrediction
64
 
65
  def load_resources():
66
- global model, tokenizer, le_category, le_subcategory
67
- if model is None:
68
- print("Loading BiLSTM model...")
69
- model = tf.keras.models.load_model('model.h5', custom_objects={'HierarchicalPrediction': HierarchicalPrediction})
70
- print("BiLSTM model loaded.")
71
- if tokenizer is None:
72
- print("Loading tokenizer...")
73
- with open('tokenizer.pkl', 'rb') as f:
74
- tokenizer = pickle.load(f)
75
- print("Tokenizer loaded.")
76
- if le_category is None:
77
- print("Loading category label encoder...")
78
- with open('le_category.pkl', 'rb') as f:
79
- le_category = pickle.load(f)
80
- print("Category label encoder loaded.")
81
- if le_subcategory is None:
82
- print("Loading subcategory label encoder...")
83
- with open('le_subcategory.pkl', 'rb') as f:
84
- le_subcategory = pickle.load(f)
85
- print("Subcategory label encoder loaded.")
86
  return model, tokenizer, le_category, le_subcategory
87
 
88
  def load_ppo_model():
89
- global ppo_model
90
- if ppo_model is None:
91
- print("Loading PPO model...")
92
- ppo_model = PPO.load('ppo_finetuned_model')
93
- print("PPO model loaded.")
94
- return ppo_model
95
 
96
  class TransactionRequest(BaseModel):
97
  description: str
@@ -120,7 +105,7 @@ async def predict(request: TransactionRequest):
120
  if request.use_rl:
121
  ppo = load_ppo_model()
122
  obs = pad[0]
123
- action, _ = ppo.predict(obs, deterministic=True) # Force deterministic
124
  print(f"RL Action: {action}, Observation: {obs}")
125
  cat_idx = action // num_subcategories
126
  subcat_idx = action % num_subcategories
 
11
  app = FastAPI(title="Transaction Classifier API", description="API for classifying banking transactions.")
12
  print("FastAPI app initialized...")
13
 
 
 
 
 
 
14
  max_len = 20
15
 
16
  class HierarchicalPrediction(tf.keras.layers.Layer):
 
58
  tf.keras.utils.get_custom_objects()['HierarchicalPrediction'] = HierarchicalPrediction
59
 
60
  def load_resources():
61
+ print("Loading BiLSTM model...")
62
+ model = tf.keras.models.load_model('model.h5', custom_objects={'HierarchicalPrediction': HierarchicalPrediction})
63
+ print("Loading tokenizer...")
64
+ with open('tokenizer.pkl', 'rb') as f:
65
+ tokenizer = pickle.load(f)
66
+ print("Loading category label encoder...")
67
+ with open('le_category.pkl', 'rb') as f:
68
+ le_category = pickle.load(f)
69
+ print("Loading subcategory label encoder...")
70
+ with open('le_subcategory.pkl', 'rb') as f:
71
+ le_subcategory = pickle.load(f)
72
+ print(f"Num categories: {len(le_category.classes_)}, Num subcategories: {len(le_subcategory.classes_)}")
 
 
 
 
 
 
 
 
73
  return model, tokenizer, le_category, le_subcategory
74
 
75
  def load_ppo_model():
76
+ print("Loading PPO model...")
77
+ ppo = PPO.load('ppo_finetuned_model')
78
+ print("PPO model loaded.")
79
+ return ppo
 
 
80
 
81
  class TransactionRequest(BaseModel):
82
  description: str
 
105
  if request.use_rl:
106
  ppo = load_ppo_model()
107
  obs = pad[0]
108
+ action, _ = ppo.predict(obs) # Match Colab: no deterministic=True
109
  print(f"RL Action: {action}, Observation: {obs}")
110
  cat_idx = action // num_subcategories
111
  subcat_idx = action % num_subcategories