Commit
·
6fc4c16
1
Parent(s):
7a4699d
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,16 +1,14 @@
|
|
| 1 |
-
from fastapi import FastAPI, HTTPException
|
| 2 |
import tensorflow as tf
|
| 3 |
-
import numpy as np
|
| 4 |
from tensorflow.keras.layers import Dense
|
| 5 |
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
|
|
|
| 6 |
import pickle
|
| 7 |
-
from
|
| 8 |
from pydantic import BaseModel
|
| 9 |
from typing import Optional
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
# Define the custom HierarchicalPrediction layer
|
| 14 |
class HierarchicalPrediction(tf.keras.layers.Layer):
|
| 15 |
def __init__(self, num_subcategories, cat_to_subcat_tensor, max_subcats_per_cat, **kwargs):
|
| 16 |
super(HierarchicalPrediction, self).__init__(**kwargs)
|
|
@@ -53,55 +51,65 @@ class HierarchicalPrediction(tf.keras.layers.Layer):
|
|
| 53 |
config['cat_to_subcat_tensor'] = tf.constant(config['cat_to_subcat_tensor'], dtype=tf.int32)
|
| 54 |
return cls(**config)
|
| 55 |
|
| 56 |
-
print("Defined HierarchicalPrediction layer...")
|
| 57 |
-
|
| 58 |
-
# Register custom layer
|
| 59 |
tf.keras.utils.get_custom_objects()['HierarchicalPrediction'] = HierarchicalPrediction
|
| 60 |
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
| 62 |
try:
|
|
|
|
| 63 |
model = tf.keras.models.load_model('model.h5', custom_objects={'HierarchicalPrediction': HierarchicalPrediction})
|
|
|
|
| 64 |
except Exception as e:
|
| 65 |
-
raise RuntimeError(f"Failed to load model: {e}")
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
le_category
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
le_subcategory
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
max_len = 20
|
| 80 |
num_subcategories = len(le_subcategory.classes_)
|
| 81 |
print(f"max_len: {max_len}, num_subcategories: {num_subcategories}")
|
| 82 |
|
| 83 |
-
print("Loading PPO model...")
|
| 84 |
try:
|
|
|
|
| 85 |
ppo_model = PPO.load('ppo_finetuned_model')
|
|
|
|
| 86 |
except Exception as e:
|
| 87 |
raise RuntimeError(f"Failed to load PPO model: {e}")
|
| 88 |
|
| 89 |
-
# Define
|
| 90 |
class TransactionRequest(BaseModel):
|
| 91 |
description: str
|
| 92 |
use_rl: Optional[bool] = False
|
| 93 |
|
| 94 |
-
# Define response model
|
| 95 |
class PredictionResponse(BaseModel):
|
| 96 |
category: str
|
| 97 |
subcategory: str
|
| 98 |
category_confidence: float
|
| 99 |
subcategory_confidence: float
|
| 100 |
|
| 101 |
-
# Initialize FastAPI app
|
| 102 |
-
app = FastAPI(title="Transaction Classifier API", description="API for classifying banking transactions.")
|
| 103 |
-
print("FastAPI app initialized...")
|
| 104 |
-
|
| 105 |
# Prediction functions
|
| 106 |
def predict_category_subcategory(description):
|
| 107 |
seq = tokenizer.texts_to_sequences([description])
|
|
@@ -131,7 +139,7 @@ def predict_with_rl(description):
|
|
| 131 |
subcat_conf = float(pred[1][0][subcat_idx] * 100)
|
| 132 |
return cat_pred, subcat_pred, cat_conf, subcat_conf
|
| 133 |
|
| 134 |
-
# Define
|
| 135 |
@app.get("/")
|
| 136 |
async def root():
|
| 137 |
return {"message": "Welcome to the Transaction Classifier API. Use POST /predict to classify transactions."}
|
|
@@ -156,4 +164,4 @@ async def predict(request: TransactionRequest):
|
|
| 156 |
async def health_check():
|
| 157 |
return {"status": "healthy"}
|
| 158 |
|
| 159 |
-
print("API
|
|
|
|
|
|
|
| 1 |
import tensorflow as tf
|
|
|
|
| 2 |
from tensorflow.keras.layers import Dense
|
| 3 |
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
| 4 |
+
import numpy as np
|
| 5 |
import pickle
|
| 6 |
+
from fastapi import FastAPI, HTTPException
|
| 7 |
from pydantic import BaseModel
|
| 8 |
from typing import Optional
|
| 9 |
+
from stable_baselines3 import PPO
|
| 10 |
|
| 11 |
+
# Define custom layer
|
|
|
|
|
|
|
| 12 |
class HierarchicalPrediction(tf.keras.layers.Layer):
|
| 13 |
def __init__(self, num_subcategories, cat_to_subcat_tensor, max_subcats_per_cat, **kwargs):
|
| 14 |
super(HierarchicalPrediction, self).__init__(**kwargs)
|
|
|
|
| 51 |
config['cat_to_subcat_tensor'] = tf.constant(config['cat_to_subcat_tensor'], dtype=tf.int32)
|
| 52 |
return cls(**config)
|
| 53 |
|
|
|
|
|
|
|
|
|
|
| 54 |
tf.keras.utils.get_custom_objects()['HierarchicalPrediction'] = HierarchicalPrediction
|
| 55 |
|
| 56 |
+
# Initialize FastAPI app at the top
|
| 57 |
+
app = FastAPI(title="Transaction Classifier API", description="API for classifying banking transactions.")
|
| 58 |
+
|
| 59 |
+
# Load models and resources
|
| 60 |
try:
|
| 61 |
+
print("Loading BiLSTM model...")
|
| 62 |
model = tf.keras.models.load_model('model.h5', custom_objects={'HierarchicalPrediction': HierarchicalPrediction})
|
| 63 |
+
print("BiLSTM model loaded.")
|
| 64 |
except Exception as e:
|
| 65 |
+
raise RuntimeError(f"Failed to load BiLSTM model: {e}")
|
| 66 |
|
| 67 |
+
try:
|
| 68 |
+
print("Loading tokenizer...")
|
| 69 |
+
with open('tokenizer.pkl', 'rb') as f:
|
| 70 |
+
tokenizer = pickle.load(f)
|
| 71 |
+
print("Tokenizer loaded.")
|
| 72 |
+
except Exception as e:
|
| 73 |
+
raise RuntimeError(f"Failed to load tokenizer: {e}")
|
| 74 |
|
| 75 |
+
try:
|
| 76 |
+
print("Loading category label encoder...")
|
| 77 |
+
with open('le_category.pkl', 'rb') as f:
|
| 78 |
+
le_category = pickle.load(f)
|
| 79 |
+
print("Category label encoder loaded.")
|
| 80 |
+
except Exception as e:
|
| 81 |
+
raise RuntimeError(f"Failed to load category label encoder: {e}")
|
| 82 |
|
| 83 |
+
try:
|
| 84 |
+
print("Loading subcategory label encoder...")
|
| 85 |
+
with open('le_subcategory.pkl', 'rb') as f:
|
| 86 |
+
le_subcategory = pickle.load(f)
|
| 87 |
+
print("Subcategory label encoder loaded.")
|
| 88 |
+
except Exception as e:
|
| 89 |
+
raise RuntimeError(f"Failed to load subcategory label encoder: {e}")
|
| 90 |
|
| 91 |
max_len = 20
|
| 92 |
num_subcategories = len(le_subcategory.classes_)
|
| 93 |
print(f"max_len: {max_len}, num_subcategories: {num_subcategories}")
|
| 94 |
|
|
|
|
| 95 |
try:
|
| 96 |
+
print("Loading PPO model...")
|
| 97 |
ppo_model = PPO.load('ppo_finetuned_model')
|
| 98 |
+
print("PPO model loaded.")
|
| 99 |
except Exception as e:
|
| 100 |
raise RuntimeError(f"Failed to load PPO model: {e}")
|
| 101 |
|
| 102 |
+
# Define models
|
| 103 |
class TransactionRequest(BaseModel):
|
| 104 |
description: str
|
| 105 |
use_rl: Optional[bool] = False
|
| 106 |
|
|
|
|
| 107 |
class PredictionResponse(BaseModel):
|
| 108 |
category: str
|
| 109 |
subcategory: str
|
| 110 |
category_confidence: float
|
| 111 |
subcategory_confidence: float
|
| 112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
# Prediction functions
|
| 114 |
def predict_category_subcategory(description):
|
| 115 |
seq = tokenizer.texts_to_sequences([description])
|
|
|
|
| 139 |
subcat_conf = float(pred[1][0][subcat_idx] * 100)
|
| 140 |
return cat_pred, subcat_pred, cat_conf, subcat_conf
|
| 141 |
|
| 142 |
+
# Define routes
|
| 143 |
@app.get("/")
|
| 144 |
async def root():
|
| 145 |
return {"message": "Welcome to the Transaction Classifier API. Use POST /predict to classify transactions."}
|
|
|
|
| 164 |
async def health_check():
|
| 165 |
return {"status": "healthy"}
|
| 166 |
|
| 167 |
+
print("API fully loaded and ready.")
|