Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import joblib | |
| import pandas as pd | |
| from typing import Dict, Any, List, Union, Optional | |
| from fastapi import FastAPI, HTTPException, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| import numpy as np | |
| import warnings | |
| import random | |
| import google.generativeai as genai | |
| import json | |
| import re | |
| # Suppress sklearn version warnings | |
| warnings.filterwarnings("ignore", category=UserWarning, module="sklearn.base") | |
| # --- FIX FOR SKLEARN VERSION COMPATIBILITY --- | |
| try: | |
| import sklearn | |
| print(f"📦 scikit-learn version: {sklearn.__version__}") | |
| # Fix for _RemainderColsList compatibility issue | |
| # We explicitly import ColumnTransformer to ensure the module path is correct | |
| from sklearn.compose._column_transformer import ColumnTransformer | |
| # Check if _RemainderColsList exists, if not create a dummy class | |
| if not hasattr(sys.modules['sklearn.compose._column_transformer'], '_RemainderColsList'): | |
| class _RemainderColsList(list): | |
| """Compatibility shim for older sklearn pickled models""" | |
| pass | |
| # Add it to the module so pickle can find it | |
| sys.modules['sklearn.compose._column_transformer']._RemainderColsList = _RemainderColsList | |
| print("✅ Applied sklearn compatibility patch for _RemainderColsList") | |
| except Exception as e: | |
| print(f"⚠️ Warning during sklearn compatibility setup: {e}") | |
| # --- MODEL CONFIGURATION & CONSTANTS --- | |
| VERSION = "1.1" | |
| MODELS = {} # Global dictionary to store loaded pipelines | |
| # Update this map based on the actual model names saved by your training script | |
| MODEL_MAP = { | |
| "decision_tree": f"classifier/ccfd_{VERSION}_decision-tree.pkl", | |
| "random_forest": f"classifier/ccfd_{VERSION}_random-forest.pkl", | |
| "xgboost": f"classifier/ccfd_{VERSION}_xg-boost.pkl", | |
| } | |
| # ------------------------------------------------------------------- | |
| # 🎯 CRITICAL FEATURE DEFINITIONS FROM TRAINING SCRIPT | |
| # ------------------------------------------------------------------- | |
| CATEGORICAL_FEATURES = [ | |
| "merchant", "category", "gender", "state", "job" | |
| ] | |
| NUMERICAL_FEATURES = [ | |
| "cc_num", "amt", "zip", "lat", "long", "city_pop", "unix_time", | |
| "merch_lat", "merch_long", "age", "trans_hour", "trans_day", | |
| "trans_month", "trans_weekday", "distance" | |
| ] | |
| # Ensure the order matches the columns fed to the ColumnTransformer during training | |
| EXPECTED_FEATURES = CATEGORICAL_FEATURES + NUMERICAL_FEATURES | |
| # --- DATA CONSTANTS --- | |
| DATA_FILE_PATH = "data/filteredTest.parquet" | |
| DATA_DF: Optional[pd.DataFrame] = None # Global variable to cache the data | |
| origins = [ | |
| "http://localhost:3000", | |
| "http://127.0.0.1:3000", | |
| "https://http://ai-credit-card-fraud-detection.vercel.app" # Update with your actual frontend domain | |
| ] | |
| # --- FASTAPI SETUP --- | |
| app = FastAPI( | |
| title="Credit Card Fraud Detection API", | |
| version=VERSION, | |
| description="Pure API server for fraud detection using ML models. Returns fraud_score (probability 0-100%)." | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, # The list of allowed origins defined above | |
| allow_credentials=True, # Allow cookies/authorization headers | |
| allow_methods=["*"], # Allow all HTTP methods (GET, POST, PUT, etc.) | |
| allow_headers=["*"], # Allow all headers | |
| ) | |
| class SingleTransactionPayload(BaseModel): | |
| model_name: str = Field(..., description="Model alias (e.g., 'decision_tree', 'random_forest', 'xgboost').") | |
| features: Dict[str, Any] = Field(..., description="Single transaction record for prediction.") | |
| class MultipleTransactionsPayload(BaseModel): | |
| model_name: str = Field(..., description="Model alias (e.g., 'decision_tree', 'random_forest', 'xgboost').") | |
| features: List[Dict[str, Any]] = Field(..., description="List of transaction records for prediction.") | |
| class LLMAnalysePayload(BaseModel): | |
| transactions: List[Dict[str, Any]] = Field(..., description="List of transaction records with 22 fields including fraud_score, STATUS, etc.") | |
| # Configure Gemini API | |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
| if GEMINI_API_KEY: | |
| genai.configure(api_key=GEMINI_API_KEY) | |
| print("✅ Gemini API configured") | |
| else: | |
| print("⚠️ GEMINI_API_KEY not set in environment variables. LLM endpoint will fail.") | |
| # --- LOAD MODELS AT STARTUP --- | |
| def load_pipelines(): | |
| """Load all ML model pipelines""" | |
| import sklearn | |
| print(f"🚀 Loading models for server version: {VERSION}") | |
| print(f"📦 Using scikit-learn: {sklearn.__version__}") | |
| print(f"📂 Current working directory: {os.getcwd()}") | |
| for alias, filename in MODEL_MAP.items(): | |
| try: | |
| # Check if file exists | |
| if not os.path.exists(filename): | |
| abs_path = os.path.abspath(filename) | |
| print(f"❌ Model file not found: {filename}") | |
| print(f" Expected at: {abs_path}") | |
| continue | |
| # Get file info | |
| file_size = os.path.getsize(filename) / (1024 * 1024) # MB | |
| print(f"📥 Loading {alias} from {filename} ({file_size:.2f} MB)...") | |
| # Load the model | |
| MODELS[alias] = joblib.load(filename) | |
| print(f"✅ Successfully loaded {alias}") | |
| except AttributeError as e: | |
| print(f"❌ Compatibility error loading {filename}") | |
| print(f" Error: {e}") | |
| print(f" 💡 This usually means the model was saved with a different sklearn version") | |
| print(f" 💡 Try re-training and saving the model with sklearn {sklearn.__version__}") | |
| except Exception as e: | |
| print(f"❌ Failed to load {filename}") | |
| print(f" Error type: {type(e).__name__}") | |
| print(f" Error message: {e}") | |
| if not MODELS: | |
| print("⚠️ No models loaded. Predictions will fail.") | |
| print(" 💡 Ensure .pkl files are in the same directory as app.py (or subdirectories like model_outputs/)") | |
| print(" 💡 Check that models were saved with compatible sklearn version") | |
| else: | |
| print(f"✅ Successfully loaded {len(MODELS)} model(s): {list(MODELS.keys())}") | |
| # Load models on import | |
| load_pipelines() | |
| # --- HELPER FUNCTION: CACHE DATA --- | |
| def load_data_file() -> Optional[pd.DataFrame]: | |
| """Load the Parquet data file into the global DATA_DF variable.""" | |
| global DATA_DF | |
| if DATA_DF is not None: | |
| return DATA_DF | |
| try: | |
| if not os.path.exists(DATA_FILE_PATH): | |
| abs_path = os.path.abspath(DATA_FILE_PATH) | |
| print(f"❌ Data file not found: {DATA_FILE_PATH}") | |
| print(f" Expected at: {abs_path}") | |
| return None | |
| print(f"💾 Loading data from {DATA_FILE_PATH}...") | |
| # Use pyarrow engine for better performance with parquet | |
| DATA_DF = pd.read_parquet(DATA_FILE_PATH, engine='pyarrow') | |
| print(f"✅ Successfully loaded data with {len(DATA_DF)} rows.") | |
| return DATA_DF | |
| except Exception as e: | |
| print(f"❌ Failed to load data file: {e}") | |
| return None | |
| # Load data on import for the new endpoint | |
| load_data_file() | |
| # --- HELPER FUNCTION: PREPARE FEATURES (WITH FIX) --- | |
| def prepare_features(features_list: List[Dict[str, Any]]) -> pd.DataFrame: | |
| """ | |
| Validate and prepare features for prediction. | |
| CRITICAL FIX: Explicitly converts numerical columns to float to avoid | |
| 'scipy.sparse does not support dtype object' error. | |
| """ | |
| df_features = pd.DataFrame(features_list) | |
| # Check for missing features | |
| missing_features = set(EXPECTED_FEATURES) - set(df_features.columns) | |
| if missing_features: | |
| raise ValueError(f"Missing required features: {list(missing_features)}") | |
| # Reorder columns to match expected order | |
| df_features = df_features[EXPECTED_FEATURES] | |
| # FIX: Ensure numerical columns are not 'object' (string) type | |
| for col in NUMERICAL_FEATURES: | |
| # Use pd.to_numeric to handle incoming data that might be strings/ints/floats | |
| df_features[col] = pd.to_numeric(df_features[col], errors='coerce') | |
| # Convert categorical columns to category dtype (as done during training) | |
| for col in CATEGORICAL_FEATURES: | |
| # NOTE: Ensure that all categories present here were also present during training | |
| # For a simple API, we rely on the model's pipeline to handle unseen categories | |
| # (usually by converting them to NaN or a dummy 'unseen' category). | |
| df_features[col] = df_features[col].astype("category") | |
| return df_features | |
| def extract_json_from_markdown(text: str) -> str: | |
| """ | |
| Extract JSON content from markdown code block. | |
| Handles cases where the LLM wraps the output in ```json ... ``` | |
| Cleans up problematic escape characters for json.loads. | |
| """ | |
| # Look for ```json ... ``` | |
| match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', text, re.DOTALL | re.IGNORECASE) | |
| if match: | |
| json_str = match.group(1).strip() | |
| else: | |
| json_str = text.strip() | |
| # Remove LLM-inserted escape sequences like \$ or \" | |
| # First, convert escaped newlines and tabs to spaces | |
| json_str = json_str.replace('\\n', ' ').replace('\\t', ' ') | |
| # Then remove unneeded backslashes before non-JSON characters | |
| json_str = re.sub(r'\\(?=[^"\\/bfnrtu])', '', json_str) | |
| # Collapse multiple spaces | |
| json_str = re.sub(r'\s+', ' ', json_str).strip() | |
| return json_str | |
| # --- FASTAPI ENDPOINTS --- | |
| async def root(): | |
| """Root endpoint - API information""" | |
| return { | |
| "status": "ok", | |
| "message": "Credit Card Fraud Detection API", | |
| "version": VERSION, | |
| "models_loaded": list(MODELS.keys()), | |
| "endpoints": { | |
| "health": "/health", | |
| "models": "/models", | |
| "predict": "/predict (POST) - Single transaction", | |
| "predict_multiple": "/predict_multiple (POST) - Multiple transactions", | |
| "random_data": "/get-random-data (GET) - Get sample data for testing", # ADDED | |
| "llm_analyse": "/llm-analyse (POST) - LLM analysis of transactions", | |
| "docs": "/docs" | |
| }, | |
| "response_format": { | |
| "description": "Returns fraud_score (probability 0-100%) for fraud class", | |
| "single": {"fraud_score": "float (0-100)"}, | |
| "multiple": { | |
| "predictions": "list of {'fraud_score': float}", | |
| "overall_stats": { | |
| "total": "int", | |
| "avg_fraud_score": "float", | |
| "min_fraud_score": "float", | |
| "max_fraud_score": "float" | |
| } | |
| }, | |
| "llm_analyse": { | |
| "fraud_score": "float (0-1, e.g., 0.12 for 12%)", | |
| "explanation": "str" | |
| } | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy" if MODELS and DATA_DF is not None else "degraded", | |
| "version": VERSION, | |
| "models_loaded": list(MODELS.keys()), | |
| "model_count": len(MODELS), | |
| "data_loaded": DATA_DF is not None, | |
| "gemini_configured": GEMINI_API_KEY is not None | |
| } | |
| async def list_models(): | |
| """List all available and loaded models""" | |
| return { | |
| "available_models": list(MODEL_MAP.keys()), | |
| "loaded_models": list(MODELS.keys()), | |
| "model_files": MODEL_MAP, | |
| "version": VERSION | |
| } | |
| async def get_random_data( | |
| num_rows: int = Query( | |
| 10, | |
| ge=1, | |
| le=1000, | |
| description="The number of random rows to return (between 1 and 1000)." | |
| ) | |
| ): | |
| """ | |
| Retrieves a specified number of random transaction records from the dataset. | |
| It ensures that at least one fraudulent (is_fraud=True) record is included, | |
| suitable for testing the prediction endpoints. | |
| """ | |
| df = load_data_file() | |
| if df is None: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Data file not loaded. Check server logs for {DATA_FILE_PATH}" | |
| ) | |
| total_rows = len(df) | |
| if num_rows > total_rows: | |
| num_rows = total_rows | |
| try: | |
| # 1. Separate fraudulent and non-fraudulent transactions | |
| fraud_df = df[df['is_fraud'] == 1].copy() | |
| non_fraud_df = df[df['is_fraud'] == 0].copy() | |
| final_sample_df = pd.DataFrame() | |
| # 2. Ensure at least one fraudulent transaction is included (if available) | |
| if not fraud_df.empty: | |
| # Take 1 fraudulent transaction | |
| fraud_sample = fraud_df.sample(n=1) | |
| final_sample_df = pd.concat([final_sample_df, fraud_sample]) | |
| # Reduce the remaining rows needed | |
| rows_needed = num_rows - 1 | |
| else: | |
| # If no fraud data, just take the requested number of rows from non-fraud | |
| rows_needed = num_rows | |
| # 3. Fill the rest of the sample from the remaining data | |
| if rows_needed > 0: | |
| # Max rows to sample from non-fraudulent data, limited by available data | |
| non_fraud_sample_size = min(rows_needed, len(non_fraud_df)) | |
| if non_fraud_sample_size > 0: | |
| non_fraud_sample = non_fraud_df.sample(n=non_fraud_sample_size) | |
| final_sample_df = pd.concat([final_sample_df, non_fraud_sample]) | |
| # 4. Final processing | |
| # Drop the 'is_fraud' column | |
| if 'is_fraud' in final_sample_df.columns: | |
| final_sample_df = final_sample_df.drop(columns=['is_fraud']) | |
| # Ensure the output columns match the expected input features for the predict endpoints | |
| final_cols = [col for col in EXPECTED_FEATURES if col in final_sample_df.columns] | |
| random_sample_df = final_sample_df[final_cols] | |
| # Convert to a list of dicts (JSON serializable format) | |
| data_records = random_sample_df.to_dict(orient='records') | |
| # Shuffle the final list to avoid placing the guaranteed fraud row always first | |
| random.shuffle(data_records) | |
| return { | |
| "success": True, | |
| "message": f"Returned {len(data_records)} random records (guaranteed at least one fraud if available).", | |
| "data": data_records | |
| } | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Error processing data request: {str(e)}" | |
| ) | |
| async def predict_single(payload: SingleTransactionPayload): | |
| """ | |
| Predict fraud score for a SINGLE transaction | |
| Returns fraud_score (probability 0-100% for fraud class) | |
| """ | |
| model_name = payload.model_name | |
| features = payload.features | |
| # Validate model exists | |
| if model_name not in MODELS: | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"Model '{model_name}' not loaded. Available: {list(MODELS.keys())}" | |
| ) | |
| model_pipeline = MODELS[model_name] | |
| # Prepare features | |
| try: | |
| df_features = prepare_features([features]) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=422, | |
| detail=f"Data validation failed: {str(e)}" | |
| ) | |
| # Perform prediction | |
| try: | |
| # Get probability (0-100%) - convert to Python float for JSON serialization | |
| # The probability of the positive class (1, fraud) is at index 1 | |
| probability = float(model_pipeline.predict_proba(df_features)[:, 1][0] * 100) | |
| return { | |
| "success": True, | |
| "model_used": model_name, | |
| "fraud_score": round(probability, 2) | |
| } | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Prediction execution failed: {type(e).__name__}: {str(e)}" | |
| ) | |
| async def predict_multiple(payload: MultipleTransactionsPayload): | |
| """ | |
| Predict fraud scores for MULTIPLE transactions | |
| Returns fraud_score (0-100%) for each transaction, plus overall statistics | |
| """ | |
| model_name = payload.model_name | |
| features_list = payload.features | |
| # Validate model exists | |
| if model_name not in MODELS: | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"Model '{model_name}' not loaded. Available: {list(MODELS.keys())}" | |
| ) | |
| model_pipeline = MODELS[model_name] | |
| # Prepare features | |
| try: | |
| df_features = prepare_features(features_list) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=422, | |
| detail=f"Data validation failed: {str(e)}" | |
| ) | |
| # Perform prediction | |
| try: | |
| # Get probabilities (0-100%) | |
| probabilities = model_pipeline.predict_proba(df_features)[:, 1] * 100 | |
| # Prepare predictions | |
| predictions = [] | |
| for prob in probabilities: | |
| # Convert numpy float to Python float for JSON serialization | |
| prob_value = float(prob) | |
| predictions.append({ | |
| "fraud_score": round(prob_value, 2) | |
| }) | |
| total = len(predictions) | |
| return { | |
| "success": True, | |
| "model_used": model_name, | |
| "total_transactions": total, | |
| "predictions": predictions, | |
| "overall_stats": { | |
| "total": total, | |
| "avg_fraud_score": round(float(probabilities.mean()), 2), | |
| "max_fraud_score": round(float(probabilities.max()), 2), | |
| "min_fraud_score": round(float(probabilities.min()), 2) | |
| } | |
| } | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Prediction execution failed: {type(e).__name__}: {str(e)}" | |
| ) | |
| async def llm_analyse(payload: LLMAnalysePayload): | |
| """ | |
| LLM-based analysis of transactions using Gemini. | |
| Expects a list of transactions with fields including fraud_score, STATUS, cc_num, merchant, category, amt, gender, state, zip, lat, long, city_pop, job, unix_time, merch_lat, merch_long, is_fraud, age, trans_hour, trans_day, trans_month, trans_weekday, distance. | |
| Passes fraud_score as a percentage string (e.g., '94%') for more descriptive LLM analysis. | |
| """ | |
| if not GEMINI_API_KEY: | |
| raise HTTPException( | |
| status_code=500, | |
| detail="Gemini API key not configured. Set GEMINI_API_KEY environment variable." | |
| ) | |
| transactions = payload.transactions | |
| if not transactions: | |
| raise HTTPException( | |
| status_code=422, | |
| detail="No transactions provided." | |
| ) | |
| try: | |
| # Convert to DataFrame | |
| df = pd.DataFrame(transactions) | |
| # Remove 'fraud_' from all merchant names | |
| if 'merchant' in df.columns: | |
| df['merchant'] = df['merchant'].str.replace('fraud_', '', regex=False) | |
| # Convert 'score' (previously 'fraud_score') to percentage string if it exists | |
| if 'score' in df.columns: | |
| def format_score(x): | |
| try: | |
| val = float(x) * 100 # multiply by 100 | |
| if val >= 99: | |
| return "99%" | |
| else: | |
| return f"{round(val, 2)}%" | |
| except: | |
| return f"{x}%" # fallback in case of unexpected value | |
| df['score'] = df['score'].apply(format_score) | |
| # Convert DataFrame to CSV string | |
| csv_string = df.to_csv(index=False) | |
| # Craft more descriptive prompt | |
| prompt = f""" | |
| You are a senior fraud analyst. Analyze the following credit card transaction dataset in CSV format. Each transaction includes a fraud_score (as percentage, e.g., '94%'), STATUS, transaction details, merchant, amount, location, time, and other relevant features. | |
| CSV Data: | |
| {csv_string} | |
| Instructions: | |
| Instructions: | |
| 1. Determine an **overall fraud risk score** (0-1 scale) reflecting the dataset’s general risk. Scale the score so that even a small number of high-risk transactions meaningfully increases the score. Mostly safe transactions should still be low, a few high-risk transactions should produce a moderate-to-high score, and many high-risk transactions should produce a higher score. Use narrative judgment to scale; do not state exact thresholds. | |
| 2. Provide a detailed **insights** paragraph (150-200 words) describing patterns, anomalies, clusters, temporal or geographic trends, and merchant behaviors. Avoid listing exact counts or percentages. | |
| 3. Provide a detailed **recommendation** paragraph (100-150 words) suggesting practical actions to mitigate risk, including monitoring, alerts, or investigation. Keep guidance non-prescriptive about individual transactions. | |
| 4. Output ONLY valid JSON in this format: {{"fraud_score": <float 0-1>, "insights": "<string insights paragraph>", "recommendation": "<string recommendation paragraph>"}}. | |
| 5. Let the fraud_score scale more sharply: even a few high-risk transactions should noticeably increase the score, and more high-risk transactions should push it even higher, while mostly safe datasets remain near the bottom of the scale. | |
| Focus on narrative-style, descriptive analysis and make the fraud score percentages in the CSV the key reference points for your reasoning. | |
| """ | |
| # Generate with Gemini | |
| model = genai.GenerativeModel('gemini-2.5-flash-lite-preview-09-2025') | |
| response = model.generate_content(prompt) | |
| # Extract JSON from response | |
| raw_response = response.text | |
| json_str = extract_json_from_markdown(raw_response) | |
| analysis_json = json.loads(json_str) | |
| # Validate output | |
| if not isinstance(analysis_json.get('fraud_score'), (int, float)) or \ | |
| not isinstance(analysis_json.get('insights'), str) or \ | |
| not isinstance(analysis_json.get('recommendation'), str): | |
| missing_keys = [k for k in ['fraud_score', 'insights', 'recommendation'] | |
| if k not in analysis_json or not isinstance(analysis_json.get(k), (int, float, str))] | |
| raise ValueError(f"Invalid JSON structure from LLM. Missing/Wrong type keys: {missing_keys}") | |
| return analysis_json | |
| except json.JSONDecodeError as je: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Failed to parse LLM response as JSON: {str(je)}. Raw response: {raw_response}" | |
| ) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"LLM analysis failed: {type(e).__name__}: {str(e)}" | |
| ) | |
| # For local development | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # IMPORTANT: Use the correct host and port for your deployment environment | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |