from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import Optional import tensorflow as tf from tensorflow.keras.layers import Dense from tensorflow.keras.preprocessing.sequence import pad_sequences import numpy as np import pickle app = FastAPI(title="Transaction Classifier API", description="API for classifying banking transactions.") print("FastAPI app initialized...") max_len = 20 class HierarchicalPrediction(tf.keras.layers.Layer): def __init__(self, num_subcategories, cat_to_subcat_tensor, max_subcats_per_cat, **kwargs): super(HierarchicalPrediction, self).__init__(**kwargs) self.num_subcategories = num_subcategories self.cat_to_subcat_tensor = cat_to_subcat_tensor self.max_subcats_per_cat = max_subcats_per_cat self.subcategory_dense = Dense(num_subcategories, activation=None) def build(self, input_shape): super(HierarchicalPrediction, self).build(input_shape) def call(self, inputs): lstm_output, category_probs = inputs subcat_logits = self.subcategory_dense(lstm_output) batch_size = tf.shape(category_probs)[0] predicted_categories = tf.argmax(category_probs, axis=1) valid_subcat_indices = tf.gather(self.cat_to_subcat_tensor, predicted_categories) batch_indices = tf.range(batch_size) batch_indices_expanded = tf.tile(batch_indices[:, tf.newaxis], [1, self.max_subcats_per_cat]) update_indices = tf.stack([batch_indices_expanded, valid_subcat_indices], axis=-1) update_indices = tf.reshape(update_indices, [-1, 2]) valid_mask = tf.not_equal(valid_subcat_indices, -1) valid_indices = tf.boolean_mask(update_indices, tf.reshape(valid_mask, [-1])) updates = tf.ones(tf.shape(valid_indices)[0], dtype=tf.float32) mask = tf.scatter_nd(valid_indices, updates, [batch_size, self.num_subcategories]) masked_logits = subcat_logits * mask + (1 - mask) * tf.float32.min return tf.nn.softmax(masked_logits) def get_config(self): config = super(HierarchicalPrediction, self).get_config() config.update({ 'num_subcategories': self.num_subcategories, 'cat_to_subcat_tensor': self.cat_to_subcat_tensor.numpy(), 'max_subcats_per_cat': self.max_subcats_per_cat }) return config @classmethod def from_config(cls, config): config['cat_to_subcat_tensor'] = tf.constant(config['cat_to_subcat_tensor'], dtype=tf.int32) return cls(**config) tf.keras.utils.get_custom_objects()['HierarchicalPrediction'] = HierarchicalPrediction def load_resources(): print("Loading BiLSTM model...") model = tf.keras.models.load_model('model.h5', custom_objects={'HierarchicalPrediction': HierarchicalPrediction}) print("Loading tokenizer...") with open('tokenizer.pkl', 'rb') as f: tokenizer = pickle.load(f) print("Loading category label encoder...") with open('le_category.pkl', 'rb') as f: le_category = pickle.load(f) print("Loading subcategory label encoder...") with open('le_subcategory.pkl', 'rb') as f: le_subcategory = pickle.load(f) print(f"Num categories: {len(le_category.classes_)}, Num subcategories: {len(le_subcategory.classes_)}") return model, tokenizer, le_category, le_subcategory class TransactionRequest(BaseModel): description: str class PredictionResponse(BaseModel): category: str subcategory: str category_confidence: float subcategory_confidence: float @app.get("/") async def root(): return {"message": "Welcome to the Transaction Classifier API. Use POST /predict to classify transactions."} @app.post("/predict", response_model=PredictionResponse) async def predict(request: TransactionRequest): try: model, tokenizer, le_category, le_subcategory = load_resources() seq = tokenizer.texts_to_sequences([request.description]) pad = pad_sequences(seq, maxlen=max_len, padding='post') pred = model.predict(pad, verbose=0) cat_probs = pred[0][0] subcat_probs = pred[1][0] cat_idx = np.argmax(cat_probs) subcat_idx = np.argmax(subcat_probs) cat_pred = le_category.inverse_transform([cat_idx])[0] subcat_pred = le_subcategory.inverse_transform([subcat_idx])[0] cat_conf = float(cat_probs[cat_idx] * 100) subcat_conf = float(subcat_probs[subcat_idx] * 100) print(f"Predicted: category={cat_pred}, subcategory={subcat_pred}") return { "category": cat_pred, "subcategory": subcat_pred, "category_confidence": cat_conf, "subcategory_confidence": subcat_conf } except Exception as e: raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") @app.get("/health") async def health_check(): return {"status": "healthy"} print("API ready with lazy-loaded resources.")