github-actions[bot] commited on
Commit ·
26df127
0
Parent(s):
Deploy from GitHub Actions
Browse files- Dockerfile +27 -0
- LICENSE +21 -0
- README.md +140 -0
- backend/__init__.py +1 -0
- backend/app.py +349 -0
- backend/schemas.py +104 -0
- frontend/__init__.py +1 -0
- frontend/app.py +269 -0
- requirements.txt +17 -0
- src/__init__.py +1 -0
- src/custom_transformers.py +596 -0
- src/global_constants.py +502 -0
- src/pipeline.py +122 -0
- src/utils.py +20 -0
- start.sh +26 -0
Dockerfile
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use Python 3.10 slim image based on Linux Debian Bookworm distribution (suitable for ML workloads)
|
| 2 |
+
FROM python:3.10-slim-bookworm
|
| 3 |
+
|
| 4 |
+
# Set the working directory in the container
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Install curl (to download GeoLite2), tar (to unzip it) & dos2unix (to fix Windows line endings)
|
| 8 |
+
RUN apt-get update && apt-get install -y --no-install-recommends curl tar dos2unix && rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
# Copy the requirements file and install dependencies
|
| 11 |
+
COPY requirements.txt .
|
| 12 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 13 |
+
|
| 14 |
+
# Copy the rest of the application source code
|
| 15 |
+
COPY ./src ./src
|
| 16 |
+
COPY ./backend ./backend
|
| 17 |
+
COPY ./frontend ./frontend
|
| 18 |
+
COPY ./start.sh .
|
| 19 |
+
|
| 20 |
+
# Fix start script line endings and make it executable
|
| 21 |
+
RUN dos2unix ./start.sh && chmod +x ./start.sh
|
| 22 |
+
|
| 23 |
+
# Expose Gradio port
|
| 24 |
+
EXPOSE 7860
|
| 25 |
+
|
| 26 |
+
# Command to run the application
|
| 27 |
+
CMD ["./start.sh"]
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Jens Bender
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
- The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Loan Default Prediction
|
| 3 |
+
subtitle: Submit customer application data to predict loan default
|
| 4 |
+
emoji: 💰
|
| 5 |
+
colorFrom: indigo
|
| 6 |
+
colorTo: green
|
| 7 |
+
sdk: docker
|
| 8 |
+
app_port: 7860
|
| 9 |
+
pinned: false
|
| 10 |
+
models:
|
| 11 |
+
- JensBender/loan-default-prediction-pipeline
|
| 12 |
+
tags:
|
| 13 |
+
- finance
|
| 14 |
+
- credit-risk
|
| 15 |
+
- loan-default
|
| 16 |
+
- tabular-data
|
| 17 |
+
- scikit-learn
|
| 18 |
+
- random-forest
|
| 19 |
+
- gradio
|
| 20 |
+
- fastapi
|
| 21 |
+
- docker
|
| 22 |
+
---
|
| 23 |
+
|
| 24 |
+
## 🏦 Loan Default Prediction App
|
| 25 |
+
A web application that predicts loan default based on customer application data, helping financial institutions make data-driven lending decisions.
|
| 26 |
+
Built with `Gradio`, `FastAPI`, and a `scikit-learn` Random Forest model trained on over 250,000 loan applications.
|
| 27 |
+
|
| 28 |
+
### How to Use
|
| 29 |
+
1. **Fill in Form**: Enter applicant details such as age, income, and experience.
|
| 30 |
+
2. **Click Predict**: The app will process your input and return a "Default" or "No Default" prediction along with probabilities.
|
| 31 |
+
3. **Interpret Responsibly**: Use the prediction to support decision making, do **not** use for fully automated decisions without human oversight.
|
| 32 |
+
|
| 33 |
+
### Use via API
|
| 34 |
+
You can also send requests directly to the FastAPI backend for programmatic access. This is useful for integrating the model into other applications or systems.
|
| 35 |
+
|
| 36 |
+
Example API usage with Python's `requests` library:
|
| 37 |
+
```python
|
| 38 |
+
import requests
|
| 39 |
+
|
| 40 |
+
# Create example applicant data (JSON payload)
|
| 41 |
+
applicant_data = {
|
| 42 |
+
"income": 300000,
|
| 43 |
+
"age": 30,
|
| 44 |
+
"experience": 3,
|
| 45 |
+
"married": "single",
|
| 46 |
+
"house_ownership": "rented",
|
| 47 |
+
"car_ownership": "no",
|
| 48 |
+
"profession": "artist",
|
| 49 |
+
"city": "sikar",
|
| 50 |
+
"state": "rajasthan",
|
| 51 |
+
"current_job_yrs": 3,
|
| 52 |
+
"current_house_yrs": 11,
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
# API request to FastAPI predict endpoint on Hugging Face Spaces
|
| 56 |
+
prediction_api_url = "https://jensbender-loan-default-prediction-app.hf.space/api/predict"
|
| 57 |
+
response = requests.post(prediction_api_url, json=applicant_data)
|
| 58 |
+
|
| 59 |
+
# Check if request was successful
|
| 60 |
+
response.raise_for_status()
|
| 61 |
+
|
| 62 |
+
# Extract prediction and probability of default
|
| 63 |
+
prediction_response = response.json()
|
| 64 |
+
prediction_result = prediction_response["results"][0]
|
| 65 |
+
prediction = prediction_result["prediction"]
|
| 66 |
+
default_probability = prediction_result["probabilities"]["Default"]
|
| 67 |
+
|
| 68 |
+
# Show results
|
| 69 |
+
print(f"Probability of default: {default_probability * 100:.1f}% (threshold: 29.0%)")
|
| 70 |
+
print(f"Prediction: {prediction}")
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
### How It Works
|
| 74 |
+
1. **Gradio Frontend (UI Layer)**
|
| 75 |
+
- Provides a clean and simple form for data entry.
|
| 76 |
+
- Sends form data as JSON to the backend API.
|
| 77 |
+
- Displays prediction results and probabilities in real time.
|
| 78 |
+
2. **FastAPI Backend (API Layer)**
|
| 79 |
+
- Receives requests from the `Gradio` frontend or direct API requests.
|
| 80 |
+
- Loads the pre-trained pipeline from the [Hugging Face Hub](https://huggingface.co/JensBender/loan-default-prediction-pipeline).
|
| 81 |
+
- Validates and passes data through the pipeline, and applies the decision threshold.
|
| 82 |
+
- Returns JSON responses containing predictions and probabilities.
|
| 83 |
+
3. **ML Pipeline (Model Layer)**
|
| 84 |
+
- Implements a `scikit-learn` pipeline with a Random Forest Classifier model and preprocessing.
|
| 85 |
+
- Performs feature engineering, scaling, and encoding.
|
| 86 |
+
- Outputs predicted probabilities for both classes ("Default" and "No Default").
|
| 87 |
+
4. **Deployment Environment**
|
| 88 |
+
- Packaged as a single `Docker` container.
|
| 89 |
+
- Runs seamlessly on Hugging Face Spaces using the Docker SDK.
|
| 90 |
+
|
| 91 |
+
### Model Performance
|
| 92 |
+
The Random Forest model achieved an **AUC-PR of 0.59** on the test set. The most influential features are income, age, and state default rate (derived via feature engineering).
|
| 93 |
+
|
| 94 |
+
**Classification Report (Test)**
|
| 95 |
+
| | Precision | Recall | F1-Score | Samples |
|
| 96 |
+
|:-----------------------|:----------|:-------|:---------|:--------|
|
| 97 |
+
| Class 0: Non-Defaulter | 0.97 | 0.90 | 0.93 | 22,122 |
|
| 98 |
+
| Class 1: Defaulter | 0.51 | 0.79 | 0.62 | 3,078 |
|
| 99 |
+
| Accuracy | | | 0.88 | 25,200 |
|
| 100 |
+
| Macro Avg | 0.74 | 0.84 | 0.78 | 25,200 |
|
| 101 |
+
| Weighted Avg | 0.91 | 0.88 | 0.89 | 25,200 |
|
| 102 |
+
|
| 103 |
+
### Resources
|
| 104 |
+
| Component | Description | Link |
|
| 105 |
+
|------------|--------------|------|
|
| 106 |
+
| **Source Code** | Full project repository with training, evaluation, and deployment scripts | [GitHub](https://github.com/JensBender/loan-default-prediction) |
|
| 107 |
+
| **Model Pipeline** | Pre-trained `scikit-learn` pipeline with Random Forest Classifier and preprocessing | [Hugging Face Hub](https://huggingface.co/JensBender/loan-default-prediction-pipeline) |
|
| 108 |
+
| **Web App** | Live, interactive demo with Gradio frontend and FastAPI backend | [Hugging Face Spaces](https://huggingface.co/spaces/JensBender/loan-default-prediction-app) |
|
| 109 |
+
|
| 110 |
+
### Responsible Use
|
| 111 |
+
The model and by extension this web app and API are intended to be used as a tool to support credit risk assessment. They can be integrated into decision-making workflows to provide a quantitative measure of default risk for loan applicants.
|
| 112 |
+
|
| 113 |
+
This model is **not** intended for:
|
| 114 |
+
- Fully automated lending decisions without human oversight. The model's predictions should not be the sole factor in any financial decision.
|
| 115 |
+
- Evaluating applicants from demographic, geographic, or socioeconomic backgrounds not represented in the training data.
|
| 116 |
+
- Use in a production environment without rigorous, ongoing validation and fairness audits.
|
| 117 |
+
|
| 118 |
+
### Bias, Risks, and Limitations
|
| 119 |
+
The model was trained on historical data that may carry biases related to socioeconomic status, geography, or other demographic factors, potentially leading to unfair predictions for certain groups. The model can be overconfident on misclassified edge cases, assigning high probabilities to incorrect predictions. Confidence scores should not be relied upon without additional scrutiny.
|
| 120 |
+
|
| 121 |
+
**Recommendations**
|
| 122 |
+
- **Human in the Loop:** Always use this model as part of a broader decision-making framework that includes human oversight.
|
| 123 |
+
- **Fairness and Bias Audits:** Before deploying this model in a production environment, conduct thorough fairness and bias analyses to ensure it performs equally across different demographic groups.
|
| 124 |
+
- **Model Monitoring:** Continuously monitor the model's performance and predictions to detect and mitigate any performance degradation or emerging biases.
|
| 125 |
+
|
| 126 |
+
### License
|
| 127 |
+
The source code for this web app on Hugging Face Spaces and the source code of the overall project on [GitHub](https://github.com/JensBender/loan-default-prediction) is licensed under the [MIT License](LICENSE). The model pipeline is licensed under [Apache-2.0](https://huggingface.co/JensBender/loan-default-prediction-pipeline/resolve/main/LICENSE).
|
| 128 |
+
|
| 129 |
+
### Citation
|
| 130 |
+
If you use this model or app in your work, please cite it as follows:
|
| 131 |
+
```bibtex
|
| 132 |
+
@misc{bender_loan_default_prediction_2025,
|
| 133 |
+
author = {Bender, Jens},
|
| 134 |
+
title = {Loan Default Prediction Pipeline},
|
| 135 |
+
year = {2025},
|
| 136 |
+
publisher = {Hugging Face},
|
| 137 |
+
url = {https://huggingface.co/JensBender/loan-default-prediction-pipeline},
|
| 138 |
+
note = {Version 1.0. A scikit-learn Random Forest pipeline for predicting loan defaults. Trained on 252,000 loan applications. Source code available at \url{https://github.com/JensBender/loan-default-prediction}. Licensed under Apache-2.0.}
|
| 139 |
+
}
|
| 140 |
+
```
|
backend/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# This file marks the directory as a Python package.
|
backend/app.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- Imports ---
|
| 2 |
+
# Standard library imports
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any
|
| 5 |
+
import logging
|
| 6 |
+
import logging.config
|
| 7 |
+
import json
|
| 8 |
+
import uuid
|
| 9 |
+
import time
|
| 10 |
+
from datetime import datetime, timezone
|
| 11 |
+
|
| 12 |
+
# Third-party library imports
|
| 13 |
+
from fastapi import FastAPI, HTTPException, Request
|
| 14 |
+
from fastapi.responses import RedirectResponse
|
| 15 |
+
from sklearn.pipeline import Pipeline
|
| 16 |
+
import gradio as gr
|
| 17 |
+
import pandas as pd
|
| 18 |
+
import numpy as np
|
| 19 |
+
import joblib
|
| 20 |
+
from huggingface_hub import hf_hub_download
|
| 21 |
+
import geoip2.database
|
| 22 |
+
from geoip2.errors import AddressNotFoundError
|
| 23 |
+
|
| 24 |
+
# Local imports
|
| 25 |
+
from backend.schemas import (
|
| 26 |
+
PipelineInput,
|
| 27 |
+
PredictionEnum,
|
| 28 |
+
PredictedProbabilities,
|
| 29 |
+
PredictionResult,
|
| 30 |
+
PredictionResponse
|
| 31 |
+
)
|
| 32 |
+
from frontend.app import gradio_app
|
| 33 |
+
from src.custom_transformers import (
|
| 34 |
+
MissingValueChecker,
|
| 35 |
+
MissingValueStandardizer,
|
| 36 |
+
RobustSimpleImputer,
|
| 37 |
+
SnakeCaseFormatter,
|
| 38 |
+
BooleanColumnTransformer,
|
| 39 |
+
JobStabilityTransformer,
|
| 40 |
+
CityTierTransformer,
|
| 41 |
+
StateDefaultRateTargetEncoder,
|
| 42 |
+
RobustStandardScaler,
|
| 43 |
+
RobustOneHotEncoder,
|
| 44 |
+
RobustOrdinalEncoder,
|
| 45 |
+
FeatureSelector
|
| 46 |
+
)
|
| 47 |
+
from src.utils import get_root_directory
|
| 48 |
+
|
| 49 |
+
# --- Logging ---
|
| 50 |
+
# Create logs directory if it doesn't exist
|
| 51 |
+
log_dir = Path("logs")
|
| 52 |
+
log_dir.mkdir(exist_ok=True)
|
| 53 |
+
|
| 54 |
+
# Define logging configuration
|
| 55 |
+
LOGGING_CONFIG = {
|
| 56 |
+
"version": 1,
|
| 57 |
+
"disable_existing_loggers": False,
|
| 58 |
+
"formatters": {
|
| 59 |
+
"default": {
|
| 60 |
+
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 61 |
+
},
|
| 62 |
+
"monitoring": {
|
| 63 |
+
"format": "%(message)s",
|
| 64 |
+
},
|
| 65 |
+
},
|
| 66 |
+
"handlers": {
|
| 67 |
+
"console": {
|
| 68 |
+
"class": "logging.StreamHandler",
|
| 69 |
+
"formatter": "default",
|
| 70 |
+
"stream": "ext://sys.stdout", # write to Python standard output stream, which goes to Docker container's standard output, which goes to Hugging Face host server, which goes to Hugging Face Space Logs tab
|
| 71 |
+
},
|
| 72 |
+
"monitoring_file": {
|
| 73 |
+
"class": "logging.handlers.RotatingFileHandler",
|
| 74 |
+
"formatter": "monitoring",
|
| 75 |
+
"filename": str(log_dir / "prediction_logs.jsonl"), # JSON Lines format: one JSON object per line
|
| 76 |
+
"maxBytes": 10485760, # 10 MB
|
| 77 |
+
"backupCount": 3, # will create prediction_logs.jsonl.1, .2, and .3 for max 4 log files (40 MB), then overwrite
|
| 78 |
+
},
|
| 79 |
+
},
|
| 80 |
+
"loggers": {
|
| 81 |
+
"": { # root logger for general logs
|
| 82 |
+
"handlers": ["console"],
|
| 83 |
+
"level": "INFO",
|
| 84 |
+
},
|
| 85 |
+
"monitoring": { # prediction records logger for model monitoring
|
| 86 |
+
"handlers": ["monitoring_file"],
|
| 87 |
+
"level": "INFO",
|
| 88 |
+
"propagate": False,
|
| 89 |
+
},
|
| 90 |
+
},
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
# Apply logging configuration
|
| 94 |
+
logging.config.dictConfig(LOGGING_CONFIG)
|
| 95 |
+
|
| 96 |
+
# Get loggers
|
| 97 |
+
logger = logging.getLogger(__name__)
|
| 98 |
+
monitoring_logger = logging.getLogger("monitoring")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# --- Helper Functions ---
|
| 102 |
+
# Function to get batch-level metadata for logging
|
| 103 |
+
def get_batch_metadata(
|
| 104 |
+
pipeline_input_dict_ls: list[dict[str, Any]],
|
| 105 |
+
request: Request,
|
| 106 |
+
geoip_reader: geoip2.database.Reader | None
|
| 107 |
+
) -> dict[str, Any]:
|
| 108 |
+
# Get user agent and IP from frontend
|
| 109 |
+
user_agent = pipeline_input_dict_ls[0].get("user_agent", None) # use first input in list of inputs
|
| 110 |
+
if user_agent is None: # fall back for direct API request to backend
|
| 111 |
+
user_agent = request.headers.get("user-agent", "unknown") # get from request headers of FastAPI backend
|
| 112 |
+
client_ip = pipeline_input_dict_ls[0].get("client_ip", None)
|
| 113 |
+
if client_ip is None:
|
| 114 |
+
x_forwarded_for = request.headers.get("x-forwarded-for") # single str with one or more comma-separated IP addresses
|
| 115 |
+
client_ip = x_forwarded_for.split(",")[0].strip() if x_forwarded_for else request.client.host # first IP is client IP address
|
| 116 |
+
|
| 117 |
+
# Get client country from IP address
|
| 118 |
+
client_country = "unknown"
|
| 119 |
+
if geoip_reader and client_ip:
|
| 120 |
+
try:
|
| 121 |
+
response = geoip_reader.country(client_ip)
|
| 122 |
+
client_country = response.country.name
|
| 123 |
+
except AddressNotFoundError: # this occurs for unknown or private or reserved IPs (e.g., 127.0.0.1)
|
| 124 |
+
logger.debug("IP address not found in GeoLite2 country database. Likely a private or local address.")
|
| 125 |
+
|
| 126 |
+
# Create dictionary with batch-level metadata
|
| 127 |
+
metadata = {
|
| 128 |
+
"batch_id": str(uuid.uuid4()),
|
| 129 |
+
"batch_size": len(pipeline_input_dict_ls),
|
| 130 |
+
"batch_timestamp": datetime.now(timezone.utc).isoformat(),
|
| 131 |
+
"pipeline_version": PIPELINE_VERSION_TAG,
|
| 132 |
+
"client_country": client_country,
|
| 133 |
+
"user_agent": user_agent,
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
return metadata
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# Function to load a scikit-learn pipeline from the local machine
|
| 140 |
+
def load_pipeline_from_local(path: str | Path) -> Pipeline:
|
| 141 |
+
# Input type validation
|
| 142 |
+
if not isinstance(path, (str, Path)):
|
| 143 |
+
raise TypeError(f"Error when loading pipeline: 'path' must be a string or Path object, got {type(path).__name__}")
|
| 144 |
+
|
| 145 |
+
# Get path as both string and Path object
|
| 146 |
+
if isinstance(path, Path):
|
| 147 |
+
path_str = str(path)
|
| 148 |
+
else: # isinstance(path, str)
|
| 149 |
+
path_str = path
|
| 150 |
+
path = Path(path)
|
| 151 |
+
|
| 152 |
+
# Ensure file exists
|
| 153 |
+
if not path.exists():
|
| 154 |
+
raise FileNotFoundError(f"Error when loading pipeline: File not found at '{path_str}'")
|
| 155 |
+
|
| 156 |
+
# Load pipeline
|
| 157 |
+
try:
|
| 158 |
+
logger.info(f"Loading pipeline from '{path_str}'...")
|
| 159 |
+
pipeline = joblib.load(path_str)
|
| 160 |
+
logger.info("Successfully loaded pipeline.")
|
| 161 |
+
except Exception as e:
|
| 162 |
+
raise RuntimeError(f"Error when loading pipeline from '{path_str}'") from e
|
| 163 |
+
|
| 164 |
+
# Ensure loaded object is a scikit-learn Pipeline
|
| 165 |
+
if not isinstance(pipeline, Pipeline):
|
| 166 |
+
raise TypeError("Error when loading pipeline: Loaded object is not a scikit-learn Pipeline")
|
| 167 |
+
|
| 168 |
+
# Ensure pipeline has .predict_proba() method
|
| 169 |
+
if not hasattr(pipeline, "predict_proba"):
|
| 170 |
+
raise TypeError("Error when loading pipeline: Loaded pipeline does not have a .predict_proba() method")
|
| 171 |
+
|
| 172 |
+
return pipeline
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# Function to download and load a scikit-learn pipeline from a Hugging Face Hub repository
|
| 176 |
+
def load_pipeline_from_huggingface(repo_id: str, filename: str, revision: str) -> Pipeline:
|
| 177 |
+
try:
|
| 178 |
+
# .hf_hub_download() downloads the pipeline file and returns its local file path (inside the Docker container)
|
| 179 |
+
# if the pipeline file was already downloaded, it will use the cached pipeline that is already stored inside the Docker container
|
| 180 |
+
logger.info(
|
| 181 |
+
f"Downloading pipeline '{filename}' with tag '{revision}' from Hugging Face Hub repo '{repo_id}'. "
|
| 182 |
+
"If already cached, will use local copy."
|
| 183 |
+
)
|
| 184 |
+
pipeline_path = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision)
|
| 185 |
+
|
| 186 |
+
# Load pipeline from file inside the Docker container
|
| 187 |
+
pipeline = load_pipeline_from_local(pipeline_path)
|
| 188 |
+
|
| 189 |
+
return pipeline
|
| 190 |
+
|
| 191 |
+
except Exception as e:
|
| 192 |
+
raise RuntimeError(f"Error loading pipeline '{filename}' from Hugging Face Hub repository '{repo_id}': {e}") from e
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# --- Geolocation Database ---
|
| 196 |
+
# Load the GeoLite2 country database to log client country for model monitoring (download database from https://www.maxmind.com to the "geoip_db/" directory)
|
| 197 |
+
GEO_DB_PATH = Path("geoip_db/GeoLite2-Country.mmdb")
|
| 198 |
+
try:
|
| 199 |
+
geoip_reader = geoip2.database.Reader(GEO_DB_PATH)
|
| 200 |
+
logger.info(f"Successfully loaded GeoLite2 country database from '{GEO_DB_PATH}'")
|
| 201 |
+
except FileNotFoundError:
|
| 202 |
+
logger.error(f"GeoLite2 country database not found at '{GEO_DB_PATH}'. Client country will not be logged. Download the database from https://www.maxmind.com.")
|
| 203 |
+
geoip_reader = None
|
| 204 |
+
|
| 205 |
+
# --- ML Pipeline ---
|
| 206 |
+
# Load loan default prediction pipeline (including data preprocessing and Random Forest Classifier model) from Hugging Face Hub
|
| 207 |
+
PIPELINE_VERSION_TAG = "v1.0"
|
| 208 |
+
pipeline = load_pipeline_from_huggingface(
|
| 209 |
+
repo_id="JensBender/loan-default-prediction-pipeline",
|
| 210 |
+
filename="loan_default_rf_pipeline.joblib",
|
| 211 |
+
revision=PIPELINE_VERSION_TAG
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# Load pipeline from local machine (use for local setup without Hugging Face Hub)
|
| 215 |
+
# root_dir = get_root_directory() # get path to root directory
|
| 216 |
+
# pipeline_path = root_dir / "models" / "loan_default_rf_pipeline.joblib" # get path to pipeline file
|
| 217 |
+
# pipeline = load_pipeline_from_local(pipeline_path)
|
| 218 |
+
|
| 219 |
+
# --- API ---
|
| 220 |
+
# Create FastAPI app
|
| 221 |
+
fastapi_app = FastAPI(title="Loan Default Prediction API")
|
| 222 |
+
|
| 223 |
+
# Prediction endpoint
|
| 224 |
+
@fastapi_app.post("/api/predict", response_model=PredictionResponse)
|
| 225 |
+
def predict(pipeline_input: PipelineInput | list[PipelineInput], request: Request) -> PredictionResponse: # JSON object -> PipelineInput | JSON array -> list[PipelineInput]
|
| 226 |
+
batch_metadata = None
|
| 227 |
+
pipeline_input_dict_ls = None
|
| 228 |
+
try:
|
| 229 |
+
# Standardize input
|
| 230 |
+
if isinstance(pipeline_input, list):
|
| 231 |
+
if pipeline_input == []: # handle empty batch input
|
| 232 |
+
return PredictionResponse(results=[])
|
| 233 |
+
pipeline_input_dict_ls = [input.model_dump() for input in pipeline_input]
|
| 234 |
+
else: # isinstance(pipeline_input, PipelineInput)
|
| 235 |
+
pipeline_input_dict_ls = [pipeline_input.model_dump()]
|
| 236 |
+
|
| 237 |
+
# Get batch metadata for logging
|
| 238 |
+
batch_metadata = get_batch_metadata(pipeline_input_dict_ls, request, geoip_reader)
|
| 239 |
+
|
| 240 |
+
# Remove "client_ip" and "user_agent" from pipeline input
|
| 241 |
+
pipeline_input_cleaned = [
|
| 242 |
+
{k: v for k, v in d.items() if k not in {"client_ip", "user_agent"}}
|
| 243 |
+
for d in pipeline_input_dict_ls
|
| 244 |
+
]
|
| 245 |
+
|
| 246 |
+
# Create DataFrame
|
| 247 |
+
pipeline_input_df: pd.DataFrame = pd.DataFrame(pipeline_input_cleaned)
|
| 248 |
+
|
| 249 |
+
# Use pipeline to batch predict probabilities (and measure latency)
|
| 250 |
+
start_time = time.perf_counter() # use .perf_counter() for latency measurement and .time() for timestamps
|
| 251 |
+
predicted_probabilities: np.ndarray = pipeline.predict_proba(pipeline_input_df)
|
| 252 |
+
pipeline_prediction_latency_ms = round((time.perf_counter() - start_time) * 1000) # rounded to milliseconds
|
| 253 |
+
|
| 254 |
+
# Apply optimized threshold to convert probabilities to binary predictions
|
| 255 |
+
optimized_threshold: float = 0.29 # see threshold optimization in training script "loan_default_prediction.ipynb"
|
| 256 |
+
predictions: np.ndarray = (predicted_probabilities[:, 1] >= optimized_threshold) # bool 1d-array based on class 1 "Default"
|
| 257 |
+
|
| 258 |
+
# Add latency to batch metadata for logging
|
| 259 |
+
batch_metadata.update({
|
| 260 |
+
"batch_latency_ms": pipeline_prediction_latency_ms,
|
| 261 |
+
"avg_prediction_latency_ms": round(pipeline_prediction_latency_ms / len(pipeline_input_dict_ls)),
|
| 262 |
+
})
|
| 263 |
+
|
| 264 |
+
# --- Create prediction response ---
|
| 265 |
+
results: list[PredictionResult] = []
|
| 266 |
+
# Iterate over each prediction
|
| 267 |
+
for i, (pred, pred_proba) in enumerate(zip(predictions, predicted_probabilities)):
|
| 268 |
+
# Create prediction result
|
| 269 |
+
prediction_enum = PredictionEnum.DEFAULT if pred else PredictionEnum.NO_DEFAULT
|
| 270 |
+
prediction_result = PredictionResult(
|
| 271 |
+
prediction=prediction_enum,
|
| 272 |
+
probabilities=PredictedProbabilities(
|
| 273 |
+
default=float(pred_proba[1]),
|
| 274 |
+
no_default=float(pred_proba[0])
|
| 275 |
+
)
|
| 276 |
+
)
|
| 277 |
+
results.append(prediction_result)
|
| 278 |
+
|
| 279 |
+
# Log single prediction record for model monitoring (including batch metadata)
|
| 280 |
+
prediction_monitoring_record = {
|
| 281 |
+
**batch_metadata,
|
| 282 |
+
"prediction_id": str(uuid.uuid4()),
|
| 283 |
+
"inputs": pipeline_input_cleaned[i],
|
| 284 |
+
"prediction": prediction_enum.value,
|
| 285 |
+
"probabilities": {
|
| 286 |
+
"default": float(pred_proba[1]),
|
| 287 |
+
"no_default": float(pred_proba[0])
|
| 288 |
+
},
|
| 289 |
+
}
|
| 290 |
+
monitoring_logger.info(json.dumps(prediction_monitoring_record)) # converts record to JSON string for log
|
| 291 |
+
|
| 292 |
+
return PredictionResponse(results=results)
|
| 293 |
+
|
| 294 |
+
except Exception as e:
|
| 295 |
+
# Log error to console
|
| 296 |
+
logger.error("Error during predict: %s", e, exc_info=True)
|
| 297 |
+
|
| 298 |
+
# Log prediction error record to file for model monitoring (including batch metadata)
|
| 299 |
+
if pipeline_input_dict_ls:
|
| 300 |
+
if batch_metadata is None: # error before .get_batch_metadata()
|
| 301 |
+
batch_metadata = {
|
| 302 |
+
"batch_id": str(uuid.uuid4()),
|
| 303 |
+
"batch_size": len(pipeline_input_dict_ls),
|
| 304 |
+
"batch_timestamp": datetime.now(timezone.utc).isoformat(),
|
| 305 |
+
"pipeline_version": PIPELINE_VERSION_TAG,
|
| 306 |
+
"client_country": None,
|
| 307 |
+
"user_agent": None
|
| 308 |
+
}
|
| 309 |
+
# Iterate over each input in batch
|
| 310 |
+
for input in pipeline_input_dict_ls:
|
| 311 |
+
prediction_monitoring_record = {
|
| 312 |
+
**batch_metadata,
|
| 313 |
+
"prediction_id": str(uuid.uuid4()),
|
| 314 |
+
"inputs": input,
|
| 315 |
+
"prediction": None,
|
| 316 |
+
"probabilities": None,
|
| 317 |
+
"error_message": str(e)
|
| 318 |
+
}
|
| 319 |
+
monitoring_logger.error(json.dumps(prediction_monitoring_record))
|
| 320 |
+
else:
|
| 321 |
+
prediction_monitoring_record = {
|
| 322 |
+
"batch_id": str(uuid.uuid4()),
|
| 323 |
+
"batch_size": None,
|
| 324 |
+
"batch_timestamp": datetime.now(timezone.utc).isoformat(),
|
| 325 |
+
"pipeline_version": PIPELINE_VERSION_TAG,
|
| 326 |
+
"client_country": None,
|
| 327 |
+
"user_agent": None,
|
| 328 |
+
"prediction_id": str(uuid.uuid4()),
|
| 329 |
+
"inputs": None,
|
| 330 |
+
"prediction": None,
|
| 331 |
+
"probabilities": None,
|
| 332 |
+
"error_message": str(e)
|
| 333 |
+
}
|
| 334 |
+
monitoring_logger.error(json.dumps(prediction_monitoring_record))
|
| 335 |
+
|
| 336 |
+
raise HTTPException(status_code=500, detail="Internal server error during loan default prediction")
|
| 337 |
+
|
| 338 |
+
# Mount Gradio frontend onto FastAPI backend
|
| 339 |
+
app = gr.mount_gradio_app(
|
| 340 |
+
fastapi_app,
|
| 341 |
+
gradio_app,
|
| 342 |
+
path="/gradio", # at "/gradio" not "/" due to known Gradio bug (redirect loop)
|
| 343 |
+
show_api=False # disable Gradio's auto-generated API
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
# Home route redirects to Gradio UI
|
| 347 |
+
@app.get("/")
|
| 348 |
+
def root():
|
| 349 |
+
return RedirectResponse(url="/gradio/")
|
backend/schemas.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- Imports ---
|
| 2 |
+
# Standard library imports
|
| 3 |
+
from enum import Enum
|
| 4 |
+
from typing import List, Annotated
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
# Third-party library imports
|
| 8 |
+
from pydantic import BaseModel, Field, field_validator, model_validator, computed_field
|
| 9 |
+
|
| 10 |
+
# Local imports
|
| 11 |
+
from src.global_constants import (
|
| 12 |
+
MARRIED_LABELS,
|
| 13 |
+
HOUSE_OWNERSHIP_LABELS,
|
| 14 |
+
CAR_OWNERSHIP_LABELS,
|
| 15 |
+
PROFESSION_LABELS,
|
| 16 |
+
CITY_LABELS,
|
| 17 |
+
STATE_LABELS
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# --- Constants ---
|
| 21 |
+
# Input constraints (for Pydantic data model)
|
| 22 |
+
INCOME_CONSTRAINTS = Field(strict=True, ge=0)
|
| 23 |
+
AGE_CONSTRAINTS = Field(strict=True, ge=21, le=79)
|
| 24 |
+
EXPERIENCE_CONSTRAINTS = Field(strict=True, ge=0, le=20)
|
| 25 |
+
CURRENT_JOB_YRS_CONSTRAINTS = Field(strict=True, ge=0, le=14)
|
| 26 |
+
CURRENT_HOUSE_YRS_CONSTRAINTS = Field(strict=True, ge=10, le=14)
|
| 27 |
+
|
| 28 |
+
# --- Custom Data Types ---
|
| 29 |
+
# Annotate and combine existing types with custom constraints (for Pydantic data model)
|
| 30 |
+
Income = Annotated[int, INCOME_CONSTRAINTS] | Annotated[float, INCOME_CONSTRAINTS]
|
| 31 |
+
Age = Annotated[int, AGE_CONSTRAINTS] | Annotated[float, AGE_CONSTRAINTS]
|
| 32 |
+
Experience = Annotated[int, EXPERIENCE_CONSTRAINTS] | Annotated[float, EXPERIENCE_CONSTRAINTS]
|
| 33 |
+
CurrentJobYrs = Annotated[int, CURRENT_JOB_YRS_CONSTRAINTS] | Annotated[float, CURRENT_JOB_YRS_CONSTRAINTS]
|
| 34 |
+
CurrentHouseYrs = Annotated[int, CURRENT_HOUSE_YRS_CONSTRAINTS] | Annotated[float, CURRENT_HOUSE_YRS_CONSTRAINTS]
|
| 35 |
+
|
| 36 |
+
# --- Enums ---
|
| 37 |
+
# Custom Enum classes for string inputs based on global constants (for Pydantic data model)
|
| 38 |
+
MarriedEnum = Enum("MarriedEnum", {label.upper(): label for label in MARRIED_LABELS}, type=str)
|
| 39 |
+
HouseOwnershipEnum = Enum("HouseOwnershipEnum", {label.upper(): label for label in HOUSE_OWNERSHIP_LABELS}, type=str)
|
| 40 |
+
CarOwnershipEnum = Enum("CarOwnershipEnum", {label.upper(): label for label in CAR_OWNERSHIP_LABELS}, type=str)
|
| 41 |
+
ProfessionEnum = Enum("ProfessionEnum", {label.upper(): label for label in PROFESSION_LABELS}, type=str)
|
| 42 |
+
CityEnum = Enum("CityEnum", {label.upper(): label for label in CITY_LABELS}, type=str)
|
| 43 |
+
StateEnum = Enum("StateEnum", {label.upper(): label for label in STATE_LABELS}, type=str)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# Enum for possible prediction strings
|
| 47 |
+
class PredictionEnum(str, Enum):
|
| 48 |
+
DEFAULT = "Default"
|
| 49 |
+
NO_DEFAULT = "No Default"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# --- Pydantic Data Models ---
|
| 53 |
+
# Pipeline input model
|
| 54 |
+
class PipelineInput(BaseModel):
|
| 55 |
+
income: Income
|
| 56 |
+
age: Age
|
| 57 |
+
experience: Experience
|
| 58 |
+
married: MarriedEnum | None = None
|
| 59 |
+
house_ownership: HouseOwnershipEnum | None = None
|
| 60 |
+
car_ownership: CarOwnershipEnum | None = None
|
| 61 |
+
profession: ProfessionEnum
|
| 62 |
+
city: CityEnum
|
| 63 |
+
state: StateEnum
|
| 64 |
+
current_job_yrs: CurrentJobYrs
|
| 65 |
+
current_house_yrs: CurrentHouseYrs
|
| 66 |
+
client_ip: str | None = None
|
| 67 |
+
user_agent: str | None = None
|
| 68 |
+
|
| 69 |
+
@field_validator("income", "age", "experience", "current_job_yrs", "current_house_yrs")
|
| 70 |
+
def convert_float_to_int(cls, value: float | int) -> int:
|
| 71 |
+
if isinstance(value, float):
|
| 72 |
+
return int(round(value))
|
| 73 |
+
return value
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Predicted probabilities model
|
| 77 |
+
class PredictedProbabilities(BaseModel):
|
| 78 |
+
default: float = Field(..., ge=0.0, le=1.0, serialization_alias="Default")
|
| 79 |
+
no_default: float = Field(..., ge=0.0, le=1.0, serialization_alias="No Default")
|
| 80 |
+
|
| 81 |
+
@field_validator("default", "no_default")
|
| 82 |
+
def round_to_3_decimals(cls, value: float) -> float:
|
| 83 |
+
return round(value, 3)
|
| 84 |
+
|
| 85 |
+
@model_validator(mode="after") # happens after rounding
|
| 86 |
+
def check_probabilities_sum_to_one(self) -> "PredictedProbabilities":
|
| 87 |
+
if not math.isclose(self.default + self.no_default, 1.0, abs_tol=0.002):
|
| 88 |
+
raise ValueError("Probabilities must sum to 1.0")
|
| 89 |
+
return self
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Prediction result model
|
| 93 |
+
class PredictionResult(BaseModel):
|
| 94 |
+
prediction: PredictionEnum
|
| 95 |
+
probabilities: PredictedProbabilities
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# Prediction response model
|
| 99 |
+
class PredictionResponse(BaseModel):
|
| 100 |
+
results: List[PredictionResult] = Field(strict=True)
|
| 101 |
+
|
| 102 |
+
@computed_field
|
| 103 |
+
def n_predictions(self) -> int:
|
| 104 |
+
return len(self.results)
|
frontend/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# This file marks the "app" directory as a Python package.
|
frontend/app.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- Imports ---
|
| 2 |
+
# Standard library imports
|
| 3 |
+
import re
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
# Third-party library imports
|
| 8 |
+
import gradio as gr
|
| 9 |
+
import requests
|
| 10 |
+
from requests.exceptions import ConnectionError, Timeout, RequestException
|
| 11 |
+
|
| 12 |
+
# Local imports
|
| 13 |
+
from src.global_constants import (
|
| 14 |
+
MARRIED_LABELS,
|
| 15 |
+
CAR_OWNERSHIP_LABELS,
|
| 16 |
+
HOUSE_OWNERSHIP_LABELS,
|
| 17 |
+
PROFESSION_LABELS,
|
| 18 |
+
CITY_LABELS,
|
| 19 |
+
STATE_LABELS
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# --- Logger ---
|
| 23 |
+
# Setup a structured logger for the frontend
|
| 24 |
+
logging.basicConfig(
|
| 25 |
+
level=logging.INFO,
|
| 26 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 27 |
+
)
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
# --- Constants ---
|
| 31 |
+
# Format categorical string labels (snake_case) for display in UI
|
| 32 |
+
MARRIED_DISPLAY_LABELS = [label.title() for label in MARRIED_LABELS]
|
| 33 |
+
CAR_OWNERSHIP_DISPLAY_LABELS = [label.title() for label in CAR_OWNERSHIP_LABELS]
|
| 34 |
+
HOUSE_OWNERSHIP_DISPLAY_LABELS = [label.replace("norent_noown", "Neither Rented Nor Owned").title() for label in HOUSE_OWNERSHIP_LABELS]
|
| 35 |
+
PROFESSION_DISPLAY_LABELS = [label.replace("_", " ").title() for label in PROFESSION_LABELS]
|
| 36 |
+
CITY_DISPLAY_LABELS = [label.replace("_", " ").title() for label in CITY_LABELS]
|
| 37 |
+
STATE_DISPLAY_LABELS = [label.replace("_", " ").title() for label in STATE_LABELS]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# --- Input Preprocessing Functions ---
|
| 41 |
+
# Format a string in snake_case (return non-string unchanged)
|
| 42 |
+
def format_snake_case(value: Any) -> Any:
|
| 43 |
+
if isinstance(value, str):
|
| 44 |
+
# Remove leading/trailing whitespace, convert to lowercase, and replace single or multiple hyphens, forward slashes, and inner whitespaces with a single underscore
|
| 45 |
+
return re.sub(r"[-/\s]+", "_", value.strip().lower())
|
| 46 |
+
return value # return non-string unchanged
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# Format all string values in a dictionary in snake_case
|
| 50 |
+
def format_snake_case_in_dict(inputs: dict[str, Any]) -> dict[str, Any]:
|
| 51 |
+
return {key: format_snake_case(value) for key, value in inputs.items()}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Format "house_ownership" label as expected by API backend
|
| 55 |
+
def format_house_ownership(display_label: Any) -> Any:
|
| 56 |
+
if isinstance(display_label, str):
|
| 57 |
+
return display_label.replace("neither_rented_nor_owned", "norent_noown")
|
| 58 |
+
return display_label # return non-string unchanged
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# --- Error Handling ---
|
| 62 |
+
# Map internal input field names (snake_case) to user-friendly error messages
|
| 63 |
+
field_to_error_map = {
|
| 64 |
+
"age": "Age: Enter a number between 21 and 79.",
|
| 65 |
+
"married": "Married/Single: Select 'Married' or 'Single'",
|
| 66 |
+
"income": "Income: Enter a number that is 0 or greater.",
|
| 67 |
+
"car_ownership": "Car Ownership: Select 'Yes' or 'No'.",
|
| 68 |
+
"house_ownership": "House Ownership: Select 'Rented', 'Owned', or 'Neither Rented Nor Owned'.",
|
| 69 |
+
"current_house_yrs": "Current House Years: Enter a number between 10 and 14.",
|
| 70 |
+
"city": "City: Select a city from the list.",
|
| 71 |
+
"state": "State: Select a state from the list.",
|
| 72 |
+
"profession": "Profession: Select a profession from the list.",
|
| 73 |
+
"experience": "Experience: Enter a number between 0 and 20.",
|
| 74 |
+
"current_job_yrs": "Current Job Years: Enter a number between 0 and 14.",
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# Function to format Pydantic validation error from FastAPI backend into a user-friendly message for Gradio frontend
|
| 79 |
+
def _format_validation_error(error_detail: dict) -> str:
|
| 80 |
+
error_msg = "Input Error! Please check your inputs and try again.\n"
|
| 81 |
+
try:
|
| 82 |
+
# Parse the Pydantic error format to create an error message with information about each invalid field
|
| 83 |
+
all_errors = error_detail["detail"]
|
| 84 |
+
for field in field_to_error_map:
|
| 85 |
+
if any(field in error["loc"] for error in all_errors):
|
| 86 |
+
error_msg += f"{field_to_error_map.get(field)}\n"
|
| 87 |
+
return error_msg
|
| 88 |
+
except Exception as e:
|
| 89 |
+
# Fallback to generic error message without details if Pydantic validation error has unexpected format
|
| 90 |
+
logger.warning("Failed to parse validation error from backend: %s", e, exc_info=True)
|
| 91 |
+
return error_msg
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# --- Function to Predict Loan Default for Gradio UI ---
|
| 95 |
+
def predict_loan_default(
|
| 96 |
+
age: int | float,
|
| 97 |
+
married: str,
|
| 98 |
+
income: int | float,
|
| 99 |
+
car_ownership: str,
|
| 100 |
+
house_ownership: str,
|
| 101 |
+
current_house_yrs: int | float,
|
| 102 |
+
city: str,
|
| 103 |
+
state: str,
|
| 104 |
+
profession: str,
|
| 105 |
+
experience: int | float,
|
| 106 |
+
current_job_yrs: int | float,
|
| 107 |
+
gr_request: gr.Request
|
| 108 |
+
) -> tuple[str, dict[str, float]] | tuple[str, str]:
|
| 109 |
+
try:
|
| 110 |
+
# Get the end-user's IP address, prioritizing the x-forwarded-for header with fallback
|
| 111 |
+
x_forwarded_for = gr_request.headers.get("x-forwarded-for")
|
| 112 |
+
client_ip = x_forwarded_for.split(",")[0].strip() if x_forwarded_for else gr_request.client.host
|
| 113 |
+
|
| 114 |
+
# --- Input preprocessing ---
|
| 115 |
+
# Create inputs dictionary
|
| 116 |
+
inputs = {
|
| 117 |
+
"income": income,
|
| 118 |
+
"age": age,
|
| 119 |
+
"experience": experience,
|
| 120 |
+
"married": married,
|
| 121 |
+
"house_ownership": house_ownership,
|
| 122 |
+
"car_ownership": car_ownership,
|
| 123 |
+
"profession": profession,
|
| 124 |
+
"city": city,
|
| 125 |
+
"state": state,
|
| 126 |
+
"current_job_yrs": current_job_yrs,
|
| 127 |
+
"current_house_yrs": current_house_yrs,
|
| 128 |
+
"client_ip": client_ip,
|
| 129 |
+
"user_agent": gr_request.headers.get("user-agent", "unknown")
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
# Format string values in snake_case
|
| 133 |
+
inputs = format_snake_case_in_dict(inputs)
|
| 134 |
+
|
| 135 |
+
# Format "house_ownership" label as expected by API backend
|
| 136 |
+
inputs["house_ownership"] = format_house_ownership(inputs["house_ownership"])
|
| 137 |
+
|
| 138 |
+
# --- Post request ---
|
| 139 |
+
# Predict loan default via post request to FastAPI backend
|
| 140 |
+
response = requests.post(
|
| 141 |
+
"http://127.0.0.1:7860/api/predict",
|
| 142 |
+
json=inputs,
|
| 143 |
+
timeout=(3, 60) # 3s connect timeout, 60s read timeout (receive first byte of response)
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# --- Error handling ---
|
| 147 |
+
# Handle HTTP errors
|
| 148 |
+
if response.status_code == 422:
|
| 149 |
+
error_detail = response.json()
|
| 150 |
+
logger.warning("Received 422 validation error from backend: %s", error_detail)
|
| 151 |
+
error_message = _format_validation_error(error_detail)
|
| 152 |
+
return error_message, ""
|
| 153 |
+
|
| 154 |
+
# Raise error for other bad status codes (4xx or 5xx)
|
| 155 |
+
response.raise_for_status()
|
| 156 |
+
|
| 157 |
+
# --- Response parsing ---
|
| 158 |
+
# Get prediction and probabilities from HTTP response for Gradio output
|
| 159 |
+
try:
|
| 160 |
+
prediction_response = response.json()
|
| 161 |
+
prediction_result = prediction_response["results"][0]
|
| 162 |
+
prediction = prediction_result["prediction"]
|
| 163 |
+
probabilities = prediction_result["probabilities"]
|
| 164 |
+
|
| 165 |
+
# Data validation for Gradio rendering
|
| 166 |
+
if not isinstance(prediction, str):
|
| 167 |
+
raise TypeError(f"'prediction' used in gr.Textbox expects str, got {type(prediction).__name__}.")
|
| 168 |
+
if not isinstance(probabilities, dict):
|
| 169 |
+
raise TypeError(f"'probabilities' used in gr.Label expects dict, got {type(probabilities).__name__}.")
|
| 170 |
+
if not probabilities:
|
| 171 |
+
raise ValueError("'probabilities' dict cannot be empty.")
|
| 172 |
+
if not all(isinstance(key, str) for key in probabilities.keys()):
|
| 173 |
+
raise TypeError("'probabilities' dict keys must be strings.")
|
| 174 |
+
if not all(isinstance(value, (int, float)) and 0 <= value <= 1 for value in probabilities.values()):
|
| 175 |
+
raise TypeError("'probabilities' dict values must be numbers between 0 and 1.")
|
| 176 |
+
|
| 177 |
+
return prediction, probabilities
|
| 178 |
+
# Handle response parsing errors
|
| 179 |
+
except (KeyError, IndexError, TypeError, ValueError):
|
| 180 |
+
logger.error("Failed to parse prediction response from backend.", exc_info=True)
|
| 181 |
+
return "Prediction Response Error", "The prediction service returned an invalid prediction format."
|
| 182 |
+
|
| 183 |
+
except ConnectionError:
|
| 184 |
+
logger.error("Connection to backend failed.", exc_info=True)
|
| 185 |
+
return "Connection Error", "Could not connect to the prediction service. Please ensure the backend is running and try again."
|
| 186 |
+
except Timeout:
|
| 187 |
+
logger.error("Request to backend timed out.", exc_info=True)
|
| 188 |
+
return "Timeout Error", "The request to the prediction service timed out. The service may be busy or slow. Please try again later."
|
| 189 |
+
except RequestException: # catches other frontend-to-backend communication errors
|
| 190 |
+
logger.error("HTTP error while trying to communicate with backend.", exc_info=True)
|
| 191 |
+
return "Communication Error", "There was a problem communicating with the prediction service. Please try again later."
|
| 192 |
+
except Exception:
|
| 193 |
+
logger.exception("Unexpected error in the frontend.")
|
| 194 |
+
return "Error", f"An unexpected error has occurred. Please verify your inputs or try again later."
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# --- Gradio App UI ---
|
| 198 |
+
# Custom CSS
|
| 199 |
+
custom_css = """
|
| 200 |
+
.narrow-centered-column {
|
| 201 |
+
max-width: 700px;
|
| 202 |
+
width: 100%;
|
| 203 |
+
margin: 0 auto;
|
| 204 |
+
}
|
| 205 |
+
#predict-button-wrapper {
|
| 206 |
+
max-width: 250px;
|
| 207 |
+
margin: 0 auto;
|
| 208 |
+
}
|
| 209 |
+
#prediction-text textarea {font-size: 1.8em; font-weight: bold; text-align: center;}
|
| 210 |
+
#pred-proba-label {margin-top: -15px;}
|
| 211 |
+
#markdown-note {margin-top: -13px;}
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
# Create Gradio app UI using Blocks
|
| 215 |
+
with gr.Blocks(css=custom_css) as gradio_app:
|
| 216 |
+
# Title and description
|
| 217 |
+
gr.Markdown(
|
| 218 |
+
"""
|
| 219 |
+
<h1 style='text-align:center'>Loan Default Prediction</h1>
|
| 220 |
+
<p style='text-align:center'>Submit the customer application data to receive an automated loan default prediction powered by machine learning.</p>
|
| 221 |
+
"""
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# Inputs
|
| 225 |
+
with gr.Group():
|
| 226 |
+
with gr.Row():
|
| 227 |
+
age = gr.Number(label="Age", value="")
|
| 228 |
+
married = gr.Dropdown(label="Married/Single", choices=MARRIED_DISPLAY_LABELS, value=None)
|
| 229 |
+
income = gr.Number(label="Income", value="")
|
| 230 |
+
with gr.Row():
|
| 231 |
+
car_ownership = gr.Dropdown(label="Car Ownership", choices=CAR_OWNERSHIP_DISPLAY_LABELS, value=None)
|
| 232 |
+
house_ownership = gr.Dropdown(label="House Ownership", choices=HOUSE_OWNERSHIP_DISPLAY_LABELS, value=None)
|
| 233 |
+
current_house_yrs = gr.Slider(label="Current House Years", minimum=10, maximum=14, step=1)
|
| 234 |
+
with gr.Row():
|
| 235 |
+
city = gr.Dropdown(label="City", choices=CITY_DISPLAY_LABELS, value=None)
|
| 236 |
+
state = gr.Dropdown(label="State", choices=STATE_DISPLAY_LABELS, value=None)
|
| 237 |
+
profession = gr.Dropdown(label="Profession", choices=PROFESSION_DISPLAY_LABELS, value=None)
|
| 238 |
+
with gr.Row():
|
| 239 |
+
experience = gr.Slider(label="Experience", minimum=0, maximum=20, step=1)
|
| 240 |
+
current_job_yrs = gr.Slider(label="Current Job Years", minimum=0, maximum=14, step=1)
|
| 241 |
+
gr.Markdown("") # empty space for layout
|
| 242 |
+
|
| 243 |
+
# Predict button
|
| 244 |
+
with gr.Column(elem_id="predict-button-wrapper"):
|
| 245 |
+
predict = gr.Button("Predict", elem_id="predict-button")
|
| 246 |
+
|
| 247 |
+
# Outputs
|
| 248 |
+
with gr.Column(elem_classes="narrow-centered-column"):
|
| 249 |
+
prediction_text = gr.Textbox(placeholder="Prediction Result", show_label=False, container=False, elem_id="prediction-text")
|
| 250 |
+
pred_proba = gr.Label(show_label=False, show_heading=False, elem_id="pred-proba-label")
|
| 251 |
+
gr.Markdown(
|
| 252 |
+
"<small>Note: Prediction uses an optimized decision threshold of 0.29 "
|
| 253 |
+
"(predicts 'Default' if probability ≥ 29%, otherwise 'No Default').</small>",
|
| 254 |
+
elem_id="markdown-note"
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Predict button click event
|
| 258 |
+
predict.click(
|
| 259 |
+
predict_loan_default,
|
| 260 |
+
inputs=[
|
| 261 |
+
age, married, income, car_ownership, house_ownership, current_house_yrs,
|
| 262 |
+
city, state, profession, experience, current_job_yrs
|
| 263 |
+
],
|
| 264 |
+
outputs=[prediction_text, pred_proba]
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# Launch Gradio app
|
| 268 |
+
if __name__ == "__main__":
|
| 269 |
+
gradio_app.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Backend
|
| 2 |
+
fastapi==0.115.12
|
| 3 |
+
pydantic==2.11.9
|
| 4 |
+
uvicorn==0.34.2
|
| 5 |
+
joblib==1.4.2
|
| 6 |
+
huggingface-hub==0.35.0
|
| 7 |
+
geoip2==5.1.0
|
| 8 |
+
|
| 9 |
+
# Frontend
|
| 10 |
+
gradio==5.29.0
|
| 11 |
+
httpx==0.28.1
|
| 12 |
+
requests==2.32.5
|
| 13 |
+
|
| 14 |
+
# Shared
|
| 15 |
+
numpy==2.2.1
|
| 16 |
+
pandas==2.2.3
|
| 17 |
+
scikit-learn==1.6.0
|
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# This file marks the directory as a Python package.
|
src/custom_transformers.py
ADDED
|
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Imports
|
| 2 |
+
from sklearn.base import BaseEstimator, TransformerMixin
|
| 3 |
+
from sklearn.impute import SimpleImputer
|
| 4 |
+
from sklearn.preprocessing import StandardScaler, OneHotEncoder, OrdinalEncoder
|
| 5 |
+
from sklearn.utils.validation import check_is_fitted
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# --- Custom error classes ---
|
| 11 |
+
# For missing values in critical columns of the X input DataFrame (in MissingValueChecker)
|
| 12 |
+
class MissingValueError(ValueError):
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
# For mistmatch between expected and actual columns in X input DataFrame because of missing columns, unexpected columns, or wrong column order
|
| 16 |
+
class ColumnMismatchError(ValueError):
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
# For invalid categorical labels (in BooleanColumnTransformer)
|
| 20 |
+
class CategoricalLabelError(ValueError):
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# --- Custom transformer classes for data preprocessing pipeline ---
|
| 25 |
+
# Check missing values
|
| 26 |
+
class MissingValueChecker(BaseEstimator, TransformerMixin):
|
| 27 |
+
def __init__(self, critical_features, non_critical_features):
|
| 28 |
+
# Validate input data type
|
| 29 |
+
if not isinstance(critical_features, list):
|
| 30 |
+
raise TypeError("'critical_features' must be a list of column names.")
|
| 31 |
+
if not isinstance(non_critical_features, list):
|
| 32 |
+
raise TypeError("'non_critical_features' must be a list of column names.")
|
| 33 |
+
|
| 34 |
+
# Validate input value
|
| 35 |
+
if not critical_features:
|
| 36 |
+
raise ValueError("'critical_features' cannot be an empty list. It must specify the names of the critical features.")
|
| 37 |
+
if not non_critical_features:
|
| 38 |
+
raise ValueError("'non_critical_features' cannot be an empty list. It must specify the names of the non-critical features.")
|
| 39 |
+
|
| 40 |
+
self.critical_features = critical_features
|
| 41 |
+
self.non_critical_features = non_critical_features
|
| 42 |
+
|
| 43 |
+
def _validate_input(self, X):
|
| 44 |
+
# Validate input data type
|
| 45 |
+
if not isinstance(X, pd.DataFrame):
|
| 46 |
+
raise TypeError("Input X must be a pandas DataFrame.")
|
| 47 |
+
|
| 48 |
+
# Ensure input DataFrame contains all required columns
|
| 49 |
+
input_columns = set(X.columns)
|
| 50 |
+
required_columns = set(self.critical_features + self.non_critical_features)
|
| 51 |
+
missing_columns = required_columns - input_columns
|
| 52 |
+
if missing_columns:
|
| 53 |
+
raise ColumnMismatchError(f"Input X is missing the following columns: {', '.join(missing_columns)}.")
|
| 54 |
+
|
| 55 |
+
# Ensure input DataFrame doesn't contain any unexpected columns
|
| 56 |
+
unexpected_columns = input_columns - required_columns
|
| 57 |
+
if unexpected_columns:
|
| 58 |
+
raise ColumnMismatchError(f"Input X contains the following columns that are neither defined in 'critical_features' nor 'non_critical_features: {', '.join(unexpected_columns)}.")
|
| 59 |
+
|
| 60 |
+
def _check_missing_values(self, X):
|
| 61 |
+
# --- Critical features ---
|
| 62 |
+
# Calculate total number of missing values
|
| 63 |
+
n_missing_total_critical = X[self.critical_features].isnull().sum().sum()
|
| 64 |
+
# Calculate number of rows with missing values
|
| 65 |
+
n_missing_rows_critical = X[self.critical_features].isnull().any(axis=1).sum()
|
| 66 |
+
# Create dictionary with number of missing values by column
|
| 67 |
+
n_missing_by_column_critical = X[self.critical_features].isnull().sum().to_dict()
|
| 68 |
+
# Raise error
|
| 69 |
+
if n_missing_total_critical > 0:
|
| 70 |
+
values = "value" if n_missing_total_critical == 1 else "values"
|
| 71 |
+
rows = "row" if n_missing_rows_critical == 1 else "rows"
|
| 72 |
+
raise MissingValueError(
|
| 73 |
+
f"{n_missing_total_critical} missing {values} found in critical features "
|
| 74 |
+
f"across {n_missing_rows_critical} {rows}. Please provide missing {values}.\n"
|
| 75 |
+
f"Missing values by column: {n_missing_by_column_critical}"
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# --- Non-critical features ---
|
| 79 |
+
# Calculate total number of missing values
|
| 80 |
+
n_missing_total_noncritical = X[self.non_critical_features].isnull().sum().sum()
|
| 81 |
+
# Calculate number of rows with missing values
|
| 82 |
+
n_missing_rows_noncritical = X[self.non_critical_features].isnull().any(axis=1).sum()
|
| 83 |
+
# Create dictionary with number of missing values by column
|
| 84 |
+
n_missing_by_column_noncritical = X[self.non_critical_features].isnull().sum().to_dict()
|
| 85 |
+
# Display warning message
|
| 86 |
+
if n_missing_total_noncritical > 0:
|
| 87 |
+
values = "value" if n_missing_total_noncritical == 1 else "values"
|
| 88 |
+
rows = "row" if n_missing_rows_noncritical == 1 else "rows"
|
| 89 |
+
print(
|
| 90 |
+
f"Warning: {n_missing_total_noncritical} missing {values} found in non-critical features "
|
| 91 |
+
f"across {n_missing_rows_noncritical} {rows}. Missing {values} will be imputed.\n"
|
| 92 |
+
f"Missing values by column: {n_missing_by_column_noncritical}"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def fit(self, X, y=None):
|
| 96 |
+
# Validate input
|
| 97 |
+
self._validate_input(X)
|
| 98 |
+
|
| 99 |
+
# Check missing values
|
| 100 |
+
self._check_missing_values(X)
|
| 101 |
+
|
| 102 |
+
# Raise MissingValueError if a non-critical feature has only missing values
|
| 103 |
+
for non_critical_feature in self.non_critical_features:
|
| 104 |
+
if X[non_critical_feature].isnull().all():
|
| 105 |
+
raise MissingValueError(f"'{non_critical_feature}' cannot be only missing values. Please ensure at least one non-missing value.")
|
| 106 |
+
|
| 107 |
+
# Store input feature number and names as learned attributes
|
| 108 |
+
self.n_features_in_ = X.shape[1]
|
| 109 |
+
self.feature_names_in_ = X.columns.tolist()
|
| 110 |
+
|
| 111 |
+
return self
|
| 112 |
+
|
| 113 |
+
def transform(self, X):
|
| 114 |
+
# Ensure .fit() happened before
|
| 115 |
+
check_is_fitted(self)
|
| 116 |
+
|
| 117 |
+
# Validate input
|
| 118 |
+
self._validate_input(X)
|
| 119 |
+
|
| 120 |
+
# Ensure input feature names and feature order is the same as during .fit()
|
| 121 |
+
if X.columns.tolist() != self.feature_names_in_:
|
| 122 |
+
raise ColumnMismatchError("Feature names and feature order of input X must be the same as during .fit().")
|
| 123 |
+
|
| 124 |
+
# Check missing values
|
| 125 |
+
self._check_missing_values(X)
|
| 126 |
+
|
| 127 |
+
return X
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# Standardize missing values
|
| 131 |
+
class MissingValueStandardizer(BaseEstimator, TransformerMixin):
|
| 132 |
+
def fit(self, X, y=None):
|
| 133 |
+
# Validate input data type
|
| 134 |
+
if not isinstance(X, pd.DataFrame):
|
| 135 |
+
raise TypeError("Input X must be a pandas DataFrame.")
|
| 136 |
+
|
| 137 |
+
# Store input feature number and names as learned attributes
|
| 138 |
+
self.n_features_in_ = X.shape[1]
|
| 139 |
+
self.feature_names_in_ = X.columns.tolist()
|
| 140 |
+
|
| 141 |
+
return self
|
| 142 |
+
|
| 143 |
+
def transform(self, X):
|
| 144 |
+
# Ensure .fit() happened before
|
| 145 |
+
check_is_fitted(self)
|
| 146 |
+
|
| 147 |
+
# Validate input data type
|
| 148 |
+
if not isinstance(X, pd.DataFrame):
|
| 149 |
+
raise TypeError("Input X must be a pandas DataFrame.")
|
| 150 |
+
|
| 151 |
+
# Convert all missing value types (None, np.nan, pd.NA, etc.) to np.nan
|
| 152 |
+
return X.fillna(value=np.nan)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# A wrapper for SimpleImputer to passthrough empty DataFrames during .transform() instead of raising a ValueError (SimpleImputer default behavior)
|
| 156 |
+
class RobustSimpleImputer(SimpleImputer):
|
| 157 |
+
def transform(self, X):
|
| 158 |
+
if X.empty:
|
| 159 |
+
return X
|
| 160 |
+
else:
|
| 161 |
+
return super().transform(X)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# Format categorical labels in snake_case
|
| 165 |
+
class SnakeCaseFormatter(BaseEstimator, TransformerMixin):
|
| 166 |
+
def __init__(self, columns=None):
|
| 167 |
+
if not isinstance(columns, list) and columns is not None:
|
| 168 |
+
raise TypeError("'columns' must be a list of column names or None. If None, all columns will be used.")
|
| 169 |
+
|
| 170 |
+
# Validate input value
|
| 171 |
+
if columns == []:
|
| 172 |
+
raise ValueError("'columns' cannot be an empty list. It must specify the column names for snake case formatting.")
|
| 173 |
+
|
| 174 |
+
self.columns = columns
|
| 175 |
+
|
| 176 |
+
def fit(self, X, y=None):
|
| 177 |
+
# Validate input data type
|
| 178 |
+
if not isinstance(X, pd.DataFrame):
|
| 179 |
+
raise TypeError("Input X must be a pandas DataFrame.")
|
| 180 |
+
|
| 181 |
+
# Determine columns to be transformed (all if none provided)
|
| 182 |
+
if self.columns is None:
|
| 183 |
+
self.columns_ = X.columns.tolist()
|
| 184 |
+
else:
|
| 185 |
+
self.columns_ = self.columns
|
| 186 |
+
# Ensure input DataFrame contains all required columns
|
| 187 |
+
missing_columns = set(self.columns_) - set(X.columns)
|
| 188 |
+
if missing_columns:
|
| 189 |
+
raise ColumnMismatchError(f"Input X is missing the following columns: {', '.join(missing_columns)}.")
|
| 190 |
+
|
| 191 |
+
# Store input feature number and names as learned attributes
|
| 192 |
+
self.n_features_in_ = X.shape[1]
|
| 193 |
+
self.feature_names_in_ = X.columns.tolist()
|
| 194 |
+
|
| 195 |
+
return self
|
| 196 |
+
|
| 197 |
+
def transform(self, X):
|
| 198 |
+
# Ensure .fit() happened before
|
| 199 |
+
check_is_fitted(self)
|
| 200 |
+
|
| 201 |
+
# Validate input data type
|
| 202 |
+
if not isinstance(X, pd.DataFrame):
|
| 203 |
+
raise TypeError("Input X must be a pandas DataFrame.")
|
| 204 |
+
|
| 205 |
+
# Ensure input DataFrame contains all required columns
|
| 206 |
+
missing_columns = set(self.columns_) - set(X.columns)
|
| 207 |
+
if missing_columns:
|
| 208 |
+
raise ColumnMismatchError(f"Input X is missing the following columns: {', '.join(missing_columns)}.")
|
| 209 |
+
|
| 210 |
+
# Ensure input feature names and feature order is the same as during .fit()
|
| 211 |
+
if X.columns.tolist() != self.feature_names_in_:
|
| 212 |
+
raise ColumnMismatchError("Feature names and feature order of input X must be the same as during .fit().")
|
| 213 |
+
|
| 214 |
+
X_transformed = X.copy()
|
| 215 |
+
|
| 216 |
+
for column in self.columns_:
|
| 217 |
+
X_transformed[column] = X_transformed[column].apply(
|
| 218 |
+
lambda categorical_label: (
|
| 219 |
+
categorical_label
|
| 220 |
+
.strip() # Remove leading/trailing spaces
|
| 221 |
+
.lower() # Convert to lowercase
|
| 222 |
+
.replace("-", "_") # Replace hyphens with "_"
|
| 223 |
+
.replace("/", "_") # Replace slashes with "_"
|
| 224 |
+
.replace(" ", "_") # Replace spaces with "_"
|
| 225 |
+
if isinstance(categorical_label, str) else categorical_label
|
| 226 |
+
)
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
return X_transformed
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# Convert binary categorical columns to boolean columns
|
| 233 |
+
class BooleanColumnTransformer(BaseEstimator, TransformerMixin):
|
| 234 |
+
def __init__(self, boolean_column_mappings):
|
| 235 |
+
# Validate input data type
|
| 236 |
+
if not isinstance(boolean_column_mappings, dict):
|
| 237 |
+
raise TypeError("'boolean_column_mappings' must be a dictionary specifying the mappings.")
|
| 238 |
+
|
| 239 |
+
# Validate input value
|
| 240 |
+
if not boolean_column_mappings:
|
| 241 |
+
raise ValueError("'boolean_column_mappings' cannot be an empty dictionary. It must specify the the mappings.")
|
| 242 |
+
|
| 243 |
+
# Iterate all columns in "boolean_column_mappings"
|
| 244 |
+
for column, mapping in boolean_column_mappings.items():
|
| 245 |
+
# Ensure the mapping of the current column is also a dictionary
|
| 246 |
+
if not isinstance(mapping, dict):
|
| 247 |
+
raise TypeError(f"The mapping for '{column}' must be a dictionary.")
|
| 248 |
+
|
| 249 |
+
# Ensure the values of the current mapping are boolean
|
| 250 |
+
if not all(isinstance(value, bool) for value in mapping.values()):
|
| 251 |
+
raise ValueError(f"All values in the mapping for '{column}' must be boolean (True or False).")
|
| 252 |
+
|
| 253 |
+
self.boolean_column_mappings = boolean_column_mappings
|
| 254 |
+
|
| 255 |
+
def _validate_input(self, X):
|
| 256 |
+
# Validate input data type
|
| 257 |
+
if not isinstance(X, pd.DataFrame):
|
| 258 |
+
raise TypeError("Input X must be a pandas DataFrame.")
|
| 259 |
+
|
| 260 |
+
# Ensure input DataFrame contains all required binary columns (from "boolean_column_mappings")
|
| 261 |
+
input_columns = set(X.columns)
|
| 262 |
+
required_columns = set(self.boolean_column_mappings.keys())
|
| 263 |
+
missing_columns = required_columns - input_columns
|
| 264 |
+
if missing_columns:
|
| 265 |
+
raise ColumnMismatchError(f"Input X is missing the following columns: {', '.join(missing_columns)}.")
|
| 266 |
+
|
| 267 |
+
# Ensure all binary columns have no missing values
|
| 268 |
+
for column in required_columns:
|
| 269 |
+
if X[column].isna().any():
|
| 270 |
+
raise MissingValueError(f"'{column}' column cannot contain missing values.")
|
| 271 |
+
|
| 272 |
+
# Ensure all binary columns have valid data types (str, int, float, bool)
|
| 273 |
+
for column in required_columns:
|
| 274 |
+
if X[column].apply(lambda x: not isinstance(x, (str, int, float, bool))).any():
|
| 275 |
+
raise TypeError(f"All values in '{column}' column must be str, int, float or bool.")
|
| 276 |
+
|
| 277 |
+
# Ensure all binary columns contains only known labels (from "boolean_column_mappings")
|
| 278 |
+
for column, mapping in self.boolean_column_mappings.items():
|
| 279 |
+
known_labels = set(mapping.keys())
|
| 280 |
+
input_labels = set(X[column].unique())
|
| 281 |
+
unknown_labels = input_labels- known_labels
|
| 282 |
+
if unknown_labels:
|
| 283 |
+
raise CategoricalLabelError(f"'{column}' column contains unknown labels that are not in 'boolean_column_mappings': {', '.join(unknown_labels)}.")
|
| 284 |
+
|
| 285 |
+
def fit(self, X, y=None):
|
| 286 |
+
# Validate input
|
| 287 |
+
self._validate_input(X)
|
| 288 |
+
|
| 289 |
+
# Store input feature number and names as learned attributes
|
| 290 |
+
self.n_features_in_ = X.shape[1]
|
| 291 |
+
self.feature_names_in_ = X.columns.tolist()
|
| 292 |
+
|
| 293 |
+
return self
|
| 294 |
+
|
| 295 |
+
def transform(self, X):
|
| 296 |
+
# Ensure .fit() happened before
|
| 297 |
+
check_is_fitted(self)
|
| 298 |
+
|
| 299 |
+
# Validate input
|
| 300 |
+
self._validate_input(X)
|
| 301 |
+
|
| 302 |
+
# Ensure input feature names and feature order is the same as during .fit()
|
| 303 |
+
if X.columns.tolist() != self.feature_names_in_:
|
| 304 |
+
raise ColumnMismatchError("Feature names and feature order of input X must be the same as during .fit().")
|
| 305 |
+
|
| 306 |
+
X_transformed = X.copy()
|
| 307 |
+
for column, mapping in self.boolean_column_mappings.items():
|
| 308 |
+
X_transformed[column] = X_transformed[column].map(mapping)
|
| 309 |
+
|
| 310 |
+
return X_transformed
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
# Derive job stability from profession
|
| 314 |
+
class JobStabilityTransformer(BaseEstimator, TransformerMixin):
|
| 315 |
+
def __init__(self, job_stability_map):
|
| 316 |
+
# Validate input data type
|
| 317 |
+
if not isinstance(job_stability_map, dict):
|
| 318 |
+
raise TypeError("'job_stability_map' must be a dictionary specifying the mappings from 'profession' to 'job_stability'.")
|
| 319 |
+
|
| 320 |
+
# Validate input value
|
| 321 |
+
if not job_stability_map:
|
| 322 |
+
raise ValueError("'job_stability_map' cannot be an empty dictionary. It must specify the mappings from 'profession' to 'job_stability'.")
|
| 323 |
+
|
| 324 |
+
self.job_stability_map = job_stability_map
|
| 325 |
+
|
| 326 |
+
def _validate_input(self, X):
|
| 327 |
+
# Validate input data type
|
| 328 |
+
if not isinstance(X, pd.DataFrame):
|
| 329 |
+
raise TypeError("Input X must be a pandas DataFrame.")
|
| 330 |
+
|
| 331 |
+
# Ensure input DataFrame contains the required "profession" column
|
| 332 |
+
if "profession" not in X.columns:
|
| 333 |
+
raise ColumnMismatchError("Input X is missing the following columns: profession.")
|
| 334 |
+
|
| 335 |
+
# Ensure "profession" column has no missing values
|
| 336 |
+
if X["profession"].isna().any():
|
| 337 |
+
raise MissingValueError("'profession' column cannot contain missing values.")
|
| 338 |
+
|
| 339 |
+
# Ensure all values in "profession" column are strings
|
| 340 |
+
if X["profession"].apply(lambda x: not isinstance(x, str)).any():
|
| 341 |
+
raise TypeError("All values in 'profession' column must be strings.")
|
| 342 |
+
|
| 343 |
+
# Ensure "profession" column contains only known professions (from "job_stability_map")
|
| 344 |
+
known_professions = set(self.job_stability_map.keys())
|
| 345 |
+
input_professions = set(X["profession"].unique())
|
| 346 |
+
unknown_professions = input_professions - known_professions
|
| 347 |
+
if unknown_professions:
|
| 348 |
+
raise CategoricalLabelError(f"'profession' column contains unknown professions: {', '.join(unknown_professions)}.")
|
| 349 |
+
|
| 350 |
+
def fit(self, X, y=None):
|
| 351 |
+
# Validate input
|
| 352 |
+
self._validate_input(X)
|
| 353 |
+
|
| 354 |
+
# Store input feature number and names as learned attributes
|
| 355 |
+
self.n_features_in_ = X.shape[1]
|
| 356 |
+
self.feature_names_in_ = X.columns.tolist()
|
| 357 |
+
|
| 358 |
+
return self
|
| 359 |
+
|
| 360 |
+
def transform(self, X):
|
| 361 |
+
# Ensure .fit() happened before
|
| 362 |
+
check_is_fitted(self)
|
| 363 |
+
|
| 364 |
+
# Validate input
|
| 365 |
+
self._validate_input(X)
|
| 366 |
+
|
| 367 |
+
# Ensure input feature names and feature order is the same as during .fit()
|
| 368 |
+
if X.columns.tolist() != self.feature_names_in_:
|
| 369 |
+
raise ColumnMismatchError("Feature names and feature order of input X must be the same as during .fit().")
|
| 370 |
+
|
| 371 |
+
# Create job stability column by mapping professions to job stability tiers (default to "moderate" for unknown professions)
|
| 372 |
+
X_transformed = X.copy()
|
| 373 |
+
X_transformed["job_stability"] = X_transformed["profession"].map(self.job_stability_map)
|
| 374 |
+
|
| 375 |
+
return X_transformed
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
# Derive city tier from city
|
| 379 |
+
class CityTierTransformer(BaseEstimator, TransformerMixin):
|
| 380 |
+
def __init__(self, city_tier_map):
|
| 381 |
+
if not isinstance(city_tier_map, dict):
|
| 382 |
+
raise TypeError("'city_tier_map' must be a dictionary specifying the mappings from 'city' to 'city_tier'.")
|
| 383 |
+
|
| 384 |
+
# Validate input value
|
| 385 |
+
if not city_tier_map:
|
| 386 |
+
raise ValueError("'city_tier_map' cannot be an empty dictionary. It must specify the mappings from 'city' to 'city_tier'.")
|
| 387 |
+
|
| 388 |
+
self.city_tier_map = city_tier_map
|
| 389 |
+
|
| 390 |
+
def _validate_input(self, X):
|
| 391 |
+
# Validate input data type
|
| 392 |
+
if not isinstance(X, pd.DataFrame):
|
| 393 |
+
raise TypeError("Input X must be a pandas DataFrame.")
|
| 394 |
+
|
| 395 |
+
# Ensure input DataFrame contains the required "city" column
|
| 396 |
+
if "city" not in X.columns:
|
| 397 |
+
raise ColumnMismatchError("Input X is missing the following columns: city.")
|
| 398 |
+
|
| 399 |
+
# Ensure "city" column has no missing values
|
| 400 |
+
if X["city"].isna().any():
|
| 401 |
+
raise MissingValueError("'city' column cannot contain missing values.")
|
| 402 |
+
|
| 403 |
+
# Ensure all values in "city" column are strings
|
| 404 |
+
if X["city"].apply(lambda x: not isinstance(x, str)).any():
|
| 405 |
+
raise TypeError("All values in 'city' column must be strings.")
|
| 406 |
+
|
| 407 |
+
# Ensure "city" column contains only known cities (from "city_tier_map")
|
| 408 |
+
known_cities = set(self.city_tier_map.keys())
|
| 409 |
+
input_cities = set(X["city"].unique())
|
| 410 |
+
unknown_cities = input_cities - known_cities
|
| 411 |
+
if unknown_cities:
|
| 412 |
+
raise CategoricalLabelError(f"'city' column contains unknown cities: {', '.join(unknown_cities)}.")
|
| 413 |
+
|
| 414 |
+
def fit(self, X, y=None):
|
| 415 |
+
# Validate input
|
| 416 |
+
self._validate_input(X)
|
| 417 |
+
|
| 418 |
+
# Store input feature number and names as learned attributes
|
| 419 |
+
self.n_features_in_ = X.shape[1]
|
| 420 |
+
self.feature_names_in_ = X.columns.tolist()
|
| 421 |
+
|
| 422 |
+
return self
|
| 423 |
+
|
| 424 |
+
def transform(self, X):
|
| 425 |
+
# Ensure .fit() happened before
|
| 426 |
+
check_is_fitted(self)
|
| 427 |
+
|
| 428 |
+
# Validate input
|
| 429 |
+
self._validate_input(X)
|
| 430 |
+
|
| 431 |
+
# Ensure input feature names and feature order is the same as during .fit()
|
| 432 |
+
if X.columns.tolist() != self.feature_names_in_:
|
| 433 |
+
raise ColumnMismatchError("Feature names and feature order of input X must be the same as during .fit().")
|
| 434 |
+
|
| 435 |
+
# Create city tier column by mapping cities to city tiers
|
| 436 |
+
X_transformed = X.copy()
|
| 437 |
+
X_transformed["city_tier"] = X_transformed["city"].map(self.city_tier_map)
|
| 438 |
+
|
| 439 |
+
return X_transformed
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
# Target encoding of state default rate
|
| 443 |
+
class StateDefaultRateTargetEncoder(BaseEstimator, TransformerMixin):
|
| 444 |
+
def _validate_X_input(self, X):
|
| 445 |
+
# Ensure X input is a DataFrame
|
| 446 |
+
if not isinstance(X, pd.DataFrame):
|
| 447 |
+
raise TypeError("Input X must be a pandas DataFrame.")
|
| 448 |
+
|
| 449 |
+
# Ensure DataFrame contains the required "state" column
|
| 450 |
+
if "state" not in X.columns:
|
| 451 |
+
raise ColumnMismatchError("Input X is missing the following columns: state.")
|
| 452 |
+
|
| 453 |
+
# Ensure "state" column has no missing values
|
| 454 |
+
if X["state"].isna().any():
|
| 455 |
+
raise MissingValueError("'state' column cannot contain missing values.")
|
| 456 |
+
|
| 457 |
+
# Ensure all values in "state" column are strings
|
| 458 |
+
if X["state"].apply(lambda x: not isinstance(x, str)).any():
|
| 459 |
+
raise TypeError("All values in 'state' column must be strings.")
|
| 460 |
+
|
| 461 |
+
def fit(self, X, y):
|
| 462 |
+
# Validate X input
|
| 463 |
+
self._validate_X_input(X)
|
| 464 |
+
|
| 465 |
+
# Ensure y input is a pandas Series
|
| 466 |
+
if not isinstance(y, pd.Series):
|
| 467 |
+
raise TypeError("Input y must be a pandas Series.")
|
| 468 |
+
|
| 469 |
+
# Ensure y has no missing values
|
| 470 |
+
if y.isna().any():
|
| 471 |
+
raise MissingValueError("Input y cannot contain missing values.")
|
| 472 |
+
|
| 473 |
+
# Ensure y is integer type
|
| 474 |
+
if not pd.api.types.is_integer_dtype(y):
|
| 475 |
+
raise TypeError("Input y must be integer type.")
|
| 476 |
+
|
| 477 |
+
# Ensure all y values are 0 or 1
|
| 478 |
+
if not y.isin([0, 1]).all():
|
| 479 |
+
raise ValueError("All y values must be 0 (no default) or 1 (default).")
|
| 480 |
+
|
| 481 |
+
# Ensure X and y have the same index
|
| 482 |
+
if not X.index.equals(y.index):
|
| 483 |
+
raise ValueError("Input X and y must have the same index.")
|
| 484 |
+
|
| 485 |
+
# Store input feature number and names as learned attributes
|
| 486 |
+
self.n_features_in_ = X.shape[1]
|
| 487 |
+
self.feature_names_in_ = X.columns.tolist()
|
| 488 |
+
|
| 489 |
+
# Calculate default rate by state
|
| 490 |
+
df = X.copy()
|
| 491 |
+
df["default"] = y
|
| 492 |
+
self.default_rate_by_state_ = df.groupby("state")["default"].mean()
|
| 493 |
+
|
| 494 |
+
return self
|
| 495 |
+
|
| 496 |
+
def transform(self, X):
|
| 497 |
+
# Ensure .fit() happened before
|
| 498 |
+
check_is_fitted(self)
|
| 499 |
+
|
| 500 |
+
# Validate X input
|
| 501 |
+
self._validate_X_input(X)
|
| 502 |
+
|
| 503 |
+
# Ensure "state" column contains only known states seen during .fit()
|
| 504 |
+
known_states = set(self.default_rate_by_state_.index)
|
| 505 |
+
input_states = set(X["state"].unique())
|
| 506 |
+
unknown_states = input_states - known_states
|
| 507 |
+
if unknown_states:
|
| 508 |
+
raise CategoricalLabelError(f"'state' column contains unknown states: {', '.join(unknown_states)}.")
|
| 509 |
+
|
| 510 |
+
# Ensure input feature names and feature order is the same as during .fit()
|
| 511 |
+
if X.columns.tolist() != self.feature_names_in_:
|
| 512 |
+
raise ColumnMismatchError("Feature names and feature order of input X must be the same as during .fit().")
|
| 513 |
+
|
| 514 |
+
# Create state default rate column by mapping the state to its corresponding default rate
|
| 515 |
+
X_transformed = X.copy()
|
| 516 |
+
X_transformed["state_default_rate"] = X_transformed["state"].map(self.default_rate_by_state_)
|
| 517 |
+
|
| 518 |
+
return X_transformed
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
# A wrapper for StandardScaler to passthrough empty DataFrames during .transform() instead of raising a ValueError
|
| 522 |
+
class RobustStandardScaler(StandardScaler):
|
| 523 |
+
def transform(self, X):
|
| 524 |
+
if X.empty:
|
| 525 |
+
feature_names_out = self.get_feature_names_out(X.columns)
|
| 526 |
+
return pd.DataFrame(columns=feature_names_out, dtype=float)
|
| 527 |
+
else:
|
| 528 |
+
return super().transform(X)
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
# A wrapper for StandardScaler to passthrough empty DataFrames during .transform() instead of raising a ValueError
|
| 532 |
+
class RobustOneHotEncoder(OneHotEncoder):
|
| 533 |
+
def transform(self, X):
|
| 534 |
+
check_is_fitted(self)
|
| 535 |
+
if X.empty:
|
| 536 |
+
feature_names_out = self.get_feature_names_out(X.columns)
|
| 537 |
+
return pd.DataFrame(columns=feature_names_out, dtype=float)
|
| 538 |
+
else:
|
| 539 |
+
return super().transform(X)
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
# A wrapper for StandardScaler to passthrough empty DataFrames during .transform() instead of raising a ValueError
|
| 543 |
+
class RobustOrdinalEncoder(OrdinalEncoder):
|
| 544 |
+
def transform(self, X):
|
| 545 |
+
if X.empty:
|
| 546 |
+
feature_names_out = self.get_feature_names_out(X.columns)
|
| 547 |
+
return pd.DataFrame(columns=feature_names_out, dtype=float)
|
| 548 |
+
else:
|
| 549 |
+
return super().transform(X)
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
# Feature selection for downstream model training and inference
|
| 553 |
+
class FeatureSelector(BaseEstimator, TransformerMixin):
|
| 554 |
+
def __init__(self, columns_to_keep):
|
| 555 |
+
# Validate input data type
|
| 556 |
+
if not isinstance(columns_to_keep, list):
|
| 557 |
+
raise TypeError("'columns_to_keep' must be a list of column names.")
|
| 558 |
+
|
| 559 |
+
# Validate input value
|
| 560 |
+
if not columns_to_keep:
|
| 561 |
+
raise ValueError("'columns_to_keep' cannot be an empty list. It must specify the column names.")
|
| 562 |
+
|
| 563 |
+
self.columns_to_keep = columns_to_keep
|
| 564 |
+
|
| 565 |
+
def fit(self, X, y=None):
|
| 566 |
+
# Validate input data type
|
| 567 |
+
if not isinstance(X, pd.DataFrame):
|
| 568 |
+
raise TypeError("Input X must be a pandas DataFrame.")
|
| 569 |
+
|
| 570 |
+
# Ensure input DataFrame contains all columns_to_keep
|
| 571 |
+
missing_columns = set(self.columns_to_keep) - set(X.columns)
|
| 572 |
+
if missing_columns:
|
| 573 |
+
raise ColumnMismatchError(f"Input X is missing the following columns: {', '.join(missing_columns)}.")
|
| 574 |
+
|
| 575 |
+
# Store input feature number and names as learned attributes
|
| 576 |
+
self.n_features_in_ = X.shape[1]
|
| 577 |
+
self.feature_names_in_ = X.columns.tolist()
|
| 578 |
+
|
| 579 |
+
return self
|
| 580 |
+
|
| 581 |
+
def transform(self, X):
|
| 582 |
+
# Ensure .fit() happened before
|
| 583 |
+
check_is_fitted(self)
|
| 584 |
+
|
| 585 |
+
# Validate input data type
|
| 586 |
+
if not isinstance(X, pd.DataFrame):
|
| 587 |
+
raise TypeError("Input X must be a pandas DataFrame.")
|
| 588 |
+
|
| 589 |
+
# Ensure input feature names and feature order is the same as during .fit()
|
| 590 |
+
if X.columns.tolist() != self.feature_names_in_:
|
| 591 |
+
raise ColumnMismatchError("Feature names and feature order of input X must be the same as during .fit().")
|
| 592 |
+
|
| 593 |
+
# Create transformed DataFrame with only the selected features
|
| 594 |
+
X_transformed = X[self.columns_to_keep].copy()
|
| 595 |
+
|
| 596 |
+
return X_transformed
|
src/global_constants.py
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Imports
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
# --- Constants for Pipeline ---
|
| 5 |
+
# Define critical vs. non-critical features (for custom MissingValueChecker)
|
| 6 |
+
CRITICAL_FEATURES = ["income", "age", "experience", "profession", "city", "state", "current_job_yrs", "current_house_yrs"]
|
| 7 |
+
NON_CRITICAL_FEATURES = ["married", "car_ownership", "house_ownership"]
|
| 8 |
+
|
| 9 |
+
# Define columns to format categorical labels as snake case (for custom SnakeCaseFormatter)
|
| 10 |
+
COLUMNS_FOR_SNAKE_CASING = ["profession", "city", "state"]
|
| 11 |
+
|
| 12 |
+
# Map binary categorical columns to boolean (for custom BooleanColumnTransformer)
|
| 13 |
+
BOOLEAN_COLUMN_MAPPINGS = {
|
| 14 |
+
"married": {"married": True, "single": False},
|
| 15 |
+
"car_ownership": {"yes": True, "no": False}
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
# Map profession to job stability tier (for custom JobStabilityTransformer)
|
| 19 |
+
JOB_STABILITY_MAP = {
|
| 20 |
+
# Government and highly regulated roles with exceptional job security
|
| 21 |
+
"civil_servant": "very_stable",
|
| 22 |
+
"army_officer": "very_stable",
|
| 23 |
+
"police_officer": "very_stable",
|
| 24 |
+
"magistrate": "very_stable",
|
| 25 |
+
"official": "very_stable",
|
| 26 |
+
"air_traffic_controller": "very_stable",
|
| 27 |
+
"firefighter": "very_stable",
|
| 28 |
+
"librarian": "very_stable",
|
| 29 |
+
|
| 30 |
+
# Licensed/regulated professionals with strong job security
|
| 31 |
+
"physician": "stable",
|
| 32 |
+
"surgeon": "stable",
|
| 33 |
+
"dentist": "stable",
|
| 34 |
+
"chartered_accountant": "stable",
|
| 35 |
+
"civil_engineer": "stable",
|
| 36 |
+
"mechanical_engineer": "stable",
|
| 37 |
+
"chemical_engineer": "stable",
|
| 38 |
+
"petroleum_engineer": "stable",
|
| 39 |
+
"biomedical_engineer": "stable",
|
| 40 |
+
"engineer": "stable",
|
| 41 |
+
|
| 42 |
+
# Corporate roles with steady demand
|
| 43 |
+
"software_developer": "moderate",
|
| 44 |
+
"computer_hardware_engineer": "moderate",
|
| 45 |
+
"financial_analyst": "moderate",
|
| 46 |
+
"industrial_engineer": "moderate",
|
| 47 |
+
"statistician": "moderate",
|
| 48 |
+
"microbiologist": "moderate",
|
| 49 |
+
"scientist": "moderate",
|
| 50 |
+
"geologist": "moderate",
|
| 51 |
+
"economist": "moderate",
|
| 52 |
+
"technology_specialist": "moderate",
|
| 53 |
+
"design_engineer": "moderate",
|
| 54 |
+
"architect": "moderate",
|
| 55 |
+
"surveyor": "moderate",
|
| 56 |
+
"secretary": "moderate",
|
| 57 |
+
"flight_attendant": "moderate",
|
| 58 |
+
"hotel_manager": "moderate",
|
| 59 |
+
"computer_operator": "moderate",
|
| 60 |
+
"technician": "moderate",
|
| 61 |
+
|
| 62 |
+
# Project-based or variable demand roles
|
| 63 |
+
"web_designer": "variable",
|
| 64 |
+
"fashion_designer": "variable",
|
| 65 |
+
"graphic_designer": "variable",
|
| 66 |
+
"designer": "variable",
|
| 67 |
+
"consultant": "variable",
|
| 68 |
+
"technical_writer": "variable",
|
| 69 |
+
"artist": "variable",
|
| 70 |
+
"comedian": "variable",
|
| 71 |
+
"chef": "variable",
|
| 72 |
+
"analyst": "variable",
|
| 73 |
+
"psychologist": "variable",
|
| 74 |
+
"drafter": "variable",
|
| 75 |
+
"aviator": "variable",
|
| 76 |
+
"politician": "variable",
|
| 77 |
+
"lawyer": "variable"
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
# Map city to city tier (for custom CityTierTransformer)
|
| 81 |
+
CITY_TIER_MAP = {
|
| 82 |
+
# Tier 1 cities
|
| 83 |
+
"new_delhi": "tier_1",
|
| 84 |
+
"navi_mumbai": "tier_1",
|
| 85 |
+
"kolkata": "tier_1",
|
| 86 |
+
"bangalore": "tier_1",
|
| 87 |
+
"chennai": "tier_1",
|
| 88 |
+
"hyderabad": "tier_1",
|
| 89 |
+
"mumbai": "tier_1",
|
| 90 |
+
"pune": "tier_1",
|
| 91 |
+
"ahmedabad": "tier_1",
|
| 92 |
+
"jaipur": "tier_1",
|
| 93 |
+
"lucknow": "tier_1",
|
| 94 |
+
"noida": "tier_1",
|
| 95 |
+
"coimbatore": "tier_1",
|
| 96 |
+
"surat": "tier_1",
|
| 97 |
+
"nagpur": "tier_1",
|
| 98 |
+
"kochi": "tier_1",
|
| 99 |
+
"thiruvananthapuram": "tier_1",
|
| 100 |
+
"kanpur": "tier_1",
|
| 101 |
+
"patna": "tier_1",
|
| 102 |
+
|
| 103 |
+
# Tier 2 cities
|
| 104 |
+
"bhopal": "tier_2",
|
| 105 |
+
"vijayawada": "tier_2",
|
| 106 |
+
"indore": "tier_2",
|
| 107 |
+
"jodhpur": "tier_2",
|
| 108 |
+
"vadodara": "tier_2",
|
| 109 |
+
"ludhiana": "tier_2",
|
| 110 |
+
"madurai": "tier_2",
|
| 111 |
+
"agra": "tier_2",
|
| 112 |
+
"mysore[7][8][9]": "tier_2",
|
| 113 |
+
"rajkot": "tier_2",
|
| 114 |
+
"nashik": "tier_2",
|
| 115 |
+
"amritsar": "tier_2",
|
| 116 |
+
"ranchi": "tier_2",
|
| 117 |
+
"chandigarh_city": "tier_2",
|
| 118 |
+
"allahabad": "tier_2",
|
| 119 |
+
"bhubaneswar": "tier_2",
|
| 120 |
+
"varanasi": "tier_2",
|
| 121 |
+
"jabalpur": "tier_2",
|
| 122 |
+
"guwahati": "tier_2",
|
| 123 |
+
"tiruppur": "tier_2",
|
| 124 |
+
"raipur": "tier_2",
|
| 125 |
+
"udaipur": "tier_2",
|
| 126 |
+
"gwalior": "tier_2",
|
| 127 |
+
|
| 128 |
+
# Tier 3 cities
|
| 129 |
+
"vijayanagaram": "tier_3",
|
| 130 |
+
"bulandshahr": "tier_3",
|
| 131 |
+
"saharsa[29]": "tier_3",
|
| 132 |
+
"hajipur[31]": "tier_3",
|
| 133 |
+
"satara": "tier_3",
|
| 134 |
+
"ongole": "tier_3",
|
| 135 |
+
"bellary": "tier_3",
|
| 136 |
+
"giridih": "tier_3",
|
| 137 |
+
"hospet": "tier_3",
|
| 138 |
+
"khammam": "tier_3",
|
| 139 |
+
"danapur": "tier_3",
|
| 140 |
+
"bareilly": "tier_3",
|
| 141 |
+
"satna": "tier_3",
|
| 142 |
+
"howrah": "tier_3",
|
| 143 |
+
"thanjavur": "tier_3",
|
| 144 |
+
"farrukhabad": "tier_3",
|
| 145 |
+
"buxar[37]": "tier_3",
|
| 146 |
+
"arrah": "tier_3",
|
| 147 |
+
"thrissur": "tier_3",
|
| 148 |
+
"proddatur": "tier_3",
|
| 149 |
+
"bahraich": "tier_3",
|
| 150 |
+
"nandyal": "tier_3",
|
| 151 |
+
"siwan[32]": "tier_3",
|
| 152 |
+
"barasat": "tier_3",
|
| 153 |
+
"dhule": "tier_3",
|
| 154 |
+
"begusarai": "tier_3",
|
| 155 |
+
"khandwa": "tier_3",
|
| 156 |
+
"guntakal": "tier_3",
|
| 157 |
+
"latur": "tier_3",
|
| 158 |
+
"karaikudi": "tier_3",
|
| 159 |
+
|
| 160 |
+
# Unknown tier cities
|
| 161 |
+
"rewa": "unknown",
|
| 162 |
+
"parbhani": "unknown",
|
| 163 |
+
"alappuzha": "unknown",
|
| 164 |
+
"tiruchirappalli[10]": "unknown",
|
| 165 |
+
"jalgaon": "unknown",
|
| 166 |
+
"jamnagar": "unknown",
|
| 167 |
+
"kota[6]": "unknown",
|
| 168 |
+
"karimnagar": "unknown",
|
| 169 |
+
"adoni": "unknown",
|
| 170 |
+
"erode[17]": "unknown",
|
| 171 |
+
"kollam": "unknown",
|
| 172 |
+
"anantapuram[24]": "unknown",
|
| 173 |
+
"kamarhati": "unknown",
|
| 174 |
+
"bhusawal": "unknown",
|
| 175 |
+
"sirsa": "unknown",
|
| 176 |
+
"amaravati": "unknown",
|
| 177 |
+
"secunderabad": "unknown",
|
| 178 |
+
"ajmer": "unknown",
|
| 179 |
+
"miryalaguda": "unknown",
|
| 180 |
+
"ambattur": "unknown",
|
| 181 |
+
"pondicherry": "unknown",
|
| 182 |
+
"shimoga": "unknown",
|
| 183 |
+
"gulbarga": "unknown",
|
| 184 |
+
"saharanpur": "unknown",
|
| 185 |
+
"gopalpur": "unknown",
|
| 186 |
+
"amravati": "unknown",
|
| 187 |
+
"udupi": "unknown",
|
| 188 |
+
"aurangabad[39]": "unknown",
|
| 189 |
+
"shimla": "unknown",
|
| 190 |
+
"bidhannagar": "unknown",
|
| 191 |
+
"purnia[26]": "unknown",
|
| 192 |
+
"bijapur": "unknown",
|
| 193 |
+
"patiala": "unknown",
|
| 194 |
+
"malda": "unknown",
|
| 195 |
+
"sagar": "unknown",
|
| 196 |
+
"durgapur": "unknown",
|
| 197 |
+
"junagadh": "unknown",
|
| 198 |
+
"singrauli": "unknown",
|
| 199 |
+
"agartala": "unknown",
|
| 200 |
+
"hindupur": "unknown",
|
| 201 |
+
"naihati": "unknown",
|
| 202 |
+
"north_dumdum": "unknown",
|
| 203 |
+
"panchkula": "unknown",
|
| 204 |
+
"anantapur": "unknown",
|
| 205 |
+
"serampore": "unknown",
|
| 206 |
+
"bathinda": "unknown",
|
| 207 |
+
"nadiad": "unknown",
|
| 208 |
+
"haridwar": "unknown",
|
| 209 |
+
"berhampur": "unknown",
|
| 210 |
+
"jamshedpur": "unknown",
|
| 211 |
+
"bidar": "unknown",
|
| 212 |
+
"kottayam": "unknown",
|
| 213 |
+
"solapur": "unknown",
|
| 214 |
+
"suryapet": "unknown",
|
| 215 |
+
"aizawl": "unknown",
|
| 216 |
+
"asansol": "unknown",
|
| 217 |
+
"deoghar": "unknown",
|
| 218 |
+
"eluru[25]": "unknown",
|
| 219 |
+
"ulhasnagar": "unknown",
|
| 220 |
+
"aligarh": "unknown",
|
| 221 |
+
"south_dumdum": "unknown",
|
| 222 |
+
"berhampore": "unknown",
|
| 223 |
+
"gandhinagar": "unknown",
|
| 224 |
+
"sonipat": "unknown",
|
| 225 |
+
"muzaffarpur": "unknown",
|
| 226 |
+
"raichur": "unknown",
|
| 227 |
+
"rajpur_sonarpur": "unknown",
|
| 228 |
+
"ambarnath": "unknown",
|
| 229 |
+
"katihar": "unknown",
|
| 230 |
+
"kozhikode": "unknown",
|
| 231 |
+
"vellore": "unknown",
|
| 232 |
+
"malegaon": "unknown",
|
| 233 |
+
"nagaon": "unknown",
|
| 234 |
+
"srinagar": "unknown",
|
| 235 |
+
"davanagere": "unknown",
|
| 236 |
+
"bhagalpur": "unknown",
|
| 237 |
+
"meerut": "unknown",
|
| 238 |
+
"dindigul": "unknown",
|
| 239 |
+
"bhatpara": "unknown",
|
| 240 |
+
"ghaziabad": "unknown",
|
| 241 |
+
"kulti": "unknown",
|
| 242 |
+
"chapra": "unknown",
|
| 243 |
+
"dibrugarh": "unknown",
|
| 244 |
+
"panihati": "unknown",
|
| 245 |
+
"bhiwandi": "unknown",
|
| 246 |
+
"morbi": "unknown",
|
| 247 |
+
"kalyan_dombivli": "unknown",
|
| 248 |
+
"gorakhpur": "unknown",
|
| 249 |
+
"panvel": "unknown",
|
| 250 |
+
"siliguri": "unknown",
|
| 251 |
+
"bongaigaon": "unknown",
|
| 252 |
+
"ramgarh": "unknown",
|
| 253 |
+
"ozhukarai": "unknown",
|
| 254 |
+
"mirzapur": "unknown",
|
| 255 |
+
"akola": "unknown",
|
| 256 |
+
"motihari[34]": "unknown",
|
| 257 |
+
"jalna": "unknown",
|
| 258 |
+
"jalandhar": "unknown",
|
| 259 |
+
"unnao": "unknown",
|
| 260 |
+
"karnal": "unknown",
|
| 261 |
+
"cuttack": "unknown",
|
| 262 |
+
"ichalkaranji": "unknown",
|
| 263 |
+
"warangal[11][12]": "unknown",
|
| 264 |
+
"jhansi": "unknown",
|
| 265 |
+
"narasaraopet": "unknown",
|
| 266 |
+
"chinsurah": "unknown",
|
| 267 |
+
"jehanabad[38]": "unknown",
|
| 268 |
+
"dhanbad": "unknown",
|
| 269 |
+
"gudivada": "unknown",
|
| 270 |
+
"gandhidham": "unknown",
|
| 271 |
+
"raiganj": "unknown",
|
| 272 |
+
"kishanganj[35]": "unknown",
|
| 273 |
+
"belgaum": "unknown",
|
| 274 |
+
"tirupati[21][22]": "unknown",
|
| 275 |
+
"tumkur": "unknown",
|
| 276 |
+
"kurnool[18]": "unknown",
|
| 277 |
+
"gurgaon": "unknown",
|
| 278 |
+
"muzaffarnagar": "unknown",
|
| 279 |
+
"aurangabad": "unknown",
|
| 280 |
+
"bhavnagar": "unknown",
|
| 281 |
+
"munger": "unknown",
|
| 282 |
+
"tirunelveli": "unknown",
|
| 283 |
+
"mango": "unknown",
|
| 284 |
+
"kadapa[23]": "unknown",
|
| 285 |
+
"khora,_ghaziabad": "unknown",
|
| 286 |
+
"ambala": "unknown",
|
| 287 |
+
"ratlam": "unknown",
|
| 288 |
+
"surendranagar_dudhrej": "unknown",
|
| 289 |
+
"delhi_city": "unknown",
|
| 290 |
+
"hapur": "unknown",
|
| 291 |
+
"rohtak": "unknown",
|
| 292 |
+
"durg": "unknown",
|
| 293 |
+
"korba": "unknown",
|
| 294 |
+
"shivpuri": "unknown",
|
| 295 |
+
"nangloi_jat": "unknown",
|
| 296 |
+
"madanapalle": "unknown",
|
| 297 |
+
"thoothukudi": "unknown",
|
| 298 |
+
"nagercoil": "unknown",
|
| 299 |
+
"gaya": "unknown",
|
| 300 |
+
"jammu[16]": "unknown",
|
| 301 |
+
"kakinada": "unknown",
|
| 302 |
+
"dewas": "unknown",
|
| 303 |
+
"bhalswa_jahangir_pur": "unknown",
|
| 304 |
+
"baranagar": "unknown",
|
| 305 |
+
"firozabad": "unknown",
|
| 306 |
+
"phusro": "unknown",
|
| 307 |
+
"guna": "unknown",
|
| 308 |
+
"thane": "unknown",
|
| 309 |
+
"etawah": "unknown",
|
| 310 |
+
"vasai_virar": "unknown",
|
| 311 |
+
"pallavaram": "unknown",
|
| 312 |
+
"morena": "unknown",
|
| 313 |
+
"ballia": "unknown",
|
| 314 |
+
"burhanpur": "unknown",
|
| 315 |
+
"phagwara": "unknown",
|
| 316 |
+
"mau": "unknown",
|
| 317 |
+
"mangalore": "unknown",
|
| 318 |
+
"alwar": "unknown",
|
| 319 |
+
"mahbubnagar": "unknown",
|
| 320 |
+
"maheshtala": "unknown",
|
| 321 |
+
"hazaribagh": "unknown",
|
| 322 |
+
"bihar_sharif": "unknown",
|
| 323 |
+
"faridabad": "unknown",
|
| 324 |
+
"tenali": "unknown",
|
| 325 |
+
"amroha": "unknown",
|
| 326 |
+
"medininagar": "unknown",
|
| 327 |
+
"rajahmundry[19][20]": "unknown",
|
| 328 |
+
"bhilai": "unknown",
|
| 329 |
+
"moradabad": "unknown",
|
| 330 |
+
"machilipatnam": "unknown",
|
| 331 |
+
"mira_bhayandar": "unknown",
|
| 332 |
+
"pali": "unknown",
|
| 333 |
+
"mehsana": "unknown",
|
| 334 |
+
"imphal": "unknown",
|
| 335 |
+
"sambalpur": "unknown",
|
| 336 |
+
"ujjain": "unknown",
|
| 337 |
+
"madhyamgram": "unknown",
|
| 338 |
+
"jamalpur[36]": "unknown",
|
| 339 |
+
"gangtok": "unknown",
|
| 340 |
+
"anand": "unknown",
|
| 341 |
+
"dehradun": "unknown",
|
| 342 |
+
"srikakulam": "unknown",
|
| 343 |
+
"darbhanga": "unknown",
|
| 344 |
+
"nizamabad": "unknown",
|
| 345 |
+
"dehri[30]": "unknown",
|
| 346 |
+
"jorhat": "unknown",
|
| 347 |
+
"kumbakonam": "unknown",
|
| 348 |
+
"haldia": "unknown",
|
| 349 |
+
"loni": "unknown",
|
| 350 |
+
"pimpri_chinchwad": "unknown",
|
| 351 |
+
"nanded": "unknown",
|
| 352 |
+
"kirari_suleman_nagar": "unknown",
|
| 353 |
+
"jaunpur": "unknown",
|
| 354 |
+
"bilaspur": "unknown",
|
| 355 |
+
"sambhal": "unknown",
|
| 356 |
+
"rourkela": "unknown",
|
| 357 |
+
"dharmavaram": "unknown",
|
| 358 |
+
"nellore[14][15]": "unknown",
|
| 359 |
+
"visakhapatnam[4]": "unknown",
|
| 360 |
+
"karawal_nagar": "unknown",
|
| 361 |
+
"avadi": "unknown",
|
| 362 |
+
"bhimavaram": "unknown",
|
| 363 |
+
"bardhaman": "unknown",
|
| 364 |
+
"silchar": "unknown",
|
| 365 |
+
"kavali": "unknown",
|
| 366 |
+
"tezpur": "unknown",
|
| 367 |
+
"ramagundam[27]": "unknown",
|
| 368 |
+
"yamunanagar": "unknown",
|
| 369 |
+
"sri_ganganagar": "unknown",
|
| 370 |
+
"sasaram[30]": "unknown",
|
| 371 |
+
"sikar": "unknown",
|
| 372 |
+
"bally": "unknown",
|
| 373 |
+
"bhiwani": "unknown",
|
| 374 |
+
"rampur": "unknown",
|
| 375 |
+
"uluberia": "unknown",
|
| 376 |
+
"sangli_miraj_&_kupwad": "unknown",
|
| 377 |
+
"hosur": "unknown",
|
| 378 |
+
"bikaner": "unknown",
|
| 379 |
+
"shahjahanpur": "unknown",
|
| 380 |
+
"sultan_pur_majra": "unknown",
|
| 381 |
+
"bharatpur": "unknown",
|
| 382 |
+
"tadepalligudem": "unknown",
|
| 383 |
+
"tinsukia": "unknown",
|
| 384 |
+
"salem": "unknown",
|
| 385 |
+
"mathura": "unknown",
|
| 386 |
+
"guntur[13]": "unknown",
|
| 387 |
+
"hubli_dharwad": "unknown",
|
| 388 |
+
"chittoor[28]": "unknown",
|
| 389 |
+
"tiruvottiyur": "unknown",
|
| 390 |
+
"ahmednagar": "unknown",
|
| 391 |
+
"fatehpur": "unknown",
|
| 392 |
+
"bhilwara": "unknown",
|
| 393 |
+
"kharagpur": "unknown",
|
| 394 |
+
"bettiah[33]": "unknown",
|
| 395 |
+
"bhind": "unknown",
|
| 396 |
+
"bokaro": "unknown",
|
| 397 |
+
"raebareli": "unknown",
|
| 398 |
+
"pudukkottai": "unknown",
|
| 399 |
+
"panipat": "unknown",
|
| 400 |
+
"tadipatri": "unknown",
|
| 401 |
+
"orai": "unknown",
|
| 402 |
+
"raurkela_industrial_township": "unknown",
|
| 403 |
+
"katni": "unknown",
|
| 404 |
+
"chandrapur": "unknown",
|
| 405 |
+
"kolhapur": "unknown"
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
# Define semantic column types (for ColumnTransformer to scale numerical and encode categorical columns)
|
| 409 |
+
NUMERICAL_COLUMNS = ["income", "age", "experience", "current_job_yrs", "current_house_yrs", "state_default_rate"]
|
| 410 |
+
CATEGORICAL_COLUMNS = ["house_ownership", "job_stability", "city_tier", "profession", "city", "state"]
|
| 411 |
+
BOOLEAN_COLUMNS = ["risk_flag", "married", "car_ownership"]
|
| 412 |
+
|
| 413 |
+
# Define the categories for the nominal column "house_ownership" (for OneHotEncoder)
|
| 414 |
+
NOMINAL_COLUMN_CATEGORIES = [["norent_noown", "owned", "rented"]] # OneHotEncoder requires list of lists even for single column
|
| 415 |
+
|
| 416 |
+
# Define the explicit order of categories for all ordinal columns (for OrdinalEncoder)
|
| 417 |
+
ORDINAL_COLUMN_ORDERS = [
|
| 418 |
+
["variable", "moderate", "stable", "very_stable"], # Order for job_stability
|
| 419 |
+
["unknown", "tier_3", "tier_2", "tier_1"] # Order for city_tier
|
| 420 |
+
]
|
| 421 |
+
|
| 422 |
+
# Define the columns to keep after preprocessing as model input (for custom FeatureSelector)
|
| 423 |
+
COLUMNS_TO_KEEP = [
|
| 424 |
+
"income", "age", "experience", "current_job_yrs", "current_house_yrs", "state_default_rate", "house_ownership_owned",
|
| 425 |
+
"house_ownership_rented", "job_stability", "city_tier", "married", "car_ownership"
|
| 426 |
+
]
|
| 427 |
+
|
| 428 |
+
# Store the best Random Forest hyperparameter values identified with random search (for RandomForestClassifier)
|
| 429 |
+
RF_BEST_PARAMS = {
|
| 430 |
+
"n_estimators": 225,
|
| 431 |
+
"max_depth": 26,
|
| 432 |
+
"min_samples_split": 2,
|
| 433 |
+
"min_samples_leaf": 1,
|
| 434 |
+
"max_features": np.float64(0.12974565961049356),
|
| 435 |
+
"class_weight": "balanced"
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
# --- Constants for App ---
|
| 439 |
+
# Lists of categorical string labels (in format expected by the pipeline)
|
| 440 |
+
MARRIED_LABELS = ["single", "married"]
|
| 441 |
+
CAR_OWNERSHIP_LABELS = ["yes", "no"]
|
| 442 |
+
HOUSE_OWNERSHIP_LABELS = ["rented", "owned", "norent_noown"]
|
| 443 |
+
PROFESSION_LABELS = [
|
| 444 |
+
"air_traffic_controller", "analyst", "architect", "army_officer", "artist", "aviator",
|
| 445 |
+
"biomedical_engineer", "chartered_accountant", "chef", "chemical_engineer", "civil_engineer",
|
| 446 |
+
"civil_servant", "comedian", "computer_hardware_engineer", "computer_operator", "consultant",
|
| 447 |
+
"dentist", "design_engineer", "designer", "drafter", "economist", "engineer",
|
| 448 |
+
"fashion_designer", "financial_analyst", "firefighter", "flight_attendant", "geologist",
|
| 449 |
+
"graphic_designer", "hotel_manager", "industrial_engineer", "lawyer", "librarian",
|
| 450 |
+
"magistrate", "mechanical_engineer", "microbiologist", "official", "petroleum_engineer",
|
| 451 |
+
"physician", "police_officer", "politician", "psychologist", "scientist", "secretary",
|
| 452 |
+
"software_developer", "statistician", "surgeon", "surveyor", "technical_writer",
|
| 453 |
+
"technician", "technology_specialist", "web_designer"
|
| 454 |
+
]
|
| 455 |
+
CITY_LABELS = [
|
| 456 |
+
"adoni", "agartala", "agra", "ahmedabad", "ahmednagar", "aizawl", "ajmer", "akola", "alappuzha", "aligarh",
|
| 457 |
+
"allahabad", "alwar", "ambala", "ambarnath", "ambattur", "amravati", "amritsar", "amroha", "anand", "anantapur",
|
| 458 |
+
"anantapuram[24]", "arrah", "asansol", "aurangabad", "aurangabad[39]", "avadi", "bahraich", "ballia", "bally",
|
| 459 |
+
"bangalore", "baranagar", "barasat", "bardhaman", "bareilly", "bathinda", "begusarai", "belgaum", "bellary",
|
| 460 |
+
"berhampore", "berhampur", "bettiah[33]", "bhadravati", "bhagalpur", "bhalswa_jahangir_pur", "bharatpur",
|
| 461 |
+
"bhatpara", "bhavnagar", "bhilai", "bhilwara", "bhimavaram", "bhind", "bhiwandi", "bhiwani", "bhopal",
|
| 462 |
+
"bhubaneswar", "bhusawal", "bidar", "bidhannagar", "bihar_sharif", "bijapur", "bikaner", "bilaspur", "bokaro",
|
| 463 |
+
"bongaigaon", "bulandshahr", "burhanpur", "buxar[37]", "chandigarh_city", "chandrapur", "chapra", "chennai",
|
| 464 |
+
"chinsurah", "chittoor[28]", "coimbatore", "cuttack", "danapur", "darbhanga", "davanagere", "dehradun",
|
| 465 |
+
"dehri[30]", "delhi_city", "deoghar", "dewas", "dhanbad", "dharmavaram", "dhule", "dibrugarh", "dindigul",
|
| 466 |
+
"durg", "durgapur", "eluru[25]", "erode[17]", "etawah", "faridabad", "farrukhabad", "fatehpur", "firozabad",
|
| 467 |
+
"gandhidham", "gandhinagar", "gangtok", "gaya", "ghaziabad", "giridih", "gopalpur", "gorakhpur", "gudivada",
|
| 468 |
+
"gulbarga", "guna", "guntakal", "guntur[13]", "gurgaon", "guwahati", "gwalior", "hajipur[31]", "haldia",
|
| 469 |
+
"hapur", "haridwar", "hazaribagh", "hindupur", "hospet", "hosur", "howrah", "hubli_dharwad", "hyderabad",
|
| 470 |
+
"ichalkaranji", "imphal", "indore", "jabalpur", "jaipur", "jalandhar", "jalgaon", "jalna", "jamalpur[36]",
|
| 471 |
+
"jammu[16]", "jamnagar", "jamshedpur", "jaunpur", "jehanabad[38]", "jhansi", "jodhpur", "jorhat", "junagadh",
|
| 472 |
+
"kadapa[23]", "kakinada", "kalyan_dombivli", "kamarhati", "kanpur", "karawal_nagar", "karaikudi", "karimnagar",
|
| 473 |
+
"karnal", "katihar", "katni", "kavali", "khammam", "khandwa", "kharagpur", "khora,_ghaziabad",
|
| 474 |
+
"kirari_suleman_nagar", "kishanganj[35]", "kochi", "kolhapur", "kolkata", "kollam", "korba", "kota[6]",
|
| 475 |
+
"kottayam", "kozhikode", "kulti", "kumbakonam", "kurnool[18]", "latur", "loni", "lucknow", "ludhiana",
|
| 476 |
+
"machilipatnam", "madanapalle", "madhyamgram", "madurai", "mahbubnagar", "maheshtala", "malda", "malegaon",
|
| 477 |
+
"mango", "mangalore", "mathura", "mau", "medininagar", "meerut", "mehsana", "mira_bhayandar", "mirzapur",
|
| 478 |
+
"miryalaguda", "moradabad", "morbi", "morena", "motihari[34]", "mumbai", "munger", "muzaffarnagar",
|
| 479 |
+
"muzaffarpur", "mysore[7][8][9]", "nadiad", "nagaon", "nagercoil", "nagpur", "naihati", "nanded", "nandyal",
|
| 480 |
+
"nangloi_jat", "narasaraopet", "nashik", "navi_mumbai", "nellore[14][15]", "new_delhi", "nizamabad", "noida",
|
| 481 |
+
"north_dumdum", "ongole", "orai", "ozhukarai", "pali", "pallavaram", "panchkula", "panipat", "panihati",
|
| 482 |
+
"panvel", "parbhani", "patiala", "patna", "phagwara", "phusro", "pimpri_chinchwad", "pondicherry",
|
| 483 |
+
"proddatur", "pudukkottai", "pune", "purnia[26]", "raebareli", "raichur", "raiganj", "raipur",
|
| 484 |
+
"rajahmundry[19][20]", "rajkot", "rajpur_sonarpur", "ramagundam[27]", "ramgarh", "rampur", "ranchi", "ratlam",
|
| 485 |
+
"raurkela_industrial_township", "rewa", "rohtak", "rourkela", "sagar", "saharanpur", "saharsa[29]", "salem",
|
| 486 |
+
"sambalpur", "sambhal", "sangli_miraj_&_kupwad", "sasaram[30]", "satara", "satna", "secunderabad",
|
| 487 |
+
"serampore", "shahjahanpur", "shimla", "shimoga", "shivpuri", "sikar", "silchar", "siliguri", "singrauli",
|
| 488 |
+
"sirsa", "siwan[32]", "solapur", "sonipat", "south_dumdum", "sri_ganganagar", "srikakulam", "srinagar",
|
| 489 |
+
"sultan_pur_majra", "surat", "surendranagar_dudhrej", "suryapet", "tadipatri", "tadepalligudem", "tenali",
|
| 490 |
+
"tezpur", "thane", "thanjavur", "thiruvananthapuram", "thoothukudi", "thrissur", "tinsukia",
|
| 491 |
+
"tiruchirappalli[10]", "tirunelveli", "tirupati[21][22]", "tiruppur", "tiruvottiyur", "tumkur", "udaipur",
|
| 492 |
+
"udupi", "ujjain", "ulhasnagar", "uluberia", "unnao", "vadodara", "varanasi", "vasai_virar", "vellore",
|
| 493 |
+
"vijayawada", "vijayanagaram", "visakhapatnam[4]", "warangal[11][12]", "yamunanagar"
|
| 494 |
+
]
|
| 495 |
+
STATE_LABELS = [
|
| 496 |
+
"andhra_pradesh", "assam", "bihar", "chandigarh", "chhattisgarh",
|
| 497 |
+
"delhi", "gujarat", "haryana", "himachal_pradesh", "jammu_and_kashmir",
|
| 498 |
+
"jharkhand", "karnataka", "kerala", "madhya_pradesh", "maharashtra",
|
| 499 |
+
"manipur", "mizoram", "odisha", "puducherry", "punjab", "rajasthan",
|
| 500 |
+
"sikkim", "tamil_nadu", "telangana", "tripura", "uttar_pradesh",
|
| 501 |
+
"uttar_pradesh[5]", "uttarakhand", "west_bengal"
|
| 502 |
+
]
|
src/pipeline.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Third-party library imports
|
| 2 |
+
from sklearn.pipeline import Pipeline
|
| 3 |
+
from sklearn.compose import ColumnTransformer
|
| 4 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 5 |
+
|
| 6 |
+
# Local imports
|
| 7 |
+
from .custom_transformers import (
|
| 8 |
+
MissingValueChecker,
|
| 9 |
+
MissingValueStandardizer,
|
| 10 |
+
RobustSimpleImputer,
|
| 11 |
+
SnakeCaseFormatter,
|
| 12 |
+
BooleanColumnTransformer,
|
| 13 |
+
JobStabilityTransformer,
|
| 14 |
+
CityTierTransformer,
|
| 15 |
+
StateDefaultRateTargetEncoder,
|
| 16 |
+
RobustStandardScaler,
|
| 17 |
+
RobustOneHotEncoder,
|
| 18 |
+
RobustOrdinalEncoder,
|
| 19 |
+
FeatureSelector
|
| 20 |
+
)
|
| 21 |
+
from .global_constants import (
|
| 22 |
+
CRITICAL_FEATURES,
|
| 23 |
+
NON_CRITICAL_FEATURES,
|
| 24 |
+
COLUMNS_FOR_SNAKE_CASING,
|
| 25 |
+
BOOLEAN_COLUMN_MAPPINGS,
|
| 26 |
+
JOB_STABILITY_MAP,
|
| 27 |
+
CITY_TIER_MAP,
|
| 28 |
+
NUMERICAL_COLUMNS,
|
| 29 |
+
NOMINAL_COLUMN_CATEGORIES,
|
| 30 |
+
ORDINAL_COLUMN_ORDERS,
|
| 31 |
+
COLUMNS_TO_KEEP,
|
| 32 |
+
RF_BEST_PARAMS
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# --- Helper Functions to Create Full Pipeline and Pipeline Segments ---
|
| 37 |
+
def create_data_preprocessing_and_model_pipeline():
|
| 38 |
+
return Pipeline([
|
| 39 |
+
("missing_value_checker", MissingValueChecker(critical_features=CRITICAL_FEATURES, non_critical_features=NON_CRITICAL_FEATURES)),
|
| 40 |
+
("missing_value_standardizer", MissingValueStandardizer()),
|
| 41 |
+
("missing_value_handler", ColumnTransformer(
|
| 42 |
+
transformers=[("categorical_imputer", RobustSimpleImputer(strategy="most_frequent").set_output(transform="pandas"), NON_CRITICAL_FEATURES)],
|
| 43 |
+
remainder="passthrough",
|
| 44 |
+
verbose_feature_names_out=False # preserve input column names instead of adding prefix
|
| 45 |
+
).set_output(transform="pandas")), # output pd.DataFrame instead of np.array
|
| 46 |
+
("snake_case_formatter", SnakeCaseFormatter(columns=COLUMNS_FOR_SNAKE_CASING)),
|
| 47 |
+
("boolean_column_transformer", BooleanColumnTransformer(boolean_column_mappings=BOOLEAN_COLUMN_MAPPINGS)),
|
| 48 |
+
("job_stability_transformer", JobStabilityTransformer(job_stability_map=JOB_STABILITY_MAP)),
|
| 49 |
+
("city_tier_transformer", CityTierTransformer(city_tier_map=CITY_TIER_MAP)),
|
| 50 |
+
("state_default_rate_target_encoder", StateDefaultRateTargetEncoder()),
|
| 51 |
+
("feature_scaler_encoder", ColumnTransformer(
|
| 52 |
+
transformers=[
|
| 53 |
+
("scaler", RobustStandardScaler(), NUMERICAL_COLUMNS),
|
| 54 |
+
("nominal_encoder", RobustOneHotEncoder(categories=NOMINAL_COLUMN_CATEGORIES, drop="first", sparse_output=False), ["house_ownership"]),
|
| 55 |
+
("ordinal_encoder", RobustOrdinalEncoder(categories=ORDINAL_COLUMN_ORDERS), ["job_stability", "city_tier"])
|
| 56 |
+
],
|
| 57 |
+
remainder="passthrough",
|
| 58 |
+
verbose_feature_names_out=False
|
| 59 |
+
).set_output(transform="pandas")),
|
| 60 |
+
("feature_selector", FeatureSelector(columns_to_keep=COLUMNS_TO_KEEP)),
|
| 61 |
+
("rf_classifier", RandomForestClassifier(**RF_BEST_PARAMS, random_state=42))
|
| 62 |
+
])
|
| 63 |
+
|
| 64 |
+
def create_data_preprocessing_pipeline():
|
| 65 |
+
return Pipeline([
|
| 66 |
+
("missing_value_checker", MissingValueChecker(critical_features=CRITICAL_FEATURES, non_critical_features=NON_CRITICAL_FEATURES)),
|
| 67 |
+
("missing_value_standardizer", MissingValueStandardizer()),
|
| 68 |
+
("missing_value_handler", ColumnTransformer(
|
| 69 |
+
transformers=[("categorical_imputer", RobustSimpleImputer(strategy="most_frequent").set_output(transform="pandas"), NON_CRITICAL_FEATURES)],
|
| 70 |
+
remainder="passthrough",
|
| 71 |
+
verbose_feature_names_out=False # preserve input column names instead of adding prefix
|
| 72 |
+
).set_output(transform="pandas")), # output pd.DataFrame instead of np.array
|
| 73 |
+
("snake_case_formatter", SnakeCaseFormatter(columns=COLUMNS_FOR_SNAKE_CASING)),
|
| 74 |
+
("boolean_column_transformer", BooleanColumnTransformer(boolean_column_mappings=BOOLEAN_COLUMN_MAPPINGS)),
|
| 75 |
+
("job_stability_transformer", JobStabilityTransformer(job_stability_map=JOB_STABILITY_MAP)),
|
| 76 |
+
("city_tier_transformer", CityTierTransformer(city_tier_map=CITY_TIER_MAP)),
|
| 77 |
+
("state_default_rate_target_encoder", StateDefaultRateTargetEncoder()),
|
| 78 |
+
("feature_scaler_encoder", ColumnTransformer(
|
| 79 |
+
transformers=[
|
| 80 |
+
("scaler", RobustStandardScaler(), NUMERICAL_COLUMNS),
|
| 81 |
+
("nominal_encoder", RobustOneHotEncoder(categories=NOMINAL_COLUMN_CATEGORIES, drop="first", sparse_output=False), ["house_ownership"]),
|
| 82 |
+
("ordinal_encoder", RobustOrdinalEncoder(categories=ORDINAL_COLUMN_ORDERS), ["job_stability", "city_tier"])
|
| 83 |
+
],
|
| 84 |
+
remainder="passthrough",
|
| 85 |
+
verbose_feature_names_out=False
|
| 86 |
+
).set_output(transform="pandas")),
|
| 87 |
+
("feature_selector", FeatureSelector(columns_to_keep=COLUMNS_TO_KEEP))
|
| 88 |
+
])
|
| 89 |
+
|
| 90 |
+
def create_model_preprocessing_pipeline():
|
| 91 |
+
return Pipeline([
|
| 92 |
+
("feature_scaler_encoder", ColumnTransformer(
|
| 93 |
+
transformers=[
|
| 94 |
+
("scaler", RobustStandardScaler(), NUMERICAL_COLUMNS),
|
| 95 |
+
("nominal_encoder", RobustOneHotEncoder(categories=NOMINAL_COLUMN_CATEGORIES, drop="first", sparse_output=False), ["house_ownership"]),
|
| 96 |
+
("ordinal_encoder", RobustOrdinalEncoder(categories=ORDINAL_COLUMN_ORDERS), ["job_stability", "city_tier"])
|
| 97 |
+
],
|
| 98 |
+
remainder="passthrough",
|
| 99 |
+
verbose_feature_names_out=False
|
| 100 |
+
).set_output(transform="pandas")),
|
| 101 |
+
("feature_selector", FeatureSelector(columns_to_keep=COLUMNS_TO_KEEP))
|
| 102 |
+
])
|
| 103 |
+
|
| 104 |
+
def create_feature_engineering_pipeline():
|
| 105 |
+
return Pipeline([
|
| 106 |
+
("snake_case_formatter", SnakeCaseFormatter(columns=COLUMNS_FOR_SNAKE_CASING)),
|
| 107 |
+
("boolean_column_transformer", BooleanColumnTransformer(boolean_column_mappings=BOOLEAN_COLUMN_MAPPINGS)),
|
| 108 |
+
("job_stability_transformer", JobStabilityTransformer(job_stability_map=JOB_STABILITY_MAP)),
|
| 109 |
+
("city_tier_transformer", CityTierTransformer(city_tier_map=CITY_TIER_MAP)),
|
| 110 |
+
("state_default_rate_target_encoder", StateDefaultRateTargetEncoder()),
|
| 111 |
+
])
|
| 112 |
+
|
| 113 |
+
def create_missing_value_handling_pipeline():
|
| 114 |
+
return Pipeline([
|
| 115 |
+
("missing_value_checker", MissingValueChecker(critical_features=CRITICAL_FEATURES, non_critical_features=NON_CRITICAL_FEATURES)),
|
| 116 |
+
("missing_value_standardizer", MissingValueStandardizer()),
|
| 117 |
+
("missing_value_handler", ColumnTransformer(
|
| 118 |
+
transformers=[("categorical_imputer", RobustSimpleImputer(strategy="most_frequent").set_output(transform="pandas"), NON_CRITICAL_FEATURES)],
|
| 119 |
+
remainder="passthrough",
|
| 120 |
+
verbose_feature_names_out=False # preserve input column names instead of adding prefix
|
| 121 |
+
).set_output(transform="pandas")), # output pd.DataFrame instead of np.array
|
| 122 |
+
])
|
src/utils.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
# Helper function to get path to root directory
|
| 4 |
+
def get_root_directory(anchor_files: str | list[str] = [".git", "Readme.md", "requirements.txt"]) -> Path:
|
| 5 |
+
# Standardize inputs to list[str]
|
| 6 |
+
if isinstance(anchor_files, str):
|
| 7 |
+
anchor_files = [anchor_files]
|
| 8 |
+
|
| 9 |
+
# Get absolute path to current file
|
| 10 |
+
file_path = Path(__file__).resolve()
|
| 11 |
+
|
| 12 |
+
# Iterate over each parent directory
|
| 13 |
+
for parent in file_path.parents:
|
| 14 |
+
# Iterate over each anchor file
|
| 15 |
+
for anchor_file in anchor_files:
|
| 16 |
+
# Check if anchor file exists in parent directory
|
| 17 |
+
if (parent / anchor_file).exists():
|
| 18 |
+
# Return the parent directory in which the anchor file was found, i.e. the root directory
|
| 19 |
+
return parent
|
| 20 |
+
raise FileNotFoundError(f"Root directory not found: None of the anchor files '{anchor_files}' were found in any parent directory.")
|
start.sh
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e # exit immediately if any command fails
|
| 3 |
+
|
| 4 |
+
# Create geoip_db directory (if it doesn't exist already)
|
| 5 |
+
mkdir -p /app/geoip_db
|
| 6 |
+
|
| 7 |
+
# Download GeoLite2-Country.mmdb (if it doesn't exist already)
|
| 8 |
+
if [ ! -f /app/geoip_db/GeoLite2-Country.mmdb ]; then
|
| 9 |
+
echo "Downloading GeoLite2-Country.mmdb..."
|
| 10 |
+
|
| 11 |
+
# Create an account at https://www.maxmind.com/, create a license key and add it to your .env file and your Hugging Face Space secrets
|
| 12 |
+
# Sanitize the key: remove single/double quotes and whitespace (Docker --env-file includes quotes in the value, which breaks the URL)
|
| 13 |
+
MAXMIND_LICENSE_KEY=$(echo "$MAXMIND_LICENSE_KEY" | tr -d '"' | tr -d "'" | tr -d '[:space:]')
|
| 14 |
+
|
| 15 |
+
# Download the database using the sanitized key
|
| 16 |
+
curl -L -o /app/geoip_db/GeoLite2-Country.tar.gz \
|
| 17 |
+
"https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-Country&license_key=${MAXMIND_LICENSE_KEY}&suffix=tar.gz"
|
| 18 |
+
|
| 19 |
+
# Extract the .mmdb file and remove archive
|
| 20 |
+
tar -xzf /app/geoip_db/GeoLite2-Country.tar.gz -C /app/geoip_db --strip-components=1
|
| 21 |
+
rm /app/geoip_db/GeoLite2-Country.tar.gz
|
| 22 |
+
echo "Successfully downloaded GeoLite2-Country.mmdb."
|
| 23 |
+
fi
|
| 24 |
+
|
| 25 |
+
# Start the combined FastAPI and Gradio app on a uvicorn server (Hugging Face Spaces expects port 7860 even though default is 8000)
|
| 26 |
+
uvicorn backend.app:app --host 0.0.0.0 --port 7860
|