Commit ·
693e9be
1
Parent(s): 24c8ebb
Update app.py
Browse files
app.py
CHANGED
|
@@ -11,7 +11,6 @@ from stable_baselines3 import PPO
|
|
| 11 |
app = FastAPI(title="Transaction Classifier API", description="API for classifying banking transactions.")
|
| 12 |
print("FastAPI app initialized...")
|
| 13 |
|
| 14 |
-
# Lazy-loaded resources
|
| 15 |
model = None
|
| 16 |
tokenizer = None
|
| 17 |
le_category = None
|
|
@@ -114,7 +113,6 @@ async def predict(request: TransactionRequest):
|
|
| 114 |
model, tokenizer, le_category, le_subcategory = load_resources()
|
| 115 |
num_subcategories = len(le_subcategory.classes_)
|
| 116 |
|
| 117 |
-
# Fixed: Use request.description instead of description
|
| 118 |
seq = tokenizer.texts_to_sequences([request.description])
|
| 119 |
pad = pad_sequences(seq, maxlen=max_len, padding='post')
|
| 120 |
pred = model.predict(pad, verbose=0)
|
|
@@ -123,8 +121,10 @@ async def predict(request: TransactionRequest):
|
|
| 123 |
ppo = load_ppo_model()
|
| 124 |
obs = pad[0]
|
| 125 |
action, _ = ppo.predict(obs)
|
|
|
|
| 126 |
cat_idx = action // num_subcategories
|
| 127 |
subcat_idx = action % num_subcategories
|
|
|
|
| 128 |
else:
|
| 129 |
cat_probs = pred[0][0]
|
| 130 |
subcat_probs = pred[1][0]
|
|
@@ -136,6 +136,7 @@ async def predict(request: TransactionRequest):
|
|
| 136 |
cat_conf = float(pred[0][0][cat_idx] * 100)
|
| 137 |
subcat_conf = float(pred[1][0][subcat_idx] * 100)
|
| 138 |
|
|
|
|
| 139 |
return {
|
| 140 |
"category": cat_pred,
|
| 141 |
"subcategory": subcat_pred,
|
|
|
|
| 11 |
app = FastAPI(title="Transaction Classifier API", description="API for classifying banking transactions.")
|
| 12 |
print("FastAPI app initialized...")
|
| 13 |
|
|
|
|
| 14 |
model = None
|
| 15 |
tokenizer = None
|
| 16 |
le_category = None
|
|
|
|
| 113 |
model, tokenizer, le_category, le_subcategory = load_resources()
|
| 114 |
num_subcategories = len(le_subcategory.classes_)
|
| 115 |
|
|
|
|
| 116 |
seq = tokenizer.texts_to_sequences([request.description])
|
| 117 |
pad = pad_sequences(seq, maxlen=max_len, padding='post')
|
| 118 |
pred = model.predict(pad, verbose=0)
|
|
|
|
| 121 |
ppo = load_ppo_model()
|
| 122 |
obs = pad[0]
|
| 123 |
action, _ = ppo.predict(obs)
|
| 124 |
+
print(f"RL Action: {action}, Observation: {obs}")
|
| 125 |
cat_idx = action // num_subcategories
|
| 126 |
subcat_idx = action % num_subcategories
|
| 127 |
+
print(f"RL cat_idx: {cat_idx}, subcat_idx: {subcat_idx}")
|
| 128 |
else:
|
| 129 |
cat_probs = pred[0][0]
|
| 130 |
subcat_probs = pred[1][0]
|
|
|
|
| 136 |
cat_conf = float(pred[0][0][cat_idx] * 100)
|
| 137 |
subcat_conf = float(pred[1][0][subcat_idx] * 100)
|
| 138 |
|
| 139 |
+
print(f"Predicted: category={cat_pred}, subcategory={subcat_pred}")
|
| 140 |
return {
|
| 141 |
"category": cat_pred,
|
| 142 |
"subcategory": subcat_pred,
|