Spaces:
Sleeping
Sleeping
Suvh
commited on
Commit
·
070061f
1
Parent(s):
9914b48
Update to v1.1-chatty-luna (2025-12-07)
Browse files- .streamlit/secrets.toml.template +29 -0
- README.md +23 -12
- app.py +7 -0
- data/adult.data +0 -0
- dataset_info/adult.json +11 -0
- models/RandomForest.pkl +3 -0
- models/classifier_metadata.pkl +3 -0
- models/focused_xai_classifier.pth +3 -0
- models/focused_xai_classifier_best.pth +3 -0
- models/focused_xai_classifier_metadata.pkl +3 -0
- models/focused_xai_label_encoder.pkl +3 -0
- models/intent_classifier.pth +3 -0
- models/intent_classifier_best.pth +3 -0
- models/intent_classifier_metadata.pkl +3 -0
- models/intent_label_encoder.pkl +3 -0
- models/label_encoder.pkl +3 -0
- models/model_metadata.json +108 -0
- models/xagent_classifier.pth +3 -0
- requirements.txt +11 -3
- src/DATA_LOGGER_README.md +114 -0
- src/ab_config.py +231 -0
- src/agent.py +300 -0
- src/answer.py +58 -0
- src/app.py +1183 -0
- src/constraints.py +53 -0
- src/data_logger.py +211 -0
- src/env_loader.py +37 -0
- src/github_saver.py +59 -0
- src/load_adult_data.py +55 -0
- src/loan_assistant.py +0 -0
- src/natural_conversation.py +567 -0
- src/nlu.py +385 -0
- src/nlu_config.json +3 -0
- src/preprocessing.py +83 -0
- src/shap_visualizer.py +269 -0
- src/streamlit_app.py +0 -40
- src/train_classifiers.py +41 -0
- src/utils.py +190 -0
- src/xai_methods.py +1028 -0
.streamlit/secrets.toml.template
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Streamlit Cloud Secrets Configuration Template
|
| 2 |
+
# Copy this to .streamlit/secrets.toml for local testing
|
| 3 |
+
# For Streamlit Cloud: Add these secrets in the app dashboard under "Secrets"
|
| 4 |
+
|
| 5 |
+
# OpenAI Configuration (REQUIRED for LLM validation messages)
|
| 6 |
+
OPENAI_API_KEY = "sk-proj-your-api-key-here"
|
| 7 |
+
OPENAI_MODEL = "gpt-4o-mini"
|
| 8 |
+
|
| 9 |
+
# GenAI Features
|
| 10 |
+
HICXAI_GENAI = "on"
|
| 11 |
+
HICXAI_OPENAI_MODEL = "gpt-4o-mini"
|
| 12 |
+
HICXAI_TEMPERATURE = "0.7"
|
| 13 |
+
HICXAI_MAX_TOKENS = "100"
|
| 14 |
+
|
| 15 |
+
# GitHub Integration (for data collection)
|
| 16 |
+
GITHUB_TOKEN = "ghp_your-github-token-here"
|
| 17 |
+
GITHUB_REPO = "https://github.com/yourusername/hicxai-data-private.git"
|
| 18 |
+
|
| 19 |
+
# A/B Testing Configuration
|
| 20 |
+
HICXAI_VERSION = "v0"
|
| 21 |
+
HICXAI_DEBUG_MODE = "false"
|
| 22 |
+
|
| 23 |
+
# Instructions for Streamlit Cloud:
|
| 24 |
+
# 1. Go to your app dashboard on share.streamlit.io
|
| 25 |
+
# 2. Click the three dots (⋮) menu → Settings → Secrets
|
| 26 |
+
# 3. Copy the contents of this file (without comments)
|
| 27 |
+
# 4. Paste into the Secrets text box
|
| 28 |
+
# 5. Click "Save"
|
| 29 |
+
# 6. Your app will automatically restart with the new secrets
|
README.md
CHANGED
|
@@ -1,19 +1,30 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
-
app_port: 8501
|
| 8 |
-
tags:
|
| 9 |
-
- streamlit
|
| 10 |
pinned: false
|
| 11 |
-
|
|
|
|
| 12 |
---
|
| 13 |
|
| 14 |
-
#
|
| 15 |
|
| 16 |
-
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: HicXAI Research - Condition 2
|
| 3 |
+
emoji: 🤖
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: docker
|
|
|
|
|
|
|
|
|
|
| 7 |
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
+
app_port: 7860
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# AI Loan Assistant - Research Study
|
| 13 |
|
| 14 |
+
**Condition 2**: No Explanation, High Anthropomorphism
|
| 15 |
|
| 16 |
+
This is an interactive AI loan assistant for research purposes studying the effects of explainable AI (XAI) methods and conversational anthropomorphism in credit decision systems.
|
| 17 |
+
|
| 18 |
+
## Features
|
| 19 |
+
|
| 20 |
+
- Interactive loan application process
|
| 21 |
+
- ML-based credit assessment
|
| 22 |
+
- Natural language conversation
|
| 23 |
+
- Decision feedback
|
| 24 |
+
- High anthropomorphism (warm, conversational)
|
| 25 |
+
|
| 26 |
+
**Note**: This application is for research purposes only and does not make real credit decisions.
|
| 27 |
+
|
| 28 |
+
## Research Context
|
| 29 |
+
|
| 30 |
+
Part of the HicXAI research project investigating human-AI interaction in high-stakes decision-making contexts.
|
app.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Entry point for Condition 2: E_none_A_high
|
| 2 |
+
Explanation: none | Anthropomorphism: high"""
|
| 3 |
+
import os, sys, streamlit as st
|
| 4 |
+
os.environ['HICXAI_EXPLANATION'] = 'none'
|
| 5 |
+
os.environ['HICXAI_ANTHRO'] = 'high'
|
| 6 |
+
sys.path.append('src')
|
| 7 |
+
exec(open('src/app.py').read())
|
data/adult.data
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset_info/adult.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"name": "adult",
|
| 2 |
+
"target_column": "income",
|
| 3 |
+
"cat_features": [],
|
| 4 |
+
"num_features": ["Age", "Hours per week","Capital Gain", "Capital Loss"],
|
| 5 |
+
"dataset_description": "Assuming you are a person with enough information, the model will predict whether an individual's income is above or below $50,000 per year, based on their demographic and employment information. ",
|
| 6 |
+
"predict_prompt": ["Your profile is not so good. With this profile, your income will be <=50k", "Your profile looks good. With this profile, your income will be <50k"],
|
| 7 |
+
"why_ans": "The above graph shows important features for this prediction. The red features increase the income, while the blue features decrease it.",
|
| 8 |
+
"feature_ans": "The model used all features, however, some features may have a significant impact on the model's prediction for your profile. ",
|
| 9 |
+
"change_ans": ["income less than 50K", "income more than 50K"],
|
| 10 |
+
"feature_description": {}
|
| 11 |
+
}
|
models/RandomForest.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:af33b8ab3dea7a97870096ec4016d3bf326abe378282fc4b26e819ba1334618e
|
| 3 |
+
size 180342936
|
models/classifier_metadata.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:660c54444a0b35c4ce8e5525f91676def30abfe52c4cc0baef939fe504f8b1ee
|
| 3 |
+
size 16016
|
models/focused_xai_classifier.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:32b9be6c51a6b06008a92fd128ed9f0e23e6112d532ff34d3c293213ec86417a
|
| 3 |
+
size 1194307
|
models/focused_xai_classifier_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:043789a066fafc86cf56e4fa587935d2203b0f458969d30bb164de4be36e0115
|
| 3 |
+
size 1194397
|
models/focused_xai_classifier_metadata.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e89c3e39df40b7be4be10e6dc5ac575e12e67b5c9f1be13ea367905f373ab5fb
|
| 3 |
+
size 685
|
models/focused_xai_label_encoder.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:94db78cb2c7f97f9c16f825bf64fc34a7887a3d0485e6cc98e2534a3f1ddb380
|
| 3 |
+
size 307
|
models/intent_classifier.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:58046920169a8b175e11c3d269f008f70a9f232e273d443b93562f86397383a4
|
| 3 |
+
size 2631079
|
models/intent_classifier_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:73880a9032618add14f117cd96ecabb27cb15281d6da0765a4fbf2b053c9653c
|
| 3 |
+
size 2631183
|
models/intent_classifier_metadata.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:caac29c30aae73e3e943d391767b205b44cc02e93583dde0791c69957fe45a91
|
| 3 |
+
size 470
|
models/intent_label_encoder.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:27069a6922e00594edfb753ae19aa472567d23c15eb3def328b6ee708906b08a
|
| 3 |
+
size 331
|
models/label_encoder.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:44beda11a94d5ca1e65f5578a0a5d017cef27c9423b77580bc00746e4af61892
|
| 3 |
+
size 660
|
models/model_metadata.json
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"accuracy": 0.8593582066635959,
|
| 3 |
+
"feature_columns": [
|
| 4 |
+
"age",
|
| 5 |
+
"fnlwgt",
|
| 6 |
+
"education_num",
|
| 7 |
+
"capital_gain",
|
| 8 |
+
"capital_loss",
|
| 9 |
+
"hours_per_week",
|
| 10 |
+
"workclass_Local-gov",
|
| 11 |
+
"workclass_Never-worked",
|
| 12 |
+
"workclass_Private",
|
| 13 |
+
"workclass_Self-emp-inc",
|
| 14 |
+
"workclass_Self-emp-not-inc",
|
| 15 |
+
"workclass_State-gov",
|
| 16 |
+
"workclass_Unknown",
|
| 17 |
+
"workclass_Without-pay",
|
| 18 |
+
"education_11th",
|
| 19 |
+
"education_12th",
|
| 20 |
+
"education_1st-4th",
|
| 21 |
+
"education_5th-6th",
|
| 22 |
+
"education_7th-8th",
|
| 23 |
+
"education_9th",
|
| 24 |
+
"education_Assoc-acdm",
|
| 25 |
+
"education_Assoc-voc",
|
| 26 |
+
"education_Bachelors",
|
| 27 |
+
"education_Doctorate",
|
| 28 |
+
"education_HS-grad",
|
| 29 |
+
"education_Masters",
|
| 30 |
+
"education_Preschool",
|
| 31 |
+
"education_Prof-school",
|
| 32 |
+
"education_Some-college",
|
| 33 |
+
"marital_status_Married-AF-spouse",
|
| 34 |
+
"marital_status_Married-civ-spouse",
|
| 35 |
+
"marital_status_Married-spouse-absent",
|
| 36 |
+
"marital_status_Never-married",
|
| 37 |
+
"marital_status_Separated",
|
| 38 |
+
"marital_status_Widowed",
|
| 39 |
+
"occupation_Armed-Forces",
|
| 40 |
+
"occupation_Craft-repair",
|
| 41 |
+
"occupation_Exec-managerial",
|
| 42 |
+
"occupation_Farming-fishing",
|
| 43 |
+
"occupation_Handlers-cleaners",
|
| 44 |
+
"occupation_Machine-op-inspct",
|
| 45 |
+
"occupation_Other-service",
|
| 46 |
+
"occupation_Priv-house-serv",
|
| 47 |
+
"occupation_Prof-specialty",
|
| 48 |
+
"occupation_Protective-serv",
|
| 49 |
+
"occupation_Sales",
|
| 50 |
+
"occupation_Tech-support",
|
| 51 |
+
"occupation_Transport-moving",
|
| 52 |
+
"occupation_Unknown",
|
| 53 |
+
"relationship_Not-in-family",
|
| 54 |
+
"relationship_Other-relative",
|
| 55 |
+
"relationship_Own-child",
|
| 56 |
+
"relationship_Unmarried",
|
| 57 |
+
"relationship_Wife",
|
| 58 |
+
"race_Asian-Pac-Islander",
|
| 59 |
+
"race_Black",
|
| 60 |
+
"race_Other",
|
| 61 |
+
"race_White",
|
| 62 |
+
"sex_Male",
|
| 63 |
+
"native_country_Canada",
|
| 64 |
+
"native_country_China",
|
| 65 |
+
"native_country_Columbia",
|
| 66 |
+
"native_country_Cuba",
|
| 67 |
+
"native_country_Dominican-Republic",
|
| 68 |
+
"native_country_Ecuador",
|
| 69 |
+
"native_country_El-Salvador",
|
| 70 |
+
"native_country_England",
|
| 71 |
+
"native_country_France",
|
| 72 |
+
"native_country_Germany",
|
| 73 |
+
"native_country_Greece",
|
| 74 |
+
"native_country_Guatemala",
|
| 75 |
+
"native_country_Haiti",
|
| 76 |
+
"native_country_Holand-Netherlands",
|
| 77 |
+
"native_country_Honduras",
|
| 78 |
+
"native_country_Hong",
|
| 79 |
+
"native_country_Hungary",
|
| 80 |
+
"native_country_India",
|
| 81 |
+
"native_country_Iran",
|
| 82 |
+
"native_country_Ireland",
|
| 83 |
+
"native_country_Italy",
|
| 84 |
+
"native_country_Jamaica",
|
| 85 |
+
"native_country_Japan",
|
| 86 |
+
"native_country_Laos",
|
| 87 |
+
"native_country_Mexico",
|
| 88 |
+
"native_country_Nicaragua",
|
| 89 |
+
"native_country_Outlying-US(Guam-USVI-etc)",
|
| 90 |
+
"native_country_Peru",
|
| 91 |
+
"native_country_Philippines",
|
| 92 |
+
"native_country_Poland",
|
| 93 |
+
"native_country_Portugal",
|
| 94 |
+
"native_country_Puerto-Rico",
|
| 95 |
+
"native_country_Scotland",
|
| 96 |
+
"native_country_South",
|
| 97 |
+
"native_country_Taiwan",
|
| 98 |
+
"native_country_Thailand",
|
| 99 |
+
"native_country_Trinadad&Tobago",
|
| 100 |
+
"native_country_United-States",
|
| 101 |
+
"native_country_Unknown",
|
| 102 |
+
"native_country_Vietnam",
|
| 103 |
+
"native_country_Yugoslavia"
|
| 104 |
+
],
|
| 105 |
+
"model_type": "RandomForestClassifier",
|
| 106 |
+
"n_estimators": 100,
|
| 107 |
+
"preprocessing": "preprocess_adult function applied"
|
| 108 |
+
}
|
models/xagent_classifier.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:350175471a46fac5ef197a470626de734ba84466dc66e21c3fe4a5be445b4191
|
| 3 |
+
size 1104859
|
requirements.txt
CHANGED
|
@@ -1,3 +1,11 @@
|
|
| 1 |
-
|
| 2 |
-
pandas
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit>=1.31.0
|
| 2 |
+
pandas>=2.2.0
|
| 3 |
+
numpy>=2.0.0
|
| 4 |
+
scikit-learn>=1.5.0
|
| 5 |
+
matplotlib>=3.8.0
|
| 6 |
+
shap>=0.45.0
|
| 7 |
+
anchor-exp>=0.0.2
|
| 8 |
+
Dice-ML>=0.10.0
|
| 9 |
+
graphviz>=0.20.3
|
| 10 |
+
dtreeviz>=2.2.2
|
| 11 |
+
openai>=1.0.0
|
src/DATA_LOGGER_README.md
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Data Logging Module
|
| 2 |
+
|
| 3 |
+
This module tracks all user interactions with the HicXAI loan assistant and saves data to a private GitHub repository.
|
| 4 |
+
|
| 5 |
+
## Setup
|
| 6 |
+
|
| 7 |
+
1. **Create GitHub Personal Access Token**:
|
| 8 |
+
- Go to GitHub Settings → Developer settings → Personal access tokens
|
| 9 |
+
- Create token with `repo` scope
|
| 10 |
+
- Add to your `.env` file as `GITHUB_DATA_TOKEN`
|
| 11 |
+
|
| 12 |
+
2. **Private Repository**:
|
| 13 |
+
- Data is saved to: `https://github.com/ksauka/hicxai-data-private`
|
| 14 |
+
- Ensure the token has access to this repository
|
| 15 |
+
|
| 16 |
+
## Data Collected
|
| 17 |
+
|
| 18 |
+
### User Identification
|
| 19 |
+
- Prolific ID (from query param `pid` or `PROLIFIC_PID`)
|
| 20 |
+
- Condition (1-6, from query param `cond`)
|
| 21 |
+
- Session ID (unique per session)
|
| 22 |
+
- Timestamps (start, end, duration)
|
| 23 |
+
|
| 24 |
+
### Application Data
|
| 25 |
+
- All 12 loan application fields (age, education, occupation, etc.)
|
| 26 |
+
- Final prediction (approved/denied)
|
| 27 |
+
- Prediction probability
|
| 28 |
+
|
| 29 |
+
### Interactions
|
| 30 |
+
- Every user message (typed or clicked)
|
| 31 |
+
- Every assistant response
|
| 32 |
+
- Input method (typed vs button click)
|
| 33 |
+
- Current field being collected
|
| 34 |
+
- Conversation state
|
| 35 |
+
|
| 36 |
+
### Behavior Metrics
|
| 37 |
+
- Total messages sent
|
| 38 |
+
- Typed vs clicked responses
|
| 39 |
+
- Help button clicks
|
| 40 |
+
- Explanation requests
|
| 41 |
+
- Progress checks
|
| 42 |
+
- Fields changed/corrected
|
| 43 |
+
|
| 44 |
+
### Feedback
|
| 45 |
+
- Rating (1-5 stars)
|
| 46 |
+
- Ease of use
|
| 47 |
+
- Explanation clarity
|
| 48 |
+
- Would recommend
|
| 49 |
+
- Free-text comments
|
| 50 |
+
|
| 51 |
+
## File Structure
|
| 52 |
+
|
| 53 |
+
Data is saved to:
|
| 54 |
+
```
|
| 55 |
+
sessions/
|
| 56 |
+
YYYY-MM-DD/
|
| 57 |
+
{prolific_id}_{condition}_{timestamp}.json
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
## Example Data
|
| 61 |
+
|
| 62 |
+
```json
|
| 63 |
+
{
|
| 64 |
+
"session_id": "abc123",
|
| 65 |
+
"prolific_id": "TEST123",
|
| 66 |
+
"condition": 2,
|
| 67 |
+
"ab_version": "control",
|
| 68 |
+
"timestamps": {
|
| 69 |
+
"session_start": "2025-11-28T10:30:00",
|
| 70 |
+
"session_end": "2025-11-28T10:33:45",
|
| 71 |
+
"duration_seconds": 225
|
| 72 |
+
},
|
| 73 |
+
"application_data": {
|
| 74 |
+
"age": 35,
|
| 75 |
+
"education": "Bachelors",
|
| 76 |
+
...
|
| 77 |
+
"prediction": ">50K",
|
| 78 |
+
"prediction_probability": 0.73
|
| 79 |
+
},
|
| 80 |
+
"interactions": [
|
| 81 |
+
{
|
| 82 |
+
"timestamp": "2025-11-28T10:30:15",
|
| 83 |
+
"type": "user_message",
|
| 84 |
+
"field": "age",
|
| 85 |
+
"input_method": "typed",
|
| 86 |
+
"content": "35"
|
| 87 |
+
},
|
| 88 |
+
...
|
| 89 |
+
],
|
| 90 |
+
"behavior_metrics": {
|
| 91 |
+
"total_messages": 15,
|
| 92 |
+
"typed_responses": 8,
|
| 93 |
+
"clicked_responses": 7,
|
| 94 |
+
...
|
| 95 |
+
},
|
| 96 |
+
"feedback": {
|
| 97 |
+
"rating": 4,
|
| 98 |
+
...
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
## Fallback
|
| 104 |
+
|
| 105 |
+
If GitHub save fails (missing token, network error, etc.), data is saved locally to:
|
| 106 |
+
```
|
| 107 |
+
data/sessions/{date}_{prolific_id}_{condition}_{timestamp}.json
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
## Privacy
|
| 111 |
+
|
| 112 |
+
- Data is saved to a **private** repository
|
| 113 |
+
- Only accessible with the GitHub token
|
| 114 |
+
- No personally identifiable information beyond Prolific ID
|
src/ab_config.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A/B Testing Configuration for HicXAI Agent
|
| 3 |
+
This module configures experimental conditions for the live study.
|
| 4 |
+
|
| 5 |
+
Experiment factors (3 × 2):
|
| 6 |
+
- Explanation type: none | counterfactual | feature_importance
|
| 7 |
+
- Anthropomorphism: low | high
|
| 8 |
+
|
| 9 |
+
Backwards compatibility:
|
| 10 |
+
- HICXAI_VERSION = v0 | v1 still works
|
| 11 |
+
v0 -> explanation=none, anthropomorphism=low
|
| 12 |
+
v1 -> explanation=feature_importance, anthropomorphism=high
|
| 13 |
+
|
| 14 |
+
Environment variables (preferred) or CLI flags:
|
| 15 |
+
- HICXAI_EXPLANATION = none | counterfactual | feature_importance
|
| 16 |
+
- HICXAI_ANTHRO = low | high
|
| 17 |
+
- HICXAI_VERSION = v0 | v1 (legacy)
|
| 18 |
+
CLI flags:
|
| 19 |
+
--explanation=none|counterfactual|feature_importance
|
| 20 |
+
--anthro=low|high
|
| 21 |
+
--HICXAI_VERSION=v0|v1 or --v0 / --v1 or --ab=v0|v1
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
import sys
|
| 26 |
+
import uuid
|
| 27 |
+
import time
|
| 28 |
+
import streamlit as st
|
| 29 |
+
|
| 30 |
+
_VALID_EXPLANATIONS = {"none", "counterfactual", "feature_importance"}
|
| 31 |
+
_VALID_ANTHRO = {"low", "high"}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class AppConfig:
|
| 35 |
+
"""Configuration class for A/B testing versions and factor levels."""
|
| 36 |
+
|
| 37 |
+
def __init__(self):
|
| 38 |
+
# read factor levels (env and CLI), then derive UI toggles
|
| 39 |
+
self.explanation = self._get_explanation_level() # none | counterfactual | feature_importance
|
| 40 |
+
self.anthro = self._get_anthropomorphism_level() # low | high
|
| 41 |
+
self.version = self._legacy_version_label() # v0 | v1 (for sidebar display only)
|
| 42 |
+
self.session_id = self._generate_session_id() # unique session tracking
|
| 43 |
+
|
| 44 |
+
# derived feature flags for UI rendering, explanations, and logging
|
| 45 |
+
self.show_anthropomorphic = (self.anthro == "high")
|
| 46 |
+
self.show_profile_pic = self.show_anthropomorphic
|
| 47 |
+
self.show_shap_visualizations = (self.explanation == "feature_importance" and self.anthro == "high")
|
| 48 |
+
self.show_counterfactual = (self.explanation == "counterfactual")
|
| 49 |
+
self.show_any_explanation = (self.explanation != "none")
|
| 50 |
+
|
| 51 |
+
# assistant identity and copy are derived from anthropomorphism
|
| 52 |
+
self.assistant_name = "Luna" if self.show_anthropomorphic else "AI Assistant"
|
| 53 |
+
if self.show_anthropomorphic:
|
| 54 |
+
self.assistant_intro = "Your AI loan assistant, I will guide you step by step and explain what matters for your decision."
|
| 55 |
+
else:
|
| 56 |
+
self.assistant_intro = "AI system for loan decision support, explanations are provided according to your selection."
|
| 57 |
+
|
| 58 |
+
# data collection options
|
| 59 |
+
self.collect_feedback = True
|
| 60 |
+
self.show_debug_info = False # keep False in production
|
| 61 |
+
|
| 62 |
+
# Legacy compatibility
|
| 63 |
+
self.use_full_features = self.show_any_explanation
|
| 64 |
+
|
| 65 |
+
# ------------- parsing helpers -------------
|
| 66 |
+
|
| 67 |
+
def _get_explanation_level(self):
|
| 68 |
+
"""Resolve explanation factor from env or CLI, with legacy fallback."""
|
| 69 |
+
# env first
|
| 70 |
+
env_val = os.getenv("HICXAI_EXPLANATION", "").strip().lower()
|
| 71 |
+
if env_val in _VALID_EXPLANATIONS:
|
| 72 |
+
return env_val
|
| 73 |
+
|
| 74 |
+
# CLI flags
|
| 75 |
+
for arg in sys.argv[1:]:
|
| 76 |
+
if arg.startswith("--explanation="):
|
| 77 |
+
cand = arg.split("=", 1)[1].strip().lower()
|
| 78 |
+
if cand in _VALID_EXPLANATIONS:
|
| 79 |
+
return cand
|
| 80 |
+
|
| 81 |
+
# legacy version mapping
|
| 82 |
+
legacy = os.getenv("HICXAI_VERSION", "").strip().lower()
|
| 83 |
+
cli_ver = self._cli_version_flag()
|
| 84 |
+
legacy = cli_ver or legacy
|
| 85 |
+
if legacy == "v1":
|
| 86 |
+
return "feature_importance"
|
| 87 |
+
if legacy == "v0":
|
| 88 |
+
return "none"
|
| 89 |
+
|
| 90 |
+
# default
|
| 91 |
+
return "none"
|
| 92 |
+
|
| 93 |
+
def _get_anthropomorphism_level(self):
|
| 94 |
+
"""Resolve anthropomorphism factor from env or CLI, with legacy fallback."""
|
| 95 |
+
# env first
|
| 96 |
+
env_val = os.getenv("HICXAI_ANTHRO", "").strip().lower()
|
| 97 |
+
if env_val in _VALID_ANTHRO:
|
| 98 |
+
return env_val
|
| 99 |
+
|
| 100 |
+
# CLI flags
|
| 101 |
+
for arg in sys.argv[1:]:
|
| 102 |
+
if arg.startswith("--anthro="):
|
| 103 |
+
cand = arg.split("=", 1)[1].strip().lower()
|
| 104 |
+
if cand in _VALID_ANTHRO:
|
| 105 |
+
return cand
|
| 106 |
+
|
| 107 |
+
# legacy version mapping
|
| 108 |
+
legacy = os.getenv("HICXAI_VERSION", "").strip().lower()
|
| 109 |
+
cli_ver = self._cli_version_flag()
|
| 110 |
+
legacy = cli_ver or legacy
|
| 111 |
+
if legacy == "v1":
|
| 112 |
+
return "high"
|
| 113 |
+
if legacy == "v0":
|
| 114 |
+
return "low"
|
| 115 |
+
|
| 116 |
+
# default
|
| 117 |
+
return "low"
|
| 118 |
+
|
| 119 |
+
def _cli_version_flag(self):
|
| 120 |
+
"""Read legacy version flags from CLI to support existing scripts."""
|
| 121 |
+
for arg in sys.argv[1:]:
|
| 122 |
+
if arg in ("--v0", "--v1"):
|
| 123 |
+
return arg[2:]
|
| 124 |
+
if arg.startswith("--HICXAI_VERSION="):
|
| 125 |
+
cand = arg.split("=", 1)[1].strip().lower()
|
| 126 |
+
if cand in {"v0", "v1"}:
|
| 127 |
+
return cand
|
| 128 |
+
if arg.startswith("--ab="):
|
| 129 |
+
cand = arg.split("=", 1)[1].strip().lower()
|
| 130 |
+
if cand in {"v0", "v1"}:
|
| 131 |
+
return cand
|
| 132 |
+
return ""
|
| 133 |
+
|
| 134 |
+
def _legacy_version_label(self):
|
| 135 |
+
"""Provide a simple label for the sidebar, does not affect factor levels."""
|
| 136 |
+
# map current factors to a human friendly tag
|
| 137 |
+
if self.explanation == "feature_importance" and self.anthro == "high":
|
| 138 |
+
return "v1"
|
| 139 |
+
if self.explanation == "none" and self.anthro == "low":
|
| 140 |
+
return "v0"
|
| 141 |
+
return "custom"
|
| 142 |
+
|
| 143 |
+
def _generate_session_id(self):
|
| 144 |
+
"""Generate unique session ID for concurrent user tracking."""
|
| 145 |
+
return f"{self.condition_code()}_{int(time.time())}_{str(uuid.uuid4())[:8]}"
|
| 146 |
+
|
| 147 |
+
# ------------- public helpers for UI and logging -------------
|
| 148 |
+
|
| 149 |
+
def condition_code(self):
|
| 150 |
+
"""
|
| 151 |
+
Compact code for logging and analysis.
|
| 152 |
+
Examples: E_none_A_low, E_cf_A_high, E_shap_A_high
|
| 153 |
+
"""
|
| 154 |
+
e = {"none": "none", "counterfactual": "cf", "feature_importance": "shap"}[self.explanation]
|
| 155 |
+
a = {"low": "low", "high": "high"}[self.anthro]
|
| 156 |
+
return f"E_{e}_A_{a}"
|
| 157 |
+
|
| 158 |
+
def get_assistant_avatar(self):
|
| 159 |
+
"""Return avatar path for high anthropomorphism, else None."""
|
| 160 |
+
if not self.show_profile_pic:
|
| 161 |
+
return None
|
| 162 |
+
possible_paths = [
|
| 163 |
+
"assets/luna_avatar.png",
|
| 164 |
+
"images/assistant_avatar.png",
|
| 165 |
+
"data_questions/Luna_is_a_Dutch_customer_service_assistant_working_at_a_restaurant_she_is_27_years_old_Please_genera.png",
|
| 166 |
+
]
|
| 167 |
+
for path in possible_paths:
|
| 168 |
+
if os.path.exists(path):
|
| 169 |
+
return path
|
| 170 |
+
return None # UI can fall back to initials
|
| 171 |
+
|
| 172 |
+
def get_welcome_message(self):
|
| 173 |
+
"""Version specific welcome message for the chat header."""
|
| 174 |
+
if self.show_anthropomorphic:
|
| 175 |
+
return f"Hi, I am {self.assistant_name}. I will review your information and explain what factors influenced this loan decision."
|
| 176 |
+
return "Welcome, this AI credit assistant can review your information and show which factors influenced the decision."
|
| 177 |
+
|
| 178 |
+
def should_show_visual_explanations(self):
|
| 179 |
+
"""Whether to render SHAP bars or equivalent visuals."""
|
| 180 |
+
return self.show_shap_visualizations
|
| 181 |
+
|
| 182 |
+
def should_show_counterfactuals(self):
|
| 183 |
+
"""Whether to render counterfactual suggestions."""
|
| 184 |
+
return self.show_counterfactual
|
| 185 |
+
|
| 186 |
+
def explanation_style(self):
|
| 187 |
+
"""Control tone for natural language explanations."""
|
| 188 |
+
return "conversational" if self.show_anthropomorphic else "technical"
|
| 189 |
+
|
| 190 |
+
def explanation_label(self):
|
| 191 |
+
"""Human readable label for the assigned explanation type."""
|
| 192 |
+
if self.explanation == "none":
|
| 193 |
+
return "No explanation"
|
| 194 |
+
if self.explanation == "counterfactual":
|
| 195 |
+
return "Counterfactual explanation"
|
| 196 |
+
return "Feature importance explanation"
|
| 197 |
+
|
| 198 |
+
# Legacy compatibility methods
|
| 199 |
+
def get_explanation_style(self):
|
| 200 |
+
"""Get explanation style based on version (alias for explanation_style)"""
|
| 201 |
+
return self.explanation_style()
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# ------------- sidebar debug -------------
|
| 205 |
+
|
| 206 |
+
def show_debug_sidebar():
|
| 207 |
+
"""Display condition and toggles for quick inspection."""
|
| 208 |
+
st.sidebar.write("### Experiment condition")
|
| 209 |
+
st.sidebar.write(f"Version tag: **{config.version}**")
|
| 210 |
+
st.sidebar.write(f"Condition: **{config.condition_code()}**")
|
| 211 |
+
st.sidebar.write(f"Assistant: **{config.assistant_name}**")
|
| 212 |
+
st.sidebar.write(f"Anthropomorphism: **{config.anthro}**")
|
| 213 |
+
st.sidebar.write(f"Explanation: **{config.explanation}**")
|
| 214 |
+
st.sidebar.write(f"Visual SHAP: {'✅' if config.show_shap_visualizations else '❌'}")
|
| 215 |
+
st.sidebar.write(f"Counterfactual: {'✅' if config.show_counterfactual else '❌'}")
|
| 216 |
+
st.sidebar.caption(f"Session ID: {config.session_id}")
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# Global config instance
|
| 220 |
+
config = AppConfig()
|
| 221 |
+
|
| 222 |
+
def show_debug_sidebar():
|
| 223 |
+
"""Display A/B testing debug info in sidebar"""
|
| 224 |
+
if config.version == "v1":
|
| 225 |
+
st.sidebar.success(f"🧪 A/B Test Version: **V1** (Full Features)")
|
| 226 |
+
else:
|
| 227 |
+
st.sidebar.info(f"🧪 A/B Test Version: **V0** (Minimal)")
|
| 228 |
+
|
| 229 |
+
st.sidebar.write(f"**Assistant:** {config.assistant_name}")
|
| 230 |
+
st.sidebar.write(f"**Visual SHAP:** {'✅' if config.show_shap_visualizations else '❌'}")
|
| 231 |
+
st.sidebar.write(f"**Anthropomorphic:** {'✅' if config.show_anthropomorphic else '❌'}")
|
src/agent.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import json
|
| 3 |
+
import random
|
| 4 |
+
import re
|
| 5 |
+
import os
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import shap
|
| 8 |
+
import sklearn
|
| 9 |
+
import pickle
|
| 10 |
+
from constraints import *
|
| 11 |
+
from nlu import NLU
|
| 12 |
+
import json
|
| 13 |
+
from answer import Answers
|
| 14 |
+
|
| 15 |
+
# Import natural conversation enhancer
|
| 16 |
+
try:
|
| 17 |
+
from natural_conversation import enhance_response
|
| 18 |
+
NATURAL_CONVERSATION_AVAILABLE = True
|
| 19 |
+
except ImportError:
|
| 20 |
+
NATURAL_CONVERSATION_AVAILABLE = False
|
| 21 |
+
def enhance_response(response, context=None, response_type="explanation"):
|
| 22 |
+
return response
|
| 23 |
+
|
| 24 |
+
class Agent:
|
| 25 |
+
def __init__(self, nlu_model=None):
|
| 26 |
+
# Core state
|
| 27 |
+
self.dataset = "adult"
|
| 28 |
+
self.current_instance = None
|
| 29 |
+
self.clf = None
|
| 30 |
+
self.predicted_class = None
|
| 31 |
+
self.mode = None
|
| 32 |
+
self.data = {"X": None, "y": None, "features": None, "classes": None}
|
| 33 |
+
|
| 34 |
+
# NLU setup: prefer provided model, else use config, else default
|
| 35 |
+
config_path = os.path.join(os.path.dirname(__file__), 'nlu_config.json')
|
| 36 |
+
if nlu_model is not None:
|
| 37 |
+
self.nlu_model = nlu_model
|
| 38 |
+
elif os.path.exists(config_path):
|
| 39 |
+
with open(config_path, 'r') as f:
|
| 40 |
+
nlu_config = json.load(f)
|
| 41 |
+
self.nlu_model = NLU(model_type=nlu_config.get('model_type', 'sentence_transformers'), model_path=nlu_config.get('model_path'))
|
| 42 |
+
else:
|
| 43 |
+
self.nlu_model = NLU()
|
| 44 |
+
|
| 45 |
+
# UI/state helpers
|
| 46 |
+
self.list_node = []
|
| 47 |
+
self.clf_display = None
|
| 48 |
+
self.l_exist_classes = None
|
| 49 |
+
self.l_exist_features = None
|
| 50 |
+
self.l_instances = None
|
| 51 |
+
self.df_display_instance = None
|
| 52 |
+
self.current_feature = None
|
| 53 |
+
self.preprocessor = None
|
| 54 |
+
|
| 55 |
+
# Feature requirements for user input flows
|
| 56 |
+
self.required_features = [
|
| 57 |
+
'age', 'workclass', 'education', 'education_num', 'marital_status',
|
| 58 |
+
'occupation', 'relationship', 'race', 'sex', 'capital_gain',
|
| 59 |
+
'capital_loss', 'hours_per_week', 'native_country'
|
| 60 |
+
]
|
| 61 |
+
self.user_features = {}
|
| 62 |
+
|
| 63 |
+
# Load data and train model (sets self.clf and self.clf_display)
|
| 64 |
+
self.load_adult_dataset()
|
| 65 |
+
self.train_model()
|
| 66 |
+
|
| 67 |
+
def load_adult_dataset(self):
|
| 68 |
+
data_path = os.path.join(os.path.dirname(__file__), '..', 'data', 'adult.data')
|
| 69 |
+
info_path = os.path.join(os.path.dirname(__file__), '..', 'dataset_info', 'adult.json')
|
| 70 |
+
columns = [
|
| 71 |
+
'age', 'workclass', 'fnlwgt', 'education', 'education_num', 'marital_status',
|
| 72 |
+
'occupation', 'relationship', 'race', 'sex', 'capital_gain', 'capital_loss',
|
| 73 |
+
'hours_per_week', 'native_country', 'income'
|
| 74 |
+
]
|
| 75 |
+
self.data['X_display'] = pd.read_csv(data_path, names=columns, skipinitialspace=True)
|
| 76 |
+
self.data['y_display'] = self.data['X_display']['income']
|
| 77 |
+
self.data['X_display'].drop(['income'], axis=1, inplace=True)
|
| 78 |
+
with open(info_path, 'r') as f:
|
| 79 |
+
self.data['info'] = json.load(f)
|
| 80 |
+
self.data['classes'] = ['<=50K', '>50K']
|
| 81 |
+
self.data['features'] = self.data['X_display'].columns.tolist()
|
| 82 |
+
self.data['feature_names'] = self.data['features']
|
| 83 |
+
self.data['map'] = {}
|
| 84 |
+
|
| 85 |
+
def train_model(self):
|
| 86 |
+
# Ensure model directory exists
|
| 87 |
+
model_dir = os.path.join(os.path.dirname(__file__), '..', 'models')
|
| 88 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 89 |
+
model_path = os.path.join(model_dir, 'RandomForest.pkl')
|
| 90 |
+
if os.path.exists(model_path):
|
| 91 |
+
try:
|
| 92 |
+
self.clf = pickle.load(open(model_path, 'rb'))
|
| 93 |
+
self.clf_display = self.clf
|
| 94 |
+
except Exception as e:
|
| 95 |
+
print(f"⚠️ Failed to load existing model ({e}). Retraining...")
|
| 96 |
+
from preprocessing import preprocess_adult
|
| 97 |
+
df = pd.concat([self.data['X_display'], self.data['y_display']], axis=1)
|
| 98 |
+
df_clean = preprocess_adult(df)
|
| 99 |
+
X = df_clean.drop('income', axis=1)
|
| 100 |
+
y = df_clean['income']
|
| 101 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 102 |
+
clf = RandomForestClassifier(n_estimators=200, random_state=42)
|
| 103 |
+
clf.fit(X, y)
|
| 104 |
+
self.clf = clf
|
| 105 |
+
self.clf_display = clf
|
| 106 |
+
pickle.dump(clf, open(model_path, 'wb'))
|
| 107 |
+
else:
|
| 108 |
+
from preprocessing import preprocess_adult
|
| 109 |
+
df = pd.concat([self.data['X_display'], self.data['y_display']], axis=1)
|
| 110 |
+
df_clean = preprocess_adult(df)
|
| 111 |
+
X = df_clean.drop('income', axis=1)
|
| 112 |
+
y = df_clean['income']
|
| 113 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 114 |
+
self.clf = RandomForestClassifier(n_estimators=100, random_state=42)
|
| 115 |
+
self.clf.fit(X, y)
|
| 116 |
+
# Persist the trained model for faster subsequent runs
|
| 117 |
+
with open(model_path, 'wb') as f:
|
| 118 |
+
pickle.dump(self.clf, f)
|
| 119 |
+
self.clf_display = self.clf
|
| 120 |
+
|
| 121 |
+
# (Removed duplicate __init__; initialization handled above)
|
| 122 |
+
|
| 123 |
+
def handle_user_input(self, user_input):
|
| 124 |
+
"""Handle user input for XAI explanations (used by loan assistant for explanations)"""
|
| 125 |
+
# Step 1: Intent classification and XAI routing using enhanced NLU
|
| 126 |
+
try:
|
| 127 |
+
intent_result, confidence, suggestions = self.nlu_model.classify_intent(user_input)
|
| 128 |
+
from constraints import SUGGEST_SIMILAR_QUESTIONS_MSG, REPHRASE_QUESTION_MSG
|
| 129 |
+
|
| 130 |
+
# Route to appropriate XAI method based on intent
|
| 131 |
+
if isinstance(intent_result, dict) and 'intent' in intent_result:
|
| 132 |
+
# Ensure we have a current instance for explanation
|
| 133 |
+
if self.current_instance is None:
|
| 134 |
+
self.select_random_instance()
|
| 135 |
+
|
| 136 |
+
# Import the routing function
|
| 137 |
+
try:
|
| 138 |
+
from xai_methods import route_to_xai_method
|
| 139 |
+
explanation_result = route_to_xai_method(self, intent_result)
|
| 140 |
+
base_explanation = explanation_result.get('explanation', 'Sorry, I could not generate an explanation.')
|
| 141 |
+
|
| 142 |
+
# Enhance with natural conversation if available
|
| 143 |
+
if NATURAL_CONVERSATION_AVAILABLE:
|
| 144 |
+
context = {
|
| 145 |
+
'explanation_type': intent_result.get('intent', 'general'),
|
| 146 |
+
'user_question': user_input,
|
| 147 |
+
'confidence': intent_result.get('confidence', 0)
|
| 148 |
+
}
|
| 149 |
+
return enhance_response(base_explanation, context, "explanation")
|
| 150 |
+
|
| 151 |
+
return base_explanation
|
| 152 |
+
except ImportError:
|
| 153 |
+
# Fallback if routing function not available
|
| 154 |
+
base_explanation = self._generate_basic_explanation(intent_result)
|
| 155 |
+
|
| 156 |
+
# Enhance fallback explanation too
|
| 157 |
+
if NATURAL_CONVERSATION_AVAILABLE:
|
| 158 |
+
context = {
|
| 159 |
+
'explanation_type': 'basic',
|
| 160 |
+
'user_question': user_input,
|
| 161 |
+
'confidence': 0.5
|
| 162 |
+
}
|
| 163 |
+
return enhance_response(base_explanation, context, "explanation")
|
| 164 |
+
|
| 165 |
+
return base_explanation
|
| 166 |
+
|
| 167 |
+
elif intent_result == 'unknown' and suggestions:
|
| 168 |
+
suggestions_str = "\n".join([f"{idx}. {q}" for idx, q in enumerate(suggestions, 1)])
|
| 169 |
+
return SUGGEST_SIMILAR_QUESTIONS_MSG.format(suggestions=suggestions_str)
|
| 170 |
+
else:
|
| 171 |
+
return REPHRASE_QUESTION_MSG
|
| 172 |
+
|
| 173 |
+
except Exception as e:
|
| 174 |
+
return f"I'm having trouble processing that question. Could you try asking it differently? Error: {str(e)}"
|
| 175 |
+
|
| 176 |
+
def _generate_basic_explanation(self, intent_result):
|
| 177 |
+
"""Generate basic explanation when XAI methods are not available"""
|
| 178 |
+
if self.current_instance is None or self.predicted_class is None:
|
| 179 |
+
return "I need a specific instance to explain. Please make sure a prediction has been made."
|
| 180 |
+
|
| 181 |
+
# Basic explanation based on the current instance
|
| 182 |
+
explanation = f"Based on your profile, the decision was: {self.predicted_class}\n\n"
|
| 183 |
+
explanation += "Key factors in this decision include:\n"
|
| 184 |
+
|
| 185 |
+
# Highlight some key features
|
| 186 |
+
key_features = ['age', 'education', 'hours_per_week', 'occupation', 'marital_status']
|
| 187 |
+
for feature in key_features:
|
| 188 |
+
if feature in self.current_instance:
|
| 189 |
+
value = self.current_instance[feature]
|
| 190 |
+
explanation += f"• {feature.replace('_', ' ').title()}: {value}\n"
|
| 191 |
+
|
| 192 |
+
explanation += "\nThis is a simplified explanation. For more detailed analysis, specific XAI methods would provide deeper insights."
|
| 193 |
+
return explanation
|
| 194 |
+
|
| 195 |
+
def select_random_instance(self):
|
| 196 |
+
"""Select a random instance from the dataset for explanation"""
|
| 197 |
+
if self.data.get('X_display') is not None and len(self.data['X_display']) > 0:
|
| 198 |
+
random_idx = random.randint(0, len(self.data['X_display']) - 1)
|
| 199 |
+
self.df_display_instance = self.data['X_display'].iloc[[random_idx]]
|
| 200 |
+
self.current_instance = self.df_display_instance.iloc[0].to_dict()
|
| 201 |
+
|
| 202 |
+
# Make prediction for this instance
|
| 203 |
+
if self.clf_display is not None:
|
| 204 |
+
self.predicted_class = self.clf_display.predict(self.df_display_instance)[0]
|
| 205 |
+
|
| 206 |
+
def get_visualization(self, viz_type, instance_df=None):
|
| 207 |
+
"""
|
| 208 |
+
Route advanced visualization requests to Answers class.
|
| 209 |
+
viz_type: 'shap_advanced' or 'dtreeviz'
|
| 210 |
+
instance_df: DataFrame for the instance to visualize
|
| 211 |
+
"""
|
| 212 |
+
answers = Answers(
|
| 213 |
+
list_node=self.list_node,
|
| 214 |
+
clf=self.clf,
|
| 215 |
+
clf_display=self.clf_display,
|
| 216 |
+
current_instance=self.current_instance,
|
| 217 |
+
question=None,
|
| 218 |
+
l_exist_classes=self.l_exist_classes,
|
| 219 |
+
l_exist_features=self.l_exist_features,
|
| 220 |
+
l_instances=self.l_instances,
|
| 221 |
+
data=self.data,
|
| 222 |
+
df_display_instance=self.df_display_instance,
|
| 223 |
+
predicted_class=self.predicted_class,
|
| 224 |
+
preprocessor=self.preprocessor
|
| 225 |
+
)
|
| 226 |
+
return answers.answer(viz_type, instance_df=instance_df)
|
| 227 |
+
|
| 228 |
+
def handle_user_input(self, user_input, instance_df=None):
|
| 229 |
+
# Step 1: Refined feature extraction using regex and synonyms
|
| 230 |
+
feature_synonyms = {
|
| 231 |
+
'age': ['age', 'years old'],
|
| 232 |
+
'workclass': ['workclass', 'work type', 'job type'],
|
| 233 |
+
'education': ['education', 'degree'],
|
| 234 |
+
'education_num': ['education num', 'education number', 'years of education'],
|
| 235 |
+
'marital_status': ['marital status', 'married', 'single', 'relationship status'],
|
| 236 |
+
'occupation': ['occupation', 'job', 'profession'],
|
| 237 |
+
'relationship': ['relationship'],
|
| 238 |
+
'race': ['race', 'ethnicity'],
|
| 239 |
+
'sex': ['sex', 'gender'],
|
| 240 |
+
'capital_gain': ['capital gain', 'gain'],
|
| 241 |
+
'capital_loss': ['capital loss', 'loss'],
|
| 242 |
+
'hours_per_week': ['hours per week', 'weekly hours', 'work hours'],
|
| 243 |
+
'native_country': ['native country', 'country', 'nationality']
|
| 244 |
+
}
|
| 245 |
+
# Try to extract feature-value pairs from user input
|
| 246 |
+
for feature, synonyms in feature_synonyms.items():
|
| 247 |
+
for syn in synonyms:
|
| 248 |
+
pattern = rf"{syn}[:=]?\s*([\w\-\+]+)"
|
| 249 |
+
match = re.search(pattern, user_input, re.IGNORECASE)
|
| 250 |
+
if match:
|
| 251 |
+
self.user_features[feature] = match.group(1)
|
| 252 |
+
# Check for missing features
|
| 253 |
+
from constraints import CLARIFY_FEATURE_MSG
|
| 254 |
+
missing = [f for f in self.required_features if f not in self.user_features]
|
| 255 |
+
if missing:
|
| 256 |
+
next_feat = missing[0]
|
| 257 |
+
return CLARIFY_FEATURE_MSG.format(feature=next_feat.replace('_', ' '))
|
| 258 |
+
# Step 2: Robust validation using adult dataset metadata
|
| 259 |
+
from constraints import REPEAT_NUM_FEATURES, REPEAT_CAT_FEATURES
|
| 260 |
+
info = self.data.get('info', {})
|
| 261 |
+
for feature in self.required_features:
|
| 262 |
+
value = self.user_features.get(feature)
|
| 263 |
+
if value is None:
|
| 264 |
+
continue
|
| 265 |
+
# Numeric validation
|
| 266 |
+
if feature in info.get('num_features', []):
|
| 267 |
+
try:
|
| 268 |
+
val = float(value)
|
| 269 |
+
minv, maxv = info.get('feature_ranges', {}).get(feature, (None, None))
|
| 270 |
+
if minv is not None and (val < minv or val > maxv):
|
| 271 |
+
del self.user_features[feature]
|
| 272 |
+
return REPEAT_NUM_FEATURES.format(f"{minv}-{maxv}")
|
| 273 |
+
except Exception:
|
| 274 |
+
del self.user_features[feature]
|
| 275 |
+
return REPEAT_NUM_FEATURES.format("valid number")
|
| 276 |
+
# Categorical validation
|
| 277 |
+
if feature in info.get('cat_features', []):
|
| 278 |
+
valid = info.get('feature_values', {}).get(feature, [])
|
| 279 |
+
if valid and value not in valid:
|
| 280 |
+
del self.user_features[feature]
|
| 281 |
+
return REPEAT_CAT_FEATURES.format(", ".join(valid))
|
| 282 |
+
# Step 3: Intent classification and XAI routing using enhanced NLU
|
| 283 |
+
intent_result, confidence, suggestions = self.nlu_model.classify_intent(user_input)
|
| 284 |
+
from constraints import SUGGEST_SIMILAR_QUESTIONS_MSG, REPHRASE_QUESTION_MSG
|
| 285 |
+
from xai_methods import route_to_xai_method
|
| 286 |
+
# Route to appropriate XAI method based on intent
|
| 287 |
+
if isinstance(intent_result, dict) and 'intent' in intent_result:
|
| 288 |
+
if self.current_instance is None:
|
| 289 |
+
self.select_random_instance()
|
| 290 |
+
# Advanced visualization intents
|
| 291 |
+
if intent_result['intent'] in ['shap_advanced', 'dtreeviz']:
|
| 292 |
+
return self.get_visualization(intent_result['intent'], instance_df)
|
| 293 |
+
# Standard explanation routing
|
| 294 |
+
explanation_result = route_to_xai_method(self, intent_result)
|
| 295 |
+
return explanation_result.get('explanation', 'Sorry, I could not generate an explanation.')
|
| 296 |
+
elif intent_result == 'unknown' and suggestions:
|
| 297 |
+
suggestions_str = "\n".join([f"{idx}. {q}" for idx, q in enumerate(suggestions, 1)])
|
| 298 |
+
return SUGGEST_SIMILAR_QUESTIONS_MSG.format(suggestions=suggestions_str)
|
| 299 |
+
else:
|
| 300 |
+
return REPHRASE_QUESTION_MSG
|
src/answer.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Migrated and adapted from XAgent/Agent/answer.py for adult-only use
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import numpy as np
|
| 7 |
+
from xai_methods import (
|
| 8 |
+
explain_with_shap, explain_with_dice, explain_with_anchor,
|
| 9 |
+
explain_with_shap_advanced, explain_with_dtreeviz
|
| 10 |
+
)
|
| 11 |
+
from constraints import *
|
| 12 |
+
|
| 13 |
+
class Answers:
|
| 14 |
+
def __init__(self, list_node, clf, clf_display, current_instance, question, l_exist_classes, l_exist_features,
|
| 15 |
+
l_instances, data, df_display_instance, predicted_class, preprocessor=None):
|
| 16 |
+
self.list_node = list_node
|
| 17 |
+
self.clf = clf
|
| 18 |
+
self.clf_display = clf_display
|
| 19 |
+
self.question = question
|
| 20 |
+
self.current_instance = current_instance
|
| 21 |
+
self.l_exist_classes = l_exist_classes
|
| 22 |
+
self.l_exist_features = l_exist_features
|
| 23 |
+
self.l_instances = l_instances
|
| 24 |
+
self.l_classes = data['classes']
|
| 25 |
+
self.l_features = data['features']
|
| 26 |
+
self.data = data
|
| 27 |
+
self.df_display_instance = df_display_instance
|
| 28 |
+
self.predicted_class = predicted_class
|
| 29 |
+
self.preprocessor = preprocessor
|
| 30 |
+
|
| 31 |
+
def answer(self, intent, conversations=[], instance_df=None, **kwargs):
|
| 32 |
+
"""
|
| 33 |
+
Route to the correct XAI method based on dynamic intent/label from NLU.
|
| 34 |
+
intent: predicted label from NLU (e.g., 'predict', 'shap_explain', 'dice_explain', 'anchor_explain', 'cf_proto', 'shap_advanced', 'dtreeviz')
|
| 35 |
+
"""
|
| 36 |
+
if intent == 'predict':
|
| 37 |
+
return f"Based on your input, the predicted income is {self.predicted_class}."
|
| 38 |
+
elif intent == 'shap_explain':
|
| 39 |
+
return explain_with_shap(self)
|
| 40 |
+
elif intent == 'dice_explain':
|
| 41 |
+
return explain_with_dice(self)
|
| 42 |
+
elif intent == 'anchor_explain':
|
| 43 |
+
return explain_with_anchor(self)
|
| 44 |
+
elif intent == 'cf_proto':
|
| 45 |
+
# CounterfactualProto (alibi) removed; optionally replace with dice-ml or handle gracefully
|
| 46 |
+
return None
|
| 47 |
+
elif intent == 'shap_advanced':
|
| 48 |
+
if instance_df is not None:
|
| 49 |
+
return explain_with_shap_advanced(self, instance_df)
|
| 50 |
+
else:
|
| 51 |
+
return {'type': 'error', 'explanation': 'No instance provided for SHAP advanced.'}
|
| 52 |
+
elif intent == 'dtreeviz':
|
| 53 |
+
if instance_df is not None:
|
| 54 |
+
return explain_with_dtreeviz(self, instance_df)
|
| 55 |
+
else:
|
| 56 |
+
return {'type': 'error', 'explanation': 'No instance provided for dtreeviz.'}
|
| 57 |
+
else:
|
| 58 |
+
return "Sorry, I can't answer that question yet."
|
src/app.py
ADDED
|
@@ -0,0 +1,1183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
# Load environment variables from .env file
|
| 4 |
+
import env_loader
|
| 5 |
+
|
| 6 |
+
# Configure page FIRST - before any other Streamlit commands
|
| 7 |
+
st.set_page_config(page_title="AI Loan Assistant - Credit Pre-Assessment", layout="wide")
|
| 8 |
+
|
| 9 |
+
# Hide Streamlit branding for anonymous review (CSS + JavaScript)
|
| 10 |
+
st.markdown("""
|
| 11 |
+
<style>
|
| 12 |
+
/* ===== COMPREHENSIVE STREAMLIT BRANDING REMOVAL ===== */
|
| 13 |
+
|
| 14 |
+
/* Hide header elements */
|
| 15 |
+
#MainMenu {visibility: hidden !important;}
|
| 16 |
+
header {visibility: hidden !important;}
|
| 17 |
+
[data-testid="stHeader"] {display: none !important;}
|
| 18 |
+
[data-testid="stToolbar"] {display: none !important;}
|
| 19 |
+
[data-testid="stDecoration"] {display: none !important;}
|
| 20 |
+
[data-testid="stStatusWidget"] {display: none !important;}
|
| 21 |
+
button[kind="header"] {display: none !important;}
|
| 22 |
+
|
| 23 |
+
/* Hide footer elements - ALL variations */
|
| 24 |
+
footer {visibility: hidden !important; display: none !important;}
|
| 25 |
+
[data-testid="stFooter"] {display: none !important;}
|
| 26 |
+
footer[data-testid="stFooter"] {display: none !important;}
|
| 27 |
+
div[role="contentinfo"] {display: none !important;}
|
| 28 |
+
[class*="footer"] {display: none !important;}
|
| 29 |
+
[class*="Footer"] {display: none !important;}
|
| 30 |
+
|
| 31 |
+
/* Hide deploy/manage buttons */
|
| 32 |
+
[data-testid="manage-app-button"] {display: none !important;}
|
| 33 |
+
.stAppDeployButton {display: none !important;}
|
| 34 |
+
.stDeployButton {display: none !important;}
|
| 35 |
+
|
| 36 |
+
/* ===== HIDE ALL CREATOR ATTRIBUTION ===== */
|
| 37 |
+
|
| 38 |
+
/* Text links to creator profile */
|
| 39 |
+
a[href*="streamlit.io"] {display: none !important;}
|
| 40 |
+
a[href*="share.streamlit.io/user"] {display: none !important;}
|
| 41 |
+
a[href*="/user/ksauka"] {display: none !important;}
|
| 42 |
+
a[target="_blank"][href^="https://share.streamlit.io"] {display: none !important;}
|
| 43 |
+
|
| 44 |
+
/* Image/Avatar links to creator profile */
|
| 45 |
+
a[href*="streamlit.io"] img {display: none !important;}
|
| 46 |
+
a[href*="share.streamlit.io"] img {display: none !important;}
|
| 47 |
+
a img[src*="avatar"] {display: none !important;}
|
| 48 |
+
a img[src*="profile"] {display: none !important;}
|
| 49 |
+
img[alt*="creator"] {display: none !important;}
|
| 50 |
+
img[alt*="author"] {display: none !important;}
|
| 51 |
+
|
| 52 |
+
/* Viewer badge containers and links */
|
| 53 |
+
.viewerBadge_link__qRIco {display: none !important;}
|
| 54 |
+
.viewerBadge_link__Ua7HT {display: none !important;}
|
| 55 |
+
.viewerBadge_container__r5tak {display: none !important;}
|
| 56 |
+
.viewerBadge_container__2QSob {display: none !important;}
|
| 57 |
+
a.viewer-badge {display: none !important;}
|
| 58 |
+
[class*="viewerBadge"] {display: none !important;}
|
| 59 |
+
[class*="ViewerBadge"] {display: none !important;}
|
| 60 |
+
|
| 61 |
+
/* Profile/Avatar elements */
|
| 62 |
+
[class*="avatar"] {display: none !important;}
|
| 63 |
+
[class*="Avatar"] {display: none !important;}
|
| 64 |
+
[class*="profile"] {display: none !important;}
|
| 65 |
+
[class*="Profile"] {display: none !important;}
|
| 66 |
+
[data-testid*="avatar"] {display: none !important;}
|
| 67 |
+
[data-testid*="profile"] {display: none !important;}
|
| 68 |
+
|
| 69 |
+
/* Any div containing creator attribution at bottom of page */
|
| 70 |
+
div[class*="creator"] {display: none !important;}
|
| 71 |
+
div[class*="author"] {display: none !important;}
|
| 72 |
+
div[class*="attribution"] {display: none !important;}
|
| 73 |
+
|
| 74 |
+
/* Catch-all: any link in bottom 100px of page pointing to streamlit.io */
|
| 75 |
+
body > div:last-child a[href*="streamlit.io"] {display: none !important;}
|
| 76 |
+
.main > div:last-child a[href*="streamlit.io"] {display: none !important;}
|
| 77 |
+
|
| 78 |
+
/* Nuclear option: hide entire bottom-most div if it contains streamlit links */
|
| 79 |
+
div:has(a[href*="streamlit.io"]) {display: none !important;}
|
| 80 |
+
|
| 81 |
+
/* Disable pointer events on any remaining visible elements */
|
| 82 |
+
a[href*="streamlit.io"],
|
| 83 |
+
a[href*="share.streamlit.io"],
|
| 84 |
+
img[src*="avatar"],
|
| 85 |
+
img[src*="profile"] {
|
| 86 |
+
pointer-events: none !important;
|
| 87 |
+
cursor: default !important;
|
| 88 |
+
display: none !important;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
/* Remove padding after footer removal */
|
| 92 |
+
section.main > div {padding-bottom: 0 !important;}
|
| 93 |
+
|
| 94 |
+
/* Legacy class hiding */
|
| 95 |
+
.css-1v0mbdj {display: none !important;}
|
| 96 |
+
</style>
|
| 97 |
+
|
| 98 |
+
<script>
|
| 99 |
+
// JavaScript to forcefully remove Streamlit branding (runs continuously)
|
| 100 |
+
(function() {
|
| 101 |
+
function removeStreamlitBranding() {
|
| 102 |
+
// Remove footer elements
|
| 103 |
+
const footers = document.querySelectorAll('footer, [data-testid="stFooter"], [class*="footer"], [class*="Footer"]');
|
| 104 |
+
footers.forEach(el => el.remove());
|
| 105 |
+
|
| 106 |
+
// Remove header elements
|
| 107 |
+
const headers = document.querySelectorAll('header, [data-testid="stHeader"], #MainMenu');
|
| 108 |
+
headers.forEach(el => el.remove());
|
| 109 |
+
|
| 110 |
+
// Remove any links to streamlit.io
|
| 111 |
+
const streamlitLinks = document.querySelectorAll('a[href*="streamlit.io"], a[href*="share.streamlit.io"]');
|
| 112 |
+
streamlitLinks.forEach(el => el.remove());
|
| 113 |
+
|
| 114 |
+
// Remove viewer badges
|
| 115 |
+
const badges = document.querySelectorAll('[class*="viewerBadge"], [class*="ViewerBadge"], .viewer-badge');
|
| 116 |
+
badges.forEach(el => el.remove());
|
| 117 |
+
|
| 118 |
+
// Remove avatars and profile images
|
| 119 |
+
const avatars = document.querySelectorAll('[class*="avatar"], [class*="Avatar"], [class*="profile"], [class*="Profile"]');
|
| 120 |
+
avatars.forEach(el => {
|
| 121 |
+
// Only remove if it's in a link to streamlit
|
| 122 |
+
const parent = el.closest('a');
|
| 123 |
+
if (parent && parent.href && parent.href.includes('streamlit.io')) {
|
| 124 |
+
parent.remove();
|
| 125 |
+
}
|
| 126 |
+
});
|
| 127 |
+
|
| 128 |
+
// Remove any div that contains streamlit links
|
| 129 |
+
const allLinks = document.querySelectorAll('a[href*="streamlit.io"]');
|
| 130 |
+
allLinks.forEach(link => {
|
| 131 |
+
const container = link.closest('div');
|
| 132 |
+
if (container) {
|
| 133 |
+
container.remove();
|
| 134 |
+
}
|
| 135 |
+
});
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
// Run immediately
|
| 139 |
+
removeStreamlitBranding();
|
| 140 |
+
|
| 141 |
+
// Run every 500ms to catch dynamically added elements
|
| 142 |
+
setInterval(removeStreamlitBranding, 500);
|
| 143 |
+
|
| 144 |
+
// Also run on DOM changes
|
| 145 |
+
const observer = new MutationObserver(removeStreamlitBranding);
|
| 146 |
+
observer.observe(document.body, { childList: true, subtree: true });
|
| 147 |
+
})();
|
| 148 |
+
</script>
|
| 149 |
+
|
| 150 |
+
<meta name="robots" content="noindex, nofollow">
|
| 151 |
+
""", unsafe_allow_html=True)
|
| 152 |
+
|
| 153 |
+
# ===== QUALTRICS/PROLIFIC INTEGRATION (robust final) =====
|
| 154 |
+
import time
|
| 155 |
+
from urllib.parse import unquote, urlparse, parse_qsl, urlencode, urlunparse
|
| 156 |
+
|
| 157 |
+
def _get_query_params():
|
| 158 |
+
try:
|
| 159 |
+
# Streamlit ≥1.32
|
| 160 |
+
return dict(st.query_params)
|
| 161 |
+
except Exception:
|
| 162 |
+
try:
|
| 163 |
+
# Older Streamlit
|
| 164 |
+
return st.experimental_get_query_params()
|
| 165 |
+
except Exception:
|
| 166 |
+
return {}
|
| 167 |
+
|
| 168 |
+
def _as_str(v):
|
| 169 |
+
if isinstance(v, list):
|
| 170 |
+
return v[0] if v else ""
|
| 171 |
+
return v if isinstance(v, str) else ""
|
| 172 |
+
|
| 173 |
+
def _is_safe_return(ru: str) -> bool:
|
| 174 |
+
"""Allow https/http + any *.qualtrics.com netloc (handles regional subdomains)."""
|
| 175 |
+
if not ru:
|
| 176 |
+
return False
|
| 177 |
+
try:
|
| 178 |
+
d = unquote(ru)
|
| 179 |
+
# tolerate missing scheme (rare). Qualtrics links should always be https
|
| 180 |
+
if not d.startswith(("http://", "https://")):
|
| 181 |
+
d = "https://" + d
|
| 182 |
+
p = urlparse(d)
|
| 183 |
+
return (p.scheme in ("http", "https")) and ("qualtrics.com" in p.netloc)
|
| 184 |
+
except Exception:
|
| 185 |
+
return False
|
| 186 |
+
|
| 187 |
+
def _build_final_return(done=True):
|
| 188 |
+
"""
|
| 189 |
+
Start with the encoded Qualtrics 'return' URL, decode once,
|
| 190 |
+
ensure it points to Qualtrics, then append pid/cond/done IFF missing.
|
| 191 |
+
"""
|
| 192 |
+
rr = st.session_state.get("return_raw", "")
|
| 193 |
+
if not rr or not _is_safe_return(rr):
|
| 194 |
+
return None
|
| 195 |
+
|
| 196 |
+
decoded = unquote(rr)
|
| 197 |
+
# normalize scheme if missing (defensive)
|
| 198 |
+
if not decoded.startswith(("http://", "https://")):
|
| 199 |
+
decoded = "https://" + decoded
|
| 200 |
+
|
| 201 |
+
p = urlparse(decoded)
|
| 202 |
+
q = dict(parse_qsl(p.query, keep_blank_values=True))
|
| 203 |
+
|
| 204 |
+
# only add if not already present
|
| 205 |
+
pid_ss = st.session_state.get("pid", "")
|
| 206 |
+
cond_ss = st.session_state.get("cond", "")
|
| 207 |
+
prolific_pid_ss = st.session_state.get("prolific_pid", "")
|
| 208 |
+
|
| 209 |
+
if "pid" not in q and pid_ss: q["pid"] = pid_ss
|
| 210 |
+
if "cond" not in q and cond_ss: q["cond"] = cond_ss
|
| 211 |
+
if "PROLIFIC_PID" not in q and prolific_pid_ss: q["PROLIFIC_PID"] = prolific_pid_ss
|
| 212 |
+
if "done" not in q: q["done"] = "1" if done else "0"
|
| 213 |
+
|
| 214 |
+
return urlunparse(p._replace(query=urlencode(q, doseq=True)))
|
| 215 |
+
|
| 216 |
+
# -------------- read & persist params once --------------
|
| 217 |
+
_qs = _get_query_params()
|
| 218 |
+
_pid_in = _as_str(_qs.get("pid", ""))
|
| 219 |
+
_cond_in = _as_str(_qs.get("cond", ""))
|
| 220 |
+
_ret_in = _as_str(_qs.get("return", ""))
|
| 221 |
+
# Prolific standard parameter
|
| 222 |
+
_prolific_pid = _as_str(_qs.get("PROLIFIC_PID", ""))
|
| 223 |
+
|
| 224 |
+
if "pid" not in st.session_state and _pid_in:
|
| 225 |
+
st.session_state.pid = _pid_in
|
| 226 |
+
if "cond" not in st.session_state and _cond_in:
|
| 227 |
+
st.session_state.cond = _cond_in
|
| 228 |
+
if "return_raw" not in st.session_state and _ret_in:
|
| 229 |
+
st.session_state.return_raw = _ret_in
|
| 230 |
+
# Store Prolific ID separately for research tracking
|
| 231 |
+
if "prolific_pid" not in st.session_state and _prolific_pid:
|
| 232 |
+
st.session_state.prolific_pid = _prolific_pid
|
| 233 |
+
|
| 234 |
+
# boolean flag for UI (sticky footer etc.)
|
| 235 |
+
st.session_state.has_return_url = bool(st.session_state.get("return_raw", "")) # always recompute
|
| 236 |
+
|
| 237 |
+
# one-shot redirect latch
|
| 238 |
+
if "_returned" not in st.session_state:
|
| 239 |
+
st.session_state._returned = False
|
| 240 |
+
|
| 241 |
+
def back_to_survey(done_flag=True):
|
| 242 |
+
"""Single exit path. Call on button click or timeout."""
|
| 243 |
+
if st.session_state._returned:
|
| 244 |
+
return
|
| 245 |
+
final = _build_final_return(done=done_flag)
|
| 246 |
+
if not final:
|
| 247 |
+
st.warning("Return link missing or invalid. Please use your browser Back button.")
|
| 248 |
+
return
|
| 249 |
+
st.session_state._returned = True
|
| 250 |
+
# immediate redirect – robust & no loops
|
| 251 |
+
st.markdown(f'<meta http-equiv="refresh" content="0;url={final}">', unsafe_allow_html=True)
|
| 252 |
+
st.stop()
|
| 253 |
+
|
| 254 |
+
# handle previously latched redirect (e.g., if Streamlit re-renders mid-redirect)
|
| 255 |
+
if st.session_state.get("_returned"):
|
| 256 |
+
final = _build_final_return(done=True)
|
| 257 |
+
if final:
|
| 258 |
+
st.markdown(f'<meta http-equiv="refresh" content="0;url={final}">', unsafe_allow_html=True)
|
| 259 |
+
st.stop()
|
| 260 |
+
|
| 261 |
+
# set the 3-minute deadline once and track start time
|
| 262 |
+
if "deadline_ts" not in st.session_state:
|
| 263 |
+
st.session_state.deadline_ts = time.time() + 180
|
| 264 |
+
st.session_state.start_time = time.time() # Track when user started
|
| 265 |
+
|
| 266 |
+
# fire auto-return when time is up (exactly once)
|
| 267 |
+
if time.time() >= st.session_state.deadline_ts:
|
| 268 |
+
back_to_survey(done_flag=True)
|
| 269 |
+
|
| 270 |
+
# expose the function for UI buttons
|
| 271 |
+
st.session_state.back_to_survey = back_to_survey
|
| 272 |
+
|
| 273 |
+
# Prevent restart via browser refresh/back ONLY if user had already started
|
| 274 |
+
# Check if this is a fresh session (first visit) vs a refresh (had chat history)
|
| 275 |
+
if "loan_assistant" not in st.session_state and st.session_state.get("return_raw"):
|
| 276 |
+
# Only redirect if they had already started (had chat history marker)
|
| 277 |
+
if st.session_state.get("application_started", False):
|
| 278 |
+
# User refreshed or went back after starting - redirect to survey
|
| 279 |
+
back_to_survey(done_flag=True)
|
| 280 |
+
|
| 281 |
+
# ===== END QUALTRICS/PROLIFIC INTEGRATION =====
|
| 282 |
+
|
| 283 |
+
# Now import everything else
|
| 284 |
+
from agent import Agent
|
| 285 |
+
from nlu import NLU
|
| 286 |
+
from answer import Answers
|
| 287 |
+
from github_saver import save_to_github
|
| 288 |
+
from loan_assistant import LoanAssistant
|
| 289 |
+
from ab_config import config
|
| 290 |
+
from shap_visualizer import display_shap_explanation, explain_shap_visualizations
|
| 291 |
+
from data_logger import init_logger
|
| 292 |
+
from xai_methods import get_friendly_feature_name
|
| 293 |
+
import os
|
| 294 |
+
import pandas as pd
|
| 295 |
+
|
| 296 |
+
# Initialize data logger
|
| 297 |
+
logger = init_logger()
|
| 298 |
+
|
| 299 |
+
# Define field options for quick selection (based on actual Adult dataset analysis)
|
| 300 |
+
field_options = {
|
| 301 |
+
'workclass': ['Private', 'Self-emp-not-inc', 'Self-emp-inc', 'Federal-gov', 'Local-gov', 'State-gov', 'Without-pay', 'Never-worked', '?'],
|
| 302 |
+
'education': ['Bachelors', 'HS-grad', 'Masters', 'Some-college', 'Assoc-acdm', 'Assoc-voc', '11th', '9th', '10th', '12th', '7th-8th', 'Doctorate', '1st-4th', '5th-6th', 'Preschool', 'Prof-school'],
|
| 303 |
+
'marital_status': ['Married-civ-spouse', 'Divorced', 'Never-married', 'Separated', 'Widowed', 'Married-spouse-absent', 'Married-AF-spouse'],
|
| 304 |
+
'occupation': ['Tech-support', 'Craft-repair', 'Other-service', 'Sales', 'Exec-managerial', 'Prof-specialty', 'Handlers-cleaners', 'Machine-op-inspct', 'Adm-clerical', 'Farming-fishing', 'Armed-Forces', 'Priv-house-serv', 'Protective-serv', 'Transport-moving', '?'],
|
| 305 |
+
'sex': ['Male', 'Female'],
|
| 306 |
+
'race': ['Black', 'Asian-Pac-Islander', 'Amer-Indian-Eskimo', 'White', 'Other'],
|
| 307 |
+
'native_country': ['United-States', 'Cambodia', 'Canada', 'China', 'Columbia', 'Cuba', 'Dominican-Republic', 'Ecuador', 'El-Salvador', 'England', 'France', 'Germany', 'Greece', 'Guatemala', 'Haiti', 'Holand-Netherlands', 'Honduras', 'Hong', 'Hungary', 'India', 'Iran', 'Ireland', 'Italy', 'Jamaica', 'Japan', 'Laos', 'Mexico', 'Nicaragua', 'Outlying-US(Guam-USVI-etc)', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Puerto-Rico', 'Scotland', 'South', 'Taiwan', 'Thailand', 'Trinadad&Tobago', 'Vietnam', 'Yugoslavia', '?'],
|
| 308 |
+
'relationship': ['Wife', 'Own-child', 'Husband', 'Not-in-family', 'Other-relative', 'Unmarried']
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
# Str <h3 style="margin: 0; color: white;">Hi! I'm Luna</h3>amlit compatibility function
|
| 312 |
+
def st_rerun():
|
| 313 |
+
"""Compatibility function for Streamlit rerun across versions"""
|
| 314 |
+
if hasattr(st, 'rerun'):
|
| 315 |
+
st.rerun()
|
| 316 |
+
else:
|
| 317 |
+
st.experimental_rerun()
|
| 318 |
+
|
| 319 |
+
# Custom CSS for better appearance with chat bubbles
|
| 320 |
+
st.markdown("""
|
| 321 |
+
<style>
|
| 322 |
+
.chat-container {
|
| 323 |
+
max-height: 600px;
|
| 324 |
+
overflow-y: auto;
|
| 325 |
+
padding: 1rem;
|
| 326 |
+
background: linear-gradient(135deg, #e3f2fd 0%, #f8f9fa 100%);
|
| 327 |
+
border-radius: 15px;
|
| 328 |
+
margin: 1rem 0;
|
| 329 |
+
border: 1px solid #e0e0e0;
|
| 330 |
+
}
|
| 331 |
+
.chat-message {
|
| 332 |
+
display: flex;
|
| 333 |
+
margin: 0.8rem 0;
|
| 334 |
+
align-items: flex-end;
|
| 335 |
+
clear: both;
|
| 336 |
+
}
|
| 337 |
+
.user-message {
|
| 338 |
+
justify-content: flex-end;
|
| 339 |
+
flex-direction: row-reverse;
|
| 340 |
+
}
|
| 341 |
+
.assistant-message {
|
| 342 |
+
justify-content: flex-start;
|
| 343 |
+
flex-direction: row;
|
| 344 |
+
}
|
| 345 |
+
.message-bubble {
|
| 346 |
+
padding: 10px 14px;
|
| 347 |
+
border-radius: 18px;
|
| 348 |
+
max-width: 65%;
|
| 349 |
+
word-wrap: break-word;
|
| 350 |
+
box-shadow: 0 1px 2px rgba(0,0,0,0.1);
|
| 351 |
+
position: relative;
|
| 352 |
+
line-height: 1.4;
|
| 353 |
+
font-size: 14px;
|
| 354 |
+
}
|
| 355 |
+
.user-bubble {
|
| 356 |
+
background: #007bff;
|
| 357 |
+
color: white;
|
| 358 |
+
border-bottom-right-radius: 4px;
|
| 359 |
+
margin-right: 8px;
|
| 360 |
+
}
|
| 361 |
+
.user-bubble::after {
|
| 362 |
+
content: '';
|
| 363 |
+
position: absolute;
|
| 364 |
+
right: -8px;
|
| 365 |
+
bottom: 0;
|
| 366 |
+
width: 0;
|
| 367 |
+
height: 0;
|
| 368 |
+
border-left: 8px solid #007bff;
|
| 369 |
+
border-bottom: 8px solid transparent;
|
| 370 |
+
}
|
| 371 |
+
.assistant-bubble {
|
| 372 |
+
background: white;
|
| 373 |
+
color: #333;
|
| 374 |
+
border: 1px solid #e0e0e0;
|
| 375 |
+
border-bottom-left-radius: 4px;
|
| 376 |
+
margin-left: 8px;
|
| 377 |
+
}
|
| 378 |
+
.assistant-bubble::after {
|
| 379 |
+
content: '';
|
| 380 |
+
position: absolute;
|
| 381 |
+
left: -9px;
|
| 382 |
+
bottom: 0;
|
| 383 |
+
width: 0;
|
| 384 |
+
height: 0;
|
| 385 |
+
border-right: 8px solid white;
|
| 386 |
+
border-bottom: 8px solid transparent;
|
| 387 |
+
border-top: 1px solid transparent;
|
| 388 |
+
}
|
| 389 |
+
.assistant-bubble::before {
|
| 390 |
+
content: '';
|
| 391 |
+
position: absolute;
|
| 392 |
+
left: -10px;
|
| 393 |
+
bottom: 0;
|
| 394 |
+
width: 0;
|
| 395 |
+
height: 0;
|
| 396 |
+
border-right: 8px solid #e0e0e0;
|
| 397 |
+
border-bottom: 8px solid transparent;
|
| 398 |
+
}
|
| 399 |
+
.profile-pic {
|
| 400 |
+
width: 40px;
|
| 401 |
+
height: 40px;
|
| 402 |
+
border-radius: 50%;
|
| 403 |
+
margin: 0 5px;
|
| 404 |
+
border: 2px solid #fff;
|
| 405 |
+
box-shadow: 0 1px 3px rgba(0,0,0,0.2);
|
| 406 |
+
flex-shrink: 0;
|
| 407 |
+
}
|
| 408 |
+
.user-icon {
|
| 409 |
+
width: 45px;
|
| 410 |
+
height: 40px;
|
| 411 |
+
border-radius: 50%;
|
| 412 |
+
background: #007bff;
|
| 413 |
+
display: flex;
|
| 414 |
+
align-items: center;
|
| 415 |
+
justify-content: center;
|
| 416 |
+
color: white;
|
| 417 |
+
font-weight: bold;
|
| 418 |
+
font-size: 11px;
|
| 419 |
+
margin: 0 5px;
|
| 420 |
+
box-shadow: 0 1px 3px rgba(0,0,0,0.2);
|
| 421 |
+
flex-shrink: 0;
|
| 422 |
+
}
|
| 423 |
+
.progress-bar {
|
| 424 |
+
background-color: #e9ecef;
|
| 425 |
+
border-radius: 10px;
|
| 426 |
+
padding: 3px;
|
| 427 |
+
border: 1px solid #dee2e6;
|
| 428 |
+
}
|
| 429 |
+
.progress-fill {
|
| 430 |
+
background: linear-gradient(135deg, #007bff 0%, #0056b3 100%);
|
| 431 |
+
height: 22px;
|
| 432 |
+
border-radius: 7px;
|
| 433 |
+
text-align: center;
|
| 434 |
+
line-height: 22px;
|
| 435 |
+
color: white;
|
| 436 |
+
font-weight: bold;
|
| 437 |
+
font-size: 12px;
|
| 438 |
+
box-shadow: 0 2px 4px rgba(0,123,255,0.2);
|
| 439 |
+
}
|
| 440 |
+
.status-card {
|
| 441 |
+
background-color: #f8f9fa;
|
| 442 |
+
padding: 1rem;
|
| 443 |
+
border-radius: 0.5rem;
|
| 444 |
+
border-left: 4px solid #007bff;
|
| 445 |
+
margin: 0.5rem 0;
|
| 446 |
+
}
|
| 447 |
+
.luna-intro {
|
| 448 |
+
display: flex;
|
| 449 |
+
align-items: center;
|
| 450 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 451 |
+
color: white;
|
| 452 |
+
padding: 1rem;
|
| 453 |
+
border-radius: 15px;
|
| 454 |
+
margin: 1rem 0;
|
| 455 |
+
box-shadow: 0 4px 10px rgba(0,0,0,0.1);
|
| 456 |
+
}
|
| 457 |
+
.luna-intro img {
|
| 458 |
+
width: 60px;
|
| 459 |
+
height: 60px;
|
| 460 |
+
border-radius: 50%;
|
| 461 |
+
margin-right: 15px;
|
| 462 |
+
border: 3px solid white;
|
| 463 |
+
}
|
| 464 |
+
.option-button {
|
| 465 |
+
background: #f8f9fa;
|
| 466 |
+
border: 1px solid #dee2e6;
|
| 467 |
+
border-radius: 8px;
|
| 468 |
+
padding: 8px 12px;
|
| 469 |
+
margin: 3px;
|
| 470 |
+
cursor: pointer;
|
| 471 |
+
transition: all 0.2s ease;
|
| 472 |
+
font-size: 13px;
|
| 473 |
+
color: #495057;
|
| 474 |
+
display: inline-block;
|
| 475 |
+
}
|
| 476 |
+
.option-button:hover {
|
| 477 |
+
background: #e9ecef;
|
| 478 |
+
border-color: #007bff;
|
| 479 |
+
color: #007bff;
|
| 480 |
+
transform: translateY(-1px);
|
| 481 |
+
box-shadow: 0 2px 4px rgba(0,123,255,0.2);
|
| 482 |
+
}
|
| 483 |
+
.option-button:active {
|
| 484 |
+
background: #007bff;
|
| 485 |
+
color: white;
|
| 486 |
+
transform: translateY(0);
|
| 487 |
+
}
|
| 488 |
+
.options-container {
|
| 489 |
+
background: #f8f9fa;
|
| 490 |
+
border-radius: 10px;
|
| 491 |
+
padding: 15px;
|
| 492 |
+
margin: 10px 0;
|
| 493 |
+
border: 1px solid #e9ecef;
|
| 494 |
+
}
|
| 495 |
+
</style>
|
| 496 |
+
""", unsafe_allow_html=True)
|
| 497 |
+
|
| 498 |
+
def initialize_system():
|
| 499 |
+
"""Initialize the agent and all components"""
|
| 500 |
+
try:
|
| 501 |
+
agent = Agent()
|
| 502 |
+
answers = Answers(
|
| 503 |
+
list_node=[],
|
| 504 |
+
clf=agent.clf,
|
| 505 |
+
clf_display=agent.clf_display,
|
| 506 |
+
current_instance=agent.current_instance,
|
| 507 |
+
question=None,
|
| 508 |
+
l_exist_classes=agent.l_exist_classes,
|
| 509 |
+
l_exist_features=agent.l_exist_features,
|
| 510 |
+
l_instances=agent.l_instances,
|
| 511 |
+
data=agent.data,
|
| 512 |
+
df_display_instance=agent.df_display_instance,
|
| 513 |
+
predicted_class=agent.predicted_class,
|
| 514 |
+
preprocessor=agent.preprocessor
|
| 515 |
+
)
|
| 516 |
+
return agent, answers
|
| 517 |
+
except Exception as e:
|
| 518 |
+
st.error(f"Failed to initialize system: {str(e)}")
|
| 519 |
+
st.error("Please check the console for more details.")
|
| 520 |
+
import traceback
|
| 521 |
+
st.code(traceback.format_exc())
|
| 522 |
+
# Return None values to prevent further errors
|
| 523 |
+
return None, None
|
| 524 |
+
|
| 525 |
+
# Initialize system
|
| 526 |
+
if 'agent' not in st.session_state:
|
| 527 |
+
st.session_state.agent, st.session_state.answers = initialize_system()
|
| 528 |
+
|
| 529 |
+
# Check if initialization was successful
|
| 530 |
+
if st.session_state.agent is None:
|
| 531 |
+
st.error("System initialization failed. Please check the error messages above and try refreshing the page.")
|
| 532 |
+
st.stop()
|
| 533 |
+
|
| 534 |
+
agent = st.session_state.agent
|
| 535 |
+
answers = st.session_state.answers
|
| 536 |
+
|
| 537 |
+
# Initialize loan assistant
|
| 538 |
+
if 'loan_assistant' not in st.session_state:
|
| 539 |
+
st.session_state.loan_assistant = LoanAssistant(agent)
|
| 540 |
+
st.session_state.chat_history = []
|
| 541 |
+
|
| 542 |
+
# App header
|
| 543 |
+
st.title("🏦 AI Loan Assistant - Credit Pre-Assessment")
|
| 544 |
+
|
| 545 |
+
# Assistant Introduction (A/B testing)
|
| 546 |
+
assistant_avatar = config.get_assistant_avatar()
|
| 547 |
+
if assistant_avatar and os.path.exists(assistant_avatar):
|
| 548 |
+
import base64
|
| 549 |
+
with open(assistant_avatar, "rb") as f:
|
| 550 |
+
avatar_pic_b64 = base64.b64encode(f.read()).decode()
|
| 551 |
+
|
| 552 |
+
st.markdown(f"""
|
| 553 |
+
<div class="luna-intro">
|
| 554 |
+
<img src="data:image/png;base64,{avatar_pic_b64}" alt="{config.assistant_name}">
|
| 555 |
+
<div>
|
| 556 |
+
<h3 style="margin: 0; color: white;">Hi! I'm {config.assistant_name}</h3>
|
| 557 |
+
<p style="margin: 5px 0 0 0; opacity: 0.9;">{config.assistant_intro}</p>
|
| 558 |
+
</div>
|
| 559 |
+
</div>
|
| 560 |
+
""", unsafe_allow_html=True)
|
| 561 |
+
else:
|
| 562 |
+
# Fallback without image
|
| 563 |
+
st.markdown(f"""
|
| 564 |
+
<div class="luna-intro">
|
| 565 |
+
<div style="width: 60px; height: 60px; border-radius: 50%; margin-right: 15px; border: 3px solid white; background: #f093fb; display: flex; align-items: center; justify-content: center; color: white; font-weight: bold; font-size: 24px;">{config.assistant_name[0]}</div>
|
| 566 |
+
<div>
|
| 567 |
+
<h3 style="margin: 0; color: white;">Hi! I'm {config.assistant_name}</h3>
|
| 568 |
+
<p style="margin: 5px 0 0 0; opacity: 0.9;">{config.assistant_intro}</p>
|
| 569 |
+
</div>
|
| 570 |
+
</div>
|
| 571 |
+
""", unsafe_allow_html=True)
|
| 572 |
+
|
| 573 |
+
# Single conversational interface
|
| 574 |
+
st.markdown("---")
|
| 575 |
+
|
| 576 |
+
# Sidebar - keep minimal to avoid distracting from experimental task
|
| 577 |
+
with st.sidebar:
|
| 578 |
+
# No restart option - users should complete one application per session
|
| 579 |
+
# Explanation style is controlled by the experimental condition, not user choice
|
| 580 |
+
|
| 581 |
+
# A/B Testing Debug Info (only for development/testing - hidden from users)
|
| 582 |
+
# Uncomment the lines below only when debugging A/B testing locally
|
| 583 |
+
# if config.show_debug_info and os.getenv('HICXAI_DEBUG_MODE', 'false').lower() == 'true':
|
| 584 |
+
# What‑if Lab (shown after user asks what-if in counterfactual HIGH anthropomorphism conditions only)
|
| 585 |
+
if config.show_counterfactual and config.show_anthropomorphic and getattr(st.session_state.loan_assistant, 'show_what_if_lab', False):
|
| 586 |
+
st.markdown("---")
|
| 587 |
+
st.subheader("🧪 What‑if Lab")
|
| 588 |
+
st.caption("Adjust inputs to see how the predicted probability changes.")
|
| 589 |
+
|
| 590 |
+
# Prepare a baseline instance from current app state if available
|
| 591 |
+
app_state = st.session_state.loan_assistant.application
|
| 592 |
+
def default(v, fallback):
|
| 593 |
+
return v if v is not None else fallback
|
| 594 |
+
|
| 595 |
+
# Core numerics
|
| 596 |
+
age = st.slider("Age", min_value=17, max_value=90, value=int(default(app_state.age, 35)))
|
| 597 |
+
hours = st.slider("Hours per week", min_value=1, max_value=99, value=int(default(app_state.hours_per_week, 40)))
|
| 598 |
+
gain = st.number_input("Capital Gain", min_value=0, max_value=99999, step=100, value=int(default(app_state.capital_gain, 0)))
|
| 599 |
+
loss = st.number_input("Capital Loss", min_value=0, max_value=4356, step=50, value=int(default(app_state.capital_loss, 0)))
|
| 600 |
+
|
| 601 |
+
# Categorical selectors using known field options
|
| 602 |
+
edu = st.selectbox("Education", options=field_options['education'], index=field_options['education'].index(default(app_state.education, 'HS-grad')))
|
| 603 |
+
occ = st.selectbox("Occupation", options=field_options['occupation'], index=field_options['occupation'].index(default(app_state.occupation, 'Sales')))
|
| 604 |
+
workclass = st.selectbox("Workclass", options=field_options['workclass'], index=field_options['workclass'].index(default(app_state.workclass, 'Private')))
|
| 605 |
+
marital = st.selectbox("Marital Status", options=field_options['marital_status'], index=field_options['marital_status'].index(default(app_state.marital_status, 'Never-married')))
|
| 606 |
+
relationship = st.selectbox("Relationship", options=field_options['relationship'], index=field_options['relationship'].index(default(app_state.relationship, 'Not-in-family')))
|
| 607 |
+
sex = st.selectbox("Sex", options=field_options['sex'], index=field_options['sex'].index(default(app_state.sex, 'Male')))
|
| 608 |
+
race = st.selectbox("Race", options=field_options['race'], index=field_options['race'].index(default(app_state.race, 'White')))
|
| 609 |
+
country = st.selectbox("Native Country", options=field_options['native_country'], index=field_options['native_country'].index(default(app_state.native_country, 'United-States')))
|
| 610 |
+
|
| 611 |
+
# Build a hypothetical instance and predict
|
| 612 |
+
try:
|
| 613 |
+
# Start from existing application dict (fill minimal defaults)
|
| 614 |
+
hypo = app_state.to_dict()
|
| 615 |
+
hypo['age'] = age
|
| 616 |
+
hypo['hours_per_week'] = hours
|
| 617 |
+
hypo['education'] = edu
|
| 618 |
+
hypo['occupation'] = occ
|
| 619 |
+
hypo['workclass'] = workclass
|
| 620 |
+
hypo['marital_status'] = marital
|
| 621 |
+
hypo['relationship'] = relationship
|
| 622 |
+
hypo['sex'] = sex
|
| 623 |
+
hypo['race'] = race
|
| 624 |
+
hypo['native_country'] = country
|
| 625 |
+
hypo['capital_gain'] = gain
|
| 626 |
+
hypo['capital_loss'] = loss
|
| 627 |
+
if hypo.get('education_num') is None:
|
| 628 |
+
edu_map = {
|
| 629 |
+
'Preschool': 1, '1st-4th': 2, '5th-6th': 3, '7th-8th': 4, '9th': 5,
|
| 630 |
+
'10th': 6, '11th': 7, '12th': 8, 'HS-grad': 9, 'Some-college': 10,
|
| 631 |
+
'Assoc-voc': 11, 'Assoc-acdm': 12, 'Bachelors': 13, 'Masters': 14,
|
| 632 |
+
'Prof-school': 15, 'Doctorate': 16
|
| 633 |
+
}
|
| 634 |
+
hypo['education_num'] = edu_map.get(edu, 9)
|
| 635 |
+
# Ensure required fields have plausible defaults
|
| 636 |
+
hypo.setdefault('workclass', 'Private')
|
| 637 |
+
hypo.setdefault('marital_status', 'Never-married')
|
| 638 |
+
hypo.setdefault('relationship', 'Not-in-family')
|
| 639 |
+
hypo.setdefault('race', 'White')
|
| 640 |
+
hypo.setdefault('sex', 'Male')
|
| 641 |
+
hypo.setdefault('capital_gain', 0)
|
| 642 |
+
hypo.setdefault('capital_loss', 0)
|
| 643 |
+
hypo.setdefault('native_country', 'United-States')
|
| 644 |
+
|
| 645 |
+
import pandas as pd
|
| 646 |
+
app_df = pd.DataFrame([hypo])
|
| 647 |
+
app_df['income'] = '<=50K' # dummy
|
| 648 |
+
from preprocessing import preprocess_adult
|
| 649 |
+
processed = preprocess_adult(app_df)
|
| 650 |
+
X = processed.drop('income', axis=1)
|
| 651 |
+
# Align with training features
|
| 652 |
+
train_df = pd.concat([agent.data['X_display'], agent.data['y_display']], axis=1)
|
| 653 |
+
train_df_processed = preprocess_adult(train_df)
|
| 654 |
+
expected = train_df_processed.drop('income', axis=1).columns.tolist()
|
| 655 |
+
for col in expected:
|
| 656 |
+
if col not in X.columns:
|
| 657 |
+
X[col] = 0
|
| 658 |
+
X = X[expected]
|
| 659 |
+
# Predict probability if available
|
| 660 |
+
prob = None
|
| 661 |
+
if hasattr(agent.clf_display, 'predict_proba'):
|
| 662 |
+
p = agent.clf_display.predict_proba(X)
|
| 663 |
+
# Assume class index 1 corresponds to '>50K'
|
| 664 |
+
prob = float(p[0][1]) if p.shape[1] > 1 else float(p[0][0])
|
| 665 |
+
st.metric(label="Estimated P(>50K)", value=f"{(prob if prob is not None else 0.5)*100:.1f}%")
|
| 666 |
+
|
| 667 |
+
# Optional: refresh SHAP visuals for hypo profile (textual SHAP for now)
|
| 668 |
+
# We keep visuals in the main flow; here we just indicate changes
|
| 669 |
+
st.caption("Adjust inputs to explore their impact. Use chat for detailed explanations and visuals.")
|
| 670 |
+
except Exception as e:
|
| 671 |
+
st.caption(f"What‑if Lab unavailable: {e}")
|
| 672 |
+
# Otherwise, no What‑if panel is shown until triggered by user
|
| 673 |
+
# st.markdown("---")
|
| 674 |
+
# st.markdown("**🧪 Debug Info**")
|
| 675 |
+
# st.markdown(f"Version: **{config.version}**")
|
| 676 |
+
# st.markdown(f"Assistant: **{config.assistant_name}**")
|
| 677 |
+
# st.markdown(f"SHAP Visuals: **{config.show_shap_visualizations}**")
|
| 678 |
+
|
| 679 |
+
# Chat interface - Display chat history with enhanced bubbles
|
| 680 |
+
st.markdown('<div class="chat-container">', unsafe_allow_html=True)
|
| 681 |
+
|
| 682 |
+
for i, (user_msg, assistant_msg) in enumerate(st.session_state.chat_history):
|
| 683 |
+
# User message (right side, blue bubble)
|
| 684 |
+
if user_msg:
|
| 685 |
+
st.markdown(f"""
|
| 686 |
+
<div class="chat-message user-message">
|
| 687 |
+
<div class="user-icon">You</div>
|
| 688 |
+
<div class="message-bubble user-bubble">
|
| 689 |
+
{user_msg}
|
| 690 |
+
</div>
|
| 691 |
+
</div>
|
| 692 |
+
""", unsafe_allow_html=True)
|
| 693 |
+
|
| 694 |
+
# Assistant message with profile picture (left side, white bubble)
|
| 695 |
+
if assistant_msg:
|
| 696 |
+
assistant_avatar = config.get_assistant_avatar()
|
| 697 |
+
if assistant_avatar and os.path.exists(assistant_avatar):
|
| 698 |
+
import base64
|
| 699 |
+
with open(assistant_avatar, "rb") as f:
|
| 700 |
+
avatar_pic_b64 = base64.b64encode(f.read()).decode()
|
| 701 |
+
avatar_pic_element = f'<img src="data:image/png;base64,{avatar_pic_b64}" class="profile-pic" alt="{config.assistant_name}">'
|
| 702 |
+
else:
|
| 703 |
+
avatar_pic_element = f'<div class="profile-pic" style="background: #f093fb; display: flex; align-items: center; justify-content: center; color: white; font-weight: bold; font-size: 16px;">{config.assistant_name[0]}</div>'
|
| 704 |
+
|
| 705 |
+
st.markdown(f"""
|
| 706 |
+
<div class="chat-message assistant-message">
|
| 707 |
+
{avatar_pic_element}
|
| 708 |
+
<div class="message-bubble assistant-bubble">
|
| 709 |
+
{assistant_msg}
|
| 710 |
+
</div>
|
| 711 |
+
</div>
|
| 712 |
+
""", unsafe_allow_html=True)
|
| 713 |
+
|
| 714 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 715 |
+
|
| 716 |
+
# Initialize with welcome message
|
| 717 |
+
if len(st.session_state.chat_history) == 0:
|
| 718 |
+
welcome_msg = st.session_state.loan_assistant.handle_message("hello")
|
| 719 |
+
st.session_state.chat_history.append((None, welcome_msg))
|
| 720 |
+
st_rerun()
|
| 721 |
+
|
| 722 |
+
# Chat input (form enables Enter-to-send and clears on submit automatically)
|
| 723 |
+
# Check if current field has clickable options for placeholder
|
| 724 |
+
current_field = getattr(st.session_state.loan_assistant, 'current_field', None)
|
| 725 |
+
if current_field and current_field in field_options:
|
| 726 |
+
placeholder_text = "💬 Type your answer or use the clickable buttons below..."
|
| 727 |
+
else:
|
| 728 |
+
placeholder_text = "Type your message to Luna..."
|
| 729 |
+
|
| 730 |
+
with st.form("chat_form", clear_on_submit=True):
|
| 731 |
+
col1, col2 = st.columns([5, 1])
|
| 732 |
+
with col1:
|
| 733 |
+
user_message = st.text_input("Message to Luna", key="user_input", placeholder=placeholder_text, label_visibility="collapsed")
|
| 734 |
+
with col2:
|
| 735 |
+
send_button = st.form_submit_button("Send", use_container_width=True)
|
| 736 |
+
|
| 737 |
+
# Add helper text for clickable features
|
| 738 |
+
if current_field and current_field in field_options:
|
| 739 |
+
st.markdown('<div style="text-align: center; color: #666; font-size: 0.85em; margin-top: 5px;">👆 Use the clickable buttons below for faster selection!</div>', unsafe_allow_html=True)
|
| 740 |
+
|
| 741 |
+
# Show clickable options right after chat input (for immediate visibility)
|
| 742 |
+
if current_field and current_field in field_options:
|
| 743 |
+
st.markdown("---")
|
| 744 |
+
st.markdown(f"### 🎯 Quick Select: {current_field.replace('_', ' ').title()}")
|
| 745 |
+
st.markdown("**💡 Click any option below instead of typing:**")
|
| 746 |
+
st.markdown('<div class="options-container">', unsafe_allow_html=True)
|
| 747 |
+
|
| 748 |
+
options = field_options[current_field]
|
| 749 |
+
|
| 750 |
+
# Create buttons in rows with enhanced styling
|
| 751 |
+
cols_per_row = 4 if len(options) > 8 else 3
|
| 752 |
+
for i in range(0, len(options), cols_per_row):
|
| 753 |
+
cols = st.columns(cols_per_row)
|
| 754 |
+
for j, option in enumerate(options[i:i+cols_per_row]):
|
| 755 |
+
with cols[j]:
|
| 756 |
+
# Get friendly name for display
|
| 757 |
+
friendly_option = get_friendly_feature_name(f"{current_field}_{option}")
|
| 758 |
+
# If no mapping found, clean up the technical name
|
| 759 |
+
if friendly_option.startswith(current_field.title()):
|
| 760 |
+
friendly_option = option.replace('-', ' ').replace('_', ' ')
|
| 761 |
+
|
| 762 |
+
# Enhanced button styling based on option type
|
| 763 |
+
if option == "Other":
|
| 764 |
+
button_text = f"🔄 {friendly_option}"
|
| 765 |
+
button_type = "primary"
|
| 766 |
+
elif option == "?":
|
| 767 |
+
button_text = f"❓ Unknown/Prefer not to say"
|
| 768 |
+
button_type = "primary"
|
| 769 |
+
elif option in ["Male", "Female"]:
|
| 770 |
+
button_text = f"👤 {friendly_option}"
|
| 771 |
+
button_type = "secondary"
|
| 772 |
+
elif option == "United-States":
|
| 773 |
+
button_text = f"🇺🇸 {friendly_option}"
|
| 774 |
+
button_type = "primary"
|
| 775 |
+
elif option in ["Private", "Self-emp-not-inc", "Self-emp-inc"]:
|
| 776 |
+
button_text = f"💼 {friendly_option}"
|
| 777 |
+
button_type = "secondary"
|
| 778 |
+
elif "gov" in option.lower():
|
| 779 |
+
button_text = f"🏛️ {friendly_option}"
|
| 780 |
+
button_type = "secondary"
|
| 781 |
+
else:
|
| 782 |
+
button_text = f"✨ {friendly_option}"
|
| 783 |
+
button_type = "secondary"
|
| 784 |
+
|
| 785 |
+
if st.button(button_text, key=f"option_top_{current_field}_{option}", use_container_width=True, type=button_type):
|
| 786 |
+
st.session_state.option_clicked = option
|
| 787 |
+
st_rerun()
|
| 788 |
+
|
| 789 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 790 |
+
st.markdown("*💬 Or you can still type your answer in the chat box above*")
|
| 791 |
+
|
| 792 |
+
# Process user input
|
| 793 |
+
if send_button and user_message:
|
| 794 |
+
# Mark that user has started the application
|
| 795 |
+
st.session_state.application_started = True
|
| 796 |
+
|
| 797 |
+
# Log interaction
|
| 798 |
+
if logger:
|
| 799 |
+
current_field = getattr(st.session_state.loan_assistant, 'current_field', None)
|
| 800 |
+
logger.log_interaction("user_message", {
|
| 801 |
+
"field": current_field,
|
| 802 |
+
"input_method": "typed",
|
| 803 |
+
"content": user_message,
|
| 804 |
+
"conversation_state": st.session_state.loan_assistant.conversation_state.value
|
| 805 |
+
})
|
| 806 |
+
|
| 807 |
+
# Handle the message through loan assistant
|
| 808 |
+
assistant_response = st.session_state.loan_assistant.handle_message(user_message)
|
| 809 |
+
|
| 810 |
+
# Log assistant response
|
| 811 |
+
if logger:
|
| 812 |
+
logger.log_interaction("assistant_response", {
|
| 813 |
+
"content": assistant_response
|
| 814 |
+
})
|
| 815 |
+
|
| 816 |
+
# Add to chat history (form clears input on submit)
|
| 817 |
+
st.session_state.chat_history.append((user_message, assistant_response))
|
| 818 |
+
st_rerun()
|
| 819 |
+
|
| 820 |
+
# Handle option clicks
|
| 821 |
+
if 'option_clicked' in st.session_state and st.session_state.option_clicked:
|
| 822 |
+
option_value = st.session_state.option_clicked
|
| 823 |
+
|
| 824 |
+
# Mark that user has started the application
|
| 825 |
+
st.session_state.application_started = True
|
| 826 |
+
|
| 827 |
+
# Log interaction
|
| 828 |
+
if logger:
|
| 829 |
+
current_field = getattr(st.session_state.loan_assistant, 'current_field', None)
|
| 830 |
+
logger.log_interaction("user_message", {
|
| 831 |
+
"field": current_field,
|
| 832 |
+
"input_method": "clicked",
|
| 833 |
+
"content": option_value,
|
| 834 |
+
"conversation_state": st.session_state.loan_assistant.conversation_state.value
|
| 835 |
+
})
|
| 836 |
+
|
| 837 |
+
assistant_response = st.session_state.loan_assistant.handle_message(option_value)
|
| 838 |
+
|
| 839 |
+
# Log assistant response
|
| 840 |
+
if logger:
|
| 841 |
+
logger.log_interaction("assistant_response", {
|
| 842 |
+
"content": assistant_response
|
| 843 |
+
})
|
| 844 |
+
|
| 845 |
+
# Add to chat history
|
| 846 |
+
st.session_state.chat_history.append((option_value, assistant_response))
|
| 847 |
+
st.session_state.option_clicked = None # Reset
|
| 848 |
+
st_rerun()
|
| 849 |
+
|
| 850 |
+
# Persistent SHAP visuals section: render when feature_importance explanation is enabled
|
| 851 |
+
if config.show_shap_visualizations:
|
| 852 |
+
shap_data = getattr(st.session_state.loan_assistant, 'last_shap_result', None)
|
| 853 |
+
if shap_data:
|
| 854 |
+
st.markdown("---")
|
| 855 |
+
st.subheader("🔎 Visual Explanations")
|
| 856 |
+
display_shap_explanation(shap_data)
|
| 857 |
+
explain_shap_visualizations()
|
| 858 |
+
|
| 859 |
+
# Quick reply buttons based on current state
|
| 860 |
+
st.markdown("---")
|
| 861 |
+
st.markdown("**Quick Replies:**")
|
| 862 |
+
|
| 863 |
+
current_state = st.session_state.loan_assistant.conversation_state.value
|
| 864 |
+
|
| 865 |
+
if current_state == 'greeting':
|
| 866 |
+
col1, col2, col3 = st.columns(3)
|
| 867 |
+
with col1:
|
| 868 |
+
if st.button("👋 Start Application", key="quick_start"):
|
| 869 |
+
response = st.session_state.loan_assistant.handle_message("start")
|
| 870 |
+
st.session_state.chat_history.append(("start", response))
|
| 871 |
+
st_rerun()
|
| 872 |
+
|
| 873 |
+
elif current_state == 'collecting_info':
|
| 874 |
+
col1, col2, col3 = st.columns(3)
|
| 875 |
+
with col1:
|
| 876 |
+
if st.button("Check Progress", key="quick_progress"):
|
| 877 |
+
if logger:
|
| 878 |
+
logger.log_interaction("progress_check", {})
|
| 879 |
+
response = st.session_state.loan_assistant.handle_message("review")
|
| 880 |
+
st.session_state.chat_history.append(("check progress", response))
|
| 881 |
+
st_rerun()
|
| 882 |
+
with col2:
|
| 883 |
+
if st.button("Help", key="quick_help"):
|
| 884 |
+
if logger:
|
| 885 |
+
logger.log_interaction("help_click", {})
|
| 886 |
+
# Get context-aware help
|
| 887 |
+
current_field = getattr(st.session_state.loan_assistant, 'current_field', None)
|
| 888 |
+
if current_field:
|
| 889 |
+
help_msg = st.session_state.loan_assistant._get_field_help(current_field)
|
| 890 |
+
help_msg += f"\n\n💡 **You can also:**\n• Say 'review' to see your progress\n• Click the quick-select buttons below\n• Ask for specific examples"
|
| 891 |
+
else:
|
| 892 |
+
help_msg = ("I'm collecting information for your loan application. Please answer the questions "
|
| 893 |
+
"as accurately as possible. You can say 'review' to see your progress.")
|
| 894 |
+
st.session_state.chat_history.append(("help", help_msg))
|
| 895 |
+
st_rerun()
|
| 896 |
+
|
| 897 |
+
elif current_state == 'complete':
|
| 898 |
+
# Only show What-If button in Condition 4 (HIGH anthropomorphism + counterfactual)
|
| 899 |
+
if config.show_counterfactual and config.show_anthropomorphic:
|
| 900 |
+
col1, col2 = st.columns(2)
|
| 901 |
+
with col1:
|
| 902 |
+
if st.button("Explain Decision", key="quick_explain", use_container_width=True):
|
| 903 |
+
if logger:
|
| 904 |
+
logger.log_interaction("explanation_request", {"type": "decision_explanation"})
|
| 905 |
+
response = st.session_state.loan_assistant.handle_message("explain")
|
| 906 |
+
st.session_state.chat_history.append(("explain", response))
|
| 907 |
+
st_rerun()
|
| 908 |
+
with col2:
|
| 909 |
+
if st.button("🔧 What If Analysis", key="quick_whatif", use_container_width=True):
|
| 910 |
+
# Turn on What‑if Lab and prompt guidance
|
| 911 |
+
try:
|
| 912 |
+
st.session_state.loan_assistant.show_what_if_lab = True
|
| 913 |
+
except Exception:
|
| 914 |
+
pass
|
| 915 |
+
response = "What‑if Lab enabled in the sidebar. Adjust Age, Hours, Education, or Occupation to see how the probability changes."
|
| 916 |
+
st.session_state.chat_history.append(("what if analysis", response))
|
| 917 |
+
st_rerun()
|
| 918 |
+
else:
|
| 919 |
+
# Show only Explain button for other conditions
|
| 920 |
+
if st.button("Explain Decision", key="quick_explain", use_container_width=True):
|
| 921 |
+
if logger:
|
| 922 |
+
logger.log_interaction("explanation_request", {"type": "decision_explanation"})
|
| 923 |
+
response = st.session_state.loan_assistant.handle_message("explain")
|
| 924 |
+
st.session_state.chat_history.append(("explain", response))
|
| 925 |
+
st_rerun()
|
| 926 |
+
|
| 927 |
+
# Clickable Options for Current Field (if collecting info)
|
| 928 |
+
if current_state == 'collecting_info' and hasattr(st.session_state.loan_assistant, 'current_field') and st.session_state.loan_assistant.current_field:
|
| 929 |
+
current_field = st.session_state.loan_assistant.current_field
|
| 930 |
+
|
| 931 |
+
if current_field in field_options:
|
| 932 |
+
st.markdown("---")
|
| 933 |
+
st.markdown(f"### 🎯 Quick Select: {current_field.replace('_', ' ').title()}")
|
| 934 |
+
st.markdown("**💡 Click any option below instead of typing:**")
|
| 935 |
+
st.markdown('<div style="background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%); padding: 15px; border-radius: 10px; margin: 10px 0; border: 1px solid #dee2e6;">', unsafe_allow_html=True)
|
| 936 |
+
|
| 937 |
+
options = field_options[current_field]
|
| 938 |
+
|
| 939 |
+
# Create buttons in rows with enhanced styling
|
| 940 |
+
cols_per_row = 4 if len(options) > 8 else 3
|
| 941 |
+
for i in range(0, len(options), cols_per_row):
|
| 942 |
+
cols = st.columns(cols_per_row)
|
| 943 |
+
for j, option in enumerate(options[i:i+cols_per_row]):
|
| 944 |
+
with cols[j]:
|
| 945 |
+
# Enhanced button styling based on option type
|
| 946 |
+
# Get friendly name for display
|
| 947 |
+
friendly_option = get_friendly_feature_name(f"{current_field}_{option}")
|
| 948 |
+
# If no mapping found, use the option as-is
|
| 949 |
+
if friendly_option.startswith(current_field.title()):
|
| 950 |
+
friendly_option = option.replace('-', ' ').replace('_', ' ')
|
| 951 |
+
|
| 952 |
+
if option == "Other":
|
| 953 |
+
button_text = f"🔄 {friendly_option}"
|
| 954 |
+
button_type = "primary"
|
| 955 |
+
elif option == "?":
|
| 956 |
+
button_text = f"❓ Unknown/Prefer not to say"
|
| 957 |
+
button_type = "primary"
|
| 958 |
+
elif option in ["Male", "Female"]:
|
| 959 |
+
button_text = f"👤 {friendly_option}"
|
| 960 |
+
button_type = "secondary"
|
| 961 |
+
elif option == "United-States":
|
| 962 |
+
button_text = f"🇺🇸 {friendly_option}"
|
| 963 |
+
button_type = "primary"
|
| 964 |
+
elif option in ["Private", "Self-emp-not-inc", "Self-emp-inc"]:
|
| 965 |
+
button_text = f"💼 {friendly_option}"
|
| 966 |
+
button_type = "secondary"
|
| 967 |
+
elif "gov" in option.lower():
|
| 968 |
+
button_text = f"🏛️ {friendly_option}"
|
| 969 |
+
button_type = "secondary"
|
| 970 |
+
else:
|
| 971 |
+
button_text = f"✨ {friendly_option}"
|
| 972 |
+
button_type = "secondary"
|
| 973 |
+
|
| 974 |
+
if st.button(button_text, key=f"option_{current_field}_{option}", use_container_width=True, type=button_type):
|
| 975 |
+
st.session_state.option_clicked = option
|
| 976 |
+
st_rerun()
|
| 977 |
+
|
| 978 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 979 |
+
st.markdown("*💬 Or you can still type your answer in the chat box above*")
|
| 980 |
+
|
| 981 |
+
# Feedback section (appears after application is complete)
|
| 982 |
+
if current_state == 'complete' and len(st.session_state.chat_history) > 5:
|
| 983 |
+
st.markdown("---")
|
| 984 |
+
st.markdown("### 📝 Your Feedback")
|
| 985 |
+
st.markdown("Help us improve by sharing your experience:")
|
| 986 |
+
|
| 987 |
+
with st.form("feedback_form"):
|
| 988 |
+
col1, col2 = st.columns(2)
|
| 989 |
+
|
| 990 |
+
with col1:
|
| 991 |
+
rating = st.select_slider(
|
| 992 |
+
"How would you rate your experience?",
|
| 993 |
+
options=[1, 2, 3, 4, 5],
|
| 994 |
+
value=3,
|
| 995 |
+
format_func=lambda x: "⭐" * x
|
| 996 |
+
)
|
| 997 |
+
|
| 998 |
+
ease_of_use = st.radio(
|
| 999 |
+
"Was the application process easy to understand?",
|
| 1000 |
+
["Very Easy", "Easy", "Neutral", "Difficult", "Very Difficult"]
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
with col2:
|
| 1004 |
+
explanation_clarity = st.radio(
|
| 1005 |
+
"Were the AI explanations helpful?",
|
| 1006 |
+
["Very Helpful", "Helpful", "Neutral", "Not Helpful", "Confusing"]
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
would_recommend = st.radio(
|
| 1010 |
+
"Would you recommend this service?",
|
| 1011 |
+
["Definitely", "Probably", "Maybe", "Probably Not", "Definitely Not"]
|
| 1012 |
+
)
|
| 1013 |
+
|
| 1014 |
+
feedback_text = st.text_area(
|
| 1015 |
+
"Additional comments (optional):",
|
| 1016 |
+
placeholder="“What feature would help you most next time?”\n“What would make this agent's explanations more useful?”..."
|
| 1017 |
+
)
|
| 1018 |
+
|
| 1019 |
+
submitted = st.form_submit_button("Submit Feedback 🚀")
|
| 1020 |
+
|
| 1021 |
+
if submitted:
|
| 1022 |
+
# Calculate completion percentage
|
| 1023 |
+
completion = st.session_state.loan_assistant.application.calculate_completion()
|
| 1024 |
+
|
| 1025 |
+
feedback_data = {
|
| 1026 |
+
"rating": rating,
|
| 1027 |
+
"ease_of_use": ease_of_use,
|
| 1028 |
+
"explanation_clarity": explanation_clarity,
|
| 1029 |
+
"would_recommend": would_recommend,
|
| 1030 |
+
"additional_comments": feedback_text,
|
| 1031 |
+
"conversation_length": len(st.session_state.chat_history),
|
| 1032 |
+
"completion_percentage": completion,
|
| 1033 |
+
# A/B Testing metadata
|
| 1034 |
+
"ab_version": config.version,
|
| 1035 |
+
"session_id": config.session_id,
|
| 1036 |
+
"assistant_name": config.assistant_name,
|
| 1037 |
+
"had_shap_visualizations": config.show_shap_visualizations,
|
| 1038 |
+
"timestamp": pd.Timestamp.now().isoformat()
|
| 1039 |
+
}
|
| 1040 |
+
|
| 1041 |
+
# Log feedback to data logger
|
| 1042 |
+
if logger:
|
| 1043 |
+
logger.set_feedback(feedback_data)
|
| 1044 |
+
|
| 1045 |
+
# Save feedback
|
| 1046 |
+
try:
|
| 1047 |
+
# Try GitHub first (if configured)
|
| 1048 |
+
github_token = os.getenv('GITHUB_TOKEN')
|
| 1049 |
+
github_repo = os.getenv('GITHUB_REPO', 'your-username/your-repo')
|
| 1050 |
+
|
| 1051 |
+
if github_token:
|
| 1052 |
+
import json
|
| 1053 |
+
timestamp = pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')
|
| 1054 |
+
filename = f"feedback/session_{config.session_id}_{timestamp}.json"
|
| 1055 |
+
|
| 1056 |
+
success = save_to_github(
|
| 1057 |
+
repo=github_repo,
|
| 1058 |
+
path=filename,
|
| 1059 |
+
content=json.dumps(feedback_data, indent=2),
|
| 1060 |
+
commit_message=f"User feedback - {config.version} - {timestamp}",
|
| 1061 |
+
github_token=github_token
|
| 1062 |
+
)
|
| 1063 |
+
|
| 1064 |
+
if success:
|
| 1065 |
+
st.success("Thank you for your feedback! 🎉")
|
| 1066 |
+
st.session_state.feedback_submitted = True
|
| 1067 |
+
else:
|
| 1068 |
+
raise Exception("GitHub save failed")
|
| 1069 |
+
else:
|
| 1070 |
+
raise Exception("No GitHub token configured")
|
| 1071 |
+
|
| 1072 |
+
except Exception as e:
|
| 1073 |
+
st.warning("Feedback saved locally. Thank you!")
|
| 1074 |
+
st.session_state.feedback_submitted = True
|
| 1075 |
+
|
| 1076 |
+
# Fallback: save to local file
|
| 1077 |
+
import json
|
| 1078 |
+
os.makedirs('feedback', exist_ok=True)
|
| 1079 |
+
timestamp = pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')
|
| 1080 |
+
filename = f"feedback/session_{config.session_id}_{timestamp}.json"
|
| 1081 |
+
|
| 1082 |
+
with open(filename, "w") as f:
|
| 1083 |
+
f.write(json.dumps(feedback_data, indent=2))
|
| 1084 |
+
|
| 1085 |
+
# Show "Continue to survey" button OUTSIDE the form (alternate after feedback)
|
| 1086 |
+
# Only show after 2 minutes to ensure user engagement
|
| 1087 |
+
if st.session_state.get("feedback_submitted", False) and st.session_state.get("return_raw"):
|
| 1088 |
+
elapsed_time = time.time() - st.session_state.get("start_time", time.time())
|
| 1089 |
+
if elapsed_time >= 120: # 2 minutes = 120 seconds
|
| 1090 |
+
st.markdown("---")
|
| 1091 |
+
if st.button("✅ Continue to survey", type="primary", use_container_width=True, key="feedback_return"):
|
| 1092 |
+
back_to_survey()
|
| 1093 |
+
else:
|
| 1094 |
+
remaining = int(120 - elapsed_time)
|
| 1095 |
+
st.markdown("---")
|
| 1096 |
+
st.info(f"⏱️ Please interact with the application. Continue button will appear in {remaining} seconds.")
|
| 1097 |
+
|
| 1098 |
+
# Footer with dataset information
|
| 1099 |
+
st.markdown("---")
|
| 1100 |
+
st.markdown("""
|
| 1101 |
+
<div style='text-align: center; color: #666; padding: 20px;'>
|
| 1102 |
+
<p>🏦 AI Loan Assistant</p>
|
| 1103 |
+
<p><small>🔬 Algorithm trained on the Adult (Census Income) dataset with 32,561 records from the UCI Machine Learning Repository</small></p>
|
| 1104 |
+
</div>
|
| 1105 |
+
""", unsafe_allow_html=True)
|
| 1106 |
+
|
| 1107 |
+
# Expandable dataset details
|
| 1108 |
+
with st.expander("📊 Dataset Information - Adult Census Income Dataset"):
|
| 1109 |
+
st.markdown("""
|
| 1110 |
+
**Dataset Overview:**
|
| 1111 |
+
|
| 1112 |
+
The Adult Census Income Dataset is a popular benchmark dataset from the UCI Machine Learning Repository,
|
| 1113 |
+
sometimes referred to as the Census Income or Adult dataset. It includes **32,561 records** and **15 attributes**,
|
| 1114 |
+
each representing a person's social, employment, and demographic information. The dataset originates from the
|
| 1115 |
+
U.S. Census database from 1994.
|
| 1116 |
+
|
| 1117 |
+
**Prediction Task:**
|
| 1118 |
+
|
| 1119 |
+
The main goal is to determine whether an individual makes more than $50,000 per year based on their attributes.
|
| 1120 |
+
The income is the target variable with two possible classes:
|
| 1121 |
+
- **≤50K**: Income less than or equal to $50,000
|
| 1122 |
+
- **>50K**: Income greater than $50,000
|
| 1123 |
+
|
| 1124 |
+
**Dataset Features:**
|
| 1125 |
+
|
| 1126 |
+
The dataset contains both qualitative and numerical attributes:
|
| 1127 |
+
|
| 1128 |
+
- **Age**: Numerical value indicating person's age
|
| 1129 |
+
- **Workclass**: Type of employment (Private sector, Self-employed, Federal/Local/State government, etc.)
|
| 1130 |
+
- **Education / Education-num**: Highest education level (High school graduate, Bachelor's, Master's, Doctorate, etc.)
|
| 1131 |
+
- **Marital-status**: Marital status (Married, Divorced, Never married, Separated, Widowed, etc.)
|
| 1132 |
+
- **Occupation**: Work area (Professional, Sales, Administrative, Tech support, Management, etc.)
|
| 1133 |
+
- **Relationship**: Family role (Husband, Wife, Own-child, Not-in-family, Other-relative, Unmarried)
|
| 1134 |
+
- **Race**: Ethnic background (White, Asian-Pacific Islander, Indigenous American, Black, Other)
|
| 1135 |
+
- **Sex**: Gender (Male, Female)
|
| 1136 |
+
- **Capital-gain / Capital-loss**: Investment gains or losses
|
| 1137 |
+
- **Hours-per-week**: Number of working hours per week
|
| 1138 |
+
- **Native-country**: Country of origin (42 countries including United States, Canada, Mexico, Philippines, India, China, Germany, England, and many others)
|
| 1139 |
+
- **Income**: Target label (≤50K or >50K)
|
| 1140 |
+
|
| 1141 |
+
**Model Performance:**
|
| 1142 |
+
|
| 1143 |
+
Our trained RandomForest classifier achieves **85.94% accuracy** on this dataset.
|
| 1144 |
+
""")
|
| 1145 |
+
|
| 1146 |
+
# A/B Testing Debug Info (only for development - hidden from users)
|
| 1147 |
+
# Only show when HICXAI_DEBUG_MODE environment variable is set to 'true'
|
| 1148 |
+
if os.getenv('HICXAI_DEBUG_MODE', 'false').lower() == 'true':
|
| 1149 |
+
st.markdown("---")
|
| 1150 |
+
st.markdown("### 🧪 A/B Testing Information (Debug Mode)")
|
| 1151 |
+
col1, col2, col3 = st.columns(3)
|
| 1152 |
+
with col1:
|
| 1153 |
+
st.markdown(f"**Version:** {config.version}")
|
| 1154 |
+
st.markdown(f"**Session ID:** {config.session_id}")
|
| 1155 |
+
with col2:
|
| 1156 |
+
st.markdown(f"**Assistant:** {config.assistant_name}")
|
| 1157 |
+
st.markdown(f"**SHAP Visuals:** {config.show_shap_visualizations}")
|
| 1158 |
+
with col3:
|
| 1159 |
+
st.markdown(f"**Concurrent Testing:** ✅ Enabled")
|
| 1160 |
+
st.markdown(f"**User Isolation:** ✅ Session-based")
|
| 1161 |
+
|
| 1162 |
+
# Sticky return footer (only show after 2 minutes of engagement)
|
| 1163 |
+
if st.session_state.get("return_raw"):
|
| 1164 |
+
elapsed_time = time.time() - st.session_state.get("start_time", time.time())
|
| 1165 |
+
|
| 1166 |
+
if elapsed_time >= 60: # 1 minute = 60 seconds
|
| 1167 |
+
st.markdown("---")
|
| 1168 |
+
col_a, col_b = st.columns([3, 1])
|
| 1169 |
+
with col_a:
|
| 1170 |
+
remaining = max(0, int(st.session_state.deadline_ts - time.time()))
|
| 1171 |
+
m, s = divmod(remaining, 60)
|
| 1172 |
+
st.caption(f"⏱️ Up to {m}:{s:02d} remaining. You can return anytime.")
|
| 1173 |
+
with col_b:
|
| 1174 |
+
if st.button("✅ Continue to survey", type="primary", use_container_width=True, key="footer_return"):
|
| 1175 |
+
back_to_survey()
|
| 1176 |
+
else:
|
| 1177 |
+
# Show countdown until button appears
|
| 1178 |
+
st.markdown("---")
|
| 1179 |
+
wait_time = int(60 - elapsed_time)
|
| 1180 |
+
m, s = divmod(wait_time, 60)
|
| 1181 |
+
remaining_deadline = max(0, int(st.session_state.deadline_ts - time.time()))
|
| 1182 |
+
md, sd = divmod(remaining_deadline, 60)
|
| 1183 |
+
st.caption(f"⏱️ Session time: up to {md}:{sd:02d} remaining • Continue button appears in: {m}:{s:02d}")
|
src/constraints.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
# Flexible, template-based constraint messages for dynamic, model-driven NLU
|
| 4 |
+
|
| 5 |
+
WELCOME_MSG = "Welcome to the HicXAI agent! Ask me about the model's predictions."
|
| 6 |
+
DATASET_ERROR_MSG = "I only support the adult dataset. Please type a correct name."
|
| 7 |
+
WAIT_MSG = "Wait a moment, I need to learn it."
|
| 8 |
+
RECORD_INFO_MSG = "I recorded the information: {}."
|
| 9 |
+
PREDICT_MSG = "You have: {}."
|
| 10 |
+
QUESTION_MSG = (
|
| 11 |
+
"You can ask me questions about a machine learning model, such as: \n"
|
| 12 |
+
"Why was the prediction made? \nWhy was Y not predicted? \n"
|
| 13 |
+
"What should change in order to make prediction Y? \nPlease type your question."
|
| 14 |
+
)
|
| 15 |
+
REPHRASE_QUESTION_MSG = "Sorry, I don't understand your question. Please rephrase your question."
|
| 16 |
+
NO_CF_MSG = "Sorry, I couldn't find a way to modify {} to change the label."
|
| 17 |
+
CANT_ANSWER_MSG = "I am not capable of answering your question. Questions of this type can currently not be answered by an explainable AI method."
|
| 18 |
+
REPEAT_CAT_FEATURES = "The input value is not valid, please choose one of the following values: {}."
|
| 19 |
+
REPEAT_NUM_FEATURES = "The input value is not valid, please type a value in the range: {}."
|
| 20 |
+
REQUEST_NUMBER_MSG = "That is not a valid number. Please choose another number."
|
| 21 |
+
|
| 22 |
+
# Dynamic clarification/feedback templates (to be filled by agent/NLU at runtime)
|
| 23 |
+
CLARIFY_FEATURE_MSG = "What is your {feature}?"
|
| 24 |
+
CLARIFY_AMBIGUOUS_MSG = "I detected ambiguity in your input: {detail}. Could you clarify?"
|
| 25 |
+
SUGGEST_SIMILAR_QUESTIONS_MSG = (
|
| 26 |
+
"I'm not sure I understood. Did you mean one of these?\n{suggestions}\nPlease type the number of the closest question, or rephrase your question."
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# XAI method routing constants (adopted from XAgent)
|
| 30 |
+
L_SHAP_QUESTION_IDS = [3, 5, 6, 8, 26, 67, 69]
|
| 31 |
+
L_SHAP_QUESTION_FEATURE = [3, 5, 69]
|
| 32 |
+
L_SHAP_QUESTION_SINGLE_FEATURE = [6]
|
| 33 |
+
L_DICE_QUESTION_IDS = [11, 12, 14, 71]
|
| 34 |
+
L_DICE_QUESTION_RELATION_IDS = [71]
|
| 35 |
+
L_ANCHOR_QUESTION_IDS = [20, 15, 13]
|
| 36 |
+
L_FEATURE_QUESTIONS_IDS = [6, 12]
|
| 37 |
+
L_NEW_PREDICT_QUESTION_IDS = [64]
|
| 38 |
+
L_SUPPORT_QUESTIONS_IDS = L_SHAP_QUESTION_IDS + L_DICE_QUESTION_IDS + L_ANCHOR_QUESTION_IDS
|
| 39 |
+
|
| 40 |
+
# Intent to XAI method mapping
|
| 41 |
+
INTENT_TO_XAI_METHOD = {
|
| 42 |
+
"feature_importance": "shap",
|
| 43 |
+
"counterfactual": "dice",
|
| 44 |
+
"local_explanation": "anchor",
|
| 45 |
+
"prototype": "cfproto",
|
| 46 |
+
"what_if": "interactive"
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
# Example usage in agent/NLU:
|
| 50 |
+
# msg = CLARIFY_FEATURE_MSG.format(feature='age')
|
| 51 |
+
# method = INTENT_TO_XAI_METHOD.get(intent, "unknown")
|
| 52 |
+
# msg = CLARIFY_AMBIGUOUS_MSG.format(detail='multiple possible occupations')
|
| 53 |
+
# msg = SUGGEST_SIMILAR_QUESTIONS_MSG.format(suggestions='1. ...\n2. ...')
|
src/data_logger.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data Logger for HicXAI Research
|
| 3 |
+
Tracks user interactions, application data, and behavior metrics
|
| 4 |
+
Saves to private GitHub repository: hicxai-data-private
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from typing import Optional, Dict, Any, List
|
| 11 |
+
import streamlit as st
|
| 12 |
+
import requests
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DataLogger:
|
| 16 |
+
"""Logs user interactions and saves to private GitHub repository"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, prolific_id: str, condition: int, session_id: str):
|
| 19 |
+
self.prolific_id = prolific_id
|
| 20 |
+
self.condition = condition
|
| 21 |
+
self.session_id = session_id
|
| 22 |
+
self.session_start = datetime.now().isoformat()
|
| 23 |
+
|
| 24 |
+
self.interactions: List[Dict] = []
|
| 25 |
+
self.application_data: Dict = {}
|
| 26 |
+
self.behavior_metrics = {
|
| 27 |
+
"total_messages": 0,
|
| 28 |
+
"typed_responses": 0,
|
| 29 |
+
"clicked_responses": 0,
|
| 30 |
+
"help_clicks": 0,
|
| 31 |
+
"explanation_requests": 0,
|
| 32 |
+
"progress_checks": 0,
|
| 33 |
+
"fields_changed": 0
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
def log_interaction(self, interaction_type: str, content: Dict[str, Any]):
|
| 37 |
+
"""Log a single interaction event"""
|
| 38 |
+
self.interactions.append({
|
| 39 |
+
"timestamp": datetime.now().isoformat(),
|
| 40 |
+
"type": interaction_type,
|
| 41 |
+
**content
|
| 42 |
+
})
|
| 43 |
+
|
| 44 |
+
# Update behavior metrics
|
| 45 |
+
if interaction_type == "user_message":
|
| 46 |
+
self.behavior_metrics["total_messages"] += 1
|
| 47 |
+
if content.get("input_method") == "typed":
|
| 48 |
+
self.behavior_metrics["typed_responses"] += 1
|
| 49 |
+
elif content.get("input_method") == "clicked":
|
| 50 |
+
self.behavior_metrics["clicked_responses"] += 1
|
| 51 |
+
elif interaction_type == "help_click":
|
| 52 |
+
self.behavior_metrics["help_clicks"] += 1
|
| 53 |
+
elif interaction_type == "explanation_request":
|
| 54 |
+
self.behavior_metrics["explanation_requests"] += 1
|
| 55 |
+
elif interaction_type == "progress_check":
|
| 56 |
+
self.behavior_metrics["progress_checks"] += 1
|
| 57 |
+
|
| 58 |
+
def update_application_data(self, field: str, value: Any):
|
| 59 |
+
"""Update application field data"""
|
| 60 |
+
if field in self.application_data and self.application_data[field] != value:
|
| 61 |
+
self.behavior_metrics["fields_changed"] += 1
|
| 62 |
+
self.application_data[field] = value
|
| 63 |
+
|
| 64 |
+
def set_prediction(self, prediction: str, probability: float):
|
| 65 |
+
"""Set final prediction result"""
|
| 66 |
+
self.application_data["prediction"] = prediction
|
| 67 |
+
self.application_data["prediction_probability"] = probability
|
| 68 |
+
|
| 69 |
+
def set_feedback(self, feedback_data: Dict[str, Any]):
|
| 70 |
+
"""Set feedback data"""
|
| 71 |
+
self.feedback_data = feedback_data
|
| 72 |
+
|
| 73 |
+
def build_final_data(self) -> Dict[str, Any]:
|
| 74 |
+
"""Build complete data structure for saving"""
|
| 75 |
+
session_end = datetime.now().isoformat()
|
| 76 |
+
start_dt = datetime.fromisoformat(self.session_start)
|
| 77 |
+
end_dt = datetime.fromisoformat(session_end)
|
| 78 |
+
duration = (end_dt - start_dt).total_seconds()
|
| 79 |
+
|
| 80 |
+
# Get A/B testing info
|
| 81 |
+
try:
|
| 82 |
+
from ab_config import config
|
| 83 |
+
ab_version = config.version
|
| 84 |
+
assistant_name = config.assistant_name
|
| 85 |
+
has_shap = config.show_shap_visualizations
|
| 86 |
+
except:
|
| 87 |
+
ab_version = "unknown"
|
| 88 |
+
assistant_name = "unknown"
|
| 89 |
+
has_shap = False
|
| 90 |
+
|
| 91 |
+
return {
|
| 92 |
+
"session_id": self.session_id,
|
| 93 |
+
"prolific_id": self.prolific_id,
|
| 94 |
+
"condition": self.condition,
|
| 95 |
+
"ab_version": ab_version,
|
| 96 |
+
"assistant_name": assistant_name,
|
| 97 |
+
"has_shap_visualizations": has_shap,
|
| 98 |
+
"timestamps": {
|
| 99 |
+
"session_start": self.session_start,
|
| 100 |
+
"session_end": session_end,
|
| 101 |
+
"duration_seconds": duration
|
| 102 |
+
},
|
| 103 |
+
"application_data": self.application_data,
|
| 104 |
+
"interactions": self.interactions,
|
| 105 |
+
"behavior_metrics": self.behavior_metrics,
|
| 106 |
+
"feedback": getattr(self, 'feedback_data', None)
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
def save_to_github(self) -> bool:
|
| 110 |
+
"""Save data to private GitHub repository"""
|
| 111 |
+
# Try Streamlit secrets first, then fall back to env variable (for local dev)
|
| 112 |
+
try:
|
| 113 |
+
github_token = st.secrets.get("GITHUB_DATA_TOKEN") or st.secrets.get("GITHUB_TOKEN")
|
| 114 |
+
except:
|
| 115 |
+
github_token = os.getenv('GITHUB_TOKEN')
|
| 116 |
+
|
| 117 |
+
if not github_token:
|
| 118 |
+
# Fallback to local save
|
| 119 |
+
return self._save_local()
|
| 120 |
+
|
| 121 |
+
try:
|
| 122 |
+
repo = "ksauka/hicxai-data-private"
|
| 123 |
+
date_str = datetime.now().strftime('%Y-%m-%d')
|
| 124 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 125 |
+
filename = f"sessions/{date_str}/{self.prolific_id}_{self.condition}_{timestamp}.json"
|
| 126 |
+
|
| 127 |
+
data = self.build_final_data()
|
| 128 |
+
content = json.dumps(data, indent=2)
|
| 129 |
+
|
| 130 |
+
# GitHub API: Create or update file
|
| 131 |
+
url = f"https://api.github.com/repos/{repo}/contents/{filename}"
|
| 132 |
+
headers = {
|
| 133 |
+
"Authorization": f"token {github_token}",
|
| 134 |
+
"Accept": "application/vnd.github.v3+json"
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
# Check if file exists
|
| 138 |
+
response = requests.get(url, headers=headers)
|
| 139 |
+
sha = response.json().get("sha") if response.status_code == 200 else None
|
| 140 |
+
|
| 141 |
+
# Create/update file
|
| 142 |
+
import base64
|
| 143 |
+
payload = {
|
| 144 |
+
"message": f"Session data: {self.prolific_id} condition {self.condition}",
|
| 145 |
+
"content": base64.b64encode(content.encode()).decode()
|
| 146 |
+
}
|
| 147 |
+
if sha:
|
| 148 |
+
payload["sha"] = sha
|
| 149 |
+
|
| 150 |
+
response = requests.put(url, headers=headers, json=payload)
|
| 151 |
+
|
| 152 |
+
if response.status_code in [200, 201]:
|
| 153 |
+
return True
|
| 154 |
+
else:
|
| 155 |
+
# Fallback to local
|
| 156 |
+
return self._save_local()
|
| 157 |
+
|
| 158 |
+
except Exception as e:
|
| 159 |
+
print(f"GitHub save failed: {e}")
|
| 160 |
+
return self._save_local()
|
| 161 |
+
|
| 162 |
+
def _save_local(self) -> bool:
|
| 163 |
+
"""Fallback: Save to local file"""
|
| 164 |
+
try:
|
| 165 |
+
os.makedirs('data/sessions', exist_ok=True)
|
| 166 |
+
date_str = datetime.now().strftime('%Y-%m-%d')
|
| 167 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 168 |
+
filename = f"data/sessions/{date_str}_{self.prolific_id}_{self.condition}_{timestamp}.json"
|
| 169 |
+
|
| 170 |
+
data = self.build_final_data()
|
| 171 |
+
with open(filename, 'w') as f:
|
| 172 |
+
json.dump(data, f, indent=2)
|
| 173 |
+
|
| 174 |
+
return True
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(f"Local save failed: {e}")
|
| 177 |
+
return False
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def init_logger() -> Optional[DataLogger]:
|
| 181 |
+
"""Initialize data logger from query parameters"""
|
| 182 |
+
if "data_logger" in st.session_state:
|
| 183 |
+
return st.session_state.data_logger
|
| 184 |
+
|
| 185 |
+
try:
|
| 186 |
+
# Get query params
|
| 187 |
+
try:
|
| 188 |
+
qs = dict(st.query_params)
|
| 189 |
+
except:
|
| 190 |
+
qs = st.experimental_get_query_params()
|
| 191 |
+
|
| 192 |
+
def _as_str(v):
|
| 193 |
+
return v[0] if isinstance(v, list) and v else (v if isinstance(v, str) else "")
|
| 194 |
+
|
| 195 |
+
# Extract Prolific ID and condition
|
| 196 |
+
prolific_id = _as_str(qs.get("pid") or qs.get("PROLIFIC_PID", "unknown"))
|
| 197 |
+
condition_str = _as_str(qs.get("cond", "0"))
|
| 198 |
+
condition = int(condition_str) if condition_str.isdigit() else 0
|
| 199 |
+
|
| 200 |
+
# Generate session ID
|
| 201 |
+
from ab_config import config
|
| 202 |
+
session_id = config.session_id
|
| 203 |
+
|
| 204 |
+
logger = DataLogger(prolific_id, condition, session_id)
|
| 205 |
+
st.session_state.data_logger = logger
|
| 206 |
+
|
| 207 |
+
return logger
|
| 208 |
+
|
| 209 |
+
except Exception as e:
|
| 210 |
+
print(f"Failed to initialize logger: {e}")
|
| 211 |
+
return None
|
src/env_loader.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Environment loader for HicXAI agent
|
| 3 |
+
Loads configuration from .env file securely
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
def _load_env_file(path: Path) -> bool:
|
| 10 |
+
if not path.exists():
|
| 11 |
+
return False
|
| 12 |
+
with open(path, 'r') as f:
|
| 13 |
+
for line in f:
|
| 14 |
+
line = line.strip()
|
| 15 |
+
if line and not line.startswith('#') and '=' in line:
|
| 16 |
+
key, value = line.split('=', 1)
|
| 17 |
+
k = key.strip()
|
| 18 |
+
v = value.strip()
|
| 19 |
+
# Do NOT override variables already set in the process env
|
| 20 |
+
# This preserves values set by entrypoints (e.g., app_v1.py sets HICXAI_VERSION=v1)
|
| 21 |
+
if k not in os.environ:
|
| 22 |
+
os.environ[k] = v
|
| 23 |
+
return True
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_env() -> bool:
|
| 27 |
+
"""Load environment variables from .env.local (preferred) and .env files."""
|
| 28 |
+
root = Path(__file__).parent.parent
|
| 29 |
+
loaded_any = False
|
| 30 |
+
# Prefer .env.local for developer-specific overrides
|
| 31 |
+
loaded_any = _load_env_file(root / '.env.local') or loaded_any
|
| 32 |
+
# Then load .env as the shared defaults
|
| 33 |
+
loaded_any = _load_env_file(root / '.env') or loaded_any
|
| 34 |
+
return loaded_any
|
| 35 |
+
|
| 36 |
+
# Load .env on import
|
| 37 |
+
load_env()
|
src/github_saver.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GitHubSaver utility: Save user feedback or logs directly to a GitHub repository using the GitHub API.
|
| 3 |
+
Requires a GitHub personal access token with repo permissions.
|
| 4 |
+
"""
|
| 5 |
+
import requests
|
| 6 |
+
import base64
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
def save_to_github(repo, path, content, commit_message, github_token):
|
| 10 |
+
"""
|
| 11 |
+
Save content to a file in a GitHub repo (creates or updates the file).
|
| 12 |
+
repo: 'username/repo'
|
| 13 |
+
path: path in the repo (e.g., 'feedback/user1.txt')
|
| 14 |
+
content: string content to save
|
| 15 |
+
commit_message: commit message
|
| 16 |
+
github_token: personal access token
|
| 17 |
+
"""
|
| 18 |
+
api_url = f"https://api.github.com/repos/{repo}/contents/{path}"
|
| 19 |
+
headers = {
|
| 20 |
+
"Authorization": f"token {github_token}",
|
| 21 |
+
"Accept": "application/vnd.github.v3+json"
|
| 22 |
+
}
|
| 23 |
+
# Check if file exists
|
| 24 |
+
r = requests.get(api_url, headers=headers)
|
| 25 |
+
if r.status_code == 200:
|
| 26 |
+
sha = r.json()['sha']
|
| 27 |
+
else:
|
| 28 |
+
sha = None
|
| 29 |
+
data = {
|
| 30 |
+
"message": commit_message,
|
| 31 |
+
"content": base64.b64encode(content.encode()).decode(),
|
| 32 |
+
"branch": "main"
|
| 33 |
+
}
|
| 34 |
+
if sha:
|
| 35 |
+
data["sha"] = sha
|
| 36 |
+
r = requests.put(api_url, headers=headers, json=data)
|
| 37 |
+
if r.status_code in [200, 201]:
|
| 38 |
+
return True
|
| 39 |
+
else:
|
| 40 |
+
print(f"GitHub API error: {r.status_code} {r.text}")
|
| 41 |
+
return False
|
| 42 |
+
|
| 43 |
+
# Example usage in Streamlit:
|
| 44 |
+
# import streamlit as st
|
| 45 |
+
# from github_saver import save_to_github
|
| 46 |
+
#
|
| 47 |
+
# feedback = st.text_area("Your feedback")
|
| 48 |
+
# if st.button("Submit Feedback"):
|
| 49 |
+
# success = save_to_github(
|
| 50 |
+
# repo="yourusername/yourrepo",
|
| 51 |
+
# path=f"feedback/{st.session_state.get('user_id','anon')}.txt",
|
| 52 |
+
# content=feedback,
|
| 53 |
+
# commit_message="User feedback submission",
|
| 54 |
+
# github_token=st.secrets["GITHUB_TOKEN"]
|
| 55 |
+
# )
|
| 56 |
+
# if success:
|
| 57 |
+
# st.success("Feedback saved to GitHub!")
|
| 58 |
+
# else:
|
| 59 |
+
# st.error("Failed to save feedback.")
|
src/load_adult_data.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
def load_adult_data(data_dir, balance=False, discretize=True):
|
| 8 |
+
"""
|
| 9 |
+
Load the Adult dataset with robust feature handling, adapted from XAgent/Agent/utils.py.
|
| 10 |
+
"""
|
| 11 |
+
data_path = os.path.join(data_dir, 'adult.data')
|
| 12 |
+
json_path = os.path.join(os.path.dirname(data_dir), 'dataset_info', 'adult.json')
|
| 13 |
+
columns = [
|
| 14 |
+
'age', 'workclass', 'fnlwgt', 'education', 'education_num', 'marital_status',
|
| 15 |
+
'occupation', 'relationship', 'race', 'sex', 'capital_gain', 'capital_loss',
|
| 16 |
+
'hours_per_week', 'native_country', 'income'
|
| 17 |
+
]
|
| 18 |
+
df = pd.read_csv(data_path, names=columns, skipinitialspace=True)
|
| 19 |
+
# Remove rows with missing values (marked as '?')
|
| 20 |
+
df = df.replace('?', np.nan)
|
| 21 |
+
df = df.dropna()
|
| 22 |
+
# Convert numerical columns to appropriate types
|
| 23 |
+
num_cols = ['age', 'fnlwgt', 'education_num', 'capital_gain', 'capital_loss', 'hours_per_week']
|
| 24 |
+
for col in num_cols:
|
| 25 |
+
df[col] = pd.to_numeric(df[col], errors='coerce')
|
| 26 |
+
# Optionally encode categorical variables using one-hot encoding
|
| 27 |
+
cat_cols = [
|
| 28 |
+
'workclass', 'education', 'marital_status', 'occupation',
|
| 29 |
+
'relationship', 'race', 'sex', 'native_country'
|
| 30 |
+
]
|
| 31 |
+
if discretize:
|
| 32 |
+
df = pd.get_dummies(df, columns=cat_cols)
|
| 33 |
+
# Encode target
|
| 34 |
+
df['income'] = df['income'].apply(lambda x: 1 if '>50K' in str(x) else 0)
|
| 35 |
+
# Load metadata
|
| 36 |
+
with open(json_path, 'r') as f:
|
| 37 |
+
meta = json.load(f)
|
| 38 |
+
# Add feature names, types, and valid values to meta if missing
|
| 39 |
+
meta.setdefault('num_features', num_cols)
|
| 40 |
+
meta.setdefault('cat_features', cat_cols)
|
| 41 |
+
for cat in cat_cols:
|
| 42 |
+
meta.setdefault('feature_values', {})
|
| 43 |
+
meta['feature_values'][cat] = sorted(df[cat].dropna().unique().tolist()) if cat in df else []
|
| 44 |
+
# Add feature ranges for numeric features
|
| 45 |
+
meta.setdefault('feature_ranges', {})
|
| 46 |
+
for num in num_cols:
|
| 47 |
+
if num in df:
|
| 48 |
+
meta['feature_ranges'][num] = (float(df[num].min()), float(df[num].max()))
|
| 49 |
+
return df, meta
|
| 50 |
+
|
| 51 |
+
if __name__ == '__main__':
|
| 52 |
+
data_dir = os.path.join(os.path.dirname(__file__), '..', 'data')
|
| 53 |
+
df, meta = load_adult_data(data_dir)
|
| 54 |
+
print('Data shape:', df.shape)
|
| 55 |
+
print('Metadata:', meta)
|
src/loan_assistant.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/natural_conversation.py
ADDED
|
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Natural conversation helpers: OpenAI GPT enhancement for explanations (gpt-4o-mini by default).
|
| 3 |
+
|
| 4 |
+
Behavior:
|
| 5 |
+
- If OPENAI_API_KEY is set (via env or Streamlit Secrets), use OpenAI to enhance explanations
|
| 6 |
+
- Style is determined by anthropomorphism level:
|
| 7 |
+
- HIGH: Warm, conversational, actionable (Luna style)
|
| 8 |
+
- LOW: Professional, technical, direct (AI Assistant style)
|
| 9 |
+
- Otherwise, return the original text unchanged
|
| 10 |
+
|
| 11 |
+
Notes:
|
| 12 |
+
- Keep outputs faithful: do not invent numbers or facts; preserve lists and key points
|
| 13 |
+
- This module is optional. LoanAssistant guards imports accordingly
|
| 14 |
+
"""
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
from typing import Any, Dict, Optional
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
# Try to import streamlit to fetch secrets when running on Streamlit Cloud
|
| 22 |
+
try:
|
| 23 |
+
import streamlit as st # type: ignore
|
| 24 |
+
except Exception: # pragma: no cover - optional dependency
|
| 25 |
+
st = None # type: ignore
|
| 26 |
+
|
| 27 |
+
# Ensure .env file is loaded (in case env_loader hasn't run yet)
|
| 28 |
+
def _ensure_env_loaded():
|
| 29 |
+
"""Load .env file if not already loaded"""
|
| 30 |
+
# Try to load .env files (prefer .env.local over .env, like env_loader.py)
|
| 31 |
+
try:
|
| 32 |
+
root = Path(__file__).parent.parent
|
| 33 |
+
env_files = [root / '.env.local', root / '.env'] # Check .env.local first
|
| 34 |
+
|
| 35 |
+
for env_file in env_files:
|
| 36 |
+
if env_file.exists():
|
| 37 |
+
with open(env_file, 'r') as f:
|
| 38 |
+
for line in f:
|
| 39 |
+
line = line.strip()
|
| 40 |
+
if not line or line.startswith('#') or '=' not in line:
|
| 41 |
+
continue
|
| 42 |
+
|
| 43 |
+
key, value = line.split('=', 1)
|
| 44 |
+
k = key.strip()
|
| 45 |
+
v = value.strip()
|
| 46 |
+
|
| 47 |
+
# ALWAYS override OPENAI_API_KEY to ensure we have the latest from .env files
|
| 48 |
+
if k == "OPENAI_API_KEY" and v:
|
| 49 |
+
os.environ[k] = v
|
| 50 |
+
elif k not in os.environ:
|
| 51 |
+
os.environ[k] = v
|
| 52 |
+
except Exception:
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _should_use_genai() -> bool:
|
| 57 |
+
"""LLM is REQUIRED for natural conversation - always returns True if API key available."""
|
| 58 |
+
_ensure_env_loaded()
|
| 59 |
+
|
| 60 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 61 |
+
|
| 62 |
+
# Allow pulling key from Streamlit Secrets when not present in env
|
| 63 |
+
if not api_key and st is not None:
|
| 64 |
+
try:
|
| 65 |
+
key = st.secrets.get("OPENAI_API_KEY", None) # type: ignore[attr-defined]
|
| 66 |
+
if key:
|
| 67 |
+
os.environ["OPENAI_API_KEY"] = str(key)
|
| 68 |
+
api_key = str(key)
|
| 69 |
+
except Exception:
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
if not api_key:
|
| 73 |
+
# Warn if missing - this is now required for quality conversation
|
| 74 |
+
import warnings
|
| 75 |
+
warnings.warn("OPENAI_API_KEY not found - conversation quality will be degraded")
|
| 76 |
+
|
| 77 |
+
return bool(api_key)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _get_openai_client():
|
| 81 |
+
"""Return an OpenAI client configured from environment/Streamlit secrets.
|
| 82 |
+
|
| 83 |
+
Honors optional base URL (HICXAI_OPENAI_BASE_URL or OPENAI_BASE_URL) for proxies.
|
| 84 |
+
"""
|
| 85 |
+
_ = _should_use_genai()
|
| 86 |
+
api_key = os.environ.get("OPENAI_API_KEY")
|
| 87 |
+
|
| 88 |
+
if not api_key:
|
| 89 |
+
return None
|
| 90 |
+
|
| 91 |
+
base_url = (
|
| 92 |
+
os.environ.get("HICXAI_OPENAI_BASE_URL")
|
| 93 |
+
or os.environ.get("OPENAI_BASE_URL")
|
| 94 |
+
or None
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
try:
|
| 98 |
+
from openai import OpenAI # type: ignore
|
| 99 |
+
if base_url:
|
| 100 |
+
return OpenAI(api_key=api_key, base_url=base_url)
|
| 101 |
+
return OpenAI(api_key=api_key)
|
| 102 |
+
except Exception:
|
| 103 |
+
return None
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _remove_letter_formatting(text: str) -> str:
|
| 107 |
+
"""Remove letter/memo formatting elements from text (LOW anthropomorphism only)."""
|
| 108 |
+
import re
|
| 109 |
+
|
| 110 |
+
# Remove subject lines
|
| 111 |
+
text = re.sub(r'^Subject:.*?\n\n?', '', text, flags=re.IGNORECASE | re.MULTILINE)
|
| 112 |
+
|
| 113 |
+
# Remove salutations (Dear X, Hello X, etc.)
|
| 114 |
+
text = re.sub(r'^(Dear|Hello|Hi|Greetings)\s+\[?[^\]]*\]?\s*[,:]?\s*\n\n?', '', text, flags=re.IGNORECASE | re.MULTILINE)
|
| 115 |
+
|
| 116 |
+
# Remove signature blocks (Sincerely, Best regards, etc.)
|
| 117 |
+
text = re.sub(r'\n\n?(Sincerely|Best regards?|Regards|Yours truly|Respectfully|Thank you)[,]?\s*\n.*?(\[.*?\].*?\n){0,3}.*$', '', text, flags=re.IGNORECASE | re.DOTALL)
|
| 118 |
+
|
| 119 |
+
# Remove placeholder blocks like [Your Name], [Your Position], [Contact Info]
|
| 120 |
+
text = re.sub(r'\n\[Your [^\]]+\]\s*', '', text, flags=re.MULTILINE)
|
| 121 |
+
text = re.sub(r'\n\[Client[^\]]*\]\s*', '', text, flags=re.MULTILINE)
|
| 122 |
+
|
| 123 |
+
# Remove unwanted document-style headers that LLM might add
|
| 124 |
+
text = re.sub(r'^Counterfactual Analysis:\s*', '', text, flags=re.MULTILINE)
|
| 125 |
+
text = re.sub(r'\n\*\*Current Decision:\*\*\s*Application (not )?approved\s*\n', '\n', text, flags=re.MULTILINE)
|
| 126 |
+
|
| 127 |
+
return text.strip()
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _build_system_prompt(high_anthropomorphism: bool = True) -> str:
|
| 131 |
+
"""Build system prompt respecting anthropomorphism condition."""
|
| 132 |
+
if high_anthropomorphism:
|
| 133 |
+
# Luna: Warm, friendly, conversational, actionable, CHATTY
|
| 134 |
+
return (
|
| 135 |
+
"You are Luna, a friendly loan assistant having a real conversation with someone. "
|
| 136 |
+
"Be CONVERSATIONAL and engaging - like a knowledgeable friend who loves talking about finance and helping people understand loans! "
|
| 137 |
+
"Add relevant context and insights about the loan process, credit factors, financial planning - make it educational and interesting! "
|
| 138 |
+
"Share brief relevant observations (e.g., 'That's actually a really common situation!' or 'Interestingly, this factor...'). "
|
| 139 |
+
"Use natural transitions and connectors like 'So here's what I'm seeing...', 'Let me explain...', 'This is interesting because...'. "
|
| 140 |
+
"Be warm, supportive, and genuinely human - someone who cares about helping them understand their financial situation. "
|
| 141 |
+
"Write like you're a real person who's passionate about this work, not a robot reading a script. "
|
| 142 |
+
"Preserve ALL factual content, numbers, and data points exactly. "
|
| 143 |
+
"CRITICAL: Keep all dollar signs ($), commas in numbers, and 'to' with spaces (e.g., '$5,000.00 to $7,000'). "
|
| 144 |
+
"Do NOT remove formatting from monetary values or ranges. "
|
| 145 |
+
"Use 2-3 emojis naturally where they fit the emotional context. "
|
| 146 |
+
"Be chatty but focused - everything should relate to their loan, finances, or understanding the process. "
|
| 147 |
+
"Structure with clear formatting (bullets, short paragraphs). Add personality without losing clarity. "
|
| 148 |
+
"Never add meta-commentary - just speak naturally and directly as Luna would. "
|
| 149 |
+
"Do not fabricate data. Do not change any numeric values."
|
| 150 |
+
)
|
| 151 |
+
else:
|
| 152 |
+
# AI Assistant: Professional, technical, direct
|
| 153 |
+
return (
|
| 154 |
+
"You are a professional AI loan advisor explaining this to a client. "
|
| 155 |
+
"Rewrite this explanation in clear, professional language - direct and informative. "
|
| 156 |
+
"Write like a knowledgeable professional communicating important information. "
|
| 157 |
+
"Preserve ALL factual content, numbers, and data points exactly. "
|
| 158 |
+
"CRITICAL: Keep all dollar signs ($), commas in numbers, and 'to' with spaces (e.g., '$5,000.00 to $7,000'). "
|
| 159 |
+
"Do NOT remove formatting from monetary values or ranges. "
|
| 160 |
+
"Be direct, clear, and authoritative. No emojis. No casual language. "
|
| 161 |
+
"CRITICAL: DO NOT format as a letter or memo. NO 'Dear', NO 'Subject:', NO salutations, "
|
| 162 |
+
"NO closings like 'Sincerely', NO signature blocks, NO [Client's Name] placeholders. "
|
| 163 |
+
"DO NOT add document-style headers like 'Counterfactual Analysis:', 'Current Decision:', etc. "
|
| 164 |
+
"If the input already has a section header (like '**Profile Modifications for Approval**'), keep it as-is. "
|
| 165 |
+
"Start directly with the content. End with the last informational sentence. "
|
| 166 |
+
"Use technical precision and structured formatting (bullets, numbered lists). "
|
| 167 |
+
"Keep the original section structure - don't add new sections or reorganize. "
|
| 168 |
+
"Never add meta-commentary - just provide the professional explanation directly. "
|
| 169 |
+
"Do not fabricate data. Do not change any numeric values."
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _compose_messages(response: str, context: Optional[Dict[str, Any]], high_anthropomorphism: bool = True):
|
| 174 |
+
sys_prompt = _build_system_prompt(high_anthropomorphism)
|
| 175 |
+
ctx_lines = []
|
| 176 |
+
if context:
|
| 177 |
+
for k, v in context.items():
|
| 178 |
+
if v is None:
|
| 179 |
+
continue
|
| 180 |
+
ctx_lines.append(f"- {k}: {v}")
|
| 181 |
+
ctx_blob = "\n".join(ctx_lines) if ctx_lines else "(no extra context)"
|
| 182 |
+
|
| 183 |
+
user_prompt = (
|
| 184 |
+
"Rewrite the following explanation for the end user. Preserve all factual content and numbers.\n\n"
|
| 185 |
+
f"Context:\n{ctx_blob}\n\n"
|
| 186 |
+
f"Original Explanation:\n{response}\n\n"
|
| 187 |
+
"Return only the rewritten explanation text."
|
| 188 |
+
)
|
| 189 |
+
return [
|
| 190 |
+
{"role": "system", "content": sys_prompt},
|
| 191 |
+
{"role": "user", "content": user_prompt},
|
| 192 |
+
]
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def handle_meta_question(field: str, user_input: str, high_anthropomorphism: bool = True) -> Optional[str]:
|
| 196 |
+
"""Detect and handle meta-questions about the form process using LLM.
|
| 197 |
+
|
| 198 |
+
This function checks if user is asking a question about the process (why, what, how)
|
| 199 |
+
rather than providing data. The LLM will generate a contextual explanation.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
field: The field name being asked about
|
| 203 |
+
user_input: The user's question/input
|
| 204 |
+
high_anthropomorphism: If True, use warm Luna tone. If False, use professional tone.
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
Explanation if it's a meta-question, None if it's a data attempt.
|
| 208 |
+
"""
|
| 209 |
+
# Quick pattern check - if it looks like a data attempt, skip LLM call
|
| 210 |
+
user_lower = user_input.lower().strip()
|
| 211 |
+
|
| 212 |
+
# Check if it's clearly a question word
|
| 213 |
+
question_words = ['why', 'what', 'how', 'where', 'when', 'who', 'explain', 'tell me']
|
| 214 |
+
is_likely_question = any(user_lower.startswith(word) for word in question_words)
|
| 215 |
+
|
| 216 |
+
# Also check for common question patterns
|
| 217 |
+
is_likely_question = is_likely_question or user_input.strip().endswith('?')
|
| 218 |
+
|
| 219 |
+
# If doesn't look like a question at all, return None immediately
|
| 220 |
+
if not is_likely_question:
|
| 221 |
+
return None
|
| 222 |
+
|
| 223 |
+
if not _should_use_genai():
|
| 224 |
+
# Fallback for when LLM unavailable
|
| 225 |
+
field_explanations = {
|
| 226 |
+
'age': "We need your age because it's a factor in assessing loan eligibility and repayment capacity.",
|
| 227 |
+
'workclass': "Your employment type helps us understand your income stability and employment security.",
|
| 228 |
+
'education': "Education level is considered as it often correlates with income potential and financial literacy.",
|
| 229 |
+
'occupation': "Your job type helps us assess income stability and employment prospects.",
|
| 230 |
+
'hours_per_week': "Work hours indicate earning capacity and employment stability.",
|
| 231 |
+
'capital_gain': "Capital gains show additional income sources beyond regular employment.",
|
| 232 |
+
'capital_loss': "Capital losses affect your overall financial picture and tax obligations.",
|
| 233 |
+
'native_country': "Country of origin is a demographic factor in our dataset.",
|
| 234 |
+
'marital_status': "Marital status can affect financial obligations and household income.",
|
| 235 |
+
'relationship': "Household relationship helps us understand your financial situation.",
|
| 236 |
+
'race': "This demographic information is part of our model's training data.",
|
| 237 |
+
'sex': "Gender is a demographic factor in our dataset, though we acknowledge its limitations."
|
| 238 |
+
}
|
| 239 |
+
explanation = field_explanations.get(field, f"This information about {field.replace('_', ' ')} helps us assess your loan application.")
|
| 240 |
+
return explanation
|
| 241 |
+
|
| 242 |
+
try:
|
| 243 |
+
client = _get_openai_client()
|
| 244 |
+
if client is None:
|
| 245 |
+
return None
|
| 246 |
+
|
| 247 |
+
if high_anthropomorphism:
|
| 248 |
+
system_prompt = (
|
| 249 |
+
"You are Luna, a friendly and warm AI loan assistant. The user is asking a question about why "
|
| 250 |
+
"you need certain information, rather than providing data. Be CONVERSATIONAL and educational! "
|
| 251 |
+
"Explain warmly why this information matters for loan decisions - share interesting insights about how "
|
| 252 |
+
"lenders evaluate this factor or how it affects creditworthiness. Make it engaging and informative! "
|
| 253 |
+
"Use 2-3 emojis naturally. Aim for 3-4 sentences that are genuinely interesting and helpful. "
|
| 254 |
+
"After explaining with personality and context, gently prompt them to provide the information."
|
| 255 |
+
)
|
| 256 |
+
else:
|
| 257 |
+
system_prompt = (
|
| 258 |
+
"You are Luna, a professional AI loan assistant. The user is asking about why certain information "
|
| 259 |
+
"is needed. Explain concisely why this field is important for loan assessment. No emojis. "
|
| 260 |
+
"Keep it to 2-3 sentences. Then prompt for the information."
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
field_friendly = field.replace('_', ' ')
|
| 264 |
+
user_prompt = (
|
| 265 |
+
f"The user asked: '{user_input}'\n"
|
| 266 |
+
f"They are responding to a request for their {field_friendly}.\n"
|
| 267 |
+
f"Explain why we need this information and then ask them to provide it."
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
model_name = os.getenv("HICXAI_OPENAI_MODEL", "gpt-4o-mini")
|
| 271 |
+
# Higher temperature for HIGH anthropomorphism = more personality
|
| 272 |
+
temperature = float(os.getenv("HICXAI_TEMPERATURE", "0.8" if high_anthropomorphism else "0.5"))
|
| 273 |
+
|
| 274 |
+
completion = client.chat.completions.create(
|
| 275 |
+
model=model_name,
|
| 276 |
+
messages=[
|
| 277 |
+
{"role": "system", "content": system_prompt},
|
| 278 |
+
{"role": "user", "content": user_prompt}
|
| 279 |
+
],
|
| 280 |
+
temperature=temperature,
|
| 281 |
+
max_tokens=300,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
result = completion.choices[0].message.content if completion and completion.choices else None
|
| 285 |
+
return result
|
| 286 |
+
except Exception:
|
| 287 |
+
return None
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def enhance_validation_message(field: str, user_input: str, expected_format: str, attempt: int = 1, high_anthropomorphism: bool = True) -> Optional[str]:
|
| 291 |
+
"""Generate a validation message using LLM (REQUIRED for natural conversation).
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
field: The field name being validated
|
| 295 |
+
user_input: The invalid input provided by user
|
| 296 |
+
expected_format: Description of the expected format
|
| 297 |
+
attempt: Which attempt this is (1, 2, 3+)
|
| 298 |
+
high_anthropomorphism: If True, use warm/friendly Luna tone. If False, use professional AI Assistant tone.
|
| 299 |
+
|
| 300 |
+
Returns None only if LLM fails - caller should have hardcoded fallback.
|
| 301 |
+
"""
|
| 302 |
+
if not _should_use_genai():
|
| 303 |
+
return None # Will use fallback, but this should not happen in production
|
| 304 |
+
|
| 305 |
+
try:
|
| 306 |
+
client = _get_openai_client()
|
| 307 |
+
if client is None:
|
| 308 |
+
return None
|
| 309 |
+
|
| 310 |
+
if high_anthropomorphism:
|
| 311 |
+
system_prompt = (
|
| 312 |
+
"You are Luna, a friendly and warm AI loan assistant. Generate a conversational, empathetic validation message "
|
| 313 |
+
"when a user enters invalid input. Be encouraging and understanding - acknowledge their attempt positively! "
|
| 314 |
+
"Add a brief helpful tip or context (e.g., 'This field is used to...', 'A lot of people...'). "
|
| 315 |
+
"Use 2-3 emojis naturally. Aim for 2-3 sentences that feel like a real person helping. "
|
| 316 |
+
"Guide them gently and warmly toward the correct format."
|
| 317 |
+
)
|
| 318 |
+
else:
|
| 319 |
+
system_prompt = (
|
| 320 |
+
"You are Luna, a professional AI loan assistant. Generate a clear, concise validation message "
|
| 321 |
+
"when a user enters invalid input. Be direct and helpful. No emojis. "
|
| 322 |
+
"Keep it to 1-2 sentences. Focus on what the user needs to provide."
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
user_prompt = (
|
| 326 |
+
f"The user entered '{user_input}' for the field '{field.replace('_', ' ')}', but this is invalid. "
|
| 327 |
+
f"Expected format: {expected_format}. "
|
| 328 |
+
f"This is attempt #{attempt}. "
|
| 329 |
+
f"Generate a friendly validation message that helps them correct their input."
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
model_name = os.getenv("HICXAI_OPENAI_MODEL", "gpt-4o-mini")
|
| 333 |
+
# Higher temperature for HIGH anthropomorphism = more personality; lower for LOW = more consistent
|
| 334 |
+
temperature = float(os.getenv("HICXAI_TEMPERATURE", "0.8" if high_anthropomorphism else "0.5"))
|
| 335 |
+
|
| 336 |
+
completion = client.chat.completions.create(
|
| 337 |
+
model=model_name,
|
| 338 |
+
messages=[
|
| 339 |
+
{"role": "system", "content": system_prompt},
|
| 340 |
+
{"role": "user", "content": user_prompt}
|
| 341 |
+
],
|
| 342 |
+
temperature=temperature,
|
| 343 |
+
max_tokens=400,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
result = completion.choices[0].message.content if completion and completion.choices else None
|
| 347 |
+
return result
|
| 348 |
+
except Exception:
|
| 349 |
+
return None
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def generate_from_data(data: Dict[str, Any], explanation_type: str = "shap", high_anthropomorphism: bool = True) -> Optional[str]:
|
| 353 |
+
"""Generate explanation from structured data using LLM (data-driven approach).
|
| 354 |
+
|
| 355 |
+
Args:
|
| 356 |
+
data: Structured data dictionary containing:
|
| 357 |
+
- For SHAP: base_value, predicted_probability, threshold, top_features, loan_approved, etc.
|
| 358 |
+
- For DiCE: current_values, suggested_changes, target_class, etc.
|
| 359 |
+
explanation_type: Type of explanation ("shap", "dice", "anchor")
|
| 360 |
+
high_anthropomorphism: If True, use warm Luna style. If False, use professional AI Assistant style.
|
| 361 |
+
|
| 362 |
+
Returns:
|
| 363 |
+
Generated explanation string, or None if LLM fails
|
| 364 |
+
"""
|
| 365 |
+
if not _should_use_genai():
|
| 366 |
+
return None
|
| 367 |
+
|
| 368 |
+
try:
|
| 369 |
+
client = _get_openai_client()
|
| 370 |
+
if client is None:
|
| 371 |
+
return None
|
| 372 |
+
|
| 373 |
+
# Build system prompt based on anthropomorphism level and explanation type
|
| 374 |
+
if high_anthropomorphism:
|
| 375 |
+
if explanation_type == "shap":
|
| 376 |
+
system_prompt = (
|
| 377 |
+
"You are Luna, a warm and empathetic AI loan assistant who LOVES helping people understand their finances! "
|
| 378 |
+
"Explaining why a loan decision was made - be CONVERSATIONAL and engaging! "
|
| 379 |
+
"Generate a natural, chatty explanation from the provided data. Add relevant context and insights! "
|
| 380 |
+
"Use natural transitions like 'So let me break this down for you...', 'Here's what's really interesting...', 'The good news is...'. "
|
| 381 |
+
"Use 2-4 emojis naturally where they fit the emotional context. Sound like a real person who's passionate about this! "
|
| 382 |
+
"For APPROVED loans: Be celebratory! Share why their profile is strong. Add encouraging observations. "
|
| 383 |
+
"For DENIED loans: Be empathetic but conversational - explain both positive factors (that helped) and limiting factors (that held back). "
|
| 384 |
+
"Use the 'tug-of-war' metaphor for denials - make it relatable and understandable. "
|
| 385 |
+
"Add brief educational insights about credit factors, what lenders look for, how things work. "
|
| 386 |
+
"Structure clearly with markdown formatting. "
|
| 387 |
+
"Preserve all numeric values exactly as provided. "
|
| 388 |
+
"Make it feel like a knowledgeable friend explaining something they're excited about - personal, warm, genuinely helpful!"
|
| 389 |
+
)
|
| 390 |
+
elif explanation_type == "dice":
|
| 391 |
+
system_prompt = (
|
| 392 |
+
"You are Luna, a warm and empathetic AI loan assistant suggesting changes to improve approval chances. "
|
| 393 |
+
"Be CONVERSATIONAL and encouraging - like a financial advisor who genuinely wants to help! "
|
| 394 |
+
"Generate a natural, chatty explanation from the provided data. "
|
| 395 |
+
"Use transitions like 'Great news - here's what could help...', 'So I've analyzed some scenarios...', 'Let me show you...'. "
|
| 396 |
+
"Use 2-3 emojis naturally. Be encouraging, actionable, and add helpful financial context! "
|
| 397 |
+
"Share brief insights about why these changes matter, what lenders consider, how to build stronger credit. "
|
| 398 |
+
"Structure with clear sections and numbered lists. Make it feel like personalized advice! "
|
| 399 |
+
"Mention the What-If Lab for interactive exploration. "
|
| 400 |
+
"Preserve all numeric values exactly as provided."
|
| 401 |
+
)
|
| 402 |
+
else:
|
| 403 |
+
system_prompt = (
|
| 404 |
+
"You are Luna, a warm AI loan assistant who loves helping people understand finances! "
|
| 405 |
+
"Generate a natural, conversational explanation from the provided data. "
|
| 406 |
+
"Be chatty and engaging - add relevant context and make it educational! "
|
| 407 |
+
"Use 2-3 emojis naturally. Be warm, personable, and genuinely helpful. "
|
| 408 |
+
"Preserve all numeric values exactly as provided."
|
| 409 |
+
)
|
| 410 |
+
else:
|
| 411 |
+
if explanation_type == "shap":
|
| 412 |
+
system_prompt = (
|
| 413 |
+
"You are a professional AI loan advisor explaining why a loan decision was made. "
|
| 414 |
+
"Generate a clear, structured explanation from the provided data. "
|
| 415 |
+
"NO emojis. NO casual language. Use professional terminology. "
|
| 416 |
+
"For APPROVED loans: Use 'Feature Impact Analysis' structure with 'Key Contributing Factors'. "
|
| 417 |
+
"For DENIED loans: Use 'Feature Impact Analysis' with separate 'Positive Factors' and 'Negative Factors' sections. "
|
| 418 |
+
"Include a 'Decision Summary' section with precise numbers. "
|
| 419 |
+
"Use markdown formatting with bold headers and bullet points. "
|
| 420 |
+
"Preserve all numeric values exactly as provided. "
|
| 421 |
+
"Be direct and technical, not conversational."
|
| 422 |
+
)
|
| 423 |
+
elif explanation_type == "dice":
|
| 424 |
+
system_prompt = (
|
| 425 |
+
"You are a professional AI loan advisor suggesting profile modifications. "
|
| 426 |
+
"Generate a clear, structured explanation from the provided data. "
|
| 427 |
+
"NO emojis. NO casual language. Use professional terminology. "
|
| 428 |
+
"Structure with sections: 'Recommended Profile Modifications', 'Analysis Methodology', 'Additional Analysis'. "
|
| 429 |
+
"Use numbered lists for changes. "
|
| 430 |
+
"Mention the What-If Lab for scenario testing. "
|
| 431 |
+
"Preserve all numeric values exactly as provided."
|
| 432 |
+
)
|
| 433 |
+
else:
|
| 434 |
+
system_prompt = (
|
| 435 |
+
"You are a professional AI loan advisor. Generate a clear explanation from the provided data. "
|
| 436 |
+
"NO emojis. Use professional language. "
|
| 437 |
+
"Preserve all numeric values exactly as provided."
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# Build user prompt with structured data
|
| 441 |
+
import json
|
| 442 |
+
data_json = json.dumps(data, indent=2, default=str)
|
| 443 |
+
user_prompt = (
|
| 444 |
+
f"Generate a {'warm, conversational' if high_anthropomorphism else 'professional, technical'} explanation "
|
| 445 |
+
f"for this {explanation_type.upper()} analysis using the following data:\n\n"
|
| 446 |
+
f"{data_json}\n\n"
|
| 447 |
+
"Generate ONLY the explanation text. Do not add meta-commentary. "
|
| 448 |
+
"Preserve all numbers exactly as provided. "
|
| 449 |
+
f"{'Use natural language and emojis.' if high_anthropomorphism else 'Use professional language without emojis.'}"
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
model_name = os.getenv("HICXAI_OPENAI_MODEL", "gpt-4o-mini")
|
| 453 |
+
# Higher temperature for HIGH anthropomorphism = more conversational variety
|
| 454 |
+
temperature = float(os.getenv("HICXAI_TEMPERATURE", "0.7" if high_anthropomorphism else "0.3"))
|
| 455 |
+
max_tokens = 600 if explanation_type == "shap" else 400
|
| 456 |
+
|
| 457 |
+
completion = client.chat.completions.create(
|
| 458 |
+
model=model_name,
|
| 459 |
+
messages=[
|
| 460 |
+
{"role": "system", "content": system_prompt},
|
| 461 |
+
{"role": "user", "content": user_prompt}
|
| 462 |
+
],
|
| 463 |
+
temperature=temperature,
|
| 464 |
+
max_tokens=max_tokens,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
content = completion.choices[0].message.content if completion and completion.choices else None
|
| 468 |
+
|
| 469 |
+
# Post-process: Remove letter formatting if LOW anthropomorphism
|
| 470 |
+
if content and not high_anthropomorphism:
|
| 471 |
+
content = _remove_letter_formatting(content)
|
| 472 |
+
|
| 473 |
+
return content
|
| 474 |
+
|
| 475 |
+
except Exception as e:
|
| 476 |
+
print(f"❌ generate_from_data failed: {e}")
|
| 477 |
+
return None
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def enhance_response(response: str, context: Optional[Dict[str, Any]] = None, response_type: str = "explanation", high_anthropomorphism: bool = True) -> str:
|
| 481 |
+
"""Enhance response using OpenAI to respect anthropomorphism condition (REQUIRED for quality).
|
| 482 |
+
|
| 483 |
+
Args:
|
| 484 |
+
response: The original response text
|
| 485 |
+
context: Optional context dictionary
|
| 486 |
+
response_type: Type of response (explanation, loan, etc)
|
| 487 |
+
high_anthropomorphism: If True, use warm Luna style with actionable insights.
|
| 488 |
+
If False, use professional AI Assistant style.
|
| 489 |
+
|
| 490 |
+
If OpenAI is not configured, returns the original response (degraded quality).
|
| 491 |
+
"""
|
| 492 |
+
if not response or not isinstance(response, str):
|
| 493 |
+
return response
|
| 494 |
+
|
| 495 |
+
if not _should_use_genai():
|
| 496 |
+
return response
|
| 497 |
+
|
| 498 |
+
try:
|
| 499 |
+
# Preferred path: OpenAI SDK v1.x
|
| 500 |
+
client = _get_openai_client()
|
| 501 |
+
messages = _compose_messages(response, context, high_anthropomorphism)
|
| 502 |
+
model_name = os.getenv("HICXAI_OPENAI_MODEL", "gpt-4o-mini")
|
| 503 |
+
# Higher temperature for HIGH anthropomorphism = more conversational variety
|
| 504 |
+
temperature = float(os.getenv("HICXAI_TEMPERATURE", "0.7" if high_anthropomorphism else "0.2"))
|
| 505 |
+
|
| 506 |
+
# For SHAP explanations, we need more tokens (especially for denials)
|
| 507 |
+
# Response type determines token budget
|
| 508 |
+
if response_type == "explanation" and context and context.get('explanation_type') == 'feature_importance':
|
| 509 |
+
# SHAP explanations need more space (denial cases are typically 400-500 tokens)
|
| 510 |
+
default_tokens = 600
|
| 511 |
+
else:
|
| 512 |
+
# Other responses can be shorter (validation, greetings, etc.)
|
| 513 |
+
default_tokens = 400
|
| 514 |
+
|
| 515 |
+
max_tokens = int(os.getenv("HICXAI_MAX_TOKENS", str(default_tokens)))
|
| 516 |
+
|
| 517 |
+
if client is not None:
|
| 518 |
+
try:
|
| 519 |
+
completion = client.chat.completions.create(
|
| 520 |
+
model=model_name,
|
| 521 |
+
messages=messages,
|
| 522 |
+
temperature=temperature,
|
| 523 |
+
max_tokens=max_tokens,
|
| 524 |
+
)
|
| 525 |
+
content = completion.choices[0].message.content if completion and completion.choices else None
|
| 526 |
+
|
| 527 |
+
# Post-process: Remove letter formatting if LOW anthropomorphism
|
| 528 |
+
if content and not high_anthropomorphism:
|
| 529 |
+
content = _remove_letter_formatting(content)
|
| 530 |
+
|
| 531 |
+
return content or response
|
| 532 |
+
except Exception:
|
| 533 |
+
pass
|
| 534 |
+
|
| 535 |
+
# Fallback: Older OpenAI SDK versions (pre-1.0)
|
| 536 |
+
try:
|
| 537 |
+
import openai # type: ignore
|
| 538 |
+
openai.api_key = os.environ.get("OPENAI_API_KEY")
|
| 539 |
+
# Support optional base URL on legacy sdk too
|
| 540 |
+
base_url = (
|
| 541 |
+
os.environ.get("HICXAI_OPENAI_BASE_URL")
|
| 542 |
+
or os.environ.get("OPENAI_BASE_URL")
|
| 543 |
+
or None
|
| 544 |
+
)
|
| 545 |
+
if base_url:
|
| 546 |
+
try:
|
| 547 |
+
openai.base_url = base_url # type: ignore[attr-defined]
|
| 548 |
+
except Exception:
|
| 549 |
+
pass
|
| 550 |
+
completion = openai.ChatCompletion.create(
|
| 551 |
+
model=model_name,
|
| 552 |
+
messages=messages,
|
| 553 |
+
temperature=temperature,
|
| 554 |
+
max_tokens=max_tokens,
|
| 555 |
+
)
|
| 556 |
+
content = completion["choices"][0]["message"]["content"] if completion else None
|
| 557 |
+
|
| 558 |
+
# Post-process: Remove letter formatting if LOW anthropomorphism
|
| 559 |
+
if content and not high_anthropomorphism:
|
| 560 |
+
content = _remove_letter_formatting(content)
|
| 561 |
+
|
| 562 |
+
return content or response
|
| 563 |
+
except Exception:
|
| 564 |
+
return response
|
| 565 |
+
except Exception:
|
| 566 |
+
# Never break the app if the API call fails
|
| 567 |
+
return response
|
src/nlu.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# NLU module for sentence-transformers-based semantic similarity and intent extraction
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
from constraints import L_SUPPORT_QUESTIONS_IDS, INTENT_TO_XAI_METHOD
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from sentence_transformers import SentenceTransformer
|
| 10 |
+
SENTENCE_TRANSFORMERS_AVAILABLE = True
|
| 11 |
+
except ImportError:
|
| 12 |
+
SentenceTransformer = None
|
| 13 |
+
SENTENCE_TRANSFORMERS_AVAILABLE = False
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from simcse import SimCSE
|
| 17 |
+
SIMCSE_AVAILABLE = True
|
| 18 |
+
except ImportError:
|
| 19 |
+
SimCSE = None
|
| 20 |
+
SIMCSE_AVAILABLE = False
|
| 21 |
+
|
| 22 |
+
class NLU:
|
| 23 |
+
def __init__(self, model_type="sentence_transformers", model_path=None):
|
| 24 |
+
self.model_type = model_type
|
| 25 |
+
self.df = pd.read_csv(os.path.join(os.path.dirname(__file__), '..', 'data_questions', 'Median_4.csv'), index_col=0).drop_duplicates()
|
| 26 |
+
self.questions = list(self.df['Question'])
|
| 27 |
+
|
| 28 |
+
# Prefer sentence-transformers; use GPU if available, otherwise CPU (Streamlit Cloud has no GPU)
|
| 29 |
+
if model_type == "sentence_transformers":
|
| 30 |
+
if not SENTENCE_TRANSFORMERS_AVAILABLE:
|
| 31 |
+
print("⚠️ sentence-transformers not available, trying SimCSE...")
|
| 32 |
+
self.model_type = "simcse"
|
| 33 |
+
else:
|
| 34 |
+
try:
|
| 35 |
+
import torch
|
| 36 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 37 |
+
except Exception:
|
| 38 |
+
device = "cpu"
|
| 39 |
+
# Lightweight, fast model for semantic similarity
|
| 40 |
+
self.model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
|
| 41 |
+
print(f"✅ Loaded sentence-transformers model on {device}")
|
| 42 |
+
# Pre-compute embeddings for all questions
|
| 43 |
+
self.question_embeddings = self.model.encode(self.questions, convert_to_numpy=True, show_progress_bar=False)
|
| 44 |
+
print(f"✅ Pre-computed embeddings for {len(self.questions)} questions")
|
| 45 |
+
|
| 46 |
+
# Optional SimCSE fallback for legacy envs
|
| 47 |
+
if self.model_type == "simcse" or (model_type == "sentence_transformers" and not SENTENCE_TRANSFORMERS_AVAILABLE):
|
| 48 |
+
if not SIMCSE_AVAILABLE:
|
| 49 |
+
print("⚠️ SimCSE not available, falling back to simple keyword matching")
|
| 50 |
+
self.model_type = "fallback"
|
| 51 |
+
self.model = None
|
| 52 |
+
else:
|
| 53 |
+
self.model = SimCSE("princeton-nlp/sup-simcse-roberta-large")
|
| 54 |
+
self.model.build_index(self.questions)
|
| 55 |
+
self.model_type = "simcse"
|
| 56 |
+
|
| 57 |
+
elif model_type == "fallback":
|
| 58 |
+
self.model = None
|
| 59 |
+
elif self.model_type not in {"sentence_transformers", "simcse", "fallback"}:
|
| 60 |
+
raise ValueError(f"Unsupported NLU model type: {model_type}. Supported: 'sentence_transformers', 'simcse', 'fallback'")
|
| 61 |
+
|
| 62 |
+
def classify_intent(self, user_input, top_k=5):
|
| 63 |
+
# Dynamic, model-driven intent extraction
|
| 64 |
+
# Fast keyword heuristics ensure clear phrases immediately map to an XAI method
|
| 65 |
+
try:
|
| 66 |
+
text = (user_input or "").lower()
|
| 67 |
+
# Heuristics for common phrasing
|
| 68 |
+
rule_keywords = ["rule-based", "rule based", "rules", "conditions", "if then", "anchor"]
|
| 69 |
+
shap_keywords = ["feature", "importance", "impact", "influence", "contribute", "shap", "why", "explain", "decision", "factors", "affected"]
|
| 70 |
+
dice_keywords = ["what if", "counterfactual", "change", "modify", "different", "should i", "how to get"]
|
| 71 |
+
if any(k in text for k in rule_keywords):
|
| 72 |
+
return {
|
| 73 |
+
'intent': 'anchor',
|
| 74 |
+
'label': None,
|
| 75 |
+
'confidence': 0.95,
|
| 76 |
+
'matched_question': "Provide a simple rule-based explanation for this decision."
|
| 77 |
+
}, 0.95, []
|
| 78 |
+
# Explicitly detect single-word 'why' queries too
|
| 79 |
+
if text.strip() == "why" or any(k in text for k in shap_keywords):
|
| 80 |
+
return {
|
| 81 |
+
'intent': 'shap',
|
| 82 |
+
'label': None,
|
| 83 |
+
'confidence': 0.9,
|
| 84 |
+
'matched_question': "Which features were most important for this prediction?"
|
| 85 |
+
}, 0.9, []
|
| 86 |
+
if any(k in text for k in dice_keywords):
|
| 87 |
+
return {
|
| 88 |
+
'intent': 'dice',
|
| 89 |
+
'label': None,
|
| 90 |
+
'confidence': 0.9,
|
| 91 |
+
'matched_question': "How should the instance be changed to get a different prediction?"
|
| 92 |
+
}, 0.9, []
|
| 93 |
+
except Exception:
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
# sentence-transformers path
|
| 97 |
+
if self.model_type == "sentence_transformers" and hasattr(self, 'question_embeddings'):
|
| 98 |
+
try:
|
| 99 |
+
query_emb = self.model.encode([user_input], convert_to_numpy=True, show_progress_bar=False)[0]
|
| 100 |
+
# Cosine similarity
|
| 101 |
+
q_norm = np.linalg.norm(self.question_embeddings, axis=1) + 1e-12
|
| 102 |
+
u_norm = np.linalg.norm(query_emb) + 1e-12
|
| 103 |
+
sims = (self.question_embeddings @ query_emb) / (q_norm * u_norm)
|
| 104 |
+
# Top-k indices
|
| 105 |
+
top_idx = np.argsort(-sims)[:top_k]
|
| 106 |
+
match_question = self.questions[top_idx[0]]
|
| 107 |
+
score = float(sims[top_idx[0]])
|
| 108 |
+
label = self.df.iloc[top_idx[0]]['Label']
|
| 109 |
+
xai_method = self.map_label_to_xai_method(label)
|
| 110 |
+
suggestions = [self.questions[i] for i in top_idx]
|
| 111 |
+
return {
|
| 112 |
+
'intent': xai_method,
|
| 113 |
+
'label': label,
|
| 114 |
+
'confidence': score,
|
| 115 |
+
'matched_question': match_question
|
| 116 |
+
}, score, suggestions
|
| 117 |
+
except Exception as e:
|
| 118 |
+
print(f"sentence-transformers classify failed: {e}")
|
| 119 |
+
|
| 120 |
+
# Legacy SimCSE path
|
| 121 |
+
if self.model_type == "simcse" and self.model is not None:
|
| 122 |
+
# Always get top matches without initial threshold filtering
|
| 123 |
+
match_results = self.model.search(user_input, threshold=0, top_k=top_k)
|
| 124 |
+
|
| 125 |
+
if len(match_results) > 0:
|
| 126 |
+
match_question, score = match_results[0]
|
| 127 |
+
|
| 128 |
+
# Get the label for the matched question
|
| 129 |
+
label = self.df.query('Question == @match_question')['Label'].iloc[0]
|
| 130 |
+
# Map label to XAI method if supported
|
| 131 |
+
xai_method = self.map_label_to_xai_method(label)
|
| 132 |
+
|
| 133 |
+
# Normalize confidence score to 0-1 range for consistency
|
| 134 |
+
# SimCSE scores can be very high, so we'll use relative confidence
|
| 135 |
+
normalized_confidence = min(1.0, score / 1e20) if score > 1 else score
|
| 136 |
+
|
| 137 |
+
# Always return the best match but indicate confidence level
|
| 138 |
+
return {
|
| 139 |
+
'intent': xai_method,
|
| 140 |
+
'label': label,
|
| 141 |
+
'confidence': normalized_confidence,
|
| 142 |
+
'matched_question': match_question
|
| 143 |
+
}, normalized_confidence, []
|
| 144 |
+
|
| 145 |
+
# Fallback to simple keyword matching when SimCSE is not available
|
| 146 |
+
elif self.model_type == "fallback" or self.model is None:
|
| 147 |
+
return self._fallback_classify_intent(user_input, top_k)
|
| 148 |
+
|
| 149 |
+
# No matches found at all
|
| 150 |
+
return 'unknown', 0.0, []
|
| 151 |
+
else:
|
| 152 |
+
return 'unknown', 0.0, []
|
| 153 |
+
|
| 154 |
+
def match(self, user_input, features=None, prediction=None, current_instance=None, labels=None):
|
| 155 |
+
"""Hybrid approach: Fuzzy first (primary), Intent classifier fallback"""
|
| 156 |
+
|
| 157 |
+
# PRIMARY: Try fuzzy matching first (fast and reliable)
|
| 158 |
+
fuzzy_result = self._fuzzy_match_fallback(user_input)
|
| 159 |
+
if fuzzy_result != "unknown":
|
| 160 |
+
print(f"🔤 Fuzzy match (primary): {fuzzy_result}")
|
| 161 |
+
return fuzzy_result
|
| 162 |
+
|
| 163 |
+
# FALLBACK 1: Try intent classifier (65% accuracy)
|
| 164 |
+
intent_result = self._classify_with_intent_classifier(user_input)
|
| 165 |
+
if intent_result != "unknown":
|
| 166 |
+
print(f"🧠 Intent classifier (fallback): {intent_result}")
|
| 167 |
+
return intent_result
|
| 168 |
+
|
| 169 |
+
# FALLBACK 2: Try embedding search if available (ST first, then SimCSE)
|
| 170 |
+
if self.model_type == "sentence_transformers" and hasattr(self, 'question_embeddings'):
|
| 171 |
+
try:
|
| 172 |
+
query_emb = self.model.encode([user_input], convert_to_numpy=True, show_progress_bar=False)[0]
|
| 173 |
+
q_norm = np.linalg.norm(self.question_embeddings, axis=1) + 1e-12
|
| 174 |
+
u_norm = np.linalg.norm(query_emb) + 1e-12
|
| 175 |
+
sims = (self.question_embeddings @ query_emb) / (q_norm * u_norm)
|
| 176 |
+
best_idx = int(np.argmax(sims))
|
| 177 |
+
match_question = self.questions[best_idx]
|
| 178 |
+
print(f"🔍 ST match (last resort): {match_question}")
|
| 179 |
+
return match_question
|
| 180 |
+
except Exception as e:
|
| 181 |
+
print(f"ST search failed: {e}")
|
| 182 |
+
|
| 183 |
+
if hasattr(self, 'model') and self.model_type == "simcse" and self.model is not None:
|
| 184 |
+
try:
|
| 185 |
+
threshold = 0.6
|
| 186 |
+
match_results = self.model.search(user_input, threshold=threshold)
|
| 187 |
+
|
| 188 |
+
if len(match_results) > 0:
|
| 189 |
+
match_question, score = match_results[0]
|
| 190 |
+
print(f"🔍 SimCSE match (last resort): {match_question}")
|
| 191 |
+
return match_question
|
| 192 |
+
else:
|
| 193 |
+
# Try with no threshold
|
| 194 |
+
match_results = self.model.search(user_input, threshold=0, top_k=5)
|
| 195 |
+
if len(match_results) > 0:
|
| 196 |
+
match_question, score = match_results[0]
|
| 197 |
+
print(f"🔍 SimCSE fallback: {match_question}")
|
| 198 |
+
return match_question
|
| 199 |
+
except Exception as e:
|
| 200 |
+
print(f"SimCSE search failed: {e}")
|
| 201 |
+
|
| 202 |
+
print(f"❓ No match found for: '{user_input}'")
|
| 203 |
+
return "unknown"
|
| 204 |
+
|
| 205 |
+
def _fuzzy_match_fallback(self, user_input):
|
| 206 |
+
"""Fallback fuzzy matching using simple string similarity"""
|
| 207 |
+
try:
|
| 208 |
+
from difflib import SequenceMatcher
|
| 209 |
+
|
| 210 |
+
user_lower = user_input.lower()
|
| 211 |
+
best_match = None
|
| 212 |
+
best_score = 0
|
| 213 |
+
|
| 214 |
+
# Define key patterns for different XAI methods
|
| 215 |
+
shap_patterns = [
|
| 216 |
+
"feature", "important", "impact", "contribute", "influence", "matter", "weigh", "explain", "why"
|
| 217 |
+
]
|
| 218 |
+
dice_patterns = [
|
| 219 |
+
"change", "different", "modify", "counterfact", "should", "what if", "approved", "denied"
|
| 220 |
+
]
|
| 221 |
+
anchor_patterns = [
|
| 222 |
+
"rule", "condition", "guarantee", "necessary", "sufficient", "always", "simple"
|
| 223 |
+
]
|
| 224 |
+
|
| 225 |
+
# Check for pattern matches
|
| 226 |
+
if any(pattern in user_lower for pattern in shap_patterns):
|
| 227 |
+
# Return a representative SHAP question
|
| 228 |
+
return "What features of this instance lead to the system's prediction?"
|
| 229 |
+
elif any(pattern in user_lower for pattern in dice_patterns):
|
| 230 |
+
# Return a representative DiCE question
|
| 231 |
+
return "How should the instance be changed to get a different (better or worse) prediction?"
|
| 232 |
+
elif any(pattern in user_lower for pattern in anchor_patterns):
|
| 233 |
+
# Return a representative Anchor question
|
| 234 |
+
return "What is the minimum requirement for the prediction to stay the same?"
|
| 235 |
+
|
| 236 |
+
# If no patterns match, try fuzzy string matching with dataset questions
|
| 237 |
+
for _, row in self.df.iterrows():
|
| 238 |
+
question = row['Question']
|
| 239 |
+
similarity = SequenceMatcher(None, user_lower, question.lower()).ratio()
|
| 240 |
+
if similarity > best_score:
|
| 241 |
+
best_score = similarity
|
| 242 |
+
best_match = question
|
| 243 |
+
|
| 244 |
+
# Return best match if similarity is reasonable
|
| 245 |
+
if best_score > 0.4: # 40% similarity threshold
|
| 246 |
+
return best_match
|
| 247 |
+
|
| 248 |
+
except Exception as e:
|
| 249 |
+
print(f"Fuzzy matching failed: {e}")
|
| 250 |
+
|
| 251 |
+
return "unknown"
|
| 252 |
+
|
| 253 |
+
def get_question_suggestions(self, match_results):
|
| 254 |
+
"""Extract question suggestions from match results"""
|
| 255 |
+
suggestions = []
|
| 256 |
+
for question, _ in match_results:
|
| 257 |
+
if len(suggestions) < 5: # Limit to 5 suggestions
|
| 258 |
+
suggestions.append(question)
|
| 259 |
+
return suggestions
|
| 260 |
+
|
| 261 |
+
def map_label_to_xai_method(self, label):
|
| 262 |
+
"""Map question label to appropriate XAI method (adopted from XAgent logic)"""
|
| 263 |
+
from constraints import L_SHAP_QUESTION_IDS, L_DICE_QUESTION_IDS, L_ANCHOR_QUESTION_IDS
|
| 264 |
+
|
| 265 |
+
if label in L_SHAP_QUESTION_IDS:
|
| 266 |
+
return "shap"
|
| 267 |
+
elif label in L_DICE_QUESTION_IDS:
|
| 268 |
+
return "dice"
|
| 269 |
+
elif label in L_ANCHOR_QUESTION_IDS:
|
| 270 |
+
return "anchor"
|
| 271 |
+
else:
|
| 272 |
+
return "general"
|
| 273 |
+
|
| 274 |
+
def replace_information(self, question, features=None, prediction=None, current_instance=None, labels=None):
|
| 275 |
+
"""Replace template variables in questions (adopted from XAgent)"""
|
| 276 |
+
if features and "{X}" in question:
|
| 277 |
+
feature_str = f"{{{features[0]},{features[1]}, ...}}" if len(features) > 1 else f"{features[0]}"
|
| 278 |
+
question = question.replace("{X}", feature_str)
|
| 279 |
+
if prediction and "{P}" in question:
|
| 280 |
+
question = question.replace("{P}", str(prediction))
|
| 281 |
+
if labels and prediction and "{Q}" in question:
|
| 282 |
+
other_labels = [label for label in labels if str(label) != str(prediction)]
|
| 283 |
+
question = question.replace("{Q}", str(other_labels))
|
| 284 |
+
return question
|
| 285 |
+
|
| 286 |
+
def _classify_with_intent_classifier(self, user_input):
|
| 287 |
+
"""Use the trained intent classifier (65% accuracy) as fallback"""
|
| 288 |
+
try:
|
| 289 |
+
# Try to load intent classifier if not already loaded
|
| 290 |
+
if not hasattr(self, 'intent_classifier') or self.intent_classifier is None:
|
| 291 |
+
self._load_intent_classifier()
|
| 292 |
+
|
| 293 |
+
if self.intent_classifier is None:
|
| 294 |
+
return "unknown"
|
| 295 |
+
|
| 296 |
+
# Generate embedding for user input
|
| 297 |
+
embedding = self.intent_simcse.encode([user_input])
|
| 298 |
+
|
| 299 |
+
# Convert to tensor
|
| 300 |
+
import torch
|
| 301 |
+
import numpy as np
|
| 302 |
+
embedding_tensor = torch.FloatTensor(embedding)
|
| 303 |
+
|
| 304 |
+
# Get classifier prediction
|
| 305 |
+
with torch.no_grad():
|
| 306 |
+
outputs = self.intent_classifier(embedding_tensor)
|
| 307 |
+
probabilities = outputs[0].numpy()
|
| 308 |
+
|
| 309 |
+
# Get the class with highest probability
|
| 310 |
+
predicted_class_idx = np.argmax(probabilities)
|
| 311 |
+
confidence = probabilities[predicted_class_idx]
|
| 312 |
+
|
| 313 |
+
# Use lower threshold since this is fallback
|
| 314 |
+
if confidence >= 0.3: # Lower threshold for fallback
|
| 315 |
+
# Convert back to intent
|
| 316 |
+
predicted_intent = self.intent_label_encoder.inverse_transform([predicted_class_idx])[0]
|
| 317 |
+
|
| 318 |
+
# Map intent to representative question
|
| 319 |
+
if predicted_intent == 'shap':
|
| 320 |
+
return "What features of this instance lead to the system's prediction?"
|
| 321 |
+
elif predicted_intent == 'dice':
|
| 322 |
+
return "How should the instance be changed to get a different (better or worse) prediction?"
|
| 323 |
+
elif predicted_intent == 'anchor':
|
| 324 |
+
return "What is the minimum requirement for the prediction to stay the same?"
|
| 325 |
+
# Don't return anything for 'other' - let it fall through
|
| 326 |
+
|
| 327 |
+
except Exception as e:
|
| 328 |
+
print(f"Intent classifier failed: {e}")
|
| 329 |
+
|
| 330 |
+
return "unknown"
|
| 331 |
+
|
| 332 |
+
def _load_intent_classifier(self):
|
| 333 |
+
"""Load the trained intent classifier (65% accuracy model)"""
|
| 334 |
+
try:
|
| 335 |
+
import torch
|
| 336 |
+
import torch.nn as nn
|
| 337 |
+
import pickle
|
| 338 |
+
import numpy as np
|
| 339 |
+
from simcse import SimCSE
|
| 340 |
+
|
| 341 |
+
# Define the classifier architecture (matching the training script)
|
| 342 |
+
class IntentClassifier(nn.Module):
|
| 343 |
+
def __init__(self, input_dim, hidden_dim, num_classes=4):
|
| 344 |
+
super(IntentClassifier, self).__init__()
|
| 345 |
+
self.network = nn.Sequential(
|
| 346 |
+
nn.Linear(input_dim, hidden_dim),
|
| 347 |
+
nn.ReLU(),
|
| 348 |
+
nn.Dropout(0.2),
|
| 349 |
+
nn.Linear(hidden_dim, hidden_dim // 2),
|
| 350 |
+
nn.ReLU(),
|
| 351 |
+
nn.Dropout(0.2),
|
| 352 |
+
nn.Linear(hidden_dim // 2, num_classes),
|
| 353 |
+
nn.Softmax(dim=1)
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
def forward(self, x):
|
| 357 |
+
return self.network(x)
|
| 358 |
+
|
| 359 |
+
# Load metadata
|
| 360 |
+
with open('models/intent_classifier_metadata.pkl', 'rb') as f:
|
| 361 |
+
metadata = pickle.load(f)
|
| 362 |
+
|
| 363 |
+
# Load label encoder
|
| 364 |
+
with open('models/intent_label_encoder.pkl', 'rb') as f:
|
| 365 |
+
self.intent_label_encoder = pickle.load(f)
|
| 366 |
+
|
| 367 |
+
# Initialize and load classifier
|
| 368 |
+
self.intent_classifier = IntentClassifier(
|
| 369 |
+
metadata['input_dim'],
|
| 370 |
+
metadata['hidden_dim'],
|
| 371 |
+
metadata['num_classes']
|
| 372 |
+
)
|
| 373 |
+
self.intent_classifier.load_state_dict(torch.load('models/intent_classifier_best.pth', map_location='cpu'))
|
| 374 |
+
self.intent_classifier.eval()
|
| 375 |
+
|
| 376 |
+
# Initialize SimCSE for embedding generation
|
| 377 |
+
self.intent_simcse = SimCSE("princeton-nlp/sup-simcse-roberta-large")
|
| 378 |
+
|
| 379 |
+
print(f"✅ Loaded intent classifier (accuracy: {metadata.get('best_accuracy', 'unknown'):.4f})")
|
| 380 |
+
|
| 381 |
+
except Exception as e:
|
| 382 |
+
print(f"⚠️ Could not load intent classifier: {e}")
|
| 383 |
+
self.intent_classifier = None
|
| 384 |
+
self.intent_label_encoder = None
|
| 385 |
+
self.intent_simcse = None
|
src/nlu_config.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "sentence_transformers"
|
| 3 |
+
}
|
src/preprocessing.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Preprocessing utilities for the Adult dataset.
|
| 2 |
+
|
| 3 |
+
Exports:
|
| 4 |
+
- preprocess_adult(df): returns a cleaned, numeric DataFrame with an 'income' label column.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import List
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _strip_and_normalize_strings(df: pd.DataFrame, cols: List[str]) -> pd.DataFrame:
|
| 13 |
+
for c in cols:
|
| 14 |
+
df[c] = (
|
| 15 |
+
df[c]
|
| 16 |
+
.astype(str)
|
| 17 |
+
.str.strip()
|
| 18 |
+
.replace({'?': 'Unknown'})
|
| 19 |
+
)
|
| 20 |
+
return df
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def preprocess_adult(df: pd.DataFrame) -> pd.DataFrame:
|
| 24 |
+
"""Clean and encode Adult dataset into numeric features.
|
| 25 |
+
|
| 26 |
+
Input:
|
| 27 |
+
df: DataFrame containing Adult columns including 'income'.
|
| 28 |
+
Output:
|
| 29 |
+
DataFrame with numeric features; 'income' remains as the target label.
|
| 30 |
+
"""
|
| 31 |
+
df = df.copy()
|
| 32 |
+
|
| 33 |
+
if 'income' not in df.columns:
|
| 34 |
+
raise ValueError("Expected 'income' column in Adult dataframe")
|
| 35 |
+
|
| 36 |
+
# Normalize string columns
|
| 37 |
+
object_cols = [c for c in df.columns if df[c].dtype == 'object']
|
| 38 |
+
df[object_cols] = df[object_cols].fillna('Unknown')
|
| 39 |
+
df = _strip_and_normalize_strings(df, object_cols)
|
| 40 |
+
|
| 41 |
+
# Ensure common numeric cols are numeric
|
| 42 |
+
numeric_candidates = [
|
| 43 |
+
'age', 'fnlwgt', 'education_num', 'capital_gain', 'capital_loss', 'hours_per_week'
|
| 44 |
+
]
|
| 45 |
+
for c in numeric_candidates:
|
| 46 |
+
if c in df.columns:
|
| 47 |
+
df[c] = pd.to_numeric(df[c], errors='coerce')
|
| 48 |
+
|
| 49 |
+
# Fill NaNs: numeric with median, categorical with mode/Unknown
|
| 50 |
+
for c in df.columns:
|
| 51 |
+
if c == 'income':
|
| 52 |
+
continue
|
| 53 |
+
if pd.api.types.is_numeric_dtype(df[c]):
|
| 54 |
+
# Calculate median, but use a default value if median is NaN (empty column)
|
| 55 |
+
median_val = df[c].median()
|
| 56 |
+
if pd.isna(median_val):
|
| 57 |
+
# Use sensible defaults for numeric columns if median is NaN
|
| 58 |
+
if c == 'age':
|
| 59 |
+
median_val = 35
|
| 60 |
+
elif c == 'fnlwgt':
|
| 61 |
+
median_val = 100000
|
| 62 |
+
elif c == 'education_num':
|
| 63 |
+
median_val = 9 # HS-grad equivalent
|
| 64 |
+
elif c in ['capital_gain', 'capital_loss']:
|
| 65 |
+
median_val = 0
|
| 66 |
+
elif c == 'hours_per_week':
|
| 67 |
+
median_val = 40
|
| 68 |
+
else:
|
| 69 |
+
median_val = 0 # Default fallback
|
| 70 |
+
df[c] = df[c].fillna(median_val)
|
| 71 |
+
else:
|
| 72 |
+
df[c] = df[c].fillna('Unknown')
|
| 73 |
+
|
| 74 |
+
# One-hot encode categorical features except the target
|
| 75 |
+
cat_cols = [c for c in df.columns if df[c].dtype == 'object' and c != 'income']
|
| 76 |
+
df_encoded = pd.get_dummies(df, columns=cat_cols, drop_first=True)
|
| 77 |
+
|
| 78 |
+
# Keep label as string categories; sklearn supports string labels
|
| 79 |
+
# Ensure 'income' column is last for readability
|
| 80 |
+
cols = [c for c in df_encoded.columns if c != 'income'] + ['income']
|
| 81 |
+
df_encoded = df_encoded[cols]
|
| 82 |
+
|
| 83 |
+
return df_encoded
|
src/shap_visualizer.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SHAP Visualization Component for XAI Explanations
|
| 3 |
+
Generates visual SHAP plots and explanations
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import streamlit as st
|
| 10 |
+
import io
|
| 11 |
+
import base64
|
| 12 |
+
|
| 13 |
+
def create_shap_bar_plot(feature_impacts, prediction_class, title="Feature Importance Analysis"):
|
| 14 |
+
"""
|
| 15 |
+
Create a SHAP-style bar plot showing feature impacts
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
feature_impacts: List of strings like "age increases the prediction probability by 0.150"
|
| 19 |
+
prediction_class: The predicted class (e.g., ">50K" or "<=50K")
|
| 20 |
+
title: Plot title
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
matplotlib figure
|
| 24 |
+
"""
|
| 25 |
+
try:
|
| 26 |
+
# Parse feature impacts
|
| 27 |
+
features = []
|
| 28 |
+
impacts = []
|
| 29 |
+
|
| 30 |
+
for impact_str in feature_impacts:
|
| 31 |
+
# Parse strings like "age increases the prediction probability by 0.150"
|
| 32 |
+
parts = impact_str.split()
|
| 33 |
+
if len(parts) >= 2:
|
| 34 |
+
feature = parts[0]
|
| 35 |
+
try:
|
| 36 |
+
# Find the numeric value
|
| 37 |
+
value = None
|
| 38 |
+
for part in parts:
|
| 39 |
+
try:
|
| 40 |
+
value = float(part)
|
| 41 |
+
break
|
| 42 |
+
except ValueError:
|
| 43 |
+
continue
|
| 44 |
+
|
| 45 |
+
if value is not None:
|
| 46 |
+
# Determine if positive or negative impact
|
| 47 |
+
if "increases" in impact_str:
|
| 48 |
+
impacts.append(value)
|
| 49 |
+
elif "decreases" in impact_str:
|
| 50 |
+
impacts.append(-value)
|
| 51 |
+
else:
|
| 52 |
+
impacts.append(value)
|
| 53 |
+
features.append(feature.capitalize())
|
| 54 |
+
except ValueError:
|
| 55 |
+
continue
|
| 56 |
+
|
| 57 |
+
if not features:
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
# Create the plot
|
| 61 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 62 |
+
|
| 63 |
+
# Sort by absolute impact
|
| 64 |
+
sorted_data = sorted(zip(features, impacts), key=lambda x: abs(x[1]), reverse=True)
|
| 65 |
+
features_sorted, impacts_sorted = zip(*sorted_data)
|
| 66 |
+
|
| 67 |
+
# Create colors: red for negative, blue for positive
|
| 68 |
+
colors = ['red' if impact < 0 else 'blue' for impact in impacts_sorted]
|
| 69 |
+
|
| 70 |
+
# Create horizontal bar plot
|
| 71 |
+
bars = ax.barh(range(len(features_sorted)), impacts_sorted, color=colors, alpha=0.7)
|
| 72 |
+
|
| 73 |
+
# Customize the plot
|
| 74 |
+
ax.set_yticks(range(len(features_sorted)))
|
| 75 |
+
ax.set_yticklabels(features_sorted)
|
| 76 |
+
ax.set_xlabel('Impact on Prediction Probability')
|
| 77 |
+
ax.set_title(f'{title}\nPrediction: {prediction_class}', fontsize=14, fontweight='bold')
|
| 78 |
+
ax.axvline(x=0, color='black', linestyle='-', alpha=0.3)
|
| 79 |
+
|
| 80 |
+
# Add value labels on bars
|
| 81 |
+
for i, (bar, impact) in enumerate(zip(bars, impacts_sorted)):
|
| 82 |
+
width = bar.get_width()
|
| 83 |
+
label_x = width + (0.01 if width >= 0 else -0.01)
|
| 84 |
+
ax.text(label_x, bar.get_y() + bar.get_height()/2,
|
| 85 |
+
f'{impact:.3f}', ha='left' if width >= 0 else 'right',
|
| 86 |
+
va='center', fontweight='bold')
|
| 87 |
+
|
| 88 |
+
# Add legend
|
| 89 |
+
from matplotlib.patches import Patch
|
| 90 |
+
legend_elements = [
|
| 91 |
+
Patch(facecolor='blue', alpha=0.7, label='Increases Probability'),
|
| 92 |
+
Patch(facecolor='red', alpha=0.7, label='Decreases Probability')
|
| 93 |
+
]
|
| 94 |
+
ax.legend(handles=legend_elements, loc='lower right')
|
| 95 |
+
|
| 96 |
+
# Style improvements
|
| 97 |
+
ax.grid(True, alpha=0.3, axis='x')
|
| 98 |
+
ax.spines['top'].set_visible(False)
|
| 99 |
+
ax.spines['right'].set_visible(False)
|
| 100 |
+
|
| 101 |
+
plt.tight_layout()
|
| 102 |
+
return fig
|
| 103 |
+
|
| 104 |
+
except Exception as e:
|
| 105 |
+
st.error(f"Error creating SHAP plot: {e}")
|
| 106 |
+
return None
|
| 107 |
+
|
| 108 |
+
def create_shap_waterfall_plot(feature_impacts, base_probability=0.5, prediction_class="<=50K"):
|
| 109 |
+
"""
|
| 110 |
+
Create a SHAP-style waterfall plot showing cumulative feature impacts
|
| 111 |
+
"""
|
| 112 |
+
try:
|
| 113 |
+
# Parse feature impacts
|
| 114 |
+
features = []
|
| 115 |
+
impacts = []
|
| 116 |
+
|
| 117 |
+
for impact_str in feature_impacts:
|
| 118 |
+
parts = impact_str.split()
|
| 119 |
+
if len(parts) >= 2:
|
| 120 |
+
feature = parts[0]
|
| 121 |
+
try:
|
| 122 |
+
value = None
|
| 123 |
+
for part in parts:
|
| 124 |
+
try:
|
| 125 |
+
value = float(part)
|
| 126 |
+
break
|
| 127 |
+
except ValueError:
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
+
if value is not None:
|
| 131 |
+
if "decreases" in impact_str:
|
| 132 |
+
value = -value
|
| 133 |
+
features.append(feature.capitalize())
|
| 134 |
+
impacts.append(value)
|
| 135 |
+
except ValueError:
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
if not features:
|
| 139 |
+
return None
|
| 140 |
+
|
| 141 |
+
# Create waterfall data
|
| 142 |
+
cumulative = [base_probability]
|
| 143 |
+
for impact in impacts:
|
| 144 |
+
cumulative.append(cumulative[-1] + impact)
|
| 145 |
+
|
| 146 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 147 |
+
|
| 148 |
+
# Draw the waterfall
|
| 149 |
+
x_pos = range(len(features) + 2)
|
| 150 |
+
colors = ['gray'] + ['red' if impact < 0 else 'blue' for impact in impacts] + ['green']
|
| 151 |
+
|
| 152 |
+
# Base probability bar
|
| 153 |
+
ax.bar(0, base_probability, color='gray', alpha=0.7, label='Base Probability')
|
| 154 |
+
ax.text(0, base_probability/2, f'{base_probability:.3f}', ha='center', va='center', fontweight='bold')
|
| 155 |
+
|
| 156 |
+
# Feature impact bars
|
| 157 |
+
for i, (feature, impact, cum_val) in enumerate(zip(features, impacts, cumulative[1:-1])):
|
| 158 |
+
start_height = cumulative[i]
|
| 159 |
+
ax.bar(i+1, impact, bottom=start_height,
|
| 160 |
+
color='red' if impact < 0 else 'blue', alpha=0.7)
|
| 161 |
+
|
| 162 |
+
# Add connecting lines
|
| 163 |
+
if i > 0:
|
| 164 |
+
ax.plot([i, i+1], [cumulative[i], cumulative[i]], 'k--', alpha=0.5)
|
| 165 |
+
|
| 166 |
+
# Add value label
|
| 167 |
+
label_y = start_height + impact/2
|
| 168 |
+
ax.text(i+1, label_y, f'{impact:+.3f}', ha='center', va='center',
|
| 169 |
+
fontweight='bold', color='white')
|
| 170 |
+
|
| 171 |
+
# Final prediction bar
|
| 172 |
+
final_prob = cumulative[-1]
|
| 173 |
+
ax.bar(len(features)+1, final_prob, color='green', alpha=0.7, label='Final Prediction')
|
| 174 |
+
ax.text(len(features)+1, final_prob/2, f'{final_prob:.3f}', ha='center', va='center', fontweight='bold')
|
| 175 |
+
|
| 176 |
+
# Customize plot
|
| 177 |
+
ax.set_xticks(x_pos)
|
| 178 |
+
ax.set_xticklabels(['Base'] + features + ['Final'], rotation=45, ha='right')
|
| 179 |
+
ax.set_ylabel('Probability')
|
| 180 |
+
ax.set_title(f'SHAP Waterfall Plot - Prediction: {prediction_class}', fontsize=14, fontweight='bold')
|
| 181 |
+
ax.grid(True, alpha=0.3, axis='y')
|
| 182 |
+
ax.legend()
|
| 183 |
+
|
| 184 |
+
plt.tight_layout()
|
| 185 |
+
return fig
|
| 186 |
+
|
| 187 |
+
except Exception as e:
|
| 188 |
+
st.error(f"Error creating waterfall plot: {e}")
|
| 189 |
+
return None
|
| 190 |
+
|
| 191 |
+
def display_shap_explanation(explanation_result):
|
| 192 |
+
"""
|
| 193 |
+
Display SHAP explanation with visualizations (only called when show_shap_visualizations=True)
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
explanation_result: Dict with SHAP explanation data
|
| 197 |
+
"""
|
| 198 |
+
if explanation_result.get('type') != 'shap':
|
| 199 |
+
return
|
| 200 |
+
|
| 201 |
+
# Visual explanations - show plots
|
| 202 |
+
if 'feature_impacts' in explanation_result and explanation_result['feature_impacts']:
|
| 203 |
+
|
| 204 |
+
# Create tabs for different visualizations
|
| 205 |
+
tab1, tab2 = st.tabs(["📊 Feature Impact", "🌊 Waterfall Analysis"])
|
| 206 |
+
|
| 207 |
+
with tab1:
|
| 208 |
+
st.write("**How each feature affects the prediction:**")
|
| 209 |
+
try:
|
| 210 |
+
fig1 = create_shap_bar_plot(
|
| 211 |
+
explanation_result['feature_impacts'],
|
| 212 |
+
explanation_result.get('prediction_class', 'Unknown'),
|
| 213 |
+
"Feature Importance Analysis"
|
| 214 |
+
)
|
| 215 |
+
if fig1:
|
| 216 |
+
st.pyplot(fig1)
|
| 217 |
+
plt.close(fig1) # Clean up memory
|
| 218 |
+
else:
|
| 219 |
+
st.warning("Unable to generate feature impact chart")
|
| 220 |
+
except Exception as e:
|
| 221 |
+
st.error(f"Error creating feature impact chart: {str(e)}")
|
| 222 |
+
|
| 223 |
+
with tab2:
|
| 224 |
+
st.write("**Step-by-step impact on prediction probability:**")
|
| 225 |
+
try:
|
| 226 |
+
fig2 = create_shap_waterfall_plot(
|
| 227 |
+
explanation_result['feature_impacts'],
|
| 228 |
+
base_probability=0.5,
|
| 229 |
+
prediction_class=explanation_result.get('prediction_class', 'Unknown')
|
| 230 |
+
)
|
| 231 |
+
if fig2:
|
| 232 |
+
st.pyplot(fig2)
|
| 233 |
+
plt.close(fig2) # Clean up memory
|
| 234 |
+
else:
|
| 235 |
+
st.warning("Unable to generate waterfall chart")
|
| 236 |
+
except Exception as e:
|
| 237 |
+
st.error(f"Error creating waterfall chart: {str(e)}")
|
| 238 |
+
|
| 239 |
+
# Feature impact breakdown
|
| 240 |
+
st.write("### 📋 Detailed Feature Impacts")
|
| 241 |
+
try:
|
| 242 |
+
impacts_df = pd.DataFrame({
|
| 243 |
+
'Feature Impact': explanation_result['feature_impacts']
|
| 244 |
+
})
|
| 245 |
+
st.dataframe(impacts_df, use_container_width=True)
|
| 246 |
+
except Exception as e:
|
| 247 |
+
st.error(f"Error displaying feature impacts table: {str(e)}")
|
| 248 |
+
|
| 249 |
+
def explain_shap_visualizations():
|
| 250 |
+
"""Provide educational content about SHAP visualizations"""
|
| 251 |
+
with st.expander("ℹ️ Understanding SHAP Visualizations"):
|
| 252 |
+
st.write("""
|
| 253 |
+
**SHAP (SHapley Additive exPlanations)** helps you understand how each feature contributed to your prediction:
|
| 254 |
+
|
| 255 |
+
**📊 Feature Impact Chart:**
|
| 256 |
+
- **Blue bars** = Features that *increase* the likelihood of approval
|
| 257 |
+
- **Red bars** = Features that *decrease* the likelihood of approval
|
| 258 |
+
- **Longer bars** = Stronger impact on the decision
|
| 259 |
+
|
| 260 |
+
**🌊 Waterfall Analysis:**
|
| 261 |
+
- Shows step-by-step how each feature moves the probability up or down
|
| 262 |
+
- Starts with base probability and shows cumulative effect
|
| 263 |
+
- Final bar shows the overall prediction probability
|
| 264 |
+
|
| 265 |
+
**Why this matters:**
|
| 266 |
+
- Understand *exactly* what factors influenced your decision
|
| 267 |
+
- See which changes would have the biggest impact
|
| 268 |
+
- Make informed decisions about improving your profile
|
| 269 |
+
""")
|
src/streamlit_app.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
import altair as alt
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pandas as pd
|
| 4 |
-
import streamlit as st
|
| 5 |
-
|
| 6 |
-
"""
|
| 7 |
-
# Welcome to Streamlit!
|
| 8 |
-
|
| 9 |
-
Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
|
| 10 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 11 |
-
forums](https://discuss.streamlit.io).
|
| 12 |
-
|
| 13 |
-
In the meantime, below is an example of what you can do with just a few lines of code:
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
|
| 17 |
-
num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
|
| 18 |
-
|
| 19 |
-
indices = np.linspace(0, 1, num_points)
|
| 20 |
-
theta = 2 * np.pi * num_turns * indices
|
| 21 |
-
radius = indices
|
| 22 |
-
|
| 23 |
-
x = radius * np.cos(theta)
|
| 24 |
-
y = radius * np.sin(theta)
|
| 25 |
-
|
| 26 |
-
df = pd.DataFrame({
|
| 27 |
-
"x": x,
|
| 28 |
-
"y": y,
|
| 29 |
-
"idx": indices,
|
| 30 |
-
"rand": np.random.randn(num_points),
|
| 31 |
-
})
|
| 32 |
-
|
| 33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
-
.mark_point(filled=True)
|
| 35 |
-
.encode(
|
| 36 |
-
x=alt.X("x", axis=None),
|
| 37 |
-
y=alt.Y("y", axis=None),
|
| 38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
-
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/train_classifiers.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import joblib
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from sklearn.model_selection import train_test_split
|
| 5 |
+
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, AdaBoostClassifier
|
| 6 |
+
from sklearn.svm import SVC
|
| 7 |
+
from sklearn.linear_model import LogisticRegression
|
| 8 |
+
from sklearn.metrics import classification_report, accuracy_score
|
| 9 |
+
from preprocessing import preprocess_adult
|
| 10 |
+
from load_adult_data import load_adult_data
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def train_and_evaluate(X, y, model, model_name, models_dir):
|
| 14 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
| 15 |
+
model.fit(X_train, y_train)
|
| 16 |
+
y_pred = model.predict(X_test)
|
| 17 |
+
print(f"\n{model_name} Results:")
|
| 18 |
+
print(classification_report(y_test, y_pred))
|
| 19 |
+
print(f"Accuracy: {accuracy_score(y_test, y_pred):.4f}")
|
| 20 |
+
# Save model
|
| 21 |
+
joblib.dump(model, os.path.join(models_dir, f'{model_name}.pkl'))
|
| 22 |
+
|
| 23 |
+
if __name__ == '__main__':
|
| 24 |
+
data_dir = os.path.join(os.path.dirname(__file__), '..', 'data')
|
| 25 |
+
models_dir = os.path.join(os.path.dirname(__file__), '..', 'models')
|
| 26 |
+
os.makedirs(models_dir, exist_ok=True)
|
| 27 |
+
df, _ = load_adult_data(data_dir)
|
| 28 |
+
df_clean = preprocess_adult(df)
|
| 29 |
+
X = df_clean.drop('income', axis=1)
|
| 30 |
+
y = df_clean['income']
|
| 31 |
+
|
| 32 |
+
classifiers = [
|
| 33 |
+
(RandomForestClassifier(n_estimators=100, random_state=42), 'RandomForest'),
|
| 34 |
+
(GradientBoostingClassifier(n_estimators=100, random_state=42), 'GradientBoosting'),
|
| 35 |
+
(AdaBoostClassifier(n_estimators=100, random_state=42), 'AdaBoost'),
|
| 36 |
+
(SVC(kernel='rbf', probability=True, random_state=42), 'SVM'),
|
| 37 |
+
(LogisticRegression(max_iter=1000, random_state=42), 'LogisticRegression')
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
for clf, name in classifiers:
|
| 41 |
+
train_and_evaluate(X, y, clf, name, models_dir)
|
src/utils.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import copy
|
| 3 |
+
import sklearn
|
| 4 |
+
import sklearn.preprocessing
|
| 5 |
+
import sklearn.model_selection
|
| 6 |
+
import numpy as np
|
| 7 |
+
import lime
|
| 8 |
+
import lime.lime_tabular
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
class Bunch(dict):
|
| 12 |
+
def __init__(self, *args, **kwargs):
|
| 13 |
+
super(Bunch, self).__init__(*args, **kwargs)
|
| 14 |
+
self.__dict__ = self
|
| 15 |
+
|
| 16 |
+
def load_dataset(dataset_name, balance=False, discretize=True, dataset_folder='./'):
|
| 17 |
+
if dataset_name == 'adult':
|
| 18 |
+
feature_names = ["Age", "Workclass", "fnlwgt", "Education",
|
| 19 |
+
"Education-Num", "Marital Status", "Occupation",
|
| 20 |
+
"Relationship", "Race", "Sex", "Capital Gain",
|
| 21 |
+
"Capital Loss", "Hours per week", "Country", 'Income']
|
| 22 |
+
features_to_use = [0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
| 23 |
+
categorical_features = [1, 5, 6, 7, 8, 9, 13]
|
| 24 |
+
dataset = load_csv_dataset(
|
| 25 |
+
os.path.join(dataset_folder, 'adult/adult.data'), -1, ', ',
|
| 26 |
+
feature_names=feature_names, features_to_use=features_to_use,
|
| 27 |
+
categorical_features=categorical_features, discretize=discretize,
|
| 28 |
+
balance=balance, feature_transformations=None)
|
| 29 |
+
elif dataset_name == 'german-credit':
|
| 30 |
+
categorical_features = [1, 2, 3, 4, 5, 8]
|
| 31 |
+
dataset = load_csv_dataset(
|
| 32 |
+
os.path.join(dataset_folder, 'german-credit/german_credit_data.csv'), -1, ',',
|
| 33 |
+
categorical_features=categorical_features, discretize=discretize,
|
| 34 |
+
balance=balance)
|
| 35 |
+
else:
|
| 36 |
+
raise ValueError(f"Unsupported dataset: {dataset_name}")
|
| 37 |
+
return dataset
|
| 38 |
+
|
| 39 |
+
def load_csv_dataset(data, target_idx, delimiter=',',
|
| 40 |
+
feature_names=None, categorical_features=None,
|
| 41 |
+
features_to_use=None, feature_transformations=None,
|
| 42 |
+
discretize=False, balance=False, fill_na='-1', filter_fn=None, skip_first=False):
|
| 43 |
+
if feature_transformations is None:
|
| 44 |
+
feature_transformations = {}
|
| 45 |
+
try:
|
| 46 |
+
data = np.genfromtxt(data, delimiter=delimiter, dtype='|S128')
|
| 47 |
+
except:
|
| 48 |
+
import pandas
|
| 49 |
+
data = pandas.read_csv(data,
|
| 50 |
+
header=None,
|
| 51 |
+
delimiter=delimiter,
|
| 52 |
+
na_filter=True,
|
| 53 |
+
dtype=str).fillna(fill_na).values
|
| 54 |
+
if target_idx < 0:
|
| 55 |
+
target_idx = data.shape[1] + target_idx
|
| 56 |
+
ret = Bunch({})
|
| 57 |
+
if feature_names is None:
|
| 58 |
+
feature_names = list(data[0])
|
| 59 |
+
data = data[1:]
|
| 60 |
+
else:
|
| 61 |
+
feature_names = copy.deepcopy(feature_names)
|
| 62 |
+
if skip_first:
|
| 63 |
+
data = data[1:]
|
| 64 |
+
if filter_fn is not None:
|
| 65 |
+
data = filter_fn(data)
|
| 66 |
+
for feature, fun in feature_transformations.items():
|
| 67 |
+
data[:, feature] = fun(data[:, feature])
|
| 68 |
+
labels = data[:, target_idx]
|
| 69 |
+
le = sklearn.preprocessing.LabelEncoder()
|
| 70 |
+
le.fit(labels)
|
| 71 |
+
ret['labels'] = le.transform(labels)
|
| 72 |
+
labels = ret['labels']
|
| 73 |
+
ret['class_names'] = list(le.classes_)
|
| 74 |
+
ret['class_target'] = feature_names[target_idx]
|
| 75 |
+
if features_to_use is not None:
|
| 76 |
+
data = data[:, features_to_use]
|
| 77 |
+
feature_names = ([x for i, x in enumerate(feature_names)
|
| 78 |
+
if i in features_to_use])
|
| 79 |
+
if categorical_features is not None:
|
| 80 |
+
categorical_features = ([features_to_use.index(x)
|
| 81 |
+
for x in categorical_features])
|
| 82 |
+
else:
|
| 83 |
+
data = np.delete(data, target_idx, 1)
|
| 84 |
+
feature_names.pop(target_idx)
|
| 85 |
+
if categorical_features:
|
| 86 |
+
categorical_features = ([x if x < target_idx else x - 1
|
| 87 |
+
for x in categorical_features])
|
| 88 |
+
if categorical_features is None:
|
| 89 |
+
categorical_features = []
|
| 90 |
+
for f in range(data.shape[1]):
|
| 91 |
+
if len(np.unique(data[:, f])) < 20:
|
| 92 |
+
categorical_features.append(f)
|
| 93 |
+
categorical_names = {}
|
| 94 |
+
for feature in categorical_features:
|
| 95 |
+
le = sklearn.preprocessing.LabelEncoder()
|
| 96 |
+
le.fit(data[:, feature])
|
| 97 |
+
data[:, feature] = le.transform(data[:, feature])
|
| 98 |
+
categorical_names[feature] = le.classes_
|
| 99 |
+
data = data.astype(float)
|
| 100 |
+
ordinal_features = []
|
| 101 |
+
if discretize:
|
| 102 |
+
disc = lime.lime_tabular.QuartileDiscretizer(data,
|
| 103 |
+
categorical_features,
|
| 104 |
+
feature_names)
|
| 105 |
+
data = disc.discretize(data)
|
| 106 |
+
ordinal_features = [x for x in range(data.shape[1])
|
| 107 |
+
if x not in categorical_features]
|
| 108 |
+
categorical_features = list(range(data.shape[1]))
|
| 109 |
+
categorical_names.update(disc.names)
|
| 110 |
+
for x in categorical_names:
|
| 111 |
+
categorical_names[x] = [y.decode() if type(y) == np.bytes_ else y for y in categorical_names[x]]
|
| 112 |
+
ret['ordinal_features'] = ordinal_features
|
| 113 |
+
ret['categorical_features'] = categorical_features
|
| 114 |
+
ret['categorical_names'] = categorical_names
|
| 115 |
+
ret['feature_names'] = feature_names
|
| 116 |
+
np.random.seed(1)
|
| 117 |
+
if balance:
|
| 118 |
+
idxs = np.array([], dtype='int')
|
| 119 |
+
min_labels = np.min(np.bincount(labels))
|
| 120 |
+
for label in np.unique(labels):
|
| 121 |
+
idx = np.random.choice(np.where(labels == label)[0], min_labels)
|
| 122 |
+
idxs = np.hstack((idxs, idx))
|
| 123 |
+
data = data[idxs]
|
| 124 |
+
labels = labels[idxs]
|
| 125 |
+
ret['data'] = data
|
| 126 |
+
ret['labels'] = labels
|
| 127 |
+
splits = sklearn.model_selection.ShuffleSplit(n_splits=1,
|
| 128 |
+
test_size=.2,
|
| 129 |
+
random_state=1)
|
| 130 |
+
train_idx, test_idx = [x for x in splits.split(data)][0]
|
| 131 |
+
ret['train'] = data[train_idx]
|
| 132 |
+
ret['labels_train'] = labels[train_idx]
|
| 133 |
+
cv_splits = sklearn.model_selection.ShuffleSplit(n_splits=1,
|
| 134 |
+
test_size=.5,
|
| 135 |
+
random_state=1)
|
| 136 |
+
cv_idx, ntest_idx = [x for x in cv_splits.split(test_idx)][0]
|
| 137 |
+
cv_idx = test_idx[cv_idx]
|
| 138 |
+
test_idx = test_idx[ntest_idx]
|
| 139 |
+
ret['validation'] = data[cv_idx]
|
| 140 |
+
ret['labels_validation'] = labels[cv_idx]
|
| 141 |
+
ret['test'] = data[test_idx]
|
| 142 |
+
ret['labels_test'] = labels[test_idx]
|
| 143 |
+
ret['test_idx'] = test_idx
|
| 144 |
+
ret['validation_idx'] = cv_idx
|
| 145 |
+
ret['train_idx'] = train_idx
|
| 146 |
+
ret['data'] = data
|
| 147 |
+
return ret
|
| 148 |
+
|
| 149 |
+
import logging
|
| 150 |
+
|
| 151 |
+
def print_log(turn, msg=None, state=None):
|
| 152 |
+
if turn == "xagent":
|
| 153 |
+
print(f"\033[1m\033[94mX-Agent:\033[0m")
|
| 154 |
+
if msg is not None:
|
| 155 |
+
print(msg)
|
| 156 |
+
if turn == "user":
|
| 157 |
+
print('\033[91m\033[1mUser:\033[0m')
|
| 158 |
+
msg = input()
|
| 159 |
+
logging.log(25, f"{turn}: {msg}")
|
| 160 |
+
if state is not None:
|
| 161 |
+
logging.log(25, state)
|
| 162 |
+
return msg
|
| 163 |
+
|
| 164 |
+
def ask_for_feature(agent):
|
| 165 |
+
if len(agent.l_exist_features) == 0:
|
| 166 |
+
msg = "which feature?"
|
| 167 |
+
print_log("xagent", msg)
|
| 168 |
+
user_input = print_log("user")
|
| 169 |
+
while user_input not in agent.l_features:
|
| 170 |
+
msg = f"please choose one of the following features: {agent.l_features}"
|
| 171 |
+
print_log("xagent", msg)
|
| 172 |
+
user_input = print_log("user")
|
| 173 |
+
agent.l_exist_features.append(user_input)
|
| 174 |
+
|
| 175 |
+
def map_array_values(array, value_map):
|
| 176 |
+
ret = array.copy()
|
| 177 |
+
for src, target in value_map.items():
|
| 178 |
+
ret[ret == src] = target
|
| 179 |
+
return ret
|
| 180 |
+
|
| 181 |
+
def replace_binary_values(array, values):
|
| 182 |
+
return map_array_values(array, {'0': values[0], '1': values[1]})
|
| 183 |
+
|
| 184 |
+
def log_user_feedback(feedback, save_path):
|
| 185 |
+
# Save feedback to a local file (append mode)
|
| 186 |
+
try:
|
| 187 |
+
with open(save_path, 'a', encoding='utf-8') as f:
|
| 188 |
+
f.write(str(feedback) + '\n')
|
| 189 |
+
except Exception as e:
|
| 190 |
+
print(f"Error saving feedback: {e}")
|
src/xai_methods.py
ADDED
|
@@ -0,0 +1,1028 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import shap
|
| 2 |
+
import numpy as np
|
| 3 |
+
import dice_ml
|
| 4 |
+
from anchor import anchor_tabular
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import os
|
| 7 |
+
from constraints import *
|
| 8 |
+
|
| 9 |
+
# Mode selection: 'full' requires dtreeviz; 'lite' skips it (good for Streamlit)
|
| 10 |
+
_MODE = os.getenv('HICXAI_MODE', 'lite').strip().lower()
|
| 11 |
+
|
| 12 |
+
# User-friendly feature name mappings (for international users)
|
| 13 |
+
FEATURE_DISPLAY_NAMES = {
|
| 14 |
+
# Workclass (employment type)
|
| 15 |
+
'workclass_Private': 'Private sector',
|
| 16 |
+
'workclass_Self-emp-not-inc': 'Self-employed',
|
| 17 |
+
'workclass_Self-emp-inc': 'Self-employed (business owner)',
|
| 18 |
+
'workclass_Federal-gov': 'Federal government',
|
| 19 |
+
'workclass_Local-gov': 'Local government',
|
| 20 |
+
'workclass_State-gov': 'State government',
|
| 21 |
+
'workclass_Without-pay': 'Unpaid work',
|
| 22 |
+
'workclass_Never-worked': 'Never worked',
|
| 23 |
+
|
| 24 |
+
# Education
|
| 25 |
+
'education_Preschool': 'Preschool',
|
| 26 |
+
'education_1st-4th': 'Elementary (1-4 years)',
|
| 27 |
+
'education_5th-6th': 'Elementary (5-6 years)',
|
| 28 |
+
'education_7th-8th': 'Middle school (7-8 years)',
|
| 29 |
+
'education_9th': 'High school (9th year)',
|
| 30 |
+
'education_10th': 'High school (10th year)',
|
| 31 |
+
'education_11th': 'High school (11th year)',
|
| 32 |
+
'education_12th': 'High school (12th year)',
|
| 33 |
+
'education_HS-grad': 'High school graduate',
|
| 34 |
+
'education_Some-college': 'Some college',
|
| 35 |
+
'education_Assoc-voc': 'Vocational degree',
|
| 36 |
+
'education_Assoc-acdm': 'Associate degree',
|
| 37 |
+
'education_Bachelors': 'Bachelor\'s degree',
|
| 38 |
+
'education_Masters': 'Master\'s degree',
|
| 39 |
+
'education_Prof-school': 'Professional degree',
|
| 40 |
+
'education_Doctorate': 'Doctorate',
|
| 41 |
+
'education_num': 'Education level',
|
| 42 |
+
|
| 43 |
+
# Marital status
|
| 44 |
+
'marital_status_Married-civ-spouse': 'Married',
|
| 45 |
+
'marital_status_Married-spouse-absent': 'Married (separated)',
|
| 46 |
+
'marital_status_Married-AF-spouse': 'Married (military)',
|
| 47 |
+
'marital_status_Never-married': 'Never married',
|
| 48 |
+
'marital_status_Divorced': 'Divorced',
|
| 49 |
+
'marital_status_Separated': 'Separated',
|
| 50 |
+
'marital_status_Widowed': 'Widowed',
|
| 51 |
+
|
| 52 |
+
# Occupation
|
| 53 |
+
'occupation_Tech-support': 'Technical support',
|
| 54 |
+
'occupation_Craft-repair': 'Skilled trades',
|
| 55 |
+
'occupation_Other-service': 'Service worker',
|
| 56 |
+
'occupation_Sales': 'Sales',
|
| 57 |
+
'occupation_Exec-managerial': 'Executive/Manager',
|
| 58 |
+
'occupation_Prof-specialty': 'Professional',
|
| 59 |
+
'occupation_Handlers-cleaners': 'Handler/Cleaner',
|
| 60 |
+
'occupation_Machine-op-inspct': 'Machine operator',
|
| 61 |
+
'occupation_Adm-clerical': 'Administrative',
|
| 62 |
+
'occupation_Farming-fishing': 'Farming/Fishing',
|
| 63 |
+
'occupation_Transport-moving': 'Transportation',
|
| 64 |
+
'occupation_Priv-house-serv': 'Household service',
|
| 65 |
+
'occupation_Protective-serv': 'Protective services',
|
| 66 |
+
'occupation_Armed-Forces': 'Military',
|
| 67 |
+
|
| 68 |
+
# Relationship
|
| 69 |
+
'relationship_Husband': 'Husband',
|
| 70 |
+
'relationship_Wife': 'Wife',
|
| 71 |
+
'relationship_Own-child': 'Child',
|
| 72 |
+
'relationship_Not-in-family': 'Not in family',
|
| 73 |
+
'relationship_Other-relative': 'Other relative',
|
| 74 |
+
'relationship_Unmarried': 'Unmarried partner',
|
| 75 |
+
|
| 76 |
+
# Race/Ethnicity
|
| 77 |
+
'race_White': 'White',
|
| 78 |
+
'race_Black': 'Black',
|
| 79 |
+
'race_Asian-Pac-Islander': 'Asian/Pacific Islander',
|
| 80 |
+
'race_Amer-Indian-Eskimo': 'Indigenous American',
|
| 81 |
+
'race_Other': 'Other',
|
| 82 |
+
|
| 83 |
+
# Sex
|
| 84 |
+
'sex_Male': 'Male',
|
| 85 |
+
'sex_Female': 'Female',
|
| 86 |
+
|
| 87 |
+
# Native Country
|
| 88 |
+
'native_country_United-States': 'United States',
|
| 89 |
+
'native_country_Cambodia': 'Cambodia',
|
| 90 |
+
'native_country_Canada': 'Canada',
|
| 91 |
+
'native_country_China': 'China',
|
| 92 |
+
'native_country_Columbia': 'Colombia',
|
| 93 |
+
'native_country_Cuba': 'Cuba',
|
| 94 |
+
'native_country_Dominican-Republic': 'Dominican Republic',
|
| 95 |
+
'native_country_Ecuador': 'Ecuador',
|
| 96 |
+
'native_country_El-Salvador': 'El Salvador',
|
| 97 |
+
'native_country_England': 'England',
|
| 98 |
+
'native_country_France': 'France',
|
| 99 |
+
'native_country_Germany': 'Germany',
|
| 100 |
+
'native_country_Greece': 'Greece',
|
| 101 |
+
'native_country_Guatemala': 'Guatemala',
|
| 102 |
+
'native_country_Haiti': 'Haiti',
|
| 103 |
+
'native_country_Holand-Netherlands': 'Netherlands',
|
| 104 |
+
'native_country_Honduras': 'Honduras',
|
| 105 |
+
'native_country_Hong': 'Hong Kong',
|
| 106 |
+
'native_country_Hungary': 'Hungary',
|
| 107 |
+
'native_country_India': 'India',
|
| 108 |
+
'native_country_Iran': 'Iran',
|
| 109 |
+
'native_country_Ireland': 'Ireland',
|
| 110 |
+
'native_country_Italy': 'Italy',
|
| 111 |
+
'native_country_Jamaica': 'Jamaica',
|
| 112 |
+
'native_country_Japan': 'Japan',
|
| 113 |
+
'native_country_Laos': 'Laos',
|
| 114 |
+
'native_country_Mexico': 'Mexico',
|
| 115 |
+
'native_country_Nicaragua': 'Nicaragua',
|
| 116 |
+
'native_country_Outlying-US(Guam-USVI-etc)': 'US Territory (Guam, Virgin Islands)',
|
| 117 |
+
'native_country_Peru': 'Peru',
|
| 118 |
+
'native_country_Philippines': 'Philippines',
|
| 119 |
+
'native_country_Poland': 'Poland',
|
| 120 |
+
'native_country_Portugal': 'Portugal',
|
| 121 |
+
'native_country_Puerto-Rico': 'Puerto Rico',
|
| 122 |
+
'native_country_Scotland': 'Scotland',
|
| 123 |
+
'native_country_South': 'South Korea',
|
| 124 |
+
'native_country_Taiwan': 'Taiwan',
|
| 125 |
+
'native_country_Thailand': 'Thailand',
|
| 126 |
+
'native_country_Trinadad&Tobago': 'Trinidad & Tobago',
|
| 127 |
+
'native_country_Vietnam': 'Vietnam',
|
| 128 |
+
'native_country_Yugoslavia': 'Former Yugoslavia',
|
| 129 |
+
|
| 130 |
+
# Numerical features
|
| 131 |
+
'age': 'Age',
|
| 132 |
+
'fnlwgt': 'Census weight',
|
| 133 |
+
'capital_gain': 'Capital gains',
|
| 134 |
+
'capital_loss': 'Capital losses',
|
| 135 |
+
'hours_per_week': 'Work hours per week',
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
def get_friendly_feature_name(feature_name):
|
| 139 |
+
"""Convert technical feature name to user-friendly display name"""
|
| 140 |
+
return FEATURE_DISPLAY_NAMES.get(feature_name, feature_name.replace('_', ' ').title())
|
| 141 |
+
|
| 142 |
+
# Visualization deps
|
| 143 |
+
try:
|
| 144 |
+
import dtreeviz # noqa: F401
|
| 145 |
+
import graphviz # noqa: F401
|
| 146 |
+
_DTREEVIZ_AVAILABLE = True
|
| 147 |
+
except Exception:
|
| 148 |
+
_DTREEVIZ_AVAILABLE = False
|
| 149 |
+
if _MODE == 'full':
|
| 150 |
+
raise ImportError(
|
| 151 |
+
"dtreeviz/graphviz are required in FULL mode. Install with conda: 'conda install -c conda-forge graphviz python-graphviz' and pip: 'pip install dtreeviz'"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
def explain_with_shap(agent, question_id=None):
|
| 155 |
+
"""SHAP explanation using actual SHAP values from the model"""
|
| 156 |
+
try:
|
| 157 |
+
from ab_config import config
|
| 158 |
+
import pandas as pd
|
| 159 |
+
|
| 160 |
+
predicted_class = getattr(agent, 'predicted_class', 'unknown')
|
| 161 |
+
current_instance = agent.current_instance
|
| 162 |
+
|
| 163 |
+
# Get LOCAL SHAP values in probability space
|
| 164 |
+
# This shows how much each feature contributed to THIS user's prediction
|
| 165 |
+
# Note: agent.data['X_display'] contains RAW data; model was trained on PREPROCESSED data
|
| 166 |
+
# Get feature names from the trained model
|
| 167 |
+
if hasattr(agent.clf_display, 'feature_names_in_'):
|
| 168 |
+
feature_names = agent.clf_display.feature_names_in_.tolist()
|
| 169 |
+
else:
|
| 170 |
+
# Fallback: use raw feature names (will likely fail if model is trained on encoded data)
|
| 171 |
+
feature_names = agent.data['X_display'].columns.tolist()
|
| 172 |
+
|
| 173 |
+
shap_values_computed = None
|
| 174 |
+
instance_df = None
|
| 175 |
+
shap_contributions = {} # Feature -> contribution in probability space (percentage points)
|
| 176 |
+
base_value = None
|
| 177 |
+
pred_prob = None
|
| 178 |
+
|
| 179 |
+
# Compute SHAP in probability space (FAST - no hanging with TreeExplainer)
|
| 180 |
+
try:
|
| 181 |
+
# Prepare instance data
|
| 182 |
+
# current_instance should already be preprocessed (with one-hot encoded columns)
|
| 183 |
+
if current_instance is not None:
|
| 184 |
+
if hasattr(current_instance, 'to_frame'):
|
| 185 |
+
instance_df = current_instance.to_frame().T
|
| 186 |
+
elif hasattr(current_instance, 'to_dict'):
|
| 187 |
+
instance_df = pd.DataFrame([current_instance.to_dict()])
|
| 188 |
+
elif isinstance(current_instance, dict):
|
| 189 |
+
instance_df = pd.DataFrame([current_instance])
|
| 190 |
+
else:
|
| 191 |
+
instance_df = pd.DataFrame([current_instance])
|
| 192 |
+
|
| 193 |
+
# Ensure column order matches training data
|
| 194 |
+
# Add missing columns with 0 (for one-hot encoded features not present)
|
| 195 |
+
for col in feature_names:
|
| 196 |
+
if col not in instance_df.columns:
|
| 197 |
+
instance_df[col] = 0
|
| 198 |
+
instance_df = instance_df[feature_names]
|
| 199 |
+
|
| 200 |
+
# Initialize TreeExplainer (returns probability space for RandomForest)
|
| 201 |
+
explainer = shap.TreeExplainer(agent.clf_display)
|
| 202 |
+
|
| 203 |
+
# Compute local SHAP values for this instance
|
| 204 |
+
shap_values = explainer.shap_values(instance_df)
|
| 205 |
+
base_value_raw = explainer.expected_value
|
| 206 |
+
|
| 207 |
+
# Get predicted probability
|
| 208 |
+
pred_prob = float(agent.clf_display.predict_proba(instance_df)[0, 1])
|
| 209 |
+
|
| 210 |
+
# Extract SHAP contributions (percentage points) for positive class
|
| 211 |
+
# TreeExplainer returns probabilities directly for tree-based models
|
| 212 |
+
if isinstance(shap_values, list):
|
| 213 |
+
# Binary classification: [negative_class_shap, positive_class_shap]
|
| 214 |
+
shap_vals_array = shap_values[1][0]
|
| 215 |
+
base_value = float(base_value_raw[1])
|
| 216 |
+
else:
|
| 217 |
+
# Shape: (n_samples, n_features, n_classes) or (n_features, n_classes)
|
| 218 |
+
if len(shap_values.shape) == 3:
|
| 219 |
+
shap_vals_array = shap_values[0, :, 1]
|
| 220 |
+
base_value = float(base_value_raw[1])
|
| 221 |
+
else:
|
| 222 |
+
shap_vals_array = shap_values[:, 1]
|
| 223 |
+
base_value = float(base_value_raw[1])
|
| 224 |
+
|
| 225 |
+
# Store contributions in dictionary
|
| 226 |
+
for idx, feature in enumerate(feature_names):
|
| 227 |
+
shap_contributions[feature] = float(shap_vals_array[idx])
|
| 228 |
+
|
| 229 |
+
shap_values_computed = shap_vals_array
|
| 230 |
+
|
| 231 |
+
# Sanity check: contributions should sum approximately to prediction
|
| 232 |
+
approx_prob = base_value + sum(shap_contributions.values())
|
| 233 |
+
if abs(approx_prob - pred_prob) > 0.05:
|
| 234 |
+
print(f"Warning: SHAP additivity check: {approx_prob:.3f} vs {pred_prob:.3f}")
|
| 235 |
+
|
| 236 |
+
except Exception as e:
|
| 237 |
+
print(f"SHAP computation failed: {e}")
|
| 238 |
+
# Fallback to feature importances
|
| 239 |
+
if hasattr(agent.clf_display, 'feature_importances_'):
|
| 240 |
+
importances = agent.clf_display.feature_importances_
|
| 241 |
+
for idx, feature in enumerate(feature_names):
|
| 242 |
+
if importances[idx] > 0.001:
|
| 243 |
+
shap_contributions[feature] = float(importances[idx])
|
| 244 |
+
# Get prediction probability for fallback
|
| 245 |
+
if instance_df is not None:
|
| 246 |
+
pred_prob = float(agent.clf_display.predict_proba(instance_df)[0, 1])
|
| 247 |
+
base_value = 0.5 # Reasonable baseline
|
| 248 |
+
|
| 249 |
+
# Build natural language explanation with actual user values
|
| 250 |
+
feature_impacts = []
|
| 251 |
+
positive_factors = []
|
| 252 |
+
negative_factors = []
|
| 253 |
+
|
| 254 |
+
# Convert Series to dict if needed for easier access
|
| 255 |
+
instance_dict = None
|
| 256 |
+
if current_instance is not None:
|
| 257 |
+
if hasattr(current_instance, 'to_dict'):
|
| 258 |
+
instance_dict = current_instance.to_dict()
|
| 259 |
+
elif isinstance(current_instance, dict):
|
| 260 |
+
instance_dict = current_instance
|
| 261 |
+
else:
|
| 262 |
+
# Fallback: try to convert to dict
|
| 263 |
+
try:
|
| 264 |
+
instance_dict = dict(current_instance)
|
| 265 |
+
except:
|
| 266 |
+
instance_dict = {}
|
| 267 |
+
|
| 268 |
+
# For categorical features that are one-hot encoded, we need to find the original value
|
| 269 |
+
# by checking which encoded column has value 1
|
| 270 |
+
def get_categorical_value(feature_base):
|
| 271 |
+
"""Extract original categorical value from one-hot encoded columns"""
|
| 272 |
+
if not instance_dict:
|
| 273 |
+
return None
|
| 274 |
+
# Look for columns like 'workclass_Private', 'workclass_Self-emp-not-inc'
|
| 275 |
+
matching_cols = [col for col in instance_dict.keys() if col.startswith(f"{feature_base}_")]
|
| 276 |
+
for col in matching_cols:
|
| 277 |
+
if instance_dict.get(col) == 1 or instance_dict.get(col) == 1.0:
|
| 278 |
+
# Extract the value after the underscore
|
| 279 |
+
return col.split(f"{feature_base}_", 1)[1] if "_" in col else None
|
| 280 |
+
return None
|
| 281 |
+
|
| 282 |
+
# Check if we have any SHAP contribution data
|
| 283 |
+
if not shap_contributions:
|
| 284 |
+
return {
|
| 285 |
+
'type': 'error',
|
| 286 |
+
'explanation': "Unable to compute SHAP contributions. The model may not have sufficient data.",
|
| 287 |
+
'error': 'No SHAP values computed'
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
# Sort by absolute contribution (most impactful features)
|
| 291 |
+
sorted_features = sorted(shap_contributions.items(), key=lambda x: abs(x[1]), reverse=True)
|
| 292 |
+
|
| 293 |
+
# Prioritize capital_gain if user has significant gains (moves it to top of list)
|
| 294 |
+
capital_gain_val = instance_dict.get('capital_gain', 0) if instance_dict else 0
|
| 295 |
+
if capital_gain_val > 5000: # Significant capital gains
|
| 296 |
+
# Find capital_gain in sorted features and move to front
|
| 297 |
+
capital_idx = next((i for i, (f, _) in enumerate(sorted_features) if f == 'capital_gain'), None)
|
| 298 |
+
if capital_idx is not None and capital_idx > 0:
|
| 299 |
+
capital_item = sorted_features.pop(capital_idx)
|
| 300 |
+
sorted_features.insert(0, capital_item)
|
| 301 |
+
|
| 302 |
+
for feature, impact in sorted_features[:15]: # Check more features to get valid ones
|
| 303 |
+
# Skip technical features first (before any processing)
|
| 304 |
+
if feature in ['fnlwgt', 'education_num']: # fnlwgt is census weight, education_num is redundant
|
| 305 |
+
continue
|
| 306 |
+
|
| 307 |
+
# Check if this is a one-hot encoded feature (e.g., workclass_Private)
|
| 308 |
+
categorical_prefixes = ['workclass_', 'education_', 'marital_status_', 'occupation_',
|
| 309 |
+
'relationship_', 'race_', 'sex_', 'native_country_']
|
| 310 |
+
|
| 311 |
+
is_onehot = any(feature.startswith(prefix) for prefix in categorical_prefixes)
|
| 312 |
+
|
| 313 |
+
if is_onehot:
|
| 314 |
+
# Extract base feature and value (e.g., 'workclass_Private' -> base='workclass', value='Private')
|
| 315 |
+
for prefix in categorical_prefixes:
|
| 316 |
+
if feature.startswith(prefix):
|
| 317 |
+
feature_base = prefix.rstrip('_')
|
| 318 |
+
actual_value = feature.replace(prefix, '')
|
| 319 |
+
break
|
| 320 |
+
else:
|
| 321 |
+
# Regular numeric feature
|
| 322 |
+
actual_value = instance_dict.get(feature, None) if instance_dict else None
|
| 323 |
+
feature_base = feature
|
| 324 |
+
|
| 325 |
+
# Skip if value is missing
|
| 326 |
+
if actual_value is None or str(actual_value).strip() == '':
|
| 327 |
+
continue
|
| 328 |
+
|
| 329 |
+
# Create natural language description using GLOBAL FEATURE_DISPLAY_NAMES
|
| 330 |
+
friendly_feature = get_friendly_feature_name(feature if not is_onehot else feature_base)
|
| 331 |
+
|
| 332 |
+
# Format value with appropriate units/formatting
|
| 333 |
+
if feature_base == 'age':
|
| 334 |
+
formatted_value = f"{actual_value} years old"
|
| 335 |
+
elif feature_base == 'hours_per_week':
|
| 336 |
+
formatted_value = f"{actual_value} hours per week"
|
| 337 |
+
elif feature_base == 'capital_gain' or feature_base == 'capital_loss':
|
| 338 |
+
formatted_value = f"${actual_value:,}" if isinstance(actual_value, (int, float)) else str(actual_value)
|
| 339 |
+
else:
|
| 340 |
+
formatted_value = str(actual_value)
|
| 341 |
+
|
| 342 |
+
factor_desc = f"Your {friendly_feature.lower()} ({formatted_value})"
|
| 343 |
+
|
| 344 |
+
if impact > 0:
|
| 345 |
+
positive_factors.append(factor_desc)
|
| 346 |
+
feature_impacts.append(f"{feature} increases the prediction probability by {impact:.3f}")
|
| 347 |
+
else:
|
| 348 |
+
negative_factors.append(factor_desc)
|
| 349 |
+
feature_impacts.append(f"{feature} decreases the prediction probability by {abs(impact):.3f}")
|
| 350 |
+
|
| 351 |
+
# Stop once we have enough features (8-10 total)
|
| 352 |
+
if len(positive_factors) + len(negative_factors) >= 10:
|
| 353 |
+
break
|
| 354 |
+
|
| 355 |
+
# Generate explanation with REASONING based on approval/denial
|
| 356 |
+
# Extract key values for reasoning
|
| 357 |
+
def fmt_money(x):
|
| 358 |
+
return f"${x:,.0f}" if isinstance(x, (int, float)) else "N/A"
|
| 359 |
+
|
| 360 |
+
cg = instance_dict.get('capital_gain') if instance_dict else None
|
| 361 |
+
cl = instance_dict.get('capital_loss') if instance_dict else None
|
| 362 |
+
age = instance_dict.get('age') if instance_dict else None
|
| 363 |
+
hrs = instance_dict.get('hours_per_week') if instance_dict else None
|
| 364 |
+
edu = instance_dict.get('education') if instance_dict else None
|
| 365 |
+
|
| 366 |
+
# Determine if approved - check the actual loan decision, not model prediction
|
| 367 |
+
# The model predicts income level (>50K or <=50K), but loan approval is a separate business decision
|
| 368 |
+
if hasattr(agent, 'loan_approved') and agent.loan_approved is not None:
|
| 369 |
+
approved = agent.loan_approved
|
| 370 |
+
elif predicted_class in ['>50K', '1']:
|
| 371 |
+
# If >50K income, likely approved
|
| 372 |
+
approved = True
|
| 373 |
+
else:
|
| 374 |
+
# If <=50K income, likely denied
|
| 375 |
+
approved = False
|
| 376 |
+
|
| 377 |
+
# Build explanation with REASONING
|
| 378 |
+
# KEY INSIGHT: All features except capital_loss are positively correlated with approval
|
| 379 |
+
# They might not be "enough" but they don't hurt - only capital_loss can truly hurt
|
| 380 |
+
|
| 381 |
+
# Collect top features with their values
|
| 382 |
+
top_feature_list = []
|
| 383 |
+
for feature, impact in sorted_features[:8]:
|
| 384 |
+
# Get actual value
|
| 385 |
+
if feature in instance_dict:
|
| 386 |
+
value = instance_dict[feature]
|
| 387 |
+
else:
|
| 388 |
+
# Handle one-hot encoded
|
| 389 |
+
for prefix in ['workclass_', 'education_', 'marital_status_', 'occupation_', 'relationship_', 'race_', 'sex_', 'native_country_']:
|
| 390 |
+
if feature.startswith(prefix):
|
| 391 |
+
value = feature.replace(prefix, '')
|
| 392 |
+
break
|
| 393 |
+
else:
|
| 394 |
+
value = None
|
| 395 |
+
|
| 396 |
+
if value is not None:
|
| 397 |
+
top_feature_list.append((feature, value, impact))
|
| 398 |
+
|
| 399 |
+
# Approval threshold
|
| 400 |
+
tau = 0.50
|
| 401 |
+
gap_to_threshold = max(0.0, tau - pred_prob) if pred_prob is not None else 0.0
|
| 402 |
+
|
| 403 |
+
# ===== DATA-DRIVEN APPROACH: Extract structured data for LLM =====
|
| 404 |
+
# Separate positive and negative contributions
|
| 405 |
+
positive_contribs = [(f, v, delta) for f, v, delta in top_feature_list if delta > 0]
|
| 406 |
+
negative_contribs = [(f, v, delta) for f, v, delta in top_feature_list if delta < 0]
|
| 407 |
+
|
| 408 |
+
# Build structured data dictionary
|
| 409 |
+
structured_data = {
|
| 410 |
+
'decision': 'approved' if approved else 'denied',
|
| 411 |
+
'base_probability': f"{base_value*100:.1f}%" if base_value is not None else "N/A",
|
| 412 |
+
'predicted_probability': f"{pred_prob*100:.1f}%" if pred_prob is not None else "N/A",
|
| 413 |
+
'threshold': f"{tau*100:.0f}%",
|
| 414 |
+
'gap_to_threshold': f"{gap_to_threshold*100:.1f} pts" if gap_to_threshold > 0 else "0.0 pts",
|
| 415 |
+
'total_adjustment': f"{(pred_prob - base_value)*100:+.1f} pts" if (pred_prob is not None and base_value is not None) else "N/A",
|
| 416 |
+
'positive_factors': [],
|
| 417 |
+
'negative_factors': []
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
# Format positive contributors
|
| 421 |
+
for feature, value, delta in positive_contribs[:5]:
|
| 422 |
+
friendly_name = get_friendly_feature_name(feature)
|
| 423 |
+
factor_entry = {
|
| 424 |
+
'feature': friendly_name,
|
| 425 |
+
'impact': f"+{delta*100:.1f} pts",
|
| 426 |
+
'impact_numeric': delta * 100
|
| 427 |
+
}
|
| 428 |
+
if 'capital_gain' in feature or 'capital_loss' in feature:
|
| 429 |
+
factor_entry['value'] = fmt_money(value)
|
| 430 |
+
elif 'hours' in feature:
|
| 431 |
+
factor_entry['value'] = f"{value} hours/week"
|
| 432 |
+
elif 'age' in feature:
|
| 433 |
+
factor_entry['value'] = f"{value} years"
|
| 434 |
+
else:
|
| 435 |
+
factor_entry['value'] = str(value)
|
| 436 |
+
structured_data['positive_factors'].append(factor_entry)
|
| 437 |
+
|
| 438 |
+
# Format negative contributors
|
| 439 |
+
for feature, value, delta in negative_contribs[:5]:
|
| 440 |
+
friendly_name = get_friendly_feature_name(feature)
|
| 441 |
+
factor_entry = {
|
| 442 |
+
'feature': friendly_name,
|
| 443 |
+
'impact': f"{delta*100:.1f} pts",
|
| 444 |
+
'impact_numeric': delta * 100
|
| 445 |
+
}
|
| 446 |
+
if 'capital_gain' in feature or 'capital_loss' in feature:
|
| 447 |
+
factor_entry['value'] = fmt_money(value)
|
| 448 |
+
elif 'hours' in feature:
|
| 449 |
+
factor_entry['value'] = f"{value} hours/week"
|
| 450 |
+
elif 'age' in feature:
|
| 451 |
+
factor_entry['value'] = f"{value} years"
|
| 452 |
+
else:
|
| 453 |
+
factor_entry['value'] = str(value)
|
| 454 |
+
structured_data['negative_factors'].append(factor_entry)
|
| 455 |
+
|
| 456 |
+
# Generate explanation from data using LLM (respects anthropomorphism condition)
|
| 457 |
+
explanation = None
|
| 458 |
+
try:
|
| 459 |
+
from natural_conversation import generate_from_data
|
| 460 |
+
|
| 461 |
+
print(f"🤖 DEBUG: Generating SHAP explanation from data (anthropomorphic={config.show_anthropomorphic})...")
|
| 462 |
+
|
| 463 |
+
explanation = generate_from_data(
|
| 464 |
+
data=structured_data,
|
| 465 |
+
explanation_type='shap',
|
| 466 |
+
high_anthropomorphism=config.show_anthropomorphic
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
if explanation and len(explanation) > 50:
|
| 470 |
+
print(f"✅ DEBUG: Generated explanation ({len(explanation)} chars)")
|
| 471 |
+
else:
|
| 472 |
+
print(f"⚠️ DEBUG: LLM generation failed or too short")
|
| 473 |
+
explanation = None
|
| 474 |
+
|
| 475 |
+
except Exception as e:
|
| 476 |
+
print(f"❌ DEBUG: LLM generation failed: {e}")
|
| 477 |
+
explanation = None
|
| 478 |
+
|
| 479 |
+
# Fallback templates if LLM fails (preserves experimental conditions)
|
| 480 |
+
if not explanation:
|
| 481 |
+
print("⚠️ DEBUG: Using fallback template")
|
| 482 |
+
if config.show_anthropomorphic:
|
| 483 |
+
# High anthropomorphism fallback
|
| 484 |
+
if approved:
|
| 485 |
+
explanation = f"Thanks for waiting — your application was approved! 🎉\n\n"
|
| 486 |
+
explanation += f"Starting from {structured_data['base_probability']}, key factors helped:\n"
|
| 487 |
+
for factor in structured_data['positive_factors'][:4]:
|
| 488 |
+
explanation += f"• {factor['feature']} ({factor['value']}): **{factor['impact']}**\n"
|
| 489 |
+
explanation += f"\nFinal score: **{structured_data['predicted_probability']}** (threshold: {structured_data['threshold']}) ✨"
|
| 490 |
+
else:
|
| 491 |
+
explanation = f"I'm sorry this wasn't the news you were hoping for. 😔\n\n"
|
| 492 |
+
explanation += f"Starting from {structured_data['base_probability']}, here's what happened:\n\n"
|
| 493 |
+
if structured_data['positive_factors']:
|
| 494 |
+
explanation += "**What helped:**\n"
|
| 495 |
+
for factor in structured_data['positive_factors'][:3]:
|
| 496 |
+
explanation += f"• {factor['feature']} ({factor['value']}): **{factor['impact']}**\n"
|
| 497 |
+
if structured_data['negative_factors']:
|
| 498 |
+
explanation += "\n**What held back:**\n"
|
| 499 |
+
for factor in structured_data['negative_factors'][:2]:
|
| 500 |
+
explanation += f"• {factor['feature']} ({factor['value']}): **{factor['impact']}**\n"
|
| 501 |
+
explanation += f"\nFinal score: **{structured_data['predicted_probability']}** (needed: {structured_data['threshold']}, gap: {structured_data['gap_to_threshold']}) 💙"
|
| 502 |
+
else:
|
| 503 |
+
# Low anthropomorphism fallback
|
| 504 |
+
if approved:
|
| 505 |
+
explanation = "**Feature Impact Analysis**\n\n"
|
| 506 |
+
explanation += f"**Baseline Probability:** {structured_data['base_probability']}\n\n"
|
| 507 |
+
explanation += "**Key Contributing Factors:**\n"
|
| 508 |
+
for factor in structured_data['positive_factors'][:5]:
|
| 509 |
+
explanation += f"• **{factor['feature']}:** {factor['impact']} (value: {factor['value']})\n"
|
| 510 |
+
explanation += f"\n**Decision Summary:**\n"
|
| 511 |
+
explanation += f"Factors increased probability by {structured_data['total_adjustment']} to **{structured_data['predicted_probability']}**, "
|
| 512 |
+
explanation += f"exceeding the **{structured_data['threshold']}** approval threshold."
|
| 513 |
+
else:
|
| 514 |
+
explanation = "**Feature Impact Analysis**\n\n"
|
| 515 |
+
explanation += f"**Baseline Probability:** {structured_data['base_probability']}\n\n"
|
| 516 |
+
if structured_data['positive_factors']:
|
| 517 |
+
explanation += "**Positive Factors** (increased approval probability):\n"
|
| 518 |
+
for factor in structured_data['positive_factors'][:5]:
|
| 519 |
+
explanation += f"• **{factor['feature']}:** {factor['impact']} (value: {factor['value']})\n"
|
| 520 |
+
explanation += "\n"
|
| 521 |
+
if structured_data['negative_factors']:
|
| 522 |
+
explanation += "**Negative Factors** (decreased approval probability):\n"
|
| 523 |
+
for factor in structured_data['negative_factors'][:5]:
|
| 524 |
+
explanation += f"• **{factor['feature']}:** {factor['impact']} (value: {factor['value']})\n"
|
| 525 |
+
explanation += "\n"
|
| 526 |
+
explanation += "**Decision Summary:**\n"
|
| 527 |
+
explanation += f"Profile factors adjusted probability by {structured_data['total_adjustment']} to **{structured_data['predicted_probability']}**. "
|
| 528 |
+
explanation += f"Approval threshold: **{structured_data['threshold']}**, shortfall: **{structured_data['gap_to_threshold']}**."
|
| 529 |
+
|
| 530 |
+
result = {
|
| 531 |
+
'type': 'shap',
|
| 532 |
+
'explanation': explanation,
|
| 533 |
+
'feature_impacts': feature_impacts,
|
| 534 |
+
'prediction_class': predicted_class,
|
| 535 |
+
'method': 'local_shap_probability_space',
|
| 536 |
+
'shap_contributions': shap_contributions,
|
| 537 |
+
'base_value': base_value,
|
| 538 |
+
'predicted_probability': pred_prob,
|
| 539 |
+
'threshold': tau,
|
| 540 |
+
'gap_to_threshold': gap_to_threshold
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
# Include SHAP values if they were successfully computed (needed for visualizations)
|
| 544 |
+
if shap_values_computed is not None:
|
| 545 |
+
result['shap_values'] = shap_values_computed
|
| 546 |
+
result['instance_df'] = instance_df
|
| 547 |
+
|
| 548 |
+
return result
|
| 549 |
+
|
| 550 |
+
except Exception as e:
|
| 551 |
+
return {
|
| 552 |
+
'type': 'error',
|
| 553 |
+
'explanation': f"Feature importance analysis unavailable: {str(e)}",
|
| 554 |
+
'error': str(e)
|
| 555 |
+
}
|
| 556 |
+
|
| 557 |
+
def explain_with_shap_advanced(agent, instance_df):
|
| 558 |
+
"""Generate SHAP force plot and summary plot for the given instance."""
|
| 559 |
+
try:
|
| 560 |
+
explainer = shap.Explainer(agent.clf_display, agent.data['X_display'])
|
| 561 |
+
shap_values = explainer(instance_df)
|
| 562 |
+
# SHAP JS visualization (force plot)
|
| 563 |
+
shap.initjs()
|
| 564 |
+
force_plot = shap.force_plot(explainer.expected_value, shap_values.values[0], instance_df.iloc[0], matplotlib=True, show=False)
|
| 565 |
+
# SHAP summary plot
|
| 566 |
+
plt.figure()
|
| 567 |
+
shap.summary_plot(shap_values.values, instance_df, show=False)
|
| 568 |
+
summary_fig = plt.gcf()
|
| 569 |
+
plt.close()
|
| 570 |
+
return {
|
| 571 |
+
'type': 'shap_advanced',
|
| 572 |
+
'force_plot': force_plot,
|
| 573 |
+
'summary_fig': summary_fig,
|
| 574 |
+
'explanation': 'SHAP force plot and summary plot generated.'
|
| 575 |
+
}
|
| 576 |
+
except Exception as e:
|
| 577 |
+
return {
|
| 578 |
+
'type': 'error',
|
| 579 |
+
'explanation': f"Could not generate SHAP advanced visualizations: {str(e)}",
|
| 580 |
+
'error': str(e)
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
def explain_with_dice(agent, target_class=None, features='all'):
|
| 584 |
+
"""DiCE counterfactuals using actual DiCE library to generate counterfactuals"""
|
| 585 |
+
try:
|
| 586 |
+
from ab_config import config
|
| 587 |
+
import pandas as pd
|
| 588 |
+
|
| 589 |
+
current_pred = getattr(agent, 'predicted_class', 'unknown')
|
| 590 |
+
target_class = target_class or ('<=50K' if current_pred == '>50K' else '>50K')
|
| 591 |
+
current_instance = agent.current_instance
|
| 592 |
+
|
| 593 |
+
changes = []
|
| 594 |
+
|
| 595 |
+
# Try to use actual DiCE library
|
| 596 |
+
try:
|
| 597 |
+
# Prepare data for DiCE
|
| 598 |
+
X_train = agent.data['X_display']
|
| 599 |
+
y_train = agent.data['y_display']
|
| 600 |
+
|
| 601 |
+
# Create dataset for DiCE
|
| 602 |
+
train_df = pd.concat([X_train, y_train], axis=1)
|
| 603 |
+
|
| 604 |
+
# Define continuous and categorical features
|
| 605 |
+
continuous_features = ['age', 'hours_per_week', 'capital_gain', 'capital_loss', 'education_num']
|
| 606 |
+
categorical_features = [col for col in X_train.columns if col not in continuous_features]
|
| 607 |
+
|
| 608 |
+
# Create DiCE data object
|
| 609 |
+
d = dice_ml.Data(
|
| 610 |
+
dataframe=train_df,
|
| 611 |
+
continuous_features=continuous_features,
|
| 612 |
+
outcome_name='income'
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
# Create DiCE model
|
| 616 |
+
m = dice_ml.Model(model=agent.clf_display, backend='sklearn')
|
| 617 |
+
|
| 618 |
+
# Create DiCE explainer
|
| 619 |
+
exp = dice_ml.Dice(d, m, method='random')
|
| 620 |
+
|
| 621 |
+
# Get current instance as dataframe
|
| 622 |
+
if isinstance(current_instance, dict):
|
| 623 |
+
query_instance = pd.DataFrame([current_instance])
|
| 624 |
+
else:
|
| 625 |
+
query_instance = pd.DataFrame([current_instance])
|
| 626 |
+
|
| 627 |
+
# Ensure all features are present
|
| 628 |
+
for col in X_train.columns:
|
| 629 |
+
if col not in query_instance.columns:
|
| 630 |
+
query_instance[col] = 0
|
| 631 |
+
query_instance = query_instance[X_train.columns]
|
| 632 |
+
|
| 633 |
+
# Generate counterfactuals
|
| 634 |
+
target_value = 1 if '>50K' in target_class else 0
|
| 635 |
+
dice_exp = exp.generate_counterfactuals(
|
| 636 |
+
query_instance,
|
| 637 |
+
total_CFs=3,
|
| 638 |
+
desired_class=target_value
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
# Extract changes from counterfactuals using natural language
|
| 642 |
+
cf_df = dice_exp.cf_examples_list[0].final_cfs_df
|
| 643 |
+
|
| 644 |
+
# Check if counterfactuals were generated (handle DataFrame properly)
|
| 645 |
+
has_cf = cf_df is not None and isinstance(cf_df, pd.DataFrame) and len(cf_df) > 0
|
| 646 |
+
if has_cf:
|
| 647 |
+
# Compare with original instance and format naturally
|
| 648 |
+
for col in query_instance.columns:
|
| 649 |
+
# Extract scalar values properly
|
| 650 |
+
orig_val = query_instance[col].values[0]
|
| 651 |
+
cf_val = cf_df[col].values[0] if hasattr(cf_df[col], 'values') else cf_df[col].iloc[0]
|
| 652 |
+
|
| 653 |
+
# Convert to comparable types and check difference
|
| 654 |
+
try:
|
| 655 |
+
# Handle numeric comparison
|
| 656 |
+
if isinstance(orig_val, (int, float, np.number)) and isinstance(cf_val, (int, float, np.number)):
|
| 657 |
+
is_different = float(orig_val) != float(cf_val)
|
| 658 |
+
else:
|
| 659 |
+
# Handle string/categorical comparison
|
| 660 |
+
is_different = str(orig_val) != str(cf_val)
|
| 661 |
+
except Exception:
|
| 662 |
+
is_different = False
|
| 663 |
+
|
| 664 |
+
if is_different:
|
| 665 |
+
# Format with natural language using GLOBAL FEATURE_DISPLAY_NAMES
|
| 666 |
+
friendly_name = get_friendly_feature_name(col)
|
| 667 |
+
|
| 668 |
+
# Format values with appropriate units
|
| 669 |
+
if col == 'age':
|
| 670 |
+
from_val = f"{orig_val} years old"
|
| 671 |
+
to_val = f"{cf_val} years old"
|
| 672 |
+
elif col == 'hours_per_week':
|
| 673 |
+
from_val = f"{orig_val} hours per week"
|
| 674 |
+
to_val = f"{cf_val} hours per week"
|
| 675 |
+
elif 'capital' in col:
|
| 676 |
+
from_val = f"${orig_val:,}" if isinstance(orig_val, (int, float)) else str(orig_val)
|
| 677 |
+
to_val = f"${cf_val:,}" if isinstance(cf_val, (int, float)) else str(cf_val)
|
| 678 |
+
else:
|
| 679 |
+
from_val = str(orig_val)
|
| 680 |
+
to_val = str(cf_val)
|
| 681 |
+
|
| 682 |
+
changes.append(f"Your {friendly_name.lower()} (changing from {from_val} to {to_val})")
|
| 683 |
+
|
| 684 |
+
except Exception as dice_error:
|
| 685 |
+
# Fallback to rule-based analysis if DiCE fails
|
| 686 |
+
pass
|
| 687 |
+
|
| 688 |
+
# If DiCE didn't generate changes or failed, use intelligent rule-based system with natural language
|
| 689 |
+
if not changes and current_instance is not None:
|
| 690 |
+
# Convert Series to dict if needed
|
| 691 |
+
if hasattr(current_instance, 'to_dict'):
|
| 692 |
+
current_instance = current_instance.to_dict()
|
| 693 |
+
|
| 694 |
+
# Check education level
|
| 695 |
+
current_education = str(current_instance.get('education', '')).lower()
|
| 696 |
+
current_education_num = current_instance.get('education_num', 0)
|
| 697 |
+
if current_education_num < 13: # Less than Bachelor's
|
| 698 |
+
if 'hs-grad' in current_education or 'high school' in current_education:
|
| 699 |
+
changes.append("Your education level (completing a Bachelor's degree)")
|
| 700 |
+
elif current_education_num < 9:
|
| 701 |
+
changes.append("Your education level (completing High School and pursuing higher education)")
|
| 702 |
+
else:
|
| 703 |
+
changes.append("Your education level (pursuing a Bachelor's or higher degree)")
|
| 704 |
+
|
| 705 |
+
# Check occupation
|
| 706 |
+
current_occupation = str(current_instance.get('occupation', '')).lower()
|
| 707 |
+
if current_occupation and 'exec' not in current_occupation and 'prof' not in current_occupation and 'managerial' not in current_occupation:
|
| 708 |
+
changes.append(f"Your occupation (moving from {current_occupation} to a professional or managerial role)")
|
| 709 |
+
elif not current_occupation:
|
| 710 |
+
changes.append("Your occupation (moving to a professional or managerial role)")
|
| 711 |
+
elif not current_occupation:
|
| 712 |
+
changes.append("Your occupation (moving to a professional or managerial role)")
|
| 713 |
+
|
| 714 |
+
# Check working hours
|
| 715 |
+
current_hours = current_instance.get('hours_per_week', 0)
|
| 716 |
+
if current_hours < 40:
|
| 717 |
+
changes.append(f"Your work schedule (increasing from {current_hours} to 40+ hours per week)")
|
| 718 |
+
elif current_hours < 50:
|
| 719 |
+
changes.append(f"Your work schedule (increasing from {current_hours} to 50+ hours per week)")
|
| 720 |
+
|
| 721 |
+
# Check marital status
|
| 722 |
+
current_marital = str(current_instance.get('marital_status', '')).lower()
|
| 723 |
+
if current_marital and 'married' not in current_marital:
|
| 724 |
+
changes.append(f"Your marital status (currently {current_marital})")
|
| 725 |
+
elif not current_marital:
|
| 726 |
+
changes.append("Your marital status (married status associated with better outcomes)")
|
| 727 |
+
|
| 728 |
+
# Check capital gain
|
| 729 |
+
current_capital_gain = current_instance.get('capital_gain', 0)
|
| 730 |
+
if current_capital_gain < 5000:
|
| 731 |
+
changes.append(f"Your capital gains (increasing from ${current_capital_gain} to $5,000 or more)")
|
| 732 |
+
|
| 733 |
+
# Check age
|
| 734 |
+
current_age = current_instance.get('age', 0)
|
| 735 |
+
if current_age < 35:
|
| 736 |
+
changes.append(f"Your age (being {current_age} years old)")
|
| 737 |
+
|
| 738 |
+
# Fallback if no changes generated
|
| 739 |
+
if not changes:
|
| 740 |
+
changes = [
|
| 741 |
+
"Your education level (pursuing a Bachelor's or Master's degree)",
|
| 742 |
+
"Your occupation (moving into a professional or managerial role)",
|
| 743 |
+
"Your work schedule (working full-time, 40+ hours per week)"
|
| 744 |
+
]
|
| 745 |
+
|
| 746 |
+
# ===== DATA-DRIVEN APPROACH: Extract structured data for LLM =====
|
| 747 |
+
structured_data = {
|
| 748 |
+
'decision': current_pred,
|
| 749 |
+
'target_class': target_class,
|
| 750 |
+
'num_changes': len(changes),
|
| 751 |
+
'suggested_changes': changes[:5],
|
| 752 |
+
'is_denied': 'not' in str(current_pred).lower() or 'denied' in str(current_pred).lower() or '<' in str(current_pred)
|
| 753 |
+
}
|
| 754 |
+
|
| 755 |
+
# Generate explanation from data using LLM (respects anthropomorphism condition)
|
| 756 |
+
explanation = None
|
| 757 |
+
try:
|
| 758 |
+
from natural_conversation import generate_from_data
|
| 759 |
+
|
| 760 |
+
print(f"🤖 DEBUG (DiCE): Generating explanation from data (anthropomorphic={config.show_anthropomorphic})...")
|
| 761 |
+
|
| 762 |
+
explanation = generate_from_data(
|
| 763 |
+
data=structured_data,
|
| 764 |
+
explanation_type='dice',
|
| 765 |
+
high_anthropomorphism=config.show_anthropomorphic
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
if explanation and len(explanation) > 50:
|
| 769 |
+
print(f"✅ DEBUG: Generated counterfactual explanation ({len(explanation)} chars)")
|
| 770 |
+
else:
|
| 771 |
+
print(f"⚠️ DEBUG: LLM generation failed or too short")
|
| 772 |
+
explanation = None
|
| 773 |
+
|
| 774 |
+
except Exception as e:
|
| 775 |
+
print(f"❌ DEBUG: LLM generation failed: {e}")
|
| 776 |
+
explanation = None
|
| 777 |
+
|
| 778 |
+
# Fallback templates if LLM fails (preserves experimental conditions)
|
| 779 |
+
if not explanation:
|
| 780 |
+
print("⚠️ DEBUG: Using fallback template")
|
| 781 |
+
if config.show_anthropomorphic:
|
| 782 |
+
# High anthropomorphism fallback
|
| 783 |
+
if structured_data['is_denied']:
|
| 784 |
+
explanation = "💡 **What could help your application?**\n\n"
|
| 785 |
+
explanation += "Here are changes that could make a difference:\n\n"
|
| 786 |
+
for i, change in enumerate(changes[:5], 1):
|
| 787 |
+
explanation += f"**{i}.** {change}\n"
|
| 788 |
+
explanation += "\n✨ These factors show up in successful applications. Try the What-If Lab to explore more! 👍"
|
| 789 |
+
else:
|
| 790 |
+
explanation = "🔄 **What might change the outcome?**\n\n"
|
| 791 |
+
explanation += "Here's what could affect the decision:\n\n"
|
| 792 |
+
for i, change in enumerate(changes[:5], 1):
|
| 793 |
+
explanation += f"**{i}.** {change}\n"
|
| 794 |
+
explanation += "\n💭 Check out the What-If Lab to test scenarios! ✨"
|
| 795 |
+
else:
|
| 796 |
+
# Low anthropomorphism fallback
|
| 797 |
+
if structured_data['is_denied']:
|
| 798 |
+
explanation = "**Recommended Profile Modifications**\n\n"
|
| 799 |
+
for i, change in enumerate(changes[:5], 1):
|
| 800 |
+
explanation += f"**{i}.** {change}\n"
|
| 801 |
+
explanation += "\nAnalysis based on approved application patterns. Refer to What-If Lab for interactive testing."
|
| 802 |
+
else:
|
| 803 |
+
explanation = "**Profile Impact Analysis**\n\n"
|
| 804 |
+
for i, change in enumerate(changes[:5], 1):
|
| 805 |
+
explanation += f"**{i}.** {change}\n"
|
| 806 |
+
explanation += "\nData-driven insights from comparative analysis. Refer to What-If Lab for exploration."
|
| 807 |
+
|
| 808 |
+
# Ensure current_instance is a dict for return values # Ensure current_instance is a dict for return values
|
| 809 |
+
instance_dict = current_instance
|
| 810 |
+
if hasattr(current_instance, 'to_dict'):
|
| 811 |
+
instance_dict = current_instance.to_dict()
|
| 812 |
+
|
| 813 |
+
return {
|
| 814 |
+
'type': 'dice',
|
| 815 |
+
'explanation': explanation,
|
| 816 |
+
'target_class': target_class,
|
| 817 |
+
'changes': changes,
|
| 818 |
+
'method': 'counterfactual_analysis',
|
| 819 |
+
'current_values': {
|
| 820 |
+
'education_num': instance_dict.get('education_num', 0) if instance_dict else 0,
|
| 821 |
+
'hours_per_week': instance_dict.get('hours_per_week', 0) if instance_dict else 0,
|
| 822 |
+
'capital_gain': instance_dict.get('capital_gain', 0) if instance_dict else 0,
|
| 823 |
+
'age': instance_dict.get('age', 0) if instance_dict else 0
|
| 824 |
+
}
|
| 825 |
+
}
|
| 826 |
+
|
| 827 |
+
except Exception as e:
|
| 828 |
+
return {
|
| 829 |
+
'type': 'error',
|
| 830 |
+
'explanation': f"Counterfactual analysis unavailable: {str(e)}",
|
| 831 |
+
'error': str(e)
|
| 832 |
+
}
|
| 833 |
+
|
| 834 |
+
def explain_with_anchor(agent):
|
| 835 |
+
"""Anchor explanations using actual data patterns from the model"""
|
| 836 |
+
try:
|
| 837 |
+
from ab_config import config
|
| 838 |
+
import pandas as pd
|
| 839 |
+
|
| 840 |
+
current_pred = getattr(agent, 'predicted_class', 'unknown')
|
| 841 |
+
current_instance = agent.current_instance
|
| 842 |
+
|
| 843 |
+
# Extract actual rules from current instance
|
| 844 |
+
rules_friendly = []
|
| 845 |
+
rules_technical = []
|
| 846 |
+
|
| 847 |
+
if current_instance is not None and len(current_instance) > 0:
|
| 848 |
+
# Convert Series to dict for safe .get() access
|
| 849 |
+
if hasattr(current_instance, 'to_dict'):
|
| 850 |
+
instance_dict = current_instance.to_dict()
|
| 851 |
+
elif isinstance(current_instance, dict):
|
| 852 |
+
instance_dict = current_instance
|
| 853 |
+
else:
|
| 854 |
+
instance_dict = dict(current_instance) if current_instance is not None else {}
|
| 855 |
+
|
| 856 |
+
# Age rule
|
| 857 |
+
age = instance_dict.get('age', 0)
|
| 858 |
+
if age > 35:
|
| 859 |
+
friendly = get_friendly_feature_name('age')
|
| 860 |
+
rules_friendly.append(f"Your {friendly.lower()} ({age} years old)")
|
| 861 |
+
rules_technical.append(f"age > 35 (value: {age})")
|
| 862 |
+
elif age < 25:
|
| 863 |
+
friendly = get_friendly_feature_name('age')
|
| 864 |
+
rules_friendly.append(f"Your {friendly.lower()} ({age} years old)")
|
| 865 |
+
rules_technical.append(f"age < 25 (value: {age})")
|
| 866 |
+
|
| 867 |
+
# Education rule
|
| 868 |
+
education_num = instance_dict.get('education_num', 0)
|
| 869 |
+
education = instance_dict.get('education', 'Unknown')
|
| 870 |
+
if education_num >= 13:
|
| 871 |
+
friendly = get_friendly_feature_name('education_num')
|
| 872 |
+
rules_friendly.append(f"Your {friendly.lower()} ({education})")
|
| 873 |
+
rules_technical.append(f"education_num >= 13 (Bachelor's or higher)")
|
| 874 |
+
elif education_num < 9:
|
| 875 |
+
friendly = get_friendly_feature_name('education_num')
|
| 876 |
+
rules_friendly.append(f"Your {friendly.lower()} ({education})")
|
| 877 |
+
rules_technical.append(f"education_num < 9 (less than HS)")
|
| 878 |
+
|
| 879 |
+
# Hours rule
|
| 880 |
+
hours = instance_dict.get('hours_per_week', 0)
|
| 881 |
+
if hours >= 40:
|
| 882 |
+
friendly = get_friendly_feature_name('hours_per_week')
|
| 883 |
+
rules_friendly.append(f"Your {friendly.lower()} ({hours} hours per week)")
|
| 884 |
+
rules_technical.append(f"hours_per_week >= 40 (value: {hours})")
|
| 885 |
+
elif hours < 30:
|
| 886 |
+
friendly = get_friendly_feature_name('hours_per_week')
|
| 887 |
+
rules_friendly.append(f"Your {friendly.lower()} ({hours} hours per week)")
|
| 888 |
+
rules_technical.append(f"hours_per_week < 30 (value: {hours})")
|
| 889 |
+
|
| 890 |
+
# Marital status rule
|
| 891 |
+
marital = instance_dict.get('marital_status', '')
|
| 892 |
+
if 'Married' in marital:
|
| 893 |
+
friendly = get_friendly_feature_name('marital_status')
|
| 894 |
+
rules_friendly.append(f"Your {friendly.lower()} ({marital})")
|
| 895 |
+
rules_technical.append(f"marital_status = '{marital}'")
|
| 896 |
+
|
| 897 |
+
# Capital gain rule
|
| 898 |
+
capital_gain = instance_dict.get('capital_gain', 0)
|
| 899 |
+
if capital_gain > 5000:
|
| 900 |
+
friendly = get_friendly_feature_name('capital_gain')
|
| 901 |
+
rules_friendly.append(f"Your {friendly.lower()} (${capital_gain:,})")
|
| 902 |
+
rules_technical.append(f"capital_gain > 5000 (value: {capital_gain})")
|
| 903 |
+
elif capital_gain > 0:
|
| 904 |
+
friendly = get_friendly_feature_name('capital_gain')
|
| 905 |
+
rules_friendly.append(f"Your {friendly.lower()} (${capital_gain:,})")
|
| 906 |
+
rules_technical.append(f"capital_gain > 0 (value: {capital_gain})")
|
| 907 |
+
|
| 908 |
+
# Occupation rule
|
| 909 |
+
occupation = instance_dict.get('occupation', '')
|
| 910 |
+
if occupation:
|
| 911 |
+
if any(x in occupation for x in ['Exec', 'Prof', 'Managerial']):
|
| 912 |
+
friendly = get_friendly_feature_name('occupation')
|
| 913 |
+
rules_friendly.append(f"Your {friendly.lower()} ({occupation})")
|
| 914 |
+
rules_technical.append(f"occupation = '{occupation}' (professional)")
|
| 915 |
+
|
| 916 |
+
# Estimate precision and coverage based on feature importance
|
| 917 |
+
precision = 0.85 + (len(rules_friendly) * 0.02) # More rules = higher precision
|
| 918 |
+
coverage = max(0.10, min(0.25, 0.05 * len(rules_friendly)))
|
| 919 |
+
|
| 920 |
+
# Generate explanation with language differentiation
|
| 921 |
+
if config.show_anthropomorphic:
|
| 922 |
+
# High anthropomorphism
|
| 923 |
+
explanation = "📋 **Key factors in your decision:**\n\n"
|
| 924 |
+
explanation += "The decision was primarily influenced by:\n"
|
| 925 |
+
for i, rule in enumerate(rules_friendly[:5], 1):
|
| 926 |
+
explanation += f"{i}. {rule}\n"
|
| 927 |
+
explanation += f"\n💡 This pattern is accurate about {precision:.0%} of the time and applies to roughly {coverage:.0%} of similar applications."
|
| 928 |
+
else:
|
| 929 |
+
# Low anthropomorphism
|
| 930 |
+
explanation = "**Decision rule analysis:**\n\n"
|
| 931 |
+
explanation += "Primary decision factors:\n"
|
| 932 |
+
for i, rule in enumerate(rules_technical[:5], 1):
|
| 933 |
+
explanation += f"{i}. {rule}\n"
|
| 934 |
+
explanation += f"\nRule precision: {precision:.2f}, Coverage: {coverage:.2f}"
|
| 935 |
+
|
| 936 |
+
return {
|
| 937 |
+
'type': 'anchor',
|
| 938 |
+
'explanation': explanation,
|
| 939 |
+
'rules': rules_technical,
|
| 940 |
+
'rules_friendly': rules_friendly,
|
| 941 |
+
'precision': precision,
|
| 942 |
+
'coverage': coverage,
|
| 943 |
+
'method': 'rule_based_analysis'
|
| 944 |
+
}
|
| 945 |
+
|
| 946 |
+
except Exception as e:
|
| 947 |
+
return {
|
| 948 |
+
'type': 'error',
|
| 949 |
+
'explanation': f"Rule analysis unavailable: {str(e)}",
|
| 950 |
+
'error': str(e)
|
| 951 |
+
}
|
| 952 |
+
|
| 953 |
+
def explain_with_dtreeviz(agent, instance_df):
|
| 954 |
+
"""Generate dtreeviz visualization for the trained decision tree."""
|
| 955 |
+
try:
|
| 956 |
+
from sklearn.tree import DecisionTreeClassifier
|
| 957 |
+
# If RandomForest, use one tree for visualization
|
| 958 |
+
if hasattr(agent.clf_display, 'estimators_'):
|
| 959 |
+
tree = agent.clf_display.estimators_[0]
|
| 960 |
+
else:
|
| 961 |
+
tree = agent.clf_display
|
| 962 |
+
viz = dtreeviz.dtreeviz(
|
| 963 |
+
tree,
|
| 964 |
+
agent.data['X_display'],
|
| 965 |
+
agent.data['y_display'],
|
| 966 |
+
target_name='income',
|
| 967 |
+
feature_names=agent.data['features'],
|
| 968 |
+
class_names=agent.data['classes']
|
| 969 |
+
)
|
| 970 |
+
return {
|
| 971 |
+
'type': 'dtreeviz',
|
| 972 |
+
'graph': viz,
|
| 973 |
+
'explanation': 'Decision tree visualization generated.'
|
| 974 |
+
}
|
| 975 |
+
except Exception as e:
|
| 976 |
+
return {
|
| 977 |
+
'type': 'error',
|
| 978 |
+
'explanation': f"Could not generate dtreeviz visualization: {str(e)}",
|
| 979 |
+
'error': str(e)
|
| 980 |
+
}
|
| 981 |
+
|
| 982 |
+
def route_to_xai_method(agent, intent_result):
|
| 983 |
+
"""Route user question to appropriate XAI method based on intent AND experimental condition"""
|
| 984 |
+
from ab_config import config
|
| 985 |
+
|
| 986 |
+
if isinstance(intent_result, dict) and 'intent' in intent_result:
|
| 987 |
+
method = intent_result['intent']
|
| 988 |
+
# Normalize common aliases
|
| 989 |
+
if method in {"rule", "rules", "rule_based", "rule-based", "local_explanation"}:
|
| 990 |
+
method = 'anchor'
|
| 991 |
+
|
| 992 |
+
# Check experimental condition - only provide explanations that are enabled
|
| 993 |
+
if method == 'shap':
|
| 994 |
+
if config.explanation == "feature_importance": # Both condition 5 and 6
|
| 995 |
+
return explain_with_shap(agent, intent_result.get('label'))
|
| 996 |
+
else:
|
| 997 |
+
return {
|
| 998 |
+
'type': 'unavailable',
|
| 999 |
+
'explanation': "Feature importance explanations are not available in this version.",
|
| 1000 |
+
'method': 'shap_disabled'
|
| 1001 |
+
}
|
| 1002 |
+
elif method == 'dice':
|
| 1003 |
+
if config.show_counterfactual: # counterfactual condition
|
| 1004 |
+
return explain_with_dice(agent)
|
| 1005 |
+
else:
|
| 1006 |
+
return {
|
| 1007 |
+
'type': 'unavailable',
|
| 1008 |
+
'explanation': "Counterfactual explanations are not available in this version.",
|
| 1009 |
+
'method': 'dice_disabled'
|
| 1010 |
+
}
|
| 1011 |
+
elif method == 'anchor':
|
| 1012 |
+
# Anchor is available in all conditions as baseline
|
| 1013 |
+
return explain_with_anchor(agent)
|
| 1014 |
+
else:
|
| 1015 |
+
return {
|
| 1016 |
+
'type': 'general',
|
| 1017 |
+
'explanation': f"I understand you're asking about: {intent_result.get('matched_question', 'the model')}. Let me provide a general explanation.",
|
| 1018 |
+
'method': 'general'
|
| 1019 |
+
}
|
| 1020 |
+
else:
|
| 1021 |
+
return {
|
| 1022 |
+
'type': 'error',
|
| 1023 |
+
'explanation': "I'm not sure how to explain that. Could you rephrase your question?",
|
| 1024 |
+
'suggestions': intent_result[2] if len(intent_result) > 2 else []
|
| 1025 |
+
}
|
| 1026 |
+
|
| 1027 |
+
|
| 1028 |
+
|