akash
commited on
Commit
·
890025a
1
Parent(s):
65eae8a
all files
Browse files- app.py +168 -0
- example.env +13 -0
- header.svg +38 -0
- laptop_data.csv +0 -0
- requirements.txt +18 -0
- src/__init__.py +6 -0
- src/preprocessing/__init__.py +4 -0
- src/preprocessing/clean_data.py +268 -0
- src/preprocessing/clean_df_fallback.py +143 -0
- src/training/__init__.py +7 -0
- src/training/hyperparametrs.py +107 -0
- src/training/model_training.py +93 -0
- src/training/test_result.py +27 -0
- src/training/train.py +140 -0
- src/ui/__init__.py +20 -0
- src/ui/css.py +377 -0
- src/ui/footer.py +10 -0
- src/ui/insight.py +30 -0
- src/ui/loading.py +119 -0
- src/ui/overview.py +47 -0
- src/ui/test_results.py +295 -0
- src/ui/visualization.py +556 -0
- src/ui/welcome.py +186 -0
- src/utils/logging.py +70 -0
app.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
# Streamlit page setup
|
| 8 |
+
st.set_page_config(
|
| 9 |
+
page_title="AutoML",
|
| 10 |
+
page_icon="🛸",
|
| 11 |
+
layout="wide",
|
| 12 |
+
initial_sidebar_state="expanded",
|
| 13 |
+
menu_items={"Get Help": None, "Report a bug": None, "About": None},
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
# Add project root and src to Python path
|
| 17 |
+
sys.path.extend([
|
| 18 |
+
os.path.dirname(os.path.abspath(__file__)), # Project root
|
| 19 |
+
os.path.join(os.path.dirname(os.path.abspath(__file__)), "src")
|
| 20 |
+
])
|
| 21 |
+
|
| 22 |
+
# Import loading FIRST before any components
|
| 23 |
+
from src.ui.loading import show_loading_state
|
| 24 |
+
# Import CSS loader FIRST
|
| 25 |
+
from src.ui.css import load_css
|
| 26 |
+
|
| 27 |
+
# Load CSS immediately after imports
|
| 28 |
+
load_css()
|
| 29 |
+
|
| 30 |
+
# Cached resource loading with TTL to refresh components periodically
|
| 31 |
+
@st.cache_resource(ttl=3600) # Cache for 1 hour
|
| 32 |
+
def load_components():
|
| 33 |
+
"""Cache component imports to avoid reloading on every rerun"""
|
| 34 |
+
from src import (
|
| 35 |
+
show_footer,
|
| 36 |
+
visualize_data,
|
| 37 |
+
show_welcome_page,
|
| 38 |
+
show_overview_page,
|
| 39 |
+
clean_csv,
|
| 40 |
+
model_training_tab,
|
| 41 |
+
display_ai_insights,
|
| 42 |
+
display_model_evaluation
|
| 43 |
+
)
|
| 44 |
+
return (show_footer, visualize_data,
|
| 45 |
+
show_welcome_page, show_overview_page, clean_csv,
|
| 46 |
+
model_training_tab, display_ai_insights, display_model_evaluation)
|
| 47 |
+
|
| 48 |
+
# Cached header rendering
|
| 49 |
+
@st.cache_data(ttl=86400) # Cache for 24 hours
|
| 50 |
+
def render_header():
|
| 51 |
+
"""Cache static header HTML"""
|
| 52 |
+
return """
|
| 53 |
+
<div class='app-header' style='padding: 1rem 0; margin-bottom: 2rem; text-align: center;'>
|
| 54 |
+
<h1 class='app-title' style='margin: 0;'>AutoML</h1>
|
| 55 |
+
<p class='app-tagline' style='margin-top: 0;'>Automated Machine Learning Made Simple.</p>
|
| 56 |
+
</div>
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
# Cached data loading
|
| 60 |
+
@st.cache_data(ttl=3600) # Cache for 1 hour
|
| 61 |
+
def load_default_data():
|
| 62 |
+
"""Load and cache the default dataset"""
|
| 63 |
+
try:
|
| 64 |
+
return pd.read_csv("laptop_data.csv")
|
| 65 |
+
except Exception as e:
|
| 66 |
+
st.error(f"❌ Error loading default dataset: {str(e)}")
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
# Performance monitoring decorator
|
| 70 |
+
def measure_time(func):
|
| 71 |
+
"""Decorator to measure execution time of functions"""
|
| 72 |
+
def wrapper(*args, **kwargs):
|
| 73 |
+
start_time = time.time()
|
| 74 |
+
result = func(*args, **kwargs)
|
| 75 |
+
end_time = time.time()
|
| 76 |
+
execution_time = end_time - start_time
|
| 77 |
+
if execution_time > 1.0: # Only log slow operations
|
| 78 |
+
print(f"⏱️ {func.__name__} took {execution_time:.2f} seconds to execute")
|
| 79 |
+
return result
|
| 80 |
+
return wrapper
|
| 81 |
+
|
| 82 |
+
@measure_time
|
| 83 |
+
def main():
|
| 84 |
+
"""Optimized main function for Streamlit AutoML app"""
|
| 85 |
+
# First show loading screen before anything else
|
| 86 |
+
if "initialized" not in st.session_state:
|
| 87 |
+
# Show loading animation in full screen mode
|
| 88 |
+
with st.container():
|
| 89 |
+
show_loading_state()
|
| 90 |
+
|
| 91 |
+
# Force render loading screen first
|
| 92 |
+
st.empty().markdown("<style>#root > div:nth-child(1) > div > div > div > div > section > div {padding: 0rem;}</style>", unsafe_allow_html=True)
|
| 93 |
+
|
| 94 |
+
# Now load components in background
|
| 95 |
+
components = load_components()
|
| 96 |
+
(show_footer, visualize_data,
|
| 97 |
+
show_welcome_page, show_overview_page, clean_csv,
|
| 98 |
+
model_training_tab, display_ai_insights, display_model_evaluation) = components
|
| 99 |
+
|
| 100 |
+
try:
|
| 101 |
+
# Load and clean data with caching
|
| 102 |
+
default_df = load_default_data()
|
| 103 |
+
if default_df is not None:
|
| 104 |
+
cleaned_df, insights = clean_csv(default_df)
|
| 105 |
+
|
| 106 |
+
# Store everything in session state
|
| 107 |
+
st.session_state.update({
|
| 108 |
+
"df": cleaned_df,
|
| 109 |
+
"insights": insights,
|
| 110 |
+
"components": components,
|
| 111 |
+
"initialized": True,
|
| 112 |
+
"current_tab_index": 0 # Use consistent naming for tab tracking
|
| 113 |
+
})
|
| 114 |
+
|
| 115 |
+
# Rerun to hide loading screen
|
| 116 |
+
st.rerun()
|
| 117 |
+
else:
|
| 118 |
+
st.error("❌ Failed to load default dataset")
|
| 119 |
+
return
|
| 120 |
+
|
| 121 |
+
except Exception as e:
|
| 122 |
+
st.error(f"❌ Error during initialization: {str(e)}")
|
| 123 |
+
return
|
| 124 |
+
|
| 125 |
+
# After initialization, show main interface
|
| 126 |
+
if "initialized" in st.session_state:
|
| 127 |
+
components = st.session_state.components
|
| 128 |
+
(show_footer, visualize_data,
|
| 129 |
+
show_welcome_page, show_overview_page, clean_csv,
|
| 130 |
+
model_training_tab, display_ai_insights, display_model_evaluation) = components
|
| 131 |
+
|
| 132 |
+
# Render main interface
|
| 133 |
+
st.markdown(render_header(), unsafe_allow_html=True)
|
| 134 |
+
|
| 135 |
+
# Create tabs with tab names as constants to avoid recreation
|
| 136 |
+
TAB_NAMES = ["👋 Welcome", "📊 Overview", "📈 Visualization",
|
| 137 |
+
"🤖 Model Training", "💡 Insights", "📊 Test Results"]
|
| 138 |
+
|
| 139 |
+
# Initialize current tab index if not present
|
| 140 |
+
if "current_tab_index" not in st.session_state:
|
| 141 |
+
st.session_state.current_tab_index = 0
|
| 142 |
+
|
| 143 |
+
# Create tabs and get the current tab index
|
| 144 |
+
tab_index = st.tabs(TAB_NAMES)
|
| 145 |
+
|
| 146 |
+
# Display content in all tabs
|
| 147 |
+
with tab_index[0]:
|
| 148 |
+
show_welcome_page()
|
| 149 |
+
|
| 150 |
+
with tab_index[1]:
|
| 151 |
+
show_overview_page()
|
| 152 |
+
|
| 153 |
+
with tab_index[2]:
|
| 154 |
+
visualize_data(st.session_state.df)
|
| 155 |
+
|
| 156 |
+
with tab_index[3]:
|
| 157 |
+
model_training_tab(st.session_state.df)
|
| 158 |
+
|
| 159 |
+
with tab_index[4]:
|
| 160 |
+
display_ai_insights()
|
| 161 |
+
|
| 162 |
+
with tab_index[5]:
|
| 163 |
+
display_model_evaluation()
|
| 164 |
+
|
| 165 |
+
show_footer()
|
| 166 |
+
|
| 167 |
+
if __name__ == "__main__":
|
| 168 |
+
main()
|
example.env
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AutoML Environment Variables
|
| 2 |
+
|
| 3 |
+
# API Keys for LLM Services
|
| 4 |
+
GROQ_API_KEY=your_groq_api_key_here
|
| 5 |
+
GEMINI_API_KEY=your_gemini_api_key_here
|
| 6 |
+
|
| 7 |
+
# LangSmith Tracking (Optional)
|
| 8 |
+
LANGCHAIN_TRACING_V2=true
|
| 9 |
+
LANGCHAIN_API_KEY=your_langchain_api_key_here
|
| 10 |
+
LANGCHAIN_PROJECT=automl-project
|
| 11 |
+
|
| 12 |
+
# Optional: Logging Configuration
|
| 13 |
+
LOG_LEVEL=INFO
|
header.svg
ADDED
|
|
laptop_data.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit>=1.29.0
|
| 2 |
+
pandas>=2.0.0
|
| 3 |
+
numpy>=1.24.0
|
| 4 |
+
scikit-learn>=1.2.0
|
| 5 |
+
matplotlib>=3.7.0
|
| 6 |
+
plotly>=5.14.0
|
| 7 |
+
seaborn>=0.12.0
|
| 8 |
+
langchain>=0.0.267
|
| 9 |
+
langchain-groq>=0.0.1
|
| 10 |
+
langchain-google-genai>=0.0.3
|
| 11 |
+
python-dotenv>=1.0.0
|
| 12 |
+
scipy>=1.10.0
|
| 13 |
+
joblib>=1.2.0
|
| 14 |
+
pydantic>=2.0.0
|
| 15 |
+
requests>=2.28.0
|
| 16 |
+
pillow>=9.0.0
|
| 17 |
+
altair>=4.2.0
|
| 18 |
+
beautifulsoup4>=4.11.0
|
src/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from .ui import *
|
| 3 |
+
from .training import *
|
| 4 |
+
from .preprocessing import *
|
| 5 |
+
from .utils import *
|
| 6 |
+
|
src/preprocessing/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .clean_data import clean_csv
|
| 2 |
+
from .clean_df_fallback import clean_dataframe_fallback
|
| 3 |
+
|
| 4 |
+
__all__ = ['clean_csv' , 'clean_dataframe_fallback']
|
src/preprocessing/clean_data.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sklearn.impute import SimpleImputer
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
from scipy import stats
|
| 4 |
+
from langchain_groq import ChatGroq
|
| 5 |
+
from langchain.chains import LLMChain
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
import re
|
| 9 |
+
import os
|
| 10 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 11 |
+
from langchain.prompts import PromptTemplate
|
| 12 |
+
from langchain_core.runnables import RunnableSequence
|
| 13 |
+
import streamlit as st
|
| 14 |
+
from .clean_df_fallback import clean_dataframe_fallback
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# # Load environment variables
|
| 18 |
+
|
| 19 |
+
load_dotenv()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
groq_api_key = os.getenv("GROQ_API_KEY")
|
| 24 |
+
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if not gemini_api_key:
|
| 28 |
+
raise ValueError("GEMINI_API_KEY not found in environment variables")
|
| 29 |
+
if not groq_api_key:
|
| 30 |
+
raise ValueError("GROQ_API_KEY not found in environment variables")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# Initialize the LLM model
|
| 34 |
+
try:
|
| 35 |
+
llm = ChatGoogleGenerativeAI(
|
| 36 |
+
model="gemini-2.0-flash-lite-preview-02-05",
|
| 37 |
+
google_api_key=gemini_api_key
|
| 38 |
+
)
|
| 39 |
+
print("Primary Gemini LLM loaded successfully.")
|
| 40 |
+
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f"Error initializing primary Gemini LLM: {e}")
|
| 43 |
+
|
| 44 |
+
# Fallback to a different LLM from Groq
|
| 45 |
+
try:
|
| 46 |
+
llm = ChatGroq(
|
| 47 |
+
model="gemma2-9b-it", # replace with your desired Groq model identifier
|
| 48 |
+
groq_api_key=groq_api_key
|
| 49 |
+
)
|
| 50 |
+
print("Fallback Groq LLM loaded successfully.")
|
| 51 |
+
|
| 52 |
+
except Exception as e2:
|
| 53 |
+
print(f"Error initializing fallback Groq LLM: {e2}")
|
| 54 |
+
llm=None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# Cache the clean_csv function to prevent redundant cleaning
|
| 59 |
+
@st.cache_data(ttl=3600, show_spinner=False)
|
| 60 |
+
def cached_clean_csv(df_json, skip_cleaning=False):
|
| 61 |
+
"""Cached version of the clean_csv function to prevent redundant cleaning.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
df_json: JSON string representation of the dataframe (for hashing)
|
| 65 |
+
skip_cleaning: Whether to skip cleaning
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Tuple of (cleaned_df, insights)
|
| 69 |
+
"""
|
| 70 |
+
# Convert JSON back to dataframe
|
| 71 |
+
df = pd.read_json(df_json, orient='records')
|
| 72 |
+
|
| 73 |
+
# If skip_cleaning is True, return the dataframe as is
|
| 74 |
+
if skip_cleaning:
|
| 75 |
+
return df, "No cleaning performed (user skipped)."
|
| 76 |
+
|
| 77 |
+
# Reset any test results if we're cleaning a new dataset
|
| 78 |
+
if "test_results_calculated" in st.session_state:
|
| 79 |
+
st.session_state.test_results_calculated = False
|
| 80 |
+
# Clear any previous test metrics to avoid using stale data
|
| 81 |
+
for key in ['test_metrics', 'test_y_pred', 'test_y_test', 'test_cm', 'sampling_message']:
|
| 82 |
+
if key in st.session_state:
|
| 83 |
+
del st.session_state[key]
|
| 84 |
+
|
| 85 |
+
# Call the actual cleaning function
|
| 86 |
+
return clean_csv(df)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def clean_csv(df):
|
| 90 |
+
"""Original clean_csv function that performs the actual cleaning."""
|
| 91 |
+
# ---------------------------
|
| 92 |
+
# Early fallback if LLM initialization failed
|
| 93 |
+
# ---------------------------
|
| 94 |
+
if llm is None:
|
| 95 |
+
print("LLM initialization failed; using hardcoded cleaning function.")
|
| 96 |
+
fallback_df = clean_dataframe_fallback(df)
|
| 97 |
+
|
| 98 |
+
return fallback_df , "LLM initialization failed; using hardcoded cleaning function, so no insights were generated."
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ---------------------------
|
| 103 |
+
# LLM-based cleaning function generation
|
| 104 |
+
# ---------------------------
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# Escape curly braces in the JSON sample and column names
|
| 108 |
+
sample_data = df.head(3).to_json(orient='records')
|
| 109 |
+
escaped_sample_data = sample_data.replace("{", "{{").replace("}", "}}")
|
| 110 |
+
|
| 111 |
+
escaped_columns = [
|
| 112 |
+
col.replace("{", "{{").replace("}", "}}") for col in df.columns
|
| 113 |
+
]
|
| 114 |
+
column_names_str = ", ".join(escaped_columns)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# Define the prompt for generating the cleaning function
|
| 119 |
+
initial_prompt = PromptTemplate.from_template(f'''
|
| 120 |
+
You are given the following sample data from a pandas DataFrame:
|
| 121 |
+
{escaped_sample_data}
|
| 122 |
+
|
| 123 |
+
column names are : [{column_names_str}].
|
| 124 |
+
|
| 125 |
+
Generate a Python function named clean_dataframe(df) considering the following:
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
1. Performs thorough data cleaning without performing feature engineering. Ensure all necessary cleaning steps are included.
|
| 129 |
+
2. Uses assignment operations (e.g., df = df.drop(...)) and avoids inplace=True for clarity.
|
| 130 |
+
3. First deeply analyze each column’s content this is the most important step , to infer its predominant data type for example if we have RS.2100 in rows remove rs and if we have (89%) remove % , if the column contains only text and no numbers then it is a text column and if it contains numbers and text then it is a mixed column and if it contains only numbers then it is a numeric column.
|
| 131 |
+
4. For columns that are intended to be numeric but contain extra characters (such as '%' in percentage values, currency symbols like 'Rs.', '$', and commas), remove all non-digit characters (except for the decimal point) and convert them to a numeric type.
|
| 132 |
+
5. For columns that are clearly text or categorical, preserve the content without removing digits or altering the textual information.
|
| 133 |
+
6. Handles missing values appropriately: fill numeric columns with the median (or 0 if the median is not available) and non-numeric columns with 'Unknown'.
|
| 134 |
+
7. For columns where more than 50% of values are strings and less than 10% are numeric, perform conservative string cleaning by removing unwanted special symbols while preserving meaningful digits.
|
| 135 |
+
8. For columns whose names contain 'name', 'Name', or 'Names' (case-insensitive), convert to string type and remove extraneous numeric characters only if they are not part of the essential text.
|
| 136 |
+
9. Preserves other categorical or text columns (such as Gender, City, State, Country, etc.) unless explicitly specified for removal.
|
| 137 |
+
10. Handles edge cases such as completely empty columns appropriately.
|
| 138 |
+
|
| 139 |
+
Return only the Python code for the function, with no explanations or extra formatting.
|
| 140 |
+
|
| 141 |
+
'''
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# Define the refinement prompt
|
| 147 |
+
refine_prompt = PromptTemplate.from_template(
|
| 148 |
+
"The following Python code for cleaning a DataFrame caused an error: {error}\n"
|
| 149 |
+
"Original code:\n{code}\n"
|
| 150 |
+
"Please correct the code to fix the error and ensure it returns a cleaned DataFrame. "
|
| 151 |
+
"Return only the corrected Python code for the function, no explanations or formatting."
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# Create the chains using modern LangChain approach
|
| 158 |
+
initial_chain = initial_prompt | llm
|
| 159 |
+
refine_chain = refine_prompt | llm
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def extract_code(response):
|
| 168 |
+
|
| 169 |
+
if isinstance(response, str):
|
| 170 |
+
# Handle Markdown or plain text
|
| 171 |
+
if "```python" in response:
|
| 172 |
+
match = re.search(r'```python\n(.*?)\n```', response, re.DOTALL)
|
| 173 |
+
return match.group(1).strip() if match else response
|
| 174 |
+
|
| 175 |
+
elif "```" in response:
|
| 176 |
+
match = re.search(r'```\n(.*?)\n```', response, re.DOTALL)
|
| 177 |
+
return match.group(1).strip() if match else response
|
| 178 |
+
|
| 179 |
+
return response.strip()
|
| 180 |
+
|
| 181 |
+
# Handle LLM response objects
|
| 182 |
+
content = getattr(response, 'content', str(response))
|
| 183 |
+
|
| 184 |
+
if "```python" in content:
|
| 185 |
+
match = re.search(r'```python\n(.*?)\n```', content, re.DOTALL)
|
| 186 |
+
return match.group(1).strip() if match else content
|
| 187 |
+
|
| 188 |
+
elif "```" in content:
|
| 189 |
+
match = re.search(r'```\n(.*?)\n```', content, re.DOTALL)
|
| 190 |
+
return match.group(1).strip() if match else content
|
| 191 |
+
|
| 192 |
+
return content.strip()
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
try:
|
| 201 |
+
# Generate initial chain and extract the cleaned code
|
| 202 |
+
cleaning_function_code = extract_code(initial_chain.invoke({}))
|
| 203 |
+
print("Initial generated cleaning function code not executed yet is:\n", cleaning_function_code)
|
| 204 |
+
|
| 205 |
+
# Iterative refinement loop with max 5 attempts
|
| 206 |
+
max_attempts = 5
|
| 207 |
+
|
| 208 |
+
for attempt in range(max_attempts):
|
| 209 |
+
print(f"Attempt {attempt} code:\n{cleaning_function_code}") # <-- HERE
|
| 210 |
+
try:
|
| 211 |
+
# Execute the code in global namespace
|
| 212 |
+
exec(cleaning_function_code, globals())
|
| 213 |
+
# Call the function and assign the result back to df
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
if 'clean_dataframe' not in globals():
|
| 217 |
+
raise NameError("Cleaning function not defined in generated code")
|
| 218 |
+
|
| 219 |
+
df = clean_dataframe(df)
|
| 220 |
+
|
| 221 |
+
print(f"Cleaning successful on attempt {attempt + 1}")
|
| 222 |
+
break
|
| 223 |
+
|
| 224 |
+
# if the cleaning fails
|
| 225 |
+
except Exception as e:
|
| 226 |
+
error_message = str(e)
|
| 227 |
+
print(f"Error on attempt {attempt + 1}: {error_message}")
|
| 228 |
+
|
| 229 |
+
if attempt < max_attempts - 1:
|
| 230 |
+
|
| 231 |
+
# Refine the code using the error message if there are still epochs left
|
| 232 |
+
refined_response = refine_chain.invoke({"error": error_message, "code": cleaning_function_code})
|
| 233 |
+
cleaning_function_code = extract_code(refined_response)
|
| 234 |
+
|
| 235 |
+
print(f"Refined cleaning function code:\n", cleaning_function_code)
|
| 236 |
+
|
| 237 |
+
else:
|
| 238 |
+
print("Failed to clean DataFrame after 5 maximum attempts")
|
| 239 |
+
# AFter all the failed attempt using the hardcoded logic
|
| 240 |
+
|
| 241 |
+
df = clean_dataframe_fallback(df)
|
| 242 |
+
|
| 243 |
+
except Exception as e:
|
| 244 |
+
print("⚡No successful cleaning done enforcing fallback")
|
| 245 |
+
df = clean_dataframe_fallback(df)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
cleaned_df = df
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
insights_prompt = f"""
|
| 252 |
+
Analyze this cleaned dataset:
|
| 253 |
+
- Columns: {cleaned_df.columns.tolist()}
|
| 254 |
+
- Sample data: {cleaned_df.head(3).to_dict()}
|
| 255 |
+
- Numeric stats: {cleaned_df.describe().to_dict()}
|
| 256 |
+
Provide key data quality insights and recommendations.
|
| 257 |
+
"""
|
| 258 |
+
|
| 259 |
+
try:
|
| 260 |
+
insights_response = llm.invoke(insights_prompt)
|
| 261 |
+
analysis_insights = insights_response.content
|
| 262 |
+
except Exception as e:
|
| 263 |
+
analysis_insights = f"Insight generation failed: {str(e)}"
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# Return the cleaned DataFrame and dummy insights
|
| 268 |
+
return cleaned_df, analysis_insights
|
src/preprocessing/clean_df_fallback.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import streamlit as st
|
| 5 |
+
|
| 6 |
+
# Define fallback cleaning function
|
| 7 |
+
|
| 8 |
+
@st.cache_data
|
| 9 |
+
def clean_dataframe_fallback(df):
|
| 10 |
+
"""Hardcoded data cleaning pipeline"""
|
| 11 |
+
|
| 12 |
+
"""Generic data cleaning pipeline with categorical preservation"""
|
| 13 |
+
df_cleaned = df.copy()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
df_cleaned = df_cleaned.applymap(
|
| 17 |
+
lambda x: re.sub(r"\(.*?\)", "", str(x)) if isinstance(x, str) else x)
|
| 18 |
+
|
| 19 |
+
# Remove 'ref.' references
|
| 20 |
+
df_cleaned = df_cleaned.applymap(
|
| 21 |
+
lambda x: re.sub(r"ref\.", "", str(x), flags=re.IGNORECASE) if isinstance(x, str) else x)
|
| 22 |
+
|
| 23 |
+
# Remove any other special characters except letters, digits, spaces, and dots
|
| 24 |
+
df_cleaned = df_cleaned.applymap(
|
| 25 |
+
lambda x: re.sub(r"[^\w\s\d\.]", "", str(x)).strip() if isinstance(x, str) else x
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Step 0 - Clean column names first
|
| 30 |
+
df_cleaned.columns = [col.strip().lower().replace(' ', '_') for col in df_cleaned.columns]
|
| 31 |
+
|
| 32 |
+
# Define measurement units to remove
|
| 33 |
+
measurement_units = {
|
| 34 |
+
'weight': r'\s*(kg|kilograms|lbs|pounds)$',
|
| 35 |
+
'height': r'\s*(cm|centimeters|inches|feet|ft)$'
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Step 1 - Remove redundant columns
|
| 40 |
+
# Preservation patterns for categorical columns
|
| 41 |
+
preserve_pattern = re.compile(r'(name|brand|model|type|category|region|text|desc|color|size)', re.IGNORECASE)
|
| 42 |
+
preserved_cols = [col for col in df_cleaned.columns if preserve_pattern.search(col)]
|
| 43 |
+
|
| 44 |
+
# ID pattern detection
|
| 45 |
+
id_pattern = re.compile(r'(_id|id_|num|no|number|identifier|code|idx|row)', re.IGNORECASE)
|
| 46 |
+
id_cols = [col for col in df_cleaned.columns if id_pattern.search(col) and col not in preserved_cols]
|
| 47 |
+
|
| 48 |
+
# Unique value columns
|
| 49 |
+
unique_cols = [col for col in df_cleaned.columns
|
| 50 |
+
if df_cleaned[col].nunique() == len(df_cleaned)
|
| 51 |
+
and col not in preserved_cols]
|
| 52 |
+
|
| 53 |
+
redundant_cols = list(set(id_cols + unique_cols))
|
| 54 |
+
df_cleaned = df_cleaned.drop(columns=redundant_cols)
|
| 55 |
+
print(f"Removed {len(redundant_cols)} redundant columns: {redundant_cols}")
|
| 56 |
+
|
| 57 |
+
# Step 2 - Enhanced numeric detection with categorical protection
|
| 58 |
+
for col in df_cleaned.columns:
|
| 59 |
+
if col in preserved_cols:
|
| 60 |
+
print(f"Preserving categorical column: {col}")
|
| 61 |
+
continue # Skip preserved columns
|
| 62 |
+
|
| 63 |
+
if any(unit in col for unit in measurement_units.keys()):
|
| 64 |
+
pattern = measurement_units.get(col.split('_')[0], r'')
|
| 65 |
+
df_cleaned[col] = df_cleaned[col].astype(str).str.replace(pattern, '', regex=True).str.strip()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
if pd.api.types.is_numeric_dtype(df_cleaned[col]):
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
# Strict numeric pattern detection
|
| 73 |
+
non_null_count = df_cleaned[col].dropna().shape[0]
|
| 74 |
+
sample_size = min(100, non_null_count)
|
| 75 |
+
sample = df_cleaned[col].dropna().sample(sample_size, random_state=42)
|
| 76 |
+
numeric_pattern = r'^[-+]?\d*\.?\d+$' # Full string match
|
| 77 |
+
num_matches = sample.astype(str).str.fullmatch(numeric_pattern).mean()
|
| 78 |
+
|
| 79 |
+
if num_matches > 0.8: # High threshold
|
| 80 |
+
# Conservative cleaning
|
| 81 |
+
cleaned = df_cleaned[col].replace(r'[^\d\.\-]', '', regex=True)
|
| 82 |
+
converted = pd.to_numeric(cleaned, errors='coerce')
|
| 83 |
+
success_rate = converted.notna().mean()
|
| 84 |
+
|
| 85 |
+
if success_rate > 0.9: # Strict success requirement
|
| 86 |
+
df_cleaned[col] = converted
|
| 87 |
+
print(f"Converted {col} to numeric (success: {success_rate:.1%})")
|
| 88 |
+
|
| 89 |
+
# Step 3 - Date detection
|
| 90 |
+
date_cols = []
|
| 91 |
+
for col in df_cleaned.select_dtypes(exclude=np.number).columns:
|
| 92 |
+
if col in preserved_cols:
|
| 93 |
+
continue
|
| 94 |
+
try:
|
| 95 |
+
df_cleaned[col] = pd.to_datetime(df_cleaned[col], errors='raise')
|
| 96 |
+
date_cols.append(col)
|
| 97 |
+
print(f"Detected datetime: {col}")
|
| 98 |
+
except:
|
| 99 |
+
pass
|
| 100 |
+
|
| 101 |
+
# Example manual approach:
|
| 102 |
+
currency_cols = [col for col in df_cleaned.columns if any(keyword in col.lower() for keyword in ["price", "gross", "budget"])]
|
| 103 |
+
for col in currency_cols:
|
| 104 |
+
df_cleaned[col] = df_cleaned[col].astype(str).str.replace(r'[^\d\.]', '', regex=True) # remove everything except digits & dots
|
| 105 |
+
df_cleaned[col] = pd.to_numeric(df_cleaned[col], errors='coerce')
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# Step 4 - Missing value handling
|
| 110 |
+
numeric_cols = df_cleaned.select_dtypes(include=np.number).columns
|
| 111 |
+
categorical_cols = df_cleaned.select_dtypes(exclude=np.number).columns
|
| 112 |
+
|
| 113 |
+
# Numeric imputation
|
| 114 |
+
for col in numeric_cols:
|
| 115 |
+
if df_cleaned[col].isna().any():
|
| 116 |
+
df_cleaned[f'{col}_missing'] = df_cleaned[col].isna().astype(int)
|
| 117 |
+
df_cleaned[col].fillna(df_cleaned[col].median(), inplace=True)
|
| 118 |
+
|
| 119 |
+
# Categorical imputation
|
| 120 |
+
for col in categorical_cols:
|
| 121 |
+
if df_cleaned[col].isna().any():
|
| 122 |
+
mode_val = df_cleaned[col].mode()[0] if not df_cleaned[col].mode().empty else 'Unknown'
|
| 123 |
+
df_cleaned[col] = df_cleaned[col].fillna(mode_val)
|
| 124 |
+
|
| 125 |
+
# Step 5 - Text normalization for non-preserved columns
|
| 126 |
+
text_cols = [col for col in categorical_cols if col not in preserved_cols]
|
| 127 |
+
for col in text_cols:
|
| 128 |
+
df_cleaned[col] = df_cleaned[col].astype(str).apply(lambda x: re.sub(r'\s+', ' ', re.sub(r'[^\w\s]', '', x)).strip().lower())
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# Step 6 - Outlier handling (preserve categoricals)
|
| 132 |
+
numeric_cols = df_cleaned.select_dtypes(include=np.number).columns
|
| 133 |
+
for col in numeric_cols:
|
| 134 |
+
if df_cleaned[col].nunique() > 10:
|
| 135 |
+
q1 = df_cleaned[col].quantile(0.05)
|
| 136 |
+
q3 = df_cleaned[col].quantile(0.95)
|
| 137 |
+
df_cleaned[col] = np.clip(df_cleaned[col], q1, q3)
|
| 138 |
+
|
| 139 |
+
# Step 7 - Final validation
|
| 140 |
+
df_cleaned = df_cleaned.drop_duplicates().reset_index(drop=True)
|
| 141 |
+
|
| 142 |
+
return df_cleaned
|
| 143 |
+
|
src/training/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .train import train_model
|
| 2 |
+
from .hyperparametrs import get_hyperparams_ui
|
| 3 |
+
from .model_training import model_training_tab
|
| 4 |
+
from .test_result import display_model_evaluation
|
| 5 |
+
|
| 6 |
+
__all__ = ["train_model" , "get_hyperparams_ui", "model_training_tab" , "display_model_evaluation"]
|
| 7 |
+
|
src/training/hyperparametrs.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Define Hyperparameter Options
|
| 6 |
+
def get_hyperparams_ui(model_name):
|
| 7 |
+
"""Generate UI components for model-specific hyperparameters."""
|
| 8 |
+
hyperparams = {}
|
| 9 |
+
|
| 10 |
+
if model_name in ["Random Forest Regressor", "Random Forest"]:
|
| 11 |
+
hyperparams["n_estimators"] = st.number_input("Number of Trees (n_estimators)", min_value=10, max_value=500, value=100)
|
| 12 |
+
hyperparams["max_depth"] = st.number_input("Max Depth", min_value=1, max_value=50, value=10)
|
| 13 |
+
hyperparams["min_samples_split"] = st.number_input("Min Samples Split", min_value=2, max_value=10, value=2)
|
| 14 |
+
|
| 15 |
+
elif model_name in ["XGBoost Regressor", "XGBoost"]:
|
| 16 |
+
hyperparams["n_estimators"] = st.number_input("Number of Boosting Rounds (n_estimators)", min_value=10, max_value=500, value=100)
|
| 17 |
+
hyperparams["learning_rate"] = st.slider("Learning Rate", 0.01, 1.0, 0.1)
|
| 18 |
+
hyperparams["max_depth"] = st.number_input("Max Depth", min_value=1, max_value=50, value=6)
|
| 19 |
+
|
| 20 |
+
elif model_name == "Linear Regression":
|
| 21 |
+
st.info("No hyperparameters required for Linear Regression.")
|
| 22 |
+
|
| 23 |
+
# New Regression Models:
|
| 24 |
+
elif model_name == "Polynomial Regression":
|
| 25 |
+
hyperparams["degree"] = st.number_input("Degree of Polynomial Features", min_value=2, max_value=10, value=2)
|
| 26 |
+
# You may add additional hyperparameters for the underlying LinearRegression if needed
|
| 27 |
+
|
| 28 |
+
elif model_name == "Ridge Regression":
|
| 29 |
+
hyperparams["alpha"] = st.slider("Regularization Strength (alpha)", 0.01, 10.0, 1.0)
|
| 30 |
+
hyperparams["solver"] = st.selectbox("Solver", ["auto", "svd", "cholesky", "lsqr", "sparse_cg", "sag", "saga", "lbfgs"])
|
| 31 |
+
|
| 32 |
+
elif model_name == "Lasso Regression":
|
| 33 |
+
hyperparams["alpha"] = st.slider("Regularization Strength (alpha)", 0.01, 10.0, 1.0)
|
| 34 |
+
hyperparams["max_iter"] = st.number_input("Max Iterations", min_value=100, max_value=1000, value=1000)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
elif model_name == "Logistic Regression":
|
| 39 |
+
hyperparams["C"] = st.slider("Regularization Strength (C)", 0.01, 10.0, 1.0)
|
| 40 |
+
hyperparams["max_iter"] = st.number_input("Max Iterations", min_value=100, max_value=1000, value=200)
|
| 41 |
+
|
| 42 |
+
elif model_name == "Support Vector Regressor":
|
| 43 |
+
hyperparams["C"] = st.slider("Regularization parameter (C)", 0.1, 100.0, 1.0)
|
| 44 |
+
hyperparams["epsilon"] = st.slider("Epsilon", 0.0, 1.0, 0.1)
|
| 45 |
+
hyperparams["kernel"] = st.selectbox("Kernel", ["linear", "rbf", "poly", "sigmoid"])
|
| 46 |
+
|
| 47 |
+
elif model_name == "Decision Tree Regressor":
|
| 48 |
+
hyperparams["max_depth"] = st.number_input("Max Depth", min_value=1, max_value=50, value=10)
|
| 49 |
+
hyperparams["min_samples_split"] = st.number_input("Min Samples Split", min_value=2, max_value=10, value=2)
|
| 50 |
+
|
| 51 |
+
elif model_name == "K-Nearest Neighbors Regressor":
|
| 52 |
+
hyperparams["n_neighbors"] = st.number_input("Number of Neighbors", min_value=1, max_value=100, value=5)
|
| 53 |
+
hyperparams["weights"] = st.selectbox("Weight Function", ["uniform", "distance"])
|
| 54 |
+
|
| 55 |
+
elif model_name == "ElasticNet":
|
| 56 |
+
hyperparams["alpha"] = st.slider("Alpha", 0.01, 10.0, 1.0)
|
| 57 |
+
hyperparams["l1_ratio"] = st.slider("L1 Ratio", 0.0, 1.0, 0.5)
|
| 58 |
+
|
| 59 |
+
elif model_name == "Gradient Boosting Regressor":
|
| 60 |
+
hyperparams["n_estimators"] = st.number_input("Number of Estimators", min_value=10, max_value=500, value=100)
|
| 61 |
+
hyperparams["learning_rate"] = st.slider("Learning Rate", 0.01, 1.0, 0.1)
|
| 62 |
+
hyperparams["max_depth"] = st.number_input("Max Depth", min_value=1, max_value=20, value=3)
|
| 63 |
+
|
| 64 |
+
elif model_name == "AdaBoost Regressor":
|
| 65 |
+
hyperparams["n_estimators"] = st.number_input("Number of Estimators", min_value=10, max_value=500, value=50)
|
| 66 |
+
hyperparams["learning_rate"] = st.slider("Learning Rate", 0.01, 1.0, 0.1)
|
| 67 |
+
|
| 68 |
+
elif model_name == "Bayesian Ridge":
|
| 69 |
+
hyperparams["alpha_1"] = st.slider("Alpha 1", 1e-6, 1e-1, 1e-4, format="%.6f")
|
| 70 |
+
hyperparams["alpha_2"] = st.slider("Alpha 2", 1e-6, 1e-1, 1e-4, format="%.6f")
|
| 71 |
+
hyperparams["lambda_1"] = st.slider("Lambda 1", 1e-6, 1e-1, 1e-4, format="%.6f")
|
| 72 |
+
hyperparams["lambda_2"] = st.slider("Lambda 2", 1e-6, 1e-1, 1e-4, format="%.6f")
|
| 73 |
+
|
| 74 |
+
# --- Additional Classification Models ---
|
| 75 |
+
elif model_name == "Support Vector Classifier":
|
| 76 |
+
hyperparams["C"] = st.slider("Regularization parameter (C)", 0.1, 100.0, 1.0)
|
| 77 |
+
hyperparams["kernel"] = st.selectbox("Kernel", ["linear", "rbf", "poly", "sigmoid"])
|
| 78 |
+
|
| 79 |
+
elif model_name == "Decision Tree Classifier":
|
| 80 |
+
hyperparams["max_depth"] = st.number_input("Max Depth", min_value=1, max_value=50, value=10)
|
| 81 |
+
hyperparams["min_samples_split"] = st.number_input("Min Samples Split", min_value=2, max_value=10, value=2)
|
| 82 |
+
|
| 83 |
+
elif model_name == "K-Nearest Neighbors Classifier":
|
| 84 |
+
hyperparams["n_neighbors"] = st.number_input("Number of Neighbors", min_value=1, max_value=100, value=5)
|
| 85 |
+
hyperparams["weights"] = st.selectbox("Weight Function", ["uniform", "distance"])
|
| 86 |
+
|
| 87 |
+
elif model_name == "Gradient Boosting Classifier":
|
| 88 |
+
hyperparams["n_estimators"] = st.number_input("Number of Estimators", min_value=10, max_value=500, value=100)
|
| 89 |
+
hyperparams["learning_rate"] = st.slider("Learning Rate", 0.01, 1.0, 0.1)
|
| 90 |
+
hyperparams["max_depth"] = st.number_input("Max Depth", min_value=1, max_value=20, value=3)
|
| 91 |
+
|
| 92 |
+
elif model_name == "AdaBoost Classifier":
|
| 93 |
+
hyperparams["n_estimators"] = st.number_input("Number of Estimators", min_value=10, max_value=500, value=50)
|
| 94 |
+
hyperparams["learning_rate"] = st.slider("Learning Rate", 0.01, 1.0, 0.1)
|
| 95 |
+
|
| 96 |
+
elif model_name == "Gaussian Naive Bayes":
|
| 97 |
+
hyperparams["var_smoothing"] = st.slider("Var Smoothing", 1e-12, 1e-8, 1e-9, format="%.12f")
|
| 98 |
+
|
| 99 |
+
elif model_name == "Quadratic Discriminant Analysis":
|
| 100 |
+
hyperparams["reg_param"] = st.slider("Regularization Parameter", 0.0, 1.0, 0.0)
|
| 101 |
+
|
| 102 |
+
elif model_name == "Linear Discriminant Analysis":
|
| 103 |
+
hyperparams["solver"] = st.selectbox("Solver", ["svd", "lsqr", "eigen"])
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
return hyperparams
|
src/training/model_training.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from .hyperparametrs import get_hyperparams_ui
|
| 3 |
+
import pickle
|
| 4 |
+
from .train import train_model
|
| 5 |
+
|
| 6 |
+
# Model Training Tab
|
| 7 |
+
def model_training_tab(df):
|
| 8 |
+
# Ensure we have session state for model training
|
| 9 |
+
if "target_column" not in st.session_state:
|
| 10 |
+
st.session_state.target_column = df.columns[0] if not df.empty else None
|
| 11 |
+
|
| 12 |
+
if "selected_model" not in st.session_state:
|
| 13 |
+
st.session_state.selected_model = None
|
| 14 |
+
|
| 15 |
+
st.subheader("📌 Model Training")
|
| 16 |
+
|
| 17 |
+
# Use session state to maintain selection across reruns
|
| 18 |
+
target_column = st.selectbox(
|
| 19 |
+
"🎯 Select Target Column (Y)",
|
| 20 |
+
df.columns,
|
| 21 |
+
index=list(df.columns).index(st.session_state.target_column) if st.session_state.target_column in df.columns else 0,
|
| 22 |
+
key="target_column_select"
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# Update session state after selection
|
| 26 |
+
st.session_state.target_column = target_column
|
| 27 |
+
|
| 28 |
+
# Infer task type automatically
|
| 29 |
+
task_type = "classification" if df[target_column].dtype == "object" or df[target_column].nunique() <= 10 else "regression"
|
| 30 |
+
st.write(f"🔍 Detected Task Type: **{task_type.capitalize()}**")
|
| 31 |
+
|
| 32 |
+
model_options = {
|
| 33 |
+
"classification": ["Random Forest", "Logistic Regression", "XGBoost" , "Support Vector Classifier", "Decision Tree Classifier", "K-Nearest Neighbors Classifier", "Gradient Boosting Classifier", "AdaBoost Classifier", "Gaussian Naive Bayes", "Quadratic Discriminant Analysis", "Linear Discriminant Analysis"],
|
| 34 |
+
"regression": ["Linear Regression", "Random Forest Regressor", "XGBoost Regressor" , "Support Vector Regressor", "Decision Tree Regressor", "K-Nearest Neighbors Regressor", "ElasticNet", "Gradient Boosting Regressor", "AdaBoost Regressor", "Bayesian Ridge" , "Ridge Regression", "Lasso Regression"],
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
# Initialize selected model if not already set or if task type changed
|
| 38 |
+
if st.session_state.selected_model not in model_options[task_type]:
|
| 39 |
+
st.session_state.selected_model = model_options[task_type][0]
|
| 40 |
+
|
| 41 |
+
# Use session state to maintain selection across reruns
|
| 42 |
+
selected_model_name = st.selectbox(
|
| 43 |
+
"🤖 Choose Model",
|
| 44 |
+
model_options[task_type],
|
| 45 |
+
index=model_options[task_type].index(st.session_state.selected_model),
|
| 46 |
+
key="selected_model_select"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Update session state after selection
|
| 50 |
+
st.session_state.selected_model = selected_model_name
|
| 51 |
+
|
| 52 |
+
st.markdown("### 🔧 Hyperparameters")
|
| 53 |
+
hyperparams = get_hyperparams_ui(selected_model_name)
|
| 54 |
+
|
| 55 |
+
# Use a unique key for the button to avoid conflicts
|
| 56 |
+
if st.button("🚀 Train Model", key="train_model_button_unique"):
|
| 57 |
+
with st.spinner("Training in progress... ⏳"):
|
| 58 |
+
trained_model = train_model(df, target_column, task_type, selected_model_name, hyperparams)
|
| 59 |
+
st.success("✅ Model trained successfully!")
|
| 60 |
+
st.session_state.trained_model = trained_model
|
| 61 |
+
st.session_state.model_trained = True
|
| 62 |
+
|
| 63 |
+
# Note: test_results_calculated is already reset in train_model function
|
| 64 |
+
|
| 65 |
+
if "trained_model" in st.session_state:
|
| 66 |
+
st.markdown("### 📥 Download Trained Model")
|
| 67 |
+
|
| 68 |
+
# Use a safer approach for file operations with proper cleanup
|
| 69 |
+
try:
|
| 70 |
+
# Use a temporary file that will be automatically cleaned up
|
| 71 |
+
import tempfile
|
| 72 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as temp_file:
|
| 73 |
+
pickle.dump(st.session_state.trained_model, temp_file)
|
| 74 |
+
temp_file_path = temp_file.name
|
| 75 |
+
|
| 76 |
+
# Read the file for download
|
| 77 |
+
with open(temp_file_path, "rb") as f:
|
| 78 |
+
st.download_button(
|
| 79 |
+
label="📥 Download Model",
|
| 80 |
+
data=f,
|
| 81 |
+
file_name="trained_model.pkl",
|
| 82 |
+
mime="application/octet-stream",
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Clean up the temporary file
|
| 86 |
+
import os
|
| 87 |
+
try:
|
| 88 |
+
os.unlink(temp_file_path)
|
| 89 |
+
except:
|
| 90 |
+
pass # Silently handle deletion errors
|
| 91 |
+
|
| 92 |
+
except Exception as e:
|
| 93 |
+
st.error(f"Error preparing model for download: {str(e)}")
|
src/training/test_result.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from ui.test_results import display_test_results
|
| 3 |
+
|
| 4 |
+
def display_model_evaluation():
|
| 5 |
+
"""Displays the evaluation results of the trained model on the test set."""
|
| 6 |
+
|
| 7 |
+
st.header("📊 Model Evaluation on Test Set")
|
| 8 |
+
|
| 9 |
+
# Ensure model and test data exist in session state
|
| 10 |
+
if "trained_model" in st.session_state and "X_test" in st.session_state:
|
| 11 |
+
trained_model = st.session_state.trained_model
|
| 12 |
+
X_test = st.session_state.X_test
|
| 13 |
+
y_test = st.session_state.y_test
|
| 14 |
+
task_type = st.session_state.task_type
|
| 15 |
+
|
| 16 |
+
# Handle classification case where model may include a label encoder
|
| 17 |
+
if task_type == "classification":
|
| 18 |
+
if isinstance(trained_model, tuple):
|
| 19 |
+
pipeline, label_encoder = trained_model
|
| 20 |
+
display_test_results((pipeline, label_encoder), X_test, y_test, task_type)
|
| 21 |
+
else:
|
| 22 |
+
display_test_results(trained_model, X_test, y_test, task_type)
|
| 23 |
+
else:
|
| 24 |
+
display_test_results(trained_model, X_test, y_test, task_type)
|
| 25 |
+
|
| 26 |
+
else:
|
| 27 |
+
st.warning("🚨 Train a model first to see test results!")
|
src/training/train.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sklearn.compose import ColumnTransformer
|
| 2 |
+
from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelEncoder
|
| 3 |
+
from sklearn.model_selection import train_test_split
|
| 4 |
+
from sklearn.linear_model import LinearRegression, LogisticRegression
|
| 5 |
+
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
|
| 6 |
+
from xgboost import XGBRegressor, XGBClassifier
|
| 7 |
+
from sklearn.svm import SVR, SVC
|
| 8 |
+
from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier
|
| 9 |
+
from sklearn.neighbors import KNeighborsRegressor, KNeighborsClassifier
|
| 10 |
+
from sklearn.linear_model import ElasticNet, BayesianRidge
|
| 11 |
+
from sklearn.ensemble import GradientBoostingRegressor, AdaBoostRegressor, GradientBoostingClassifier, AdaBoostClassifier
|
| 12 |
+
from sklearn.naive_bayes import GaussianNB
|
| 13 |
+
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis, LinearDiscriminantAnalysis
|
| 14 |
+
from sklearn.linear_model import Ridge, Lasso
|
| 15 |
+
from sklearn.impute import SimpleImputer
|
| 16 |
+
from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
| 17 |
+
from sklearn.compose import ColumnTransformer
|
| 18 |
+
from sklearn.pipeline import Pipeline as SkPipeline
|
| 19 |
+
|
| 20 |
+
import streamlit as st
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_model(task_type, model_name, hyperparams):
|
| 26 |
+
"""Returns the model instance based on user selection with hyperparameters."""
|
| 27 |
+
models = {
|
| 28 |
+
"regression": {
|
| 29 |
+
# Already existing:
|
| 30 |
+
"Linear Regression": LinearRegression,
|
| 31 |
+
"Random Forest Regressor": RandomForestRegressor,
|
| 32 |
+
"XGBoost Regressor": XGBRegressor,
|
| 33 |
+
# Additional regression models:
|
| 34 |
+
"Support Vector Regressor": SVR,
|
| 35 |
+
"Decision Tree Regressor": DecisionTreeRegressor,
|
| 36 |
+
"K-Nearest Neighbors Regressor": KNeighborsRegressor,
|
| 37 |
+
"ElasticNet": ElasticNet,
|
| 38 |
+
"Gradient Boosting Regressor": GradientBoostingRegressor,
|
| 39 |
+
"AdaBoost Regressor": AdaBoostRegressor,
|
| 40 |
+
"Bayesian Ridge": BayesianRidge,
|
| 41 |
+
"Ridge Regression": Ridge,
|
| 42 |
+
"Lasso Regression": Lasso ,
|
| 43 |
+
|
| 44 |
+
},
|
| 45 |
+
"classification": {
|
| 46 |
+
# Already existing:
|
| 47 |
+
"Logistic Regression": LogisticRegression,
|
| 48 |
+
"Random Forest": RandomForestClassifier,
|
| 49 |
+
"XGBoost": XGBClassifier,
|
| 50 |
+
# Additional classification models:
|
| 51 |
+
"Support Vector Classifier": SVC,
|
| 52 |
+
"Decision Tree Classifier": DecisionTreeClassifier,
|
| 53 |
+
"K-Nearest Neighbors Classifier": KNeighborsClassifier,
|
| 54 |
+
"Gradient Boosting Classifier": GradientBoostingClassifier,
|
| 55 |
+
"AdaBoost Classifier": AdaBoostClassifier,
|
| 56 |
+
"Gaussian Naive Bayes": GaussianNB,
|
| 57 |
+
"Quadratic Discriminant Analysis": QuadraticDiscriminantAnalysis,
|
| 58 |
+
"Linear Discriminant Analysis": LinearDiscriminantAnalysis
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
if task_type in models and model_name in models[task_type]:
|
| 64 |
+
return models[task_type][model_name](**hyperparams) # Apply hyperparameters
|
| 65 |
+
else:
|
| 66 |
+
raise ValueError(f"Invalid model selection: {model_name} for {task_type}")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def train_model(df, target_column, task_type, selected_model_name, hyperparams):
|
| 70 |
+
"""Preprocess data, train the selected model with hyperparameters, and return the trained model."""
|
| 71 |
+
|
| 72 |
+
with st.spinner(" Training model... Please wait!"):
|
| 73 |
+
|
| 74 |
+
# Get the model with hyperparameters
|
| 75 |
+
model = get_model(task_type, selected_model_name, hyperparams)
|
| 76 |
+
|
| 77 |
+
# Split features and target
|
| 78 |
+
X = df.drop(columns=[target_column])
|
| 79 |
+
y = df[target_column]
|
| 80 |
+
|
| 81 |
+
# Label encode target if classification (for categorical labels)
|
| 82 |
+
label_encoder = None
|
| 83 |
+
if task_type == "classification" and y.dtype == "object":
|
| 84 |
+
from sklearn.preprocessing import LabelEncoder
|
| 85 |
+
label_encoder = LabelEncoder()
|
| 86 |
+
y = label_encoder.fit_transform(y)
|
| 87 |
+
|
| 88 |
+
# Train-Test Split (80-20)
|
| 89 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
| 90 |
+
|
| 91 |
+
# Identify numerical and categorical columns
|
| 92 |
+
num_cols = X.select_dtypes(include=["int64", "float64"]).columns
|
| 93 |
+
cat_cols = X.select_dtypes(include=["object", "category"]).columns
|
| 94 |
+
|
| 95 |
+
# Preprocessing Pipeline
|
| 96 |
+
# Numeric pipeline: impute missing values then scale them
|
| 97 |
+
num_pipeline = SkPipeline([
|
| 98 |
+
("imputer", SimpleImputer(strategy="median")),
|
| 99 |
+
("scaler", StandardScaler())
|
| 100 |
+
])
|
| 101 |
+
|
| 102 |
+
# Categorical pipeline: impute missing values then one-hot encode them
|
| 103 |
+
cat_pipeline = SkPipeline([
|
| 104 |
+
("imputer", SimpleImputer(strategy="most_frequent")),
|
| 105 |
+
("encoder", OneHotEncoder(handle_unknown="ignore", sparse_output=False))
|
| 106 |
+
])
|
| 107 |
+
|
| 108 |
+
preprocessor = ColumnTransformer([
|
| 109 |
+
("num", num_pipeline, num_cols),
|
| 110 |
+
("cat", cat_pipeline, cat_cols)
|
| 111 |
+
])
|
| 112 |
+
|
| 113 |
+
pipeline = SkPipeline([
|
| 114 |
+
("preprocessor", preprocessor),
|
| 115 |
+
("model", model)
|
| 116 |
+
])
|
| 117 |
+
|
| 118 |
+
# Train Model
|
| 119 |
+
pipeline.fit(X_train, y_train)
|
| 120 |
+
|
| 121 |
+
# Store test data and metadata in session state
|
| 122 |
+
st.session_state.X_test = X_test
|
| 123 |
+
st.session_state.y_test = y_test
|
| 124 |
+
st.session_state.task_type = task_type
|
| 125 |
+
st.session_state.label_encoder = label_encoder # Store label encoder for decoding predictions
|
| 126 |
+
|
| 127 |
+
# Reset test results calculation flag when a new model is trained
|
| 128 |
+
if "test_results_calculated" in st.session_state:
|
| 129 |
+
st.session_state.test_results_calculated = False
|
| 130 |
+
|
| 131 |
+
# Clear any previous test metrics to avoid using stale data
|
| 132 |
+
for key in ['test_metrics', 'test_y_pred', 'test_y_test', 'test_cm', 'sampling_message']:
|
| 133 |
+
if key in st.session_state:
|
| 134 |
+
del st.session_state[key]
|
| 135 |
+
|
| 136 |
+
# Return trained model + label encoder (needed for decoding predictions if classification)
|
| 137 |
+
if task_type == "classification":
|
| 138 |
+
return pipeline, label_encoder
|
| 139 |
+
else:
|
| 140 |
+
return pipeline
|
src/ui/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .visualization import visualize_data
|
| 2 |
+
from .css import load_css
|
| 3 |
+
from .loading import show_loading_state
|
| 4 |
+
from .footer import show_footer
|
| 5 |
+
from .welcome import show_welcome_page
|
| 6 |
+
from .test_results import display_test_results
|
| 7 |
+
from .overview import show_overview_page
|
| 8 |
+
from .insight import display_ai_insights
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"load_css",
|
| 12 |
+
"show_footer",
|
| 13 |
+
"show_loading_state",
|
| 14 |
+
"show_welcome_page",
|
| 15 |
+
"visualize_data",
|
| 16 |
+
"display_test_results",
|
| 17 |
+
"show_overview_page" ,
|
| 18 |
+
"display_ai_insights"
|
| 19 |
+
|
| 20 |
+
]
|
src/ui/css.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/css.py
|
| 2 |
+
import streamlit as st
|
| 3 |
+
|
| 4 |
+
def load_css():
|
| 5 |
+
css = """
|
| 6 |
+
<style>
|
| 7 |
+
|
| 8 |
+
/* --- EXPLICIT COLOR DEFINITIONS --- */
|
| 9 |
+
:root {
|
| 10 |
+
--dark-bg: #0E1117;
|
| 11 |
+
--light-text: #FAFAFA;
|
| 12 |
+
|
| 13 |
+
--neon-green: #00FFA3;
|
| 14 |
+
--neon-font: Arial, Helvetica, sans-serif;
|
| 15 |
+
--tab-font-size: 18px;
|
| 16 |
+
--tab-bottom-border-height: 4px;
|
| 17 |
+
/* Component Styling Variables */
|
| 18 |
+
--card-bg: rgba(30, 30, 40, 0.5);
|
| 19 |
+
--card-border: rgba(0, 255, 163, 0.25);
|
| 20 |
+
--expander-header-bg: rgba(40, 40, 55, 0.6);
|
| 21 |
+
--expander-hover-bg: rgba(0, 255, 163, 0.1);
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
/* --- Base Styles --- */
|
| 25 |
+
body {
|
| 26 |
+
background-color: var(--dark-bg) !important;
|
| 27 |
+
color: var(--light-text) !important;
|
| 28 |
+
}
|
| 29 |
+
.stApp {
|
| 30 |
+
background-color: var(--dark-bg) !important;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
/* --- App Header Styling --- */
|
| 34 |
+
.app-title {
|
| 35 |
+
color: var(--neon-green) !important;
|
| 36 |
+
font-family: var(--neon-font);
|
| 37 |
+
font-size: 80px !important;
|
| 38 |
+
font-weight: 700;
|
| 39 |
+
text-shadow:
|
| 40 |
+
0 0 1px var(--neon-green),
|
| 41 |
+
0 0 2px var(--neon-green),
|
| 42 |
+
0 0 5px var(--neon-green),
|
| 43 |
+
0 0 45px var(--neon-green);
|
| 44 |
+
margin-bottom: 25px !important;
|
| 45 |
+
line-height: 1.0 !important;
|
| 46 |
+
text-align: center !important;
|
| 47 |
+
}
|
| 48 |
+
.app-tagline {
|
| 49 |
+
color: var(--neon-green) !important;
|
| 50 |
+
font-family: var(--neon-font);
|
| 51 |
+
font-size: 27px !important;
|
| 52 |
+
font-style: normal !important;
|
| 53 |
+
font-weight: 400 !important;
|
| 54 |
+
text-shadow:
|
| 55 |
+
0 0 5px var(--neon-green),
|
| 56 |
+
0 0 10px var(--neon-green);
|
| 57 |
+
margin-top: 10px !important;
|
| 58 |
+
text-align: center !important;
|
| 59 |
+
}
|
| 60 |
+
.app-header {
|
| 61 |
+
padding: 1rem 0 !important;
|
| 62 |
+
margin-bottom: 2rem !important;
|
| 63 |
+
text-align: center !important;
|
| 64 |
+
}
|
| 65 |
+
/* --- End App Header Styling --- */
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
/* --- Tab Styling --- */
|
| 69 |
+
div[data-baseweb="tab-list"] button[data-baseweb="tab"],
|
| 70 |
+
div[data-baseweb="tab-list"] button[data-baseweb="tab"] > div,
|
| 71 |
+
div[data-baseweb="tab-list"] button[data-baseweb="tab"] > div > span {
|
| 72 |
+
font-size: var(--tab-font-size) !important;
|
| 73 |
+
font-family: var(--neon-font) !important;
|
| 74 |
+
font-weight: 600 !important;
|
| 75 |
+
}
|
| 76 |
+
div[data-baseweb="tab-list"] button[data-baseweb="tab"][aria-selected="true"],
|
| 77 |
+
div[data-baseweb="tab-list"] button[data-baseweb="tab"][aria-selected="true"] > div,
|
| 78 |
+
div[data-baseweb="tab-list"] button[data-baseweb="tab"][aria-selected="true"] > div > span {
|
| 79 |
+
font-size: var(--tab-font-size) !important;
|
| 80 |
+
}
|
| 81 |
+
div[data-baseweb="tab-list"] {
|
| 82 |
+
border: none !important; border-top: none !important; border-right: none !important; border-left: none !important; border-bottom: none !important;
|
| 83 |
+
border-color: transparent !important; outline: none !important; box-shadow: none !important;
|
| 84 |
+
margin-bottom: 25px !important; padding: 0 !important;
|
| 85 |
+
display: flex !important; justify-content: space-around !important;
|
| 86 |
+
}
|
| 87 |
+
button[data-baseweb="tab"] {
|
| 88 |
+
color: var(--light-text) !important; padding: 1rem 1.5rem !important;
|
| 89 |
+
transition: color 0.3s ease, text-shadow 0.3s ease, border-bottom-color 0.3s ease !important;
|
| 90 |
+
border-style: solid !important; border-width: 0 0 var(--tab-bottom-border-height) 0 !important;
|
| 91 |
+
border-color: transparent transparent transparent transparent !important;
|
| 92 |
+
outline: none !important; box-shadow: none !important; background-color: transparent !important;
|
| 93 |
+
margin: 0 !important; line-height: normal !important; flex-shrink: 0 !important;
|
| 94 |
+
}
|
| 95 |
+
button[data-baseweb="tab"]::before, button[data-baseweb="tab"]::after { display: none !important; content: none !important; }
|
| 96 |
+
button[data-baseweb="tab"]:hover:not([aria-selected="true"]) {
|
| 97 |
+
color: var(--neon-green) !important; background-color: transparent !important;
|
| 98 |
+
border-color: transparent transparent transparent transparent !important; outline: none !important; box-shadow: none !important;
|
| 99 |
+
text-shadow: 0 0 3px var(--neon-green), 0 0 6px var(--neon-green);
|
| 100 |
+
}
|
| 101 |
+
button[data-baseweb="tab"][aria-selected="true"] {
|
| 102 |
+
border: none !important; border-top: none !important; border-right: none !important; border-left: none !important; border-bottom: none !important;
|
| 103 |
+
border-color: transparent !important; outline: none !important; box-shadow: none !important; background-color: transparent !important;
|
| 104 |
+
color: var(--neon-green) !important; border-bottom-style: solid !important;
|
| 105 |
+
border-bottom-width: var(--tab-bottom-border-height) !important; border-bottom-color: var(--neon-green) !important;
|
| 106 |
+
text-shadow: 0 0 3px var(--neon-green), 0 0 6px var(--neon-green);
|
| 107 |
+
}
|
| 108 |
+
/* --- End Tab Styling --- */
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
/* --- Welcome Page Specific Styling --- */
|
| 113 |
+
|
| 114 |
+
/* Main Welcome Header (H1) - *** UPDATED COLOR *** */
|
| 115 |
+
.welcome-header h1 {
|
| 116 |
+
font-size: 2.8rem !important;
|
| 117 |
+
font-weight: 700 !important;
|
| 118 |
+
margin-bottom: 0.5rem !important;
|
| 119 |
+
color: var(--neon-green) !important; /* Use neon green */
|
| 120 |
+
text-align: left !important;
|
| 121 |
+
border-bottom: none !important;
|
| 122 |
+
/* Optional: Add a subtle glow like the main title */
|
| 123 |
+
text-shadow: 0 0 4px rgba(0, 255, 163, 0.7);
|
| 124 |
+
}
|
| 125 |
+
/* Main Welcome Subtitle (P) */
|
| 126 |
+
.welcome-header p.subtitle {
|
| 127 |
+
font-size: 1.15rem !important;
|
| 128 |
+
color: var(--subtitle-text) !important; /* Keep subtitle gray */
|
| 129 |
+
margin-bottom: 0 !important;
|
| 130 |
+
text-align: left !important;
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
/* Section Headers (H2 generated by st.markdown("## ...")) */
|
| 134 |
+
.stApp h2 {
|
| 135 |
+
font-size: 1.9rem !important;
|
| 136 |
+
font-weight: 600 !important;
|
| 137 |
+
color: var(--neon-green) !important;
|
| 138 |
+
border-bottom: 1px solid rgba(0, 255, 163, 0.3);
|
| 139 |
+
padding-bottom: 8px !important;
|
| 140 |
+
margin-top: 40px !important;
|
| 141 |
+
margin-bottom: 25px !important;
|
| 142 |
+
}
|
| 143 |
+
/* Override for Sidebar Title H2 */
|
| 144 |
+
section[data-testid="stSidebar"] h2 {
|
| 145 |
+
border-bottom: none !important;
|
| 146 |
+
color: #E6E6FA !important;
|
| 147 |
+
font-size: 1.8rem !important;
|
| 148 |
+
font-weight: 600 !important;
|
| 149 |
+
margin: 0 !important;
|
| 150 |
+
padding-bottom: 0 !important;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
/* Feature Card Styling */
|
| 154 |
+
.feature-card {
|
| 155 |
+
background-color: var(--card-bg);
|
| 156 |
+
border: 1px solid var(--card-border);
|
| 157 |
+
border-radius: 8px;
|
| 158 |
+
padding: 1.5rem 1.75rem;
|
| 159 |
+
height: 100%;
|
| 160 |
+
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
|
| 161 |
+
transition: transform 0.2s ease-in-out, box-shadow 0.2s ease-in-out;
|
| 162 |
+
margin-bottom: 1rem;
|
| 163 |
+
}
|
| 164 |
+
.feature-card:hover {
|
| 165 |
+
transform: translateY(-3px);
|
| 166 |
+
box-shadow: 0 6px 12px rgba(0, 255, 163, 0.15);
|
| 167 |
+
}
|
| 168 |
+
/* Titles within cards (H3) */
|
| 169 |
+
.feature-card h3 {
|
| 170 |
+
font-size: 1.3rem !important;
|
| 171 |
+
font-weight: 600 !important;
|
| 172 |
+
color: var(--light-text) !important;
|
| 173 |
+
margin-top: 0 !important;
|
| 174 |
+
margin-bottom: 1rem !important;
|
| 175 |
+
border-bottom: none !important;
|
| 176 |
+
padding-bottom: 0 !important;
|
| 177 |
+
}
|
| 178 |
+
/* Lists within cards */
|
| 179 |
+
.feature-card ul {
|
| 180 |
+
padding-left: 0 !important; margin-left: 5px; margin-bottom: 0 !important;
|
| 181 |
+
list-style-type: none;
|
| 182 |
+
}
|
| 183 |
+
.feature-card ul li {
|
| 184 |
+
margin-bottom: 0.6rem !important; line-height: 1.5;
|
| 185 |
+
color: var(--subtitle-text) !important; position: relative; padding-left: 1.2em;
|
| 186 |
+
}
|
| 187 |
+
.feature-card ul li::before {
|
| 188 |
+
content: '▪'; color: var(--neon-green); font-weight: bold;
|
| 189 |
+
display: inline-block; position: absolute; left: 0; top: 0;
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
/* Getting Started List (Ordered List - OL) */
|
| 193 |
+
.stApp ol {
|
| 194 |
+
padding-left: 0 !important; margin-left: 5px; margin-bottom: 30px !important;
|
| 195 |
+
list-style-type: none; counter-reset: getting-started-counter;
|
| 196 |
+
}
|
| 197 |
+
.stApp ol li {
|
| 198 |
+
margin-bottom: 12px !important; line-height: 1.6 !important;
|
| 199 |
+
color: var(--light-text) !important; counter-increment: getting-started-counter;
|
| 200 |
+
position: relative; padding-left: 2.5em;
|
| 201 |
+
}
|
| 202 |
+
.stApp ol li::before {
|
| 203 |
+
content: counter(getting-started-counter); color: var(--dark-bg); background-color: var(--neon-green);
|
| 204 |
+
font-weight: bold; border-radius: 50%; width: 1.6em; height: 1.6em; display: inline-block;
|
| 205 |
+
text-align: center; line-height: 1.6em; position: absolute; left: 0; top: 0;
|
| 206 |
+
}
|
| 207 |
+
.stApp ol li strong { color: var(--neon-green) !important; font-weight: 600; }
|
| 208 |
+
|
| 209 |
+
/* Expander Styling */
|
| 210 |
+
div[data-testid="stExpander"] {
|
| 211 |
+
border: none !important; border-radius: 8px !important; margin-bottom: 1rem !important;
|
| 212 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.2); background-color: transparent !important; overflow: hidden;
|
| 213 |
+
}
|
| 214 |
+
div[data-testid="stExpander"] summary {
|
| 215 |
+
padding: 0.8rem 1.2rem !important; font-size: 1.1rem !important; font-weight: 600 !important;
|
| 216 |
+
color: var(--light-text) !important; background-color: var(--expander-header-bg) !important;
|
| 217 |
+
border: none !important; border-radius: 0 !important;
|
| 218 |
+
transition: background-color 0.2s ease, color 0.2s ease; cursor: pointer;
|
| 219 |
+
}
|
| 220 |
+
div[data-testid="stExpander"] summary:hover {
|
| 221 |
+
background-color: var(--expander-hover-bg) !important; color: var(--neon-green) !important;
|
| 222 |
+
}
|
| 223 |
+
div[data-testid="stExpander"] summary svg { fill: var(--light-text) !important; }
|
| 224 |
+
div[data-testid="stExpander"] summary:hover svg { fill: var(--neon-green) !important; }
|
| 225 |
+
div[data-testid="stExpander"] div[role="button"] + div { /* Content area */
|
| 226 |
+
padding: 1.2rem 1.5rem !important; background-color: var(--card-bg); border: none !important;
|
| 227 |
+
}
|
| 228 |
+
div[data-testid="stExpander"] div[role="button"] + div ul,
|
| 229 |
+
div[data-testid="stExpander"] div[role="button"] + div ol { margin-bottom: 0 !important; padding-left: 20px !important; list-style-type: disc; }
|
| 230 |
+
div[data-testid="stExpander"] div[role="button"] + div li {
|
| 231 |
+
color: var(--subtitle-text) !important; margin-bottom: 0.5rem !important; list-style-type: disc; padding-left: 0;
|
| 232 |
+
}
|
| 233 |
+
div[data-testid="stExpander"] div[role="button"] + div li::before { content: none !important; }
|
| 234 |
+
div[data-testid="stExpander"] a { color: var(--neon-green) !important; text-decoration: underline; }
|
| 235 |
+
div[data-testid="stExpander"] a:hover { text-shadow: 0 0 3px var(--neon-green); }
|
| 236 |
+
|
| 237 |
+
/* Footer Styling */
|
| 238 |
+
.footer {
|
| 239 |
+
margin-top: 4rem !important; padding: 1rem !important; font-size: 0.9rem !important;
|
| 240 |
+
color: var(--subtitle-text) !important; text-align: center; width: 100%;
|
| 241 |
+
position: relative; bottom: auto; left: auto;
|
| 242 |
+
border-top: 1px solid rgba(255, 255, 255, 0.1);
|
| 243 |
+
}
|
| 244 |
+
/* --- End Welcome Page Specific Styling --- */
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
/* --- Overview Tab Styling --- */
|
| 248 |
+
|
| 249 |
+
/* Style for st.metric containers */
|
| 250 |
+
div[data-testid="stMetric"] {
|
| 251 |
+
background-color: var(--card-bg); /* Use card background */
|
| 252 |
+
border: 1px solid var(--card-border); /* Use card border */
|
| 253 |
+
border-radius: 8px;
|
| 254 |
+
padding: 1rem 1.25rem;
|
| 255 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.15);
|
| 256 |
+
transition: transform 0.2s ease-in-out, box-shadow 0.2s ease-in-out;
|
| 257 |
+
height: 100%; /* Ensure metrics in a row are same height */
|
| 258 |
+
}
|
| 259 |
+
div[data-testid="stMetric"]:hover {
|
| 260 |
+
transform: translateY(-2px); /* Lift effect */
|
| 261 |
+
box-shadow: 0 4px 8px rgba(0, 255, 163, 0.1); /* Neon glow */
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
/* Style for st.metric Label */
|
| 265 |
+
div[data-testid="stMetric"] label[data-testid="stMetricLabel"] {
|
| 266 |
+
color: var(--subtitle-text) !important; /* Dimmer label */
|
| 267 |
+
font-weight: 500 !important;
|
| 268 |
+
font-size: 0.95rem !important;
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
/* Style for st.metric Value */
|
| 272 |
+
div[data-testid="stMetric"] div[data-testid="stMetricValue"] {
|
| 273 |
+
color: var(--neon-green) !important; /* Neon value */
|
| 274 |
+
font-size: 2.5rem !important; /* Larger value */
|
| 275 |
+
font-weight: 700 !important;
|
| 276 |
+
padding-top: 5px;
|
| 277 |
+
}
|
| 278 |
+
div[data-testid="stMetric"] div[data-testid="stMetricDelta"] {
|
| 279 |
+
/* Style delta if you use it */
|
| 280 |
+
font-weight: 500 !important;
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
/* Section Headers in Overview Tab (Reuse H3 style) */
|
| 284 |
+
/* The general .stApp h3 rule should cover this if specific enough */
|
| 285 |
+
/* If not, target more specifically: */
|
| 286 |
+
div[data-testid="stVerticalBlock"] h3 { /* Assuming Overview content is in a vertical block */
|
| 287 |
+
font-size: 1.75rem !important;
|
| 288 |
+
font-weight: 600 !important;
|
| 289 |
+
color: var(--neon-green) !important;
|
| 290 |
+
border-bottom: 1px solid rgba(0, 255, 163, 0.3);
|
| 291 |
+
padding-bottom: 8px !important;
|
| 292 |
+
margin-top: 30px !important; /* Adjust spacing */
|
| 293 |
+
margin-bottom: 20px !important;
|
| 294 |
+
}
|
| 295 |
+
/* Reset for feature card H3 if needed */
|
| 296 |
+
.feature-card h3 {
|
| 297 |
+
font-size: 1.3rem !important;
|
| 298 |
+
border-bottom: none !important;
|
| 299 |
+
padding-bottom: 0 !important;
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
/* DataFrame Styling */
|
| 304 |
+
div[data-testid="stDataFrame"] {
|
| 305 |
+
border: 1px solid var(--card-border) !important; /* Neon border */
|
| 306 |
+
border-radius: 8px;
|
| 307 |
+
overflow: hidden; /* Ensures border radius applies to table */
|
| 308 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.2);
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
/* DataFrame Header */
|
| 312 |
+
div[data-testid="stDataFrame"] .col_heading {
|
| 313 |
+
background-color: var(--expander-header-bg) !important; /* Darker header */
|
| 314 |
+
color: var(--light-text) !important;
|
| 315 |
+
font-weight: 600 !important;
|
| 316 |
+
font-size: 0.95rem !important;
|
| 317 |
+
text-align: left !important;
|
| 318 |
+
border-bottom: 1px solid var(--neon-green) !important; /* Neon underline */
|
| 319 |
+
}
|
| 320 |
+
div[data-testid="stDataFrame"] .col_heading:first-of-type {
|
| 321 |
+
border-top-left-radius: 7px; /* Match container radius */
|
| 322 |
+
}
|
| 323 |
+
div[data-testid="stDataFrame"] .col_heading:last-of-type {
|
| 324 |
+
border-top-right-radius: 7px; /* Match container radius */
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
/* DataFrame Cells */
|
| 329 |
+
div[data-testid="stDataFrame"] .dataframe td,
|
| 330 |
+
div[data-testid="stDataFrame"] .dataframe th { /* Also style index header */
|
| 331 |
+
color: var(--subtitle-text) !important;
|
| 332 |
+
border-bottom: 1px solid rgba(255, 255, 255, 0.1) !important; /* Faint row separators */
|
| 333 |
+
border-right: none !important; /* Remove vertical separators */
|
| 334 |
+
padding: 0.5rem 0.75rem !important;
|
| 335 |
+
font-size: 0.9rem !important;
|
| 336 |
+
}
|
| 337 |
+
div[data-testid="stDataFrame"] .dataframe th { /* Index header specifically */
|
| 338 |
+
background-color: rgba(30, 30, 40, 0.3); /* Slightly different background for index */
|
| 339 |
+
color: var(--light-text) !important;
|
| 340 |
+
font-weight: 500 !important;
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
/* DataFrame Rows Hover */
|
| 344 |
+
div[data-testid="stDataFrame"] .dataframe tr:hover td,
|
| 345 |
+
div[data-testid="stDataFrame"] .dataframe tr:hover th {
|
| 346 |
+
background-color: rgba(0, 255, 163, 0.05) !important; /* Faint neon hover */
|
| 347 |
+
color: var(--light-text) !important;
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
/* --- End Overview Tab Styling --- */
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
/* --- Hide Streamlit elements --- */
|
| 357 |
+
#MainMenu {visibility: hidden !important;}
|
| 358 |
+
header {visibility: hidden !important;}
|
| 359 |
+
.stDeployButton {display: none !important;}
|
| 360 |
+
div[data-testid="stToolbar"] {display: none !important;}
|
| 361 |
+
div[data-testid="stDecoration"] {display: none !important;}
|
| 362 |
+
div[data-testid="stStatusWidget"] {display: none !important;}
|
| 363 |
+
/* --- End Hide Streamlit elements --- */
|
| 364 |
+
|
| 365 |
+
/* --- Sidebar styling --- */
|
| 366 |
+
section[data-testid="stSidebar"] > div:first-child {
|
| 367 |
+
background-color: var(--dark-bg) !important;
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
div[data-testid="stMetric"] { color: var(--light-text) !important; }
|
| 371 |
+
div[data-testid="stMetric"] > div { color: var(--light-text) !important; }
|
| 372 |
+
div[data-testid="stMetric"] label { color: var(--light-text) !important; }
|
| 373 |
+
/* --- End Remaining Styles --- */
|
| 374 |
+
|
| 375 |
+
</style>
|
| 376 |
+
"""
|
| 377 |
+
st.markdown(css, unsafe_allow_html=True)
|
src/ui/footer.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
def show_footer():
|
| 4 |
+
"""Display footer with copyright information."""
|
| 5 |
+
footer_html = """
|
| 6 |
+
<div class="footer">
|
| 7 |
+
© 2025 AutoML All Rights Reserved.
|
| 8 |
+
</div>
|
| 9 |
+
"""
|
| 10 |
+
st.markdown(footer_html, unsafe_allow_html=True)
|
src/ui/insight.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
def display_ai_insights():
|
| 4 |
+
"""Displays AI-Powered Insights and Data Cleaning Process."""
|
| 5 |
+
|
| 6 |
+
st.header("💡 AI-Powered Insights")
|
| 7 |
+
|
| 8 |
+
with st.expander("🧹 Data Cleaning Process", expanded=True):
|
| 9 |
+
if "insights" in st.session_state and "df" in st.session_state:
|
| 10 |
+
# Split insights into cleaning process and analysis
|
| 11 |
+
parts = st.session_state.insights.split("ANALYSIS INSIGHTS:")
|
| 12 |
+
|
| 13 |
+
# Show cleaning instructions
|
| 14 |
+
st.markdown(parts[0])
|
| 15 |
+
|
| 16 |
+
# Show interactive dataframe preview using st.session_state.df
|
| 17 |
+
st.subheader("Cleaned Data Sample")
|
| 18 |
+
st.dataframe(
|
| 19 |
+
st.session_state.df.head(), # Use the existing df state
|
| 20 |
+
use_container_width=True,
|
| 21 |
+
hide_index=True,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Show analysis insights if present
|
| 25 |
+
if len(parts) > 1:
|
| 26 |
+
st.markdown("---")
|
| 27 |
+
st.markdown("#### Analysis Insights")
|
| 28 |
+
st.markdown(parts[1])
|
| 29 |
+
else:
|
| 30 |
+
st.warning("No insights generated yet. Upload and process a file first.")
|
src/ui/loading.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
def show_loading_state():
|
| 5 |
+
"""
|
| 6 |
+
Cyber-inspired loading animation with circuit-like effects
|
| 7 |
+
"""
|
| 8 |
+
try:
|
| 9 |
+
st.html("""
|
| 10 |
+
<div class="loading-container-cyber">
|
| 11 |
+
<div class="rocket-animation-cyber">
|
| 12 |
+
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 100 100" class="cyber-rocket">
|
| 13 |
+
<path d="M50 10 L30 70 L50 65 L70 70 Z" fill="#00FFD1"/>
|
| 14 |
+
<path d="M40 80 L50 90 L60 80" stroke="#00FFD1" stroke-width="3" fill="none"/>
|
| 15 |
+
</svg>
|
| 16 |
+
</div>
|
| 17 |
+
|
| 18 |
+
<h1 class="title-cyber">AutoML</h1>
|
| 19 |
+
|
| 20 |
+
<h2 class="subtitle-cyber">You Ask , We Deliver</h2>
|
| 21 |
+
|
| 22 |
+
<div class="loading-content-cyber">
|
| 23 |
+
<p class="loading-text-cyber">Initializing neural networks...</p>
|
| 24 |
+
<div class="loading-bar-container-cyber">
|
| 25 |
+
<div class="loading-bar-cyber"></div>
|
| 26 |
+
</div>
|
| 27 |
+
</div>
|
| 28 |
+
|
| 29 |
+
<style>
|
| 30 |
+
body { background-color: #000000 !important; }
|
| 31 |
+
|
| 32 |
+
.loading-container-cyber {
|
| 33 |
+
display: flex;
|
| 34 |
+
flex-direction: column;
|
| 35 |
+
align-items: center;
|
| 36 |
+
justify-content: center;
|
| 37 |
+
min-height: 80vh;
|
| 38 |
+
text-align: center;
|
| 39 |
+
padding: 2rem;
|
| 40 |
+
background: radial-gradient(circle, rgba(0,0,0,1) 0%, rgba(0,0,0,1) 100%);
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
.cyber-rocket {
|
| 44 |
+
width: 100px;
|
| 45 |
+
height: 100px;
|
| 46 |
+
animation: pulse 2s infinite;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
.title-cyber {
|
| 50 |
+
font-size: 3rem;
|
| 51 |
+
margin-bottom: 0.5rem;
|
| 52 |
+
color: #00FFD1;
|
| 53 |
+
text-shadow: 0 0 10px #00FFD1;
|
| 54 |
+
font-family: 'Orbitron', sans-serif;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
.subtitle-cyber {
|
| 58 |
+
font-size: 1.5rem;
|
| 59 |
+
margin-bottom: 2rem;
|
| 60 |
+
color: #00A86B;
|
| 61 |
+
font-family: 'Chakra Petch', sans-serif;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
.loading-content-cyber {
|
| 65 |
+
background: rgba(0, 255, 209, 0.05);
|
| 66 |
+
border: 1px solid rgba(0, 255, 209, 0.2);
|
| 67 |
+
padding: 1.5rem 2rem;
|
| 68 |
+
border-radius: 8px;
|
| 69 |
+
max-width: 600px;
|
| 70 |
+
width: 100%;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
.loading-text-cyber {
|
| 74 |
+
margin: 0 0 1rem 0;
|
| 75 |
+
font-size: 1.1rem;
|
| 76 |
+
color: #00FFD1;
|
| 77 |
+
font-family: 'Chakra Petch', sans-serif;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
.loading-bar-container-cyber {
|
| 81 |
+
height: 6px;
|
| 82 |
+
background: rgba(0, 255, 209, 0.2);
|
| 83 |
+
border-radius: 3px;
|
| 84 |
+
overflow: hidden;
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
.loading-bar-cyber {
|
| 88 |
+
height: 100%;
|
| 89 |
+
width: 30%;
|
| 90 |
+
background: linear-gradient(90deg, #00FFD1, #00A86B);
|
| 91 |
+
animation: circuit-load 1.5s cubic-bezier(0.4, 0.0, 0.2, 1) infinite;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
@keyframes pulse {
|
| 95 |
+
0%, 100% { transform: scale(1); }
|
| 96 |
+
50% { transform: scale(1.1); }
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
@keyframes circuit-load {
|
| 100 |
+
0% { transform: translateX(-100%); box-shadow: 0 0 10px #00FFD1; }
|
| 101 |
+
50% { box-shadow: 0 0 20px #00FFD1; }
|
| 102 |
+
100% { transform: translateX(400%); box-shadow: 0 0 10px #00FFD1; }
|
| 103 |
+
}
|
| 104 |
+
</style>
|
| 105 |
+
</div> """)
|
| 106 |
+
except Exception as e:
|
| 107 |
+
# Fallback to built-in Streamlit spinner if custom animation fails
|
| 108 |
+
st.warning("Custom loading animation unavailable. Using default spinner...")
|
| 109 |
+
with st.spinner("Loading, please wait..."):
|
| 110 |
+
time.sleep(3)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
if __name__ == "__main__":
|
| 114 |
+
show_loading_state()
|
| 115 |
+
time.sleep(3)
|
| 116 |
+
st.empty()
|
| 117 |
+
st.success("App loaded successfully!")
|
| 118 |
+
|
| 119 |
+
|
src/ui/overview.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import pandas as pd
|
| 4 |
+
|
| 5 |
+
@st.cache_data
|
| 6 |
+
def compute_column_info(df):
|
| 7 |
+
"""Compute summary statistics for each column."""
|
| 8 |
+
return pd.DataFrame({
|
| 9 |
+
"Column": df.dtypes.index,
|
| 10 |
+
"Type": df.dtypes.astype(str),
|
| 11 |
+
"Non-Null Count": df.count(),
|
| 12 |
+
"Null Count": df.isnull().sum(),
|
| 13 |
+
"Unique Values": df.nunique(),
|
| 14 |
+
})
|
| 15 |
+
|
| 16 |
+
def show_overview_page():
|
| 17 |
+
"""Displays dataset statistics, preview, and column information."""
|
| 18 |
+
|
| 19 |
+
if "df" not in st.session_state or st.session_state.df is None:
|
| 20 |
+
st.warning("⚠️ No dataset loaded. Please upload a dataset first.")
|
| 21 |
+
return
|
| 22 |
+
|
| 23 |
+
df = st.session_state.df
|
| 24 |
+
|
| 25 |
+
# Dataset Statistics
|
| 26 |
+
st.markdown("## 📊 Dataset Statistics")
|
| 27 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 28 |
+
|
| 29 |
+
with col1:
|
| 30 |
+
st.metric("Total Rows", len(df))
|
| 31 |
+
with col2:
|
| 32 |
+
st.metric("Total Columns", len(df.columns))
|
| 33 |
+
with col3:
|
| 34 |
+
numeric_count = len(df.select_dtypes(include=["int64", "float64"]).columns)
|
| 35 |
+
st.metric("Numeric Columns", numeric_count)
|
| 36 |
+
with col4:
|
| 37 |
+
categorical_count = len(df.select_dtypes(include=["object", "category"]).columns)
|
| 38 |
+
st.metric("Categorical Columns", categorical_count)
|
| 39 |
+
|
| 40 |
+
# Data Preview: Only display the top few rows
|
| 41 |
+
st.markdown("## 🔍 Data Preview")
|
| 42 |
+
st.dataframe(df.head(), use_container_width=True)
|
| 43 |
+
|
| 44 |
+
# Column Information: Use cached computation for faster loading
|
| 45 |
+
st.markdown("## 📌 Column Information")
|
| 46 |
+
dtypes_df = compute_column_info(df)
|
| 47 |
+
st.dataframe(dtypes_df, use_container_width=True)
|
src/ui/test_results.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import io
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import seaborn as sns
|
| 7 |
+
import plotly.express as px
|
| 8 |
+
from sklearn.metrics import (
|
| 9 |
+
accuracy_score,
|
| 10 |
+
precision_score,
|
| 11 |
+
recall_score,
|
| 12 |
+
f1_score,
|
| 13 |
+
confusion_matrix,
|
| 14 |
+
mean_absolute_error,
|
| 15 |
+
mean_squared_error,
|
| 16 |
+
r2_score,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
# ==== LLM Setup with Caching ====
|
| 20 |
+
@st.cache_resource(show_spinner=False) # Disable default spinner
|
| 21 |
+
def get_llm():
|
| 22 |
+
"""Cached LLM initialization to prevent reloading on every rerun"""
|
| 23 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 24 |
+
from langchain_groq import ChatGroq
|
| 25 |
+
import os
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
return ChatGroq(
|
| 29 |
+
model="gemma2-9b-it",
|
| 30 |
+
groq_api_key=os.getenv("GROQ_API_KEY")
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
except Exception as e:
|
| 34 |
+
try:
|
| 35 |
+
return ChatGoogleGenerativeAI(
|
| 36 |
+
model="gemini-2.0-flash-lite-preview-02-05",
|
| 37 |
+
google_api_key=os.getenv("GEMINI_API_KEY")
|
| 38 |
+
)
|
| 39 |
+
except:
|
| 40 |
+
return None
|
| 41 |
+
|
| 42 |
+
llm_insights = get_llm()
|
| 43 |
+
|
| 44 |
+
# ==== Cached Metric Calculations ====
|
| 45 |
+
@st.cache_data(show_spinner=False) # Add to heavy computations
|
| 46 |
+
def _compute_classification_metrics(y_test, y_pred):
|
| 47 |
+
"""Cached metric computation for classification"""
|
| 48 |
+
return {
|
| 49 |
+
'accuracy': accuracy_score(y_test, y_pred),
|
| 50 |
+
'precision': precision_score(y_test, y_pred, average="weighted", zero_division=0),
|
| 51 |
+
'recall': recall_score(y_test, y_pred, average="weighted", zero_division=0),
|
| 52 |
+
'f1': f1_score(y_test, y_pred, average="weighted", zero_division=0),
|
| 53 |
+
'cm': confusion_matrix(y_test, y_pred)
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
@st.cache_data
|
| 57 |
+
def _compute_regression_metrics(y_test, y_pred):
|
| 58 |
+
"""Cached metric computation for regression"""
|
| 59 |
+
return {
|
| 60 |
+
'mae': mean_absolute_error(y_test, y_pred),
|
| 61 |
+
'mse': mean_squared_error(y_test, y_pred),
|
| 62 |
+
'rmse': np.sqrt(mean_squared_error(y_test, y_pred)),
|
| 63 |
+
'r2': r2_score(y_test, y_pred)
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
# ==== Cached Visualization Generation ====
|
| 67 |
+
@st.cache_data(show_spinner=False) # Add to heavy computations
|
| 68 |
+
def _plot_confusion_matrix(cm, classes):
|
| 69 |
+
"""Cached confusion matrix plotting"""
|
| 70 |
+
fig, ax = plt.subplots(figsize=(2, 2), dpi=200)
|
| 71 |
+
sns.heatmap(
|
| 72 |
+
cm,
|
| 73 |
+
annot=True,
|
| 74 |
+
fmt="d",
|
| 75 |
+
cmap="Blues",
|
| 76 |
+
xticklabels=classes,
|
| 77 |
+
yticklabels=classes,
|
| 78 |
+
annot_kws={"size": 8},
|
| 79 |
+
)
|
| 80 |
+
plt.xticks(fontsize=5)
|
| 81 |
+
plt.yticks(fontsize=5)
|
| 82 |
+
|
| 83 |
+
buf = io.BytesIO()
|
| 84 |
+
fig.savefig(buf, format="png", bbox_inches="tight", dpi=200)
|
| 85 |
+
buf.seek(0)
|
| 86 |
+
return buf
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# ==== Optimized Insights Generation ====
|
| 90 |
+
@st.cache_data(show_spinner=False) # Add to heavy computations
|
| 91 |
+
def _get_insights_classification(accuracy, precision, recall, f1, cm_shape):
|
| 92 |
+
"""Cached insights generation based on metrics"""
|
| 93 |
+
if llm_insights is None:
|
| 94 |
+
return (
|
| 95 |
+
f"### Classification Metrics Explained\n\n"
|
| 96 |
+
f"**Accuracy** ({accuracy:.3f}): Correct predictions ratio\n"
|
| 97 |
+
f"**Precision** ({precision:.3f}): Positive prediction accuracy\n"
|
| 98 |
+
f"**Recall** ({recall:.3f}): Actual positives found\n"
|
| 99 |
+
f"**F1 Score** ({f1:.3f}): Precision-Recall balance\n"
|
| 100 |
+
f"Confusion Matrix ({cm_shape[0]}x{cm_shape[1]}): Prediction vs Actual distribution"
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
response = llm_insights.invoke(f"""
|
| 105 |
+
Briefly explain these classification metrics (accuracy={accuracy:.3f},
|
| 106 |
+
precision={precision:.3f}, recall={recall:.3f}, f1={f1:.3f})
|
| 107 |
+
and {cm_shape[0]}x{cm_shape[1]} confusion matrix.
|
| 108 |
+
Use markdown bullet points.
|
| 109 |
+
""")
|
| 110 |
+
return response.content.strip()
|
| 111 |
+
except:
|
| 112 |
+
return "Could not generate AI insights - showing basic metrics explanation."
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def display_test_results(trained_model, X_test, y_test, task_type, label_encoder=None):
|
| 117 |
+
"""
|
| 118 |
+
Displays test results, including metrics, confusion matrix (if classification),
|
| 119 |
+
and LLM-based or fallback insights about the metrics.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
# Create a placeholder for the loading message at the top of the page
|
| 123 |
+
st.markdown("## Test Results")
|
| 124 |
+
loading_placeholder = st.empty()
|
| 125 |
+
|
| 126 |
+
# Show initial loading message
|
| 127 |
+
with loading_placeholder.container():
|
| 128 |
+
st.info("⏳ Evaluating model performance on test data. This may take a moment for large datasets.")
|
| 129 |
+
progress_bar = st.progress(0)
|
| 130 |
+
|
| 131 |
+
# Set a flag to track if results have been calculated
|
| 132 |
+
if "test_results_calculated" not in st.session_state:
|
| 133 |
+
st.session_state.test_results_calculated = False
|
| 134 |
+
|
| 135 |
+
# Only perform calculations if they haven't been done yet
|
| 136 |
+
if not st.session_state.test_results_calculated:
|
| 137 |
+
|
| 138 |
+
sampling_message = None
|
| 139 |
+
MAX_SAMPLES = 5000 # Increased from 50 to 5000
|
| 140 |
+
|
| 141 |
+
# Update progress - Starting evaluation
|
| 142 |
+
with loading_placeholder.container():
|
| 143 |
+
progress_bar.progress(10)
|
| 144 |
+
|
| 145 |
+
if len(X_test) <= MAX_SAMPLES:
|
| 146 |
+
# Use all test data
|
| 147 |
+
X_test_sample = X_test
|
| 148 |
+
y_test_sample = y_test
|
| 149 |
+
st.info("🔍 Using all test data for evaluation...")
|
| 150 |
+
else:
|
| 151 |
+
# Use sampling for large datasets
|
| 152 |
+
sampling_message = f"📊 Using {MAX_SAMPLES} samples from the test set for visualization (out of {len(X_test)} total)"
|
| 153 |
+
st.info("🔍 Sampling test data for evaluation...")
|
| 154 |
+
|
| 155 |
+
# Simple random sampling
|
| 156 |
+
idx = np.random.choice(len(X_test.index if hasattr(X_test, 'index') else X_test), size=MAX_SAMPLES, replace=False)
|
| 157 |
+
X_test_sample = X_test.iloc[idx] if hasattr(X_test, 'iloc') else X_test[idx]
|
| 158 |
+
y_test_sample = y_test.iloc[idx] if hasattr(y_test, 'iloc') else y_test[idx]
|
| 159 |
+
|
| 160 |
+
# Generate predictions
|
| 161 |
+
with loading_placeholder.container():
|
| 162 |
+
progress_bar.progress(30)
|
| 163 |
+
st.info("🔄 Generating predictions... Please wait")
|
| 164 |
+
# Add a spinner for visual feedback during prediction
|
| 165 |
+
with st.spinner("Model working..."):
|
| 166 |
+
if task_type == "regression":
|
| 167 |
+
y_pred = trained_model.predict(X_test_sample)
|
| 168 |
+
elif task_type == "classification":
|
| 169 |
+
pipeline, enc = trained_model if label_encoder is None else (trained_model, label_encoder)
|
| 170 |
+
y_pred = pipeline.predict(X_test_sample)
|
| 171 |
+
|
| 172 |
+
# Decode if label_encoder is used
|
| 173 |
+
if enc:
|
| 174 |
+
y_pred = enc.inverse_transform(y_pred)
|
| 175 |
+
y_test_decoded = enc.inverse_transform(y_test_sample)
|
| 176 |
+
else:
|
| 177 |
+
y_test_decoded = y_test_sample
|
| 178 |
+
|
| 179 |
+
# Update progress - Computing metrics
|
| 180 |
+
with loading_placeholder.container():
|
| 181 |
+
progress_bar.progress(60)
|
| 182 |
+
st.info("📊 Computing metrics...")
|
| 183 |
+
|
| 184 |
+
# Compute metrics
|
| 185 |
+
if task_type == "regression":
|
| 186 |
+
metrics = _compute_regression_metrics(y_test_sample, y_pred)
|
| 187 |
+
else:
|
| 188 |
+
metrics = _compute_classification_metrics(y_test_decoded, y_pred)
|
| 189 |
+
|
| 190 |
+
# Update progress - Preparing visualizations
|
| 191 |
+
with loading_placeholder.container():
|
| 192 |
+
progress_bar.progress(90)
|
| 193 |
+
st.info("📈 Preparing visualizations...")
|
| 194 |
+
|
| 195 |
+
# For classification, pre-calculate confusion matrix before showing "ready" message
|
| 196 |
+
if task_type == "classification":
|
| 197 |
+
# Pre-calculate confusion matrix (this is the slow part)
|
| 198 |
+
_ = _plot_confusion_matrix(metrics['cm'], np.unique(y_test_decoded))
|
| 199 |
+
# Pre-calculate insights (also potentially slow with LLM)
|
| 200 |
+
_ = _get_insights_classification(
|
| 201 |
+
metrics['accuracy'],
|
| 202 |
+
metrics['precision'],
|
| 203 |
+
metrics['recall'],
|
| 204 |
+
metrics['f1'],
|
| 205 |
+
metrics['cm'].shape
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# Update progress - Complete (only after all calculations are done)
|
| 209 |
+
with loading_placeholder.container():
|
| 210 |
+
progress_bar.progress(100)
|
| 211 |
+
st.success("✅ Test results ready!")
|
| 212 |
+
|
| 213 |
+
# Mark results as calculated
|
| 214 |
+
st.session_state.test_results_calculated = True
|
| 215 |
+
|
| 216 |
+
# Store results in session state for reuse
|
| 217 |
+
st.session_state.test_metrics = metrics
|
| 218 |
+
if task_type == "classification":
|
| 219 |
+
st.session_state.test_y_pred = y_pred
|
| 220 |
+
st.session_state.test_y_test = y_test_decoded
|
| 221 |
+
else:
|
| 222 |
+
st.session_state.test_y_pred = y_pred
|
| 223 |
+
st.session_state.test_y_test = y_test_sample
|
| 224 |
+
|
| 225 |
+
# Store sampling message
|
| 226 |
+
st.session_state.sampling_message = sampling_message
|
| 227 |
+
|
| 228 |
+
# Import time only when needed (moved from global to local scope)
|
| 229 |
+
import time
|
| 230 |
+
time.sleep(0.5) # Short delay to show the "Test results ready!" message
|
| 231 |
+
|
| 232 |
+
# Display sampling message if it exists
|
| 233 |
+
if "sampling_message" in st.session_state and st.session_state.sampling_message:
|
| 234 |
+
st.info(st.session_state.sampling_message)
|
| 235 |
+
|
| 236 |
+
# Display the results using stored values
|
| 237 |
+
if task_type == "regression":
|
| 238 |
+
st.subheader("🔍 Regression Metrics")
|
| 239 |
+
|
| 240 |
+
# Get metrics from session state or use the ones we just calculated
|
| 241 |
+
if "test_metrics" in st.session_state and st.session_state.test_results_calculated:
|
| 242 |
+
metrics = st.session_state.test_metrics
|
| 243 |
+
y_pred = st.session_state.test_y_pred
|
| 244 |
+
y_test = st.session_state.test_y_test
|
| 245 |
+
|
| 246 |
+
mae, mse, rmse, r2 = metrics['mae'], metrics['mse'], np.sqrt(metrics['mse']), metrics['r2']
|
| 247 |
+
|
| 248 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 249 |
+
col1.metric("📉 MAE", f"{mae:.4f}")
|
| 250 |
+
col2.metric("📊 MSE", f"{mse:.4f}")
|
| 251 |
+
col3.metric("📈 RMSE", f"{rmse:.4f}")
|
| 252 |
+
col4.metric("📌 R² Score", f"{r2:.4f}")
|
| 253 |
+
|
| 254 |
+
# Add regression visualization
|
| 255 |
+
st.subheader("📈 Prediction vs Actual")
|
| 256 |
+
df_results = pd.DataFrame({
|
| 257 |
+
'Actual': y_test,
|
| 258 |
+
'Predicted': y_pred
|
| 259 |
+
})
|
| 260 |
+
fig = px.scatter(df_results, x='Actual', y='Predicted',
|
| 261 |
+
title='Predicted vs Actual Values',
|
| 262 |
+
labels={'Actual': 'Actual Values', 'Predicted': 'Predicted Values'})
|
| 263 |
+
fig.add_shape(type='line', x0=min(y_test), y0=min(y_test),
|
| 264 |
+
x1=max(y_test), y1=max(y_test),
|
| 265 |
+
line=dict(color='red', dash='dash'))
|
| 266 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 267 |
+
|
| 268 |
+
elif task_type == "classification":
|
| 269 |
+
st.subheader("🔍 Classification Metrics")
|
| 270 |
+
|
| 271 |
+
# Get metrics from session state or use the ones we just calculated
|
| 272 |
+
if "test_metrics" in st.session_state and st.session_state.test_results_calculated:
|
| 273 |
+
metrics = st.session_state.test_metrics
|
| 274 |
+
y_pred = st.session_state.test_y_pred
|
| 275 |
+
y_test_decoded = st.session_state.test_y_test
|
| 276 |
+
|
| 277 |
+
accuracy, precision, recall, f1 = metrics['accuracy'], metrics['precision'], metrics['recall'], metrics['f1']
|
| 278 |
+
|
| 279 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 280 |
+
col1.metric("✅ Accuracy", f"{accuracy:.4f}")
|
| 281 |
+
col2.metric("🎯 Precision", f"{precision:.4f}")
|
| 282 |
+
col3.metric("📢 Recall", f"{recall:.4f}")
|
| 283 |
+
col4.metric("🔥 F1 Score", f"{f1:.4f}")
|
| 284 |
+
|
| 285 |
+
st.subheader("📊 Confusion Matrix")
|
| 286 |
+
# Use cached function for confusion matrix visualization
|
| 287 |
+
buf = _plot_confusion_matrix(metrics['cm'], np.unique(y_test_decoded))
|
| 288 |
+
st.image(buf, width=450)
|
| 289 |
+
|
| 290 |
+
# === Additional Insights Section ===
|
| 291 |
+
st.markdown("---")
|
| 292 |
+
st.markdown("#### Test Insights")
|
| 293 |
+
accuracy, precision, recall, f1 = metrics['accuracy'], metrics['precision'], metrics['recall'], metrics['f1']
|
| 294 |
+
classification_insights = _get_insights_classification(accuracy, precision, recall, f1, metrics['cm'].shape)
|
| 295 |
+
st.markdown(classification_insights)
|
src/ui/visualization.py
ADDED
|
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import plotly.express as px
|
| 4 |
+
import plotly.graph_objects as go
|
| 5 |
+
from plotly.subplots import make_subplots
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
from src.utils.logging import log_frontend_error, log_frontend_warning
|
| 9 |
+
|
| 10 |
+
SAMPLE_SIZE = 10000 # Define a sample size for subsampling large datasets
|
| 11 |
+
|
| 12 |
+
# Efficiently hash a dataframe to detect changes
|
| 13 |
+
@st.cache_data(show_spinner=False)
|
| 14 |
+
def compute_df_hash(df):
|
| 15 |
+
"""Optimized dataframe hashing"""
|
| 16 |
+
return hash((df.shape, pd.util.hash_pandas_object(df.iloc[:min(100, len(df))]).sum())) # Sample-based hashing
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@st.cache_data(show_spinner=False, ttl=3600) # Cache for 1 hour
|
| 20 |
+
def is_potential_date_column(series, sample_size=5):
|
| 21 |
+
"""Check if column might contain dates"""
|
| 22 |
+
# Check column name first
|
| 23 |
+
if any(keyword in series.name.lower() for keyword in ['date', 'time', 'year', 'month', 'day']):
|
| 24 |
+
return True
|
| 25 |
+
|
| 26 |
+
# Check sample values
|
| 27 |
+
sample = series.dropna().head(sample_size).astype(str)
|
| 28 |
+
date_patterns = [
|
| 29 |
+
r'\d{4}-\d{2}-\d{2}', # YYYY-MM-DD
|
| 30 |
+
r'\d{2}/\d{2}/\d{4}', # MM/DD/YYYY
|
| 31 |
+
r'\d{2}-\w{3}-\d{2,4}', # DD-MON-YY(Y)
|
| 32 |
+
r'\d{1,2} \w{3,} \d{4}' # 1 January 2023
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
date_count = sum(1 for val in sample if any(re.match(p, val) for p in date_patterns))
|
| 36 |
+
return date_count / len(sample) > 0.5 if len(sample) > 0 else False # >50% match
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Cache column type detection with improved performance
|
| 42 |
+
@st.cache_data(show_spinner=False, ttl=3600) # Cache for 1 hour
|
| 43 |
+
def get_column_types(df):
|
| 44 |
+
"""Detect column types efficiently and cache the results."""
|
| 45 |
+
column_types = {}
|
| 46 |
+
|
| 47 |
+
# Process columns in batches for better performance
|
| 48 |
+
for chunk_start in range(0, len(df.columns), 10):
|
| 49 |
+
chunk_end = min(chunk_start + 10, len(df.columns))
|
| 50 |
+
chunk_columns = df.columns[chunk_start:chunk_end]
|
| 51 |
+
|
| 52 |
+
for column in chunk_columns:
|
| 53 |
+
# Check for numeric columns
|
| 54 |
+
if pd.api.types.is_numeric_dtype(df[column]):
|
| 55 |
+
# Detect if it's a binary column (0/1, True/False)
|
| 56 |
+
if df[column].nunique() <= 2:
|
| 57 |
+
column_types[column] = "BINARY"
|
| 58 |
+
# Detect if it's a discrete numeric column (few unique values)
|
| 59 |
+
elif df[column].nunique() < 20:
|
| 60 |
+
column_types[column] = "NUMERIC_DISCRETE"
|
| 61 |
+
# Otherwise it's a continuous numeric column
|
| 62 |
+
else:
|
| 63 |
+
column_types[column] = "NUMERIC_CONTINUOUS"
|
| 64 |
+
else:
|
| 65 |
+
# Check for temporal/date columns
|
| 66 |
+
if is_potential_date_column(df[column]):
|
| 67 |
+
try:
|
| 68 |
+
# Attempt conversion with coerce
|
| 69 |
+
converted = pd.to_datetime(df[column], errors='coerce')
|
| 70 |
+
if not converted.isnull().all(): # At least some valid dates
|
| 71 |
+
column_types[column] = "TEMPORAL"
|
| 72 |
+
continue
|
| 73 |
+
except Exception:
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
# Check for ID-like columns (high cardinality with unique patterns)
|
| 77 |
+
if (df[column].nunique() > len(df) * 0.9 and
|
| 78 |
+
any(x in column.lower() for x in ['id', 'code', 'key', 'uuid', 'identifier'])):
|
| 79 |
+
column_types[column] = "ID"
|
| 80 |
+
# Check for categorical columns (low to medium cardinality)
|
| 81 |
+
elif df[column].nunique() <= 20:
|
| 82 |
+
column_types[column] = "CATEGORICAL"
|
| 83 |
+
# Otherwise it's a text column
|
| 84 |
+
else:
|
| 85 |
+
column_types[column] = "TEXT"
|
| 86 |
+
|
| 87 |
+
return column_types
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Cache correlation matrix computation with improved performance
|
| 93 |
+
@st.cache_data(show_spinner=False, ttl=3600) # Cache for 1 hour
|
| 94 |
+
def get_corr_matrix(df):
|
| 95 |
+
"""Compute and cache the correlation matrix for numeric columns."""
|
| 96 |
+
# Only select numeric columns to avoid errors
|
| 97 |
+
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
| 98 |
+
|
| 99 |
+
# If we have too many numeric columns, sample them for better performance
|
| 100 |
+
if len(numeric_cols) > 30:
|
| 101 |
+
numeric_cols = numeric_cols[:30]
|
| 102 |
+
|
| 103 |
+
# Return correlation matrix if we have at least 2 numeric columns
|
| 104 |
+
return df[numeric_cols].corr() if len(numeric_cols) > 1 else None
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# Cache subsampled data with improved performance
|
| 111 |
+
@st.cache_data(show_spinner=False, ttl=3600) # Cache for 1 hour
|
| 112 |
+
def get_subsampled_data(df, column):
|
| 113 |
+
"""Return subsampled data for faster visualization."""
|
| 114 |
+
# Check if column exists
|
| 115 |
+
if column not in df.columns:
|
| 116 |
+
return pd.DataFrame()
|
| 117 |
+
|
| 118 |
+
# Use stratified sampling for categorical columns if possible
|
| 119 |
+
if df[column].nunique() < 20 and len(df) > SAMPLE_SIZE:
|
| 120 |
+
try:
|
| 121 |
+
# Try to get a representative sample
|
| 122 |
+
fractions = min(0.5, SAMPLE_SIZE / len(df))
|
| 123 |
+
return df[[column]].groupby(column, group_keys=False).apply(
|
| 124 |
+
lambda x: x.sample(max(1, int(fractions * len(x))), random_state=42)
|
| 125 |
+
)
|
| 126 |
+
except Exception:
|
| 127 |
+
# Fall back to random sampling
|
| 128 |
+
pass
|
| 129 |
+
|
| 130 |
+
# Use random sampling
|
| 131 |
+
return df[[column]].sample(min(len(df), SAMPLE_SIZE), random_state=42)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# Cache chart creation with improved performance
|
| 137 |
+
@st.cache_data(show_spinner=False, ttl=1800, hash_funcs={ # Cache for 30 minutes
|
| 138 |
+
pd.DataFrame: compute_df_hash,
|
| 139 |
+
pd.Series: lambda s: hash((s.name, compute_df_hash(s.to_frame())))
|
| 140 |
+
})
|
| 141 |
+
def create_chart(df, column, column_type):
|
| 142 |
+
"""Generate optimized charts based on column type."""
|
| 143 |
+
# Check if column exists in the dataframe
|
| 144 |
+
if column not in df.columns:
|
| 145 |
+
return None
|
| 146 |
+
|
| 147 |
+
# Get subsampled data for better performance
|
| 148 |
+
df_sample = get_subsampled_data(df, column)
|
| 149 |
+
if df_sample.empty:
|
| 150 |
+
return None
|
| 151 |
+
|
| 152 |
+
try:
|
| 153 |
+
# Year-based columns (special case)
|
| 154 |
+
if "year" in column.lower():
|
| 155 |
+
fig = make_subplots(rows=1, cols=2, subplot_titles=("Year Distribution", "Box Plot"),
|
| 156 |
+
specs=[[{"type": "bar"}, {"type": "box"}]], column_widths=[0.7, 0.3], horizontal_spacing=0.1)
|
| 157 |
+
year_counts = df_sample[column].value_counts().sort_index()
|
| 158 |
+
fig.add_trace(go.Bar(x=year_counts.index, y=year_counts.values, marker_color='#7B68EE'), row=1, col=1)
|
| 159 |
+
fig.add_trace(go.Box(x=df_sample[column], marker_color='#7B68EE'), row=1, col=2)
|
| 160 |
+
|
| 161 |
+
# Binary columns (0/1, True/False)
|
| 162 |
+
elif column_type == "BINARY":
|
| 163 |
+
value_counts = df_sample[column].value_counts()
|
| 164 |
+
fig = make_subplots(rows=1, cols=2,
|
| 165 |
+
subplot_titles=("Distribution", "Percentage"),
|
| 166 |
+
specs=[[{"type": "bar"}, {"type": "pie"}]],
|
| 167 |
+
column_widths=[0.5, 0.5],
|
| 168 |
+
horizontal_spacing=0.1)
|
| 169 |
+
|
| 170 |
+
fig.add_trace(go.Bar(
|
| 171 |
+
x=value_counts.index,
|
| 172 |
+
y=value_counts.values,
|
| 173 |
+
marker_color=['#FF4B4B', '#4CAF50'],
|
| 174 |
+
text=value_counts.values,
|
| 175 |
+
textposition='auto'
|
| 176 |
+
), row=1, col=1)
|
| 177 |
+
|
| 178 |
+
fig.add_trace(go.Pie(
|
| 179 |
+
labels=value_counts.index,
|
| 180 |
+
values=value_counts.values,
|
| 181 |
+
marker=dict(colors=['#FF4B4B', '#4CAF50']),
|
| 182 |
+
textinfo='percent+label'
|
| 183 |
+
), row=1, col=2)
|
| 184 |
+
|
| 185 |
+
fig.update_layout(title_text=f"Binary Distribution: {column}")
|
| 186 |
+
|
| 187 |
+
# Numeric continuous columns
|
| 188 |
+
elif column_type == "NUMERIC_CONTINUOUS":
|
| 189 |
+
fig = make_subplots(rows=2, cols=2,
|
| 190 |
+
subplot_titles=("Distribution", "Box Plot", "Violin Plot", "Cumulative Distribution"),
|
| 191 |
+
specs=[[{"type": "histogram"}, {"type": "box"}],
|
| 192 |
+
[{"type": "violin"}, {"type": "scatter"}]],
|
| 193 |
+
vertical_spacing=0.15,
|
| 194 |
+
horizontal_spacing=0.1)
|
| 195 |
+
|
| 196 |
+
# Histogram
|
| 197 |
+
fig.add_trace(go.Histogram(
|
| 198 |
+
x=df_sample[column],
|
| 199 |
+
nbinsx=30,
|
| 200 |
+
marker_color='#FF4B4B',
|
| 201 |
+
opacity=0.7
|
| 202 |
+
), row=1, col=1)
|
| 203 |
+
|
| 204 |
+
# Box plot
|
| 205 |
+
fig.add_trace(go.Box(
|
| 206 |
+
x=df_sample[column],
|
| 207 |
+
marker_color='#FF4B4B',
|
| 208 |
+
boxpoints='outliers'
|
| 209 |
+
), row=1, col=2)
|
| 210 |
+
|
| 211 |
+
# Violin plot
|
| 212 |
+
fig.add_trace(go.Violin(
|
| 213 |
+
x=df_sample[column],
|
| 214 |
+
marker_color='#FF4B4B',
|
| 215 |
+
box_visible=True,
|
| 216 |
+
points='outliers'
|
| 217 |
+
), row=2, col=1)
|
| 218 |
+
|
| 219 |
+
# CDF
|
| 220 |
+
sorted_data = np.sort(df_sample[column].dropna())
|
| 221 |
+
cumulative = np.arange(1, len(sorted_data) + 1) / len(sorted_data)
|
| 222 |
+
|
| 223 |
+
fig.add_trace(go.Scatter(
|
| 224 |
+
x=sorted_data,
|
| 225 |
+
y=cumulative,
|
| 226 |
+
mode='lines',
|
| 227 |
+
line=dict(color='#FF4B4B', width=2)
|
| 228 |
+
), row=2, col=2)
|
| 229 |
+
|
| 230 |
+
fig.update_layout(height=600, title_text=f"Continuous Variable Analysis: {column}")
|
| 231 |
+
|
| 232 |
+
# Numeric discrete columns
|
| 233 |
+
elif column_type == "NUMERIC_DISCRETE":
|
| 234 |
+
value_counts = df_sample[column].value_counts().sort_index()
|
| 235 |
+
fig = make_subplots(rows=1, cols=2,
|
| 236 |
+
subplot_titles=("Distribution", "Percentage"),
|
| 237 |
+
specs=[[{"type": "bar"}, {"type": "pie"}]],
|
| 238 |
+
column_widths=[0.7, 0.3],
|
| 239 |
+
horizontal_spacing=0.1)
|
| 240 |
+
|
| 241 |
+
fig.add_trace(go.Bar(
|
| 242 |
+
x=value_counts.index,
|
| 243 |
+
y=value_counts.values,
|
| 244 |
+
marker_color='#FF4B4B',
|
| 245 |
+
text=value_counts.values,
|
| 246 |
+
textposition='auto'
|
| 247 |
+
), row=1, col=1)
|
| 248 |
+
|
| 249 |
+
fig.add_trace(go.Pie(
|
| 250 |
+
labels=value_counts.index,
|
| 251 |
+
values=value_counts.values,
|
| 252 |
+
marker=dict(colors=px.colors.sequential.Reds),
|
| 253 |
+
textinfo='percent+label'
|
| 254 |
+
), row=1, col=2)
|
| 255 |
+
|
| 256 |
+
fig.update_layout(title_text=f"Discrete Numeric Distribution: {column}")
|
| 257 |
+
|
| 258 |
+
# Categorical columns
|
| 259 |
+
elif column_type == "CATEGORICAL":
|
| 260 |
+
value_counts = df_sample[column].value_counts().head(20) # Limit to top 20 categories
|
| 261 |
+
fig = make_subplots(rows=1, cols=2,
|
| 262 |
+
subplot_titles=("Category Distribution", "Percentage Breakdown"),
|
| 263 |
+
specs=[[{"type": "bar"}, {"type": "pie"}]],
|
| 264 |
+
column_widths=[0.6, 0.4],
|
| 265 |
+
horizontal_spacing=0.1)
|
| 266 |
+
|
| 267 |
+
# Bar chart
|
| 268 |
+
fig.add_trace(go.Bar(
|
| 269 |
+
x=value_counts.index,
|
| 270 |
+
y=value_counts.values,
|
| 271 |
+
marker_color='#00FFA3',
|
| 272 |
+
text=value_counts.values,
|
| 273 |
+
textposition='auto'
|
| 274 |
+
), row=1, col=1)
|
| 275 |
+
|
| 276 |
+
# Pie chart
|
| 277 |
+
fig.add_trace(go.Pie(
|
| 278 |
+
labels=value_counts.index,
|
| 279 |
+
values=value_counts.values,
|
| 280 |
+
marker=dict(colors=px.colors.sequential.Greens),
|
| 281 |
+
textinfo='percent+label'
|
| 282 |
+
), row=1, col=2)
|
| 283 |
+
|
| 284 |
+
fig.update_layout(title_text=f"Categorical Analysis: {column}")
|
| 285 |
+
|
| 286 |
+
# Temporal/date columns
|
| 287 |
+
elif column_type == "TEMPORAL":
|
| 288 |
+
# Convert with safe datetime parsing
|
| 289 |
+
dates = pd.to_datetime(df_sample[column], errors='coerce', format='mixed')
|
| 290 |
+
valid_dates = dates[dates.notna()]
|
| 291 |
+
|
| 292 |
+
fig = make_subplots(
|
| 293 |
+
rows=2,
|
| 294 |
+
cols=2,
|
| 295 |
+
subplot_titles=("Monthly Pattern", "Yearly Pattern", "Cumulative Trend", "Day of Week Distribution"),
|
| 296 |
+
vertical_spacing=0.15,
|
| 297 |
+
horizontal_spacing=0.1,
|
| 298 |
+
specs=[[{"type": "bar"}, {"type": "bar"}],
|
| 299 |
+
[{"type": "scatter"}, {"type": "bar"}]]
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# Monthly pattern
|
| 303 |
+
if not valid_dates.empty:
|
| 304 |
+
monthly_counts = valid_dates.dt.month.value_counts().sort_index()
|
| 305 |
+
month_names = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
|
| 306 |
+
month_labels = [month_names[i-1] for i in monthly_counts.index]
|
| 307 |
+
|
| 308 |
+
fig.add_trace(go.Bar(
|
| 309 |
+
x=month_labels,
|
| 310 |
+
y=monthly_counts.values,
|
| 311 |
+
marker_color='#7B68EE',
|
| 312 |
+
text=monthly_counts.values,
|
| 313 |
+
textposition='auto'
|
| 314 |
+
), row=1, col=1)
|
| 315 |
+
|
| 316 |
+
# Yearly pattern
|
| 317 |
+
yearly_counts = valid_dates.dt.year.value_counts().sort_index()
|
| 318 |
+
|
| 319 |
+
fig.add_trace(go.Bar(
|
| 320 |
+
x=yearly_counts.index,
|
| 321 |
+
y=yearly_counts.values,
|
| 322 |
+
marker_color='#7B68EE',
|
| 323 |
+
text=yearly_counts.values,
|
| 324 |
+
textposition='auto'
|
| 325 |
+
), row=1, col=2)
|
| 326 |
+
|
| 327 |
+
# Cumulative trend
|
| 328 |
+
sorted_dates = valid_dates.sort_values()
|
| 329 |
+
cumulative = np.arange(1, len(sorted_dates) + 1)
|
| 330 |
+
|
| 331 |
+
fig.add_trace(go.Scatter(
|
| 332 |
+
x=sorted_dates,
|
| 333 |
+
y=cumulative,
|
| 334 |
+
mode='lines',
|
| 335 |
+
line=dict(color='#7B68EE', width=2)
|
| 336 |
+
), row=2, col=1)
|
| 337 |
+
|
| 338 |
+
# Day of week distribution
|
| 339 |
+
dow_counts = valid_dates.dt.dayofweek.value_counts().sort_index()
|
| 340 |
+
dow_names = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
|
| 341 |
+
dow_labels = [dow_names[i] for i in dow_counts.index]
|
| 342 |
+
|
| 343 |
+
fig.add_trace(go.Bar(
|
| 344 |
+
x=dow_labels,
|
| 345 |
+
y=dow_counts.values,
|
| 346 |
+
marker_color='#7B68EE',
|
| 347 |
+
text=dow_counts.values,
|
| 348 |
+
textposition='auto'
|
| 349 |
+
), row=2, col=2)
|
| 350 |
+
|
| 351 |
+
fig.update_layout(height=600, title_text=f"Temporal Analysis: {column}")
|
| 352 |
+
|
| 353 |
+
# ID columns (show distribution of first few characters, length distribution)
|
| 354 |
+
elif column_type == "ID":
|
| 355 |
+
# Calculate ID length statistics
|
| 356 |
+
id_lengths = df_sample[column].astype(str).str.len()
|
| 357 |
+
|
| 358 |
+
# Extract first 2 characters for prefix analysis
|
| 359 |
+
id_prefixes = df_sample[column].astype(str).str[:2].value_counts().head(15)
|
| 360 |
+
|
| 361 |
+
fig = make_subplots(
|
| 362 |
+
rows=1,
|
| 363 |
+
cols=2,
|
| 364 |
+
subplot_titles=("ID Length Distribution", "Common ID Prefixes"),
|
| 365 |
+
horizontal_spacing=0.1,
|
| 366 |
+
specs=[[{"type": "histogram"}, {"type": "bar"}]]
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# ID length histogram
|
| 370 |
+
fig.add_trace(go.Histogram(
|
| 371 |
+
x=id_lengths,
|
| 372 |
+
nbinsx=20,
|
| 373 |
+
marker_color='#9C27B0'
|
| 374 |
+
), row=1, col=1)
|
| 375 |
+
|
| 376 |
+
# ID prefix bar chart
|
| 377 |
+
fig.add_trace(go.Bar(
|
| 378 |
+
x=id_prefixes.index,
|
| 379 |
+
y=id_prefixes.values,
|
| 380 |
+
marker_color='#9C27B0',
|
| 381 |
+
text=id_prefixes.values,
|
| 382 |
+
textposition='auto'
|
| 383 |
+
), row=1, col=2)
|
| 384 |
+
|
| 385 |
+
fig.update_layout(title_text=f"ID Analysis: {column}")
|
| 386 |
+
|
| 387 |
+
# Text columns
|
| 388 |
+
elif column_type == "TEXT":
|
| 389 |
+
# For text columns, show top values and length distribution
|
| 390 |
+
value_counts = df_sample[column].value_counts().head(15)
|
| 391 |
+
|
| 392 |
+
# Calculate text length statistics
|
| 393 |
+
text_lengths = df_sample[column].astype(str).str.len()
|
| 394 |
+
|
| 395 |
+
fig = make_subplots(
|
| 396 |
+
rows=2,
|
| 397 |
+
cols=1,
|
| 398 |
+
subplot_titles=("Top Values", "Text Length Distribution"),
|
| 399 |
+
vertical_spacing=0.2,
|
| 400 |
+
specs=[[{"type": "bar"}], [{"type": "histogram"}]]
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
# Top values bar chart
|
| 404 |
+
fig.add_trace(
|
| 405 |
+
go.Bar(
|
| 406 |
+
x=value_counts.index,
|
| 407 |
+
y=value_counts.values,
|
| 408 |
+
marker_color='#00B4D8',
|
| 409 |
+
text=value_counts.values,
|
| 410 |
+
textposition='auto'
|
| 411 |
+
),
|
| 412 |
+
row=1, col=1
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
# Text length histogram
|
| 416 |
+
fig.add_trace(
|
| 417 |
+
go.Histogram(
|
| 418 |
+
x=text_lengths,
|
| 419 |
+
nbinsx=20,
|
| 420 |
+
marker_color='#00B4D8'
|
| 421 |
+
),
|
| 422 |
+
row=2, col=1
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
fig.update_layout(
|
| 426 |
+
height=600,
|
| 427 |
+
title_text=f"Text Analysis: {column}"
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# Fallback for any other column type
|
| 431 |
+
else:
|
| 432 |
+
fig = go.Figure(go.Histogram(x=df_sample[column], marker_color='#888'))
|
| 433 |
+
fig.update_layout(title_text=f"Generic Analysis: {column}")
|
| 434 |
+
|
| 435 |
+
# Common layout settings
|
| 436 |
+
fig.update_layout(
|
| 437 |
+
height=400,
|
| 438 |
+
showlegend=False,
|
| 439 |
+
plot_bgcolor='rgba(0,0,0,0)',
|
| 440 |
+
paper_bgcolor='rgba(0,0,0,0)',
|
| 441 |
+
font=dict(color='#FFFFFF'),
|
| 442 |
+
margin=dict(l=40, r=40, t=50, b=40)
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
return fig
|
| 446 |
+
|
| 447 |
+
except Exception as e:
|
| 448 |
+
log_frontend_error("Chart Generation", f"Error creating chart for {column}: {str(e)}")
|
| 449 |
+
return None
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def visualize_data(df):
|
| 455 |
+
"""Automated dashboard with optimized visualizations."""
|
| 456 |
+
if df is None or df.empty:
|
| 457 |
+
st.error("❌ No data available. Please upload and clean your data first.")
|
| 458 |
+
return
|
| 459 |
+
|
| 460 |
+
# Calculate dataframe hash only once
|
| 461 |
+
df_hash = compute_df_hash(df)
|
| 462 |
+
|
| 463 |
+
# Initialize selected columns in session state if not already present
|
| 464 |
+
if "selected_viz_columns" not in st.session_state:
|
| 465 |
+
# Initialize with first 4 columns or fewer if df has fewer columns
|
| 466 |
+
initial_columns = list(df.columns[:min(4, len(df.columns))])
|
| 467 |
+
st.session_state.selected_viz_columns = initial_columns
|
| 468 |
+
|
| 469 |
+
# Filter out any columns that no longer exist in the dataframe
|
| 470 |
+
valid_columns = [col for col in st.session_state.selected_viz_columns if col in df.columns]
|
| 471 |
+
|
| 472 |
+
# Define a callback function to update selected columns
|
| 473 |
+
def on_column_selection_change():
|
| 474 |
+
# Store the selected columns in session state
|
| 475 |
+
st.session_state.selected_viz_columns = st.session_state.viz_column_selector
|
| 476 |
+
# Ensure we stay on the visualization tab (index 2)
|
| 477 |
+
st.session_state.current_tab_index = 2
|
| 478 |
+
|
| 479 |
+
# Use session state for the multiselect with a consistent key and callback
|
| 480 |
+
selected_columns = st.multiselect(
|
| 481 |
+
"Select columns to visualize",
|
| 482 |
+
options=df.columns,
|
| 483 |
+
default=valid_columns,
|
| 484 |
+
key="viz_column_selector",
|
| 485 |
+
on_change=on_column_selection_change
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
# Check if we need to recompute column types and correlation matrix
|
| 489 |
+
# This will only happen when:
|
| 490 |
+
# 1. We don't have column_types in session_state
|
| 491 |
+
# 2. The dataframe hash has changed (new data)
|
| 492 |
+
# 3. We're using a user-uploaded dataset for the first time
|
| 493 |
+
recompute_needed = (
|
| 494 |
+
"column_types" not in st.session_state or
|
| 495 |
+
"df_hash" not in st.session_state or
|
| 496 |
+
st.session_state.get("df_hash") != df_hash
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
if recompute_needed:
|
| 500 |
+
with st.spinner("🔄 Analyzing data structure..."):
|
| 501 |
+
# Compute and cache column types
|
| 502 |
+
st.session_state.column_types = get_column_types(df)
|
| 503 |
+
# Compute and cache correlation matrix
|
| 504 |
+
st.session_state.corr_matrix = get_corr_matrix(df)
|
| 505 |
+
# Update the dataframe hash
|
| 506 |
+
st.session_state.df_hash = df_hash
|
| 507 |
+
# Ensure we stay on the visualization tab
|
| 508 |
+
st.session_state.current_tab_index = 2
|
| 509 |
+
|
| 510 |
+
# Reset any test results if the data has changed
|
| 511 |
+
if "test_results_calculated" in st.session_state:
|
| 512 |
+
st.session_state.test_results_calculated = False
|
| 513 |
+
# Clear any previous test metrics to avoid using stale data
|
| 514 |
+
for key in ['test_metrics', 'test_y_pred', 'test_y_test', 'test_cm', 'sampling_message']:
|
| 515 |
+
if key in st.session_state:
|
| 516 |
+
del st.session_state[key]
|
| 517 |
+
|
| 518 |
+
# Use cached values from session state
|
| 519 |
+
column_types = st.session_state.column_types
|
| 520 |
+
corr_matrix = st.session_state.corr_matrix
|
| 521 |
+
|
| 522 |
+
if selected_columns:
|
| 523 |
+
# Use a container to wrap all visualizations
|
| 524 |
+
viz_container = st.container()
|
| 525 |
+
|
| 526 |
+
with viz_container:
|
| 527 |
+
for idx in range(0, len(selected_columns), 2):
|
| 528 |
+
col1, col2 = st.columns(2)
|
| 529 |
+
|
| 530 |
+
for i, col in enumerate([col1, col2]):
|
| 531 |
+
if idx + i < len(selected_columns):
|
| 532 |
+
column = selected_columns[idx + i]
|
| 533 |
+
with col:
|
| 534 |
+
# Use consistent keys for charts based on column name
|
| 535 |
+
chart_key = f"plot_{column.replace(' ', '_')}"
|
| 536 |
+
|
| 537 |
+
# Only create chart if column exists in column_types
|
| 538 |
+
if column in column_types:
|
| 539 |
+
fig = create_chart(df, column, column_types[column])
|
| 540 |
+
if fig:
|
| 541 |
+
st.plotly_chart(fig, use_container_width=True, key=chart_key)
|
| 542 |
+
with st.expander(f"📊 Summary Statistics - {column}", expanded=False):
|
| 543 |
+
if "NUMERIC" in column_types[column]:
|
| 544 |
+
st.dataframe(df[column].describe(), key=f"stats_{column.replace(' ', '_')}")
|
| 545 |
+
else:
|
| 546 |
+
st.dataframe(df[column].value_counts(), key=f"counts_{column.replace(' ', '_')}")
|
| 547 |
+
else:
|
| 548 |
+
st.warning(f"⚠️ Column '{column}' not found in the dataset or its type couldn't be determined.")
|
| 549 |
+
|
| 550 |
+
if corr_matrix is not None:
|
| 551 |
+
st.subheader("🔗 Correlation Analysis")
|
| 552 |
+
fig = px.imshow(corr_matrix, title="Correlation Matrix", color_continuous_scale="RdBu")
|
| 553 |
+
st.plotly_chart(fig, use_container_width=True, key="corr_matrix_plot")
|
| 554 |
+
|
| 555 |
+
else:
|
| 556 |
+
st.info("👆 Please select columns to visualize")
|
src/ui/welcome.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from src.preprocessing.clean_data import cached_clean_csv
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
|
| 6 |
+
# Cache static content to avoid recomputation
|
| 7 |
+
@lru_cache(maxsize=1)
|
| 8 |
+
def get_static_content():
|
| 9 |
+
"""Cache static HTML content to avoid regeneration."""
|
| 10 |
+
welcome_header = """
|
| 11 |
+
<div class="welcome-header" style="text-align: left; margin-bottom: 2rem;">
|
| 12 |
+
<h1>Experience Ai like never before</h1>
|
| 13 |
+
<p class="subtitle">
|
| 14 |
+
Performance, Analysis, Insights Made Simple.
|
| 15 |
+
</p>
|
| 16 |
+
</div>
|
| 17 |
+
"""
|
| 18 |
+
features_header = "## ✨ Key Features"
|
| 19 |
+
feature_cards = [
|
| 20 |
+
"""
|
| 21 |
+
<div class="feature-card">
|
| 22 |
+
<h3>📊 Data Analysis</h3>
|
| 23 |
+
<ul>
|
| 24 |
+
<li>Automated data cleaning</li>
|
| 25 |
+
<li>Interactive visualizations</li>
|
| 26 |
+
<li>Statistical insights</li>
|
| 27 |
+
<li>Correlation analysis</li>
|
| 28 |
+
</ul>
|
| 29 |
+
</div>
|
| 30 |
+
""",
|
| 31 |
+
"""
|
| 32 |
+
<div class="feature-card">
|
| 33 |
+
<h3>🤖 Machine Learning</h3>
|
| 34 |
+
<ul>
|
| 35 |
+
<li>Multiple ML algorithms</li>
|
| 36 |
+
<li>Automated model selection</li>
|
| 37 |
+
<li>Hyperparameter tuning</li>
|
| 38 |
+
<li>Performance metrics</li>
|
| 39 |
+
</ul>
|
| 40 |
+
</div>
|
| 41 |
+
""",
|
| 42 |
+
"""
|
| 43 |
+
<div class="feature-card">
|
| 44 |
+
<h3>🔍 AI Insights</h3>
|
| 45 |
+
<ul>
|
| 46 |
+
<li>Data quality checks</li>
|
| 47 |
+
<li>Feature importance</li>
|
| 48 |
+
<li>Model explanations</li>
|
| 49 |
+
<li>Smart recommendations</li>
|
| 50 |
+
</ul>
|
| 51 |
+
</div>
|
| 52 |
+
"""
|
| 53 |
+
]
|
| 54 |
+
getting_started = """
|
| 55 |
+
## 🚀 Getting Started
|
| 56 |
+
1. **Upload Your Dataset**: Use the sidebar to upload your CSV file
|
| 57 |
+
2. **Explore Data**: View statistics and visualizations in the Overview tab
|
| 58 |
+
3. **Train Models**: Select algorithms and tune parameters
|
| 59 |
+
4. **Get Insights**: Receive AI-powered recommendations
|
| 60 |
+
"""
|
| 61 |
+
dataset_requirements = """
|
| 62 |
+
* File format: CSV
|
| 63 |
+
* Maximum size: 200MB
|
| 64 |
+
* Supported column types:
|
| 65 |
+
* Numeric (int, float)
|
| 66 |
+
* Categorical (string, boolean)
|
| 67 |
+
* Temporal (date, datetime)
|
| 68 |
+
* Clean data preferred, but not required
|
| 69 |
+
"""
|
| 70 |
+
example_datasets = """
|
| 71 |
+
Try these example datasets to explore the app:
|
| 72 |
+
* [Iris Dataset](https://archive.ics.uci.edu/ml/datasets/iris)
|
| 73 |
+
* [Boston Housing](https://www.kaggle.com/c/boston-housing)
|
| 74 |
+
* [Wine Quality](https://archive.ics.uci.edu/ml/datasets/wine+quality)
|
| 75 |
+
"""
|
| 76 |
+
return welcome_header, features_header, feature_cards, getting_started, dataset_requirements, example_datasets
|
| 77 |
+
|
| 78 |
+
def show_welcome_page():
|
| 79 |
+
"""Display welcome page with features and instructions efficiently."""
|
| 80 |
+
# Load cached static content
|
| 81 |
+
welcome_header, features_header, feature_cards, getting_started, dataset_requirements, example_datasets = get_static_content()
|
| 82 |
+
|
| 83 |
+
# Render static content
|
| 84 |
+
st.markdown(welcome_header, unsafe_allow_html=True)
|
| 85 |
+
st.markdown(features_header, unsafe_allow_html=True)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# Feature columns with minimal overhead
|
| 90 |
+
col1, col2, col3 = st.columns(3, gap="medium")
|
| 91 |
+
with col1:
|
| 92 |
+
st.markdown(feature_cards[0], unsafe_allow_html=True)
|
| 93 |
+
with col2:
|
| 94 |
+
st.markdown(feature_cards[1], unsafe_allow_html=True)
|
| 95 |
+
with col3:
|
| 96 |
+
st.markdown(feature_cards[2], unsafe_allow_html=True)
|
| 97 |
+
|
| 98 |
+
st.markdown("<br>", unsafe_allow_html=True) # Spacing
|
| 99 |
+
|
| 100 |
+
# Getting Started and Expanders
|
| 101 |
+
st.markdown(getting_started, unsafe_allow_html=True)
|
| 102 |
+
with st.expander("📋 Dataset Requirements"):
|
| 103 |
+
st.markdown(dataset_requirements)
|
| 104 |
+
|
| 105 |
+
with st.expander("🎯 Example Datasets"):
|
| 106 |
+
st.markdown(example_datasets)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# New File Uploader Section
|
| 111 |
+
st.markdown("### 📤 Upload Your Dataset (Currently Using Default Dataset)")
|
| 112 |
+
|
| 113 |
+
# Add a checkbox to indicate if the dataset is already cleaned
|
| 114 |
+
skip_cleaning = st.checkbox("My dataset is already cleaned (skip cleaning)")
|
| 115 |
+
|
| 116 |
+
uploaded_file = st.file_uploader("Upload CSV file", type=["csv"])
|
| 117 |
+
|
| 118 |
+
if uploaded_file is not None:
|
| 119 |
+
try:
|
| 120 |
+
# Validate file size
|
| 121 |
+
file_details = {"FileName": uploaded_file.name, "FileType": uploaded_file.type, "FileSize": uploaded_file.size}
|
| 122 |
+
if uploaded_file.size > 200 * 1024 * 1024: # 200MB limit
|
| 123 |
+
st.error("❌ File size exceeds 200MB limit. Please upload a smaller file.")
|
| 124 |
+
return
|
| 125 |
+
|
| 126 |
+
# Attempt to read the CSV
|
| 127 |
+
try:
|
| 128 |
+
df = pd.read_csv(uploaded_file)
|
| 129 |
+
if df.empty:
|
| 130 |
+
st.error("❌ The uploaded file is empty. Please upload a file with data.")
|
| 131 |
+
return
|
| 132 |
+
|
| 133 |
+
st.success("✅ Dataset uploaded successfully!")
|
| 134 |
+
except pd.errors.EmptyDataError:
|
| 135 |
+
st.error("❌ The uploaded file is empty. Please upload a file with data.")
|
| 136 |
+
return
|
| 137 |
+
except pd.errors.ParserError:
|
| 138 |
+
st.error("❌ Unable to parse the CSV file. Please ensure it's properly formatted.")
|
| 139 |
+
return
|
| 140 |
+
|
| 141 |
+
# Convert dataframe to JSON for caching
|
| 142 |
+
df_json = df.to_json(orient='records')
|
| 143 |
+
|
| 144 |
+
# Use the cached cleaning function with proper error handling
|
| 145 |
+
with st.spinner("🧠 AI is analyzing and cleaning the data..." if not skip_cleaning else "Processing dataset..."):
|
| 146 |
+
try:
|
| 147 |
+
cleaned_df, insights = cached_clean_csv(df_json, skip_cleaning)
|
| 148 |
+
except Exception as cleaning_error:
|
| 149 |
+
st.error(f"❌ Error during data cleaning: {str(cleaning_error)}")
|
| 150 |
+
# Fallback to using the original dataframe
|
| 151 |
+
st.warning("⚠️ Using original dataset without cleaning due to errors.")
|
| 152 |
+
cleaned_df = df
|
| 153 |
+
insights = "Cleaning failed, using original data."
|
| 154 |
+
|
| 155 |
+
# Save results to session state
|
| 156 |
+
st.session_state.df = cleaned_df
|
| 157 |
+
st.session_state.insights = insights
|
| 158 |
+
st.session_state.data_cleaned = True
|
| 159 |
+
st.session_state.dataset_loaded = True
|
| 160 |
+
|
| 161 |
+
# Store a flag to indicate this is a user-uploaded dataset
|
| 162 |
+
st.session_state.is_user_uploaded = True
|
| 163 |
+
|
| 164 |
+
# Store the original dataframe JSON and skip_cleaning preference
|
| 165 |
+
# This helps prevent redundant cleaning
|
| 166 |
+
st.session_state.original_df_json = df_json
|
| 167 |
+
st.session_state.skip_cleaning = skip_cleaning
|
| 168 |
+
|
| 169 |
+
# Reset visualization and model training related session state
|
| 170 |
+
if "column_types" in st.session_state:
|
| 171 |
+
del st.session_state.column_types
|
| 172 |
+
if "corr_matrix" in st.session_state:
|
| 173 |
+
del st.session_state.corr_matrix
|
| 174 |
+
if "df_hash" in st.session_state:
|
| 175 |
+
del st.session_state.df_hash
|
| 176 |
+
if "test_results_calculated" in st.session_state:
|
| 177 |
+
st.session_state.test_results_calculated = False
|
| 178 |
+
|
| 179 |
+
if skip_cleaning:
|
| 180 |
+
st.success("✅ Using uploaded dataset as-is (skipped cleaning).")
|
| 181 |
+
else:
|
| 182 |
+
st.success("✅ Data cleaned successfully!")
|
| 183 |
+
|
| 184 |
+
except Exception as e:
|
| 185 |
+
st.error(f"❌ Error processing dataset: {str(e)}")
|
| 186 |
+
st.info("ℹ️ Please check that your file is a valid CSV and try again.")
|
src/utils/logging.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
# Create logs directory if it doesn't exist
|
| 6 |
+
logs_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'logs')
|
| 7 |
+
if not os.path.exists(logs_dir):
|
| 8 |
+
os.makedirs(logs_dir)
|
| 9 |
+
|
| 10 |
+
# Configure the logger
|
| 11 |
+
def setup_logger():
|
| 12 |
+
"""
|
| 13 |
+
Set up and configure the frontend error logger.
|
| 14 |
+
"""
|
| 15 |
+
# Create a logger instance
|
| 16 |
+
logger = logging.getLogger('frontend_logger')
|
| 17 |
+
logger.setLevel(logging.DEBUG)
|
| 18 |
+
|
| 19 |
+
# Create a file handler
|
| 20 |
+
log_file = os.path.join(logs_dir, f'frontend_errors_{datetime.now().strftime("%Y%m%d")}.log')
|
| 21 |
+
file_handler = logging.FileHandler(log_file)
|
| 22 |
+
file_handler.setLevel(logging.DEBUG)
|
| 23 |
+
|
| 24 |
+
# Create a console handler
|
| 25 |
+
console_handler = logging.StreamHandler()
|
| 26 |
+
console_handler.setLevel(logging.ERROR)
|
| 27 |
+
|
| 28 |
+
# Create a formatter
|
| 29 |
+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
| 30 |
+
file_handler.setFormatter(formatter)
|
| 31 |
+
console_handler.setFormatter(formatter)
|
| 32 |
+
|
| 33 |
+
# Add handlers to logger
|
| 34 |
+
logger.addHandler(file_handler)
|
| 35 |
+
logger.addHandler(console_handler)
|
| 36 |
+
|
| 37 |
+
return logger
|
| 38 |
+
|
| 39 |
+
# Initialize logger
|
| 40 |
+
frontend_logger = setup_logger()
|
| 41 |
+
|
| 42 |
+
def log_frontend_error(error_type: str, error_message: str, additional_info: dict = None):
|
| 43 |
+
"""
|
| 44 |
+
Log frontend errors with detailed information.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
error_type (str): Type of error (e.g., 'Arrow Conversion', 'Model Training', etc.)
|
| 48 |
+
error_message (str): The error message
|
| 49 |
+
additional_info (dict, optional): Additional context about the error
|
| 50 |
+
"""
|
| 51 |
+
error_details = f"Type: {error_type}\nMessage: {error_message}"
|
| 52 |
+
if additional_info:
|
| 53 |
+
error_details += f"\nAdditional Info: {additional_info}"
|
| 54 |
+
|
| 55 |
+
frontend_logger.error(error_details)
|
| 56 |
+
|
| 57 |
+
def log_frontend_warning(warning_type: str, warning_message: str, additional_info: dict = None):
|
| 58 |
+
"""
|
| 59 |
+
Log frontend warnings with detailed information.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
warning_type (str): Type of warning
|
| 63 |
+
warning_message (str): The warning message
|
| 64 |
+
additional_info (dict, optional): Additional context about the warning
|
| 65 |
+
"""
|
| 66 |
+
warning_details = f"Type: {warning_type}\nMessage: {warning_message}"
|
| 67 |
+
if additional_info:
|
| 68 |
+
warning_details += f"\nAdditional Info: {additional_info}"
|
| 69 |
+
|
| 70 |
+
frontend_logger.warning(warning_details)
|