Commit
·
75289c7
1
Parent(s):
0cda1e8
Update app.py
Browse files
app.py
CHANGED
|
@@ -11,11 +11,6 @@ from stable_baselines3 import PPO
|
|
| 11 |
app = FastAPI(title="Transaction Classifier API", description="API for classifying banking transactions.")
|
| 12 |
print("FastAPI app initialized...")
|
| 13 |
|
| 14 |
-
model = None
|
| 15 |
-
tokenizer = None
|
| 16 |
-
le_category = None
|
| 17 |
-
le_subcategory = None
|
| 18 |
-
ppo_model = None
|
| 19 |
max_len = 20
|
| 20 |
|
| 21 |
class HierarchicalPrediction(tf.keras.layers.Layer):
|
|
@@ -63,35 +58,25 @@ class HierarchicalPrediction(tf.keras.layers.Layer):
|
|
| 63 |
tf.keras.utils.get_custom_objects()['HierarchicalPrediction'] = HierarchicalPrediction
|
| 64 |
|
| 65 |
def load_resources():
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
with open('le_category.pkl', 'rb') as f:
|
| 79 |
-
le_category = pickle.load(f)
|
| 80 |
-
print("Category label encoder loaded.")
|
| 81 |
-
if le_subcategory is None:
|
| 82 |
-
print("Loading subcategory label encoder...")
|
| 83 |
-
with open('le_subcategory.pkl', 'rb') as f:
|
| 84 |
-
le_subcategory = pickle.load(f)
|
| 85 |
-
print("Subcategory label encoder loaded.")
|
| 86 |
return model, tokenizer, le_category, le_subcategory
|
| 87 |
|
| 88 |
def load_ppo_model():
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
print("PPO model loaded.")
|
| 94 |
-
return ppo_model
|
| 95 |
|
| 96 |
class TransactionRequest(BaseModel):
|
| 97 |
description: str
|
|
@@ -120,7 +105,7 @@ async def predict(request: TransactionRequest):
|
|
| 120 |
if request.use_rl:
|
| 121 |
ppo = load_ppo_model()
|
| 122 |
obs = pad[0]
|
| 123 |
-
action, _ = ppo.predict(obs
|
| 124 |
print(f"RL Action: {action}, Observation: {obs}")
|
| 125 |
cat_idx = action // num_subcategories
|
| 126 |
subcat_idx = action % num_subcategories
|
|
|
|
| 11 |
app = FastAPI(title="Transaction Classifier API", description="API for classifying banking transactions.")
|
| 12 |
print("FastAPI app initialized...")
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
max_len = 20
|
| 15 |
|
| 16 |
class HierarchicalPrediction(tf.keras.layers.Layer):
|
|
|
|
| 58 |
tf.keras.utils.get_custom_objects()['HierarchicalPrediction'] = HierarchicalPrediction
|
| 59 |
|
| 60 |
def load_resources():
|
| 61 |
+
print("Loading BiLSTM model...")
|
| 62 |
+
model = tf.keras.models.load_model('model.h5', custom_objects={'HierarchicalPrediction': HierarchicalPrediction})
|
| 63 |
+
print("Loading tokenizer...")
|
| 64 |
+
with open('tokenizer.pkl', 'rb') as f:
|
| 65 |
+
tokenizer = pickle.load(f)
|
| 66 |
+
print("Loading category label encoder...")
|
| 67 |
+
with open('le_category.pkl', 'rb') as f:
|
| 68 |
+
le_category = pickle.load(f)
|
| 69 |
+
print("Loading subcategory label encoder...")
|
| 70 |
+
with open('le_subcategory.pkl', 'rb') as f:
|
| 71 |
+
le_subcategory = pickle.load(f)
|
| 72 |
+
print(f"Num categories: {len(le_category.classes_)}, Num subcategories: {len(le_subcategory.classes_)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
return model, tokenizer, le_category, le_subcategory
|
| 74 |
|
| 75 |
def load_ppo_model():
|
| 76 |
+
print("Loading PPO model...")
|
| 77 |
+
ppo = PPO.load('ppo_finetuned_model')
|
| 78 |
+
print("PPO model loaded.")
|
| 79 |
+
return ppo
|
|
|
|
|
|
|
| 80 |
|
| 81 |
class TransactionRequest(BaseModel):
|
| 82 |
description: str
|
|
|
|
| 105 |
if request.use_rl:
|
| 106 |
ppo = load_ppo_model()
|
| 107 |
obs = pad[0]
|
| 108 |
+
action, _ = ppo.predict(obs) # Match Colab: no deterministic=True
|
| 109 |
print(f"RL Action: {action}, Observation: {obs}")
|
| 110 |
cat_idx = action // num_subcategories
|
| 111 |
subcat_idx = action % num_subcategories
|