Spaces:
Runtime error
Runtime error
Commit ·
a2dbe57
0
Parent(s):
Initial clean commit (no runtime data)
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .claude/settings.local.json +8 -0
- .dockerignore +28 -0
- .env.example +17 -0
- .gitattributes +35 -0
- .gitignore +13 -0
- DEPLOYMENT_GUIDE.md +570 -0
- Dockerfile +33 -0
- README.md +12 -0
- credily.egg-info/PKG-INFO +33 -0
- credily.egg-info/SOURCES.txt +18 -0
- credily.egg-info/dependency_links.txt +1 -0
- credily.egg-info/entry_points.txt +2 -0
- credily.egg-info/requires.txt +11 -0
- credily.egg-info/top_level.txt +1 -0
- credily/__init__.py +38 -0
- credily/__pycache__/__init__.cpython-314.pyc +0 -0
- credily/__pycache__/agnostic_pipeline.cpython-314.pyc +0 -0
- credily/__pycache__/analyzer.cpython-314.pyc +0 -0
- credily/__pycache__/automl.cpython-314.pyc +0 -0
- credily/__pycache__/balancing.cpython-314.pyc +0 -0
- credily/__pycache__/cleaning.cpython-314.pyc +0 -0
- credily/__pycache__/cli.cpython-314.pyc +0 -0
- credily/__pycache__/profiler.cpython-314.pyc +0 -0
- credily/__pycache__/reporting.cpython-314.pyc +0 -0
- credily/__pycache__/safety.cpython-314.pyc +0 -0
- credily/__pycache__/utils.cpython-314.pyc +0 -0
- credily/agnostic_pipeline.py +537 -0
- credily/analyzer.py +214 -0
- credily/api/__init__.py +8 -0
- credily/api/__pycache__/__init__.cpython-314.pyc +0 -0
- credily/api/__pycache__/database.cpython-314.pyc +0 -0
- credily/api/__pycache__/errors.cpython-314.pyc +0 -0
- credily/api/__pycache__/main.cpython-314.pyc +0 -0
- credily/api/__pycache__/schemas.cpython-314.pyc +0 -0
- credily/api/database.py +368 -0
- credily/api/errors.py +232 -0
- credily/api/main.py +1035 -0
- credily/api/schemas.py +229 -0
- credily/automl.py +1073 -0
- credily/balancing.py +375 -0
- credily/cleaning.py +643 -0
- credily/cli.py +367 -0
- credily/metrics.py +49 -0
- credily/model.py +240 -0
- credily/preprocessing.py +63 -0
- credily/profiler.py +184 -0
- credily/reporting.py +257 -0
- credily/safety.py +634 -0
- credily/utils.py +199 -0
- debug_output/model.pkl +3 -0
.claude/settings.local.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"permissions": {
|
| 3 |
+
"allow": [
|
| 4 |
+
"Bash(python -c \"from credily.cli import cli; cli\\(\\)\":*)",
|
| 5 |
+
"Bash(pip install:*)"
|
| 6 |
+
]
|
| 7 |
+
}
|
| 8 |
+
}
|
.dockerignore
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
*.pyo
|
| 5 |
+
*.pyd
|
| 6 |
+
|
| 7 |
+
# Virtual env
|
| 8 |
+
venv/
|
| 9 |
+
.venv/
|
| 10 |
+
|
| 11 |
+
# ML artifacts
|
| 12 |
+
credily_data/
|
| 13 |
+
credily_models/
|
| 14 |
+
credily_output/
|
| 15 |
+
debug_output/
|
| 16 |
+
debug_output_smote/
|
| 17 |
+
*.pkl
|
| 18 |
+
*.zip
|
| 19 |
+
|
| 20 |
+
# Build artifacts
|
| 21 |
+
credily.egg-info/
|
| 22 |
+
|
| 23 |
+
# Git
|
| 24 |
+
.git/
|
| 25 |
+
.gitignore
|
| 26 |
+
|
| 27 |
+
# Env
|
| 28 |
+
.env
|
.env.example
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Credily Backend Environment Configuration
|
| 2 |
+
# Copy this file to .env and fill in your values
|
| 3 |
+
|
| 4 |
+
# Database Configuration
|
| 5 |
+
# For development: Leave unset to use SQLite (default)
|
| 6 |
+
# For production: Set to your PostgreSQL connection string
|
| 7 |
+
# DATABASE_URL=postgresql://user:password@host:port/database
|
| 8 |
+
|
| 9 |
+
# Production PostgreSQL with PgBouncer (Supabase example)
|
| 10 |
+
# DATABASE_URL=postgresql://postgres.xxxxx:password@aws-1-eu-central-2.pooler.supabase.com:6543/postgres?pgbouncer=true
|
| 11 |
+
|
| 12 |
+
# API Configuration (optional)
|
| 13 |
+
# HOST=0.0.0.0
|
| 14 |
+
# PORT=8000
|
| 15 |
+
|
| 16 |
+
# CORS Origins (comma-separated for production)
|
| 17 |
+
# CORS_ORIGINS=https://your-frontend.com,https://api.your-domain.com
|
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
credily_data/
|
| 2 |
+
credily_models/
|
| 3 |
+
credily_output/
|
| 4 |
+
|
| 5 |
+
*.db
|
| 6 |
+
*.pkl
|
| 7 |
+
*.zip
|
| 8 |
+
|
| 9 |
+
.env
|
| 10 |
+
venv/
|
| 11 |
+
.venv/
|
| 12 |
+
__pycache__/
|
| 13 |
+
credily.egg-info/
|
DEPLOYMENT_GUIDE.md
ADDED
|
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Credily Backend Deployment Guide
|
| 2 |
+
|
| 3 |
+
A comprehensive guide to deploying and using the Credily credit risk prediction backend.
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## Table of Contents
|
| 8 |
+
|
| 9 |
+
1. [Prerequisites](#prerequisites)
|
| 10 |
+
2. [Installation](#installation)
|
| 11 |
+
3. [Configuration](#configuration)
|
| 12 |
+
4. [Running the API Server](#running-the-api-server)
|
| 13 |
+
5. [Testing the Deployment](#testing-the-deployment)
|
| 14 |
+
6. [API Endpoints](#api-endpoints)
|
| 15 |
+
7. [Model Training Workflow](#model-training-workflow)
|
| 16 |
+
8. [Prediction Workflow](#prediction-workflow)
|
| 17 |
+
9. [Data Requirements](#data-requirements)
|
| 18 |
+
10. [Troubleshooting](#troubleshooting)
|
| 19 |
+
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
## Prerequisites
|
| 23 |
+
|
| 24 |
+
### System Requirements
|
| 25 |
+
- Python 3.10+ (tested with Python 3.14)
|
| 26 |
+
- 4GB+ RAM recommended
|
| 27 |
+
- Windows/Linux/macOS
|
| 28 |
+
|
| 29 |
+
### Required Python Packages
|
| 30 |
+
```
|
| 31 |
+
fastapi
|
| 32 |
+
uvicorn
|
| 33 |
+
pandas
|
| 34 |
+
numpy
|
| 35 |
+
scikit-learn
|
| 36 |
+
joblib
|
| 37 |
+
imbalanced-learn
|
| 38 |
+
xgboost (optional)
|
| 39 |
+
lightgbm (optional)
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
---
|
| 43 |
+
|
| 44 |
+
## Installation
|
| 45 |
+
|
| 46 |
+
### 1. Clone/Download the Repository
|
| 47 |
+
```bash
|
| 48 |
+
cd Credily_backend
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### 2. Create Virtual Environment (Recommended)
|
| 52 |
+
```bash
|
| 53 |
+
python -m venv venv
|
| 54 |
+
|
| 55 |
+
# Windows
|
| 56 |
+
.\venv\Scripts\activate
|
| 57 |
+
|
| 58 |
+
# Linux/macOS
|
| 59 |
+
source venv/bin/activate
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
### 3. Install Dependencies
|
| 63 |
+
```bash
|
| 64 |
+
pip install -r requirements.txt
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
### 4. Verify Installation
|
| 68 |
+
```bash
|
| 69 |
+
python debug_pipeline.py
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
Expected output:
|
| 73 |
+
```
|
| 74 |
+
ALL TESTS PASSED - Pipeline ready for deployment!
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
---
|
| 78 |
+
|
| 79 |
+
## Configuration
|
| 80 |
+
|
| 81 |
+
### Environment Variables
|
| 82 |
+
Create a `.env` file in the project root (see `.env.example`):
|
| 83 |
+
|
| 84 |
+
```env
|
| 85 |
+
# Database Configuration
|
| 86 |
+
# For development: Leave unset to use SQLite (stored in credily_data/credily.db)
|
| 87 |
+
# For production: Set to your PostgreSQL connection string
|
| 88 |
+
DATABASE_URL=postgresql://user:password@host:port/database
|
| 89 |
+
|
| 90 |
+
# Production PostgreSQL with PgBouncer (Supabase example)
|
| 91 |
+
# DATABASE_URL=postgresql://postgres.xxxxx:password@aws-1-eu-central-2.pooler.supabase.com:6543/postgres?pgbouncer=true
|
| 92 |
+
|
| 93 |
+
# Server settings (optional)
|
| 94 |
+
HOST=0.0.0.0
|
| 95 |
+
PORT=8000
|
| 96 |
+
|
| 97 |
+
# CORS Origins (comma-separated for production)
|
| 98 |
+
# CORS_ORIGINS=https://your-frontend.com
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
### Database Configuration
|
| 102 |
+
|
| 103 |
+
The API supports two database backends:
|
| 104 |
+
|
| 105 |
+
| Environment | Database | Configuration |
|
| 106 |
+
|-------------|----------|---------------|
|
| 107 |
+
| Development | SQLite | Default - no config needed |
|
| 108 |
+
| Production | PostgreSQL | Set `DATABASE_URL` environment variable |
|
| 109 |
+
|
| 110 |
+
**SQLite (Development)**
|
| 111 |
+
- Used automatically when `DATABASE_URL` is not set
|
| 112 |
+
- Data stored in `credily_data/credily.db`
|
| 113 |
+
- No additional setup required
|
| 114 |
+
|
| 115 |
+
**PostgreSQL (Production)**
|
| 116 |
+
- Set the `DATABASE_URL` environment variable
|
| 117 |
+
- Supports connection pooling (e.g., PgBouncer)
|
| 118 |
+
- Tables are created automatically on first run
|
| 119 |
+
|
| 120 |
+
### API Configuration
|
| 121 |
+
The API can be configured via `credily/api/main.py`:
|
| 122 |
+
|
| 123 |
+
| Setting | Default | Description |
|
| 124 |
+
|---------|---------|-------------|
|
| 125 |
+
| `host` | `0.0.0.0` | Server host |
|
| 126 |
+
| `port` | `8000` | Server port |
|
| 127 |
+
| `reload` | `True` | Auto-reload on code changes (dev only) |
|
| 128 |
+
|
| 129 |
+
---
|
| 130 |
+
|
| 131 |
+
## Running the API Server
|
| 132 |
+
|
| 133 |
+
### Development Mode
|
| 134 |
+
```bash
|
| 135 |
+
cd Credily_backend
|
| 136 |
+
python -m uvicorn credily.api.main:app --reload --host 0.0.0.0 --port 8000
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
### Production Mode
|
| 140 |
+
```bash
|
| 141 |
+
python -m uvicorn credily.api.main:app --host 0.0.0.0 --port 8000 --workers 4
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
### Using the CLI
|
| 145 |
+
```bash
|
| 146 |
+
# Start server via CLI
|
| 147 |
+
python -m credily.cli serve --port 8000
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
### Verify Server is Running
|
| 151 |
+
Open browser: http://localhost:8000/docs
|
| 152 |
+
|
| 153 |
+
You should see the Swagger UI with all available endpoints.
|
| 154 |
+
|
| 155 |
+
---
|
| 156 |
+
|
| 157 |
+
## Testing the Deployment
|
| 158 |
+
|
| 159 |
+
### 1. Health Check
|
| 160 |
+
```bash
|
| 161 |
+
curl http://localhost:8000/health
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
Expected response:
|
| 165 |
+
```json
|
| 166 |
+
{"status": "healthy"}
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
### 2. Run Debug Tests
|
| 170 |
+
```bash
|
| 171 |
+
python debug_pipeline.py
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
### 3. Test with Sample Data
|
| 175 |
+
```bash
|
| 176 |
+
curl -X POST "http://localhost:8000/api/profile" \
|
| 177 |
+
-H "Content-Type: application/json" \
|
| 178 |
+
-d '[{"age": 25, "income": 50000, "target": 0}]'
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
---
|
| 182 |
+
|
| 183 |
+
## API Endpoints
|
| 184 |
+
|
| 185 |
+
### Core Endpoints
|
| 186 |
+
|
| 187 |
+
| Method | Endpoint | Description |
|
| 188 |
+
|--------|----------|-------------|
|
| 189 |
+
| `GET` | `/health` | Health check |
|
| 190 |
+
| `POST` | `/api/train` | Train a new model |
|
| 191 |
+
| `POST` | `/api/predict` | Make predictions |
|
| 192 |
+
| `POST` | `/api/predict/single` | Single record prediction |
|
| 193 |
+
| `POST` | `/api/profile` | Profile dataset |
|
| 194 |
+
| `GET` | `/api/models` | List saved models |
|
| 195 |
+
|
| 196 |
+
### Training Endpoint
|
| 197 |
+
|
| 198 |
+
**POST `/api/train`**
|
| 199 |
+
|
| 200 |
+
Request body:
|
| 201 |
+
```json
|
| 202 |
+
{
|
| 203 |
+
"data": [
|
| 204 |
+
{"feature1": 1, "feature2": "a", "target": 0},
|
| 205 |
+
{"feature1": 2, "feature2": "b", "target": 1}
|
| 206 |
+
],
|
| 207 |
+
"target_column": "target",
|
| 208 |
+
"clean_data": true,
|
| 209 |
+
"clean_mode": "thorough",
|
| 210 |
+
"balance_data": true,
|
| 211 |
+
"balance_method": "smote",
|
| 212 |
+
"calibrate": true,
|
| 213 |
+
"optimize_threshold": true
|
| 214 |
+
}
|
| 215 |
+
```
|
| 216 |
+
|
| 217 |
+
Response:
|
| 218 |
+
```json
|
| 219 |
+
{
|
| 220 |
+
"success": true,
|
| 221 |
+
"model_path": "credily_output/model.pkl",
|
| 222 |
+
"download_url": "/download/model_20240101_120000.zip",
|
| 223 |
+
"results": {
|
| 224 |
+
"best_model": "RandomForest",
|
| 225 |
+
"test_auc": 0.85,
|
| 226 |
+
"optimal_threshold": 0.42
|
| 227 |
+
}
|
| 228 |
+
}
|
| 229 |
+
```
|
| 230 |
+
|
| 231 |
+
### Prediction Endpoint
|
| 232 |
+
|
| 233 |
+
**POST `/api/predict`**
|
| 234 |
+
|
| 235 |
+
Request body:
|
| 236 |
+
```json
|
| 237 |
+
{
|
| 238 |
+
"data": [
|
| 239 |
+
{"feature1": 1, "feature2": "a"},
|
| 240 |
+
{"feature1": 2, "feature2": "b"}
|
| 241 |
+
],
|
| 242 |
+
"model_path": "C:/path/to/model.pkl",
|
| 243 |
+
"include_proba": true,
|
| 244 |
+
"threshold": null,
|
| 245 |
+
"save_results": false
|
| 246 |
+
}
|
| 247 |
+
```
|
| 248 |
+
|
| 249 |
+
Response:
|
| 250 |
+
```json
|
| 251 |
+
{
|
| 252 |
+
"success": true,
|
| 253 |
+
"predictions": [
|
| 254 |
+
{"index": 0, "prediction": 0, "probability": 0.23, "risk_level": "low"},
|
| 255 |
+
{"index": 1, "prediction": 1, "probability": 0.78, "risk_level": "high"}
|
| 256 |
+
],
|
| 257 |
+
"summary": {
|
| 258 |
+
"total_records": 2,
|
| 259 |
+
"predicted_positive": 1,
|
| 260 |
+
"positive_rate": 0.5
|
| 261 |
+
}
|
| 262 |
+
}
|
| 263 |
+
```
|
| 264 |
+
|
| 265 |
+
---
|
| 266 |
+
|
| 267 |
+
## Model Training Workflow
|
| 268 |
+
|
| 269 |
+
### Step 1: Prepare Your Data
|
| 270 |
+
|
| 271 |
+
Your training data should be a CSV or JSON with:
|
| 272 |
+
- Feature columns (numeric and/or categorical)
|
| 273 |
+
- Target column (binary: 0/1 or string labels)
|
| 274 |
+
|
| 275 |
+
**Example CSV:**
|
| 276 |
+
```csv
|
| 277 |
+
age,income,employment,education,target
|
| 278 |
+
25,50000,employed,bachelor,0
|
| 279 |
+
45,80000,self-employed,master,1
|
| 280 |
+
```
|
| 281 |
+
|
| 282 |
+
### Step 2: Data Cleaning (Automatic)
|
| 283 |
+
|
| 284 |
+
The pipeline automatically:
|
| 285 |
+
- Removes ID columns (`id`, `customer_id`, `loan_id`, etc.)
|
| 286 |
+
- Removes unnamed columns
|
| 287 |
+
- Replaces invalid values (`?`, `N/A`, `NULL`) with NaN
|
| 288 |
+
- Handles missing values (creates `_missing` indicator columns)
|
| 289 |
+
- Handles outliers (IQR capping)
|
| 290 |
+
- Removes low-variance features
|
| 291 |
+
- Removes highly correlated features
|
| 292 |
+
- Standardizes categorical values
|
| 293 |
+
|
| 294 |
+
### Step 3: Train via API
|
| 295 |
+
|
| 296 |
+
```python
|
| 297 |
+
import requests
|
| 298 |
+
import pandas as pd
|
| 299 |
+
|
| 300 |
+
# Load your data
|
| 301 |
+
df = pd.read_csv('your_data.csv')
|
| 302 |
+
data = df.to_dict(orient='records')
|
| 303 |
+
|
| 304 |
+
# Train model
|
| 305 |
+
response = requests.post(
|
| 306 |
+
'http://localhost:8000/api/train',
|
| 307 |
+
json={
|
| 308 |
+
'data': data,
|
| 309 |
+
'target_column': 'target',
|
| 310 |
+
'clean_data': True,
|
| 311 |
+
'balance_data': True,
|
| 312 |
+
'calibrate': True
|
| 313 |
+
}
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
result = response.json()
|
| 317 |
+
print(f"Model saved to: {result['model_path']}")
|
| 318 |
+
print(f"Test AUC: {result['results']['test_auc']}")
|
| 319 |
+
```
|
| 320 |
+
|
| 321 |
+
### Step 4: Download Model
|
| 322 |
+
|
| 323 |
+
The trained model is saved as a `.pkl` file containing:
|
| 324 |
+
- Trained sklearn pipeline
|
| 325 |
+
- Feature names
|
| 326 |
+
- Expected columns (for prediction alignment)
|
| 327 |
+
- Optimal threshold
|
| 328 |
+
- Model metadata
|
| 329 |
+
|
| 330 |
+
---
|
| 331 |
+
|
| 332 |
+
## Prediction Workflow
|
| 333 |
+
|
| 334 |
+
### Step 1: Prepare Prediction Data
|
| 335 |
+
|
| 336 |
+
Your prediction data should have the same features as training data.
|
| 337 |
+
|
| 338 |
+
**Important:**
|
| 339 |
+
- Target column is NOT required
|
| 340 |
+
- ID columns will be automatically ignored
|
| 341 |
+
- Missing columns will be filled with NaN (imputed)
|
| 342 |
+
- Extra columns will be removed
|
| 343 |
+
|
| 344 |
+
### Step 2: Make Predictions
|
| 345 |
+
|
| 346 |
+
```python
|
| 347 |
+
import requests
|
| 348 |
+
import pandas as pd
|
| 349 |
+
|
| 350 |
+
# Load prediction data
|
| 351 |
+
df = pd.read_csv('new_data.csv')
|
| 352 |
+
data = df.to_dict(orient='records')
|
| 353 |
+
|
| 354 |
+
# Predict
|
| 355 |
+
response = requests.post(
|
| 356 |
+
'http://localhost:8000/api/predict',
|
| 357 |
+
json={
|
| 358 |
+
'data': data,
|
| 359 |
+
'model_path': 'C:/path/to/model.pkl',
|
| 360 |
+
'include_proba': True
|
| 361 |
+
}
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
result = response.json()
|
| 365 |
+
for pred in result['predictions']:
|
| 366 |
+
print(f"Record {pred['index']}: {pred['risk_level']} ({pred['probability']:.2%})")
|
| 367 |
+
```
|
| 368 |
+
|
| 369 |
+
### Step 3: Interpret Results
|
| 370 |
+
|
| 371 |
+
| Risk Level | Probability Range | Interpretation |
|
| 372 |
+
|------------|------------------|----------------|
|
| 373 |
+
| `very_low` | 0.00 - 0.25 | Low risk of default |
|
| 374 |
+
| `low` | 0.25 - 0.50 | Below average risk |
|
| 375 |
+
| `medium` | 0.50 - 0.75 | Above average risk |
|
| 376 |
+
| `high` | 0.75 - 1.00 | High risk of default |
|
| 377 |
+
|
| 378 |
+
---
|
| 379 |
+
|
| 380 |
+
## Data Requirements
|
| 381 |
+
|
| 382 |
+
### Supported Data Types
|
| 383 |
+
|
| 384 |
+
| Type | Examples | Handling |
|
| 385 |
+
|------|----------|----------|
|
| 386 |
+
| Numeric | `age`, `income`, `score` | StandardScaler + median imputation |
|
| 387 |
+
| Categorical | `employment`, `education` | OneHotEncoder + mode imputation |
|
| 388 |
+
| Binary target | `0/1`, `yes/no`, `+/-` | Auto-converted to 0/1 |
|
| 389 |
+
|
| 390 |
+
### Columns Automatically Removed
|
| 391 |
+
|
| 392 |
+
The following columns are automatically detected and removed:
|
| 393 |
+
|
| 394 |
+
1. **ID columns** (by name pattern):
|
| 395 |
+
- `id`, `ID`, `_id`
|
| 396 |
+
- `customer_id`, `user_id`, `account_id`
|
| 397 |
+
- `loan_id`, `application_id`, `transaction_id`
|
| 398 |
+
- `index`, `idx`, `key`, `pk`
|
| 399 |
+
- `uuid`, `guid`
|
| 400 |
+
|
| 401 |
+
2. **ID columns** (by characteristics):
|
| 402 |
+
- 100% unique values (object type)
|
| 403 |
+
- Sequential integers (auto-increment pattern)
|
| 404 |
+
|
| 405 |
+
3. **Other:**
|
| 406 |
+
- Unnamed columns (`Unnamed: 0`, etc.)
|
| 407 |
+
- High-missing columns (>50% missing)
|
| 408 |
+
- Low-variance columns (<0.01 variance)
|
| 409 |
+
- Highly correlated columns (>0.95 correlation)
|
| 410 |
+
|
| 411 |
+
### Missing Value Handling
|
| 412 |
+
|
| 413 |
+
| Column Type | Missing < 5% | Missing 5-50% | Missing > 50% |
|
| 414 |
+
|-------------|--------------|---------------|---------------|
|
| 415 |
+
| Numeric | Median impute | Median impute + `_missing` flag | Drop column |
|
| 416 |
+
| Categorical | Mode impute | Mode impute + `_missing` flag | Drop column |
|
| 417 |
+
|
| 418 |
+
---
|
| 419 |
+
|
| 420 |
+
## Troubleshooting
|
| 421 |
+
|
| 422 |
+
### Common Issues
|
| 423 |
+
|
| 424 |
+
#### 1. "Model not trained" Error
|
| 425 |
+
```
|
| 426 |
+
ValueError: Model not trained. Call train() first or load a saved model.
|
| 427 |
+
```
|
| 428 |
+
**Solution:** Ensure you're passing the correct `model_path` to the predict endpoint.
|
| 429 |
+
|
| 430 |
+
#### 2. Missing Columns Warning
|
| 431 |
+
```
|
| 432 |
+
Column alignment applied: {'missing_columns': ['feature_x'], ...}
|
| 433 |
+
```
|
| 434 |
+
**This is normal.** Missing columns are automatically filled with NaN and imputed. The model will still make predictions.
|
| 435 |
+
|
| 436 |
+
#### 3. SMOTE Fails
|
| 437 |
+
```
|
| 438 |
+
Warning: SMOTE failed (...). Using random oversampling instead.
|
| 439 |
+
```
|
| 440 |
+
**This can happen when:**
|
| 441 |
+
- Minority class has too few samples (< 6)
|
| 442 |
+
- All numeric columns have NaN
|
| 443 |
+
|
| 444 |
+
**Solution:** Use `balance_method='random_oversample'` or increase data size.
|
| 445 |
+
|
| 446 |
+
#### 4. Import Errors
|
| 447 |
+
```
|
| 448 |
+
ModuleNotFoundError: No module named 'imblearn'
|
| 449 |
+
```
|
| 450 |
+
**Solution:**
|
| 451 |
+
```bash
|
| 452 |
+
pip install imbalanced-learn
|
| 453 |
+
```
|
| 454 |
+
|
| 455 |
+
#### 5. Memory Errors
|
| 456 |
+
```
|
| 457 |
+
MemoryError: Unable to allocate...
|
| 458 |
+
```
|
| 459 |
+
**Solution:**
|
| 460 |
+
- Reduce dataset size
|
| 461 |
+
- Use `clean_mode='aggressive'` to drop more columns
|
| 462 |
+
- Increase system RAM
|
| 463 |
+
|
| 464 |
+
### Debug Commands
|
| 465 |
+
|
| 466 |
+
```bash
|
| 467 |
+
# Test full pipeline
|
| 468 |
+
python debug_pipeline.py
|
| 469 |
+
|
| 470 |
+
# Check installed packages
|
| 471 |
+
pip list | grep -E "sklearn|pandas|imblearn"
|
| 472 |
+
|
| 473 |
+
# Verify API is running
|
| 474 |
+
curl http://localhost:8000/health
|
| 475 |
+
|
| 476 |
+
# Check model contents
|
| 477 |
+
python -c "import joblib; m = joblib.load('model.pkl'); print(m.keys())"
|
| 478 |
+
```
|
| 479 |
+
|
| 480 |
+
### Log Files
|
| 481 |
+
|
| 482 |
+
API logs are printed to stdout. For production, redirect to a file:
|
| 483 |
+
```bash
|
| 484 |
+
python -m uvicorn credily.api.main:app > api.log 2>&1
|
| 485 |
+
```
|
| 486 |
+
|
| 487 |
+
---
|
| 488 |
+
|
| 489 |
+
## Production Deployment
|
| 490 |
+
|
| 491 |
+
### Using Docker
|
| 492 |
+
|
| 493 |
+
```dockerfile
|
| 494 |
+
FROM python:3.11-slim
|
| 495 |
+
|
| 496 |
+
WORKDIR /app
|
| 497 |
+
COPY requirements.txt .
|
| 498 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 499 |
+
|
| 500 |
+
COPY . .
|
| 501 |
+
|
| 502 |
+
EXPOSE 8000
|
| 503 |
+
CMD ["uvicorn", "credily.api.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
| 504 |
+
```
|
| 505 |
+
|
| 506 |
+
Build and run:
|
| 507 |
+
```bash
|
| 508 |
+
docker build -t credily-backend .
|
| 509 |
+
docker run -p 8000:8000 credily-backend
|
| 510 |
+
```
|
| 511 |
+
|
| 512 |
+
### Using Gunicorn (Linux)
|
| 513 |
+
|
| 514 |
+
```bash
|
| 515 |
+
pip install gunicorn
|
| 516 |
+
gunicorn credily.api.main:app -w 4 -k uvicorn.workers.UvicornWorker -b 0.0.0.0:8000
|
| 517 |
+
```
|
| 518 |
+
|
| 519 |
+
### Nginx Reverse Proxy
|
| 520 |
+
|
| 521 |
+
```nginx
|
| 522 |
+
server {
|
| 523 |
+
listen 80;
|
| 524 |
+
server_name your-domain.com;
|
| 525 |
+
|
| 526 |
+
location / {
|
| 527 |
+
proxy_pass http://127.0.0.1:8000;
|
| 528 |
+
proxy_set_header Host $host;
|
| 529 |
+
proxy_set_header X-Real-IP $remote_addr;
|
| 530 |
+
}
|
| 531 |
+
}
|
| 532 |
+
```
|
| 533 |
+
|
| 534 |
+
---
|
| 535 |
+
|
| 536 |
+
## Quick Reference
|
| 537 |
+
|
| 538 |
+
### CLI Commands
|
| 539 |
+
```bash
|
| 540 |
+
# Train model
|
| 541 |
+
python -m credily.cli train --data data.csv --target target
|
| 542 |
+
|
| 543 |
+
# Predict
|
| 544 |
+
python -m credily.cli predict --model model.pkl --data new_data.csv
|
| 545 |
+
|
| 546 |
+
# Start API server
|
| 547 |
+
python -m credily.cli serve --port 8000
|
| 548 |
+
```
|
| 549 |
+
|
| 550 |
+
### Python SDK
|
| 551 |
+
```python
|
| 552 |
+
from credily.automl import CredilyPipeline
|
| 553 |
+
|
| 554 |
+
# Train
|
| 555 |
+
pipeline = CredilyPipeline(target_column='target')
|
| 556 |
+
results = pipeline.train(df)
|
| 557 |
+
|
| 558 |
+
# Load and predict
|
| 559 |
+
pipeline = CredilyPipeline.load('model.pkl')
|
| 560 |
+
predictions = pipeline.predict(new_df, include_proba=True)
|
| 561 |
+
```
|
| 562 |
+
|
| 563 |
+
---
|
| 564 |
+
|
| 565 |
+
## Support
|
| 566 |
+
|
| 567 |
+
For issues and feature requests, please check:
|
| 568 |
+
- Debug tests: `python debug_pipeline.py`
|
| 569 |
+
- API docs: http://localhost:8000/docs
|
| 570 |
+
- Swagger UI: http://localhost:8000/redoc
|
Dockerfile
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
# System deps for ML + Postgres
|
| 4 |
+
RUN apt-get update && apt-get install -y \
|
| 5 |
+
build-essential \
|
| 6 |
+
gcc \
|
| 7 |
+
g++ \
|
| 8 |
+
libpq-dev \
|
| 9 |
+
curl \
|
| 10 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 13 |
+
ENV PYTHONUNBUFFERED=1
|
| 14 |
+
|
| 15 |
+
WORKDIR /app
|
| 16 |
+
|
| 17 |
+
# Install deps first (better caching)
|
| 18 |
+
COPY requirements.txt .
|
| 19 |
+
RUN pip install --upgrade pip && pip install -r requirements.txt
|
| 20 |
+
|
| 21 |
+
# Copy app code
|
| 22 |
+
COPY . .
|
| 23 |
+
|
| 24 |
+
# Runtime directories (mounted as volumes)
|
| 25 |
+
RUN mkdir -p credily_data credily_models credily_output
|
| 26 |
+
|
| 27 |
+
# Non-root user (important for prod)
|
| 28 |
+
RUN useradd -m credilyuser && chown -R credilyuser:credilyuser /app
|
| 29 |
+
USER credilyuser
|
| 30 |
+
|
| 31 |
+
EXPOSE 7860
|
| 32 |
+
|
| 33 |
+
CMD ["uvicorn", "credily.api.main:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "4"]
|
README.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Credily Backend Test
|
| 3 |
+
emoji: 📚
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
license: apache-2.0
|
| 9 |
+
short_description: 'A backend test to try out engine '
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
credily.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: credily
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Summary: Fast, Explainable AutoML for Tabular Data (Finance Focus)
|
| 5 |
+
Author: Your Name
|
| 6 |
+
Author-email: your.email@example.com
|
| 7 |
+
Classifier: Development Status :: 3 - Alpha
|
| 8 |
+
Classifier: Intended Audience :: Developers
|
| 9 |
+
Classifier: Intended Audience :: Financial and Insurance Industry
|
| 10 |
+
Classifier: License :: OSI Approved :: MIT License
|
| 11 |
+
Classifier: Programming Language :: Python :: 3
|
| 12 |
+
Classifier: Programming Language :: Python :: 3.8
|
| 13 |
+
Classifier: Programming Language :: Python :: 3.9
|
| 14 |
+
Classifier: Programming Language :: Python :: 3.10
|
| 15 |
+
Classifier: Programming Language :: Python :: 3.11
|
| 16 |
+
Requires-Python: >=3.8
|
| 17 |
+
Requires-Dist: pandas>=1.3.0
|
| 18 |
+
Requires-Dist: numpy>=1.21.0
|
| 19 |
+
Requires-Dist: scikit-learn>=1.0.0
|
| 20 |
+
Requires-Dist: matplotlib>=3.4.0
|
| 21 |
+
Requires-Dist: joblib>=1.0.0
|
| 22 |
+
Requires-Dist: click>=8.0.0
|
| 23 |
+
Provides-Extra: full
|
| 24 |
+
Requires-Dist: xgboost>=1.5.0; extra == "full"
|
| 25 |
+
Requires-Dist: lightgbm>=3.3.0; extra == "full"
|
| 26 |
+
Requires-Dist: imbalanced-learn>=0.9.0; extra == "full"
|
| 27 |
+
Dynamic: author
|
| 28 |
+
Dynamic: author-email
|
| 29 |
+
Dynamic: classifier
|
| 30 |
+
Dynamic: provides-extra
|
| 31 |
+
Dynamic: requires-dist
|
| 32 |
+
Dynamic: requires-python
|
| 33 |
+
Dynamic: summary
|
credily.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
setup.py
|
| 2 |
+
credily/__init__.py
|
| 3 |
+
credily/analyzer.py
|
| 4 |
+
credily/automl.py
|
| 5 |
+
credily/balancing.py
|
| 6 |
+
credily/cleaning.py
|
| 7 |
+
credily/cli.py
|
| 8 |
+
credily/metrics.py
|
| 9 |
+
credily/model.py
|
| 10 |
+
credily/preprocessing.py
|
| 11 |
+
credily/profiler.py
|
| 12 |
+
credily/reporting.py
|
| 13 |
+
credily.egg-info/PKG-INFO
|
| 14 |
+
credily.egg-info/SOURCES.txt
|
| 15 |
+
credily.egg-info/dependency_links.txt
|
| 16 |
+
credily.egg-info/entry_points.txt
|
| 17 |
+
credily.egg-info/requires.txt
|
| 18 |
+
credily.egg-info/top_level.txt
|
credily.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
credily.egg-info/entry_points.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[console_scripts]
|
| 2 |
+
credily = credily.cli:main
|
credily.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pandas>=1.3.0
|
| 2 |
+
numpy>=1.21.0
|
| 3 |
+
scikit-learn>=1.0.0
|
| 4 |
+
matplotlib>=3.4.0
|
| 5 |
+
joblib>=1.0.0
|
| 6 |
+
click>=8.0.0
|
| 7 |
+
|
| 8 |
+
[full]
|
| 9 |
+
xgboost>=1.5.0
|
| 10 |
+
lightgbm>=3.3.0
|
| 11 |
+
imbalanced-learn>=0.9.0
|
credily.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
credily
|
credily/__init__.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Credily - Fast, Explainable AutoML for Tabular Data (Finance Focus)
|
| 3 |
+
|
| 4 |
+
A command-line tool for automated machine learning that:
|
| 5 |
+
- Profiles your data automatically
|
| 6 |
+
- Infers the ML task type (classification/regression)
|
| 7 |
+
- Preprocesses data (imputation, scaling, encoding)
|
| 8 |
+
- Trains multiple models (Logistic Regression, Random Forest, XGBoost, LightGBM)
|
| 9 |
+
- Selects the best performer using cross-validation
|
| 10 |
+
- Exports model (.pkl) and reports (HTML/JSON)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
import warnings
|
| 16 |
+
|
| 17 |
+
# Suppress joblib resource tracker warnings on Windows
|
| 18 |
+
if sys.platform == 'win32':
|
| 19 |
+
os.environ.setdefault('LOKY_PICKLER', 'pickle')
|
| 20 |
+
warnings.filterwarnings('ignore', category=UserWarning, module='joblib')
|
| 21 |
+
|
| 22 |
+
from .automl import CredilyPipeline
|
| 23 |
+
from .profiler import DataProfiler
|
| 24 |
+
from .analyzer import BusinessAnalyzer
|
| 25 |
+
from .cleaning import DataCleaner
|
| 26 |
+
from .balancing import DataBalancer
|
| 27 |
+
from .agnostic_pipeline import AgnosticPipeline, QuickPipeline
|
| 28 |
+
|
| 29 |
+
__version__ = '0.1.0'
|
| 30 |
+
__all__ = [
|
| 31 |
+
'CredilyPipeline',
|
| 32 |
+
'DataProfiler',
|
| 33 |
+
'BusinessAnalyzer',
|
| 34 |
+
'DataCleaner',
|
| 35 |
+
'DataBalancer',
|
| 36 |
+
'AgnosticPipeline',
|
| 37 |
+
'QuickPipeline'
|
| 38 |
+
]
|
credily/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (1.41 kB). View file
|
|
|
credily/__pycache__/agnostic_pipeline.cpython-314.pyc
ADDED
|
Binary file (27.1 kB). View file
|
|
|
credily/__pycache__/analyzer.cpython-314.pyc
ADDED
|
Binary file (8.61 kB). View file
|
|
|
credily/__pycache__/automl.cpython-314.pyc
ADDED
|
Binary file (43.7 kB). View file
|
|
|
credily/__pycache__/balancing.cpython-314.pyc
ADDED
|
Binary file (22.4 kB). View file
|
|
|
credily/__pycache__/cleaning.cpython-314.pyc
ADDED
|
Binary file (35.2 kB). View file
|
|
|
credily/__pycache__/cli.cpython-314.pyc
ADDED
|
Binary file (23.6 kB). View file
|
|
|
credily/__pycache__/profiler.cpython-314.pyc
ADDED
|
Binary file (9.67 kB). View file
|
|
|
credily/__pycache__/reporting.cpython-314.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
credily/__pycache__/safety.cpython-314.pyc
ADDED
|
Binary file (28.1 kB). View file
|
|
|
credily/__pycache__/utils.cpython-314.pyc
ADDED
|
Binary file (8.19 kB). View file
|
|
|
credily/agnostic_pipeline.py
ADDED
|
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agnostic ML Preprocessing Pipeline for Credily.
|
| 3 |
+
A flexible, data-agnostic pipeline for preprocessing any dataset for ML tasks.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
from sklearn.preprocessing import LabelEncoder
|
| 9 |
+
from typing import Optional, Callable, Dict, Any, List, Tuple, Union
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AgnosticPipeline:
|
| 13 |
+
"""
|
| 14 |
+
Agnostic Pipeline to preprocess any dataset for ML tasks (binary classification or regression).
|
| 15 |
+
|
| 16 |
+
Features:
|
| 17 |
+
- Dynamic target handling (numeric, categorical, or already binary)
|
| 18 |
+
- Automatic missing value handling with indicator flags
|
| 19 |
+
- Automatic encoding for categorical variables
|
| 20 |
+
- Automatic dropping of high-cardinality identifiers
|
| 21 |
+
- Returns X (features) and y (target) ready for modeling
|
| 22 |
+
|
| 23 |
+
This pipeline is designed to work with ANY dataset without domain-specific assumptions.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
binary_threshold: Optional[float] = None,
|
| 29 |
+
binary_rule: Optional[Callable[[pd.Series], int]] = None,
|
| 30 |
+
positive_classes: Optional[List[str]] = None,
|
| 31 |
+
task_type: str = 'binary',
|
| 32 |
+
id_uniqueness_threshold: float = 0.9,
|
| 33 |
+
low_cardinality_threshold: int = 20,
|
| 34 |
+
flag_missing: bool = True,
|
| 35 |
+
verbose: bool = True
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Initialize the AgnosticPipeline.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
binary_threshold: Numeric threshold to convert target to binary.
|
| 42 |
+
Values BELOW threshold become 1 (positive/default class).
|
| 43 |
+
Example: binary_threshold=600 for credit score → score<600 = default
|
| 44 |
+
binary_rule: Custom function to convert target to binary.
|
| 45 |
+
Takes a row (pd.Series) and returns 0 or 1.
|
| 46 |
+
Example: lambda row: 1 if row['score'] < 600 else 0
|
| 47 |
+
positive_classes: List of class labels to treat as positive (1) for binary grouping.
|
| 48 |
+
All other classes become negative (0).
|
| 49 |
+
Example: ['Poor', 'Standard'] → these become 1, 'Good' becomes 0
|
| 50 |
+
task_type: 'binary' for binary classification, 'multiclass' for multi-class, 'regression' for numeric
|
| 51 |
+
id_uniqueness_threshold: Columns with >threshold unique ratio are dropped as IDs
|
| 52 |
+
low_cardinality_threshold: Max unique values for one-hot encoding (others get label encoded)
|
| 53 |
+
flag_missing: Whether to create _missing indicator columns
|
| 54 |
+
verbose: Print processing steps
|
| 55 |
+
"""
|
| 56 |
+
self.binary_threshold = binary_threshold
|
| 57 |
+
self.binary_rule = binary_rule
|
| 58 |
+
self.positive_classes = positive_classes
|
| 59 |
+
self.task_type = task_type
|
| 60 |
+
self.id_uniqueness_threshold = id_uniqueness_threshold
|
| 61 |
+
self.low_cardinality_threshold = low_cardinality_threshold
|
| 62 |
+
self.flag_missing = flag_missing
|
| 63 |
+
self.verbose = verbose
|
| 64 |
+
|
| 65 |
+
# Storage for fitted state (for transform on new data)
|
| 66 |
+
self.id_cols: List[str] = []
|
| 67 |
+
self.num_cols: List[str] = []
|
| 68 |
+
self.cat_cols: List[str] = []
|
| 69 |
+
self.low_card_cols: List[str] = []
|
| 70 |
+
self.high_card_cols: List[str] = []
|
| 71 |
+
self.label_encoders: Dict[str, LabelEncoder] = {}
|
| 72 |
+
self.target_label_encoder: Optional[LabelEncoder] = None # For multiclass targets
|
| 73 |
+
self.numeric_medians: Dict[str, float] = {}
|
| 74 |
+
self.feature_columns: List[str] = []
|
| 75 |
+
self.class_names: List[str] = [] # For multiclass
|
| 76 |
+
self.n_classes: int = 2 # Number of classes
|
| 77 |
+
self.is_fitted: bool = False
|
| 78 |
+
self.processing_report: Dict[str, Any] = {}
|
| 79 |
+
|
| 80 |
+
def _log(self, message: str):
|
| 81 |
+
"""Print message if verbose mode is enabled."""
|
| 82 |
+
if self.verbose:
|
| 83 |
+
print(message)
|
| 84 |
+
|
| 85 |
+
def _detect_columns(self, df: pd.DataFrame, target_column: Optional[str] = None) -> pd.DataFrame:
|
| 86 |
+
"""
|
| 87 |
+
Detect column types and drop high-cardinality identifier columns.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
df: Input dataframe
|
| 91 |
+
target_column: Name of target column (excluded from ID detection)
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
DataFrame with ID columns removed
|
| 95 |
+
"""
|
| 96 |
+
df = df.copy()
|
| 97 |
+
n_rows = len(df)
|
| 98 |
+
nunique = df.nunique()
|
| 99 |
+
|
| 100 |
+
# Identify ID columns (>threshold unique values ratio)
|
| 101 |
+
self.id_cols = []
|
| 102 |
+
for col in df.columns:
|
| 103 |
+
if col == target_column:
|
| 104 |
+
continue
|
| 105 |
+
unique_ratio = nunique[col] / n_rows
|
| 106 |
+
if unique_ratio > self.id_uniqueness_threshold:
|
| 107 |
+
self.id_cols.append(col)
|
| 108 |
+
|
| 109 |
+
if self.id_cols:
|
| 110 |
+
df = df.drop(columns=self.id_cols, errors='ignore')
|
| 111 |
+
self._log(f" [1] Dropped {len(self.id_cols)} ID columns: {self.id_cols}")
|
| 112 |
+
|
| 113 |
+
# Detect column types (excluding target)
|
| 114 |
+
feature_df = df.drop(columns=[target_column], errors='ignore') if target_column else df
|
| 115 |
+
self.num_cols = feature_df.select_dtypes(include=['number']).columns.tolist()
|
| 116 |
+
self.cat_cols = feature_df.select_dtypes(include=['object', 'category']).columns.tolist()
|
| 117 |
+
|
| 118 |
+
self._log(f" [2] Detected {len(self.num_cols)} numeric columns")
|
| 119 |
+
self._log(f" [3] Detected {len(self.cat_cols)} categorical columns")
|
| 120 |
+
|
| 121 |
+
return df
|
| 122 |
+
|
| 123 |
+
def _handle_target(
|
| 124 |
+
self,
|
| 125 |
+
df: pd.DataFrame,
|
| 126 |
+
target_column: str
|
| 127 |
+
) -> Tuple[pd.DataFrame, pd.Series]:
|
| 128 |
+
"""
|
| 129 |
+
Create y (target) based on task_type and user settings.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
df: Input dataframe with target column
|
| 133 |
+
target_column: Name of target column
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
Tuple of (X features dataframe, y target series)
|
| 137 |
+
"""
|
| 138 |
+
df = df.copy()
|
| 139 |
+
|
| 140 |
+
if self.task_type == 'regression':
|
| 141 |
+
# For regression, keep target as-is (must be numeric)
|
| 142 |
+
y = df[target_column].astype(float)
|
| 143 |
+
self._log(f" [4] Target '{target_column}' kept as numeric for regression")
|
| 144 |
+
self.n_classes = 0 # Regression has no classes
|
| 145 |
+
|
| 146 |
+
elif self.task_type == 'multiclass':
|
| 147 |
+
# Multi-class classification: encode categorical target to integers
|
| 148 |
+
unique_vals = df[target_column].dropna().unique()
|
| 149 |
+
self.n_classes = len(unique_vals)
|
| 150 |
+
self.class_names = sorted([str(v) for v in unique_vals])
|
| 151 |
+
|
| 152 |
+
# Create label encoder for target
|
| 153 |
+
self.target_label_encoder = LabelEncoder()
|
| 154 |
+
y = pd.Series(
|
| 155 |
+
self.target_label_encoder.fit_transform(df[target_column].astype(str)),
|
| 156 |
+
index=df.index
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
label_map = dict(zip(self.target_label_encoder.classes_, range(len(self.target_label_encoder.classes_))))
|
| 160 |
+
self._log(f" [4] Target '{target_column}' encoded for multiclass: {label_map}")
|
| 161 |
+
self._log(f" Number of classes: {self.n_classes}")
|
| 162 |
+
|
| 163 |
+
elif self.task_type == 'binary':
|
| 164 |
+
self.n_classes = 2
|
| 165 |
+
|
| 166 |
+
if self.binary_rule is not None:
|
| 167 |
+
# Apply custom rule (function takes row, returns 0 or 1)
|
| 168 |
+
y = df.apply(self.binary_rule, axis=1)
|
| 169 |
+
self._log(f" [4] Target created using custom binary_rule")
|
| 170 |
+
|
| 171 |
+
elif self.binary_threshold is not None:
|
| 172 |
+
# Apply threshold: values BELOW threshold = 1 (positive/default)
|
| 173 |
+
y = df[target_column].apply(lambda x: 1 if x < self.binary_threshold else 0)
|
| 174 |
+
self._log(f" [4] Target '{target_column}' binarized: < {self.binary_threshold} → 1 (positive)")
|
| 175 |
+
|
| 176 |
+
elif self.positive_classes is not None:
|
| 177 |
+
# Binary grouping: specified classes become positive (1), others become negative (0)
|
| 178 |
+
y = df[target_column].apply(
|
| 179 |
+
lambda x: 1 if str(x) in [str(c) for c in self.positive_classes] else 0
|
| 180 |
+
)
|
| 181 |
+
self._log(f" [4] Target '{target_column}' grouped: {self.positive_classes} → 1 (positive), others → 0")
|
| 182 |
+
|
| 183 |
+
else:
|
| 184 |
+
# Check if already binary
|
| 185 |
+
unique_vals = df[target_column].dropna().unique()
|
| 186 |
+
if len(unique_vals) == 2:
|
| 187 |
+
# Auto-convert to 0/1
|
| 188 |
+
sorted_vals = sorted(unique_vals, key=lambda x: str(x))
|
| 189 |
+
label_map = {sorted_vals[0]: 0, sorted_vals[1]: 1}
|
| 190 |
+
y = df[target_column].map(label_map)
|
| 191 |
+
self.class_names = [str(sorted_vals[0]), str(sorted_vals[1])]
|
| 192 |
+
self._log(f" [4] Target '{target_column}' auto-mapped: {label_map}")
|
| 193 |
+
else:
|
| 194 |
+
raise ValueError(
|
| 195 |
+
f"For binary task with non-binary target, provide binary_threshold, positive_classes, or binary_rule. "
|
| 196 |
+
f"Target has {len(unique_vals)} unique values: {list(unique_vals)[:5]}"
|
| 197 |
+
)
|
| 198 |
+
else:
|
| 199 |
+
raise ValueError(f"task_type must be 'binary', 'multiclass', or 'regression', got '{self.task_type}'")
|
| 200 |
+
|
| 201 |
+
X = df.drop(columns=[target_column])
|
| 202 |
+
|
| 203 |
+
# Log class distribution
|
| 204 |
+
if self.task_type in ['binary', 'multiclass']:
|
| 205 |
+
class_counts = y.value_counts()
|
| 206 |
+
self._log(f" Class distribution: {dict(class_counts)}")
|
| 207 |
+
|
| 208 |
+
return X, y
|
| 209 |
+
|
| 210 |
+
def _preprocess_features(self, df: pd.DataFrame, fit: bool = True) -> pd.DataFrame:
|
| 211 |
+
"""
|
| 212 |
+
Handle missing values and encode categorical variables.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
df: Features dataframe (X)
|
| 216 |
+
fit: Whether to fit encoders (True for training, False for inference)
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
Preprocessed dataframe
|
| 220 |
+
"""
|
| 221 |
+
df = df.copy()
|
| 222 |
+
missing_flags_created = []
|
| 223 |
+
|
| 224 |
+
# ===== NUMERIC COLUMNS =====
|
| 225 |
+
for col in self.num_cols:
|
| 226 |
+
if col not in df.columns:
|
| 227 |
+
continue
|
| 228 |
+
|
| 229 |
+
# Create missing flag
|
| 230 |
+
if self.flag_missing:
|
| 231 |
+
missing_count = df[col].isna().sum()
|
| 232 |
+
if missing_count > 0:
|
| 233 |
+
df[col + "_missing"] = df[col].isna().astype(int)
|
| 234 |
+
missing_flags_created.append(col + "_missing")
|
| 235 |
+
|
| 236 |
+
# Fill missing with median
|
| 237 |
+
if fit:
|
| 238 |
+
self.numeric_medians[col] = df[col].median()
|
| 239 |
+
median_val = self.numeric_medians.get(col, 0)
|
| 240 |
+
df[col] = df[col].fillna(median_val)
|
| 241 |
+
|
| 242 |
+
if missing_flags_created:
|
| 243 |
+
self._log(f" [5] Created {len(missing_flags_created)} numeric missing flags")
|
| 244 |
+
|
| 245 |
+
# ===== CATEGORICAL COLUMNS =====
|
| 246 |
+
cat_missing_flags = []
|
| 247 |
+
for col in self.cat_cols:
|
| 248 |
+
if col not in df.columns:
|
| 249 |
+
continue
|
| 250 |
+
|
| 251 |
+
# Create missing flag
|
| 252 |
+
if self.flag_missing:
|
| 253 |
+
missing_count = df[col].isna().sum()
|
| 254 |
+
if missing_count > 0:
|
| 255 |
+
df[col + "_missing"] = df[col].isna().astype(int)
|
| 256 |
+
cat_missing_flags.append(col + "_missing")
|
| 257 |
+
|
| 258 |
+
# Fill missing with "Missing" placeholder
|
| 259 |
+
df[col] = df[col].fillna("Missing")
|
| 260 |
+
|
| 261 |
+
if cat_missing_flags:
|
| 262 |
+
self._log(f" [6] Created {len(cat_missing_flags)} categorical missing flags")
|
| 263 |
+
|
| 264 |
+
# ===== ENCODING =====
|
| 265 |
+
if fit:
|
| 266 |
+
# Determine low vs high cardinality
|
| 267 |
+
self.low_card_cols = [c for c in self.cat_cols if c in df.columns and df[c].nunique() <= self.low_cardinality_threshold]
|
| 268 |
+
self.high_card_cols = [c for c in self.cat_cols if c in df.columns and df[c].nunique() > self.low_cardinality_threshold]
|
| 269 |
+
|
| 270 |
+
# One-hot encode low-cardinality columns
|
| 271 |
+
if self.low_card_cols:
|
| 272 |
+
existing_low_card = [c for c in self.low_card_cols if c in df.columns]
|
| 273 |
+
if existing_low_card:
|
| 274 |
+
df = pd.get_dummies(df, columns=existing_low_card, drop_first=True, dtype=int)
|
| 275 |
+
self._log(f" [7] One-hot encoded {len(existing_low_card)} low-cardinality columns")
|
| 276 |
+
|
| 277 |
+
# Label encode high-cardinality columns
|
| 278 |
+
if self.high_card_cols:
|
| 279 |
+
for col in self.high_card_cols:
|
| 280 |
+
if col not in df.columns:
|
| 281 |
+
continue
|
| 282 |
+
if fit:
|
| 283 |
+
le = LabelEncoder()
|
| 284 |
+
# Fit on all values including "Missing"
|
| 285 |
+
df[col] = le.fit_transform(df[col].astype(str))
|
| 286 |
+
self.label_encoders[col] = le
|
| 287 |
+
else:
|
| 288 |
+
le = self.label_encoders.get(col)
|
| 289 |
+
if le:
|
| 290 |
+
# Handle unseen categories
|
| 291 |
+
df[col] = df[col].astype(str).apply(
|
| 292 |
+
lambda x: le.transform([x])[0] if x in le.classes_ else -1
|
| 293 |
+
)
|
| 294 |
+
if self.high_card_cols:
|
| 295 |
+
self._log(f" [8] Label encoded {len(self.high_card_cols)} high-cardinality columns")
|
| 296 |
+
|
| 297 |
+
return df
|
| 298 |
+
|
| 299 |
+
def fit_transform(
|
| 300 |
+
self,
|
| 301 |
+
df: pd.DataFrame,
|
| 302 |
+
target_column: str
|
| 303 |
+
) -> Tuple[pd.DataFrame, pd.Series]:
|
| 304 |
+
"""
|
| 305 |
+
Main method to preprocess the dataset for training.
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
df: Raw dataframe with features and target
|
| 309 |
+
target_column: Name of the target column
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
Tuple of (X features, y target) ready for ML modeling
|
| 313 |
+
"""
|
| 314 |
+
self._log(f"\n{'='*60}")
|
| 315 |
+
self._log("AGNOSTIC PIPELINE - FIT TRANSFORM")
|
| 316 |
+
self._log(f"{'='*60}")
|
| 317 |
+
self._log(f"Input shape: {df.shape[0]} rows, {df.shape[1]} columns")
|
| 318 |
+
|
| 319 |
+
# Reset state
|
| 320 |
+
self.is_fitted = False
|
| 321 |
+
self.processing_report = {'input_shape': df.shape}
|
| 322 |
+
|
| 323 |
+
# Step 1: Detect columns & drop IDs
|
| 324 |
+
df = self._detect_columns(df, target_column)
|
| 325 |
+
|
| 326 |
+
# Step 2: Handle target
|
| 327 |
+
X, y = self._handle_target(df, target_column)
|
| 328 |
+
|
| 329 |
+
# Step 3: Preprocess features
|
| 330 |
+
X = self._preprocess_features(X, fit=True)
|
| 331 |
+
|
| 332 |
+
# Store feature columns for transform
|
| 333 |
+
self.feature_columns = X.columns.tolist()
|
| 334 |
+
self.is_fitted = True
|
| 335 |
+
|
| 336 |
+
# Summary
|
| 337 |
+
if self.task_type == 'regression':
|
| 338 |
+
target_info = {'mean': float(y.mean()), 'std': float(y.std())}
|
| 339 |
+
else:
|
| 340 |
+
target_info = y.value_counts().to_dict()
|
| 341 |
+
|
| 342 |
+
self.processing_report.update({
|
| 343 |
+
'output_shape': X.shape,
|
| 344 |
+
'task_type': self.task_type,
|
| 345 |
+
'n_classes': self.n_classes,
|
| 346 |
+
'class_names': self.class_names if self.class_names else None,
|
| 347 |
+
'target_distribution': target_info,
|
| 348 |
+
'positive_classes': self.positive_classes,
|
| 349 |
+
'binary_threshold': self.binary_threshold,
|
| 350 |
+
'id_columns_dropped': self.id_cols,
|
| 351 |
+
'numeric_columns': self.num_cols,
|
| 352 |
+
'categorical_columns': self.cat_cols,
|
| 353 |
+
'low_cardinality_encoded': self.low_card_cols,
|
| 354 |
+
'high_cardinality_encoded': self.high_card_cols
|
| 355 |
+
})
|
| 356 |
+
|
| 357 |
+
self._log(f"\n{'='*60}")
|
| 358 |
+
self._log("PREPROCESSING COMPLETE")
|
| 359 |
+
self._log(f" Output shape: {X.shape[0]} rows, {X.shape[1]} features")
|
| 360 |
+
self._log(f"{'='*60}\n")
|
| 361 |
+
|
| 362 |
+
return X, y
|
| 363 |
+
|
| 364 |
+
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
| 365 |
+
"""
|
| 366 |
+
Transform new data using fitted preprocessing.
|
| 367 |
+
|
| 368 |
+
Args:
|
| 369 |
+
df: New dataframe with features only (no target)
|
| 370 |
+
|
| 371 |
+
Returns:
|
| 372 |
+
Preprocessed features dataframe
|
| 373 |
+
"""
|
| 374 |
+
if not self.is_fitted:
|
| 375 |
+
raise ValueError("Pipeline not fitted. Call fit_transform() first.")
|
| 376 |
+
|
| 377 |
+
df = df.copy()
|
| 378 |
+
|
| 379 |
+
# Drop ID columns
|
| 380 |
+
df = df.drop(columns=self.id_cols, errors='ignore')
|
| 381 |
+
|
| 382 |
+
# Preprocess features
|
| 383 |
+
df = self._preprocess_features(df, fit=False)
|
| 384 |
+
|
| 385 |
+
# Align columns with training
|
| 386 |
+
for col in self.feature_columns:
|
| 387 |
+
if col not in df.columns:
|
| 388 |
+
df[col] = 0 # Default value for missing columns
|
| 389 |
+
|
| 390 |
+
# Keep only expected columns in correct order
|
| 391 |
+
df = df[self.feature_columns]
|
| 392 |
+
|
| 393 |
+
return df
|
| 394 |
+
|
| 395 |
+
def get_report(self) -> Dict[str, Any]:
|
| 396 |
+
"""Get the processing report."""
|
| 397 |
+
return self.processing_report
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
class QuickPipeline:
|
| 401 |
+
"""
|
| 402 |
+
Convenience class that combines AgnosticPipeline with DataCleaner for full preprocessing.
|
| 403 |
+
"""
|
| 404 |
+
|
| 405 |
+
def __init__(
|
| 406 |
+
self,
|
| 407 |
+
target_column: str,
|
| 408 |
+
binary_threshold: Optional[float] = None,
|
| 409 |
+
binary_rule: Optional[Callable[[pd.Series], int]] = None,
|
| 410 |
+
task_type: str = 'binary',
|
| 411 |
+
clean_data: bool = True,
|
| 412 |
+
clean_mode: str = 'thorough',
|
| 413 |
+
verbose: bool = True
|
| 414 |
+
):
|
| 415 |
+
"""
|
| 416 |
+
Initialize QuickPipeline.
|
| 417 |
+
|
| 418 |
+
Args:
|
| 419 |
+
target_column: Name of target column
|
| 420 |
+
binary_threshold: Threshold for binary classification
|
| 421 |
+
binary_rule: Custom rule for binary classification
|
| 422 |
+
task_type: 'binary' or 'regression'
|
| 423 |
+
clean_data: Whether to apply DataCleaner first
|
| 424 |
+
clean_mode: Cleaning mode ('basic', 'thorough', 'aggressive')
|
| 425 |
+
verbose: Print processing steps
|
| 426 |
+
"""
|
| 427 |
+
self.target_column = target_column
|
| 428 |
+
self.clean_data = clean_data
|
| 429 |
+
self.clean_mode = clean_mode
|
| 430 |
+
self.verbose = verbose
|
| 431 |
+
|
| 432 |
+
# Initialize sub-pipelines
|
| 433 |
+
self.agnostic_pipeline = AgnosticPipeline(
|
| 434 |
+
binary_threshold=binary_threshold,
|
| 435 |
+
binary_rule=binary_rule,
|
| 436 |
+
task_type=task_type,
|
| 437 |
+
verbose=verbose
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
self.cleaner = None
|
| 441 |
+
self.cleaning_report = None
|
| 442 |
+
|
| 443 |
+
def fit_transform(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.Series]:
|
| 444 |
+
"""
|
| 445 |
+
Full preprocessing: Clean → AgnosticPipeline.
|
| 446 |
+
|
| 447 |
+
Args:
|
| 448 |
+
df: Raw dataframe
|
| 449 |
+
|
| 450 |
+
Returns:
|
| 451 |
+
Tuple of (X features, y target)
|
| 452 |
+
"""
|
| 453 |
+
# Step 1: Clean data (optional)
|
| 454 |
+
if self.clean_data:
|
| 455 |
+
from .cleaning import DataCleaner
|
| 456 |
+
self.cleaner = DataCleaner(
|
| 457 |
+
target_column=self.target_column,
|
| 458 |
+
clean_mode=self.clean_mode
|
| 459 |
+
)
|
| 460 |
+
df = self.cleaner.clean(df, verbose=self.verbose)
|
| 461 |
+
self.cleaning_report = self.cleaner.get_report()
|
| 462 |
+
|
| 463 |
+
# Step 2: Agnostic preprocessing
|
| 464 |
+
X, y = self.agnostic_pipeline.fit_transform(df, self.target_column)
|
| 465 |
+
|
| 466 |
+
return X, y
|
| 467 |
+
|
| 468 |
+
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
| 469 |
+
"""Transform new data."""
|
| 470 |
+
return self.agnostic_pipeline.transform(df)
|
| 471 |
+
|
| 472 |
+
def get_full_report(self) -> Dict[str, Any]:
|
| 473 |
+
"""Get combined report from both stages."""
|
| 474 |
+
return {
|
| 475 |
+
'cleaning': self.cleaning_report,
|
| 476 |
+
'preprocessing': self.agnostic_pipeline.get_report()
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
# ============================================================
|
| 481 |
+
# Example usage
|
| 482 |
+
# ============================================================
|
| 483 |
+
|
| 484 |
+
if __name__ == "__main__":
|
| 485 |
+
# Example 1: Basic usage with credit score threshold
|
| 486 |
+
print("=" * 70)
|
| 487 |
+
print("EXAMPLE 1: Credit Score Binary Classification")
|
| 488 |
+
print("=" * 70)
|
| 489 |
+
|
| 490 |
+
# Create sample data
|
| 491 |
+
sample_data = pd.DataFrame({
|
| 492 |
+
'customer_id': range(1000, 1100), # Will be dropped as ID
|
| 493 |
+
'age': np.random.randint(18, 70, 100),
|
| 494 |
+
'income': np.random.randint(20000, 150000, 100),
|
| 495 |
+
'employment_type': np.random.choice(['Employed', 'Self-Employed', 'Unemployed', None], 100),
|
| 496 |
+
'loan_amount': np.random.randint(5000, 50000, 100),
|
| 497 |
+
'credit_score': np.random.randint(400, 850, 100) # Target
|
| 498 |
+
})
|
| 499 |
+
|
| 500 |
+
# Initialize pipeline: credit score < 600 = default (1)
|
| 501 |
+
pipeline = AgnosticPipeline(
|
| 502 |
+
binary_threshold=600,
|
| 503 |
+
task_type='binary'
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
# Preprocess
|
| 507 |
+
X, y = pipeline.fit_transform(sample_data, target_column='credit_score')
|
| 508 |
+
|
| 509 |
+
print(f"\nFeatures shape: {X.shape}")
|
| 510 |
+
print(f"Target distribution:\n{y.value_counts()}")
|
| 511 |
+
print(f"\nFeature columns: {X.columns.tolist()}")
|
| 512 |
+
|
| 513 |
+
# Example 2: Custom binary rule
|
| 514 |
+
print("\n" + "=" * 70)
|
| 515 |
+
print("EXAMPLE 2: Custom Binary Rule")
|
| 516 |
+
print("=" * 70)
|
| 517 |
+
|
| 518 |
+
# Custom rule: default if score < 600 AND income < 50000
|
| 519 |
+
custom_rule = lambda row: 1 if (row['credit_score'] < 600 and row['income'] < 50000) else 0
|
| 520 |
+
|
| 521 |
+
pipeline2 = AgnosticPipeline(
|
| 522 |
+
binary_rule=custom_rule,
|
| 523 |
+
task_type='binary'
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
X2, y2 = pipeline2.fit_transform(sample_data, target_column='credit_score')
|
| 527 |
+
print(f"\nTarget distribution with custom rule:\n{y2.value_counts()}")
|
| 528 |
+
|
| 529 |
+
# Example 3: Regression task
|
| 530 |
+
print("\n" + "=" * 70)
|
| 531 |
+
print("EXAMPLE 3: Regression Task")
|
| 532 |
+
print("=" * 70)
|
| 533 |
+
|
| 534 |
+
pipeline3 = AgnosticPipeline(task_type='regression')
|
| 535 |
+
X3, y3 = pipeline3.fit_transform(sample_data, target_column='credit_score')
|
| 536 |
+
|
| 537 |
+
print(f"\nTarget stats: mean={y3.mean():.2f}, std={y3.std():.2f}")
|
credily/analyzer.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Business context analyzer for TabulaML.
|
| 3 |
+
Analyzes model performance in finance-specific contexts.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from typing import Dict, Any, List
|
| 9 |
+
from sklearn.metrics import precision_score, recall_score, confusion_matrix
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BusinessAnalyzer:
|
| 13 |
+
"""
|
| 14 |
+
Analyzes model performance in business contexts.
|
| 15 |
+
Optimized for finance use cases like credit scoring, fraud detection, etc.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
CONTEXTS = {
|
| 19 |
+
'credit_scoring': {
|
| 20 |
+
'description': 'Loan default prediction',
|
| 21 |
+
'positive_label': 'default',
|
| 22 |
+
'cost_fp': 100, # Cost of rejecting good customer (lost revenue)
|
| 23 |
+
'cost_fn': 1000, # Cost of approving bad customer (default loss)
|
| 24 |
+
'revenue_tp': 50, # Revenue from correctly rejecting bad customer
|
| 25 |
+
'revenue_tn': 200, # Revenue from correctly approving good customer
|
| 26 |
+
},
|
| 27 |
+
'fraud_detection': {
|
| 28 |
+
'description': 'Transaction fraud detection',
|
| 29 |
+
'positive_label': 'fraud',
|
| 30 |
+
'cost_fp': 10, # Cost of blocking legitimate transaction
|
| 31 |
+
'cost_fn': 500, # Cost of missing fraud
|
| 32 |
+
'revenue_tp': 500, # Savings from catching fraud
|
| 33 |
+
'revenue_tn': 0, # Normal transaction
|
| 34 |
+
},
|
| 35 |
+
'churn_prediction': {
|
| 36 |
+
'description': 'Customer churn prediction',
|
| 37 |
+
'positive_label': 'churn',
|
| 38 |
+
'cost_fp': 50, # Cost of unnecessary retention effort
|
| 39 |
+
'cost_fn': 300, # Cost of losing customer
|
| 40 |
+
'revenue_tp': 250, # Value of retained customer
|
| 41 |
+
'revenue_tn': 0, # No action needed
|
| 42 |
+
},
|
| 43 |
+
'insurance_claims': {
|
| 44 |
+
'description': 'Insurance claims prediction',
|
| 45 |
+
'positive_label': 'claim',
|
| 46 |
+
'cost_fp': 20, # Cost of extra investigation
|
| 47 |
+
'cost_fn': 1000, # Cost of missing fraudulent claim
|
| 48 |
+
'revenue_tp': 800, # Savings from detecting bad claim
|
| 49 |
+
'revenue_tn': 0, # Normal claim processing
|
| 50 |
+
},
|
| 51 |
+
'collections': {
|
| 52 |
+
'description': 'Debt collection prioritization',
|
| 53 |
+
'positive_label': 'will_pay',
|
| 54 |
+
'cost_fp': 30, # Cost of unnecessary collection effort
|
| 55 |
+
'cost_fn': 200, # Lost recovery
|
| 56 |
+
'revenue_tp': 150, # Successful recovery
|
| 57 |
+
'revenue_tn': 0, # No action
|
| 58 |
+
},
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
def __init__(self, context: str = 'credit_scoring'):
|
| 62 |
+
if context not in self.CONTEXTS:
|
| 63 |
+
raise ValueError(f"Unknown context: {context}. Available: {list(self.CONTEXTS.keys())}")
|
| 64 |
+
self.context = context
|
| 65 |
+
self.config = self.CONTEXTS[context]
|
| 66 |
+
|
| 67 |
+
def analyze(
|
| 68 |
+
self,
|
| 69 |
+
pipeline,
|
| 70 |
+
df: pd.DataFrame,
|
| 71 |
+
target_column: str
|
| 72 |
+
) -> Dict[str, Any]:
|
| 73 |
+
"""
|
| 74 |
+
Analyze model performance in business context.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
pipeline: Trained TabulaMLPipeline
|
| 78 |
+
df: Test dataframe with features and target
|
| 79 |
+
target_column: Name of target column
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
dict: Business analysis report
|
| 83 |
+
"""
|
| 84 |
+
X = df.drop(columns=[target_column])
|
| 85 |
+
y_true = df[target_column].values
|
| 86 |
+
|
| 87 |
+
y_proba = pipeline.best_model.predict_proba(X)[:, 1]
|
| 88 |
+
|
| 89 |
+
# Find optimal threshold
|
| 90 |
+
thresholds = np.arange(0.1, 0.9, 0.05)
|
| 91 |
+
best_threshold = 0.5
|
| 92 |
+
best_profit = float('-inf')
|
| 93 |
+
|
| 94 |
+
threshold_analysis = []
|
| 95 |
+
for thresh in thresholds:
|
| 96 |
+
y_pred = (y_proba >= thresh).astype(int)
|
| 97 |
+
profit = self._calculate_profit(y_true, y_pred)
|
| 98 |
+
threshold_analysis.append({
|
| 99 |
+
'threshold': thresh,
|
| 100 |
+
'profit': profit,
|
| 101 |
+
'precision': precision_score(y_true, y_pred, zero_division=0),
|
| 102 |
+
'recall': recall_score(y_true, y_pred, zero_division=0)
|
| 103 |
+
})
|
| 104 |
+
if profit > best_profit:
|
| 105 |
+
best_profit = profit
|
| 106 |
+
best_threshold = thresh
|
| 107 |
+
|
| 108 |
+
# Calculate metrics at optimal threshold
|
| 109 |
+
y_pred_optimal = (y_proba >= best_threshold).astype(int)
|
| 110 |
+
cm = confusion_matrix(y_true, y_pred_optimal)
|
| 111 |
+
tn, fp, fn, tp = cm.ravel()
|
| 112 |
+
|
| 113 |
+
# Financial calculations
|
| 114 |
+
expected_profit = self._calculate_profit(y_true, y_pred_optimal)
|
| 115 |
+
risk_exposure = fn * self.config['cost_fn']
|
| 116 |
+
|
| 117 |
+
# Generate recommendations
|
| 118 |
+
recommendations = self._generate_recommendations(
|
| 119 |
+
precision_score(y_true, y_pred_optimal, zero_division=0),
|
| 120 |
+
recall_score(y_true, y_pred_optimal, zero_division=0),
|
| 121 |
+
best_threshold
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
return {
|
| 125 |
+
'context': self.context,
|
| 126 |
+
'context_description': self.config['description'],
|
| 127 |
+
'optimal_threshold': best_threshold,
|
| 128 |
+
'expected_profit': expected_profit,
|
| 129 |
+
'risk_exposure': risk_exposure,
|
| 130 |
+
'precision': precision_score(y_true, y_pred_optimal, zero_division=0),
|
| 131 |
+
'recall': recall_score(y_true, y_pred_optimal, zero_division=0),
|
| 132 |
+
'confusion_matrix': {
|
| 133 |
+
'true_negatives': int(tn),
|
| 134 |
+
'false_positives': int(fp),
|
| 135 |
+
'false_negatives': int(fn),
|
| 136 |
+
'true_positives': int(tp)
|
| 137 |
+
},
|
| 138 |
+
'threshold_analysis': threshold_analysis,
|
| 139 |
+
'recommendations': recommendations
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
def _calculate_profit(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
| 143 |
+
"""Calculate expected profit based on confusion matrix."""
|
| 144 |
+
cm = confusion_matrix(y_true, y_pred)
|
| 145 |
+
tn, fp, fn, tp = cm.ravel()
|
| 146 |
+
|
| 147 |
+
profit = (
|
| 148 |
+
tp * self.config['revenue_tp'] +
|
| 149 |
+
tn * self.config['revenue_tn'] -
|
| 150 |
+
fp * self.config['cost_fp'] -
|
| 151 |
+
fn * self.config['cost_fn']
|
| 152 |
+
)
|
| 153 |
+
return profit
|
| 154 |
+
|
| 155 |
+
def _generate_recommendations(
|
| 156 |
+
self,
|
| 157 |
+
precision: float,
|
| 158 |
+
recall: float,
|
| 159 |
+
threshold: float
|
| 160 |
+
) -> List[str]:
|
| 161 |
+
"""Generate business recommendations based on metrics."""
|
| 162 |
+
recommendations = []
|
| 163 |
+
|
| 164 |
+
if self.context == 'credit_scoring':
|
| 165 |
+
if precision < 0.7:
|
| 166 |
+
recommendations.append(
|
| 167 |
+
"Low precision: Consider raising the approval threshold to reduce bad debt"
|
| 168 |
+
)
|
| 169 |
+
if recall < 0.6:
|
| 170 |
+
recommendations.append(
|
| 171 |
+
"Low recall: Many defaulters are being approved. Review underwriting criteria"
|
| 172 |
+
)
|
| 173 |
+
if threshold > 0.6:
|
| 174 |
+
recommendations.append(
|
| 175 |
+
f"High threshold ({threshold:.2f}): May be rejecting too many good applicants"
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
elif self.context == 'fraud_detection':
|
| 179 |
+
if recall < 0.8:
|
| 180 |
+
recommendations.append(
|
| 181 |
+
"Critical: Low fraud detection rate. Lower threshold or add features"
|
| 182 |
+
)
|
| 183 |
+
if precision < 0.5:
|
| 184 |
+
recommendations.append(
|
| 185 |
+
"High false positive rate causing customer friction. Review flagging rules"
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
elif self.context == 'churn_prediction':
|
| 189 |
+
if recall < 0.7:
|
| 190 |
+
recommendations.append(
|
| 191 |
+
"Missing too many churners. Expand retention campaigns"
|
| 192 |
+
)
|
| 193 |
+
if precision < 0.5:
|
| 194 |
+
recommendations.append(
|
| 195 |
+
"Retention budget being wasted on non-churners. Refine targeting"
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# General recommendations
|
| 199 |
+
if precision > 0.8 and recall > 0.8:
|
| 200 |
+
recommendations.append(
|
| 201 |
+
"Model performing well. Consider A/B testing in production"
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
if not recommendations:
|
| 205 |
+
recommendations.append(
|
| 206 |
+
"Model metrics are within acceptable range for this context"
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
return recommendations
|
| 210 |
+
|
| 211 |
+
@classmethod
|
| 212 |
+
def list_contexts(cls) -> Dict[str, str]:
|
| 213 |
+
"""List all available business contexts."""
|
| 214 |
+
return {k: v['description'] for k, v in cls.CONTEXTS.items()}
|
credily/api/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Credily FastAPI REST API module.
|
| 3 |
+
Exposes ML functionality as HTTP endpoints.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from .main import app
|
| 7 |
+
|
| 8 |
+
__all__ = ['app']
|
credily/api/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (335 Bytes). View file
|
|
|
credily/api/__pycache__/database.cpython-314.pyc
ADDED
|
Binary file (20.2 kB). View file
|
|
|
credily/api/__pycache__/errors.cpython-314.pyc
ADDED
|
Binary file (13.1 kB). View file
|
|
|
credily/api/__pycache__/main.cpython-314.pyc
ADDED
|
Binary file (48.1 kB). View file
|
|
|
credily/api/__pycache__/schemas.cpython-314.pyc
ADDED
|
Binary file (18.3 kB). View file
|
|
|
credily/api/database.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Database module for storing prediction history and reports.
|
| 3 |
+
Supports both SQLite (development) and PostgreSQL (production).
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Optional, List, Dict, Any
|
| 11 |
+
from contextlib import contextmanager
|
| 12 |
+
from urllib.parse import urlparse
|
| 13 |
+
|
| 14 |
+
# Database configuration
|
| 15 |
+
DATABASE_URL = os.environ.get("DATABASE_URL")
|
| 16 |
+
|
| 17 |
+
# Determine database type
|
| 18 |
+
if DATABASE_URL and DATABASE_URL.startswith("postgresql"):
|
| 19 |
+
DB_TYPE = "postgresql"
|
| 20 |
+
try:
|
| 21 |
+
import psycopg2
|
| 22 |
+
from psycopg2.extras import RealDictCursor
|
| 23 |
+
except ImportError:
|
| 24 |
+
raise ImportError(
|
| 25 |
+
"psycopg2 is required for PostgreSQL. Install with: pip install psycopg2-binary"
|
| 26 |
+
)
|
| 27 |
+
else:
|
| 28 |
+
DB_TYPE = "sqlite"
|
| 29 |
+
import sqlite3
|
| 30 |
+
DB_PATH = Path("credily_data") / "credily.db"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def init_db():
|
| 34 |
+
"""Initialize the database with required tables."""
|
| 35 |
+
if DB_TYPE == "sqlite":
|
| 36 |
+
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 37 |
+
|
| 38 |
+
with get_db() as conn:
|
| 39 |
+
cursor = conn.cursor()
|
| 40 |
+
|
| 41 |
+
if DB_TYPE == "postgresql":
|
| 42 |
+
# PostgreSQL syntax
|
| 43 |
+
cursor.execute("""
|
| 44 |
+
CREATE TABLE IF NOT EXISTS prediction_sessions (
|
| 45 |
+
id TEXT PRIMARY KEY,
|
| 46 |
+
model_path TEXT NOT NULL,
|
| 47 |
+
model_name TEXT,
|
| 48 |
+
threshold_used DOUBLE PRECISION,
|
| 49 |
+
total_records INTEGER,
|
| 50 |
+
predicted_positive INTEGER,
|
| 51 |
+
predicted_negative INTEGER,
|
| 52 |
+
positive_rate DOUBLE PRECISION,
|
| 53 |
+
avg_probability DOUBLE PRECISION,
|
| 54 |
+
created_at TEXT NOT NULL
|
| 55 |
+
)
|
| 56 |
+
""")
|
| 57 |
+
|
| 58 |
+
cursor.execute("""
|
| 59 |
+
CREATE TABLE IF NOT EXISTS prediction_results (
|
| 60 |
+
id SERIAL PRIMARY KEY,
|
| 61 |
+
session_id TEXT NOT NULL,
|
| 62 |
+
record_index INTEGER,
|
| 63 |
+
prediction INTEGER,
|
| 64 |
+
probability DOUBLE PRECISION,
|
| 65 |
+
risk_level TEXT,
|
| 66 |
+
input_data TEXT,
|
| 67 |
+
FOREIGN KEY (session_id) REFERENCES prediction_sessions(id) ON DELETE CASCADE
|
| 68 |
+
)
|
| 69 |
+
""")
|
| 70 |
+
|
| 71 |
+
cursor.execute("""
|
| 72 |
+
CREATE TABLE IF NOT EXISTS training_reports (
|
| 73 |
+
id TEXT PRIMARY KEY,
|
| 74 |
+
model_name TEXT,
|
| 75 |
+
best_model TEXT,
|
| 76 |
+
best_score DOUBLE PRECISION,
|
| 77 |
+
test_auc DOUBLE PRECISION,
|
| 78 |
+
test_pr_auc DOUBLE PRECISION,
|
| 79 |
+
optimal_threshold DOUBLE PRECISION,
|
| 80 |
+
model_scores TEXT,
|
| 81 |
+
classification_report TEXT,
|
| 82 |
+
confusion_matrix TEXT,
|
| 83 |
+
feature_importances TEXT,
|
| 84 |
+
cleaning_report TEXT,
|
| 85 |
+
balancing_report TEXT,
|
| 86 |
+
created_at TEXT NOT NULL
|
| 87 |
+
)
|
| 88 |
+
""")
|
| 89 |
+
else:
|
| 90 |
+
# SQLite syntax
|
| 91 |
+
cursor.execute("""
|
| 92 |
+
CREATE TABLE IF NOT EXISTS prediction_sessions (
|
| 93 |
+
id TEXT PRIMARY KEY,
|
| 94 |
+
model_path TEXT NOT NULL,
|
| 95 |
+
model_name TEXT,
|
| 96 |
+
threshold_used REAL,
|
| 97 |
+
total_records INTEGER,
|
| 98 |
+
predicted_positive INTEGER,
|
| 99 |
+
predicted_negative INTEGER,
|
| 100 |
+
positive_rate REAL,
|
| 101 |
+
avg_probability REAL,
|
| 102 |
+
created_at TEXT NOT NULL
|
| 103 |
+
)
|
| 104 |
+
""")
|
| 105 |
+
|
| 106 |
+
cursor.execute("""
|
| 107 |
+
CREATE TABLE IF NOT EXISTS prediction_results (
|
| 108 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 109 |
+
session_id TEXT NOT NULL,
|
| 110 |
+
record_index INTEGER,
|
| 111 |
+
prediction INTEGER,
|
| 112 |
+
probability REAL,
|
| 113 |
+
risk_level TEXT,
|
| 114 |
+
input_data TEXT,
|
| 115 |
+
FOREIGN KEY (session_id) REFERENCES prediction_sessions(id)
|
| 116 |
+
)
|
| 117 |
+
""")
|
| 118 |
+
|
| 119 |
+
cursor.execute("""
|
| 120 |
+
CREATE TABLE IF NOT EXISTS training_reports (
|
| 121 |
+
id TEXT PRIMARY KEY,
|
| 122 |
+
model_name TEXT,
|
| 123 |
+
best_model TEXT,
|
| 124 |
+
best_score REAL,
|
| 125 |
+
test_auc REAL,
|
| 126 |
+
test_pr_auc REAL,
|
| 127 |
+
optimal_threshold REAL,
|
| 128 |
+
model_scores TEXT,
|
| 129 |
+
classification_report TEXT,
|
| 130 |
+
confusion_matrix TEXT,
|
| 131 |
+
feature_importances TEXT,
|
| 132 |
+
cleaning_report TEXT,
|
| 133 |
+
balancing_report TEXT,
|
| 134 |
+
created_at TEXT NOT NULL
|
| 135 |
+
)
|
| 136 |
+
""")
|
| 137 |
+
|
| 138 |
+
conn.commit()
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
@contextmanager
|
| 142 |
+
def get_db():
|
| 143 |
+
"""Get database connection context manager."""
|
| 144 |
+
if DB_TYPE == "postgresql":
|
| 145 |
+
conn = psycopg2.connect(DATABASE_URL)
|
| 146 |
+
conn.autocommit = False
|
| 147 |
+
try:
|
| 148 |
+
yield conn
|
| 149 |
+
finally:
|
| 150 |
+
conn.close()
|
| 151 |
+
else:
|
| 152 |
+
conn = sqlite3.connect(str(DB_PATH))
|
| 153 |
+
conn.row_factory = sqlite3.Row
|
| 154 |
+
try:
|
| 155 |
+
yield conn
|
| 156 |
+
finally:
|
| 157 |
+
conn.close()
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _dict_from_row(row, cursor_description=None):
|
| 161 |
+
"""Convert a database row to a dictionary."""
|
| 162 |
+
if DB_TYPE == "postgresql":
|
| 163 |
+
if cursor_description:
|
| 164 |
+
columns = [desc[0] for desc in cursor_description]
|
| 165 |
+
return dict(zip(columns, row))
|
| 166 |
+
return dict(row) if hasattr(row, 'keys') else row
|
| 167 |
+
else:
|
| 168 |
+
return dict(row)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _p(query: str) -> str:
|
| 172 |
+
"""Convert SQLite-style ? placeholders to PostgreSQL %s if needed."""
|
| 173 |
+
if DB_TYPE == "postgresql":
|
| 174 |
+
return query.replace("?", "%s")
|
| 175 |
+
return query
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# ============== Prediction Sessions ==============
|
| 179 |
+
|
| 180 |
+
def save_prediction_session(
|
| 181 |
+
session_id: str,
|
| 182 |
+
model_path: str,
|
| 183 |
+
model_name: Optional[str],
|
| 184 |
+
threshold_used: float,
|
| 185 |
+
total_records: int,
|
| 186 |
+
predicted_positive: int,
|
| 187 |
+
predicted_negative: int,
|
| 188 |
+
positive_rate: float,
|
| 189 |
+
avg_probability: Optional[float] = None
|
| 190 |
+
) -> str:
|
| 191 |
+
"""Save a prediction session."""
|
| 192 |
+
with get_db() as conn:
|
| 193 |
+
cursor = conn.cursor()
|
| 194 |
+
cursor.execute(_p("""
|
| 195 |
+
INSERT INTO prediction_sessions
|
| 196 |
+
(id, model_path, model_name, threshold_used, total_records,
|
| 197 |
+
predicted_positive, predicted_negative, positive_rate, avg_probability, created_at)
|
| 198 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
| 199 |
+
"""), (
|
| 200 |
+
session_id, model_path, model_name, threshold_used, total_records,
|
| 201 |
+
predicted_positive, predicted_negative, positive_rate, avg_probability,
|
| 202 |
+
datetime.now().isoformat()
|
| 203 |
+
))
|
| 204 |
+
conn.commit()
|
| 205 |
+
return session_id
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def save_prediction_results(session_id: str, results: List[Dict[str, Any]]):
|
| 209 |
+
"""Save individual prediction results."""
|
| 210 |
+
with get_db() as conn:
|
| 211 |
+
cursor = conn.cursor()
|
| 212 |
+
for result in results:
|
| 213 |
+
cursor.execute(_p("""
|
| 214 |
+
INSERT INTO prediction_results
|
| 215 |
+
(session_id, record_index, prediction, probability, risk_level, input_data)
|
| 216 |
+
VALUES (?, ?, ?, ?, ?, ?)
|
| 217 |
+
"""), (
|
| 218 |
+
session_id,
|
| 219 |
+
result.get('index'),
|
| 220 |
+
result.get('prediction'),
|
| 221 |
+
result.get('probability'),
|
| 222 |
+
result.get('risk_level'),
|
| 223 |
+
json.dumps(result.get('input_data', {}))
|
| 224 |
+
))
|
| 225 |
+
conn.commit()
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def get_prediction_sessions(limit: int = 50, offset: int = 0) -> List[Dict[str, Any]]:
|
| 229 |
+
"""Get prediction session history."""
|
| 230 |
+
with get_db() as conn:
|
| 231 |
+
cursor = conn.cursor()
|
| 232 |
+
cursor.execute(_p("""
|
| 233 |
+
SELECT * FROM prediction_sessions
|
| 234 |
+
ORDER BY created_at DESC
|
| 235 |
+
LIMIT ? OFFSET ?
|
| 236 |
+
"""), (limit, offset))
|
| 237 |
+
rows = cursor.fetchall()
|
| 238 |
+
return [_dict_from_row(row, cursor.description) for row in rows]
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def get_prediction_session(session_id: str) -> Optional[Dict[str, Any]]:
|
| 242 |
+
"""Get a specific prediction session."""
|
| 243 |
+
with get_db() as conn:
|
| 244 |
+
cursor = conn.cursor()
|
| 245 |
+
cursor.execute(_p("SELECT * FROM prediction_sessions WHERE id = ?"), (session_id,))
|
| 246 |
+
row = cursor.fetchone()
|
| 247 |
+
if row:
|
| 248 |
+
return _dict_from_row(row, cursor.description)
|
| 249 |
+
return None
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def get_prediction_results(session_id: str) -> List[Dict[str, Any]]:
|
| 253 |
+
"""Get prediction results for a session."""
|
| 254 |
+
with get_db() as conn:
|
| 255 |
+
cursor = conn.cursor()
|
| 256 |
+
cursor.execute(_p("""
|
| 257 |
+
SELECT * FROM prediction_results
|
| 258 |
+
WHERE session_id = ?
|
| 259 |
+
ORDER BY record_index
|
| 260 |
+
"""), (session_id,))
|
| 261 |
+
rows = cursor.fetchall()
|
| 262 |
+
results = []
|
| 263 |
+
for row in rows:
|
| 264 |
+
result = _dict_from_row(row, cursor.description)
|
| 265 |
+
if result.get('input_data'):
|
| 266 |
+
result['input_data'] = json.loads(result['input_data'])
|
| 267 |
+
results.append(result)
|
| 268 |
+
return results
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def delete_prediction_session(session_id: str) -> bool:
|
| 272 |
+
"""Delete a prediction session and its results."""
|
| 273 |
+
with get_db() as conn:
|
| 274 |
+
cursor = conn.cursor()
|
| 275 |
+
cursor.execute(_p("DELETE FROM prediction_results WHERE session_id = ?"), (session_id,))
|
| 276 |
+
cursor.execute(_p("DELETE FROM prediction_sessions WHERE id = ?"), (session_id,))
|
| 277 |
+
conn.commit()
|
| 278 |
+
return cursor.rowcount > 0
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
# ============== Training Reports ==============
|
| 282 |
+
|
| 283 |
+
def save_training_report(
|
| 284 |
+
report_id: str,
|
| 285 |
+
model_name: str,
|
| 286 |
+
results: Dict[str, Any]
|
| 287 |
+
) -> str:
|
| 288 |
+
"""Save a training report."""
|
| 289 |
+
with get_db() as conn:
|
| 290 |
+
cursor = conn.cursor()
|
| 291 |
+
cursor.execute(_p("""
|
| 292 |
+
INSERT INTO training_reports
|
| 293 |
+
(id, model_name, best_model, best_score, test_auc, test_pr_auc,
|
| 294 |
+
optimal_threshold, model_scores, classification_report, confusion_matrix,
|
| 295 |
+
feature_importances, cleaning_report, balancing_report, created_at)
|
| 296 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
| 297 |
+
"""), (
|
| 298 |
+
report_id,
|
| 299 |
+
model_name,
|
| 300 |
+
results.get('best_model'),
|
| 301 |
+
results.get('best_score'),
|
| 302 |
+
results.get('test_auc'),
|
| 303 |
+
results.get('test_pr_auc'),
|
| 304 |
+
results.get('optimal_threshold'),
|
| 305 |
+
json.dumps(results.get('model_scores', {})),
|
| 306 |
+
json.dumps(results.get('classification_report', {})),
|
| 307 |
+
json.dumps(results.get('confusion_matrix', [])),
|
| 308 |
+
json.dumps(results.get('feature_importances', {})),
|
| 309 |
+
json.dumps(results.get('cleaning_report')),
|
| 310 |
+
json.dumps(results.get('balancing_report')),
|
| 311 |
+
datetime.now().isoformat()
|
| 312 |
+
))
|
| 313 |
+
conn.commit()
|
| 314 |
+
return report_id
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def get_training_reports(limit: int = 50, offset: int = 0) -> List[Dict[str, Any]]:
|
| 318 |
+
"""Get training report history."""
|
| 319 |
+
with get_db() as conn:
|
| 320 |
+
cursor = conn.cursor()
|
| 321 |
+
cursor.execute(_p("""
|
| 322 |
+
SELECT id, model_name, best_model, best_score, test_auc,
|
| 323 |
+
test_pr_auc, optimal_threshold, created_at
|
| 324 |
+
FROM training_reports
|
| 325 |
+
ORDER BY created_at DESC
|
| 326 |
+
LIMIT ? OFFSET ?
|
| 327 |
+
"""), (limit, offset))
|
| 328 |
+
rows = cursor.fetchall()
|
| 329 |
+
return [_dict_from_row(row, cursor.description) for row in rows]
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def get_training_report(report_id: str) -> Optional[Dict[str, Any]]:
|
| 333 |
+
"""Get a specific training report."""
|
| 334 |
+
with get_db() as conn:
|
| 335 |
+
cursor = conn.cursor()
|
| 336 |
+
cursor.execute(_p("SELECT * FROM training_reports WHERE id = ?"), (report_id,))
|
| 337 |
+
row = cursor.fetchone()
|
| 338 |
+
if row:
|
| 339 |
+
result = _dict_from_row(row, cursor.description)
|
| 340 |
+
# Parse JSON fields
|
| 341 |
+
for field in ['model_scores', 'classification_report', 'confusion_matrix',
|
| 342 |
+
'feature_importances', 'cleaning_report', 'balancing_report']:
|
| 343 |
+
if result.get(field):
|
| 344 |
+
result[field] = json.loads(result[field])
|
| 345 |
+
return result
|
| 346 |
+
return None
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def delete_training_report(report_id: str) -> bool:
|
| 350 |
+
"""Delete a training report."""
|
| 351 |
+
with get_db() as conn:
|
| 352 |
+
cursor = conn.cursor()
|
| 353 |
+
cursor.execute(_p("DELETE FROM training_reports WHERE id = ?"), (report_id,))
|
| 354 |
+
conn.commit()
|
| 355 |
+
return cursor.rowcount > 0
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def get_db_info() -> Dict[str, Any]:
|
| 359 |
+
"""Get database connection info (for debugging/health checks)."""
|
| 360 |
+
return {
|
| 361 |
+
"type": DB_TYPE,
|
| 362 |
+
"url": DATABASE_URL[:30] + "..." if DATABASE_URL else None,
|
| 363 |
+
"path": str(DB_PATH) if DB_TYPE == "sqlite" else None
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
# Initialize database on module import
|
| 368 |
+
init_db()
|
credily/api/errors.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Error handling module for Credily API.
|
| 3 |
+
Provides user-friendly error messages while logging detailed errors for developers.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import traceback
|
| 8 |
+
from typing import Optional, Dict, Any
|
| 9 |
+
from fastapi import HTTPException
|
| 10 |
+
|
| 11 |
+
# Configure logging for developer errors
|
| 12 |
+
logging.basicConfig(
|
| 13 |
+
level=logging.INFO,
|
| 14 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 15 |
+
)
|
| 16 |
+
logger = logging.getLogger('credily.api')
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class UserFriendlyError(Exception):
|
| 20 |
+
"""Exception with user-friendly message and optional developer details."""
|
| 21 |
+
|
| 22 |
+
def __init__(self, user_message: str, detail: str = None, status_code: int = 400):
|
| 23 |
+
self.user_message = user_message
|
| 24 |
+
self.detail = detail
|
| 25 |
+
self.status_code = status_code
|
| 26 |
+
super().__init__(user_message)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Error message mappings for common exceptions
|
| 30 |
+
ERROR_MAPPINGS = {
|
| 31 |
+
# Data errors
|
| 32 |
+
"No data provided": "Please provide data to process. The data field cannot be empty.",
|
| 33 |
+
"Target column": "The target column was not found in your data. Please check the column name.",
|
| 34 |
+
"not found": "The requested resource was not found. Please check the path and try again.",
|
| 35 |
+
|
| 36 |
+
# Model errors
|
| 37 |
+
"Model not trained": "No trained model available. Please train a model first or provide a valid model path.",
|
| 38 |
+
"Failed to load model": "Unable to load the model file. Please ensure the file is valid and not corrupted.",
|
| 39 |
+
"model.pkl": "Invalid model file format. Please provide a valid Credily model file.",
|
| 40 |
+
|
| 41 |
+
# Data quality errors
|
| 42 |
+
"Binary classification requires 2 classes": "Your data must have exactly 2 classes in the target column for binary classification.",
|
| 43 |
+
"less than 2 classes": "Your target column has only one class. Binary classification requires at least 2 different classes.",
|
| 44 |
+
"more than 2 classes": "Your target column has more than 2 classes. This tool currently supports binary classification only.",
|
| 45 |
+
|
| 46 |
+
# File errors
|
| 47 |
+
"not found at": "The specified file could not be found. Please check the path and try again.",
|
| 48 |
+
"Permission denied": "Unable to access the file. Please check file permissions.",
|
| 49 |
+
|
| 50 |
+
# Memory errors
|
| 51 |
+
"MemoryError": "The dataset is too large to process. Please try with a smaller dataset or contact support.",
|
| 52 |
+
|
| 53 |
+
# Connection errors
|
| 54 |
+
"Connection": "Unable to connect to the service. Please try again later.",
|
| 55 |
+
|
| 56 |
+
# Validation errors
|
| 57 |
+
"validation error": "Invalid data format. Please check your input data matches the required format.",
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_user_friendly_message(error: Exception) -> str:
|
| 62 |
+
"""
|
| 63 |
+
Convert a technical error message to a user-friendly message.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
error: The exception that occurred
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
A user-friendly error message
|
| 70 |
+
"""
|
| 71 |
+
error_str = str(error).lower()
|
| 72 |
+
error_type = type(error).__name__
|
| 73 |
+
|
| 74 |
+
# Check for specific error patterns
|
| 75 |
+
for pattern, friendly_message in ERROR_MAPPINGS.items():
|
| 76 |
+
if pattern.lower() in error_str:
|
| 77 |
+
return friendly_message
|
| 78 |
+
|
| 79 |
+
# Handle specific exception types
|
| 80 |
+
if isinstance(error, FileNotFoundError):
|
| 81 |
+
return "The requested file was not found. Please check the path and try again."
|
| 82 |
+
|
| 83 |
+
if isinstance(error, PermissionError):
|
| 84 |
+
return "Unable to access the file due to permission restrictions."
|
| 85 |
+
|
| 86 |
+
if isinstance(error, ValueError):
|
| 87 |
+
# Try to make ValueError messages more user-friendly
|
| 88 |
+
if "column" in error_str:
|
| 89 |
+
return "There was an issue with the data columns. Please check your data format."
|
| 90 |
+
if "shape" in error_str:
|
| 91 |
+
return "The data dimensions are incorrect. Please ensure your data is properly formatted."
|
| 92 |
+
if "dtype" in error_str or "type" in error_str:
|
| 93 |
+
return "There was a data type mismatch. Please ensure all values are in the correct format."
|
| 94 |
+
return "Invalid value provided. Please check your input data."
|
| 95 |
+
|
| 96 |
+
if isinstance(error, KeyError):
|
| 97 |
+
return "A required field is missing from your data. Please check the data format."
|
| 98 |
+
|
| 99 |
+
if isinstance(error, TypeError):
|
| 100 |
+
return "Invalid data type provided. Please check your input format."
|
| 101 |
+
|
| 102 |
+
if "memory" in error_str:
|
| 103 |
+
return "The operation requires more memory than available. Please try with a smaller dataset."
|
| 104 |
+
|
| 105 |
+
if "timeout" in error_str:
|
| 106 |
+
return "The operation timed out. Please try again with a smaller dataset or later."
|
| 107 |
+
|
| 108 |
+
# Default message for unknown errors
|
| 109 |
+
return "An unexpected error occurred. Please try again or contact support if the problem persists."
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def handle_api_error(
|
| 113 |
+
error: Exception,
|
| 114 |
+
operation: str = "operation",
|
| 115 |
+
context: Optional[Dict[str, Any]] = None
|
| 116 |
+
) -> HTTPException:
|
| 117 |
+
"""
|
| 118 |
+
Handle an API error by logging details and returning a user-friendly HTTPException.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
error: The exception that occurred
|
| 122 |
+
operation: Description of what operation was being performed
|
| 123 |
+
context: Optional additional context for logging
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
HTTPException with user-friendly message
|
| 127 |
+
"""
|
| 128 |
+
# Get user-friendly message
|
| 129 |
+
if isinstance(error, UserFriendlyError):
|
| 130 |
+
user_message = error.user_message
|
| 131 |
+
status_code = error.status_code
|
| 132 |
+
elif isinstance(error, HTTPException):
|
| 133 |
+
# Already an HTTP exception, but might need friendlier message
|
| 134 |
+
user_message = get_user_friendly_message(Exception(error.detail))
|
| 135 |
+
status_code = error.status_code
|
| 136 |
+
else:
|
| 137 |
+
user_message = get_user_friendly_message(error)
|
| 138 |
+
status_code = 500 if not isinstance(error, (ValueError, KeyError, TypeError)) else 400
|
| 139 |
+
|
| 140 |
+
# Log detailed error for developers
|
| 141 |
+
logger.error(
|
| 142 |
+
f"Error during {operation}: {type(error).__name__}: {str(error)}",
|
| 143 |
+
extra={'context': context or {}}
|
| 144 |
+
)
|
| 145 |
+
logger.debug(f"Full traceback:\n{traceback.format_exc()}")
|
| 146 |
+
|
| 147 |
+
return HTTPException(status_code=status_code, detail=user_message)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def log_warning(message: str, context: Optional[Dict[str, Any]] = None):
|
| 151 |
+
"""Log a warning message."""
|
| 152 |
+
logger.warning(message, extra={'context': context or {}})
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def log_info(message: str, context: Optional[Dict[str, Any]] = None):
|
| 156 |
+
"""Log an info message."""
|
| 157 |
+
logger.info(message, extra={'context': context or {}})
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# Pre-defined user-friendly error responses
|
| 161 |
+
class APIErrors:
|
| 162 |
+
"""Common API error responses."""
|
| 163 |
+
|
| 164 |
+
@staticmethod
|
| 165 |
+
def no_data() -> HTTPException:
|
| 166 |
+
return HTTPException(
|
| 167 |
+
status_code=400,
|
| 168 |
+
detail="Please provide data to process. The data field cannot be empty."
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
@staticmethod
|
| 172 |
+
def model_not_found(path: str = None) -> HTTPException:
|
| 173 |
+
msg = "Model not found."
|
| 174 |
+
if path:
|
| 175 |
+
msg = f"Model not found at the specified location. Please check the path and try again."
|
| 176 |
+
return HTTPException(status_code=404, detail=msg)
|
| 177 |
+
|
| 178 |
+
@staticmethod
|
| 179 |
+
def invalid_model() -> HTTPException:
|
| 180 |
+
return HTTPException(
|
| 181 |
+
status_code=400,
|
| 182 |
+
detail="Invalid model file. Please ensure you're using a valid Credily model."
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
@staticmethod
|
| 186 |
+
def target_not_found(column: str) -> HTTPException:
|
| 187 |
+
return HTTPException(
|
| 188 |
+
status_code=400,
|
| 189 |
+
detail=f"Target column '{column}' not found in your data. Please check the column name and try again."
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
@staticmethod
|
| 193 |
+
def insufficient_classes() -> HTTPException:
|
| 194 |
+
return HTTPException(
|
| 195 |
+
status_code=400,
|
| 196 |
+
detail="Your data must have exactly 2 classes for binary classification. Please check your target column."
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
@staticmethod
|
| 200 |
+
def training_failed() -> HTTPException:
|
| 201 |
+
return HTTPException(
|
| 202 |
+
status_code=500,
|
| 203 |
+
detail="Model training failed. Please check your data format and try again."
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
@staticmethod
|
| 207 |
+
def prediction_failed() -> HTTPException:
|
| 208 |
+
return HTTPException(
|
| 209 |
+
status_code=500,
|
| 210 |
+
detail="Prediction failed. Please ensure your data matches the format used for training."
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
@staticmethod
|
| 214 |
+
def file_not_found(filename: str = None) -> HTTPException:
|
| 215 |
+
msg = "The requested file was not found."
|
| 216 |
+
if filename:
|
| 217 |
+
msg = f"File '{filename}' not found. It may have expired or been removed."
|
| 218 |
+
return HTTPException(status_code=404, detail=msg)
|
| 219 |
+
|
| 220 |
+
@staticmethod
|
| 221 |
+
def invalid_data_format() -> HTTPException:
|
| 222 |
+
return HTTPException(
|
| 223 |
+
status_code=400,
|
| 224 |
+
detail="Invalid data format. Please provide data as a list of records (JSON objects)."
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
@staticmethod
|
| 228 |
+
def server_error() -> HTTPException:
|
| 229 |
+
return HTTPException(
|
| 230 |
+
status_code=500,
|
| 231 |
+
detail="An internal error occurred. Please try again later."
|
| 232 |
+
)
|
credily/api/main.py
ADDED
|
@@ -0,0 +1,1035 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI application for Credily Credit Scoring API.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import io
|
| 7 |
+
import uuid
|
| 8 |
+
import json
|
| 9 |
+
import shutil
|
| 10 |
+
import zipfile
|
| 11 |
+
import tempfile
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Optional
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import pandas as pd
|
| 18 |
+
from fastapi import FastAPI, HTTPException, Query, File, UploadFile, Form
|
| 19 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 20 |
+
from fastapi.responses import StreamingResponse
|
| 21 |
+
|
| 22 |
+
from .schemas import (
|
| 23 |
+
TrainRequest, TrainResponse, TrainConfig,
|
| 24 |
+
PredictRequest, PredictResponse, PredictionResult,
|
| 25 |
+
SinglePredictRequest,
|
| 26 |
+
ProfileRequest, ProfileResponse, ColumnProfile,
|
| 27 |
+
PredictionHistoryResponse, PredictionSessionInfo, PredictionSessionDetailResponse,
|
| 28 |
+
TrainingReportsResponse, TrainingReportInfo, TrainingReportDetailResponse,
|
| 29 |
+
HealthResponse, SafetyReportSchema
|
| 30 |
+
)
|
| 31 |
+
from .database import (
|
| 32 |
+
init_db, save_prediction_session, save_prediction_results,
|
| 33 |
+
get_prediction_sessions, get_prediction_session, get_prediction_results,
|
| 34 |
+
delete_prediction_session, save_training_report, get_training_reports,
|
| 35 |
+
get_training_report, delete_training_report, get_db_info, DB_TYPE
|
| 36 |
+
)
|
| 37 |
+
from .errors import handle_api_error, APIErrors, logger, log_info
|
| 38 |
+
from ..automl import CredilyPipeline
|
| 39 |
+
from ..profiler import DataProfiler
|
| 40 |
+
from .. import __version__
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ============== App Configuration ==============
|
| 44 |
+
|
| 45 |
+
app = FastAPI(
|
| 46 |
+
title="Credily Credit Scoring API",
|
| 47 |
+
description="REST API for AI-powered credit scoring and risk assessment",
|
| 48 |
+
version=__version__,
|
| 49 |
+
docs_url="/docs",
|
| 50 |
+
redoc_url="/redoc"
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# CORS configuration - allow frontend to communicate
|
| 54 |
+
app.add_middleware(
|
| 55 |
+
CORSMiddleware,
|
| 56 |
+
allow_origins=[
|
| 57 |
+
"http://localhost:5173",
|
| 58 |
+
"http://localhost:8080",
|
| 59 |
+
"http://localhost:8081",
|
| 60 |
+
"http://localhost:8082",
|
| 61 |
+
"http://localhost:3000",
|
| 62 |
+
"http://127.0.0.1:5173",
|
| 63 |
+
"http://127.0.0.1:3000",
|
| 64 |
+
"http://127.0.0.1:8080",
|
| 65 |
+
"http://127.0.0.1:8081",
|
| 66 |
+
"http://127.0.0.1:8082",
|
| 67 |
+
# Network IP access
|
| 68 |
+
"http://172.20.10.3:8080",
|
| 69 |
+
"http://172.20.10.3:5173",
|
| 70 |
+
"http://172.20.10.3:3000",
|
| 71 |
+
# Production frontend URLs (Lovable)
|
| 72 |
+
"https://id-preview--449e5a89-ee15-4a86-a93c-cf19ebb9c17e.lovable.app",
|
| 73 |
+
"https://credily-credit-scoring-ai.lovable.app",
|
| 74 |
+
"https://credily-six.vercel.app",
|
| 75 |
+
],
|
| 76 |
+
allow_credentials=True,
|
| 77 |
+
allow_methods=["*"],
|
| 78 |
+
allow_headers=["*"],
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Temporary storage for model downloads (cleaned up after download)
|
| 82 |
+
TEMP_MODELS_DIR = Path(tempfile.gettempdir()) / "credily_temp_models"
|
| 83 |
+
TEMP_MODELS_DIR.mkdir(exist_ok=True)
|
| 84 |
+
|
| 85 |
+
# Storage for uploaded models (model_id -> model_path mapping)
|
| 86 |
+
UPLOADED_MODELS_DIR = TEMP_MODELS_DIR / "uploaded_models"
|
| 87 |
+
UPLOADED_MODELS_DIR.mkdir(exist_ok=True)
|
| 88 |
+
uploaded_models: dict[str, str] = {} # model_id -> model.pkl path
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# ============== Helper Functions ==============
|
| 92 |
+
|
| 93 |
+
def convert_numpy_types(obj):
|
| 94 |
+
"""Recursively convert numpy types to native Python types."""
|
| 95 |
+
if isinstance(obj, dict):
|
| 96 |
+
return {k: convert_numpy_types(v) for k, v in obj.items()}
|
| 97 |
+
elif isinstance(obj, list):
|
| 98 |
+
return [convert_numpy_types(item) for item in obj]
|
| 99 |
+
elif isinstance(obj, np.ndarray):
|
| 100 |
+
return obj.tolist()
|
| 101 |
+
elif isinstance(obj, (np.int64, np.int32, np.int16, np.int8)):
|
| 102 |
+
return int(obj)
|
| 103 |
+
elif isinstance(obj, (np.float64, np.float32, np.float16)):
|
| 104 |
+
return float(obj)
|
| 105 |
+
elif isinstance(obj, np.bool_):
|
| 106 |
+
return bool(obj)
|
| 107 |
+
else:
|
| 108 |
+
return obj
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def classify_risk(probability: float, threshold: float) -> str:
|
| 112 |
+
"""Classify risk level based on probability."""
|
| 113 |
+
if probability >= threshold + 0.2:
|
| 114 |
+
return "high"
|
| 115 |
+
elif probability >= threshold:
|
| 116 |
+
return "medium"
|
| 117 |
+
elif probability >= threshold - 0.15:
|
| 118 |
+
return "low"
|
| 119 |
+
else:
|
| 120 |
+
return "very_low"
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def create_model_zip(model_dir: Path, model_name: str) -> Path:
|
| 124 |
+
"""Create a zip file containing all model artifacts."""
|
| 125 |
+
zip_filename = f"{model_name}.zip"
|
| 126 |
+
zip_path = TEMP_MODELS_DIR / zip_filename
|
| 127 |
+
|
| 128 |
+
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
| 129 |
+
for file_path in model_dir.rglob('*'):
|
| 130 |
+
if file_path.is_file():
|
| 131 |
+
arcname = file_path.relative_to(model_dir)
|
| 132 |
+
zipf.write(file_path, arcname)
|
| 133 |
+
|
| 134 |
+
return zip_path
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def load_model_from_path(model_path: str) -> CredilyPipeline:
|
| 138 |
+
"""Load a model from the specified path."""
|
| 139 |
+
path = Path(model_path)
|
| 140 |
+
|
| 141 |
+
if not path.exists():
|
| 142 |
+
logger.warning(f"Model file not found: {model_path}")
|
| 143 |
+
raise APIErrors.model_not_found(model_path)
|
| 144 |
+
|
| 145 |
+
# If it's a zip file, extract it first
|
| 146 |
+
if path.suffix == '.zip':
|
| 147 |
+
extract_dir = TEMP_MODELS_DIR / f"extract_{uuid.uuid4().hex[:8]}"
|
| 148 |
+
extract_dir.mkdir(exist_ok=True)
|
| 149 |
+
|
| 150 |
+
try:
|
| 151 |
+
with zipfile.ZipFile(path, 'r') as zipf:
|
| 152 |
+
zipf.extractall(extract_dir)
|
| 153 |
+
except zipfile.BadZipFile:
|
| 154 |
+
shutil.rmtree(extract_dir, ignore_errors=True)
|
| 155 |
+
logger.error(f"Invalid zip file: {model_path}")
|
| 156 |
+
raise APIErrors.invalid_model()
|
| 157 |
+
|
| 158 |
+
# Find the model.pkl file
|
| 159 |
+
pkl_files = list(extract_dir.rglob('model.pkl'))
|
| 160 |
+
if not pkl_files:
|
| 161 |
+
shutil.rmtree(extract_dir)
|
| 162 |
+
logger.error(f"No model.pkl in zip: {model_path}")
|
| 163 |
+
raise APIErrors.invalid_model()
|
| 164 |
+
|
| 165 |
+
model_path = str(pkl_files[0])
|
| 166 |
+
|
| 167 |
+
try:
|
| 168 |
+
pipeline = CredilyPipeline.load(model_path)
|
| 169 |
+
log_info(f"Model loaded successfully: {model_path}")
|
| 170 |
+
return pipeline
|
| 171 |
+
except Exception as e:
|
| 172 |
+
logger.error(f"Failed to load model: {e}", exc_info=True)
|
| 173 |
+
raise HTTPException(
|
| 174 |
+
status_code=400,
|
| 175 |
+
detail="Unable to load the model file. Please ensure it's a valid Credily model."
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def process_uploaded_model(file_content: bytes, filename: str) -> tuple[str, str]:
|
| 180 |
+
"""
|
| 181 |
+
Process an uploaded model file and return (model_id, model_path).
|
| 182 |
+
|
| 183 |
+
Supports .zip files containing model.pkl or direct .pkl files.
|
| 184 |
+
"""
|
| 185 |
+
model_id = f"model_{uuid.uuid4().hex[:12]}"
|
| 186 |
+
model_dir = UPLOADED_MODELS_DIR / model_id
|
| 187 |
+
model_dir.mkdir(exist_ok=True)
|
| 188 |
+
|
| 189 |
+
try:
|
| 190 |
+
if filename.endswith('.zip'):
|
| 191 |
+
# Extract zip file
|
| 192 |
+
zip_path = model_dir / filename
|
| 193 |
+
zip_path.write_bytes(file_content)
|
| 194 |
+
|
| 195 |
+
try:
|
| 196 |
+
with zipfile.ZipFile(zip_path, 'r') as zipf:
|
| 197 |
+
zipf.extractall(model_dir)
|
| 198 |
+
except zipfile.BadZipFile:
|
| 199 |
+
shutil.rmtree(model_dir, ignore_errors=True)
|
| 200 |
+
raise HTTPException(
|
| 201 |
+
status_code=400,
|
| 202 |
+
detail="Invalid zip file. Please upload a valid model zip file."
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Find model.pkl
|
| 206 |
+
pkl_files = list(model_dir.rglob('model.pkl'))
|
| 207 |
+
if not pkl_files:
|
| 208 |
+
shutil.rmtree(model_dir, ignore_errors=True)
|
| 209 |
+
raise HTTPException(
|
| 210 |
+
status_code=400,
|
| 211 |
+
detail="No model.pkl found in the zip file. Please upload a valid Credily model."
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
model_path = str(pkl_files[0])
|
| 215 |
+
|
| 216 |
+
elif filename.endswith('.pkl'):
|
| 217 |
+
# Direct pkl file
|
| 218 |
+
model_path = str(model_dir / 'model.pkl')
|
| 219 |
+
Path(model_path).write_bytes(file_content)
|
| 220 |
+
else:
|
| 221 |
+
shutil.rmtree(model_dir, ignore_errors=True)
|
| 222 |
+
raise HTTPException(
|
| 223 |
+
status_code=400,
|
| 224 |
+
detail="Unsupported file format. Please upload a .zip or .pkl file."
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Validate model can be loaded
|
| 228 |
+
try:
|
| 229 |
+
pipeline = CredilyPipeline.load(model_path)
|
| 230 |
+
log_info(f"Uploaded model validated: {model_id}")
|
| 231 |
+
except Exception as e:
|
| 232 |
+
shutil.rmtree(model_dir, ignore_errors=True)
|
| 233 |
+
logger.error(f"Invalid model file: {e}")
|
| 234 |
+
raise HTTPException(
|
| 235 |
+
status_code=400,
|
| 236 |
+
detail="Invalid model file. Please upload a valid Credily model."
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# Store mapping
|
| 240 |
+
uploaded_models[model_id] = model_path
|
| 241 |
+
|
| 242 |
+
return model_id, model_path
|
| 243 |
+
|
| 244 |
+
except HTTPException:
|
| 245 |
+
raise
|
| 246 |
+
except Exception as e:
|
| 247 |
+
shutil.rmtree(model_dir, ignore_errors=True)
|
| 248 |
+
logger.error(f"Failed to process uploaded model: {e}", exc_info=True)
|
| 249 |
+
raise HTTPException(
|
| 250 |
+
status_code=500,
|
| 251 |
+
detail="Failed to process the uploaded model file."
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# ============== Startup Event ==============
|
| 256 |
+
|
| 257 |
+
@app.on_event("startup")
|
| 258 |
+
async def startup_event():
|
| 259 |
+
"""Initialize database on startup."""
|
| 260 |
+
init_db()
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
# ============== API Endpoints ==============
|
| 264 |
+
|
| 265 |
+
@app.get("/", response_model=HealthResponse)
|
| 266 |
+
async def health_check():
|
| 267 |
+
"""Health check endpoint."""
|
| 268 |
+
db_info = get_db_info()
|
| 269 |
+
db_status = f"healthy ({db_info['type']})"
|
| 270 |
+
return HealthResponse(
|
| 271 |
+
status="healthy",
|
| 272 |
+
version=__version__,
|
| 273 |
+
database_status=db_status
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
@app.get("/api/health", response_model=HealthResponse)
|
| 278 |
+
async def api_health():
|
| 279 |
+
"""API health check endpoint."""
|
| 280 |
+
return await health_check()
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# ============== Training Endpoints ==============
|
| 284 |
+
|
| 285 |
+
@app.post("/api/train", response_model=TrainResponse)
|
| 286 |
+
async def train_model(request: TrainRequest):
|
| 287 |
+
"""
|
| 288 |
+
Train a new credit scoring model.
|
| 289 |
+
|
| 290 |
+
- Accepts training data as JSON records
|
| 291 |
+
- Automatically profiles, cleans, and balances data
|
| 292 |
+
- Trains multiple models and selects the best performer
|
| 293 |
+
- Returns model as a downloadable zip file
|
| 294 |
+
- Saves training report to database
|
| 295 |
+
"""
|
| 296 |
+
try:
|
| 297 |
+
# Convert request data to DataFrame
|
| 298 |
+
df = pd.DataFrame(request.data)
|
| 299 |
+
|
| 300 |
+
if df.empty:
|
| 301 |
+
raise APIErrors.no_data()
|
| 302 |
+
|
| 303 |
+
# Get config or use defaults
|
| 304 |
+
config = request.config or TrainConfig()
|
| 305 |
+
|
| 306 |
+
# Generate unique model name/ID
|
| 307 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 308 |
+
model_name = request.model_name or f"credily_model_{timestamp}"
|
| 309 |
+
model_name = model_name.replace(' ', '_')
|
| 310 |
+
report_id = f"report_{timestamp}_{uuid.uuid4().hex[:8]}"
|
| 311 |
+
|
| 312 |
+
# Create temporary output directory
|
| 313 |
+
temp_output_dir = TEMP_MODELS_DIR / f"train_{uuid.uuid4().hex[:8]}"
|
| 314 |
+
|
| 315 |
+
# Create pipeline with config
|
| 316 |
+
pipeline = CredilyPipeline(
|
| 317 |
+
target_column=config.target_column,
|
| 318 |
+
output_dir=str(temp_output_dir),
|
| 319 |
+
test_size=config.test_size,
|
| 320 |
+
cv_folds=config.cv_folds,
|
| 321 |
+
clean_data=config.clean_data,
|
| 322 |
+
clean_mode=config.clean_mode,
|
| 323 |
+
flag_missing=config.flag_missing,
|
| 324 |
+
balance_data=config.balance_data,
|
| 325 |
+
balance_method=config.balance_method,
|
| 326 |
+
calibrate=config.calibrate,
|
| 327 |
+
optimize_threshold=config.optimize_threshold,
|
| 328 |
+
conservative_mode=config.conservative_mode,
|
| 329 |
+
binary_threshold=config.binary_threshold,
|
| 330 |
+
positive_classes=config.positive_classes
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# Profile and train
|
| 334 |
+
pipeline.profile(df)
|
| 335 |
+
results = pipeline.train(df)
|
| 336 |
+
|
| 337 |
+
# Convert numpy types to native Python types
|
| 338 |
+
results = convert_numpy_types(results)
|
| 339 |
+
|
| 340 |
+
# Save training report to database
|
| 341 |
+
save_training_report(report_id, model_name, results)
|
| 342 |
+
|
| 343 |
+
# Create zip file for download
|
| 344 |
+
zip_path = create_model_zip(temp_output_dir, model_name)
|
| 345 |
+
|
| 346 |
+
# Clean up temp training directory
|
| 347 |
+
shutil.rmtree(temp_output_dir)
|
| 348 |
+
|
| 349 |
+
# Prepare test predictions data for visualization
|
| 350 |
+
test_preds = results.get("test_predictions")
|
| 351 |
+
test_predictions_data = None
|
| 352 |
+
if test_preds:
|
| 353 |
+
test_predictions_data = {
|
| 354 |
+
"y_true": test_preds["y_true"],
|
| 355 |
+
"y_pred": test_preds["y_pred"],
|
| 356 |
+
"y_proba": test_preds["y_proba"],
|
| 357 |
+
"n_samples": test_preds["n_samples"]
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
# Prepare ROC curve data
|
| 361 |
+
roc_data = results.get("roc_curve")
|
| 362 |
+
roc_curve_data = None
|
| 363 |
+
if roc_data:
|
| 364 |
+
roc_curve_data = {
|
| 365 |
+
"x": roc_data["fpr"],
|
| 366 |
+
"y": roc_data["tpr"]
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
# Prepare PR curve data
|
| 370 |
+
pr_data = results.get("pr_curve")
|
| 371 |
+
pr_curve_data = None
|
| 372 |
+
if pr_data:
|
| 373 |
+
pr_curve_data = {
|
| 374 |
+
"x": pr_data["recall"],
|
| 375 |
+
"y": pr_data["precision"]
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
# Prepare safety report data
|
| 379 |
+
safety_report_data = None
|
| 380 |
+
model_valid = True # Default to valid if no safety report
|
| 381 |
+
raw_safety = results.get("safety_report")
|
| 382 |
+
if raw_safety:
|
| 383 |
+
safety_report_data = SafetyReportSchema(
|
| 384 |
+
status=raw_safety.get("status", "PASS"),
|
| 385 |
+
model_valid=raw_safety.get("model_valid", True),
|
| 386 |
+
dropped_features=raw_safety.get("dropped_features", {}),
|
| 387 |
+
warnings=raw_safety.get("warnings", []),
|
| 388 |
+
errors=raw_safety.get("errors", []),
|
| 389 |
+
leakage_detected=raw_safety.get("leakage_detected", {}),
|
| 390 |
+
redundant_features=raw_safety.get("redundant_features", []),
|
| 391 |
+
feature_dominance=raw_safety.get("feature_dominance", {}),
|
| 392 |
+
overfitting_metrics=raw_safety.get("overfitting_metrics", {})
|
| 393 |
+
)
|
| 394 |
+
model_valid = raw_safety.get("model_valid", True)
|
| 395 |
+
|
| 396 |
+
# Convert model_test_metrics to serializable format
|
| 397 |
+
model_test_metrics = results.get("model_test_metrics")
|
| 398 |
+
if model_test_metrics:
|
| 399 |
+
model_test_metrics = {
|
| 400 |
+
name: {
|
| 401 |
+
"pr_auc": float(m["pr_auc"]),
|
| 402 |
+
"roc_auc": float(m["roc_auc"]),
|
| 403 |
+
"default_recall": float(m["default_recall"]),
|
| 404 |
+
"fp_count": int(m["fp_count"]),
|
| 405 |
+
"threshold": float(m["threshold"]),
|
| 406 |
+
"cv_score": float(m["cv_score"])
|
| 407 |
+
}
|
| 408 |
+
for name, m in model_test_metrics.items()
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
return TrainResponse(
|
| 412 |
+
success=True,
|
| 413 |
+
report_id=report_id,
|
| 414 |
+
model_name=model_name,
|
| 415 |
+
best_model=results["best_model"],
|
| 416 |
+
best_score=float(results["best_score"]),
|
| 417 |
+
test_auc=float(results["test_auc"]),
|
| 418 |
+
test_pr_auc=float(results["test_pr_auc"]),
|
| 419 |
+
optimal_threshold=float(results["optimal_threshold"]),
|
| 420 |
+
model_scores={k: float(v) for k, v in results["model_scores"].items()},
|
| 421 |
+
model_test_metrics=model_test_metrics,
|
| 422 |
+
model_ranking=results.get("model_ranking"),
|
| 423 |
+
classification_report=results["classification_report"],
|
| 424 |
+
confusion_matrix=results["confusion_matrix"],
|
| 425 |
+
feature_importances={k: float(v) for k, v in results["feature_importances"].items()},
|
| 426 |
+
download_url=f"/api/train/download/{model_name}",
|
| 427 |
+
message=f"Model trained successfully. Best model: {results['best_model']} with ROC-AUC: {float(results['test_auc']):.4f}. Download your model using the download URL.",
|
| 428 |
+
model_valid=model_valid,
|
| 429 |
+
safety_report=safety_report_data,
|
| 430 |
+
test_predictions=test_predictions_data,
|
| 431 |
+
roc_curve=roc_curve_data,
|
| 432 |
+
pr_curve=pr_curve_data,
|
| 433 |
+
sanity_warnings=results.get("sanity_warnings", [])
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
except HTTPException:
|
| 437 |
+
raise
|
| 438 |
+
except ValueError as e:
|
| 439 |
+
logger.error(f"Training validation error: {e}")
|
| 440 |
+
raise handle_api_error(e, "model training")
|
| 441 |
+
except Exception as e:
|
| 442 |
+
logger.error(f"Training failed: {e}", exc_info=True)
|
| 443 |
+
raise handle_api_error(e, "model training")
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
@app.get("/api/train/download/{model_name}")
|
| 447 |
+
async def download_model(model_name: str):
|
| 448 |
+
"""
|
| 449 |
+
Download a trained model zip file.
|
| 450 |
+
|
| 451 |
+
The model should be extracted and the model.pkl path used for predictions.
|
| 452 |
+
"""
|
| 453 |
+
zip_path = TEMP_MODELS_DIR / f"{model_name}.zip"
|
| 454 |
+
|
| 455 |
+
if not zip_path.exists():
|
| 456 |
+
raise HTTPException(
|
| 457 |
+
status_code=404,
|
| 458 |
+
detail=f"Model '{model_name}' not found. It may have expired or been downloaded already."
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
def iterfile():
|
| 462 |
+
with open(zip_path, 'rb') as f:
|
| 463 |
+
yield from f
|
| 464 |
+
# Clean up after download
|
| 465 |
+
try:
|
| 466 |
+
zip_path.unlink()
|
| 467 |
+
except:
|
| 468 |
+
pass
|
| 469 |
+
|
| 470 |
+
return StreamingResponse(
|
| 471 |
+
iterfile(),
|
| 472 |
+
media_type="application/zip",
|
| 473 |
+
headers={
|
| 474 |
+
"Content-Disposition": f"attachment; filename={model_name}.zip"
|
| 475 |
+
}
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
# ============== Model Upload Endpoints ==============
|
| 480 |
+
|
| 481 |
+
@app.post("/api/models/upload")
|
| 482 |
+
async def upload_model(model_file: UploadFile = File(...)):
|
| 483 |
+
"""
|
| 484 |
+
Upload a trained model file for predictions.
|
| 485 |
+
|
| 486 |
+
- Accepts .zip files (from model training) or .pkl files
|
| 487 |
+
- Returns a model_id that can be used for predictions
|
| 488 |
+
- Models are stored temporarily on the server
|
| 489 |
+
"""
|
| 490 |
+
if not model_file.filename:
|
| 491 |
+
raise HTTPException(status_code=400, detail="No file provided")
|
| 492 |
+
|
| 493 |
+
# Validate file extension
|
| 494 |
+
if not (model_file.filename.endswith('.zip') or model_file.filename.endswith('.pkl')):
|
| 495 |
+
raise HTTPException(
|
| 496 |
+
status_code=400,
|
| 497 |
+
detail="Unsupported file format. Please upload a .zip or .pkl file."
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
# Read file content
|
| 501 |
+
content = await model_file.read()
|
| 502 |
+
|
| 503 |
+
# Process and validate the model
|
| 504 |
+
model_id, model_path = process_uploaded_model(content, model_file.filename)
|
| 505 |
+
|
| 506 |
+
return {
|
| 507 |
+
"success": True,
|
| 508 |
+
"model_id": model_id,
|
| 509 |
+
"filename": model_file.filename,
|
| 510 |
+
"message": "Model uploaded successfully. Use the model_id for predictions."
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
@app.get("/api/models/{model_id}")
|
| 515 |
+
async def get_uploaded_model_info(model_id: str):
|
| 516 |
+
"""Get information about an uploaded model."""
|
| 517 |
+
if model_id not in uploaded_models:
|
| 518 |
+
raise HTTPException(status_code=404, detail=f"Model '{model_id}' not found")
|
| 519 |
+
|
| 520 |
+
model_path = uploaded_models[model_id]
|
| 521 |
+
|
| 522 |
+
try:
|
| 523 |
+
pipeline = CredilyPipeline.load(model_path)
|
| 524 |
+
return {
|
| 525 |
+
"success": True,
|
| 526 |
+
"model_id": model_id,
|
| 527 |
+
"best_model": pipeline.best_model_name,
|
| 528 |
+
"optimal_threshold": pipeline.optimal_threshold,
|
| 529 |
+
"features": pipeline.feature_columns[:10] if hasattr(pipeline, 'feature_columns') else []
|
| 530 |
+
}
|
| 531 |
+
except Exception as e:
|
| 532 |
+
logger.error(f"Failed to load model info: {e}")
|
| 533 |
+
raise HTTPException(status_code=500, detail="Failed to load model information")
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
@app.delete("/api/models/{model_id}")
|
| 537 |
+
async def delete_uploaded_model(model_id: str):
|
| 538 |
+
"""Delete an uploaded model."""
|
| 539 |
+
if model_id not in uploaded_models:
|
| 540 |
+
raise HTTPException(status_code=404, detail=f"Model '{model_id}' not found")
|
| 541 |
+
|
| 542 |
+
model_path = uploaded_models[model_id]
|
| 543 |
+
model_dir = UPLOADED_MODELS_DIR / model_id
|
| 544 |
+
|
| 545 |
+
try:
|
| 546 |
+
if model_dir.exists():
|
| 547 |
+
shutil.rmtree(model_dir)
|
| 548 |
+
del uploaded_models[model_id]
|
| 549 |
+
return {"success": True, "message": f"Model '{model_id}' deleted"}
|
| 550 |
+
except Exception as e:
|
| 551 |
+
logger.error(f"Failed to delete model: {e}")
|
| 552 |
+
raise HTTPException(status_code=500, detail="Failed to delete model")
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
# ============== Prediction Endpoints ==============
|
| 556 |
+
|
| 557 |
+
@app.post("/api/predict/with-upload", response_model=PredictResponse)
|
| 558 |
+
async def predict_with_upload(
|
| 559 |
+
data_file: UploadFile = File(...),
|
| 560 |
+
model_file: UploadFile = File(...),
|
| 561 |
+
threshold: Optional[float] = Form(None),
|
| 562 |
+
save_results: bool = Form(True)
|
| 563 |
+
):
|
| 564 |
+
"""
|
| 565 |
+
Make predictions by uploading both model and data files.
|
| 566 |
+
|
| 567 |
+
- model_file: .zip or .pkl model file
|
| 568 |
+
- data_file: .csv file with data to predict
|
| 569 |
+
- threshold: optional custom threshold (0.0-1.0)
|
| 570 |
+
- save_results: whether to save to database
|
| 571 |
+
"""
|
| 572 |
+
# Validate file types
|
| 573 |
+
if not model_file.filename or not (model_file.filename.endswith('.zip') or model_file.filename.endswith('.pkl')):
|
| 574 |
+
raise HTTPException(status_code=400, detail="Model must be a .zip or .pkl file")
|
| 575 |
+
|
| 576 |
+
if not data_file.filename or not data_file.filename.endswith('.csv'):
|
| 577 |
+
raise HTTPException(status_code=400, detail="Data must be a .csv file")
|
| 578 |
+
|
| 579 |
+
try:
|
| 580 |
+
# Process model upload
|
| 581 |
+
model_content = await model_file.read()
|
| 582 |
+
model_id, model_path = process_uploaded_model(model_content, model_file.filename)
|
| 583 |
+
|
| 584 |
+
# Read and parse CSV data
|
| 585 |
+
data_content = await data_file.read()
|
| 586 |
+
df = pd.read_csv(io.BytesIO(data_content))
|
| 587 |
+
|
| 588 |
+
if df.empty:
|
| 589 |
+
raise APIErrors.no_data()
|
| 590 |
+
|
| 591 |
+
# Load model and make predictions
|
| 592 |
+
pipeline = CredilyPipeline.load(model_path)
|
| 593 |
+
result_df = pipeline.predict(df, include_proba=True, threshold=threshold)
|
| 594 |
+
|
| 595 |
+
actual_threshold = threshold or pipeline.optimal_threshold
|
| 596 |
+
|
| 597 |
+
# Build response
|
| 598 |
+
predictions = []
|
| 599 |
+
for idx, row in result_df.iterrows():
|
| 600 |
+
prob = row.get("proba_1", None)
|
| 601 |
+
pred_result = PredictionResult(
|
| 602 |
+
index=int(idx),
|
| 603 |
+
prediction=int(row["prediction"]),
|
| 604 |
+
probability=float(prob) if prob is not None else None,
|
| 605 |
+
risk_level=classify_risk(prob, actual_threshold) if prob is not None else None
|
| 606 |
+
)
|
| 607 |
+
predictions.append(pred_result)
|
| 608 |
+
|
| 609 |
+
# Summary
|
| 610 |
+
pred_series = result_df["prediction"]
|
| 611 |
+
total_records = len(predictions)
|
| 612 |
+
predicted_positive = int(pred_series.sum())
|
| 613 |
+
predicted_negative = int((pred_series == 0).sum())
|
| 614 |
+
positive_rate = float(pred_series.mean())
|
| 615 |
+
|
| 616 |
+
summary = {
|
| 617 |
+
"total_records": total_records,
|
| 618 |
+
"predicted_positive": predicted_positive,
|
| 619 |
+
"predicted_negative": predicted_negative,
|
| 620 |
+
"positive_rate": positive_rate,
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
avg_probability = None
|
| 624 |
+
if "proba_1" in result_df.columns:
|
| 625 |
+
avg_probability = float(result_df["proba_1"].mean())
|
| 626 |
+
summary["avg_probability"] = avg_probability
|
| 627 |
+
summary["risk_distribution"] = {
|
| 628 |
+
"very_low": sum(1 for p in predictions if p.risk_level == "very_low"),
|
| 629 |
+
"low": sum(1 for p in predictions if p.risk_level == "low"),
|
| 630 |
+
"medium": sum(1 for p in predictions if p.risk_level == "medium"),
|
| 631 |
+
"high": sum(1 for p in predictions if p.risk_level == "high"),
|
| 632 |
+
}
|
| 633 |
+
|
| 634 |
+
# Save to database if requested
|
| 635 |
+
session_id = None
|
| 636 |
+
if save_results:
|
| 637 |
+
session_id = f"pred_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
| 638 |
+
save_prediction_session(
|
| 639 |
+
session_id=session_id,
|
| 640 |
+
model_path=f"uploaded:{model_id}",
|
| 641 |
+
model_name=pipeline.best_model_name,
|
| 642 |
+
threshold_used=actual_threshold,
|
| 643 |
+
total_records=total_records,
|
| 644 |
+
predicted_positive=predicted_positive,
|
| 645 |
+
predicted_negative=predicted_negative,
|
| 646 |
+
positive_rate=positive_rate,
|
| 647 |
+
avg_probability=avg_probability
|
| 648 |
+
)
|
| 649 |
+
# Save individual results
|
| 650 |
+
results_to_save = []
|
| 651 |
+
for i, pred in enumerate(predictions):
|
| 652 |
+
results_to_save.append({
|
| 653 |
+
'index': pred.index,
|
| 654 |
+
'prediction': pred.prediction,
|
| 655 |
+
'probability': pred.probability,
|
| 656 |
+
'risk_level': pred.risk_level,
|
| 657 |
+
'input_data': df.iloc[i].to_dict() if i < len(df) else {}
|
| 658 |
+
})
|
| 659 |
+
save_prediction_results(session_id, results_to_save)
|
| 660 |
+
|
| 661 |
+
return PredictResponse(
|
| 662 |
+
success=True,
|
| 663 |
+
session_id=session_id,
|
| 664 |
+
model_path=f"uploaded:{model_id}",
|
| 665 |
+
threshold_used=actual_threshold,
|
| 666 |
+
predictions=predictions,
|
| 667 |
+
summary=summary
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
except HTTPException:
|
| 671 |
+
raise
|
| 672 |
+
except Exception as e:
|
| 673 |
+
logger.error(f"Prediction with upload failed: {e}", exc_info=True)
|
| 674 |
+
raise handle_api_error(e, "prediction")
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
@app.post("/api/predict", response_model=PredictResponse)
|
| 678 |
+
async def predict(request: PredictRequest):
|
| 679 |
+
"""
|
| 680 |
+
Make credit risk predictions on new data.
|
| 681 |
+
|
| 682 |
+
- Requires the absolute path to a trained model file (.pkl or .zip)
|
| 683 |
+
- Returns predictions with probability scores
|
| 684 |
+
- Optionally saves results to database for history tracking
|
| 685 |
+
"""
|
| 686 |
+
try:
|
| 687 |
+
# Load model from path
|
| 688 |
+
pipeline = load_model_from_path(request.model_path)
|
| 689 |
+
|
| 690 |
+
# Convert data to DataFrame
|
| 691 |
+
df = pd.DataFrame(request.data)
|
| 692 |
+
if df.empty:
|
| 693 |
+
raise APIErrors.no_data()
|
| 694 |
+
|
| 695 |
+
# Make predictions
|
| 696 |
+
result_df = pipeline.predict(
|
| 697 |
+
df,
|
| 698 |
+
include_proba=request.include_proba,
|
| 699 |
+
threshold=request.threshold
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
threshold = request.threshold or pipeline.optimal_threshold
|
| 703 |
+
|
| 704 |
+
# Build response
|
| 705 |
+
predictions = []
|
| 706 |
+
for idx, row in result_df.iterrows():
|
| 707 |
+
prob = row.get("proba_1", None)
|
| 708 |
+
pred_result = PredictionResult(
|
| 709 |
+
index=int(idx),
|
| 710 |
+
prediction=int(row["prediction"]),
|
| 711 |
+
probability=float(prob) if prob is not None else None,
|
| 712 |
+
risk_level=classify_risk(prob, threshold) if prob is not None else None
|
| 713 |
+
)
|
| 714 |
+
predictions.append(pred_result)
|
| 715 |
+
|
| 716 |
+
# Summary statistics
|
| 717 |
+
pred_series = result_df["prediction"]
|
| 718 |
+
total_records = len(predictions)
|
| 719 |
+
predicted_positive = int(pred_series.sum())
|
| 720 |
+
predicted_negative = int((pred_series == 0).sum())
|
| 721 |
+
positive_rate = float(pred_series.mean())
|
| 722 |
+
|
| 723 |
+
summary = {
|
| 724 |
+
"total_records": total_records,
|
| 725 |
+
"predicted_positive": predicted_positive,
|
| 726 |
+
"predicted_negative": predicted_negative,
|
| 727 |
+
"positive_rate": positive_rate,
|
| 728 |
+
}
|
| 729 |
+
|
| 730 |
+
avg_probability = None
|
| 731 |
+
if request.include_proba and "proba_1" in result_df.columns:
|
| 732 |
+
avg_probability = float(result_df["proba_1"].mean())
|
| 733 |
+
summary["avg_probability"] = avg_probability
|
| 734 |
+
summary["risk_distribution"] = {
|
| 735 |
+
"very_low": sum(1 for p in predictions if p.risk_level == "very_low"),
|
| 736 |
+
"low": sum(1 for p in predictions if p.risk_level == "low"),
|
| 737 |
+
"medium": sum(1 for p in predictions if p.risk_level == "medium"),
|
| 738 |
+
"high": sum(1 for p in predictions if p.risk_level == "high"),
|
| 739 |
+
}
|
| 740 |
+
|
| 741 |
+
# Save to database if requested
|
| 742 |
+
session_id = None
|
| 743 |
+
if request.save_results:
|
| 744 |
+
session_id = f"pred_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
| 745 |
+
|
| 746 |
+
# Save session
|
| 747 |
+
save_prediction_session(
|
| 748 |
+
session_id=session_id,
|
| 749 |
+
model_path=request.model_path,
|
| 750 |
+
model_name=pipeline.best_model_name,
|
| 751 |
+
threshold_used=threshold,
|
| 752 |
+
total_records=total_records,
|
| 753 |
+
predicted_positive=predicted_positive,
|
| 754 |
+
predicted_negative=predicted_negative,
|
| 755 |
+
positive_rate=positive_rate,
|
| 756 |
+
avg_probability=avg_probability
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
# Save individual results with input data
|
| 760 |
+
results_to_save = []
|
| 761 |
+
for i, (pred, input_row) in enumerate(zip(predictions, request.data)):
|
| 762 |
+
results_to_save.append({
|
| 763 |
+
'index': pred.index,
|
| 764 |
+
'prediction': pred.prediction,
|
| 765 |
+
'probability': pred.probability,
|
| 766 |
+
'risk_level': pred.risk_level,
|
| 767 |
+
'input_data': input_row
|
| 768 |
+
})
|
| 769 |
+
save_prediction_results(session_id, results_to_save)
|
| 770 |
+
|
| 771 |
+
return PredictResponse(
|
| 772 |
+
success=True,
|
| 773 |
+
session_id=session_id,
|
| 774 |
+
model_path=request.model_path,
|
| 775 |
+
threshold_used=threshold,
|
| 776 |
+
predictions=predictions,
|
| 777 |
+
summary=summary
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
except HTTPException:
|
| 781 |
+
raise
|
| 782 |
+
except Exception as e:
|
| 783 |
+
logger.error(f"Prediction failed: {e}", exc_info=True)
|
| 784 |
+
raise handle_api_error(e, "prediction")
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
@app.post("/api/predict/single")
|
| 788 |
+
async def predict_single(request: SinglePredictRequest):
|
| 789 |
+
"""
|
| 790 |
+
Make a single prediction (convenience endpoint).
|
| 791 |
+
|
| 792 |
+
Accepts a single record and returns prediction.
|
| 793 |
+
"""
|
| 794 |
+
predict_request = PredictRequest(
|
| 795 |
+
data=[request.data],
|
| 796 |
+
model_path=request.model_path,
|
| 797 |
+
include_proba=True,
|
| 798 |
+
threshold=request.threshold,
|
| 799 |
+
save_results=False # Don't save single predictions by default
|
| 800 |
+
)
|
| 801 |
+
response = await predict(predict_request)
|
| 802 |
+
|
| 803 |
+
if response.predictions:
|
| 804 |
+
pred = response.predictions[0]
|
| 805 |
+
return {
|
| 806 |
+
"prediction": pred.prediction,
|
| 807 |
+
"probability": pred.probability,
|
| 808 |
+
"risk_level": pred.risk_level,
|
| 809 |
+
"threshold": response.threshold_used
|
| 810 |
+
}
|
| 811 |
+
|
| 812 |
+
raise APIErrors.prediction_failed()
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
# ============== Profiling Endpoints ==============
|
| 816 |
+
|
| 817 |
+
@app.post("/api/profile", response_model=ProfileResponse)
|
| 818 |
+
async def profile_data(request: ProfileRequest):
|
| 819 |
+
"""
|
| 820 |
+
Profile a dataset to understand its structure and quality.
|
| 821 |
+
|
| 822 |
+
- Analyzes column types and distributions
|
| 823 |
+
- Identifies missing values and anomalies
|
| 824 |
+
- Provides recommendations for preprocessing
|
| 825 |
+
"""
|
| 826 |
+
try:
|
| 827 |
+
df = pd.DataFrame(request.data)
|
| 828 |
+
if df.empty:
|
| 829 |
+
raise APIErrors.no_data()
|
| 830 |
+
|
| 831 |
+
# Create profiler
|
| 832 |
+
profiler = DataProfiler(target_column=request.target_column)
|
| 833 |
+
profile = profiler.profile(df)
|
| 834 |
+
|
| 835 |
+
# Build column profiles
|
| 836 |
+
columns = []
|
| 837 |
+
for col in df.columns:
|
| 838 |
+
col_data = df[col]
|
| 839 |
+
columns.append(ColumnProfile(
|
| 840 |
+
name=col,
|
| 841 |
+
dtype=str(col_data.dtype),
|
| 842 |
+
non_null_count=int(col_data.notna().sum()),
|
| 843 |
+
null_count=int(col_data.isna().sum()),
|
| 844 |
+
null_percentage=float(col_data.isna().mean() * 100),
|
| 845 |
+
unique_count=int(col_data.nunique()),
|
| 846 |
+
sample_values=col_data.dropna().head(5).tolist()
|
| 847 |
+
))
|
| 848 |
+
|
| 849 |
+
# Numeric summary
|
| 850 |
+
numeric_cols = df.select_dtypes(include=["number"])
|
| 851 |
+
numeric_summary = None
|
| 852 |
+
if not numeric_cols.empty:
|
| 853 |
+
numeric_summary = numeric_cols.describe().to_dict()
|
| 854 |
+
|
| 855 |
+
# Target analysis
|
| 856 |
+
target_analysis = None
|
| 857 |
+
if request.target_column and request.target_column in df.columns:
|
| 858 |
+
target = df[request.target_column]
|
| 859 |
+
target_analysis = {
|
| 860 |
+
"class_distribution": target.value_counts().to_dict(),
|
| 861 |
+
"class_balance": float(target.value_counts(normalize=True).min()),
|
| 862 |
+
"is_binary": len(target.unique()) == 2
|
| 863 |
+
}
|
| 864 |
+
|
| 865 |
+
# Generate recommendations
|
| 866 |
+
recommendations = []
|
| 867 |
+
for col_profile in columns:
|
| 868 |
+
if col_profile.null_percentage > 30:
|
| 869 |
+
recommendations.append(f"Column '{col_profile.name}' has {col_profile.null_percentage:.1f}% missing values - consider removing or imputing")
|
| 870 |
+
if col_profile.unique_count == 1:
|
| 871 |
+
recommendations.append(f"Column '{col_profile.name}' has only one unique value - consider removing (no predictive value)")
|
| 872 |
+
if col_profile.unique_count == len(df) and col_profile.dtype == "object":
|
| 873 |
+
recommendations.append(f"Column '{col_profile.name}' appears to be an ID column - consider removing")
|
| 874 |
+
|
| 875 |
+
if target_analysis and target_analysis["class_balance"] < 0.2:
|
| 876 |
+
recommendations.append("Target class is highly imbalanced - SMOTE or other balancing techniques recommended")
|
| 877 |
+
|
| 878 |
+
return ProfileResponse(
|
| 879 |
+
success=True,
|
| 880 |
+
n_rows=len(df),
|
| 881 |
+
n_columns=len(df.columns),
|
| 882 |
+
columns=columns,
|
| 883 |
+
numeric_summary=numeric_summary,
|
| 884 |
+
target_analysis=target_analysis,
|
| 885 |
+
recommendations=recommendations
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
except HTTPException:
|
| 889 |
+
raise
|
| 890 |
+
except Exception as e:
|
| 891 |
+
logger.error(f"Profiling failed: {e}", exc_info=True)
|
| 892 |
+
raise handle_api_error(e, "data profiling")
|
| 893 |
+
|
| 894 |
+
|
| 895 |
+
# ============== Prediction History Endpoints ==============
|
| 896 |
+
|
| 897 |
+
@app.get("/api/predictions", response_model=PredictionHistoryResponse)
|
| 898 |
+
async def list_prediction_history(
|
| 899 |
+
limit: int = Query(50, ge=1, le=100),
|
| 900 |
+
offset: int = Query(0, ge=0)
|
| 901 |
+
):
|
| 902 |
+
"""List prediction history sessions."""
|
| 903 |
+
sessions = get_prediction_sessions(limit=limit, offset=offset)
|
| 904 |
+
return PredictionHistoryResponse(
|
| 905 |
+
success=True,
|
| 906 |
+
sessions=[PredictionSessionInfo(**s) for s in sessions],
|
| 907 |
+
total=len(sessions)
|
| 908 |
+
)
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
@app.get("/api/predictions/{session_id}", response_model=PredictionSessionDetailResponse)
|
| 912 |
+
async def get_prediction_detail(session_id: str):
|
| 913 |
+
"""Get detailed prediction session including all results."""
|
| 914 |
+
session = get_prediction_session(session_id)
|
| 915 |
+
if not session:
|
| 916 |
+
raise HTTPException(status_code=404, detail=f"Prediction session '{session_id}' not found")
|
| 917 |
+
|
| 918 |
+
results = get_prediction_results(session_id)
|
| 919 |
+
|
| 920 |
+
return PredictionSessionDetailResponse(
|
| 921 |
+
success=True,
|
| 922 |
+
session=PredictionSessionInfo(**session),
|
| 923 |
+
results=results
|
| 924 |
+
)
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
@app.delete("/api/predictions/{session_id}")
|
| 928 |
+
async def delete_prediction(session_id: str):
|
| 929 |
+
"""Delete a prediction session and its results."""
|
| 930 |
+
if delete_prediction_session(session_id):
|
| 931 |
+
return {"success": True, "message": f"Prediction session '{session_id}' deleted"}
|
| 932 |
+
raise HTTPException(status_code=404, detail=f"Prediction session '{session_id}' not found")
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
@app.get("/api/predictions/{session_id}/export")
|
| 936 |
+
async def export_prediction_results(session_id: str, format: str = Query("csv", enum=["csv", "json"])):
|
| 937 |
+
"""Export prediction results as CSV or JSON."""
|
| 938 |
+
session = get_prediction_session(session_id)
|
| 939 |
+
if not session:
|
| 940 |
+
raise HTTPException(status_code=404, detail=f"Prediction session '{session_id}' not found")
|
| 941 |
+
|
| 942 |
+
results = get_prediction_results(session_id)
|
| 943 |
+
|
| 944 |
+
if format == "csv":
|
| 945 |
+
# Create CSV
|
| 946 |
+
df = pd.DataFrame(results)
|
| 947 |
+
if 'input_data' in df.columns:
|
| 948 |
+
# Expand input_data into columns
|
| 949 |
+
input_df = pd.DataFrame(df['input_data'].tolist())
|
| 950 |
+
df = pd.concat([df.drop('input_data', axis=1), input_df], axis=1)
|
| 951 |
+
|
| 952 |
+
output = io.StringIO()
|
| 953 |
+
df.to_csv(output, index=False)
|
| 954 |
+
output.seek(0)
|
| 955 |
+
|
| 956 |
+
return StreamingResponse(
|
| 957 |
+
iter([output.getvalue()]),
|
| 958 |
+
media_type="text/csv",
|
| 959 |
+
headers={"Content-Disposition": f"attachment; filename=predictions_{session_id}.csv"}
|
| 960 |
+
)
|
| 961 |
+
else:
|
| 962 |
+
return {
|
| 963 |
+
"session": session,
|
| 964 |
+
"results": results
|
| 965 |
+
}
|
| 966 |
+
|
| 967 |
+
|
| 968 |
+
# ============== Training Reports Endpoints ==============
|
| 969 |
+
|
| 970 |
+
@app.get("/api/reports", response_model=TrainingReportsResponse)
|
| 971 |
+
async def list_training_reports(
|
| 972 |
+
limit: int = Query(50, ge=1, le=100),
|
| 973 |
+
offset: int = Query(0, ge=0)
|
| 974 |
+
):
|
| 975 |
+
"""List training reports history."""
|
| 976 |
+
reports = get_training_reports(limit=limit, offset=offset)
|
| 977 |
+
return TrainingReportsResponse(
|
| 978 |
+
success=True,
|
| 979 |
+
reports=[TrainingReportInfo(**r) for r in reports],
|
| 980 |
+
total=len(reports)
|
| 981 |
+
)
|
| 982 |
+
|
| 983 |
+
|
| 984 |
+
@app.get("/api/reports/{report_id}", response_model=TrainingReportDetailResponse)
|
| 985 |
+
async def get_report_detail(report_id: str):
|
| 986 |
+
"""Get detailed training report."""
|
| 987 |
+
report = get_training_report(report_id)
|
| 988 |
+
if not report:
|
| 989 |
+
raise HTTPException(status_code=404, detail=f"Training report '{report_id}' not found")
|
| 990 |
+
|
| 991 |
+
return TrainingReportDetailResponse(
|
| 992 |
+
success=True,
|
| 993 |
+
report=report
|
| 994 |
+
)
|
| 995 |
+
|
| 996 |
+
|
| 997 |
+
@app.delete("/api/reports/{report_id}")
|
| 998 |
+
async def delete_report(report_id: str):
|
| 999 |
+
"""Delete a training report."""
|
| 1000 |
+
if delete_training_report(report_id):
|
| 1001 |
+
return {"success": True, "message": f"Training report '{report_id}' deleted"}
|
| 1002 |
+
raise HTTPException(status_code=404, detail=f"Training report '{report_id}' not found")
|
| 1003 |
+
|
| 1004 |
+
|
| 1005 |
+
# ============== File Upload Endpoint ==============
|
| 1006 |
+
|
| 1007 |
+
@app.post("/api/upload")
|
| 1008 |
+
async def upload_csv(file_content: str, filename: str):
|
| 1009 |
+
"""
|
| 1010 |
+
Upload CSV data as base64 encoded string.
|
| 1011 |
+
Returns the data as JSON records for further processing.
|
| 1012 |
+
"""
|
| 1013 |
+
import base64
|
| 1014 |
+
|
| 1015 |
+
try:
|
| 1016 |
+
# Decode base64 content
|
| 1017 |
+
decoded = base64.b64decode(file_content)
|
| 1018 |
+
|
| 1019 |
+
# Read as CSV
|
| 1020 |
+
df = pd.read_csv(io.BytesIO(decoded))
|
| 1021 |
+
|
| 1022 |
+
return {
|
| 1023 |
+
"success": True,
|
| 1024 |
+
"filename": filename,
|
| 1025 |
+
"rows": len(df),
|
| 1026 |
+
"columns": list(df.columns),
|
| 1027 |
+
"data": df.to_dict(orient="records")
|
| 1028 |
+
}
|
| 1029 |
+
|
| 1030 |
+
except Exception as e:
|
| 1031 |
+
logger.error(f"CSV parsing failed: {e}", exc_info=True)
|
| 1032 |
+
raise HTTPException(
|
| 1033 |
+
status_code=400,
|
| 1034 |
+
detail="Unable to parse the CSV file. Please ensure it's a valid CSV format."
|
| 1035 |
+
)
|
credily/api/schemas.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic schemas for API request/response validation.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Dict, List, Any, Optional
|
| 6 |
+
from pydantic import BaseModel, Field
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# ============== Training Schemas ==============
|
| 10 |
+
|
| 11 |
+
class TrainConfig(BaseModel):
|
| 12 |
+
"""Configuration for model training."""
|
| 13 |
+
target_column: str = Field(default="target", description="Name of the target column")
|
| 14 |
+
test_size: float = Field(default=0.2, ge=0.1, le=0.5, description="Test set proportion")
|
| 15 |
+
cv_folds: int = Field(default=5, ge=2, le=10, description="Cross-validation folds")
|
| 16 |
+
clean_data: bool = Field(default=True, description="Whether to clean data before training")
|
| 17 |
+
clean_mode: str = Field(default="thorough", description="Cleaning mode: basic, thorough, aggressive")
|
| 18 |
+
flag_missing: bool = Field(default=True, description="Create missing value indicator columns (set False if *_missing features dominate)")
|
| 19 |
+
balance_data: bool = Field(default=True, description="Whether to balance imbalanced classes")
|
| 20 |
+
balance_method: str = Field(default="smote", description="Balancing method: smote, random_oversample, random_undersample, smote_tomek, tomek, nearmiss, none")
|
| 21 |
+
calibrate: bool = Field(default=True, description="Whether to calibrate probabilities")
|
| 22 |
+
optimize_threshold: bool = Field(default=True, description="Whether to optimize classification threshold")
|
| 23 |
+
conservative_mode: str = Field(default="auto", description="Regularization mode: auto (detect small datasets), always, never")
|
| 24 |
+
# Agnostic pipeline options
|
| 25 |
+
task_type: str = Field(default="binary", description="Classification type: 'binary' for binary classification, 'multiclass' for multi-class classification")
|
| 26 |
+
binary_threshold: Optional[float] = Field(default=None, description="Threshold to convert numeric target to binary (values BELOW threshold = positive class). Example: 600 for credit score means score < 600 = default")
|
| 27 |
+
positive_classes: Optional[List[str]] = Field(default=None, description="List of class labels to treat as positive (1) for binary grouping. All other classes become negative (0). Example: ['Poor', 'Standard'] → these become 1, 'Good' becomes 0")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class TrainRequest(BaseModel):
|
| 31 |
+
"""Request body for training a model."""
|
| 32 |
+
data: List[Dict[str, Any]] = Field(..., description="Training data as list of records")
|
| 33 |
+
config: Optional[TrainConfig] = Field(default=None, description="Training configuration")
|
| 34 |
+
model_name: Optional[str] = Field(default=None, description="Custom name for the model")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class TestPredictions(BaseModel):
|
| 38 |
+
"""Test set predictions for visualization."""
|
| 39 |
+
y_true: List[int]
|
| 40 |
+
y_pred: List[int]
|
| 41 |
+
y_proba: List[float]
|
| 42 |
+
n_samples: int
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class CurveData(BaseModel):
|
| 46 |
+
"""Data points for ROC or PR curves."""
|
| 47 |
+
x: List[float] # FPR for ROC, Recall for PR
|
| 48 |
+
y: List[float] # TPR for ROC, Precision for PR
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class SafetyReportSchema(BaseModel):
|
| 52 |
+
"""Safety validation report."""
|
| 53 |
+
status: str # PASS, WARN, FAIL
|
| 54 |
+
model_valid: bool
|
| 55 |
+
dropped_features: Dict[str, str] = {} # feature: reason
|
| 56 |
+
warnings: List[str] = []
|
| 57 |
+
errors: List[str] = []
|
| 58 |
+
leakage_detected: Dict[str, float] = {} # feature: correlation
|
| 59 |
+
redundant_features: List[Dict[str, Any]] = []
|
| 60 |
+
feature_dominance: Dict[str, float] = {}
|
| 61 |
+
overfitting_metrics: Dict[str, Any] = {}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class ModelTestMetrics(BaseModel):
|
| 65 |
+
"""Test set metrics for a single model."""
|
| 66 |
+
pr_auc: float
|
| 67 |
+
roc_auc: float
|
| 68 |
+
default_recall: float
|
| 69 |
+
fp_count: int
|
| 70 |
+
threshold: float
|
| 71 |
+
cv_score: float
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class TrainResponse(BaseModel):
|
| 75 |
+
"""Response from model training."""
|
| 76 |
+
success: bool
|
| 77 |
+
report_id: str
|
| 78 |
+
model_name: str
|
| 79 |
+
best_model: str
|
| 80 |
+
best_score: float
|
| 81 |
+
test_auc: float
|
| 82 |
+
test_pr_auc: float
|
| 83 |
+
optimal_threshold: float
|
| 84 |
+
model_scores: Dict[str, float]
|
| 85 |
+
model_test_metrics: Optional[Dict[str, ModelTestMetrics]] = None # Full test metrics for each model
|
| 86 |
+
model_ranking: Optional[List[str]] = None # Ordered list of model names by rank
|
| 87 |
+
classification_report: Dict[str, Any]
|
| 88 |
+
confusion_matrix: List[List[int]]
|
| 89 |
+
feature_importances: Dict[str, float]
|
| 90 |
+
message: str
|
| 91 |
+
download_url: str # URL to download the model zip file
|
| 92 |
+
# Safety validation
|
| 93 |
+
model_valid: bool = True # Whether model passes all safety checks
|
| 94 |
+
safety_report: Optional[SafetyReportSchema] = None
|
| 95 |
+
# Model performance visualization data
|
| 96 |
+
test_predictions: Optional[TestPredictions] = None
|
| 97 |
+
roc_curve: Optional[CurveData] = None
|
| 98 |
+
pr_curve: Optional[CurveData] = None
|
| 99 |
+
sanity_warnings: Optional[List[str]] = None
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ============== Prediction Schemas ==============
|
| 103 |
+
|
| 104 |
+
class PredictRequest(BaseModel):
|
| 105 |
+
"""Request body for making predictions."""
|
| 106 |
+
data: List[Dict[str, Any]] = Field(..., description="Data to predict on as list of records")
|
| 107 |
+
model_path: str = Field(..., description="Absolute path to the trained model file (.pkl)")
|
| 108 |
+
include_proba: bool = Field(default=True, description="Whether to include probability scores")
|
| 109 |
+
threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0, description="Custom threshold")
|
| 110 |
+
save_results: bool = Field(default=True, description="Whether to save results to database")
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class PredictionResult(BaseModel):
|
| 114 |
+
"""Single prediction result."""
|
| 115 |
+
index: int
|
| 116 |
+
prediction: int
|
| 117 |
+
probability: Optional[float] = None
|
| 118 |
+
risk_level: Optional[str] = None
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class PredictResponse(BaseModel):
|
| 122 |
+
"""Response from prediction."""
|
| 123 |
+
success: bool
|
| 124 |
+
session_id: Optional[str] = None
|
| 125 |
+
model_path: str
|
| 126 |
+
threshold_used: float
|
| 127 |
+
predictions: List[PredictionResult]
|
| 128 |
+
summary: Dict[str, Any]
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class SinglePredictRequest(BaseModel):
|
| 132 |
+
"""Request for single prediction."""
|
| 133 |
+
data: Dict[str, Any] = Field(..., description="Single record to predict")
|
| 134 |
+
model_path: str = Field(..., description="Absolute path to the trained model file (.pkl)")
|
| 135 |
+
threshold: Optional[float] = Field(default=None, ge=0.0, le=1.0, description="Custom threshold")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# ============== Profiling Schemas ==============
|
| 139 |
+
|
| 140 |
+
class ProfileRequest(BaseModel):
|
| 141 |
+
"""Request body for data profiling."""
|
| 142 |
+
data: List[Dict[str, Any]] = Field(..., description="Data to profile as list of records")
|
| 143 |
+
target_column: Optional[str] = Field(default=None, description="Target column name if known")
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class ColumnProfile(BaseModel):
|
| 147 |
+
"""Profile of a single column."""
|
| 148 |
+
name: str
|
| 149 |
+
dtype: str
|
| 150 |
+
non_null_count: int
|
| 151 |
+
null_count: int
|
| 152 |
+
null_percentage: float
|
| 153 |
+
unique_count: int
|
| 154 |
+
sample_values: List[Any]
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class ProfileResponse(BaseModel):
|
| 158 |
+
"""Response from data profiling."""
|
| 159 |
+
success: bool
|
| 160 |
+
n_rows: int
|
| 161 |
+
n_columns: int
|
| 162 |
+
columns: List[ColumnProfile]
|
| 163 |
+
numeric_summary: Optional[Dict[str, Any]] = None
|
| 164 |
+
target_analysis: Optional[Dict[str, Any]] = None
|
| 165 |
+
recommendations: List[str]
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# ============== History Schemas ==============
|
| 169 |
+
|
| 170 |
+
class PredictionSessionInfo(BaseModel):
|
| 171 |
+
"""Summary info about a prediction session."""
|
| 172 |
+
id: str
|
| 173 |
+
model_path: str
|
| 174 |
+
model_name: Optional[str]
|
| 175 |
+
threshold_used: float
|
| 176 |
+
total_records: int
|
| 177 |
+
predicted_positive: int
|
| 178 |
+
predicted_negative: int
|
| 179 |
+
positive_rate: float
|
| 180 |
+
avg_probability: Optional[float]
|
| 181 |
+
created_at: str
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class PredictionHistoryResponse(BaseModel):
|
| 185 |
+
"""Response listing prediction history."""
|
| 186 |
+
success: bool
|
| 187 |
+
sessions: List[PredictionSessionInfo]
|
| 188 |
+
total: int
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class PredictionSessionDetailResponse(BaseModel):
|
| 192 |
+
"""Detailed response for a prediction session."""
|
| 193 |
+
success: bool
|
| 194 |
+
session: PredictionSessionInfo
|
| 195 |
+
results: List[Dict[str, Any]]
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class TrainingReportInfo(BaseModel):
|
| 199 |
+
"""Summary info about a training report."""
|
| 200 |
+
id: str
|
| 201 |
+
model_name: Optional[str]
|
| 202 |
+
best_model: str
|
| 203 |
+
best_score: float
|
| 204 |
+
test_auc: float
|
| 205 |
+
test_pr_auc: float
|
| 206 |
+
optimal_threshold: float
|
| 207 |
+
created_at: str
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class TrainingReportsResponse(BaseModel):
|
| 211 |
+
"""Response listing training reports."""
|
| 212 |
+
success: bool
|
| 213 |
+
reports: List[TrainingReportInfo]
|
| 214 |
+
total: int
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class TrainingReportDetailResponse(BaseModel):
|
| 218 |
+
"""Detailed response for a training report."""
|
| 219 |
+
success: bool
|
| 220 |
+
report: Dict[str, Any]
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# ============== Health Check Schemas ==============
|
| 224 |
+
|
| 225 |
+
class HealthResponse(BaseModel):
|
| 226 |
+
"""Health check response."""
|
| 227 |
+
status: str
|
| 228 |
+
version: str
|
| 229 |
+
database_status: str
|
credily/automl.py
ADDED
|
@@ -0,0 +1,1073 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AutoML pipeline module for Credily.
|
| 3 |
+
Trains and compares multiple models with cross-validation.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import json
|
| 9 |
+
import warnings
|
| 10 |
+
import joblib
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
# Suppress joblib resource tracker warnings on Windows
|
| 15 |
+
if sys.platform == 'win32':
|
| 16 |
+
os.environ.setdefault('LOKY_PICKLER', 'pickle')
|
| 17 |
+
warnings.filterwarnings('ignore', message='.*resource_tracker.*')
|
| 18 |
+
warnings.filterwarnings('ignore', message='.*Cannot register.*')
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Optional, Dict, Any, List
|
| 21 |
+
|
| 22 |
+
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
|
| 23 |
+
from sklearn.pipeline import Pipeline
|
| 24 |
+
from sklearn.compose import ColumnTransformer
|
| 25 |
+
from sklearn.impute import SimpleImputer
|
| 26 |
+
from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
| 27 |
+
from sklearn.linear_model import LogisticRegression
|
| 28 |
+
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
|
| 29 |
+
from sklearn.metrics import (
|
| 30 |
+
classification_report, roc_auc_score, precision_recall_curve,
|
| 31 |
+
confusion_matrix, f1_score, average_precision_score
|
| 32 |
+
)
|
| 33 |
+
from sklearn.calibration import CalibratedClassifierCV
|
| 34 |
+
|
| 35 |
+
from .profiler import DataProfiler
|
| 36 |
+
from .reporting import ReportGenerator
|
| 37 |
+
from .cleaning import DataCleaner
|
| 38 |
+
from .balancing import DataBalancer
|
| 39 |
+
from .agnostic_pipeline import AgnosticPipeline
|
| 40 |
+
from .safety import SafetyValidator, SafetyConfig, check_perfect_score_warning
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ============== Sanity Check Thresholds ==============
|
| 44 |
+
LEAKAGE_CORRELATION_THRESHOLD = 0.95 # Features with >95% correlation to target are suspicious
|
| 45 |
+
CV_TEST_DROP_THRESHOLD = 0.05 # Warn if test score drops >5% from CV score
|
| 46 |
+
FEATURE_DOMINANCE_THRESHOLD = 0.50 # Warn if single feature has >50% importance
|
| 47 |
+
MIN_MINORITY_SAMPLES = 50 # Minimum minority class samples for reliable PR-AUC
|
| 48 |
+
|
| 49 |
+
# ============== Small Dataset Thresholds ==============
|
| 50 |
+
SMALL_DATASET_THRESHOLD = 10000 # Datasets below this trigger conservative mode
|
| 51 |
+
VERY_SMALL_DATASET_THRESHOLD = 5000 # Datasets below this use extra regularization
|
| 52 |
+
|
| 53 |
+
# ============== Model Selection Thresholds ==============
|
| 54 |
+
RANDOM_FOREST_MIN_SAMPLES = 20000 # Only train RF for datasets > 20k rows
|
| 55 |
+
RF_PRAUC_TOLERANCE = 0.02 # Discard RF if PR-AUC < GB PR-AUC by this margin
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class CredilyPipeline:
|
| 59 |
+
"""
|
| 60 |
+
AutoML pipeline for binary classification tasks.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
target_column: str = 'target',
|
| 66 |
+
output_dir: str = 'credily_output',
|
| 67 |
+
test_size: float = 0.2,
|
| 68 |
+
cv_folds: int = 5,
|
| 69 |
+
random_state: int = 42,
|
| 70 |
+
# Cleaning options
|
| 71 |
+
clean_data: bool = True,
|
| 72 |
+
clean_mode: str = 'thorough',
|
| 73 |
+
outlier_method: str = 'iqr',
|
| 74 |
+
flag_missing: bool = True, # Create missing value indicator columns
|
| 75 |
+
# Balancing options
|
| 76 |
+
balance_data: bool = True,
|
| 77 |
+
balance_method: str = 'smote',
|
| 78 |
+
# Parallel processing (default=1 to avoid Windows joblib issues)
|
| 79 |
+
n_jobs: int = 1,
|
| 80 |
+
# Advanced options
|
| 81 |
+
calibrate: bool = True,
|
| 82 |
+
calibration_method: str = 'isotonic',
|
| 83 |
+
optimize_threshold: bool = True,
|
| 84 |
+
threshold_metric: str = 'f1',
|
| 85 |
+
# Regularization options (auto-tuned for small datasets)
|
| 86 |
+
conservative_mode: str = 'auto', # 'auto', 'always', 'never'
|
| 87 |
+
# Agnostic pipeline options (for dynamic target handling)
|
| 88 |
+
binary_threshold: Optional[float] = None, # Threshold to convert numeric target to binary
|
| 89 |
+
binary_rule: Optional[callable] = None, # Custom function for target conversion
|
| 90 |
+
positive_classes: Optional[list] = None # List of classes to treat as positive for binary grouping
|
| 91 |
+
):
|
| 92 |
+
self.target_column = target_column
|
| 93 |
+
self.output_dir = Path(output_dir)
|
| 94 |
+
self.test_size = test_size
|
| 95 |
+
self.cv_folds = cv_folds
|
| 96 |
+
self.random_state = random_state
|
| 97 |
+
|
| 98 |
+
# Cleaning and balancing options
|
| 99 |
+
self.clean_data = clean_data
|
| 100 |
+
self.clean_mode = clean_mode
|
| 101 |
+
self.outlier_method = outlier_method
|
| 102 |
+
self.flag_missing = flag_missing
|
| 103 |
+
self.balance_data = balance_data
|
| 104 |
+
self.balance_method = balance_method
|
| 105 |
+
self.n_jobs = n_jobs
|
| 106 |
+
|
| 107 |
+
# Advanced options
|
| 108 |
+
self.calibrate = calibrate
|
| 109 |
+
self.calibration_method = calibration_method
|
| 110 |
+
self.optimize_threshold = optimize_threshold
|
| 111 |
+
self.threshold_metric = threshold_metric
|
| 112 |
+
self.optimal_threshold = 0.5
|
| 113 |
+
self.conservative_mode = conservative_mode
|
| 114 |
+
self.is_small_dataset = False # Set during training
|
| 115 |
+
|
| 116 |
+
# Agnostic pipeline options
|
| 117 |
+
self.binary_threshold = binary_threshold
|
| 118 |
+
self.binary_rule = binary_rule
|
| 119 |
+
self.positive_classes = positive_classes
|
| 120 |
+
self.agnostic_pipeline = None # Initialized when needed
|
| 121 |
+
|
| 122 |
+
self.preprocessor = None
|
| 123 |
+
self.best_model = None
|
| 124 |
+
self.best_model_name = None
|
| 125 |
+
self.best_score = None
|
| 126 |
+
self.feature_names = None
|
| 127 |
+
self.numeric_columns = None
|
| 128 |
+
self.categorical_columns = None
|
| 129 |
+
self.expected_columns = None # All expected input columns (for prediction alignment)
|
| 130 |
+
self.profiler = DataProfiler(target_column=target_column)
|
| 131 |
+
self.profile_report = None
|
| 132 |
+
self.training_results = None
|
| 133 |
+
self.cleaning_report = None
|
| 134 |
+
self.balancing_report = None
|
| 135 |
+
self.agnostic_report = None # Report from AgnosticPipeline
|
| 136 |
+
self.class_ratio = None
|
| 137 |
+
self.sanity_warnings = [] # Sanity check warnings
|
| 138 |
+
self.safety_report = None # Safety validation report
|
| 139 |
+
|
| 140 |
+
# Models will be initialized after we know the class ratio
|
| 141 |
+
self.models = None
|
| 142 |
+
|
| 143 |
+
def _init_models(self, class_ratio: float = 1.0, n_samples: int = 0) -> Dict[str, Any]:
|
| 144 |
+
"""
|
| 145 |
+
Initialize all models with proper n_jobs and class weight settings.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
class_ratio: Ratio of negative to positive samples for scale_pos_weight
|
| 149 |
+
n_samples: Number of samples in dataset (for conservative mode tuning)
|
| 150 |
+
"""
|
| 151 |
+
# Determine if we should use conservative mode
|
| 152 |
+
use_conservative = False
|
| 153 |
+
if self.conservative_mode == 'always':
|
| 154 |
+
use_conservative = True
|
| 155 |
+
elif self.conservative_mode == 'auto' and n_samples < SMALL_DATASET_THRESHOLD:
|
| 156 |
+
use_conservative = True
|
| 157 |
+
self.is_small_dataset = True
|
| 158 |
+
|
| 159 |
+
# Extra conservative for very small datasets
|
| 160 |
+
very_small = n_samples < VERY_SMALL_DATASET_THRESHOLD
|
| 161 |
+
|
| 162 |
+
if use_conservative:
|
| 163 |
+
print(f"\n[CONSERVATIVE MODE] Dataset has {n_samples} samples - using regularized hyperparameters")
|
| 164 |
+
|
| 165 |
+
# Configure hyperparameters based on dataset size
|
| 166 |
+
if use_conservative:
|
| 167 |
+
# Conservative settings to prevent overfitting
|
| 168 |
+
rf_depth = 6 if very_small else 8
|
| 169 |
+
rf_estimators = 100
|
| 170 |
+
rf_min_samples_leaf = 10 if very_small else 5
|
| 171 |
+
|
| 172 |
+
gb_depth = 3 if very_small else 4
|
| 173 |
+
gb_estimators = 100
|
| 174 |
+
gb_learning_rate = 0.05 if very_small else 0.1
|
| 175 |
+
gb_subsample = 0.8
|
| 176 |
+
gb_min_samples_leaf = 10 if very_small else 5
|
| 177 |
+
|
| 178 |
+
xgb_depth = 4 if very_small else 5
|
| 179 |
+
xgb_estimators = 100
|
| 180 |
+
xgb_learning_rate = 0.05 if very_small else 0.1
|
| 181 |
+
xgb_subsample = 0.8
|
| 182 |
+
xgb_colsample = 0.8
|
| 183 |
+
|
| 184 |
+
lgb_depth = 4 if very_small else 5
|
| 185 |
+
lgb_estimators = 100
|
| 186 |
+
lgb_learning_rate = 0.05 if very_small else 0.1
|
| 187 |
+
else:
|
| 188 |
+
# Standard settings for larger datasets
|
| 189 |
+
rf_depth = 10
|
| 190 |
+
rf_estimators = 200
|
| 191 |
+
rf_min_samples_leaf = 1
|
| 192 |
+
|
| 193 |
+
gb_depth = 5
|
| 194 |
+
gb_estimators = 100
|
| 195 |
+
gb_learning_rate = 0.1
|
| 196 |
+
gb_subsample = 1.0
|
| 197 |
+
gb_min_samples_leaf = 1
|
| 198 |
+
|
| 199 |
+
xgb_depth = 6
|
| 200 |
+
xgb_estimators = 200
|
| 201 |
+
xgb_learning_rate = 0.1
|
| 202 |
+
xgb_subsample = 1.0
|
| 203 |
+
xgb_colsample = 1.0
|
| 204 |
+
|
| 205 |
+
lgb_depth = 6
|
| 206 |
+
lgb_estimators = 200
|
| 207 |
+
lgb_learning_rate = 0.1
|
| 208 |
+
|
| 209 |
+
# Core models: Logistic Regression (baseline) + Gradient Boosting (champion)
|
| 210 |
+
models = {
|
| 211 |
+
'LogisticRegression': LogisticRegression(
|
| 212 |
+
max_iter=1000,
|
| 213 |
+
class_weight='balanced',
|
| 214 |
+
random_state=self.random_state
|
| 215 |
+
# Note: n_jobs removed - deprecated in sklearn 1.8+
|
| 216 |
+
),
|
| 217 |
+
'GradientBoosting': GradientBoostingClassifier(
|
| 218 |
+
n_estimators=gb_estimators,
|
| 219 |
+
max_depth=gb_depth,
|
| 220 |
+
learning_rate=gb_learning_rate,
|
| 221 |
+
subsample=gb_subsample,
|
| 222 |
+
min_samples_leaf=gb_min_samples_leaf,
|
| 223 |
+
random_state=self.random_state
|
| 224 |
+
),
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
# RandomForest: Only train for larger datasets (>20k rows) with strict regularization
|
| 228 |
+
# RF tends to overfit on smaller datasets - GB is more reliable
|
| 229 |
+
if n_samples >= RANDOM_FOREST_MIN_SAMPLES:
|
| 230 |
+
models['RandomForest'] = RandomForestClassifier(
|
| 231 |
+
n_estimators=rf_estimators,
|
| 232 |
+
max_depth=min(rf_depth, 8), # Strict regularization: max_depth <= 8
|
| 233 |
+
min_samples_leaf=max(rf_min_samples_leaf, 5),
|
| 234 |
+
class_weight='balanced',
|
| 235 |
+
random_state=self.random_state,
|
| 236 |
+
n_jobs=self.n_jobs
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# Try to import optional models
|
| 240 |
+
try:
|
| 241 |
+
from xgboost import XGBClassifier
|
| 242 |
+
# Use scale_pos_weight = (negatives / positives) for imbalanced data
|
| 243 |
+
models['XGBoost'] = XGBClassifier(
|
| 244 |
+
n_estimators=xgb_estimators,
|
| 245 |
+
max_depth=xgb_depth,
|
| 246 |
+
learning_rate=xgb_learning_rate,
|
| 247 |
+
subsample=xgb_subsample,
|
| 248 |
+
colsample_bytree=xgb_colsample,
|
| 249 |
+
scale_pos_weight=class_ratio,
|
| 250 |
+
random_state=self.random_state,
|
| 251 |
+
n_jobs=self.n_jobs,
|
| 252 |
+
use_label_encoder=False,
|
| 253 |
+
eval_metric='logloss'
|
| 254 |
+
)
|
| 255 |
+
except ImportError:
|
| 256 |
+
pass
|
| 257 |
+
|
| 258 |
+
try:
|
| 259 |
+
from lightgbm import LGBMClassifier
|
| 260 |
+
models['LightGBM'] = LGBMClassifier(
|
| 261 |
+
n_estimators=lgb_estimators,
|
| 262 |
+
max_depth=lgb_depth,
|
| 263 |
+
learning_rate=lgb_learning_rate,
|
| 264 |
+
class_weight='balanced',
|
| 265 |
+
random_state=self.random_state,
|
| 266 |
+
n_jobs=self.n_jobs,
|
| 267 |
+
verbose=-1
|
| 268 |
+
)
|
| 269 |
+
except ImportError:
|
| 270 |
+
pass
|
| 271 |
+
|
| 272 |
+
return models
|
| 273 |
+
|
| 274 |
+
def profile(self, df: pd.DataFrame) -> Dict[str, Any]:
|
| 275 |
+
"""Profile the dataset."""
|
| 276 |
+
self.profile_report = self.profiler.profile(df)
|
| 277 |
+
return self.profile_report
|
| 278 |
+
|
| 279 |
+
def train(self, df: pd.DataFrame) -> Dict[str, Any]:
|
| 280 |
+
"""
|
| 281 |
+
Train multiple models and select the best performer.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
df: Training dataframe with features and target
|
| 285 |
+
|
| 286 |
+
Returns:
|
| 287 |
+
dict: Training results with best model and scores
|
| 288 |
+
"""
|
| 289 |
+
if self.target_column not in df.columns:
|
| 290 |
+
raise ValueError(f"Target column '{self.target_column}' not found")
|
| 291 |
+
|
| 292 |
+
# Create output directory
|
| 293 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 294 |
+
|
| 295 |
+
# Step 1: Clean data
|
| 296 |
+
if self.clean_data:
|
| 297 |
+
cleaner = DataCleaner(
|
| 298 |
+
target_column=self.target_column,
|
| 299 |
+
clean_mode=self.clean_mode,
|
| 300 |
+
outlier_method=self.outlier_method,
|
| 301 |
+
flag_missing=self.flag_missing # Control missing value indicators
|
| 302 |
+
)
|
| 303 |
+
df = cleaner.clean(df)
|
| 304 |
+
self.cleaning_report = cleaner.get_report()
|
| 305 |
+
|
| 306 |
+
# Step 1.2: Safety Validation - Pre-training checks
|
| 307 |
+
print("\n[SAFETY] Running pre-training safety validation...")
|
| 308 |
+
safety_validator = SafetyValidator(verbose=True)
|
| 309 |
+
df, self.safety_report = safety_validator.run_pre_training_checks(df, self.target_column)
|
| 310 |
+
|
| 311 |
+
# Add dropped features to sanity warnings
|
| 312 |
+
if self.safety_report.dropped_features:
|
| 313 |
+
self.sanity_warnings.append(
|
| 314 |
+
f"Safety: Dropped {len(self.safety_report.dropped_features)} features due to leakage/redundancy"
|
| 315 |
+
)
|
| 316 |
+
for warning in self.safety_report.warnings:
|
| 317 |
+
self.sanity_warnings.append(f"Safety: {warning}")
|
| 318 |
+
|
| 319 |
+
# Step 1.5: Apply AgnosticPipeline for dynamic target handling (if configured)
|
| 320 |
+
if self.binary_threshold is not None or self.binary_rule is not None or self.positive_classes is not None:
|
| 321 |
+
print("\n[AGNOSTIC PIPELINE] Dynamic target transformation enabled")
|
| 322 |
+
self.agnostic_pipeline = AgnosticPipeline(
|
| 323 |
+
binary_threshold=self.binary_threshold,
|
| 324 |
+
binary_rule=self.binary_rule,
|
| 325 |
+
positive_classes=self.positive_classes,
|
| 326 |
+
task_type='binary',
|
| 327 |
+
flag_missing=self.flag_missing,
|
| 328 |
+
verbose=True
|
| 329 |
+
)
|
| 330 |
+
X, y = self.agnostic_pipeline.fit_transform(df, self.target_column)
|
| 331 |
+
self.agnostic_report = self.agnostic_pipeline.get_report()
|
| 332 |
+
else:
|
| 333 |
+
X = df.drop(columns=[self.target_column])
|
| 334 |
+
y = df[self.target_column]
|
| 335 |
+
|
| 336 |
+
n_samples = len(X)
|
| 337 |
+
|
| 338 |
+
# Validate we have exactly 2 classes for binary classification
|
| 339 |
+
class_counts = y.value_counts()
|
| 340 |
+
n_classes = len(class_counts)
|
| 341 |
+
if n_classes < 2:
|
| 342 |
+
raise ValueError(
|
| 343 |
+
f"Binary classification requires 2 classes, but found {n_classes}. "
|
| 344 |
+
f"Classes: {list(class_counts.index)}. "
|
| 345 |
+
"Check if data cleaning removed one class or if target column has issues."
|
| 346 |
+
)
|
| 347 |
+
elif n_classes > 2:
|
| 348 |
+
print(f"Warning: Found {n_classes} classes. Treating as multi-class problem.")
|
| 349 |
+
|
| 350 |
+
# Calculate class ratio for XGBoost scale_pos_weight
|
| 351 |
+
if len(class_counts) == 2:
|
| 352 |
+
n_negative = class_counts.get(0, class_counts.iloc[0])
|
| 353 |
+
n_positive = class_counts.get(1, class_counts.iloc[1])
|
| 354 |
+
self.class_ratio = n_negative / n_positive if n_positive > 0 else 1.0
|
| 355 |
+
print(f"Class ratio (neg/pos): {self.class_ratio:.2f}")
|
| 356 |
+
else:
|
| 357 |
+
self.class_ratio = 1.0
|
| 358 |
+
|
| 359 |
+
# Initialize models with class ratio and sample count for proper tuning
|
| 360 |
+
self.models = self._init_models(class_ratio=self.class_ratio, n_samples=n_samples)
|
| 361 |
+
|
| 362 |
+
# Identify column types
|
| 363 |
+
self.numeric_columns = X.select_dtypes(include=[np.number]).columns.tolist()
|
| 364 |
+
self.categorical_columns = X.select_dtypes(include=['object', 'category']).columns.tolist()
|
| 365 |
+
|
| 366 |
+
# Store all expected columns for prediction alignment
|
| 367 |
+
self.expected_columns = X.columns.tolist()
|
| 368 |
+
|
| 369 |
+
# Create preprocessor
|
| 370 |
+
self.preprocessor = self._create_preprocessor()
|
| 371 |
+
|
| 372 |
+
# Split data
|
| 373 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 374 |
+
X, y,
|
| 375 |
+
test_size=self.test_size,
|
| 376 |
+
stratify=y,
|
| 377 |
+
random_state=self.random_state
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# Step 2: Balance training data (only on training set to avoid data leakage)
|
| 381 |
+
if self.balance_data:
|
| 382 |
+
balancer = DataBalancer(
|
| 383 |
+
method=self.balance_method,
|
| 384 |
+
random_state=self.random_state
|
| 385 |
+
)
|
| 386 |
+
X_train, y_train = balancer.balance(X_train, y_train)
|
| 387 |
+
self.balancing_report = balancer.get_report()
|
| 388 |
+
|
| 389 |
+
# Train and evaluate models
|
| 390 |
+
model_scores = {}
|
| 391 |
+
cv = StratifiedKFold(n_splits=self.cv_folds, shuffle=True, random_state=self.random_state)
|
| 392 |
+
|
| 393 |
+
print(f"\nTraining {len(self.models)} models with {self.cv_folds}-fold CV...")
|
| 394 |
+
print("-" * 50)
|
| 395 |
+
|
| 396 |
+
for name, model in self.models.items():
|
| 397 |
+
pipeline = Pipeline([
|
| 398 |
+
('preprocessor', self.preprocessor),
|
| 399 |
+
('classifier', model)
|
| 400 |
+
])
|
| 401 |
+
|
| 402 |
+
try:
|
| 403 |
+
scores = cross_val_score(
|
| 404 |
+
pipeline, X_train, y_train,
|
| 405 |
+
cv=cv, scoring='roc_auc', n_jobs=self.n_jobs
|
| 406 |
+
)
|
| 407 |
+
mean_score = scores.mean()
|
| 408 |
+
std_score = scores.std()
|
| 409 |
+
model_scores[name] = mean_score
|
| 410 |
+
print(f" {name}: ROC-AUC = {mean_score:.4f} (+/- {std_score:.4f})")
|
| 411 |
+
except Exception as e:
|
| 412 |
+
print(f" {name}: Failed - {str(e)}")
|
| 413 |
+
model_scores[name] = 0.0
|
| 414 |
+
|
| 415 |
+
# Auto-discard RandomForest if it underperforms GradientBoosting
|
| 416 |
+
# This prevents RF from being selected when it overfits
|
| 417 |
+
if 'RandomForest' in model_scores and 'GradientBoosting' in model_scores:
|
| 418 |
+
rf_score = model_scores['RandomForest']
|
| 419 |
+
gb_score = model_scores['GradientBoosting']
|
| 420 |
+
if rf_score < gb_score - RF_PRAUC_TOLERANCE:
|
| 421 |
+
print(f" [AUTO-DISCARD] RandomForest ({rf_score:.4f}) underperforms GradientBoosting ({gb_score:.4f}) - removing from candidates")
|
| 422 |
+
del model_scores['RandomForest']
|
| 423 |
+
|
| 424 |
+
print("-" * 50)
|
| 425 |
+
print(f"\nTraining and evaluating all candidate models on test set...")
|
| 426 |
+
print("-" * 50)
|
| 427 |
+
|
| 428 |
+
# Train ALL candidate models on full training data and evaluate on test set
|
| 429 |
+
# This enables proper ranking based on test set metrics
|
| 430 |
+
trained_models = {}
|
| 431 |
+
model_test_metrics = {}
|
| 432 |
+
|
| 433 |
+
for name in model_scores.keys():
|
| 434 |
+
print(f"\n Training {name}...")
|
| 435 |
+
pipeline = Pipeline([
|
| 436 |
+
('preprocessor', self.preprocessor),
|
| 437 |
+
('classifier', self.models[name])
|
| 438 |
+
])
|
| 439 |
+
pipeline.fit(X_train, y_train)
|
| 440 |
+
|
| 441 |
+
# Temporarily store feature names for this model
|
| 442 |
+
temp_model = self.best_model
|
| 443 |
+
self.best_model = pipeline
|
| 444 |
+
self._extract_feature_names()
|
| 445 |
+
|
| 446 |
+
# Apply calibration if enabled
|
| 447 |
+
if self.calibrate:
|
| 448 |
+
pipeline = self._calibrate_model(X_train, y_train)
|
| 449 |
+
|
| 450 |
+
self.best_model = temp_model # Restore
|
| 451 |
+
|
| 452 |
+
trained_models[name] = pipeline
|
| 453 |
+
|
| 454 |
+
# Evaluate on test set
|
| 455 |
+
y_proba = pipeline.predict_proba(X_test)[:, 1]
|
| 456 |
+
|
| 457 |
+
# Find optimal threshold for this model
|
| 458 |
+
if self.optimize_threshold:
|
| 459 |
+
threshold = self._find_optimal_threshold(y_test, y_proba)
|
| 460 |
+
else:
|
| 461 |
+
threshold = 0.5
|
| 462 |
+
|
| 463 |
+
y_pred = (y_proba >= threshold).astype(int)
|
| 464 |
+
|
| 465 |
+
# Compute metrics for ranking
|
| 466 |
+
pr_auc = average_precision_score(y_test, y_proba)
|
| 467 |
+
roc_auc = roc_auc_score(y_test, y_proba)
|
| 468 |
+
|
| 469 |
+
# Compute confusion matrix to get Default Recall and FP count
|
| 470 |
+
# Default class is 1 (positive class), recall = TP / (TP + FN)
|
| 471 |
+
cm = confusion_matrix(y_test, y_pred)
|
| 472 |
+
tn, fp, fn, tp = cm.ravel()
|
| 473 |
+
default_recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 474 |
+
fp_count = fp
|
| 475 |
+
|
| 476 |
+
model_test_metrics[name] = {
|
| 477 |
+
'pr_auc': pr_auc,
|
| 478 |
+
'roc_auc': roc_auc,
|
| 479 |
+
'default_recall': default_recall,
|
| 480 |
+
'fp_count': fp_count,
|
| 481 |
+
'threshold': threshold,
|
| 482 |
+
'cv_score': model_scores[name]
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
print(f" CV ROC-AUC: {model_scores[name]:.4f}")
|
| 486 |
+
print(f" Test PR-AUC: {pr_auc:.4f} | Test ROC-AUC: {roc_auc:.4f}")
|
| 487 |
+
print(f" Default Recall: {default_recall:.4f} | FP Count: {fp_count}")
|
| 488 |
+
|
| 489 |
+
# Rank models: PR-AUC (desc) → Default Recall (desc) → FP Count (asc)
|
| 490 |
+
print("\n" + "-" * 50)
|
| 491 |
+
print("Model Ranking (PR-AUC → Default Recall → FP Count):")
|
| 492 |
+
print("-" * 50)
|
| 493 |
+
|
| 494 |
+
def model_ranking_key(name):
|
| 495 |
+
metrics = model_test_metrics[name]
|
| 496 |
+
# Higher PR-AUC is better (negate for descending)
|
| 497 |
+
# Higher Default Recall is better (negate for descending)
|
| 498 |
+
# Lower FP Count is better (keep positive for ascending)
|
| 499 |
+
return (-metrics['pr_auc'], -metrics['default_recall'], metrics['fp_count'])
|
| 500 |
+
|
| 501 |
+
ranked_models = sorted(model_test_metrics.keys(), key=model_ranking_key)
|
| 502 |
+
|
| 503 |
+
for rank, name in enumerate(ranked_models, 1):
|
| 504 |
+
m = model_test_metrics[name]
|
| 505 |
+
marker = " ★" if rank == 1 else ""
|
| 506 |
+
print(f" {rank}. {name}: PR-AUC={m['pr_auc']:.4f}, Recall={m['default_recall']:.4f}, FP={m['fp_count']}{marker}")
|
| 507 |
+
|
| 508 |
+
# Select the best model based on ranking
|
| 509 |
+
self.best_model_name = ranked_models[0]
|
| 510 |
+
self.best_model = trained_models[self.best_model_name]
|
| 511 |
+
self.best_score = model_scores[self.best_model_name] # CV score for consistency
|
| 512 |
+
self.optimal_threshold = model_test_metrics[self.best_model_name]['threshold']
|
| 513 |
+
|
| 514 |
+
# Store all model test metrics for reporting
|
| 515 |
+
self.model_test_metrics = model_test_metrics
|
| 516 |
+
|
| 517 |
+
print("-" * 50)
|
| 518 |
+
print(f"Selected model: {self.best_model_name}")
|
| 519 |
+
print(f" CV ROC-AUC: {self.best_score:.4f}")
|
| 520 |
+
print(f" Test PR-AUC: {model_test_metrics[self.best_model_name]['pr_auc']:.4f}")
|
| 521 |
+
print(f" Default Recall: {model_test_metrics[self.best_model_name]['default_recall']:.4f}")
|
| 522 |
+
print(f" FP Count: {model_test_metrics[self.best_model_name]['fp_count']}")
|
| 523 |
+
|
| 524 |
+
# Re-extract feature names for the selected model
|
| 525 |
+
self._extract_feature_names()
|
| 526 |
+
|
| 527 |
+
# Get test metrics for the selected model
|
| 528 |
+
y_proba = self.best_model.predict_proba(X_test)[:, 1]
|
| 529 |
+
test_auc = roc_auc_score(y_test, y_proba)
|
| 530 |
+
test_pr_auc = average_precision_score(y_test, y_proba)
|
| 531 |
+
|
| 532 |
+
# Optimize threshold if enabled
|
| 533 |
+
if self.optimize_threshold:
|
| 534 |
+
self.optimal_threshold = self._find_optimal_threshold(y_test, y_proba)
|
| 535 |
+
print(f"Optimal threshold ({self.threshold_metric}): {self.optimal_threshold:.3f}")
|
| 536 |
+
|
| 537 |
+
# Use optimal threshold for predictions
|
| 538 |
+
y_pred = (y_proba >= self.optimal_threshold).astype(int)
|
| 539 |
+
|
| 540 |
+
print(f"\nTest set evaluation (threshold={self.optimal_threshold:.3f}):")
|
| 541 |
+
print(classification_report(y_test, y_pred))
|
| 542 |
+
print(f"Test ROC-AUC: {test_auc:.4f}")
|
| 543 |
+
print(f"Test PR-AUC: {test_pr_auc:.4f}")
|
| 544 |
+
|
| 545 |
+
# Get feature importances
|
| 546 |
+
feature_importances = self._get_feature_importances()
|
| 547 |
+
|
| 548 |
+
# Run sanity checks
|
| 549 |
+
self.sanity_warnings = self._run_sanity_checks(
|
| 550 |
+
X=X,
|
| 551 |
+
y=y,
|
| 552 |
+
cv_score=self.best_score,
|
| 553 |
+
test_score=test_auc,
|
| 554 |
+
importances=feature_importances
|
| 555 |
+
)
|
| 556 |
+
self._print_sanity_warnings()
|
| 557 |
+
|
| 558 |
+
# Step 6: Post-training safety validation
|
| 559 |
+
print("\n[SAFETY] Running post-training safety validation...")
|
| 560 |
+
self.safety_report = safety_validator.run_post_training_checks(
|
| 561 |
+
feature_importances=feature_importances,
|
| 562 |
+
cv_score=self.best_score,
|
| 563 |
+
test_auc=test_auc
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
# Check for perfect scores (likely leakage)
|
| 567 |
+
perfect_score_warnings = check_perfect_score_warning({
|
| 568 |
+
'Test AUC': test_auc,
|
| 569 |
+
'Test PR-AUC': test_pr_auc,
|
| 570 |
+
'CV Score': self.best_score
|
| 571 |
+
})
|
| 572 |
+
for warning in perfect_score_warnings:
|
| 573 |
+
self.sanity_warnings.append(warning)
|
| 574 |
+
self.safety_report.add_warning(warning)
|
| 575 |
+
|
| 576 |
+
# Add post-training safety warnings/errors to sanity warnings
|
| 577 |
+
for error in self.safety_report.errors:
|
| 578 |
+
self.sanity_warnings.append(f"SAFETY FAIL: {error}")
|
| 579 |
+
for warning in self.safety_report.warnings:
|
| 580 |
+
if warning not in self.sanity_warnings:
|
| 581 |
+
self.sanity_warnings.append(f"Safety: {warning}")
|
| 582 |
+
|
| 583 |
+
# Compute ROC curve points for visualization
|
| 584 |
+
from sklearn.metrics import roc_curve, precision_recall_curve
|
| 585 |
+
fpr, tpr, roc_thresholds = roc_curve(y_test, y_proba)
|
| 586 |
+
precision_curve, recall_curve, pr_thresholds = precision_recall_curve(y_test, y_proba)
|
| 587 |
+
|
| 588 |
+
# Compile results
|
| 589 |
+
self.training_results = {
|
| 590 |
+
'best_model': self.best_model_name,
|
| 591 |
+
'best_score': self.best_score,
|
| 592 |
+
'test_auc': test_auc,
|
| 593 |
+
'test_pr_auc': test_pr_auc,
|
| 594 |
+
'optimal_threshold': self.optimal_threshold,
|
| 595 |
+
'class_ratio': self.class_ratio,
|
| 596 |
+
'model_scores': model_scores,
|
| 597 |
+
'model_test_metrics': model_test_metrics, # Full test metrics for all models
|
| 598 |
+
'model_ranking': ranked_models, # Ordered list of model names by rank
|
| 599 |
+
'classification_report': classification_report(y_test, y_pred, output_dict=True),
|
| 600 |
+
'confusion_matrix': confusion_matrix(y_test, y_pred).tolist(),
|
| 601 |
+
'feature_importances': feature_importances,
|
| 602 |
+
'cleaning_report': self.cleaning_report,
|
| 603 |
+
'balancing_report': self.balancing_report,
|
| 604 |
+
'agnostic_report': self.agnostic_report,
|
| 605 |
+
'safety_report': self.safety_report.to_dict() if self.safety_report else None,
|
| 606 |
+
'model_valid': self.safety_report.model_valid if self.safety_report else True,
|
| 607 |
+
'calibrated': self.calibrate,
|
| 608 |
+
'sanity_warnings': self.sanity_warnings, # Include warnings in results
|
| 609 |
+
# Test predictions for visualization
|
| 610 |
+
'test_predictions': {
|
| 611 |
+
'y_true': y_test.tolist(),
|
| 612 |
+
'y_pred': y_pred.tolist(),
|
| 613 |
+
'y_proba': y_proba.tolist(),
|
| 614 |
+
'n_samples': len(y_test)
|
| 615 |
+
},
|
| 616 |
+
# ROC and PR curve data for charting
|
| 617 |
+
'roc_curve': {
|
| 618 |
+
'fpr': fpr.tolist(),
|
| 619 |
+
'tpr': tpr.tolist()
|
| 620 |
+
},
|
| 621 |
+
'pr_curve': {
|
| 622 |
+
'precision': precision_curve.tolist(),
|
| 623 |
+
'recall': recall_curve.tolist()
|
| 624 |
+
}
|
| 625 |
+
}
|
| 626 |
+
|
| 627 |
+
# Save outputs
|
| 628 |
+
self._save_outputs()
|
| 629 |
+
|
| 630 |
+
return self.training_results
|
| 631 |
+
|
| 632 |
+
def predict(
|
| 633 |
+
self,
|
| 634 |
+
df: pd.DataFrame,
|
| 635 |
+
include_proba: bool = False,
|
| 636 |
+
threshold: Optional[float] = None
|
| 637 |
+
) -> pd.DataFrame:
|
| 638 |
+
"""
|
| 639 |
+
Make predictions on new data using optimal threshold.
|
| 640 |
+
|
| 641 |
+
Args:
|
| 642 |
+
df: Input dataframe with features
|
| 643 |
+
include_proba: Whether to include prediction probabilities
|
| 644 |
+
threshold: Custom threshold (uses optimal if None)
|
| 645 |
+
|
| 646 |
+
Returns:
|
| 647 |
+
DataFrame with predictions
|
| 648 |
+
"""
|
| 649 |
+
if self.best_model is None:
|
| 650 |
+
raise ValueError("Model not trained. Call train() first or load a saved model.")
|
| 651 |
+
|
| 652 |
+
if self.target_column in df.columns:
|
| 653 |
+
X = df.drop(columns=[self.target_column])
|
| 654 |
+
else:
|
| 655 |
+
X = df.copy()
|
| 656 |
+
|
| 657 |
+
# Use agnostic pipeline transform if it was used during training
|
| 658 |
+
if self.agnostic_pipeline is not None:
|
| 659 |
+
X = self.agnostic_pipeline.transform(X)
|
| 660 |
+
alignment_info = {} # Agnostic pipeline handles alignment
|
| 661 |
+
else:
|
| 662 |
+
# Align prediction data with expected columns from training
|
| 663 |
+
X, alignment_info = self._align_prediction_data(X)
|
| 664 |
+
|
| 665 |
+
result = df.copy()
|
| 666 |
+
|
| 667 |
+
# Get probabilities
|
| 668 |
+
probas = self.best_model.predict_proba(X)
|
| 669 |
+
|
| 670 |
+
# Use optimal threshold
|
| 671 |
+
thresh = threshold if threshold is not None else self.optimal_threshold
|
| 672 |
+
result['prediction'] = (probas[:, 1] >= thresh).astype(int)
|
| 673 |
+
|
| 674 |
+
if include_proba:
|
| 675 |
+
result['proba_0'] = probas[:, 0]
|
| 676 |
+
result['proba_1'] = probas[:, 1]
|
| 677 |
+
result['threshold_used'] = thresh
|
| 678 |
+
|
| 679 |
+
# Add alignment info as metadata if there were issues
|
| 680 |
+
if alignment_info.get('missing_columns') or alignment_info.get('extra_columns'):
|
| 681 |
+
print(f"Column alignment applied: {alignment_info}")
|
| 682 |
+
|
| 683 |
+
return result
|
| 684 |
+
|
| 685 |
+
def _align_prediction_data(self, X: pd.DataFrame) -> tuple:
|
| 686 |
+
"""
|
| 687 |
+
Align prediction data to match the expected columns from training.
|
| 688 |
+
|
| 689 |
+
This handles:
|
| 690 |
+
1. Missing columns: Added with NaN (will be imputed by the pipeline)
|
| 691 |
+
2. Extra columns: Removed (not needed for prediction)
|
| 692 |
+
3. Column order: Reordered to match training order
|
| 693 |
+
|
| 694 |
+
Args:
|
| 695 |
+
X: Input dataframe for prediction
|
| 696 |
+
|
| 697 |
+
Returns:
|
| 698 |
+
Tuple of (aligned dataframe, alignment info dict)
|
| 699 |
+
"""
|
| 700 |
+
alignment_info = {
|
| 701 |
+
'missing_columns': [],
|
| 702 |
+
'extra_columns': [],
|
| 703 |
+
'columns_added_with_nan': []
|
| 704 |
+
}
|
| 705 |
+
|
| 706 |
+
if self.expected_columns is None:
|
| 707 |
+
# Fallback for older models without expected_columns
|
| 708 |
+
# Use numeric_columns + categorical_columns
|
| 709 |
+
if self.numeric_columns is not None and self.categorical_columns is not None:
|
| 710 |
+
self.expected_columns = self.numeric_columns + self.categorical_columns
|
| 711 |
+
|
| 712 |
+
if self.expected_columns is None:
|
| 713 |
+
# No column info available, return as-is
|
| 714 |
+
return X, alignment_info
|
| 715 |
+
|
| 716 |
+
input_cols = set(X.columns)
|
| 717 |
+
expected_cols = set(self.expected_columns)
|
| 718 |
+
|
| 719 |
+
# Find missing and extra columns
|
| 720 |
+
missing = expected_cols - input_cols
|
| 721 |
+
extra = input_cols - expected_cols
|
| 722 |
+
|
| 723 |
+
alignment_info['missing_columns'] = list(missing)
|
| 724 |
+
alignment_info['extra_columns'] = list(extra)
|
| 725 |
+
|
| 726 |
+
# Add missing columns with NaN (they will be imputed by SimpleImputer)
|
| 727 |
+
for col in missing:
|
| 728 |
+
# Determine the appropriate dtype for the missing column
|
| 729 |
+
if col in self.numeric_columns:
|
| 730 |
+
X[col] = np.nan
|
| 731 |
+
else:
|
| 732 |
+
# For categorical columns, use NaN (will be imputed with most_frequent)
|
| 733 |
+
X[col] = np.nan
|
| 734 |
+
alignment_info['columns_added_with_nan'].append(col)
|
| 735 |
+
|
| 736 |
+
# Remove extra columns (not needed for prediction)
|
| 737 |
+
X = X.drop(columns=list(extra), errors='ignore')
|
| 738 |
+
|
| 739 |
+
# Reorder columns to match training order
|
| 740 |
+
X = X[self.expected_columns]
|
| 741 |
+
|
| 742 |
+
return X, alignment_info
|
| 743 |
+
|
| 744 |
+
def _create_preprocessor(self) -> ColumnTransformer:
|
| 745 |
+
"""Create the preprocessing pipeline."""
|
| 746 |
+
numeric_transformer = Pipeline([
|
| 747 |
+
('imputer', SimpleImputer(strategy='median')),
|
| 748 |
+
('scaler', StandardScaler())
|
| 749 |
+
])
|
| 750 |
+
|
| 751 |
+
categorical_transformer = Pipeline([
|
| 752 |
+
('imputer', SimpleImputer(strategy='most_frequent')),
|
| 753 |
+
('encoder', OneHotEncoder(handle_unknown='ignore', sparse_output=False))
|
| 754 |
+
])
|
| 755 |
+
|
| 756 |
+
return ColumnTransformer(
|
| 757 |
+
transformers=[
|
| 758 |
+
('num', numeric_transformer, self.numeric_columns),
|
| 759 |
+
('cat', categorical_transformer, self.categorical_columns)
|
| 760 |
+
],
|
| 761 |
+
remainder='drop'
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
def _extract_feature_names(self):
|
| 765 |
+
"""Extract feature names from fitted preprocessor."""
|
| 766 |
+
preprocessor = self.best_model.named_steps['preprocessor']
|
| 767 |
+
feature_names = list(self.numeric_columns)
|
| 768 |
+
|
| 769 |
+
if self.categorical_columns:
|
| 770 |
+
cat_encoder = preprocessor.named_transformers_['cat'].named_steps['encoder']
|
| 771 |
+
cat_names = cat_encoder.get_feature_names_out(self.categorical_columns)
|
| 772 |
+
feature_names.extend(cat_names.tolist())
|
| 773 |
+
|
| 774 |
+
self.feature_names = feature_names
|
| 775 |
+
|
| 776 |
+
def _get_feature_importances(self) -> Dict[str, float]:
|
| 777 |
+
"""Get feature importances from the trained model."""
|
| 778 |
+
# Handle calibrated models
|
| 779 |
+
if hasattr(self.best_model, 'calibrated_classifiers_'):
|
| 780 |
+
# CalibratedClassifierCV wraps the original pipeline
|
| 781 |
+
base_classifier = self.best_model.calibrated_classifiers_[0].estimator
|
| 782 |
+
if hasattr(base_classifier, 'named_steps'):
|
| 783 |
+
classifier = base_classifier.named_steps['classifier']
|
| 784 |
+
else:
|
| 785 |
+
classifier = base_classifier
|
| 786 |
+
elif hasattr(self.best_model, 'named_steps'):
|
| 787 |
+
classifier = self.best_model.named_steps['classifier']
|
| 788 |
+
else:
|
| 789 |
+
return {}
|
| 790 |
+
|
| 791 |
+
if hasattr(classifier, 'feature_importances_'):
|
| 792 |
+
importances = classifier.feature_importances_
|
| 793 |
+
elif hasattr(classifier, 'coef_'):
|
| 794 |
+
importances = np.abs(classifier.coef_[0])
|
| 795 |
+
else:
|
| 796 |
+
return {}
|
| 797 |
+
|
| 798 |
+
return dict(zip(self.feature_names, importances.tolist()))
|
| 799 |
+
|
| 800 |
+
def _calibrate_model(self, X_train: pd.DataFrame, y_train: pd.Series) -> CalibratedClassifierCV:
|
| 801 |
+
"""
|
| 802 |
+
Apply probability calibration using isotonic regression or Platt scaling.
|
| 803 |
+
|
| 804 |
+
Args:
|
| 805 |
+
X_train: Training features
|
| 806 |
+
y_train: Training labels
|
| 807 |
+
|
| 808 |
+
Returns:
|
| 809 |
+
Calibrated model
|
| 810 |
+
"""
|
| 811 |
+
calibrated = CalibratedClassifierCV(
|
| 812 |
+
self.best_model,
|
| 813 |
+
method=self.calibration_method,
|
| 814 |
+
cv=3
|
| 815 |
+
)
|
| 816 |
+
calibrated.fit(X_train, y_train)
|
| 817 |
+
return calibrated
|
| 818 |
+
|
| 819 |
+
def _find_optimal_threshold(self, y_true: pd.Series, y_proba: np.ndarray) -> float:
|
| 820 |
+
"""
|
| 821 |
+
Find optimal classification threshold based on specified metric.
|
| 822 |
+
|
| 823 |
+
Args:
|
| 824 |
+
y_true: True labels
|
| 825 |
+
y_proba: Predicted probabilities
|
| 826 |
+
|
| 827 |
+
Returns:
|
| 828 |
+
Optimal threshold value
|
| 829 |
+
"""
|
| 830 |
+
thresholds = np.arange(0.1, 0.9, 0.01)
|
| 831 |
+
best_threshold = 0.5
|
| 832 |
+
best_score = 0
|
| 833 |
+
|
| 834 |
+
if self.threshold_metric == 'f1':
|
| 835 |
+
for thresh in thresholds:
|
| 836 |
+
y_pred = (y_proba >= thresh).astype(int)
|
| 837 |
+
score = f1_score(y_true, y_pred, zero_division=0)
|
| 838 |
+
if score > best_score:
|
| 839 |
+
best_score = score
|
| 840 |
+
best_threshold = thresh
|
| 841 |
+
|
| 842 |
+
elif self.threshold_metric == 'precision_recall_balance':
|
| 843 |
+
# Find threshold where precision and recall are closest
|
| 844 |
+
precision, recall, pr_thresholds = precision_recall_curve(y_true, y_proba)
|
| 845 |
+
# Find index where precision and recall are most balanced
|
| 846 |
+
diff = np.abs(precision[:-1] - recall[:-1])
|
| 847 |
+
best_idx = np.argmin(diff)
|
| 848 |
+
best_threshold = pr_thresholds[best_idx]
|
| 849 |
+
|
| 850 |
+
elif self.threshold_metric == 'youden':
|
| 851 |
+
# Youden's J statistic: sensitivity + specificity - 1
|
| 852 |
+
for thresh in thresholds:
|
| 853 |
+
y_pred = (y_proba >= thresh).astype(int)
|
| 854 |
+
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
|
| 855 |
+
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
|
| 856 |
+
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
|
| 857 |
+
score = sensitivity + specificity - 1
|
| 858 |
+
if score > best_score:
|
| 859 |
+
best_score = score
|
| 860 |
+
best_threshold = thresh
|
| 861 |
+
|
| 862 |
+
elif self.threshold_metric == 'cost':
|
| 863 |
+
# Cost-sensitive: penalize false negatives more (missing defaults is expensive)
|
| 864 |
+
# Default: FN costs 5x more than FP
|
| 865 |
+
fn_cost = 5.0
|
| 866 |
+
fp_cost = 1.0
|
| 867 |
+
best_cost = float('inf')
|
| 868 |
+
for thresh in thresholds:
|
| 869 |
+
y_pred = (y_proba >= thresh).astype(int)
|
| 870 |
+
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
|
| 871 |
+
total_cost = fn * fn_cost + fp * fp_cost
|
| 872 |
+
if total_cost < best_cost:
|
| 873 |
+
best_cost = total_cost
|
| 874 |
+
best_threshold = thresh
|
| 875 |
+
|
| 876 |
+
return best_threshold
|
| 877 |
+
|
| 878 |
+
def _save_outputs(self):
|
| 879 |
+
"""Save model and reports."""
|
| 880 |
+
# Save model
|
| 881 |
+
model_path = self.output_dir / 'model.pkl'
|
| 882 |
+
model_data = {
|
| 883 |
+
'pipeline': self.best_model,
|
| 884 |
+
'model_name': self.best_model_name,
|
| 885 |
+
'feature_names': self.feature_names,
|
| 886 |
+
'numeric_columns': self.numeric_columns,
|
| 887 |
+
'categorical_columns': self.categorical_columns,
|
| 888 |
+
'expected_columns': self.expected_columns, # For prediction alignment
|
| 889 |
+
'target_column': self.target_column,
|
| 890 |
+
'optimal_threshold': self.optimal_threshold,
|
| 891 |
+
'class_ratio': self.class_ratio,
|
| 892 |
+
'calibrated': self.calibrate,
|
| 893 |
+
# Agnostic pipeline info
|
| 894 |
+
'agnostic_pipeline': self.agnostic_pipeline,
|
| 895 |
+
'binary_threshold': self.binary_threshold,
|
| 896 |
+
'binary_rule': self.binary_rule,
|
| 897 |
+
'positive_classes': self.positive_classes
|
| 898 |
+
}
|
| 899 |
+
joblib.dump(model_data, model_path)
|
| 900 |
+
print(f"\nModel saved to: {model_path}")
|
| 901 |
+
|
| 902 |
+
# Save JSON report
|
| 903 |
+
report_path = self.output_dir / 'report.json'
|
| 904 |
+
with open(report_path, 'w') as f:
|
| 905 |
+
json.dump(self.training_results, f, indent=2, default=str)
|
| 906 |
+
print(f"Report saved to: {report_path}")
|
| 907 |
+
|
| 908 |
+
# Generate HTML report
|
| 909 |
+
reporter = ReportGenerator(self.output_dir)
|
| 910 |
+
reporter.generate_html_report(
|
| 911 |
+
self.training_results,
|
| 912 |
+
self.profile_report
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
# ============== Sanity Checks ==============
|
| 916 |
+
|
| 917 |
+
def _check_leakage(self, X: pd.DataFrame, y: pd.Series) -> List[str]:
|
| 918 |
+
"""
|
| 919 |
+
Check for potential feature leakage by detecting features highly correlated with target.
|
| 920 |
+
|
| 921 |
+
Returns:
|
| 922 |
+
List of warning messages for suspicious features
|
| 923 |
+
"""
|
| 924 |
+
warnings = []
|
| 925 |
+
numeric_cols = X.select_dtypes(include=[np.number]).columns
|
| 926 |
+
|
| 927 |
+
for col in numeric_cols:
|
| 928 |
+
try:
|
| 929 |
+
# Calculate point-biserial correlation for numeric features
|
| 930 |
+
corr = np.abs(X[col].corr(y))
|
| 931 |
+
if corr >= LEAKAGE_CORRELATION_THRESHOLD:
|
| 932 |
+
warnings.append(
|
| 933 |
+
f"LEAKAGE WARNING: Feature '{col}' has {corr:.2%} correlation with target. "
|
| 934 |
+
f"This may indicate data leakage - investigate if this feature contains future information."
|
| 935 |
+
)
|
| 936 |
+
except Exception:
|
| 937 |
+
pass # Skip columns that can't be correlated
|
| 938 |
+
|
| 939 |
+
return warnings
|
| 940 |
+
|
| 941 |
+
def _check_minority_samples(self, y: pd.Series) -> List[str]:
|
| 942 |
+
"""
|
| 943 |
+
Check if minority class has enough samples for reliable metrics.
|
| 944 |
+
|
| 945 |
+
Returns:
|
| 946 |
+
List of warning messages
|
| 947 |
+
"""
|
| 948 |
+
warnings = []
|
| 949 |
+
class_counts = y.value_counts()
|
| 950 |
+
min_count = class_counts.min()
|
| 951 |
+
|
| 952 |
+
if min_count < MIN_MINORITY_SAMPLES:
|
| 953 |
+
warnings.append(
|
| 954 |
+
f"MINORITY CLASS WARNING: Only {min_count} samples in minority class. "
|
| 955 |
+
f"PR-AUC and other metrics may be unreliable. Consider collecting more data "
|
| 956 |
+
f"or using techniques like SMOTE cautiously."
|
| 957 |
+
)
|
| 958 |
+
|
| 959 |
+
return warnings
|
| 960 |
+
|
| 961 |
+
def _check_cv_test_gap(self, cv_score: float, test_score: float) -> List[str]:
|
| 962 |
+
"""
|
| 963 |
+
Check for significant gap between CV and test scores (indicates overfitting).
|
| 964 |
+
|
| 965 |
+
Returns:
|
| 966 |
+
List of warning messages
|
| 967 |
+
"""
|
| 968 |
+
warnings = []
|
| 969 |
+
gap = cv_score - test_score
|
| 970 |
+
|
| 971 |
+
if gap > CV_TEST_DROP_THRESHOLD:
|
| 972 |
+
warnings.append(
|
| 973 |
+
f"OVERFITTING WARNING: Test AUC ({test_score:.4f}) is {gap:.4f} lower than "
|
| 974 |
+
f"CV AUC ({cv_score:.4f}). This indicates potential overfitting. "
|
| 975 |
+
f"Consider: more regularization, simpler model, or more training data."
|
| 976 |
+
)
|
| 977 |
+
|
| 978 |
+
return warnings
|
| 979 |
+
|
| 980 |
+
def _check_feature_dominance(self, importances: Dict[str, float]) -> List[str]:
|
| 981 |
+
"""
|
| 982 |
+
Check if a single feature dominates the model (potential leakage or oversensitivity).
|
| 983 |
+
|
| 984 |
+
Returns:
|
| 985 |
+
List of warning messages
|
| 986 |
+
"""
|
| 987 |
+
warnings = []
|
| 988 |
+
if not importances:
|
| 989 |
+
return warnings
|
| 990 |
+
|
| 991 |
+
total = sum(importances.values())
|
| 992 |
+
if total == 0:
|
| 993 |
+
return warnings
|
| 994 |
+
|
| 995 |
+
for feature, importance in importances.items():
|
| 996 |
+
pct = importance / total
|
| 997 |
+
if pct > FEATURE_DOMINANCE_THRESHOLD:
|
| 998 |
+
warnings.append(
|
| 999 |
+
f"FEATURE DOMINANCE WARNING: Feature '{feature}' accounts for {pct:.1%} "
|
| 1000 |
+
f"of model importance. Investigate for potential leakage or consider "
|
| 1001 |
+
f"if the model is too dependent on a single variable."
|
| 1002 |
+
)
|
| 1003 |
+
|
| 1004 |
+
return warnings
|
| 1005 |
+
|
| 1006 |
+
def _run_sanity_checks(
|
| 1007 |
+
self,
|
| 1008 |
+
X: pd.DataFrame,
|
| 1009 |
+
y: pd.Series,
|
| 1010 |
+
cv_score: float,
|
| 1011 |
+
test_score: float,
|
| 1012 |
+
importances: Dict[str, float]
|
| 1013 |
+
) -> List[str]:
|
| 1014 |
+
"""
|
| 1015 |
+
Run all sanity checks and collect warnings.
|
| 1016 |
+
|
| 1017 |
+
Returns:
|
| 1018 |
+
List of all warning messages
|
| 1019 |
+
"""
|
| 1020 |
+
all_warnings = []
|
| 1021 |
+
|
| 1022 |
+
# Check for leakage
|
| 1023 |
+
all_warnings.extend(self._check_leakage(X, y))
|
| 1024 |
+
|
| 1025 |
+
# Check minority class samples
|
| 1026 |
+
all_warnings.extend(self._check_minority_samples(y))
|
| 1027 |
+
|
| 1028 |
+
# Check CV vs test gap
|
| 1029 |
+
all_warnings.extend(self._check_cv_test_gap(cv_score, test_score))
|
| 1030 |
+
|
| 1031 |
+
# Check feature dominance
|
| 1032 |
+
all_warnings.extend(self._check_feature_dominance(importances))
|
| 1033 |
+
|
| 1034 |
+
return all_warnings
|
| 1035 |
+
|
| 1036 |
+
def _print_sanity_warnings(self):
|
| 1037 |
+
"""Print all sanity warnings."""
|
| 1038 |
+
if self.sanity_warnings:
|
| 1039 |
+
print(f"\n{'='*60}")
|
| 1040 |
+
print("SANITY CHECK WARNINGS")
|
| 1041 |
+
print(f"{'='*60}")
|
| 1042 |
+
for i, warning in enumerate(self.sanity_warnings, 1):
|
| 1043 |
+
print(f"\n[{i}] {warning}")
|
| 1044 |
+
print(f"\n{'='*60}")
|
| 1045 |
+
|
| 1046 |
+
@classmethod
|
| 1047 |
+
def load(cls, model_path: str) -> 'CredilyPipeline':
|
| 1048 |
+
"""Load a trained pipeline from disk."""
|
| 1049 |
+
model_data = joblib.load(model_path)
|
| 1050 |
+
|
| 1051 |
+
instance = cls(target_column=model_data['target_column'])
|
| 1052 |
+
instance.best_model = model_data['pipeline']
|
| 1053 |
+
instance.best_model_name = model_data['model_name']
|
| 1054 |
+
instance.feature_names = model_data['feature_names']
|
| 1055 |
+
instance.numeric_columns = model_data['numeric_columns']
|
| 1056 |
+
instance.categorical_columns = model_data['categorical_columns']
|
| 1057 |
+
instance.expected_columns = model_data.get('expected_columns', None)
|
| 1058 |
+
instance.optimal_threshold = model_data.get('optimal_threshold', 0.5)
|
| 1059 |
+
instance.class_ratio = model_data.get('class_ratio', 1.0)
|
| 1060 |
+
instance.calibrate = model_data.get('calibrated', False)
|
| 1061 |
+
|
| 1062 |
+
# Agnostic pipeline info
|
| 1063 |
+
instance.agnostic_pipeline = model_data.get('agnostic_pipeline', None)
|
| 1064 |
+
instance.binary_threshold = model_data.get('binary_threshold', None)
|
| 1065 |
+
instance.binary_rule = model_data.get('binary_rule', None)
|
| 1066 |
+
instance.positive_classes = model_data.get('positive_classes', None)
|
| 1067 |
+
|
| 1068 |
+
# Fallback for older models without expected_columns
|
| 1069 |
+
if instance.expected_columns is None:
|
| 1070 |
+
if instance.numeric_columns is not None and instance.categorical_columns is not None:
|
| 1071 |
+
instance.expected_columns = instance.numeric_columns + instance.categorical_columns
|
| 1072 |
+
|
| 1073 |
+
return instance
|
credily/balancing.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data balancing module for Credily.
|
| 3 |
+
Handles imbalanced datasets for binary classification.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import Tuple, Optional, Dict, Any
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DataBalancer:
|
| 12 |
+
"""
|
| 13 |
+
Balances imbalanced datasets for binary classification.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
METHODS = ['smote', 'random_oversample', 'random_undersample', 'smote_tomek', 'tomek', 'nearmiss', 'none']
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
method: str = 'smote',
|
| 21 |
+
sampling_strategy: str = 'auto',
|
| 22 |
+
random_state: int = 42
|
| 23 |
+
):
|
| 24 |
+
"""
|
| 25 |
+
Initialize the DataBalancer.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
method: Balancing method ('smote', 'random_oversample', 'random_undersample', 'smote_tomek', 'none')
|
| 29 |
+
sampling_strategy: Sampling strategy ('auto', 'minority', or float ratio)
|
| 30 |
+
random_state: Random seed for reproducibility
|
| 31 |
+
"""
|
| 32 |
+
if method not in self.METHODS:
|
| 33 |
+
raise ValueError(f"Unknown method: {method}. Choose from {self.METHODS}")
|
| 34 |
+
|
| 35 |
+
self.method = method
|
| 36 |
+
self.sampling_strategy = sampling_strategy
|
| 37 |
+
self.random_state = random_state
|
| 38 |
+
self.balancing_report = {}
|
| 39 |
+
|
| 40 |
+
def balance(
|
| 41 |
+
self,
|
| 42 |
+
X: pd.DataFrame,
|
| 43 |
+
y: pd.Series,
|
| 44 |
+
verbose: bool = True
|
| 45 |
+
) -> Tuple[pd.DataFrame, pd.Series]:
|
| 46 |
+
"""
|
| 47 |
+
Balance the dataset.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
X: Feature dataframe
|
| 51 |
+
y: Target series
|
| 52 |
+
verbose: Print balancing steps
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Balanced (X, y) tuple
|
| 56 |
+
"""
|
| 57 |
+
if self.method == 'none':
|
| 58 |
+
return X, y
|
| 59 |
+
|
| 60 |
+
original_counts = y.value_counts().to_dict()
|
| 61 |
+
|
| 62 |
+
if verbose:
|
| 63 |
+
print(f"\n{'='*50}")
|
| 64 |
+
print("DATA BALANCING")
|
| 65 |
+
print(f"{'='*50}")
|
| 66 |
+
print(f"Original class distribution: {original_counts}")
|
| 67 |
+
print(f"Method: {self.method}")
|
| 68 |
+
|
| 69 |
+
# Check imbalance ratio
|
| 70 |
+
majority = max(original_counts.values())
|
| 71 |
+
minority = min(original_counts.values())
|
| 72 |
+
imbalance_ratio = majority / minority if minority > 0 else float('inf')
|
| 73 |
+
|
| 74 |
+
if imbalance_ratio < 1.5:
|
| 75 |
+
if verbose:
|
| 76 |
+
print("Dataset is already balanced (ratio < 1.5). Skipping.")
|
| 77 |
+
return X, y
|
| 78 |
+
|
| 79 |
+
# Apply balancing method
|
| 80 |
+
if self.method == 'smote':
|
| 81 |
+
X_bal, y_bal = self._apply_smote(X, y)
|
| 82 |
+
elif self.method == 'random_oversample':
|
| 83 |
+
X_bal, y_bal = self._apply_random_oversample(X, y)
|
| 84 |
+
elif self.method == 'random_undersample':
|
| 85 |
+
X_bal, y_bal = self._apply_random_undersample(X, y)
|
| 86 |
+
elif self.method == 'smote_tomek':
|
| 87 |
+
X_bal, y_bal = self._apply_smote_tomek(X, y)
|
| 88 |
+
elif self.method == 'tomek':
|
| 89 |
+
X_bal, y_bal = self._apply_tomek_links(X, y)
|
| 90 |
+
elif self.method == 'nearmiss':
|
| 91 |
+
X_bal, y_bal = self._apply_nearmiss(X, y)
|
| 92 |
+
else:
|
| 93 |
+
X_bal, y_bal = X, y
|
| 94 |
+
|
| 95 |
+
final_counts = y_bal.value_counts().to_dict()
|
| 96 |
+
|
| 97 |
+
if verbose:
|
| 98 |
+
print(f"Final class distribution: {final_counts}")
|
| 99 |
+
print(f"Samples before: {len(y)}, after: {len(y_bal)}")
|
| 100 |
+
print(f"{'='*50}\n")
|
| 101 |
+
|
| 102 |
+
self.balancing_report = {
|
| 103 |
+
'method': self.method,
|
| 104 |
+
'original_counts': original_counts,
|
| 105 |
+
'final_counts': final_counts,
|
| 106 |
+
'original_imbalance_ratio': imbalance_ratio,
|
| 107 |
+
'samples_before': len(y),
|
| 108 |
+
'samples_after': len(y_bal)
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
return X_bal, y_bal
|
| 112 |
+
|
| 113 |
+
def _apply_smote(self, X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, pd.Series]:
|
| 114 |
+
"""
|
| 115 |
+
Apply SMOTE oversampling.
|
| 116 |
+
Uses SMOTENC for mixed numeric/categorical data, regular SMOTE for all-numeric data.
|
| 117 |
+
"""
|
| 118 |
+
try:
|
| 119 |
+
# Check if we have categorical columns
|
| 120 |
+
categorical_cols = X.select_dtypes(include=['object', 'category']).columns.tolist()
|
| 121 |
+
|
| 122 |
+
if categorical_cols:
|
| 123 |
+
# Use SMOTENC for mixed data types
|
| 124 |
+
return self._apply_smotenc(X, y, categorical_cols)
|
| 125 |
+
else:
|
| 126 |
+
# Use regular SMOTE for all-numeric data
|
| 127 |
+
from imblearn.over_sampling import SMOTE
|
| 128 |
+
smote = SMOTE(
|
| 129 |
+
sampling_strategy=self.sampling_strategy,
|
| 130 |
+
random_state=self.random_state,
|
| 131 |
+
k_neighbors=min(5, y.value_counts().min() - 1)
|
| 132 |
+
)
|
| 133 |
+
X_res, y_res = smote.fit_resample(X, y)
|
| 134 |
+
return pd.DataFrame(X_res, columns=X.columns), pd.Series(y_res, name=y.name)
|
| 135 |
+
except ImportError:
|
| 136 |
+
print(" Warning: imbalanced-learn not installed. Using random oversampling instead.")
|
| 137 |
+
print(" Install with: pip install imbalanced-learn")
|
| 138 |
+
return self._apply_random_oversample(X, y)
|
| 139 |
+
except Exception as e:
|
| 140 |
+
print(f" Warning: SMOTE failed ({e}). Using random oversampling instead.")
|
| 141 |
+
return self._apply_random_oversample(X, y)
|
| 142 |
+
|
| 143 |
+
def _apply_smotenc(self, X: pd.DataFrame, y: pd.Series, categorical_cols: list) -> Tuple[pd.DataFrame, pd.Series]:
|
| 144 |
+
"""
|
| 145 |
+
Apply SMOTENC for datasets with mixed numeric and categorical features.
|
| 146 |
+
SMOTENC properly handles categorical columns by preserving their discrete nature.
|
| 147 |
+
"""
|
| 148 |
+
try:
|
| 149 |
+
from imblearn.over_sampling import SMOTENC
|
| 150 |
+
|
| 151 |
+
# Get categorical feature indices
|
| 152 |
+
categorical_indices = [X.columns.get_loc(col) for col in categorical_cols]
|
| 153 |
+
numeric_cols = [col for col in X.columns if col not in categorical_cols]
|
| 154 |
+
|
| 155 |
+
# Create a copy and handle missing values
|
| 156 |
+
X_encoded = X.copy()
|
| 157 |
+
|
| 158 |
+
# Track which numeric values were NaN (for restoration later)
|
| 159 |
+
numeric_nan_masks = {}
|
| 160 |
+
for col in numeric_cols:
|
| 161 |
+
if X_encoded[col].isna().any():
|
| 162 |
+
numeric_nan_masks[col] = X_encoded[col].isna()
|
| 163 |
+
# Fill with median for SMOTE (will be restored for original samples)
|
| 164 |
+
X_encoded[col] = X_encoded[col].fillna(X_encoded[col].median())
|
| 165 |
+
|
| 166 |
+
# Encode categorical columns
|
| 167 |
+
encoders = {}
|
| 168 |
+
for col in categorical_cols:
|
| 169 |
+
# Fill NaN with a placeholder before encoding
|
| 170 |
+
X_encoded[col] = X_encoded[col].fillna('__MISSING__')
|
| 171 |
+
# Convert to category codes
|
| 172 |
+
X_encoded[col] = X_encoded[col].astype('category')
|
| 173 |
+
encoders[col] = dict(enumerate(X_encoded[col].cat.categories))
|
| 174 |
+
X_encoded[col] = X_encoded[col].cat.codes
|
| 175 |
+
|
| 176 |
+
smotenc = SMOTENC(
|
| 177 |
+
categorical_features=categorical_indices,
|
| 178 |
+
sampling_strategy=self.sampling_strategy,
|
| 179 |
+
random_state=self.random_state,
|
| 180 |
+
k_neighbors=min(5, y.value_counts().min() - 1)
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
X_res, y_res = smotenc.fit_resample(X_encoded, y)
|
| 184 |
+
|
| 185 |
+
# Convert back to DataFrame
|
| 186 |
+
X_res = pd.DataFrame(X_res, columns=X.columns)
|
| 187 |
+
|
| 188 |
+
# Decode categorical columns back to original string values
|
| 189 |
+
for col in categorical_cols:
|
| 190 |
+
reverse_encoder = {v: k for k, v in encoders[col].items()}
|
| 191 |
+
# Ensure integer type for mapping, then convert to string
|
| 192 |
+
X_res[col] = X_res[col].round().astype(int).map(reverse_encoder)
|
| 193 |
+
# Convert to string type (required for sklearn's categorical handling)
|
| 194 |
+
X_res[col] = X_res[col].astype(str)
|
| 195 |
+
# Restore NaN for __MISSING__ values
|
| 196 |
+
X_res.loc[X_res[col] == '__MISSING__', col] = np.nan
|
| 197 |
+
|
| 198 |
+
# Note: For synthetic samples, NaN values are not restored because
|
| 199 |
+
# SMOTE creates new samples by interpolating existing values.
|
| 200 |
+
# This is the expected behavior.
|
| 201 |
+
|
| 202 |
+
return X_res, pd.Series(y_res, name=y.name)
|
| 203 |
+
|
| 204 |
+
except ImportError:
|
| 205 |
+
print(" Warning: SMOTENC not available. Using random oversampling instead.")
|
| 206 |
+
return self._apply_random_oversample(X, y)
|
| 207 |
+
except Exception as e:
|
| 208 |
+
print(f" Warning: SMOTENC failed ({e}). Using random oversampling instead.")
|
| 209 |
+
return self._apply_random_oversample(X, y)
|
| 210 |
+
|
| 211 |
+
def _apply_random_oversample(self, X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, pd.Series]:
|
| 212 |
+
"""Apply random oversampling."""
|
| 213 |
+
try:
|
| 214 |
+
from imblearn.over_sampling import RandomOverSampler
|
| 215 |
+
ros = RandomOverSampler(
|
| 216 |
+
sampling_strategy=self.sampling_strategy,
|
| 217 |
+
random_state=self.random_state
|
| 218 |
+
)
|
| 219 |
+
X_res, y_res = ros.fit_resample(X, y)
|
| 220 |
+
return pd.DataFrame(X_res, columns=X.columns), pd.Series(y_res, name=y.name)
|
| 221 |
+
except ImportError:
|
| 222 |
+
# Fallback: manual random oversampling
|
| 223 |
+
return self._manual_oversample(X, y)
|
| 224 |
+
|
| 225 |
+
def _apply_random_undersample(self, X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, pd.Series]:
|
| 226 |
+
"""Apply random undersampling."""
|
| 227 |
+
try:
|
| 228 |
+
from imblearn.under_sampling import RandomUnderSampler
|
| 229 |
+
rus = RandomUnderSampler(
|
| 230 |
+
sampling_strategy=self.sampling_strategy,
|
| 231 |
+
random_state=self.random_state
|
| 232 |
+
)
|
| 233 |
+
X_res, y_res = rus.fit_resample(X, y)
|
| 234 |
+
return pd.DataFrame(X_res, columns=X.columns), pd.Series(y_res, name=y.name)
|
| 235 |
+
except ImportError:
|
| 236 |
+
# Fallback: manual random undersampling
|
| 237 |
+
return self._manual_undersample(X, y)
|
| 238 |
+
|
| 239 |
+
def _apply_smote_tomek(self, X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, pd.Series]:
|
| 240 |
+
"""Apply SMOTE + Tomek links (combined over/undersampling)."""
|
| 241 |
+
try:
|
| 242 |
+
from imblearn.combine import SMOTETomek
|
| 243 |
+
smt = SMOTETomek(
|
| 244 |
+
sampling_strategy=self.sampling_strategy,
|
| 245 |
+
random_state=self.random_state
|
| 246 |
+
)
|
| 247 |
+
X_res, y_res = smt.fit_resample(X, y)
|
| 248 |
+
return pd.DataFrame(X_res, columns=X.columns), pd.Series(y_res, name=y.name)
|
| 249 |
+
except ImportError:
|
| 250 |
+
print(" Warning: imbalanced-learn not installed. Using SMOTE instead.")
|
| 251 |
+
return self._apply_smote(X, y)
|
| 252 |
+
|
| 253 |
+
def _apply_tomek_links(self, X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, pd.Series]:
|
| 254 |
+
"""
|
| 255 |
+
Apply Tomek Links for cleaning noisy/borderline samples.
|
| 256 |
+
Removes majority class samples that form Tomek links with minority class.
|
| 257 |
+
"""
|
| 258 |
+
try:
|
| 259 |
+
from imblearn.under_sampling import TomekLinks
|
| 260 |
+
tomek = TomekLinks(sampling_strategy='majority')
|
| 261 |
+
X_res, y_res = tomek.fit_resample(X, y)
|
| 262 |
+
return pd.DataFrame(X_res, columns=X.columns), pd.Series(y_res, name=y.name)
|
| 263 |
+
except ImportError:
|
| 264 |
+
print(" Warning: imbalanced-learn not installed. Skipping Tomek Links.")
|
| 265 |
+
return X, y
|
| 266 |
+
|
| 267 |
+
def _apply_nearmiss(self, X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, pd.Series]:
|
| 268 |
+
"""
|
| 269 |
+
Apply NearMiss undersampling.
|
| 270 |
+
Selects majority samples closest to minority samples.
|
| 271 |
+
"""
|
| 272 |
+
try:
|
| 273 |
+
from imblearn.under_sampling import NearMiss
|
| 274 |
+
nm = NearMiss(
|
| 275 |
+
sampling_strategy=self.sampling_strategy,
|
| 276 |
+
version=1 # NearMiss-1: closest to minority samples
|
| 277 |
+
)
|
| 278 |
+
X_res, y_res = nm.fit_resample(X, y)
|
| 279 |
+
return pd.DataFrame(X_res, columns=X.columns), pd.Series(y_res, name=y.name)
|
| 280 |
+
except ImportError:
|
| 281 |
+
print(" Warning: imbalanced-learn not installed. Using random undersampling instead.")
|
| 282 |
+
return self._apply_random_undersample(X, y)
|
| 283 |
+
|
| 284 |
+
def _manual_oversample(self, X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, pd.Series]:
|
| 285 |
+
"""Manual random oversampling without imbalanced-learn."""
|
| 286 |
+
counts = y.value_counts()
|
| 287 |
+
majority_count = counts.max()
|
| 288 |
+
|
| 289 |
+
dfs = []
|
| 290 |
+
for class_val in counts.index:
|
| 291 |
+
class_mask = y == class_val
|
| 292 |
+
class_df = X[class_mask]
|
| 293 |
+
class_y = y[class_mask]
|
| 294 |
+
|
| 295 |
+
if len(class_df) < majority_count:
|
| 296 |
+
# Oversample minority class
|
| 297 |
+
n_samples = majority_count - len(class_df)
|
| 298 |
+
sampled_idx = np.random.choice(class_df.index, size=n_samples, replace=True)
|
| 299 |
+
extra_X = class_df.loc[sampled_idx].reset_index(drop=True)
|
| 300 |
+
extra_y = class_y.loc[sampled_idx].reset_index(drop=True)
|
| 301 |
+
dfs.append((pd.concat([class_df, extra_X], ignore_index=True),
|
| 302 |
+
pd.concat([class_y, extra_y], ignore_index=True)))
|
| 303 |
+
else:
|
| 304 |
+
dfs.append((class_df.reset_index(drop=True), class_y.reset_index(drop=True)))
|
| 305 |
+
|
| 306 |
+
X_res = pd.concat([d[0] for d in dfs], ignore_index=True)
|
| 307 |
+
y_res = pd.concat([d[1] for d in dfs], ignore_index=True)
|
| 308 |
+
return X_res, y_res
|
| 309 |
+
|
| 310 |
+
def _manual_undersample(self, X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, pd.Series]:
|
| 311 |
+
"""Manual random undersampling without imbalanced-learn."""
|
| 312 |
+
counts = y.value_counts()
|
| 313 |
+
minority_count = counts.min()
|
| 314 |
+
|
| 315 |
+
dfs = []
|
| 316 |
+
for class_val in counts.index:
|
| 317 |
+
class_mask = y == class_val
|
| 318 |
+
class_df = X[class_mask]
|
| 319 |
+
class_y = y[class_mask]
|
| 320 |
+
|
| 321 |
+
if len(class_df) > minority_count:
|
| 322 |
+
# Undersample majority class
|
| 323 |
+
sampled_idx = np.random.choice(class_df.index, size=minority_count, replace=False)
|
| 324 |
+
dfs.append((class_df.loc[sampled_idx].reset_index(drop=True),
|
| 325 |
+
class_y.loc[sampled_idx].reset_index(drop=True)))
|
| 326 |
+
else:
|
| 327 |
+
dfs.append((class_df.reset_index(drop=True), class_y.reset_index(drop=True)))
|
| 328 |
+
|
| 329 |
+
X_res = pd.concat([d[0] for d in dfs], ignore_index=True)
|
| 330 |
+
y_res = pd.concat([d[1] for d in dfs], ignore_index=True)
|
| 331 |
+
return X_res, y_res
|
| 332 |
+
|
| 333 |
+
def get_report(self) -> Dict[str, Any]:
|
| 334 |
+
"""Get the balancing report."""
|
| 335 |
+
return self.balancing_report
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def check_imbalance(y: pd.Series) -> Dict[str, Any]:
|
| 339 |
+
"""
|
| 340 |
+
Check class imbalance in target variable.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
y: Target series
|
| 344 |
+
|
| 345 |
+
Returns:
|
| 346 |
+
Dictionary with imbalance statistics
|
| 347 |
+
"""
|
| 348 |
+
counts = y.value_counts()
|
| 349 |
+
majority = counts.max()
|
| 350 |
+
minority = counts.min()
|
| 351 |
+
ratio = majority / minority if minority > 0 else float('inf')
|
| 352 |
+
|
| 353 |
+
return {
|
| 354 |
+
'class_counts': counts.to_dict(),
|
| 355 |
+
'majority_class': counts.idxmax(),
|
| 356 |
+
'minority_class': counts.idxmin(),
|
| 357 |
+
'imbalance_ratio': ratio,
|
| 358 |
+
'is_imbalanced': ratio > 1.5,
|
| 359 |
+
'recommendation': _get_recommendation(ratio, len(y))
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def _get_recommendation(ratio: float, n_samples: int) -> str:
|
| 364 |
+
"""Get balancing recommendation based on imbalance ratio."""
|
| 365 |
+
if ratio < 1.5:
|
| 366 |
+
return "Dataset is balanced. No action needed."
|
| 367 |
+
elif ratio < 3:
|
| 368 |
+
return "Mild imbalance. Consider using class_weight='balanced' or SMOTE."
|
| 369 |
+
elif ratio < 10:
|
| 370 |
+
return "Moderate imbalance. Use SMOTE or random oversampling."
|
| 371 |
+
else:
|
| 372 |
+
if n_samples > 10000:
|
| 373 |
+
return "Severe imbalance. Use SMOTE-Tomek or random undersampling."
|
| 374 |
+
else:
|
| 375 |
+
return "Severe imbalance with limited data. Use SMOTE carefully to avoid overfitting."
|
credily/cleaning.py
ADDED
|
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data cleaning module for Credily.
|
| 3 |
+
Comprehensive data cleaning for ML readiness with thorough defaults.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
import re
|
| 9 |
+
from typing import Optional, Dict, Any, List, Tuple
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DataCleaner:
|
| 13 |
+
"""
|
| 14 |
+
Comprehensive data cleaner for tabular datasets.
|
| 15 |
+
Default mode is 'thorough' which applies all best practices automatically.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
CLEAN_MODES = ['basic', 'thorough', 'aggressive']
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
target_column: Optional[str] = None,
|
| 23 |
+
clean_mode: str = 'thorough',
|
| 24 |
+
|
| 25 |
+
# Outlier settings
|
| 26 |
+
outlier_method: str = 'iqr',
|
| 27 |
+
outlier_threshold: float = 3.0, # was 1.5 → preserve tail risk (defaults live here)
|
| 28 |
+
|
| 29 |
+
# Missing value settings
|
| 30 |
+
max_missing_threshold: float = 0.5,
|
| 31 |
+
flag_missing: bool = True,
|
| 32 |
+
|
| 33 |
+
# Duplicate settings
|
| 34 |
+
remove_duplicates: bool = True,
|
| 35 |
+
|
| 36 |
+
# Feature settings
|
| 37 |
+
remove_low_variance: bool = True,
|
| 38 |
+
variance_threshold: float = 0.001, # was 0.01 → keep rare but important risk flags
|
| 39 |
+
remove_high_correlation: bool = True,
|
| 40 |
+
correlation_threshold: float = 0.95,
|
| 41 |
+
max_cardinality: int = 100, # was 50 → preserve categorical risk segmentation
|
| 42 |
+
|
| 43 |
+
# Validation
|
| 44 |
+
validate_negative: bool = True
|
| 45 |
+
|
| 46 |
+
):
|
| 47 |
+
"""
|
| 48 |
+
Initialize the DataCleaner.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
target_column: Name of target column (excluded from cleaning)
|
| 52 |
+
clean_mode: 'basic', 'thorough' (default), or 'aggressive'
|
| 53 |
+
outlier_method: 'iqr', 'zscore', or 'none'
|
| 54 |
+
outlier_threshold: Threshold for outlier detection
|
| 55 |
+
max_missing_threshold: Drop columns with more than this % missing
|
| 56 |
+
flag_missing: Create indicator columns for missing values
|
| 57 |
+
remove_duplicates: Whether to remove duplicate rows
|
| 58 |
+
remove_low_variance: Remove near-constant features
|
| 59 |
+
variance_threshold: Minimum variance to keep feature
|
| 60 |
+
remove_high_correlation: Remove highly correlated features
|
| 61 |
+
correlation_threshold: Max correlation allowed
|
| 62 |
+
max_cardinality: Max unique values for categorical encoding
|
| 63 |
+
validate_negative: Flag unexpected negative values
|
| 64 |
+
"""
|
| 65 |
+
if clean_mode not in self.CLEAN_MODES:
|
| 66 |
+
raise ValueError(f"clean_mode must be one of {self.CLEAN_MODES}")
|
| 67 |
+
|
| 68 |
+
self.target_column = target_column
|
| 69 |
+
self.clean_mode = clean_mode
|
| 70 |
+
self.outlier_method = outlier_method
|
| 71 |
+
self.outlier_threshold = outlier_threshold
|
| 72 |
+
self.max_missing_threshold = max_missing_threshold
|
| 73 |
+
self.flag_missing = flag_missing
|
| 74 |
+
self.remove_duplicates = remove_duplicates
|
| 75 |
+
self.remove_low_variance = remove_low_variance
|
| 76 |
+
self.variance_threshold = variance_threshold
|
| 77 |
+
self.remove_high_correlation = remove_high_correlation
|
| 78 |
+
self.correlation_threshold = correlation_threshold
|
| 79 |
+
self.max_cardinality = max_cardinality
|
| 80 |
+
self.validate_negative = validate_negative
|
| 81 |
+
self.cleaning_report = {}
|
| 82 |
+
|
| 83 |
+
# Adjust settings based on mode
|
| 84 |
+
self._apply_mode_settings()
|
| 85 |
+
|
| 86 |
+
def _apply_mode_settings(self):
|
| 87 |
+
"""Adjust settings based on clean_mode."""
|
| 88 |
+
if self.clean_mode == 'basic':
|
| 89 |
+
self.flag_missing = False
|
| 90 |
+
self.remove_low_variance = False
|
| 91 |
+
self.remove_high_correlation = False
|
| 92 |
+
self.validate_negative = False
|
| 93 |
+
elif self.clean_mode == 'aggressive':
|
| 94 |
+
self.max_missing_threshold = 0.3
|
| 95 |
+
self.variance_threshold = 0.05
|
| 96 |
+
self.correlation_threshold = 0.9
|
| 97 |
+
self.max_cardinality = 30
|
| 98 |
+
|
| 99 |
+
def clean(self, df: pd.DataFrame, verbose: bool = True) -> pd.DataFrame:
|
| 100 |
+
"""
|
| 101 |
+
Clean the dataset comprehensively.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
df: Input dataframe
|
| 105 |
+
verbose: Print cleaning steps
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
Cleaned dataframe
|
| 109 |
+
"""
|
| 110 |
+
original_shape = df.shape
|
| 111 |
+
df_clean = df.copy()
|
| 112 |
+
|
| 113 |
+
# Initialize tracking
|
| 114 |
+
report = {
|
| 115 |
+
'original_shape': original_shape,
|
| 116 |
+
'steps_applied': [],
|
| 117 |
+
'warnings': []
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
if verbose:
|
| 121 |
+
print(f"\n{'='*60}")
|
| 122 |
+
print(f"DATA CLEANING (Mode: {self.clean_mode})")
|
| 123 |
+
print(f"{'='*60}")
|
| 124 |
+
print(f"Original shape: {original_shape[0]} rows, {original_shape[1]} columns")
|
| 125 |
+
|
| 126 |
+
# ===== STEP 0: Drop Unnamed columns (common pandas artifact) =====
|
| 127 |
+
unnamed_cols = [col for col in df_clean.columns if col.startswith('Unnamed')]
|
| 128 |
+
if unnamed_cols:
|
| 129 |
+
df_clean = df_clean.drop(columns=unnamed_cols)
|
| 130 |
+
report['unnamed_dropped'] = unnamed_cols
|
| 131 |
+
report['steps_applied'].append('drop_unnamed')
|
| 132 |
+
if verbose:
|
| 133 |
+
print(f" [0] Dropped {len(unnamed_cols)} 'Unnamed' columns")
|
| 134 |
+
|
| 135 |
+
# ===== STEP 0.5: Drop ID columns =====
|
| 136 |
+
df_clean, id_cols_dropped = self._drop_id_columns(df_clean)
|
| 137 |
+
if id_cols_dropped:
|
| 138 |
+
report['id_columns_dropped'] = id_cols_dropped
|
| 139 |
+
report['steps_applied'].append('drop_id_columns')
|
| 140 |
+
if verbose:
|
| 141 |
+
print(f" [0.5] Dropped {len(id_cols_dropped)} ID columns: {id_cols_dropped}")
|
| 142 |
+
|
| 143 |
+
# ===== STEP 1: Remove exact duplicates =====
|
| 144 |
+
if self.remove_duplicates:
|
| 145 |
+
before = len(df_clean)
|
| 146 |
+
df_clean = df_clean.drop_duplicates()
|
| 147 |
+
removed = before - len(df_clean)
|
| 148 |
+
if removed > 0:
|
| 149 |
+
report['duplicates_removed'] = removed
|
| 150 |
+
report['steps_applied'].append('remove_duplicates')
|
| 151 |
+
if verbose:
|
| 152 |
+
print(f" [1] Removed {removed} duplicate rows")
|
| 153 |
+
|
| 154 |
+
# ===== STEP 2: Handle invalid/placeholder values =====
|
| 155 |
+
df_clean, invalid_count = self._replace_invalid_values(df_clean)
|
| 156 |
+
if invalid_count > 0:
|
| 157 |
+
report['invalid_values_replaced'] = invalid_count
|
| 158 |
+
report['steps_applied'].append('replace_invalid_values')
|
| 159 |
+
if verbose:
|
| 160 |
+
print(f" [2] Replaced {invalid_count} invalid/placeholder values")
|
| 161 |
+
|
| 162 |
+
# ===== STEP 3: Fix data types =====
|
| 163 |
+
df_clean, type_fixes = self._fix_data_types(df_clean)
|
| 164 |
+
if type_fixes:
|
| 165 |
+
report['type_fixes'] = type_fixes
|
| 166 |
+
report['steps_applied'].append('fix_data_types')
|
| 167 |
+
if verbose:
|
| 168 |
+
print(f" [3] Fixed data types for {len(type_fixes)} columns")
|
| 169 |
+
|
| 170 |
+
# ===== STEP 4: Handle infinite values =====
|
| 171 |
+
df_clean, inf_count = self._handle_infinite_values(df_clean)
|
| 172 |
+
if inf_count > 0:
|
| 173 |
+
report['infinite_replaced'] = inf_count
|
| 174 |
+
report['steps_applied'].append('handle_infinite')
|
| 175 |
+
if verbose:
|
| 176 |
+
print(f" [4] Replaced {inf_count} infinite values")
|
| 177 |
+
|
| 178 |
+
# ===== STEP 5: Drop high-missing columns =====
|
| 179 |
+
df_clean, cols_dropped = self._drop_high_missing_columns(df_clean)
|
| 180 |
+
if cols_dropped:
|
| 181 |
+
report['high_missing_dropped'] = cols_dropped
|
| 182 |
+
report['steps_applied'].append('drop_high_missing')
|
| 183 |
+
if verbose:
|
| 184 |
+
print(f" [5] Dropped {len(cols_dropped)} high-missing columns: {cols_dropped[:5]}{'...' if len(cols_dropped) > 5 else ''}")
|
| 185 |
+
|
| 186 |
+
# ===== STEP 6: Flag missing values (create indicator columns) =====
|
| 187 |
+
if self.flag_missing:
|
| 188 |
+
df_clean, missing_flags = self._flag_missing_values(df_clean)
|
| 189 |
+
if missing_flags:
|
| 190 |
+
report['missing_flags_created'] = missing_flags
|
| 191 |
+
report['steps_applied'].append('flag_missing')
|
| 192 |
+
if verbose:
|
| 193 |
+
print(f" [6] Created {len(missing_flags)} missing value indicator columns")
|
| 194 |
+
|
| 195 |
+
# ===== STEP 7: Handle outliers =====
|
| 196 |
+
if self.outlier_method != 'none':
|
| 197 |
+
df_clean, outlier_count = self._handle_outliers(df_clean)
|
| 198 |
+
if outlier_count > 0:
|
| 199 |
+
report['outliers_capped'] = outlier_count
|
| 200 |
+
report['steps_applied'].append('handle_outliers')
|
| 201 |
+
if verbose:
|
| 202 |
+
print(f" [7] Capped {outlier_count} outlier values (method: {self.outlier_method})")
|
| 203 |
+
|
| 204 |
+
# ===== STEP 8: Validate negative values =====
|
| 205 |
+
if self.validate_negative:
|
| 206 |
+
df_clean, neg_issues = self._validate_negative_values(df_clean)
|
| 207 |
+
if neg_issues:
|
| 208 |
+
report['negative_value_issues'] = neg_issues
|
| 209 |
+
report['warnings'].append(f"Unexpected negative values in: {list(neg_issues.keys())}")
|
| 210 |
+
if verbose:
|
| 211 |
+
print(f" [8] Found unexpected negatives in {len(neg_issues)} columns")
|
| 212 |
+
|
| 213 |
+
# ===== STEP 9: Fix categorical inconsistencies =====
|
| 214 |
+
df_clean, cat_fixes = self._fix_categorical_inconsistencies(df_clean)
|
| 215 |
+
if cat_fixes:
|
| 216 |
+
report['categorical_fixes'] = cat_fixes
|
| 217 |
+
report['steps_applied'].append('fix_categorical')
|
| 218 |
+
if verbose:
|
| 219 |
+
print(f" [9] Standardized {len(cat_fixes)} categorical columns")
|
| 220 |
+
|
| 221 |
+
# ===== STEP 10: Handle high cardinality categoricals =====
|
| 222 |
+
df_clean, high_card = self._handle_high_cardinality(df_clean)
|
| 223 |
+
if high_card:
|
| 224 |
+
report['high_cardinality_handled'] = high_card
|
| 225 |
+
report['steps_applied'].append('handle_high_cardinality')
|
| 226 |
+
if verbose:
|
| 227 |
+
print(f" [10] Handled {len(high_card)} high-cardinality columns")
|
| 228 |
+
|
| 229 |
+
# ===== STEP 11: Remove low variance features =====
|
| 230 |
+
if self.remove_low_variance:
|
| 231 |
+
df_clean, low_var_removed = self._remove_low_variance_features(df_clean)
|
| 232 |
+
if low_var_removed:
|
| 233 |
+
report['low_variance_removed'] = low_var_removed
|
| 234 |
+
report['steps_applied'].append('remove_low_variance')
|
| 235 |
+
if verbose:
|
| 236 |
+
print(f" [11] Removed {len(low_var_removed)} low-variance columns")
|
| 237 |
+
|
| 238 |
+
# ===== STEP 12: Remove highly correlated features =====
|
| 239 |
+
if self.remove_high_correlation:
|
| 240 |
+
df_clean, corr_removed = self._remove_correlated_features(df_clean)
|
| 241 |
+
if corr_removed:
|
| 242 |
+
report['correlated_removed'] = corr_removed
|
| 243 |
+
report['steps_applied'].append('remove_correlated')
|
| 244 |
+
if verbose:
|
| 245 |
+
print(f" [12] Removed {len(corr_removed)} highly correlated columns")
|
| 246 |
+
|
| 247 |
+
# ===== STEP 13: Validate target labels =====
|
| 248 |
+
if self.target_column and self.target_column in df_clean.columns:
|
| 249 |
+
df_clean, label_issues = self._validate_labels(df_clean)
|
| 250 |
+
report['label_validation'] = label_issues
|
| 251 |
+
report['steps_applied'].append('validate_labels')
|
| 252 |
+
if verbose:
|
| 253 |
+
print(f" [13] Target validation: {label_issues.get('status', 'complete')}")
|
| 254 |
+
|
| 255 |
+
# ===== STEP 14: Final missing value check =====
|
| 256 |
+
missing_summary = df_clean.isnull().sum()
|
| 257 |
+
cols_with_missing = missing_summary[missing_summary > 0]
|
| 258 |
+
if len(cols_with_missing) > 0:
|
| 259 |
+
report['remaining_missing'] = cols_with_missing.to_dict()
|
| 260 |
+
if verbose:
|
| 261 |
+
print(f" [14] Remaining missing values in {len(cols_with_missing)} columns (will be imputed during training)")
|
| 262 |
+
|
| 263 |
+
# Final summary
|
| 264 |
+
final_shape = df_clean.shape
|
| 265 |
+
report['final_shape'] = final_shape
|
| 266 |
+
report['rows_removed'] = original_shape[0] - final_shape[0]
|
| 267 |
+
report['columns_removed'] = original_shape[1] - final_shape[1]
|
| 268 |
+
|
| 269 |
+
if verbose:
|
| 270 |
+
print(f"\n{'='*60}")
|
| 271 |
+
print(f"CLEANING COMPLETE")
|
| 272 |
+
print(f" Final shape: {final_shape[0]} rows, {final_shape[1]} columns")
|
| 273 |
+
print(f" Rows removed: {report['rows_removed']}")
|
| 274 |
+
print(f" Columns removed: {report['columns_removed']}")
|
| 275 |
+
print(f" Steps applied: {len(report['steps_applied'])}")
|
| 276 |
+
if report['warnings']:
|
| 277 |
+
print(f" Warnings: {len(report['warnings'])}")
|
| 278 |
+
print(f"{'='*60}\n")
|
| 279 |
+
|
| 280 |
+
self.cleaning_report = report
|
| 281 |
+
return df_clean
|
| 282 |
+
|
| 283 |
+
def _drop_id_columns(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]:
|
| 284 |
+
"""
|
| 285 |
+
Detect and drop ID columns that provide no predictive value.
|
| 286 |
+
|
| 287 |
+
ID columns are identified by:
|
| 288 |
+
1. Column name patterns (id, index, key, etc.)
|
| 289 |
+
2. Unique value count equals row count (for non-numeric columns)
|
| 290 |
+
3. Sequential integer patterns
|
| 291 |
+
"""
|
| 292 |
+
id_cols = []
|
| 293 |
+
n_rows = len(df)
|
| 294 |
+
|
| 295 |
+
# Common ID column name patterns (case-insensitive)
|
| 296 |
+
id_patterns = [
|
| 297 |
+
r'^id$', r'^_id$', r'_id$', r'^index$', r'^idx$',
|
| 298 |
+
r'^row_?num', r'^row_?id', r'^record_?id', r'^key$',
|
| 299 |
+
r'^pk$', r'^primary_?key', r'^unique_?id', r'^uuid$',
|
| 300 |
+
r'^guid$', r'^serial', r'^sequence', r'^customer_?id$',
|
| 301 |
+
r'^user_?id$', r'^account_?id$', r'^transaction_?id$',
|
| 302 |
+
r'^loan_?id$', r'^application_?id$', r'^case_?id$',
|
| 303 |
+
r'^member_?id$', r'^client_?id$', r'^order_?id$',
|
| 304 |
+
r'^sk_id', r'^member_id$'
|
| 305 |
+
]
|
| 306 |
+
|
| 307 |
+
for col in df.columns:
|
| 308 |
+
# Skip target column
|
| 309 |
+
if col == self.target_column:
|
| 310 |
+
continue
|
| 311 |
+
|
| 312 |
+
col_lower = col.lower()
|
| 313 |
+
|
| 314 |
+
# Check 1: Name-based detection
|
| 315 |
+
is_id_name = any(re.match(pattern, col_lower) for pattern in id_patterns)
|
| 316 |
+
|
| 317 |
+
# Check 2: All unique values (non-numeric columns with 100% unique)
|
| 318 |
+
is_all_unique = False
|
| 319 |
+
if df[col].dtype == 'object':
|
| 320 |
+
n_unique = df[col].nunique()
|
| 321 |
+
if n_unique == n_rows:
|
| 322 |
+
is_all_unique = True
|
| 323 |
+
|
| 324 |
+
# Check 3: Sequential integers (likely auto-increment ID)
|
| 325 |
+
is_sequential = False
|
| 326 |
+
if pd.api.types.is_integer_dtype(df[col]):
|
| 327 |
+
sorted_vals = df[col].dropna().sort_values()
|
| 328 |
+
if len(sorted_vals) > 1:
|
| 329 |
+
diffs = sorted_vals.diff().dropna()
|
| 330 |
+
# Check if mostly sequential (increments of 1)
|
| 331 |
+
if (diffs == 1).mean() > 0.95:
|
| 332 |
+
# Also check if it spans nearly the full range
|
| 333 |
+
if sorted_vals.iloc[-1] - sorted_vals.iloc[0] >= n_rows * 0.9:
|
| 334 |
+
is_sequential = True
|
| 335 |
+
|
| 336 |
+
# Drop if matches any criteria
|
| 337 |
+
if is_id_name or is_all_unique or is_sequential:
|
| 338 |
+
id_cols.append(col)
|
| 339 |
+
|
| 340 |
+
if id_cols:
|
| 341 |
+
df = df.drop(columns=id_cols)
|
| 342 |
+
|
| 343 |
+
return df, id_cols
|
| 344 |
+
|
| 345 |
+
def _replace_invalid_values(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, int]:
|
| 346 |
+
"""Replace common invalid/placeholder values with NaN."""
|
| 347 |
+
# Note: '-' removed from invalid patterns as it can be a valid label (e.g., +/- classification)
|
| 348 |
+
invalid_patterns = ['?', 'N/A', 'n/a', 'NA', 'na', 'NULL', 'null',
|
| 349 |
+
'None', 'none', 'NaN', 'nan', '', ' ', '--',
|
| 350 |
+
'missing', 'Missing', 'MISSING', 'unknown', 'Unknown']
|
| 351 |
+
|
| 352 |
+
count = 0
|
| 353 |
+
for col in df.columns:
|
| 354 |
+
# Skip target column to preserve labels
|
| 355 |
+
if col == self.target_column:
|
| 356 |
+
continue
|
| 357 |
+
if df[col].dtype == 'object':
|
| 358 |
+
mask = df[col].isin(invalid_patterns)
|
| 359 |
+
count += mask.sum()
|
| 360 |
+
df.loc[mask, col] = np.nan
|
| 361 |
+
|
| 362 |
+
return df, count
|
| 363 |
+
|
| 364 |
+
def _fix_data_types(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]:
|
| 365 |
+
"""Fix common data type issues."""
|
| 366 |
+
fixed_cols = []
|
| 367 |
+
|
| 368 |
+
for col in df.columns:
|
| 369 |
+
if col == self.target_column:
|
| 370 |
+
continue
|
| 371 |
+
|
| 372 |
+
if df[col].dtype == 'object':
|
| 373 |
+
# Try to convert to numeric
|
| 374 |
+
numeric_converted = pd.to_numeric(df[col], errors='coerce')
|
| 375 |
+
non_null_original = df[col].notna().sum()
|
| 376 |
+
non_null_converted = numeric_converted.notna().sum()
|
| 377 |
+
|
| 378 |
+
# Only convert if we don't lose too much data
|
| 379 |
+
if non_null_converted >= non_null_original * 0.9:
|
| 380 |
+
df[col] = numeric_converted
|
| 381 |
+
fixed_cols.append(col)
|
| 382 |
+
|
| 383 |
+
return df, fixed_cols
|
| 384 |
+
|
| 385 |
+
def _handle_infinite_values(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, int]:
|
| 386 |
+
"""Replace infinite values with NaN."""
|
| 387 |
+
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
| 388 |
+
inf_count = 0
|
| 389 |
+
for col in numeric_cols:
|
| 390 |
+
mask = np.isinf(df[col])
|
| 391 |
+
inf_count += mask.sum()
|
| 392 |
+
df.loc[mask, col] = np.nan
|
| 393 |
+
return df, inf_count
|
| 394 |
+
|
| 395 |
+
def _drop_high_missing_columns(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]:
|
| 396 |
+
"""Drop columns with too many missing values."""
|
| 397 |
+
missing_pct = df.isnull().sum() / len(df)
|
| 398 |
+
cols_to_drop = missing_pct[missing_pct > self.max_missing_threshold].index.tolist()
|
| 399 |
+
|
| 400 |
+
if self.target_column in cols_to_drop:
|
| 401 |
+
cols_to_drop.remove(self.target_column)
|
| 402 |
+
|
| 403 |
+
df = df.drop(columns=cols_to_drop)
|
| 404 |
+
return df, cols_to_drop
|
| 405 |
+
|
| 406 |
+
def _flag_missing_values(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]:
|
| 407 |
+
"""Create indicator columns for features with significant missing values."""
|
| 408 |
+
flags_created = []
|
| 409 |
+
missing_pct = df.isnull().sum() / len(df)
|
| 410 |
+
|
| 411 |
+
# Only flag columns with 5-50% missing (meaningful missingness)
|
| 412 |
+
cols_to_flag = missing_pct[(missing_pct >= 0.05) & (missing_pct <= 0.5)].index.tolist()
|
| 413 |
+
|
| 414 |
+
if self.target_column in cols_to_flag:
|
| 415 |
+
cols_to_flag.remove(self.target_column)
|
| 416 |
+
|
| 417 |
+
for col in cols_to_flag:
|
| 418 |
+
flag_col = f"{col}_missing"
|
| 419 |
+
df[flag_col] = df[col].isnull().astype(int)
|
| 420 |
+
flags_created.append(flag_col)
|
| 421 |
+
|
| 422 |
+
return df, flags_created
|
| 423 |
+
|
| 424 |
+
def _handle_outliers(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, int]:
|
| 425 |
+
"""Handle outliers using specified method."""
|
| 426 |
+
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
| 427 |
+
|
| 428 |
+
if self.target_column in numeric_cols:
|
| 429 |
+
numeric_cols.remove(self.target_column)
|
| 430 |
+
|
| 431 |
+
# Also exclude flag columns
|
| 432 |
+
numeric_cols = [c for c in numeric_cols if not c.endswith('_missing')]
|
| 433 |
+
|
| 434 |
+
outliers_count = 0
|
| 435 |
+
|
| 436 |
+
for col in numeric_cols:
|
| 437 |
+
if df[col].nunique() < 3: # Skip near-constant columns
|
| 438 |
+
continue
|
| 439 |
+
|
| 440 |
+
if self.outlier_method == 'iqr':
|
| 441 |
+
Q1 = df[col].quantile(0.25)
|
| 442 |
+
Q3 = df[col].quantile(0.75)
|
| 443 |
+
IQR = Q3 - Q1
|
| 444 |
+
if IQR == 0:
|
| 445 |
+
continue
|
| 446 |
+
lower = Q1 - self.outlier_threshold * IQR
|
| 447 |
+
upper = Q3 + self.outlier_threshold * IQR
|
| 448 |
+
elif self.outlier_method == 'zscore':
|
| 449 |
+
mean = df[col].mean()
|
| 450 |
+
std = df[col].std()
|
| 451 |
+
if std == 0:
|
| 452 |
+
continue
|
| 453 |
+
lower = mean - self.outlier_threshold * std
|
| 454 |
+
upper = mean + self.outlier_threshold * std
|
| 455 |
+
else:
|
| 456 |
+
continue
|
| 457 |
+
|
| 458 |
+
outliers = ((df[col] < lower) | (df[col] > upper)).sum()
|
| 459 |
+
outliers_count += outliers
|
| 460 |
+
df[col] = df[col].clip(lower=lower, upper=upper)
|
| 461 |
+
|
| 462 |
+
return df, outliers_count
|
| 463 |
+
|
| 464 |
+
def _validate_negative_values(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, int]]:
|
| 465 |
+
"""Check for unexpected negative values in typically positive columns."""
|
| 466 |
+
issues = {}
|
| 467 |
+
positive_keywords = ['age', 'income', 'salary', 'amount', 'balance', 'count',
|
| 468 |
+
'quantity', 'price', 'rate', 'score', 'years', 'months']
|
| 469 |
+
|
| 470 |
+
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
| 471 |
+
|
| 472 |
+
for col in numeric_cols:
|
| 473 |
+
if col == self.target_column:
|
| 474 |
+
continue
|
| 475 |
+
col_lower = col.lower()
|
| 476 |
+
if any(kw in col_lower for kw in positive_keywords):
|
| 477 |
+
neg_count = (df[col] < 0).sum()
|
| 478 |
+
if neg_count > 0:
|
| 479 |
+
issues[col] = neg_count
|
| 480 |
+
# Optionally clip to 0
|
| 481 |
+
df[col] = df[col].clip(lower=0)
|
| 482 |
+
|
| 483 |
+
return df, issues
|
| 484 |
+
|
| 485 |
+
def _fix_categorical_inconsistencies(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]:
|
| 486 |
+
"""Fix common inconsistencies in categorical columns."""
|
| 487 |
+
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
|
| 488 |
+
|
| 489 |
+
# Skip target column - labels should be handled separately
|
| 490 |
+
if self.target_column in categorical_cols:
|
| 491 |
+
categorical_cols.remove(self.target_column)
|
| 492 |
+
|
| 493 |
+
fixed_cols = []
|
| 494 |
+
|
| 495 |
+
for col in categorical_cols:
|
| 496 |
+
original_unique = df[col].nunique()
|
| 497 |
+
|
| 498 |
+
if df[col].dtype == 'object':
|
| 499 |
+
# Strip whitespace and standardize case
|
| 500 |
+
df[col] = df[col].astype(str).str.strip().str.lower()
|
| 501 |
+
|
| 502 |
+
# Common replacements (for features, not labels)
|
| 503 |
+
replacements = {
|
| 504 |
+
'yes': 'yes', 'y': 'yes', 'true': 'yes',
|
| 505 |
+
'no': 'no', 'n': 'no', 'false': 'no',
|
| 506 |
+
'male': 'male', 'm': 'male', 'man': 'male',
|
| 507 |
+
'female': 'female', 'f': 'female', 'woman': 'female',
|
| 508 |
+
'nan': np.nan, 'none': np.nan, 'null': np.nan, '': np.nan
|
| 509 |
+
}
|
| 510 |
+
df[col] = df[col].replace(replacements)
|
| 511 |
+
|
| 512 |
+
new_unique = df[col].nunique()
|
| 513 |
+
if new_unique < original_unique:
|
| 514 |
+
fixed_cols.append(col)
|
| 515 |
+
|
| 516 |
+
return df, fixed_cols
|
| 517 |
+
|
| 518 |
+
def _handle_high_cardinality(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]:
|
| 519 |
+
"""Handle categorical columns with too many unique values."""
|
| 520 |
+
handled = []
|
| 521 |
+
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
|
| 522 |
+
|
| 523 |
+
if self.target_column in categorical_cols:
|
| 524 |
+
categorical_cols.remove(self.target_column)
|
| 525 |
+
|
| 526 |
+
for col in categorical_cols:
|
| 527 |
+
n_unique = df[col].nunique()
|
| 528 |
+
if n_unique > self.max_cardinality:
|
| 529 |
+
# Keep top categories, group rest as 'other'
|
| 530 |
+
top_cats = df[col].value_counts().head(self.max_cardinality - 1).index.tolist()
|
| 531 |
+
df[col] = df[col].apply(lambda x: x if x in top_cats else 'other')
|
| 532 |
+
handled.append(col)
|
| 533 |
+
|
| 534 |
+
return df, handled
|
| 535 |
+
|
| 536 |
+
def _remove_low_variance_features(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]:
|
| 537 |
+
"""Remove features with very low variance (near-constant)."""
|
| 538 |
+
removed = []
|
| 539 |
+
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
| 540 |
+
|
| 541 |
+
if self.target_column in numeric_cols:
|
| 542 |
+
numeric_cols.remove(self.target_column)
|
| 543 |
+
|
| 544 |
+
# Exclude flag columns
|
| 545 |
+
numeric_cols = [c for c in numeric_cols if not c.endswith('_missing')]
|
| 546 |
+
|
| 547 |
+
for col in numeric_cols:
|
| 548 |
+
variance = df[col].var()
|
| 549 |
+
if variance is not None and variance < self.variance_threshold:
|
| 550 |
+
df = df.drop(columns=[col])
|
| 551 |
+
removed.append(col)
|
| 552 |
+
|
| 553 |
+
return df, removed
|
| 554 |
+
|
| 555 |
+
def _remove_correlated_features(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]:
|
| 556 |
+
"""Remove highly correlated features (keep first, remove duplicates)."""
|
| 557 |
+
removed = []
|
| 558 |
+
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
| 559 |
+
|
| 560 |
+
if self.target_column in numeric_cols:
|
| 561 |
+
numeric_cols.remove(self.target_column)
|
| 562 |
+
|
| 563 |
+
# Exclude flag columns
|
| 564 |
+
numeric_cols = [c for c in numeric_cols if not c.endswith('_missing')]
|
| 565 |
+
|
| 566 |
+
if len(numeric_cols) < 2:
|
| 567 |
+
return df, removed
|
| 568 |
+
|
| 569 |
+
corr_matrix = df[numeric_cols].corr().abs()
|
| 570 |
+
upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
|
| 571 |
+
|
| 572 |
+
to_drop = [col for col in upper.columns if any(upper[col] > self.correlation_threshold)]
|
| 573 |
+
df = df.drop(columns=to_drop)
|
| 574 |
+
removed = to_drop
|
| 575 |
+
|
| 576 |
+
return df, removed
|
| 577 |
+
|
| 578 |
+
def _validate_labels(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, Any]]:
|
| 579 |
+
"""
|
| 580 |
+
Validate and standardize target labels for binary classification.
|
| 581 |
+
Data-agnostic: simply maps any two classes to 0 and 1.
|
| 582 |
+
"""
|
| 583 |
+
issues = {}
|
| 584 |
+
target = df[self.target_column]
|
| 585 |
+
|
| 586 |
+
unique_values = target.dropna().unique()
|
| 587 |
+
n_classes = len(unique_values)
|
| 588 |
+
|
| 589 |
+
issues['n_classes'] = n_classes
|
| 590 |
+
issues['unique_values'] = [str(v) for v in unique_values]
|
| 591 |
+
|
| 592 |
+
# Check for missing labels
|
| 593 |
+
missing_labels = target.isnull().sum()
|
| 594 |
+
if missing_labels > 0:
|
| 595 |
+
issues['missing_labels'] = int(missing_labels)
|
| 596 |
+
df = df.dropna(subset=[self.target_column])
|
| 597 |
+
# Recalculate unique values after dropping nulls
|
| 598 |
+
unique_values = df[self.target_column].dropna().unique()
|
| 599 |
+
n_classes = len(unique_values)
|
| 600 |
+
|
| 601 |
+
# Standardize binary labels to 0/1
|
| 602 |
+
if n_classes == 2:
|
| 603 |
+
# Data-agnostic approach: sort values and map first to 0, second to 1
|
| 604 |
+
sorted_vals = sorted(unique_values, key=lambda x: str(x))
|
| 605 |
+
label_map = {sorted_vals[0]: 0, sorted_vals[1]: 1}
|
| 606 |
+
|
| 607 |
+
df[self.target_column] = df[self.target_column].map(label_map)
|
| 608 |
+
issues['label_mapping'] = {str(k): v for k, v in label_map.items()}
|
| 609 |
+
issues['status'] = 'standardized to 0/1'
|
| 610 |
+
elif n_classes > 2:
|
| 611 |
+
issues['status'] = 'warning: more than 2 classes'
|
| 612 |
+
elif n_classes < 2:
|
| 613 |
+
issues['status'] = 'error: less than 2 classes'
|
| 614 |
+
|
| 615 |
+
return df, issues
|
| 616 |
+
|
| 617 |
+
def get_report(self) -> Dict[str, Any]:
|
| 618 |
+
"""Get the cleaning report."""
|
| 619 |
+
return self.cleaning_report
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
def detect_outliers(df: pd.DataFrame, column: str, method: str = 'iqr') -> pd.Series:
|
| 623 |
+
"""
|
| 624 |
+
Detect outliers in a column.
|
| 625 |
+
|
| 626 |
+
Args:
|
| 627 |
+
df: Input dataframe
|
| 628 |
+
column: Column name to check
|
| 629 |
+
method: Detection method ('iqr' or 'zscore')
|
| 630 |
+
|
| 631 |
+
Returns:
|
| 632 |
+
Boolean series indicating outliers
|
| 633 |
+
"""
|
| 634 |
+
if method == 'iqr':
|
| 635 |
+
Q1 = df[column].quantile(0.25)
|
| 636 |
+
Q3 = df[column].quantile(0.75)
|
| 637 |
+
IQR = Q3 - Q1
|
| 638 |
+
return (df[column] < Q1 - 1.5 * IQR) | (df[column] > Q3 + 1.5 * IQR)
|
| 639 |
+
elif method == 'zscore':
|
| 640 |
+
z_scores = np.abs((df[column] - df[column].mean()) / df[column].std())
|
| 641 |
+
return z_scores > 3
|
| 642 |
+
else:
|
| 643 |
+
raise ValueError(f"Unknown method: {method}")
|
credily/cli.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Command Line Interface for Credily.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import warnings
|
| 8 |
+
import multiprocessing
|
| 9 |
+
|
| 10 |
+
# Suppress joblib resource tracker warnings on Windows (must be set before ANY imports)
|
| 11 |
+
if sys.platform == 'win32':
|
| 12 |
+
# Force spawn method to avoid fork-related issues
|
| 13 |
+
try:
|
| 14 |
+
multiprocessing.set_start_method('spawn', force=True)
|
| 15 |
+
except RuntimeError:
|
| 16 |
+
pass # Already set
|
| 17 |
+
|
| 18 |
+
os.environ['LOKY_PICKLER'] = 'pickle'
|
| 19 |
+
os.environ['JOBLIB_MULTIPROCESSING'] = '0'
|
| 20 |
+
os.environ['LOKY_MAX_CPU_COUNT'] = '1'
|
| 21 |
+
|
| 22 |
+
# Suppress all joblib/multiprocessing warnings
|
| 23 |
+
warnings.filterwarnings('ignore', category=UserWarning, module='joblib')
|
| 24 |
+
warnings.filterwarnings('ignore', category=UserWarning, module='loky')
|
| 25 |
+
warnings.filterwarnings('ignore', message='.*resource_tracker.*')
|
| 26 |
+
warnings.filterwarnings('ignore', message='.*Cannot register.*')
|
| 27 |
+
warnings.filterwarnings('ignore', message='.*leaked.*')
|
| 28 |
+
|
| 29 |
+
import click
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@click.group()
|
| 34 |
+
@click.version_option(version='0.1.0', prog_name='Credily')
|
| 35 |
+
def cli():
|
| 36 |
+
"""
|
| 37 |
+
Credily — Data-Agnostic Credit Risk Modeling Engine for Tabular Data
|
| 38 |
+
|
| 39 |
+
Credily is a command-line tool for building, evaluating, and deploying
|
| 40 |
+
binary classification models for credit risk and default prediction.
|
| 41 |
+
|
| 42 |
+
It is designed to adapt to diverse tabular datasets without hard-coded
|
| 43 |
+
assumptions, while exposing interpretable risk signals and tunable
|
| 44 |
+
decision policies.
|
| 45 |
+
|
| 46 |
+
\b
|
| 47 |
+
Key capabilities:
|
| 48 |
+
- Automatically profiles tabular datasets and detects imbalance
|
| 49 |
+
- Trains and evaluates multiple tree-based and linear models
|
| 50 |
+
- Selects the best model using cross-validated ROC-AUC
|
| 51 |
+
- Calibrates predicted probabilities for decision-making
|
| 52 |
+
- Optimizes decision thresholds for imbalanced data
|
| 53 |
+
- Generates credit-ready reports (ROC, PR-AUC, confusion matrix)
|
| 54 |
+
- Exports trained models and structured reports (HTML / JSON)
|
| 55 |
+
|
| 56 |
+
\b
|
| 57 |
+
Typical use cases:
|
| 58 |
+
- Credit scoring and default risk assessment
|
| 59 |
+
- Loan approval / rejection systems
|
| 60 |
+
- Risk-based customer segmentation
|
| 61 |
+
- Decision support for financial products
|
| 62 |
+
|
| 63 |
+
\b
|
| 64 |
+
Quick Start:
|
| 65 |
+
credily train data.csv
|
| 66 |
+
credily train # Interactive mode
|
| 67 |
+
"""
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@cli.command()
|
| 72 |
+
@click.argument('data', type=click.Path(exists=True), required=False)
|
| 73 |
+
@click.option('-t', '--target', help='Target column name')
|
| 74 |
+
@click.option('-o', '--output', default='credily_output', help='Output directory')
|
| 75 |
+
@click.option('--test-size', type=float, default=0.2, help='Test split ratio (default: 0.2)')
|
| 76 |
+
@click.option('--cv', type=int, default=5, help='Cross-validation folds (default: 5)')
|
| 77 |
+
@click.option('--no-profile', is_flag=True, help='Skip data profiling')
|
| 78 |
+
@click.option('--no-clean', is_flag=True, help='Skip data cleaning')
|
| 79 |
+
@click.option('--clean-mode', type=click.Choice(['basic', 'thorough', 'aggressive']), default='thorough', help='Cleaning mode: basic, thorough (default), or aggressive')
|
| 80 |
+
@click.option('--outlier-method', type=click.Choice(['iqr', 'zscore', 'none']), default='iqr', help='Outlier detection method')
|
| 81 |
+
@click.option('--no-balance', is_flag=True, help='Skip data balancing')
|
| 82 |
+
@click.option('--balance-method', type=click.Choice(['smote', 'random_oversample', 'random_undersample', 'smote_tomek', 'tomek', 'nearmiss', 'none']), default='smote', help='Balancing method')
|
| 83 |
+
@click.option('--parallel', is_flag=True, help='Enable parallel processing (may cause warnings on Windows)')
|
| 84 |
+
@click.option('--no-calibrate', is_flag=True, help='Skip probability calibration')
|
| 85 |
+
@click.option('--calibration-method', type=click.Choice(['isotonic', 'sigmoid']), default='isotonic', help='Calibration method (isotonic or sigmoid/Platt)')
|
| 86 |
+
@click.option('--threshold-metric', type=click.Choice(['f1', 'cost', 'youden', 'precision_recall_balance']), default='f1', help='Metric for threshold optimization')
|
| 87 |
+
@click.option('--no-threshold-opt', is_flag=True, help='Skip threshold optimization (use 0.5)')
|
| 88 |
+
@click.option('--binary-threshold', type=float, default=None, help='Threshold to convert numeric target to binary (values BELOW threshold = positive class)')
|
| 89 |
+
def train(data, target, output, test_size, cv, no_profile, no_clean, clean_mode, outlier_method, no_balance, balance_method, parallel, no_calibrate, calibration_method, threshold_metric, no_threshold_opt, binary_threshold):
|
| 90 |
+
"""
|
| 91 |
+
Train a model on your dataset.
|
| 92 |
+
|
| 93 |
+
Supported file formats: CSV (.csv), Text (.txt, .tsv), Excel (.xlsx, .xls)
|
| 94 |
+
"""
|
| 95 |
+
from .automl import CredilyPipeline
|
| 96 |
+
from .utils import load_data, get_supported_formats
|
| 97 |
+
|
| 98 |
+
if data is None:
|
| 99 |
+
# Interactive mode
|
| 100 |
+
click.echo(f"\n{get_supported_formats()}")
|
| 101 |
+
data = click.prompt('Enter path to data file')
|
| 102 |
+
if not Path(data).exists():
|
| 103 |
+
click.echo(click.style(f"Error: File '{data}' not found", fg='red'))
|
| 104 |
+
sys.exit(1)
|
| 105 |
+
|
| 106 |
+
click.echo(click.style(f"\n📊 Loading data from: {data}", fg='cyan'))
|
| 107 |
+
try:
|
| 108 |
+
df = load_data(data)
|
| 109 |
+
except ValueError as e:
|
| 110 |
+
click.echo(click.style(f"Error: {e}", fg='red'))
|
| 111 |
+
sys.exit(1)
|
| 112 |
+
except ImportError as e:
|
| 113 |
+
click.echo(click.style(f"Error: {e}", fg='red'))
|
| 114 |
+
sys.exit(1)
|
| 115 |
+
click.echo(f" Shape: {df.shape[0]} rows, {df.shape[1]} columns")
|
| 116 |
+
|
| 117 |
+
if target is None:
|
| 118 |
+
click.echo(f"\n Columns: {', '.join(df.columns.tolist())}")
|
| 119 |
+
target = click.prompt('Enter target column name')
|
| 120 |
+
|
| 121 |
+
if target not in df.columns:
|
| 122 |
+
click.echo(click.style(f"Error: Target '{target}' not found in data", fg='red'))
|
| 123 |
+
sys.exit(1)
|
| 124 |
+
|
| 125 |
+
pipeline = CredilyPipeline(
|
| 126 |
+
target_column=target,
|
| 127 |
+
output_dir=output,
|
| 128 |
+
test_size=test_size,
|
| 129 |
+
cv_folds=cv,
|
| 130 |
+
clean_data=not no_clean,
|
| 131 |
+
clean_mode=clean_mode,
|
| 132 |
+
outlier_method=outlier_method,
|
| 133 |
+
balance_data=not no_balance,
|
| 134 |
+
balance_method=balance_method,
|
| 135 |
+
n_jobs=-1 if parallel else 1,
|
| 136 |
+
calibrate=not no_calibrate,
|
| 137 |
+
calibration_method=calibration_method,
|
| 138 |
+
optimize_threshold=not no_threshold_opt,
|
| 139 |
+
threshold_metric=threshold_metric,
|
| 140 |
+
binary_threshold=binary_threshold
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if not no_profile:
|
| 144 |
+
click.echo(click.style("\n🔍 Profiling data...", fg='cyan'))
|
| 145 |
+
pipeline.profile(df)
|
| 146 |
+
|
| 147 |
+
click.echo(click.style("\n🚀 Training models...", fg='cyan'))
|
| 148 |
+
results = pipeline.train(df)
|
| 149 |
+
|
| 150 |
+
click.echo(click.style(f"\n✅ Training complete!", fg='green'))
|
| 151 |
+
click.echo(f" Best model: {results['best_model']}")
|
| 152 |
+
click.echo(f" ROC-AUC: {results['best_score']:.4f}")
|
| 153 |
+
click.echo(f" Output saved to: {output}/")
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@cli.command()
|
| 157 |
+
@click.argument('data', type=click.Path(exists=True))
|
| 158 |
+
@click.option('-m', '--model', default='credily_output/model.pkl', help='Path to trained model')
|
| 159 |
+
@click.option('-o', '--output', default='predictions', help='Output file path (without extension)')
|
| 160 |
+
@click.option('--format', 'output_format', type=click.Choice(['excel', 'csv', 'both']), default='excel', help='Output format: excel (default), csv, or both')
|
| 161 |
+
@click.option('--proba', is_flag=True, default=True, help='Include prediction probabilities (default: True)')
|
| 162 |
+
@click.option('--no-proba', is_flag=True, help='Exclude prediction probabilities')
|
| 163 |
+
def predict(data, model, output, output_format, proba, no_proba):
|
| 164 |
+
"""
|
| 165 |
+
Make predictions using a trained model.
|
| 166 |
+
|
| 167 |
+
Outputs the FULL dataset with predictions and probabilities.
|
| 168 |
+
Supported input formats: CSV (.csv), Text (.txt, .tsv), Excel (.xlsx, .xls)
|
| 169 |
+
"""
|
| 170 |
+
from .automl import CredilyPipeline
|
| 171 |
+
from .utils import load_data, save_to_excel
|
| 172 |
+
|
| 173 |
+
include_proba = proba and not no_proba
|
| 174 |
+
|
| 175 |
+
click.echo(click.style(f"\n📦 Loading model from: {model}", fg='cyan'))
|
| 176 |
+
pipeline = CredilyPipeline.load(model)
|
| 177 |
+
|
| 178 |
+
click.echo(click.style(f"📊 Loading data from: {data}", fg='cyan'))
|
| 179 |
+
try:
|
| 180 |
+
df = load_data(data)
|
| 181 |
+
except (ValueError, ImportError) as e:
|
| 182 |
+
click.echo(click.style(f"Error: {e}", fg='red'))
|
| 183 |
+
sys.exit(1)
|
| 184 |
+
click.echo(f" Shape: {df.shape[0]} rows, {df.shape[1]} columns")
|
| 185 |
+
|
| 186 |
+
click.echo(click.style("\n🔮 Making predictions...", fg='cyan'))
|
| 187 |
+
predictions = pipeline.predict(df, include_proba=include_proba)
|
| 188 |
+
|
| 189 |
+
# Show prediction summary
|
| 190 |
+
pred_counts = predictions['prediction'].value_counts().sort_index()
|
| 191 |
+
click.echo(f"\n Prediction distribution:")
|
| 192 |
+
for pred_val, count in pred_counts.items():
|
| 193 |
+
pct = count / len(predictions) * 100
|
| 194 |
+
click.echo(f" - Class {pred_val}: {count} ({pct:.1f}%)")
|
| 195 |
+
|
| 196 |
+
# Save outputs
|
| 197 |
+
saved_files = []
|
| 198 |
+
|
| 199 |
+
if output_format in ['excel', 'both']:
|
| 200 |
+
try:
|
| 201 |
+
excel_path = save_to_excel(predictions, output)
|
| 202 |
+
saved_files.append(excel_path)
|
| 203 |
+
click.echo(click.style(f"\n✅ Excel report saved to: {excel_path}", fg='green'))
|
| 204 |
+
except ImportError as e:
|
| 205 |
+
click.echo(click.style(f"Warning: {e}", fg='yellow'))
|
| 206 |
+
if output_format == 'excel':
|
| 207 |
+
output_format = 'csv'
|
| 208 |
+
click.echo(" Falling back to CSV format...")
|
| 209 |
+
|
| 210 |
+
if output_format in ['csv', 'both']:
|
| 211 |
+
csv_path = f"{output}.csv"
|
| 212 |
+
predictions.to_csv(csv_path, index=False)
|
| 213 |
+
saved_files.append(csv_path)
|
| 214 |
+
click.echo(click.style(f"\n✅ CSV saved to: {csv_path}", fg='green'))
|
| 215 |
+
|
| 216 |
+
click.echo(f"\n Total records: {len(predictions)}")
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
@cli.command()
|
| 220 |
+
@click.argument('data', type=click.Path(exists=True))
|
| 221 |
+
@click.option('-t', '--target', help='Target column name (optional)')
|
| 222 |
+
def profile(data, target):
|
| 223 |
+
"""
|
| 224 |
+
Profile a dataset without training.
|
| 225 |
+
|
| 226 |
+
Supported formats: CSV (.csv), Text (.txt, .tsv), Excel (.xlsx, .xls)
|
| 227 |
+
"""
|
| 228 |
+
from .profiler import DataProfiler
|
| 229 |
+
from .utils import load_data
|
| 230 |
+
|
| 231 |
+
click.echo(click.style(f"\n📊 Loading data from: {data}", fg='cyan'))
|
| 232 |
+
try:
|
| 233 |
+
df = load_data(data)
|
| 234 |
+
except (ValueError, ImportError) as e:
|
| 235 |
+
click.echo(click.style(f"Error: {e}", fg='red'))
|
| 236 |
+
sys.exit(1)
|
| 237 |
+
|
| 238 |
+
profiler = DataProfiler(target_column=target)
|
| 239 |
+
report = profiler.profile(df)
|
| 240 |
+
|
| 241 |
+
click.echo(click.style("\n" + "=" * 60, fg='cyan'))
|
| 242 |
+
click.echo(click.style("DATA PROFILE REPORT", fg='cyan', bold=True))
|
| 243 |
+
click.echo(click.style("=" * 60, fg='cyan'))
|
| 244 |
+
|
| 245 |
+
click.echo(f"\n📋 Basic Info:")
|
| 246 |
+
click.echo(f" Rows: {report['n_rows']}")
|
| 247 |
+
click.echo(f" Columns: {report['n_cols']}")
|
| 248 |
+
click.echo(f" Memory: {report['memory_mb']:.2f} MB")
|
| 249 |
+
|
| 250 |
+
click.echo(f"\n📊 Column Types:")
|
| 251 |
+
click.echo(f" Numeric: {report['n_numeric']}")
|
| 252 |
+
click.echo(f" Categorical: {report['n_categorical']}")
|
| 253 |
+
|
| 254 |
+
click.echo(f"\n⚠️ Data Quality:")
|
| 255 |
+
click.echo(f" Missing values: {report['missing_pct']:.1f}%")
|
| 256 |
+
click.echo(f" Duplicate rows: {report['duplicate_rows']}")
|
| 257 |
+
|
| 258 |
+
if target and 'target_info' in report:
|
| 259 |
+
click.echo(f"\n🎯 Target Analysis:")
|
| 260 |
+
click.echo(f" Task type: {report['target_info']['task_type']}")
|
| 261 |
+
click.echo(f" Classes: {report['target_info']['n_classes']}")
|
| 262 |
+
click.echo(f" Class balance: {report['target_info']['balance']}")
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
@cli.command()
|
| 266 |
+
@click.argument('report_path', type=click.Path(exists=True))
|
| 267 |
+
def show(report_path):
|
| 268 |
+
"""Display a training report."""
|
| 269 |
+
import json
|
| 270 |
+
from pathlib import Path
|
| 271 |
+
|
| 272 |
+
path = Path(report_path)
|
| 273 |
+
|
| 274 |
+
if path.suffix == '.json':
|
| 275 |
+
with open(path) as f:
|
| 276 |
+
report = json.load(f)
|
| 277 |
+
|
| 278 |
+
click.echo(click.style("\n" + "=" * 60, fg='cyan'))
|
| 279 |
+
click.echo(click.style("TRAINING REPORT", fg='cyan', bold=True))
|
| 280 |
+
click.echo(click.style("=" * 60, fg='cyan'))
|
| 281 |
+
|
| 282 |
+
click.echo(f"\n🏆 Best Model: {report.get('best_model', 'N/A')}")
|
| 283 |
+
click.echo(f" ROC-AUC: {report.get('best_score', 0):.4f}")
|
| 284 |
+
|
| 285 |
+
if 'model_scores' in report:
|
| 286 |
+
click.echo(f"\n📊 All Models:")
|
| 287 |
+
for model, score in report['model_scores'].items():
|
| 288 |
+
click.echo(f" {model}: {score:.4f}")
|
| 289 |
+
|
| 290 |
+
elif path.suffix == '.html':
|
| 291 |
+
import webbrowser
|
| 292 |
+
webbrowser.open(f'file://{path.absolute()}')
|
| 293 |
+
click.echo(click.style(f"Opening report in browser...", fg='cyan'))
|
| 294 |
+
else:
|
| 295 |
+
click.echo(click.style(f"Error: Unsupported format '{path.suffix}'", fg='red'))
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
@cli.command()
|
| 299 |
+
@click.argument('data', type=click.Path(exists=True))
|
| 300 |
+
@click.option('-m', '--model', default='credily_output/model.pkl', help='Path to trained model')
|
| 301 |
+
@click.option('-t', '--target', required=True, help='Target column name')
|
| 302 |
+
@click.option('-c', '--context', default='credit_scoring', help='Business context')
|
| 303 |
+
def analyze(data, model, target, context):
|
| 304 |
+
"""
|
| 305 |
+
Analyze model performance in different business contexts.
|
| 306 |
+
|
| 307 |
+
Supported formats: CSV (.csv), Text (.txt, .tsv), Excel (.xlsx, .xls)
|
| 308 |
+
"""
|
| 309 |
+
from .automl import CredilyPipeline
|
| 310 |
+
from .analyzer import BusinessAnalyzer
|
| 311 |
+
from .utils import load_data
|
| 312 |
+
|
| 313 |
+
click.echo(click.style(f"\n📦 Loading model from: {model}", fg='cyan'))
|
| 314 |
+
pipeline = CredilyPipeline.load(model)
|
| 315 |
+
|
| 316 |
+
click.echo(click.style(f"📊 Loading data from: {data}", fg='cyan'))
|
| 317 |
+
try:
|
| 318 |
+
df = load_data(data)
|
| 319 |
+
except (ValueError, ImportError) as e:
|
| 320 |
+
click.echo(click.style(f"Error: {e}", fg='red'))
|
| 321 |
+
sys.exit(1)
|
| 322 |
+
|
| 323 |
+
analyzer = BusinessAnalyzer(context=context)
|
| 324 |
+
report = analyzer.analyze(pipeline, df, target)
|
| 325 |
+
|
| 326 |
+
click.echo(click.style("\n" + "=" * 60, fg='cyan'))
|
| 327 |
+
click.echo(click.style(f"BUSINESS ANALYSIS: {context.upper()}", fg='cyan', bold=True))
|
| 328 |
+
click.echo(click.style("=" * 60, fg='cyan'))
|
| 329 |
+
|
| 330 |
+
click.echo(f"\n💰 Financial Impact:")
|
| 331 |
+
click.echo(f" Expected profit: ${report['expected_profit']:,.2f}")
|
| 332 |
+
click.echo(f" Risk exposure: ${report['risk_exposure']:,.2f}")
|
| 333 |
+
|
| 334 |
+
click.echo(f"\n📊 Threshold Analysis:")
|
| 335 |
+
click.echo(f" Optimal threshold: {report['optimal_threshold']:.2f}")
|
| 336 |
+
click.echo(f" Precision at threshold: {report['precision']:.2%}")
|
| 337 |
+
click.echo(f" Recall at threshold: {report['recall']:.2%}")
|
| 338 |
+
|
| 339 |
+
click.echo(f"\n🎯 Recommendations:")
|
| 340 |
+
for rec in report['recommendations']:
|
| 341 |
+
click.echo(f" • {rec}")
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
@cli.command('list-contexts')
|
| 345 |
+
def list_contexts():
|
| 346 |
+
"""List all available business contexts with descriptions."""
|
| 347 |
+
contexts = {
|
| 348 |
+
'credit_scoring': 'Loan default prediction - optimizes for minimizing bad debt',
|
| 349 |
+
'fraud_detection': 'Transaction fraud - optimizes for catching fraud with low false positives',
|
| 350 |
+
'churn_prediction': 'Customer churn - optimizes for retention ROI',
|
| 351 |
+
'insurance_claims': 'Claims prediction - optimizes for loss ratio',
|
| 352 |
+
'collections': 'Debt collection - optimizes for recovery rate',
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
click.echo(click.style("\n📋 Available Business Contexts:", fg='cyan', bold=True))
|
| 356 |
+
click.echo()
|
| 357 |
+
for name, desc in contexts.items():
|
| 358 |
+
click.echo(f" {click.style(name, fg='green', bold=True)}")
|
| 359 |
+
click.echo(f" {desc}\n")
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def main():
|
| 363 |
+
cli()
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
if __name__ == '__main__':
|
| 367 |
+
main()
|
credily/metrics.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Metrics module for TabulaML.
|
| 3 |
+
Handles model evaluation and visualization.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from sklearn.metrics import classification_report, roc_auc_score
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def print_classification_metrics(y_true, y_pred, y_proba=None):
|
| 12 |
+
"""
|
| 13 |
+
Print classification metrics including precision, recall, F1-score.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
y_true: True labels
|
| 17 |
+
y_pred: Predicted labels
|
| 18 |
+
y_proba: Predicted probabilities for positive class (optional)
|
| 19 |
+
"""
|
| 20 |
+
print("\n" + "=" * 50)
|
| 21 |
+
print("CLASSIFICATION REPORT")
|
| 22 |
+
print("=" * 50)
|
| 23 |
+
print(classification_report(y_true, y_pred))
|
| 24 |
+
|
| 25 |
+
if y_proba is not None:
|
| 26 |
+
roc_auc = roc_auc_score(y_true, y_proba)
|
| 27 |
+
print(f"ROC-AUC Score: {roc_auc:.4f}")
|
| 28 |
+
print("=" * 50 + "\n")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def plot_feature_importance(feature_names: list, importances: np.ndarray, top_n: int = 20):
|
| 32 |
+
"""
|
| 33 |
+
Generate and display a feature importance chart.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
feature_names: List of feature names
|
| 37 |
+
importances: Array of feature importances
|
| 38 |
+
top_n: Number of top features to display
|
| 39 |
+
"""
|
| 40 |
+
indices = np.argsort(importances)[::-1][:top_n]
|
| 41 |
+
|
| 42 |
+
plt.figure(figsize=(10, 8))
|
| 43 |
+
plt.title("Feature Importances (Top {})".format(min(top_n, len(feature_names))))
|
| 44 |
+
plt.barh(range(len(indices)), importances[indices][::-1], align='center')
|
| 45 |
+
plt.yticks(range(len(indices)), [feature_names[i] for i in indices][::-1])
|
| 46 |
+
plt.xlabel("Importance")
|
| 47 |
+
plt.ylabel("Feature")
|
| 48 |
+
plt.tight_layout()
|
| 49 |
+
plt.show()
|
credily/model.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core TabulaML model module.
|
| 3 |
+
Contains the TabulaMLModel class for training, prediction, and evaluation.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import joblib
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 11 |
+
from sklearn.model_selection import train_test_split
|
| 12 |
+
from sklearn.pipeline import Pipeline
|
| 13 |
+
|
| 14 |
+
from .preprocessing import identify_column_types, create_preprocessor
|
| 15 |
+
from .metrics import print_classification_metrics, plot_feature_importance
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TabulaMLModel:
|
| 19 |
+
"""
|
| 20 |
+
A data-generic binary classification model for credit scoring.
|
| 21 |
+
|
| 22 |
+
This class provides methods to train, predict, and evaluate
|
| 23 |
+
a Random Forest model on any tabular dataset with numeric
|
| 24 |
+
and categorical features.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
target_column: str = 'target',
|
| 30 |
+
model_path: str = 'credit_model.pkl',
|
| 31 |
+
n_estimators: int = 200,
|
| 32 |
+
max_depth: int = 10,
|
| 33 |
+
random_state: int = 42
|
| 34 |
+
):
|
| 35 |
+
"""
|
| 36 |
+
Initialize the TabulaMLModel.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
target_column: Name of the binary target column
|
| 40 |
+
model_path: Path to save/load the trained model
|
| 41 |
+
n_estimators: Number of trees in the forest
|
| 42 |
+
max_depth: Maximum depth of trees
|
| 43 |
+
random_state: Random seed for reproducibility
|
| 44 |
+
"""
|
| 45 |
+
self.target_column = target_column
|
| 46 |
+
self.model_path = model_path
|
| 47 |
+
self.n_estimators = n_estimators
|
| 48 |
+
self.max_depth = max_depth
|
| 49 |
+
self.random_state = random_state
|
| 50 |
+
|
| 51 |
+
self.pipeline = None
|
| 52 |
+
self.feature_names = None
|
| 53 |
+
self.numeric_columns = None
|
| 54 |
+
self.categorical_columns = None
|
| 55 |
+
|
| 56 |
+
def train(self, dataframe: pd.DataFrame, test_size: float = 0.2):
|
| 57 |
+
"""
|
| 58 |
+
Train the model on the provided dataframe.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
dataframe: Input dataframe with features and target column
|
| 62 |
+
test_size: Proportion of data for testing (default 0.2)
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
dict: Training results including test metrics
|
| 66 |
+
"""
|
| 67 |
+
if self.target_column not in dataframe.columns:
|
| 68 |
+
raise ValueError(f"Target column '{self.target_column}' not found in dataframe")
|
| 69 |
+
|
| 70 |
+
X = dataframe.drop(columns=[self.target_column])
|
| 71 |
+
y = dataframe[self.target_column]
|
| 72 |
+
|
| 73 |
+
self.numeric_columns, self.categorical_columns = identify_column_types(
|
| 74 |
+
dataframe, self.target_column
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
print(f"Numeric features: {len(self.numeric_columns)}")
|
| 78 |
+
print(f"Categorical features: {len(self.categorical_columns)}")
|
| 79 |
+
|
| 80 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 81 |
+
X, y,
|
| 82 |
+
test_size=test_size,
|
| 83 |
+
stratify=y,
|
| 84 |
+
random_state=self.random_state
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
preprocessor = create_preprocessor(
|
| 88 |
+
self.numeric_columns,
|
| 89 |
+
self.categorical_columns
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
model = RandomForestClassifier(
|
| 93 |
+
n_estimators=self.n_estimators,
|
| 94 |
+
max_depth=self.max_depth,
|
| 95 |
+
class_weight='balanced',
|
| 96 |
+
random_state=self.random_state,
|
| 97 |
+
n_jobs=-1
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
self.pipeline = Pipeline(steps=[
|
| 101 |
+
('preprocessor', preprocessor),
|
| 102 |
+
('classifier', model)
|
| 103 |
+
])
|
| 104 |
+
|
| 105 |
+
print("\nTraining model...")
|
| 106 |
+
self.pipeline.fit(X_train, y_train)
|
| 107 |
+
|
| 108 |
+
self._extract_feature_names()
|
| 109 |
+
|
| 110 |
+
self._save_model()
|
| 111 |
+
|
| 112 |
+
print("\nEvaluating on test set...")
|
| 113 |
+
y_pred = self.pipeline.predict(X_test)
|
| 114 |
+
y_proba = self.pipeline.predict_proba(X_test)[:, 1]
|
| 115 |
+
|
| 116 |
+
print_classification_metrics(y_test, y_pred, y_proba)
|
| 117 |
+
|
| 118 |
+
return {
|
| 119 |
+
'X_test': X_test,
|
| 120 |
+
'y_test': y_test,
|
| 121 |
+
'y_pred': y_pred,
|
| 122 |
+
'y_proba': y_proba
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
def predict(self, dataframe: pd.DataFrame) -> np.ndarray:
|
| 126 |
+
"""
|
| 127 |
+
Make predictions on new data.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
dataframe: Input dataframe with features (no target column needed)
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
np.ndarray: Predicted class labels
|
| 134 |
+
"""
|
| 135 |
+
if self.pipeline is None:
|
| 136 |
+
self._load_model()
|
| 137 |
+
|
| 138 |
+
if self.target_column in dataframe.columns:
|
| 139 |
+
dataframe = dataframe.drop(columns=[self.target_column])
|
| 140 |
+
|
| 141 |
+
return self.pipeline.predict(dataframe)
|
| 142 |
+
|
| 143 |
+
def predict_proba(self, dataframe: pd.DataFrame) -> np.ndarray:
|
| 144 |
+
"""
|
| 145 |
+
Get prediction probabilities for new data.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
dataframe: Input dataframe with features
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
np.ndarray: Predicted probabilities for each class
|
| 152 |
+
"""
|
| 153 |
+
if self.pipeline is None:
|
| 154 |
+
self._load_model()
|
| 155 |
+
|
| 156 |
+
if self.target_column in dataframe.columns:
|
| 157 |
+
dataframe = dataframe.drop(columns=[self.target_column])
|
| 158 |
+
|
| 159 |
+
return self.pipeline.predict_proba(dataframe)
|
| 160 |
+
|
| 161 |
+
def evaluate(self, dataframe: pd.DataFrame, show_feature_importance: bool = True):
|
| 162 |
+
"""
|
| 163 |
+
Evaluate the model on a dataset and print metrics.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
dataframe: Input dataframe with features and target column
|
| 167 |
+
show_feature_importance: Whether to display feature importance chart
|
| 168 |
+
"""
|
| 169 |
+
if self.pipeline is None:
|
| 170 |
+
self._load_model()
|
| 171 |
+
|
| 172 |
+
if self.target_column not in dataframe.columns:
|
| 173 |
+
raise ValueError(f"Target column '{self.target_column}' required for evaluation")
|
| 174 |
+
|
| 175 |
+
X = dataframe.drop(columns=[self.target_column])
|
| 176 |
+
y = dataframe[self.target_column]
|
| 177 |
+
|
| 178 |
+
y_pred = self.pipeline.predict(X)
|
| 179 |
+
y_proba = self.pipeline.predict_proba(X)[:, 1]
|
| 180 |
+
|
| 181 |
+
print_classification_metrics(y, y_pred, y_proba)
|
| 182 |
+
|
| 183 |
+
if show_feature_importance and self.feature_names is not None:
|
| 184 |
+
importances = self.pipeline.named_steps['classifier'].feature_importances_
|
| 185 |
+
plot_feature_importance(self.feature_names, importances)
|
| 186 |
+
|
| 187 |
+
def _extract_feature_names(self):
|
| 188 |
+
"""Extract feature names from the fitted preprocessor."""
|
| 189 |
+
preprocessor = self.pipeline.named_steps['preprocessor']
|
| 190 |
+
|
| 191 |
+
feature_names = []
|
| 192 |
+
|
| 193 |
+
feature_names.extend(self.numeric_columns)
|
| 194 |
+
|
| 195 |
+
if self.categorical_columns:
|
| 196 |
+
cat_encoder = preprocessor.named_transformers_['cat'].named_steps['encoder']
|
| 197 |
+
cat_feature_names = cat_encoder.get_feature_names_out(self.categorical_columns)
|
| 198 |
+
feature_names.extend(cat_feature_names.tolist())
|
| 199 |
+
|
| 200 |
+
self.feature_names = feature_names
|
| 201 |
+
|
| 202 |
+
def _save_model(self):
|
| 203 |
+
"""Save the trained model to disk."""
|
| 204 |
+
model_data = {
|
| 205 |
+
'pipeline': self.pipeline,
|
| 206 |
+
'feature_names': self.feature_names,
|
| 207 |
+
'numeric_columns': self.numeric_columns,
|
| 208 |
+
'categorical_columns': self.categorical_columns,
|
| 209 |
+
'target_column': self.target_column
|
| 210 |
+
}
|
| 211 |
+
joblib.dump(model_data, self.model_path)
|
| 212 |
+
print(f"\nModel saved to: {self.model_path}")
|
| 213 |
+
|
| 214 |
+
def _load_model(self):
|
| 215 |
+
"""Load a trained model from disk."""
|
| 216 |
+
if not os.path.exists(self.model_path):
|
| 217 |
+
raise FileNotFoundError(f"No model found at: {self.model_path}")
|
| 218 |
+
|
| 219 |
+
model_data = joblib.load(self.model_path)
|
| 220 |
+
self.pipeline = model_data['pipeline']
|
| 221 |
+
self.feature_names = model_data['feature_names']
|
| 222 |
+
self.numeric_columns = model_data['numeric_columns']
|
| 223 |
+
self.categorical_columns = model_data['categorical_columns']
|
| 224 |
+
self.target_column = model_data['target_column']
|
| 225 |
+
print(f"Model loaded from: {self.model_path}")
|
| 226 |
+
|
| 227 |
+
@classmethod
|
| 228 |
+
def load(cls, model_path: str = 'credit_model.pkl'):
|
| 229 |
+
"""
|
| 230 |
+
Class method to load a pre-trained model.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
model_path: Path to the saved model file
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
TabulaMLModel: Loaded model instance
|
| 237 |
+
"""
|
| 238 |
+
instance = cls(model_path=model_path)
|
| 239 |
+
instance._load_model()
|
| 240 |
+
return instance
|
credily/preprocessing.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Preprocessing module for TabulaML.
|
| 3 |
+
Handles missing value imputation and categorical encoding.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
from sklearn.impute import SimpleImputer
|
| 9 |
+
from sklearn.preprocessing import OneHotEncoder
|
| 10 |
+
from sklearn.compose import ColumnTransformer
|
| 11 |
+
from sklearn.pipeline import Pipeline
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def identify_column_types(df: pd.DataFrame, target_column: str = None):
|
| 15 |
+
"""
|
| 16 |
+
Identify numeric and categorical columns in the dataframe.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
df: Input dataframe
|
| 20 |
+
target_column: Name of target column to exclude from features
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
tuple: (numeric_columns, categorical_columns)
|
| 24 |
+
"""
|
| 25 |
+
cols = df.columns.tolist()
|
| 26 |
+
if target_column and target_column in cols:
|
| 27 |
+
cols.remove(target_column)
|
| 28 |
+
|
| 29 |
+
numeric_cols = df[cols].select_dtypes(include=[np.number]).columns.tolist()
|
| 30 |
+
categorical_cols = df[cols].select_dtypes(include=['object', 'category']).columns.tolist()
|
| 31 |
+
|
| 32 |
+
return numeric_cols, categorical_cols
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def create_preprocessor(numeric_columns: list, categorical_columns: list):
|
| 36 |
+
"""
|
| 37 |
+
Create a preprocessing pipeline for numeric and categorical features.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
numeric_columns: List of numeric column names
|
| 41 |
+
categorical_columns: List of categorical column names
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
ColumnTransformer: Preprocessing pipeline
|
| 45 |
+
"""
|
| 46 |
+
numeric_transformer = Pipeline(steps=[
|
| 47 |
+
('imputer', SimpleImputer(strategy='median'))
|
| 48 |
+
])
|
| 49 |
+
|
| 50 |
+
categorical_transformer = Pipeline(steps=[
|
| 51 |
+
('imputer', SimpleImputer(strategy='most_frequent')),
|
| 52 |
+
('encoder', OneHotEncoder(handle_unknown='ignore', sparse_output=False))
|
| 53 |
+
])
|
| 54 |
+
|
| 55 |
+
preprocessor = ColumnTransformer(
|
| 56 |
+
transformers=[
|
| 57 |
+
('num', numeric_transformer, numeric_columns),
|
| 58 |
+
('cat', categorical_transformer, categorical_columns)
|
| 59 |
+
],
|
| 60 |
+
remainder='drop'
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
return preprocessor
|
credily/profiler.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data profiling module for TabulaML.
|
| 3 |
+
Analyzes datasets and infers ML task types.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import Optional, Dict, Any
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DataProfiler:
|
| 12 |
+
"""
|
| 13 |
+
Profiles tabular datasets for ML readiness.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, target_column: Optional[str] = None):
|
| 17 |
+
self.target_column = target_column
|
| 18 |
+
|
| 19 |
+
def profile(self, df: pd.DataFrame) -> Dict[str, Any]:
|
| 20 |
+
"""
|
| 21 |
+
Generate a comprehensive profile of the dataset.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
df: Input dataframe
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
dict: Profile report with statistics and recommendations
|
| 28 |
+
"""
|
| 29 |
+
report = {}
|
| 30 |
+
|
| 31 |
+
# Basic info
|
| 32 |
+
report['n_rows'] = len(df)
|
| 33 |
+
report['n_cols'] = len(df.columns)
|
| 34 |
+
report['memory_mb'] = df.memory_usage(deep=True).sum() / (1024 * 1024)
|
| 35 |
+
|
| 36 |
+
# Column types
|
| 37 |
+
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
| 38 |
+
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
|
| 39 |
+
|
| 40 |
+
if self.target_column:
|
| 41 |
+
if self.target_column in numeric_cols:
|
| 42 |
+
numeric_cols.remove(self.target_column)
|
| 43 |
+
if self.target_column in categorical_cols:
|
| 44 |
+
categorical_cols.remove(self.target_column)
|
| 45 |
+
|
| 46 |
+
report['n_numeric'] = len(numeric_cols)
|
| 47 |
+
report['n_categorical'] = len(categorical_cols)
|
| 48 |
+
report['numeric_columns'] = numeric_cols
|
| 49 |
+
report['categorical_columns'] = categorical_cols
|
| 50 |
+
|
| 51 |
+
# Missing values
|
| 52 |
+
missing_total = df.isnull().sum().sum()
|
| 53 |
+
total_cells = df.shape[0] * df.shape[1]
|
| 54 |
+
report['missing_pct'] = (missing_total / total_cells) * 100 if total_cells > 0 else 0
|
| 55 |
+
report['missing_by_column'] = df.isnull().sum().to_dict()
|
| 56 |
+
|
| 57 |
+
# Duplicates
|
| 58 |
+
report['duplicate_rows'] = df.duplicated().sum()
|
| 59 |
+
|
| 60 |
+
# Column statistics
|
| 61 |
+
report['column_stats'] = self._get_column_stats(df, numeric_cols, categorical_cols)
|
| 62 |
+
|
| 63 |
+
# Target analysis
|
| 64 |
+
if self.target_column and self.target_column in df.columns:
|
| 65 |
+
report['target_info'] = self._analyze_target(df[self.target_column])
|
| 66 |
+
|
| 67 |
+
# Data quality warnings
|
| 68 |
+
report['warnings'] = self._generate_warnings(df, report)
|
| 69 |
+
|
| 70 |
+
return report
|
| 71 |
+
|
| 72 |
+
def _get_column_stats(
|
| 73 |
+
self,
|
| 74 |
+
df: pd.DataFrame,
|
| 75 |
+
numeric_cols: list,
|
| 76 |
+
categorical_cols: list
|
| 77 |
+
) -> Dict[str, Dict]:
|
| 78 |
+
"""Get statistics for each column."""
|
| 79 |
+
stats = {}
|
| 80 |
+
|
| 81 |
+
for col in numeric_cols:
|
| 82 |
+
stats[col] = {
|
| 83 |
+
'type': 'numeric',
|
| 84 |
+
'mean': df[col].mean(),
|
| 85 |
+
'std': df[col].std(),
|
| 86 |
+
'min': df[col].min(),
|
| 87 |
+
'max': df[col].max(),
|
| 88 |
+
'missing': df[col].isnull().sum(),
|
| 89 |
+
'zeros': (df[col] == 0).sum(),
|
| 90 |
+
'unique': df[col].nunique()
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
for col in categorical_cols:
|
| 94 |
+
value_counts = df[col].value_counts()
|
| 95 |
+
stats[col] = {
|
| 96 |
+
'type': 'categorical',
|
| 97 |
+
'unique': df[col].nunique(),
|
| 98 |
+
'missing': df[col].isnull().sum(),
|
| 99 |
+
'top_value': value_counts.index[0] if len(value_counts) > 0 else None,
|
| 100 |
+
'top_freq': value_counts.iloc[0] if len(value_counts) > 0 else 0
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
return stats
|
| 104 |
+
|
| 105 |
+
def _analyze_target(self, target: pd.Series) -> Dict[str, Any]:
|
| 106 |
+
"""Analyze the target variable."""
|
| 107 |
+
# Drop NaN values for analysis
|
| 108 |
+
target_clean = target.dropna()
|
| 109 |
+
n_unique = target_clean.nunique()
|
| 110 |
+
|
| 111 |
+
if n_unique == 0:
|
| 112 |
+
return {
|
| 113 |
+
'task_type': 'unknown',
|
| 114 |
+
'n_classes': 0,
|
| 115 |
+
'value_counts': {},
|
| 116 |
+
'balance': 'N/A (no valid values)',
|
| 117 |
+
'imbalance_ratio': 0,
|
| 118 |
+
'is_imbalanced': False,
|
| 119 |
+
'warning': 'Target column has no valid values'
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
if n_unique == 2:
|
| 123 |
+
task_type = 'binary_classification'
|
| 124 |
+
elif n_unique <= 10:
|
| 125 |
+
task_type = 'multiclass_classification'
|
| 126 |
+
elif target.dtype in [np.float64, np.float32]:
|
| 127 |
+
task_type = 'regression'
|
| 128 |
+
else:
|
| 129 |
+
task_type = 'multiclass_classification' if n_unique <= 50 else 'regression'
|
| 130 |
+
|
| 131 |
+
value_counts = target_clean.value_counts()
|
| 132 |
+
majority_class = value_counts.iloc[0] if len(value_counts) > 0 else 0
|
| 133 |
+
minority_class = value_counts.iloc[-1] if len(value_counts) > 0 else 0
|
| 134 |
+
imbalance_ratio = majority_class / minority_class if minority_class > 0 else float('inf')
|
| 135 |
+
|
| 136 |
+
return {
|
| 137 |
+
'task_type': task_type,
|
| 138 |
+
'n_classes': n_unique,
|
| 139 |
+
'value_counts': value_counts.to_dict(),
|
| 140 |
+
'balance': f"{minority_class}:{majority_class} (1:{imbalance_ratio:.1f})",
|
| 141 |
+
'imbalance_ratio': imbalance_ratio,
|
| 142 |
+
'is_imbalanced': imbalance_ratio > 3
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
def _generate_warnings(self, df: pd.DataFrame, report: Dict) -> list:
|
| 146 |
+
"""Generate data quality warnings."""
|
| 147 |
+
warnings = []
|
| 148 |
+
|
| 149 |
+
if report['missing_pct'] > 20:
|
| 150 |
+
warnings.append(f"High missing value rate: {report['missing_pct']:.1f}%")
|
| 151 |
+
|
| 152 |
+
if report['duplicate_rows'] > 0:
|
| 153 |
+
warnings.append(f"Found {report['duplicate_rows']} duplicate rows")
|
| 154 |
+
|
| 155 |
+
# Check for high cardinality categoricals
|
| 156 |
+
for col, stats in report['column_stats'].items():
|
| 157 |
+
if stats['type'] == 'categorical' and stats['unique'] > 50:
|
| 158 |
+
warnings.append(f"High cardinality in '{col}': {stats['unique']} unique values")
|
| 159 |
+
|
| 160 |
+
# Check target imbalance
|
| 161 |
+
if 'target_info' in report and report['target_info'].get('is_imbalanced'):
|
| 162 |
+
warnings.append(f"Target is imbalanced (ratio: 1:{report['target_info']['imbalance_ratio']:.1f})")
|
| 163 |
+
|
| 164 |
+
return warnings
|
| 165 |
+
|
| 166 |
+
def infer_task_type(self, df: pd.DataFrame) -> str:
|
| 167 |
+
"""
|
| 168 |
+
Infer the ML task type from the target column.
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
str: 'binary_classification', 'multiclass_classification', or 'regression'
|
| 172 |
+
"""
|
| 173 |
+
if not self.target_column or self.target_column not in df.columns:
|
| 174 |
+
raise ValueError("Target column must be specified and present in dataframe")
|
| 175 |
+
|
| 176 |
+
target = df[self.target_column]
|
| 177 |
+
n_unique = target.nunique()
|
| 178 |
+
|
| 179 |
+
if n_unique == 2:
|
| 180 |
+
return 'binary_classification'
|
| 181 |
+
elif n_unique <= 10 or target.dtype == 'object':
|
| 182 |
+
return 'multiclass_classification'
|
| 183 |
+
else:
|
| 184 |
+
return 'regression'
|
credily/reporting.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Report generation module for Credily.
|
| 3 |
+
Generates HTML and JSON reports.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Dict, Any, Optional
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ReportGenerator:
|
| 13 |
+
"""Generates training reports in various formats."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, output_dir: str):
|
| 16 |
+
self.output_dir = Path(output_dir)
|
| 17 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 18 |
+
|
| 19 |
+
def generate_html_report(
|
| 20 |
+
self,
|
| 21 |
+
training_results: Dict[str, Any],
|
| 22 |
+
profile_report: Optional[Dict[str, Any]] = None
|
| 23 |
+
) -> str:
|
| 24 |
+
"""
|
| 25 |
+
Generate an HTML report.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
training_results: Results from training
|
| 29 |
+
profile_report: Optional data profile report
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
Path to generated HTML file
|
| 33 |
+
"""
|
| 34 |
+
html = self._build_html(training_results, profile_report)
|
| 35 |
+
report_path = self.output_dir / 'report.html'
|
| 36 |
+
|
| 37 |
+
with open(report_path, 'w', encoding='utf-8') as f:
|
| 38 |
+
f.write(html)
|
| 39 |
+
|
| 40 |
+
print(f"HTML report saved to: {report_path}")
|
| 41 |
+
return str(report_path)
|
| 42 |
+
|
| 43 |
+
def _build_html(
|
| 44 |
+
self,
|
| 45 |
+
results: Dict[str, Any],
|
| 46 |
+
profile: Optional[Dict[str, Any]]
|
| 47 |
+
) -> str:
|
| 48 |
+
"""Build HTML content."""
|
| 49 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 50 |
+
|
| 51 |
+
# Extract label mapping to show original class names
|
| 52 |
+
label_mapping = {}
|
| 53 |
+
cleaning_report = results.get('cleaning_report', {})
|
| 54 |
+
if cleaning_report:
|
| 55 |
+
label_validation = cleaning_report.get('label_validation', {})
|
| 56 |
+
label_mapping = label_validation.get('label_mapping', {})
|
| 57 |
+
|
| 58 |
+
# Create reverse mapping: {0: 'original_negative', 1: 'original_positive'}
|
| 59 |
+
reverse_label_map = {str(v): k for k, v in label_mapping.items()} if label_mapping else {}
|
| 60 |
+
|
| 61 |
+
# Model comparison table
|
| 62 |
+
model_rows = ""
|
| 63 |
+
for model, score in results.get('model_scores', {}).items():
|
| 64 |
+
is_best = model == results.get('best_model', '')
|
| 65 |
+
highlight = 'style="background-color: #d4edda; font-weight: bold;"' if is_best else ''
|
| 66 |
+
model_rows += f"<tr {highlight}><td>{model}</td><td>{score:.4f}</td></tr>"
|
| 67 |
+
|
| 68 |
+
# Feature importance table
|
| 69 |
+
feature_rows = ""
|
| 70 |
+
importances = results.get('feature_importances', {})
|
| 71 |
+
sorted_features = sorted(importances.items(), key=lambda x: x[1], reverse=True)[:20]
|
| 72 |
+
for feature, importance in sorted_features:
|
| 73 |
+
feature_rows += f"<tr><td>{feature}</td><td>{importance:.4f}</td></tr>"
|
| 74 |
+
|
| 75 |
+
# Classification report with original class labels
|
| 76 |
+
clf_report = results.get('classification_report', {})
|
| 77 |
+
clf_rows = ""
|
| 78 |
+
for label, metrics in clf_report.items():
|
| 79 |
+
if isinstance(metrics, dict):
|
| 80 |
+
# Use original class name if available, otherwise use the label
|
| 81 |
+
display_label = reverse_label_map.get(str(label), label)
|
| 82 |
+
# For aggregate metrics like 'macro avg', keep the original name
|
| 83 |
+
if label in ['macro avg', 'weighted avg', 'accuracy']:
|
| 84 |
+
display_label = label
|
| 85 |
+
clf_rows += f"""<tr>
|
| 86 |
+
<td>{display_label}</td>
|
| 87 |
+
<td>{metrics.get('precision', 0):.3f}</td>
|
| 88 |
+
<td>{metrics.get('recall', 0):.3f}</td>
|
| 89 |
+
<td>{metrics.get('f1-score', 0):.3f}</td>
|
| 90 |
+
<td>{metrics.get('support', 0)}</td>
|
| 91 |
+
</tr>"""
|
| 92 |
+
|
| 93 |
+
# Profile section
|
| 94 |
+
profile_section = ""
|
| 95 |
+
if profile:
|
| 96 |
+
warnings_html = "".join([f"<li>{w}</li>" for w in profile.get('warnings', [])])
|
| 97 |
+
profile_section = f"""
|
| 98 |
+
<section class="card">
|
| 99 |
+
<h2>Data Profile</h2>
|
| 100 |
+
<div class="stats-grid">
|
| 101 |
+
<div class="stat">
|
| 102 |
+
<span class="stat-value">{profile.get('n_rows', 0):,}</span>
|
| 103 |
+
<span class="stat-label">Rows</span>
|
| 104 |
+
</div>
|
| 105 |
+
<div class="stat">
|
| 106 |
+
<span class="stat-value">{profile.get('n_cols', 0)}</span>
|
| 107 |
+
<span class="stat-label">Columns</span>
|
| 108 |
+
</div>
|
| 109 |
+
<div class="stat">
|
| 110 |
+
<span class="stat-value">{profile.get('n_numeric', 0)}</span>
|
| 111 |
+
<span class="stat-label">Numeric</span>
|
| 112 |
+
</div>
|
| 113 |
+
<div class="stat">
|
| 114 |
+
<span class="stat-value">{profile.get('n_categorical', 0)}</span>
|
| 115 |
+
<span class="stat-label">Categorical</span>
|
| 116 |
+
</div>
|
| 117 |
+
<div class="stat">
|
| 118 |
+
<span class="stat-value">{profile.get('missing_pct', 0):.1f}%</span>
|
| 119 |
+
<span class="stat-label">Missing</span>
|
| 120 |
+
</div>
|
| 121 |
+
</div>
|
| 122 |
+
{f'<h3>Warnings</h3><ul class="warnings">{warnings_html}</ul>' if warnings_html else ''}
|
| 123 |
+
</section>
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
html = f"""<!DOCTYPE html>
|
| 127 |
+
<html lang="en">
|
| 128 |
+
<head>
|
| 129 |
+
<meta charset="UTF-8">
|
| 130 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 131 |
+
<title>Credily Training Report</title>
|
| 132 |
+
<style>
|
| 133 |
+
* {{ box-sizing: border-box; margin: 0; padding: 0; }}
|
| 134 |
+
body {{
|
| 135 |
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
|
| 136 |
+
line-height: 1.6;
|
| 137 |
+
color: #333;
|
| 138 |
+
background: #f5f7fa;
|
| 139 |
+
padding: 2rem;
|
| 140 |
+
}}
|
| 141 |
+
.container {{ max-width: 1200px; margin: 0 auto; }}
|
| 142 |
+
header {{
|
| 143 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 144 |
+
color: white;
|
| 145 |
+
padding: 2rem;
|
| 146 |
+
border-radius: 12px;
|
| 147 |
+
margin-bottom: 2rem;
|
| 148 |
+
}}
|
| 149 |
+
header h1 {{ font-size: 2rem; margin-bottom: 0.5rem; }}
|
| 150 |
+
header p {{ opacity: 0.9; }}
|
| 151 |
+
.card {{
|
| 152 |
+
background: white;
|
| 153 |
+
border-radius: 12px;
|
| 154 |
+
padding: 1.5rem;
|
| 155 |
+
margin-bottom: 1.5rem;
|
| 156 |
+
box-shadow: 0 2px 8px rgba(0,0,0,0.08);
|
| 157 |
+
}}
|
| 158 |
+
.card h2 {{
|
| 159 |
+
color: #667eea;
|
| 160 |
+
margin-bottom: 1rem;
|
| 161 |
+
padding-bottom: 0.5rem;
|
| 162 |
+
border-bottom: 2px solid #f0f0f0;
|
| 163 |
+
}}
|
| 164 |
+
.stats-grid {{
|
| 165 |
+
display: grid;
|
| 166 |
+
grid-template-columns: repeat(auto-fit, minmax(120px, 1fr));
|
| 167 |
+
gap: 1rem;
|
| 168 |
+
}}
|
| 169 |
+
.stat {{
|
| 170 |
+
text-align: center;
|
| 171 |
+
padding: 1rem;
|
| 172 |
+
background: #f8f9fa;
|
| 173 |
+
border-radius: 8px;
|
| 174 |
+
}}
|
| 175 |
+
.stat-value {{
|
| 176 |
+
display: block;
|
| 177 |
+
font-size: 1.5rem;
|
| 178 |
+
font-weight: bold;
|
| 179 |
+
color: #667eea;
|
| 180 |
+
}}
|
| 181 |
+
.stat-label {{ color: #666; font-size: 0.9rem; }}
|
| 182 |
+
.best-model {{
|
| 183 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 184 |
+
color: white;
|
| 185 |
+
padding: 1.5rem;
|
| 186 |
+
border-radius: 8px;
|
| 187 |
+
text-align: center;
|
| 188 |
+
margin-bottom: 1rem;
|
| 189 |
+
}}
|
| 190 |
+
.best-model h3 {{ font-size: 1.2rem; margin-bottom: 0.5rem; }}
|
| 191 |
+
.best-model .score {{ font-size: 2rem; font-weight: bold; }}
|
| 192 |
+
table {{
|
| 193 |
+
width: 100%;
|
| 194 |
+
border-collapse: collapse;
|
| 195 |
+
margin-top: 1rem;
|
| 196 |
+
}}
|
| 197 |
+
th, td {{
|
| 198 |
+
padding: 0.75rem;
|
| 199 |
+
text-align: left;
|
| 200 |
+
border-bottom: 1px solid #eee;
|
| 201 |
+
}}
|
| 202 |
+
th {{ background: #f8f9fa; font-weight: 600; }}
|
| 203 |
+
tr:hover {{ background: #f8f9fa; }}
|
| 204 |
+
.warnings {{ color: #856404; background: #fff3cd; padding: 1rem; border-radius: 8px; margin-top: 1rem; }}
|
| 205 |
+
.warnings li {{ margin-left: 1.5rem; }}
|
| 206 |
+
footer {{ text-align: center; color: #666; margin-top: 2rem; font-size: 0.9rem; }}
|
| 207 |
+
</style>
|
| 208 |
+
</head>
|
| 209 |
+
<body>
|
| 210 |
+
<div class="container">
|
| 211 |
+
<header>
|
| 212 |
+
<h1>Credily Training Report</h1>
|
| 213 |
+
<p>Generated: {timestamp}</p>
|
| 214 |
+
</header>
|
| 215 |
+
|
| 216 |
+
{profile_section}
|
| 217 |
+
|
| 218 |
+
<section class="card">
|
| 219 |
+
<h2>Model Performance</h2>
|
| 220 |
+
<div class="best-model">
|
| 221 |
+
<h3>Best Model: {results.get('best_model', 'N/A')}</h3>
|
| 222 |
+
<div class="score">ROC-AUC: {results.get('best_score', 0):.4f}</div>
|
| 223 |
+
<p>Test AUC: {results.get('test_auc', 0):.4f}</p>
|
| 224 |
+
</div>
|
| 225 |
+
<h3>Model Comparison</h3>
|
| 226 |
+
<table>
|
| 227 |
+
<thead><tr><th>Model</th><th>CV ROC-AUC</th></tr></thead>
|
| 228 |
+
<tbody>{model_rows}</tbody>
|
| 229 |
+
</table>
|
| 230 |
+
</section>
|
| 231 |
+
|
| 232 |
+
<section class="card">
|
| 233 |
+
<h2>Classification Report</h2>
|
| 234 |
+
<table>
|
| 235 |
+
<thead>
|
| 236 |
+
<tr><th>Class</th><th>Precision</th><th>Recall</th><th>F1-Score</th><th>Support</th></tr>
|
| 237 |
+
</thead>
|
| 238 |
+
<tbody>{clf_rows}</tbody>
|
| 239 |
+
</table>
|
| 240 |
+
</section>
|
| 241 |
+
|
| 242 |
+
<section class="card">
|
| 243 |
+
<h2>Feature Importance (Top 20)</h2>
|
| 244 |
+
<table>
|
| 245 |
+
<thead><tr><th>Feature</th><th>Importance</th></tr></thead>
|
| 246 |
+
<tbody>{feature_rows}</tbody>
|
| 247 |
+
</table>
|
| 248 |
+
</section>
|
| 249 |
+
|
| 250 |
+
<footer>
|
| 251 |
+
<p>Generated by Credily - Fast, Explainable AutoML for Tabular Data</p>
|
| 252 |
+
</footer>
|
| 253 |
+
</div>
|
| 254 |
+
</body>
|
| 255 |
+
</html>"""
|
| 256 |
+
|
| 257 |
+
return html
|
credily/safety.py
ADDED
|
@@ -0,0 +1,634 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AutoML Safety & Leakage Prevention Module for Credily.
|
| 3 |
+
|
| 4 |
+
This module implements comprehensive checks to detect and prevent:
|
| 5 |
+
- Data leakage (features that directly encode the target)
|
| 6 |
+
- Feature dominance (single feature explaining too much variance)
|
| 7 |
+
- Overfitting (CV vs Test performance gap)
|
| 8 |
+
- Feature redundancy (highly correlated features)
|
| 9 |
+
|
| 10 |
+
These protections ensure production-safe, reliable models.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import numpy as np
|
| 15 |
+
from typing import Dict, List, Tuple, Optional, Any
|
| 16 |
+
from scipy import stats
|
| 17 |
+
from sklearn.feature_selection import mutual_info_classif
|
| 18 |
+
import warnings
|
| 19 |
+
|
| 20 |
+
# Suppress warnings for cleaner output
|
| 21 |
+
warnings.filterwarnings('ignore')
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# ============== Configuration Thresholds ==============
|
| 25 |
+
|
| 26 |
+
class SafetyConfig:
|
| 27 |
+
"""Configuration thresholds for safety checks."""
|
| 28 |
+
|
| 29 |
+
# Feature-Target Leakage
|
| 30 |
+
LEAKAGE_DROP_THRESHOLD = 0.95 # Auto-drop if |corr| >= this
|
| 31 |
+
LEAKAGE_WARN_THRESHOLD = 0.90 # Warn if |corr| >= this
|
| 32 |
+
|
| 33 |
+
# Feature-Feature Redundancy
|
| 34 |
+
REDUNDANCY_THRESHOLD = 0.98 # Drop one if |corr| >= this
|
| 35 |
+
|
| 36 |
+
# Feature Dominance (Post-Training)
|
| 37 |
+
DOMINANCE_INVALID_THRESHOLD = 0.85 # Mark model INVALID if max importance >= this
|
| 38 |
+
DOMINANCE_WARN_THRESHOLD = 0.70 # Warn if max importance >= this
|
| 39 |
+
|
| 40 |
+
# Overfitting Guard
|
| 41 |
+
OVERFIT_INVALID_GAP = 0.10 # Mark INVALID if CV-Test gap >= this
|
| 42 |
+
OVERFIT_WARN_GAP = 0.05 # Warn if gap >= this
|
| 43 |
+
|
| 44 |
+
# Minimum acceptable test performance
|
| 45 |
+
MIN_TEST_AUC = 0.60 # Must beat random baseline
|
| 46 |
+
|
| 47 |
+
# Target-related column name patterns to exclude
|
| 48 |
+
TARGET_NAME_PATTERNS = [
|
| 49 |
+
'target', 'label', 'outcome', 'result',
|
| 50 |
+
'default', 'churn', 'flag', 'class', 'status'
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class SafetyReport:
|
| 55 |
+
"""Container for safety check results."""
|
| 56 |
+
|
| 57 |
+
def __init__(self):
|
| 58 |
+
self.status = "PASS" # PASS, WARN, FAIL
|
| 59 |
+
self.dropped_features: Dict[str, str] = {} # feature: reason
|
| 60 |
+
self.warnings: List[str] = []
|
| 61 |
+
self.errors: List[str] = []
|
| 62 |
+
self.leakage_report: Dict[str, float] = {}
|
| 63 |
+
self.redundancy_report: List[Tuple[str, str, float]] = []
|
| 64 |
+
self.dominance_report: Dict[str, float] = {}
|
| 65 |
+
self.overfitting_report: Dict[str, float] = {}
|
| 66 |
+
self.model_valid = True
|
| 67 |
+
|
| 68 |
+
def add_dropped_feature(self, feature: str, reason: str):
|
| 69 |
+
"""Record a dropped feature."""
|
| 70 |
+
self.dropped_features[feature] = reason
|
| 71 |
+
|
| 72 |
+
def add_warning(self, message: str):
|
| 73 |
+
"""Add a warning message."""
|
| 74 |
+
self.warnings.append(message)
|
| 75 |
+
if self.status == "PASS":
|
| 76 |
+
self.status = "WARN"
|
| 77 |
+
|
| 78 |
+
def add_error(self, message: str):
|
| 79 |
+
"""Add an error and mark model as invalid."""
|
| 80 |
+
self.errors.append(message)
|
| 81 |
+
self.status = "FAIL"
|
| 82 |
+
self.model_valid = False
|
| 83 |
+
|
| 84 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 85 |
+
"""Convert report to dictionary for JSON serialization."""
|
| 86 |
+
return {
|
| 87 |
+
'status': self.status,
|
| 88 |
+
'model_valid': self.model_valid,
|
| 89 |
+
'dropped_features': self.dropped_features,
|
| 90 |
+
'warnings': self.warnings,
|
| 91 |
+
'errors': self.errors,
|
| 92 |
+
'leakage_detected': self.leakage_report,
|
| 93 |
+
'redundant_features': [
|
| 94 |
+
{'feature1': f1, 'feature2': f2, 'correlation': corr}
|
| 95 |
+
for f1, f2, corr in self.redundancy_report
|
| 96 |
+
],
|
| 97 |
+
'feature_dominance': self.dominance_report,
|
| 98 |
+
'overfitting_metrics': self.overfitting_report
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class SafetyValidator:
|
| 103 |
+
"""
|
| 104 |
+
Main safety validation class for AutoML pipelines.
|
| 105 |
+
|
| 106 |
+
Implements comprehensive checks to prevent data leakage,
|
| 107 |
+
feature dominance, and overfitting.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
def __init__(self, config: Optional[SafetyConfig] = None, verbose: bool = True):
|
| 111 |
+
"""
|
| 112 |
+
Initialize the safety validator.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
config: Safety configuration thresholds
|
| 116 |
+
verbose: Print detailed logs
|
| 117 |
+
"""
|
| 118 |
+
self.config = config or SafetyConfig()
|
| 119 |
+
self.verbose = verbose
|
| 120 |
+
self.report = SafetyReport()
|
| 121 |
+
|
| 122 |
+
def _log(self, message: str):
|
| 123 |
+
"""Print message if verbose mode is enabled."""
|
| 124 |
+
if self.verbose:
|
| 125 |
+
print(message)
|
| 126 |
+
|
| 127 |
+
# ============== Step 1: Column Hygiene ==============
|
| 128 |
+
|
| 129 |
+
def check_column_hygiene(
|
| 130 |
+
self,
|
| 131 |
+
df: pd.DataFrame,
|
| 132 |
+
target_column: str
|
| 133 |
+
) -> Tuple[pd.DataFrame, List[str]]:
|
| 134 |
+
"""
|
| 135 |
+
Check and clean column names, removing potential target leaks.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
df: Input dataframe
|
| 139 |
+
target_column: Name of target column
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
Tuple of (cleaned dataframe, list of dropped columns)
|
| 143 |
+
"""
|
| 144 |
+
self._log("\n" + "="*60)
|
| 145 |
+
self._log("[SAFETY] Step 1: Column Hygiene Check")
|
| 146 |
+
self._log("="*60)
|
| 147 |
+
|
| 148 |
+
dropped = []
|
| 149 |
+
df = df.copy()
|
| 150 |
+
|
| 151 |
+
for col in df.columns:
|
| 152 |
+
if col == target_column:
|
| 153 |
+
continue
|
| 154 |
+
|
| 155 |
+
col_lower = col.lower()
|
| 156 |
+
|
| 157 |
+
# Check for target-related names
|
| 158 |
+
for pattern in self.config.TARGET_NAME_PATTERNS:
|
| 159 |
+
if pattern in col_lower:
|
| 160 |
+
dropped.append(col)
|
| 161 |
+
self.report.add_dropped_feature(
|
| 162 |
+
col,
|
| 163 |
+
f"Column name contains target-related pattern: '{pattern}'"
|
| 164 |
+
)
|
| 165 |
+
self._log(f" [DROP] '{col}' - contains pattern '{pattern}'")
|
| 166 |
+
break
|
| 167 |
+
|
| 168 |
+
if dropped:
|
| 169 |
+
df = df.drop(columns=dropped, errors='ignore')
|
| 170 |
+
self._log(f" Dropped {len(dropped)} columns with target-related names")
|
| 171 |
+
else:
|
| 172 |
+
self._log(" No target-related column names detected")
|
| 173 |
+
|
| 174 |
+
return df, dropped
|
| 175 |
+
|
| 176 |
+
# ============== Step 2: Feature-Target Leakage Detection ==============
|
| 177 |
+
|
| 178 |
+
def _compute_correlation(
|
| 179 |
+
self,
|
| 180 |
+
feature: pd.Series,
|
| 181 |
+
target: pd.Series
|
| 182 |
+
) -> float:
|
| 183 |
+
"""
|
| 184 |
+
Compute correlation between a feature and binary/numeric target.
|
| 185 |
+
|
| 186 |
+
Uses Point-Biserial for numeric features, Cramér's V for categorical.
|
| 187 |
+
"""
|
| 188 |
+
# Handle missing values
|
| 189 |
+
mask = ~(feature.isna() | target.isna())
|
| 190 |
+
feature = feature[mask]
|
| 191 |
+
target = target[mask]
|
| 192 |
+
|
| 193 |
+
if len(feature) < 10:
|
| 194 |
+
return 0.0
|
| 195 |
+
|
| 196 |
+
try:
|
| 197 |
+
# Check if feature is numeric
|
| 198 |
+
if pd.api.types.is_numeric_dtype(feature):
|
| 199 |
+
# Point-biserial or Pearson correlation
|
| 200 |
+
corr, _ = stats.pearsonr(feature.astype(float), target.astype(float))
|
| 201 |
+
return abs(corr) if not np.isnan(corr) else 0.0
|
| 202 |
+
else:
|
| 203 |
+
# Cramér's V for categorical
|
| 204 |
+
return self._cramers_v(feature, target)
|
| 205 |
+
except Exception:
|
| 206 |
+
return 0.0
|
| 207 |
+
|
| 208 |
+
def _cramers_v(self, x: pd.Series, y: pd.Series) -> float:
|
| 209 |
+
"""Compute Cramér's V statistic for categorical-categorical association."""
|
| 210 |
+
try:
|
| 211 |
+
confusion_matrix = pd.crosstab(x, y)
|
| 212 |
+
chi2 = stats.chi2_contingency(confusion_matrix)[0]
|
| 213 |
+
n = confusion_matrix.sum().sum()
|
| 214 |
+
min_dim = min(confusion_matrix.shape) - 1
|
| 215 |
+
if min_dim == 0 or n == 0:
|
| 216 |
+
return 0.0
|
| 217 |
+
return np.sqrt(chi2 / (n * min_dim))
|
| 218 |
+
except Exception:
|
| 219 |
+
return 0.0
|
| 220 |
+
|
| 221 |
+
def detect_leakage(
|
| 222 |
+
self,
|
| 223 |
+
X: pd.DataFrame,
|
| 224 |
+
y: pd.Series
|
| 225 |
+
) -> Tuple[pd.DataFrame, List[str]]:
|
| 226 |
+
"""
|
| 227 |
+
Detect and remove features with high correlation to target.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
X: Feature dataframe
|
| 231 |
+
y: Target series
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
Tuple of (cleaned X, list of dropped features)
|
| 235 |
+
"""
|
| 236 |
+
self._log("\n" + "="*60)
|
| 237 |
+
self._log("[SAFETY] Step 2: Feature-Target Leakage Detection")
|
| 238 |
+
self._log("="*60)
|
| 239 |
+
|
| 240 |
+
dropped = []
|
| 241 |
+
correlations = {}
|
| 242 |
+
|
| 243 |
+
for col in X.columns:
|
| 244 |
+
corr = self._compute_correlation(X[col], y)
|
| 245 |
+
correlations[col] = corr
|
| 246 |
+
|
| 247 |
+
if corr >= self.config.LEAKAGE_DROP_THRESHOLD:
|
| 248 |
+
dropped.append(col)
|
| 249 |
+
self.report.add_dropped_feature(
|
| 250 |
+
col,
|
| 251 |
+
f"High correlation with target: {corr:.4f} (threshold: {self.config.LEAKAGE_DROP_THRESHOLD})"
|
| 252 |
+
)
|
| 253 |
+
self.report.leakage_report[col] = corr
|
| 254 |
+
self._log(f" [DROP] '{col}' - correlation: {corr:.4f} >= {self.config.LEAKAGE_DROP_THRESHOLD}")
|
| 255 |
+
|
| 256 |
+
elif corr >= self.config.LEAKAGE_WARN_THRESHOLD:
|
| 257 |
+
self.report.add_warning(
|
| 258 |
+
f"Feature '{col}' has high correlation with target: {corr:.4f}"
|
| 259 |
+
)
|
| 260 |
+
self.report.leakage_report[col] = corr
|
| 261 |
+
self._log(f" [WARN] '{col}' - correlation: {corr:.4f} (high-risk)")
|
| 262 |
+
|
| 263 |
+
if dropped:
|
| 264 |
+
X = X.drop(columns=dropped, errors='ignore')
|
| 265 |
+
self._log(f"\n Dropped {len(dropped)} leaky features")
|
| 266 |
+
else:
|
| 267 |
+
self._log(" No leakage detected")
|
| 268 |
+
|
| 269 |
+
return X, dropped
|
| 270 |
+
|
| 271 |
+
# ============== Step 3: Feature-Feature Redundancy ==============
|
| 272 |
+
|
| 273 |
+
def remove_redundant_features(
|
| 274 |
+
self,
|
| 275 |
+
X: pd.DataFrame,
|
| 276 |
+
y: pd.Series
|
| 277 |
+
) -> Tuple[pd.DataFrame, List[str]]:
|
| 278 |
+
"""
|
| 279 |
+
Remove highly correlated feature pairs, keeping the more informative one.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
X: Feature dataframe
|
| 283 |
+
y: Target series
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
Tuple of (cleaned X, list of dropped features)
|
| 287 |
+
"""
|
| 288 |
+
self._log("\n" + "="*60)
|
| 289 |
+
self._log("[SAFETY] Step 3: Feature-Feature Redundancy Check")
|
| 290 |
+
self._log("="*60)
|
| 291 |
+
|
| 292 |
+
dropped = []
|
| 293 |
+
|
| 294 |
+
# Only check numeric columns for correlation
|
| 295 |
+
numeric_cols = X.select_dtypes(include=[np.number]).columns.tolist()
|
| 296 |
+
|
| 297 |
+
if len(numeric_cols) < 2:
|
| 298 |
+
self._log(" Not enough numeric features to check redundancy")
|
| 299 |
+
return X, dropped
|
| 300 |
+
|
| 301 |
+
# Compute correlation matrix
|
| 302 |
+
try:
|
| 303 |
+
corr_matrix = X[numeric_cols].corr().abs()
|
| 304 |
+
except Exception as e:
|
| 305 |
+
self._log(f" Could not compute correlation matrix: {e}")
|
| 306 |
+
return X, dropped
|
| 307 |
+
|
| 308 |
+
# Find highly correlated pairs
|
| 309 |
+
pairs_checked = set()
|
| 310 |
+
|
| 311 |
+
for i, col1 in enumerate(numeric_cols):
|
| 312 |
+
for col2 in numeric_cols[i+1:]:
|
| 313 |
+
if (col1, col2) in pairs_checked or (col2, col1) in pairs_checked:
|
| 314 |
+
continue
|
| 315 |
+
|
| 316 |
+
pairs_checked.add((col1, col2))
|
| 317 |
+
|
| 318 |
+
corr = corr_matrix.loc[col1, col2]
|
| 319 |
+
|
| 320 |
+
if corr >= self.config.REDUNDANCY_THRESHOLD:
|
| 321 |
+
# Decide which to drop based on:
|
| 322 |
+
# 1. Lower missingness
|
| 323 |
+
# 2. Higher mutual information with target
|
| 324 |
+
|
| 325 |
+
miss1 = X[col1].isna().sum()
|
| 326 |
+
miss2 = X[col2].isna().sum()
|
| 327 |
+
|
| 328 |
+
# Keep the one with less missing
|
| 329 |
+
if miss1 != miss2:
|
| 330 |
+
to_drop = col1 if miss1 > miss2 else col2
|
| 331 |
+
else:
|
| 332 |
+
# Use mutual information as tiebreaker
|
| 333 |
+
try:
|
| 334 |
+
mi1 = mutual_info_classif(
|
| 335 |
+
X[[col1]].fillna(0), y, random_state=42
|
| 336 |
+
)[0]
|
| 337 |
+
mi2 = mutual_info_classif(
|
| 338 |
+
X[[col2]].fillna(0), y, random_state=42
|
| 339 |
+
)[0]
|
| 340 |
+
to_drop = col1 if mi1 < mi2 else col2
|
| 341 |
+
except Exception:
|
| 342 |
+
to_drop = col2 # Default to dropping second
|
| 343 |
+
|
| 344 |
+
if to_drop not in dropped:
|
| 345 |
+
dropped.append(to_drop)
|
| 346 |
+
self.report.add_dropped_feature(
|
| 347 |
+
to_drop,
|
| 348 |
+
f"Redundant with '{col1 if to_drop == col2 else col2}' (corr: {corr:.4f})"
|
| 349 |
+
)
|
| 350 |
+
self.report.redundancy_report.append((col1, col2, corr))
|
| 351 |
+
self._log(f" [DROP] '{to_drop}' - redundant with '{col1 if to_drop == col2 else col2}' (corr: {corr:.4f})")
|
| 352 |
+
|
| 353 |
+
if dropped:
|
| 354 |
+
X = X.drop(columns=dropped, errors='ignore')
|
| 355 |
+
self._log(f"\n Dropped {len(dropped)} redundant features")
|
| 356 |
+
else:
|
| 357 |
+
self._log(" No redundant feature pairs detected")
|
| 358 |
+
|
| 359 |
+
return X, dropped
|
| 360 |
+
|
| 361 |
+
# ============== Step 4: Feature Dominance Validation ==============
|
| 362 |
+
|
| 363 |
+
def validate_feature_dominance(
|
| 364 |
+
self,
|
| 365 |
+
feature_importances: Dict[str, float]
|
| 366 |
+
) -> bool:
|
| 367 |
+
"""
|
| 368 |
+
Check if any single feature dominates the model.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
feature_importances: Dictionary of feature -> importance
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
True if model passes validation, False otherwise
|
| 375 |
+
"""
|
| 376 |
+
self._log("\n" + "="*60)
|
| 377 |
+
self._log("[SAFETY] Step 4: Feature Dominance Validation")
|
| 378 |
+
self._log("="*60)
|
| 379 |
+
|
| 380 |
+
if not feature_importances:
|
| 381 |
+
self._log(" No feature importances provided")
|
| 382 |
+
return True
|
| 383 |
+
|
| 384 |
+
# Normalize importances
|
| 385 |
+
total = sum(feature_importances.values())
|
| 386 |
+
if total == 0:
|
| 387 |
+
return True
|
| 388 |
+
|
| 389 |
+
normalized = {k: v/total for k, v in feature_importances.items()}
|
| 390 |
+
self.report.dominance_report = normalized
|
| 391 |
+
|
| 392 |
+
# Find max importance
|
| 393 |
+
max_feature = max(normalized, key=normalized.get)
|
| 394 |
+
max_importance = normalized[max_feature]
|
| 395 |
+
|
| 396 |
+
self._log(f" Top feature: '{max_feature}' with {max_importance:.2%} importance")
|
| 397 |
+
|
| 398 |
+
if max_importance >= self.config.DOMINANCE_INVALID_THRESHOLD:
|
| 399 |
+
self.report.add_error(
|
| 400 |
+
f"Feature dominance detected: '{max_feature}' has {max_importance:.2%} importance. "
|
| 401 |
+
f"Model likely learned the target indirectly. Threshold: {self.config.DOMINANCE_INVALID_THRESHOLD:.0%}"
|
| 402 |
+
)
|
| 403 |
+
self._log(f" [FAIL] Feature dominance violation - model INVALID")
|
| 404 |
+
return False
|
| 405 |
+
|
| 406 |
+
elif max_importance >= self.config.DOMINANCE_WARN_THRESHOLD:
|
| 407 |
+
self.report.add_warning(
|
| 408 |
+
f"High feature importance: '{max_feature}' explains {max_importance:.2%} of predictions. "
|
| 409 |
+
f"Consider investigating this feature."
|
| 410 |
+
)
|
| 411 |
+
self._log(f" [WARN] High feature importance detected")
|
| 412 |
+
|
| 413 |
+
else:
|
| 414 |
+
self._log(" Feature importances are well-distributed")
|
| 415 |
+
|
| 416 |
+
return True
|
| 417 |
+
|
| 418 |
+
# ============== Step 5: Overfitting Guard ==============
|
| 419 |
+
|
| 420 |
+
def validate_overfitting(
|
| 421 |
+
self,
|
| 422 |
+
cv_score: float,
|
| 423 |
+
test_score: float,
|
| 424 |
+
metric_name: str = "AUC"
|
| 425 |
+
) -> bool:
|
| 426 |
+
"""
|
| 427 |
+
Check for overfitting by comparing CV and test performance.
|
| 428 |
+
|
| 429 |
+
Args:
|
| 430 |
+
cv_score: Cross-validation score
|
| 431 |
+
test_score: Held-out test score
|
| 432 |
+
metric_name: Name of the metric being compared
|
| 433 |
+
|
| 434 |
+
Returns:
|
| 435 |
+
True if model passes validation, False otherwise
|
| 436 |
+
"""
|
| 437 |
+
self._log("\n" + "="*60)
|
| 438 |
+
self._log("[SAFETY] Step 5: Overfitting Guard")
|
| 439 |
+
self._log("="*60)
|
| 440 |
+
|
| 441 |
+
gap = cv_score - test_score
|
| 442 |
+
|
| 443 |
+
self.report.overfitting_report = {
|
| 444 |
+
'cv_score': cv_score,
|
| 445 |
+
'test_score': test_score,
|
| 446 |
+
'gap': gap,
|
| 447 |
+
'metric': metric_name
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
self._log(f" CV {metric_name}: {cv_score:.4f}")
|
| 451 |
+
self._log(f" Test {metric_name}: {test_score:.4f}")
|
| 452 |
+
self._log(f" Gap: {gap:.4f}")
|
| 453 |
+
|
| 454 |
+
if gap >= self.config.OVERFIT_INVALID_GAP:
|
| 455 |
+
self.report.add_error(
|
| 456 |
+
f"Overfitting detected: CV-Test gap is {gap:.4f} "
|
| 457 |
+
f"(threshold: {self.config.OVERFIT_INVALID_GAP}). "
|
| 458 |
+
f"Model performance will not generalize."
|
| 459 |
+
)
|
| 460 |
+
self._log(f" [FAIL] Overfitting violation - model INVALID")
|
| 461 |
+
return False
|
| 462 |
+
|
| 463 |
+
elif gap >= self.config.OVERFIT_WARN_GAP:
|
| 464 |
+
self.report.add_warning(
|
| 465 |
+
f"Potential overfitting: CV-Test gap is {gap:.4f}. "
|
| 466 |
+
f"Monitor model performance closely in production."
|
| 467 |
+
)
|
| 468 |
+
self._log(f" [WARN] Potential overfitting detected")
|
| 469 |
+
|
| 470 |
+
else:
|
| 471 |
+
self._log(" No significant overfitting detected")
|
| 472 |
+
|
| 473 |
+
return True
|
| 474 |
+
|
| 475 |
+
# ============== Step 6: Model Acceptance Criteria ==============
|
| 476 |
+
|
| 477 |
+
def validate_model_acceptance(
|
| 478 |
+
self,
|
| 479 |
+
test_auc: float,
|
| 480 |
+
feature_importances: Dict[str, float],
|
| 481 |
+
cv_score: float
|
| 482 |
+
) -> bool:
|
| 483 |
+
"""
|
| 484 |
+
Final validation to determine if model can be exported.
|
| 485 |
+
|
| 486 |
+
Args:
|
| 487 |
+
test_auc: Test set AUC score
|
| 488 |
+
feature_importances: Feature importance dictionary
|
| 489 |
+
cv_score: Cross-validation score
|
| 490 |
+
|
| 491 |
+
Returns:
|
| 492 |
+
True if model passes all criteria, False otherwise
|
| 493 |
+
"""
|
| 494 |
+
self._log("\n" + "="*60)
|
| 495 |
+
self._log("[SAFETY] Step 6: Model Acceptance Criteria")
|
| 496 |
+
self._log("="*60)
|
| 497 |
+
|
| 498 |
+
passed = True
|
| 499 |
+
|
| 500 |
+
# Check minimum performance
|
| 501 |
+
if test_auc < self.config.MIN_TEST_AUC:
|
| 502 |
+
self.report.add_error(
|
| 503 |
+
f"Test AUC ({test_auc:.4f}) is below minimum threshold "
|
| 504 |
+
f"({self.config.MIN_TEST_AUC}). Model does not beat baseline."
|
| 505 |
+
)
|
| 506 |
+
passed = False
|
| 507 |
+
self._log(f" [FAIL] Test AUC below minimum threshold")
|
| 508 |
+
|
| 509 |
+
# Check feature dominance
|
| 510 |
+
dominance_ok = self.validate_feature_dominance(feature_importances)
|
| 511 |
+
if not dominance_ok:
|
| 512 |
+
passed = False
|
| 513 |
+
|
| 514 |
+
# Check overfitting
|
| 515 |
+
overfit_ok = self.validate_overfitting(cv_score, test_auc)
|
| 516 |
+
if not overfit_ok:
|
| 517 |
+
passed = False
|
| 518 |
+
|
| 519 |
+
# Final status
|
| 520 |
+
if passed and self.report.status != "FAIL":
|
| 521 |
+
self._log("\n [PASS] Model meets all acceptance criteria")
|
| 522 |
+
else:
|
| 523 |
+
self._log("\n [FAIL] Model does NOT meet acceptance criteria")
|
| 524 |
+
self.report.model_valid = False
|
| 525 |
+
|
| 526 |
+
return passed
|
| 527 |
+
|
| 528 |
+
# ============== Main Validation Pipeline ==============
|
| 529 |
+
|
| 530 |
+
def run_pre_training_checks(
|
| 531 |
+
self,
|
| 532 |
+
df: pd.DataFrame,
|
| 533 |
+
target_column: str
|
| 534 |
+
) -> Tuple[pd.DataFrame, SafetyReport]:
|
| 535 |
+
"""
|
| 536 |
+
Run all pre-training safety checks.
|
| 537 |
+
|
| 538 |
+
Args:
|
| 539 |
+
df: Input dataframe with features and target
|
| 540 |
+
target_column: Name of target column
|
| 541 |
+
|
| 542 |
+
Returns:
|
| 543 |
+
Tuple of (cleaned dataframe, safety report)
|
| 544 |
+
"""
|
| 545 |
+
self._log("\n" + "#"*60)
|
| 546 |
+
self._log("# SAFETY VALIDATION - PRE-TRAINING CHECKS")
|
| 547 |
+
self._log("#"*60)
|
| 548 |
+
|
| 549 |
+
# Step 1: Column hygiene
|
| 550 |
+
df, _ = self.check_column_hygiene(df, target_column)
|
| 551 |
+
|
| 552 |
+
# Extract X and y
|
| 553 |
+
X = df.drop(columns=[target_column])
|
| 554 |
+
y = df[target_column]
|
| 555 |
+
|
| 556 |
+
# Step 2: Leakage detection
|
| 557 |
+
X, _ = self.detect_leakage(X, y)
|
| 558 |
+
|
| 559 |
+
# Step 3: Redundancy removal
|
| 560 |
+
X, _ = self.remove_redundant_features(X, y)
|
| 561 |
+
|
| 562 |
+
# Reconstruct dataframe
|
| 563 |
+
df_clean = pd.concat([X, y], axis=1)
|
| 564 |
+
|
| 565 |
+
self._log("\n" + "#"*60)
|
| 566 |
+
self._log(f"# PRE-TRAINING CHECKS COMPLETE")
|
| 567 |
+
self._log(f"# Status: {self.report.status}")
|
| 568 |
+
self._log(f"# Features dropped: {len(self.report.dropped_features)}")
|
| 569 |
+
self._log(f"# Warnings: {len(self.report.warnings)}")
|
| 570 |
+
self._log("#"*60 + "\n")
|
| 571 |
+
|
| 572 |
+
return df_clean, self.report
|
| 573 |
+
|
| 574 |
+
def run_post_training_checks(
|
| 575 |
+
self,
|
| 576 |
+
feature_importances: Dict[str, float],
|
| 577 |
+
cv_score: float,
|
| 578 |
+
test_auc: float
|
| 579 |
+
) -> SafetyReport:
|
| 580 |
+
"""
|
| 581 |
+
Run all post-training safety checks.
|
| 582 |
+
|
| 583 |
+
Args:
|
| 584 |
+
feature_importances: Feature importance dictionary
|
| 585 |
+
cv_score: Cross-validation score
|
| 586 |
+
test_auc: Test set AUC
|
| 587 |
+
|
| 588 |
+
Returns:
|
| 589 |
+
Updated safety report
|
| 590 |
+
"""
|
| 591 |
+
self._log("\n" + "#"*60)
|
| 592 |
+
self._log("# SAFETY VALIDATION - POST-TRAINING CHECKS")
|
| 593 |
+
self._log("#"*60)
|
| 594 |
+
|
| 595 |
+
# Run acceptance validation
|
| 596 |
+
self.validate_model_acceptance(test_auc, feature_importances, cv_score)
|
| 597 |
+
|
| 598 |
+
self._log("\n" + "#"*60)
|
| 599 |
+
self._log(f"# POST-TRAINING CHECKS COMPLETE")
|
| 600 |
+
self._log(f"# Final Status: {self.report.status}")
|
| 601 |
+
self._log(f"# Model Valid: {self.report.model_valid}")
|
| 602 |
+
self._log("#"*60 + "\n")
|
| 603 |
+
|
| 604 |
+
return self.report
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
# ============== Utility Functions ==============
|
| 608 |
+
|
| 609 |
+
def check_perfect_score_warning(metrics: Dict[str, float]) -> List[str]:
|
| 610 |
+
"""
|
| 611 |
+
Check for suspiciously perfect scores that may indicate leakage.
|
| 612 |
+
|
| 613 |
+
Args:
|
| 614 |
+
metrics: Dictionary of metric names to scores
|
| 615 |
+
|
| 616 |
+
Returns:
|
| 617 |
+
List of warning messages
|
| 618 |
+
"""
|
| 619 |
+
warnings = []
|
| 620 |
+
|
| 621 |
+
for metric, score in metrics.items():
|
| 622 |
+
if score >= 0.999:
|
| 623 |
+
warnings.append(
|
| 624 |
+
f"CRITICAL: {metric} = {score:.4f} is suspiciously perfect. "
|
| 625 |
+
f"This almost certainly indicates data leakage. "
|
| 626 |
+
f"Do NOT trust this model."
|
| 627 |
+
)
|
| 628 |
+
elif score >= 0.98:
|
| 629 |
+
warnings.append(
|
| 630 |
+
f"WARNING: {metric} = {score:.4f} is very high. "
|
| 631 |
+
f"Verify there is no data leakage before using this model."
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
return warnings
|
credily/utils.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for Credily.
|
| 3 |
+
File loading, format detection, and export utilities.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Optional, Union
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_data(file_path: Union[str, Path], **kwargs) -> pd.DataFrame:
|
| 12 |
+
"""
|
| 13 |
+
Load data from various file formats (CSV, TXT, Excel).
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
file_path: Path to the data file
|
| 17 |
+
**kwargs: Additional arguments passed to the reader
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
DataFrame with loaded data
|
| 21 |
+
|
| 22 |
+
Supported formats:
|
| 23 |
+
- .csv: Comma-separated values
|
| 24 |
+
- .txt: Tab or comma-separated text files
|
| 25 |
+
- .xlsx, .xls: Excel files
|
| 26 |
+
- .tsv: Tab-separated values
|
| 27 |
+
"""
|
| 28 |
+
path = Path(file_path)
|
| 29 |
+
suffix = path.suffix.lower()
|
| 30 |
+
|
| 31 |
+
if suffix == '.csv':
|
| 32 |
+
return pd.read_csv(path, **kwargs)
|
| 33 |
+
|
| 34 |
+
elif suffix == '.txt':
|
| 35 |
+
# Try to auto-detect delimiter for text files
|
| 36 |
+
with open(path, 'r', encoding='utf-8', errors='ignore') as f:
|
| 37 |
+
first_line = f.readline()
|
| 38 |
+
|
| 39 |
+
# Detect delimiter
|
| 40 |
+
if '\t' in first_line:
|
| 41 |
+
delimiter = '\t'
|
| 42 |
+
elif ';' in first_line:
|
| 43 |
+
delimiter = ';'
|
| 44 |
+
elif '|' in first_line:
|
| 45 |
+
delimiter = '|'
|
| 46 |
+
else:
|
| 47 |
+
delimiter = ','
|
| 48 |
+
|
| 49 |
+
return pd.read_csv(path, delimiter=delimiter, **kwargs)
|
| 50 |
+
|
| 51 |
+
elif suffix == '.tsv':
|
| 52 |
+
return pd.read_csv(path, delimiter='\t', **kwargs)
|
| 53 |
+
|
| 54 |
+
elif suffix in ['.xlsx', '.xls']:
|
| 55 |
+
# Check if openpyxl is available for xlsx
|
| 56 |
+
try:
|
| 57 |
+
if suffix == '.xlsx':
|
| 58 |
+
return pd.read_excel(path, engine='openpyxl', **kwargs)
|
| 59 |
+
else:
|
| 60 |
+
return pd.read_excel(path, **kwargs)
|
| 61 |
+
except ImportError:
|
| 62 |
+
raise ImportError(
|
| 63 |
+
"Excel support requires 'openpyxl' package. "
|
| 64 |
+
"Install it with: pip install openpyxl"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError(
|
| 69 |
+
f"Unsupported file format: '{suffix}'. "
|
| 70 |
+
f"Supported formats: .csv, .txt, .tsv, .xlsx, .xls"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def save_to_excel(
|
| 75 |
+
df: pd.DataFrame,
|
| 76 |
+
file_path: Union[str, Path],
|
| 77 |
+
sheet_name: str = 'Predictions',
|
| 78 |
+
include_summary: bool = True
|
| 79 |
+
) -> str:
|
| 80 |
+
"""
|
| 81 |
+
Save DataFrame to Excel with optional summary sheet.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
df: DataFrame to save
|
| 85 |
+
file_path: Output path (will add .xlsx if needed)
|
| 86 |
+
sheet_name: Name of the main data sheet
|
| 87 |
+
include_summary: Whether to include a summary sheet
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
Path to saved file
|
| 91 |
+
"""
|
| 92 |
+
try:
|
| 93 |
+
from openpyxl import Workbook
|
| 94 |
+
from openpyxl.utils.dataframe import dataframe_to_rows
|
| 95 |
+
from openpyxl.styles import Font, PatternFill, Alignment, Border, Side
|
| 96 |
+
except ImportError:
|
| 97 |
+
raise ImportError(
|
| 98 |
+
"Excel export requires 'openpyxl' package. "
|
| 99 |
+
"Install it with: pip install openpyxl"
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
path = Path(file_path)
|
| 103 |
+
if path.suffix.lower() not in ['.xlsx', '.xls']:
|
| 104 |
+
path = path.with_suffix('.xlsx')
|
| 105 |
+
|
| 106 |
+
wb = Workbook()
|
| 107 |
+
ws = wb.active
|
| 108 |
+
ws.title = sheet_name
|
| 109 |
+
|
| 110 |
+
# Style definitions
|
| 111 |
+
header_font = Font(bold=True, color='FFFFFF')
|
| 112 |
+
header_fill = PatternFill(start_color='667eea', end_color='667eea', fill_type='solid')
|
| 113 |
+
thin_border = Border(
|
| 114 |
+
left=Side(style='thin'),
|
| 115 |
+
right=Side(style='thin'),
|
| 116 |
+
top=Side(style='thin'),
|
| 117 |
+
bottom=Side(style='thin')
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Write data
|
| 121 |
+
for r_idx, row in enumerate(dataframe_to_rows(df, index=False, header=True), 1):
|
| 122 |
+
for c_idx, value in enumerate(row, 1):
|
| 123 |
+
cell = ws.cell(row=r_idx, column=c_idx, value=value)
|
| 124 |
+
cell.border = thin_border
|
| 125 |
+
|
| 126 |
+
if r_idx == 1: # Header row
|
| 127 |
+
cell.font = header_font
|
| 128 |
+
cell.fill = header_fill
|
| 129 |
+
cell.alignment = Alignment(horizontal='center')
|
| 130 |
+
|
| 131 |
+
# Auto-adjust column widths
|
| 132 |
+
for column in ws.columns:
|
| 133 |
+
max_length = 0
|
| 134 |
+
column_letter = column[0].column_letter
|
| 135 |
+
for cell in column:
|
| 136 |
+
try:
|
| 137 |
+
if len(str(cell.value)) > max_length:
|
| 138 |
+
max_length = len(str(cell.value))
|
| 139 |
+
except:
|
| 140 |
+
pass
|
| 141 |
+
adjusted_width = min(max_length + 2, 50)
|
| 142 |
+
ws.column_dimensions[column_letter].width = adjusted_width
|
| 143 |
+
|
| 144 |
+
# Add summary sheet if predictions exist
|
| 145 |
+
if include_summary and 'prediction' in df.columns:
|
| 146 |
+
summary_ws = wb.create_sheet('Summary')
|
| 147 |
+
|
| 148 |
+
# Prediction distribution
|
| 149 |
+
pred_counts = df['prediction'].value_counts().sort_index()
|
| 150 |
+
|
| 151 |
+
summary_data = [
|
| 152 |
+
['Prediction Summary Report'],
|
| 153 |
+
[''],
|
| 154 |
+
['Total Records', len(df)],
|
| 155 |
+
[''],
|
| 156 |
+
['Prediction Distribution'],
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
for pred_val, count in pred_counts.items():
|
| 160 |
+
pct = count / len(df) * 100
|
| 161 |
+
summary_data.append([f'Class {pred_val}', count, f'{pct:.1f}%'])
|
| 162 |
+
|
| 163 |
+
# Add probability stats if available
|
| 164 |
+
if 'proba_1' in df.columns:
|
| 165 |
+
summary_data.extend([
|
| 166 |
+
[''],
|
| 167 |
+
['Probability Statistics (Class 1)'],
|
| 168 |
+
['Mean', f"{df['proba_1'].mean():.4f}"],
|
| 169 |
+
['Median', f"{df['proba_1'].median():.4f}"],
|
| 170 |
+
['Min', f"{df['proba_1'].min():.4f}"],
|
| 171 |
+
['Max', f"{df['proba_1'].max():.4f}"],
|
| 172 |
+
])
|
| 173 |
+
|
| 174 |
+
if 'threshold_used' in df.columns:
|
| 175 |
+
summary_data.extend([
|
| 176 |
+
[''],
|
| 177 |
+
['Threshold Used', df['threshold_used'].iloc[0]],
|
| 178 |
+
])
|
| 179 |
+
|
| 180 |
+
for r_idx, row in enumerate(summary_data, 1):
|
| 181 |
+
for c_idx, value in enumerate(row, 1):
|
| 182 |
+
cell = summary_ws.cell(row=r_idx, column=c_idx, value=value)
|
| 183 |
+
if r_idx == 1:
|
| 184 |
+
cell.font = Font(bold=True, size=14)
|
| 185 |
+
elif value in ['Prediction Distribution', 'Probability Statistics (Class 1)']:
|
| 186 |
+
cell.font = Font(bold=True)
|
| 187 |
+
|
| 188 |
+
# Adjust column widths
|
| 189 |
+
summary_ws.column_dimensions['A'].width = 30
|
| 190 |
+
summary_ws.column_dimensions['B'].width = 15
|
| 191 |
+
summary_ws.column_dimensions['C'].width = 15
|
| 192 |
+
|
| 193 |
+
wb.save(path)
|
| 194 |
+
return str(path)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def get_supported_formats() -> str:
|
| 198 |
+
"""Return a string listing supported file formats."""
|
| 199 |
+
return "Supported formats: CSV (.csv), Text (.txt, .tsv), Excel (.xlsx, .xls)"
|
debug_output/model.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1df29e3b079f840118439eec948ecbafe781fc11d53971fbfe288fef7d919aab
|
| 3 |
+
size 2274753
|