akash commited on
Commit
890025a
·
1 Parent(s): 65eae8a
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import sys
3
+ import os
4
+ import pandas as pd
5
+ import time
6
+
7
+ # Streamlit page setup
8
+ st.set_page_config(
9
+ page_title="AutoML",
10
+ page_icon="🛸",
11
+ layout="wide",
12
+ initial_sidebar_state="expanded",
13
+ menu_items={"Get Help": None, "Report a bug": None, "About": None},
14
+ )
15
+
16
+ # Add project root and src to Python path
17
+ sys.path.extend([
18
+ os.path.dirname(os.path.abspath(__file__)), # Project root
19
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), "src")
20
+ ])
21
+
22
+ # Import loading FIRST before any components
23
+ from src.ui.loading import show_loading_state
24
+ # Import CSS loader FIRST
25
+ from src.ui.css import load_css
26
+
27
+ # Load CSS immediately after imports
28
+ load_css()
29
+
30
+ # Cached resource loading with TTL to refresh components periodically
31
+ @st.cache_resource(ttl=3600) # Cache for 1 hour
32
+ def load_components():
33
+ """Cache component imports to avoid reloading on every rerun"""
34
+ from src import (
35
+ show_footer,
36
+ visualize_data,
37
+ show_welcome_page,
38
+ show_overview_page,
39
+ clean_csv,
40
+ model_training_tab,
41
+ display_ai_insights,
42
+ display_model_evaluation
43
+ )
44
+ return (show_footer, visualize_data,
45
+ show_welcome_page, show_overview_page, clean_csv,
46
+ model_training_tab, display_ai_insights, display_model_evaluation)
47
+
48
+ # Cached header rendering
49
+ @st.cache_data(ttl=86400) # Cache for 24 hours
50
+ def render_header():
51
+ """Cache static header HTML"""
52
+ return """
53
+ <div class='app-header' style='padding: 1rem 0; margin-bottom: 2rem; text-align: center;'>
54
+ <h1 class='app-title' style='margin: 0;'>AutoML</h1>
55
+ <p class='app-tagline' style='margin-top: 0;'>Automated Machine Learning Made Simple.</p>
56
+ </div>
57
+ """
58
+
59
+ # Cached data loading
60
+ @st.cache_data(ttl=3600) # Cache for 1 hour
61
+ def load_default_data():
62
+ """Load and cache the default dataset"""
63
+ try:
64
+ return pd.read_csv("laptop_data.csv")
65
+ except Exception as e:
66
+ st.error(f"❌ Error loading default dataset: {str(e)}")
67
+ return None
68
+
69
+ # Performance monitoring decorator
70
+ def measure_time(func):
71
+ """Decorator to measure execution time of functions"""
72
+ def wrapper(*args, **kwargs):
73
+ start_time = time.time()
74
+ result = func(*args, **kwargs)
75
+ end_time = time.time()
76
+ execution_time = end_time - start_time
77
+ if execution_time > 1.0: # Only log slow operations
78
+ print(f"⏱️ {func.__name__} took {execution_time:.2f} seconds to execute")
79
+ return result
80
+ return wrapper
81
+
82
+ @measure_time
83
+ def main():
84
+ """Optimized main function for Streamlit AutoML app"""
85
+ # First show loading screen before anything else
86
+ if "initialized" not in st.session_state:
87
+ # Show loading animation in full screen mode
88
+ with st.container():
89
+ show_loading_state()
90
+
91
+ # Force render loading screen first
92
+ st.empty().markdown("<style>#root > div:nth-child(1) > div > div > div > div > section > div {padding: 0rem;}</style>", unsafe_allow_html=True)
93
+
94
+ # Now load components in background
95
+ components = load_components()
96
+ (show_footer, visualize_data,
97
+ show_welcome_page, show_overview_page, clean_csv,
98
+ model_training_tab, display_ai_insights, display_model_evaluation) = components
99
+
100
+ try:
101
+ # Load and clean data with caching
102
+ default_df = load_default_data()
103
+ if default_df is not None:
104
+ cleaned_df, insights = clean_csv(default_df)
105
+
106
+ # Store everything in session state
107
+ st.session_state.update({
108
+ "df": cleaned_df,
109
+ "insights": insights,
110
+ "components": components,
111
+ "initialized": True,
112
+ "current_tab_index": 0 # Use consistent naming for tab tracking
113
+ })
114
+
115
+ # Rerun to hide loading screen
116
+ st.rerun()
117
+ else:
118
+ st.error("❌ Failed to load default dataset")
119
+ return
120
+
121
+ except Exception as e:
122
+ st.error(f"❌ Error during initialization: {str(e)}")
123
+ return
124
+
125
+ # After initialization, show main interface
126
+ if "initialized" in st.session_state:
127
+ components = st.session_state.components
128
+ (show_footer, visualize_data,
129
+ show_welcome_page, show_overview_page, clean_csv,
130
+ model_training_tab, display_ai_insights, display_model_evaluation) = components
131
+
132
+ # Render main interface
133
+ st.markdown(render_header(), unsafe_allow_html=True)
134
+
135
+ # Create tabs with tab names as constants to avoid recreation
136
+ TAB_NAMES = ["👋 Welcome", "📊 Overview", "📈 Visualization",
137
+ "🤖 Model Training", "💡 Insights", "📊 Test Results"]
138
+
139
+ # Initialize current tab index if not present
140
+ if "current_tab_index" not in st.session_state:
141
+ st.session_state.current_tab_index = 0
142
+
143
+ # Create tabs and get the current tab index
144
+ tab_index = st.tabs(TAB_NAMES)
145
+
146
+ # Display content in all tabs
147
+ with tab_index[0]:
148
+ show_welcome_page()
149
+
150
+ with tab_index[1]:
151
+ show_overview_page()
152
+
153
+ with tab_index[2]:
154
+ visualize_data(st.session_state.df)
155
+
156
+ with tab_index[3]:
157
+ model_training_tab(st.session_state.df)
158
+
159
+ with tab_index[4]:
160
+ display_ai_insights()
161
+
162
+ with tab_index[5]:
163
+ display_model_evaluation()
164
+
165
+ show_footer()
166
+
167
+ if __name__ == "__main__":
168
+ main()
example.env ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AutoML Environment Variables
2
+
3
+ # API Keys for LLM Services
4
+ GROQ_API_KEY=your_groq_api_key_here
5
+ GEMINI_API_KEY=your_gemini_api_key_here
6
+
7
+ # LangSmith Tracking (Optional)
8
+ LANGCHAIN_TRACING_V2=true
9
+ LANGCHAIN_API_KEY=your_langchain_api_key_here
10
+ LANGCHAIN_PROJECT=automl-project
11
+
12
+ # Optional: Logging Configuration
13
+ LOG_LEVEL=INFO
header.svg ADDED
laptop_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit>=1.29.0
2
+ pandas>=2.0.0
3
+ numpy>=1.24.0
4
+ scikit-learn>=1.2.0
5
+ matplotlib>=3.7.0
6
+ plotly>=5.14.0
7
+ seaborn>=0.12.0
8
+ langchain>=0.0.267
9
+ langchain-groq>=0.0.1
10
+ langchain-google-genai>=0.0.3
11
+ python-dotenv>=1.0.0
12
+ scipy>=1.10.0
13
+ joblib>=1.2.0
14
+ pydantic>=2.0.0
15
+ requests>=2.28.0
16
+ pillow>=9.0.0
17
+ altair>=4.2.0
18
+ beautifulsoup4>=4.11.0
src/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ from .ui import *
3
+ from .training import *
4
+ from .preprocessing import *
5
+ from .utils import *
6
+
src/preprocessing/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .clean_data import clean_csv
2
+ from .clean_df_fallback import clean_dataframe_fallback
3
+
4
+ __all__ = ['clean_csv' , 'clean_dataframe_fallback']
src/preprocessing/clean_data.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.impute import SimpleImputer
2
+ from dotenv import load_dotenv
3
+ from scipy import stats
4
+ from langchain_groq import ChatGroq
5
+ from langchain.chains import LLMChain
6
+ import pandas as pd
7
+ import numpy as np
8
+ import re
9
+ import os
10
+ from langchain_google_genai import ChatGoogleGenerativeAI
11
+ from langchain.prompts import PromptTemplate
12
+ from langchain_core.runnables import RunnableSequence
13
+ import streamlit as st
14
+ from .clean_df_fallback import clean_dataframe_fallback
15
+
16
+
17
+ # # Load environment variables
18
+
19
+ load_dotenv()
20
+
21
+
22
+
23
+ groq_api_key = os.getenv("GROQ_API_KEY")
24
+ gemini_api_key = os.getenv("GEMINI_API_KEY")
25
+
26
+
27
+ if not gemini_api_key:
28
+ raise ValueError("GEMINI_API_KEY not found in environment variables")
29
+ if not groq_api_key:
30
+ raise ValueError("GROQ_API_KEY not found in environment variables")
31
+
32
+
33
+ # Initialize the LLM model
34
+ try:
35
+ llm = ChatGoogleGenerativeAI(
36
+ model="gemini-2.0-flash-lite-preview-02-05",
37
+ google_api_key=gemini_api_key
38
+ )
39
+ print("Primary Gemini LLM loaded successfully.")
40
+
41
+ except Exception as e:
42
+ print(f"Error initializing primary Gemini LLM: {e}")
43
+
44
+ # Fallback to a different LLM from Groq
45
+ try:
46
+ llm = ChatGroq(
47
+ model="gemma2-9b-it", # replace with your desired Groq model identifier
48
+ groq_api_key=groq_api_key
49
+ )
50
+ print("Fallback Groq LLM loaded successfully.")
51
+
52
+ except Exception as e2:
53
+ print(f"Error initializing fallback Groq LLM: {e2}")
54
+ llm=None
55
+
56
+
57
+
58
+ # Cache the clean_csv function to prevent redundant cleaning
59
+ @st.cache_data(ttl=3600, show_spinner=False)
60
+ def cached_clean_csv(df_json, skip_cleaning=False):
61
+ """Cached version of the clean_csv function to prevent redundant cleaning.
62
+
63
+ Args:
64
+ df_json: JSON string representation of the dataframe (for hashing)
65
+ skip_cleaning: Whether to skip cleaning
66
+
67
+ Returns:
68
+ Tuple of (cleaned_df, insights)
69
+ """
70
+ # Convert JSON back to dataframe
71
+ df = pd.read_json(df_json, orient='records')
72
+
73
+ # If skip_cleaning is True, return the dataframe as is
74
+ if skip_cleaning:
75
+ return df, "No cleaning performed (user skipped)."
76
+
77
+ # Reset any test results if we're cleaning a new dataset
78
+ if "test_results_calculated" in st.session_state:
79
+ st.session_state.test_results_calculated = False
80
+ # Clear any previous test metrics to avoid using stale data
81
+ for key in ['test_metrics', 'test_y_pred', 'test_y_test', 'test_cm', 'sampling_message']:
82
+ if key in st.session_state:
83
+ del st.session_state[key]
84
+
85
+ # Call the actual cleaning function
86
+ return clean_csv(df)
87
+
88
+
89
+ def clean_csv(df):
90
+ """Original clean_csv function that performs the actual cleaning."""
91
+ # ---------------------------
92
+ # Early fallback if LLM initialization failed
93
+ # ---------------------------
94
+ if llm is None:
95
+ print("LLM initialization failed; using hardcoded cleaning function.")
96
+ fallback_df = clean_dataframe_fallback(df)
97
+
98
+ return fallback_df , "LLM initialization failed; using hardcoded cleaning function, so no insights were generated."
99
+
100
+
101
+
102
+ # ---------------------------
103
+ # LLM-based cleaning function generation
104
+ # ---------------------------
105
+
106
+
107
+ # Escape curly braces in the JSON sample and column names
108
+ sample_data = df.head(3).to_json(orient='records')
109
+ escaped_sample_data = sample_data.replace("{", "{{").replace("}", "}}")
110
+
111
+ escaped_columns = [
112
+ col.replace("{", "{{").replace("}", "}}") for col in df.columns
113
+ ]
114
+ column_names_str = ", ".join(escaped_columns)
115
+
116
+
117
+
118
+ # Define the prompt for generating the cleaning function
119
+ initial_prompt = PromptTemplate.from_template(f'''
120
+ You are given the following sample data from a pandas DataFrame:
121
+ {escaped_sample_data}
122
+
123
+ column names are : [{column_names_str}].
124
+
125
+ Generate a Python function named clean_dataframe(df) considering the following:
126
+
127
+
128
+ 1. Performs thorough data cleaning without performing feature engineering. Ensure all necessary cleaning steps are included.
129
+ 2. Uses assignment operations (e.g., df = df.drop(...)) and avoids inplace=True for clarity.
130
+ 3. First deeply analyze each column’s content this is the most important step , to infer its predominant data type for example if we have RS.2100 in rows remove rs and if we have (89%) remove % , if the column contains only text and no numbers then it is a text column and if it contains numbers and text then it is a mixed column and if it contains only numbers then it is a numeric column.
131
+ 4. For columns that are intended to be numeric but contain extra characters (such as '%' in percentage values, currency symbols like 'Rs.', '$', and commas), remove all non-digit characters (except for the decimal point) and convert them to a numeric type.
132
+ 5. For columns that are clearly text or categorical, preserve the content without removing digits or altering the textual information.
133
+ 6. Handles missing values appropriately: fill numeric columns with the median (or 0 if the median is not available) and non-numeric columns with 'Unknown'.
134
+ 7. For columns where more than 50% of values are strings and less than 10% are numeric, perform conservative string cleaning by removing unwanted special symbols while preserving meaningful digits.
135
+ 8. For columns whose names contain 'name', 'Name', or 'Names' (case-insensitive), convert to string type and remove extraneous numeric characters only if they are not part of the essential text.
136
+ 9. Preserves other categorical or text columns (such as Gender, City, State, Country, etc.) unless explicitly specified for removal.
137
+ 10. Handles edge cases such as completely empty columns appropriately.
138
+
139
+ Return only the Python code for the function, with no explanations or extra formatting.
140
+
141
+ '''
142
+ )
143
+
144
+
145
+
146
+ # Define the refinement prompt
147
+ refine_prompt = PromptTemplate.from_template(
148
+ "The following Python code for cleaning a DataFrame caused an error: {error}\n"
149
+ "Original code:\n{code}\n"
150
+ "Please correct the code to fix the error and ensure it returns a cleaned DataFrame. "
151
+ "Return only the corrected Python code for the function, no explanations or formatting."
152
+ )
153
+
154
+
155
+
156
+
157
+ # Create the chains using modern LangChain approach
158
+ initial_chain = initial_prompt | llm
159
+ refine_chain = refine_prompt | llm
160
+
161
+
162
+
163
+
164
+
165
+
166
+
167
+ def extract_code(response):
168
+
169
+ if isinstance(response, str):
170
+ # Handle Markdown or plain text
171
+ if "```python" in response:
172
+ match = re.search(r'```python\n(.*?)\n```', response, re.DOTALL)
173
+ return match.group(1).strip() if match else response
174
+
175
+ elif "```" in response:
176
+ match = re.search(r'```\n(.*?)\n```', response, re.DOTALL)
177
+ return match.group(1).strip() if match else response
178
+
179
+ return response.strip()
180
+
181
+ # Handle LLM response objects
182
+ content = getattr(response, 'content', str(response))
183
+
184
+ if "```python" in content:
185
+ match = re.search(r'```python\n(.*?)\n```', content, re.DOTALL)
186
+ return match.group(1).strip() if match else content
187
+
188
+ elif "```" in content:
189
+ match = re.search(r'```\n(.*?)\n```', content, re.DOTALL)
190
+ return match.group(1).strip() if match else content
191
+
192
+ return content.strip()
193
+
194
+
195
+
196
+
197
+
198
+
199
+
200
+ try:
201
+ # Generate initial chain and extract the cleaned code
202
+ cleaning_function_code = extract_code(initial_chain.invoke({}))
203
+ print("Initial generated cleaning function code not executed yet is:\n", cleaning_function_code)
204
+
205
+ # Iterative refinement loop with max 5 attempts
206
+ max_attempts = 5
207
+
208
+ for attempt in range(max_attempts):
209
+ print(f"Attempt {attempt} code:\n{cleaning_function_code}") # <-- HERE
210
+ try:
211
+ # Execute the code in global namespace
212
+ exec(cleaning_function_code, globals())
213
+ # Call the function and assign the result back to df
214
+
215
+
216
+ if 'clean_dataframe' not in globals():
217
+ raise NameError("Cleaning function not defined in generated code")
218
+
219
+ df = clean_dataframe(df)
220
+
221
+ print(f"Cleaning successful on attempt {attempt + 1}")
222
+ break
223
+
224
+ # if the cleaning fails
225
+ except Exception as e:
226
+ error_message = str(e)
227
+ print(f"Error on attempt {attempt + 1}: {error_message}")
228
+
229
+ if attempt < max_attempts - 1:
230
+
231
+ # Refine the code using the error message if there are still epochs left
232
+ refined_response = refine_chain.invoke({"error": error_message, "code": cleaning_function_code})
233
+ cleaning_function_code = extract_code(refined_response)
234
+
235
+ print(f"Refined cleaning function code:\n", cleaning_function_code)
236
+
237
+ else:
238
+ print("Failed to clean DataFrame after 5 maximum attempts")
239
+ # AFter all the failed attempt using the hardcoded logic
240
+
241
+ df = clean_dataframe_fallback(df)
242
+
243
+ except Exception as e:
244
+ print("⚡No successful cleaning done enforcing fallback")
245
+ df = clean_dataframe_fallback(df)
246
+
247
+
248
+ cleaned_df = df
249
+
250
+
251
+ insights_prompt = f"""
252
+ Analyze this cleaned dataset:
253
+ - Columns: {cleaned_df.columns.tolist()}
254
+ - Sample data: {cleaned_df.head(3).to_dict()}
255
+ - Numeric stats: {cleaned_df.describe().to_dict()}
256
+ Provide key data quality insights and recommendations.
257
+ """
258
+
259
+ try:
260
+ insights_response = llm.invoke(insights_prompt)
261
+ analysis_insights = insights_response.content
262
+ except Exception as e:
263
+ analysis_insights = f"Insight generation failed: {str(e)}"
264
+
265
+
266
+
267
+ # Return the cleaned DataFrame and dummy insights
268
+ return cleaned_df, analysis_insights
src/preprocessing/clean_df_fallback.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import pandas as pd
3
+ import numpy as np
4
+ import streamlit as st
5
+
6
+ # Define fallback cleaning function
7
+
8
+ @st.cache_data
9
+ def clean_dataframe_fallback(df):
10
+ """Hardcoded data cleaning pipeline"""
11
+
12
+ """Generic data cleaning pipeline with categorical preservation"""
13
+ df_cleaned = df.copy()
14
+
15
+
16
+ df_cleaned = df_cleaned.applymap(
17
+ lambda x: re.sub(r"\(.*?\)", "", str(x)) if isinstance(x, str) else x)
18
+
19
+ # Remove 'ref.' references
20
+ df_cleaned = df_cleaned.applymap(
21
+ lambda x: re.sub(r"ref\.", "", str(x), flags=re.IGNORECASE) if isinstance(x, str) else x)
22
+
23
+ # Remove any other special characters except letters, digits, spaces, and dots
24
+ df_cleaned = df_cleaned.applymap(
25
+ lambda x: re.sub(r"[^\w\s\d\.]", "", str(x)).strip() if isinstance(x, str) else x
26
+ )
27
+
28
+
29
+ # Step 0 - Clean column names first
30
+ df_cleaned.columns = [col.strip().lower().replace(' ', '_') for col in df_cleaned.columns]
31
+
32
+ # Define measurement units to remove
33
+ measurement_units = {
34
+ 'weight': r'\s*(kg|kilograms|lbs|pounds)$',
35
+ 'height': r'\s*(cm|centimeters|inches|feet|ft)$'
36
+ }
37
+
38
+
39
+ # Step 1 - Remove redundant columns
40
+ # Preservation patterns for categorical columns
41
+ preserve_pattern = re.compile(r'(name|brand|model|type|category|region|text|desc|color|size)', re.IGNORECASE)
42
+ preserved_cols = [col for col in df_cleaned.columns if preserve_pattern.search(col)]
43
+
44
+ # ID pattern detection
45
+ id_pattern = re.compile(r'(_id|id_|num|no|number|identifier|code|idx|row)', re.IGNORECASE)
46
+ id_cols = [col for col in df_cleaned.columns if id_pattern.search(col) and col not in preserved_cols]
47
+
48
+ # Unique value columns
49
+ unique_cols = [col for col in df_cleaned.columns
50
+ if df_cleaned[col].nunique() == len(df_cleaned)
51
+ and col not in preserved_cols]
52
+
53
+ redundant_cols = list(set(id_cols + unique_cols))
54
+ df_cleaned = df_cleaned.drop(columns=redundant_cols)
55
+ print(f"Removed {len(redundant_cols)} redundant columns: {redundant_cols}")
56
+
57
+ # Step 2 - Enhanced numeric detection with categorical protection
58
+ for col in df_cleaned.columns:
59
+ if col in preserved_cols:
60
+ print(f"Preserving categorical column: {col}")
61
+ continue # Skip preserved columns
62
+
63
+ if any(unit in col for unit in measurement_units.keys()):
64
+ pattern = measurement_units.get(col.split('_')[0], r'')
65
+ df_cleaned[col] = df_cleaned[col].astype(str).str.replace(pattern, '', regex=True).str.strip()
66
+
67
+
68
+
69
+ if pd.api.types.is_numeric_dtype(df_cleaned[col]):
70
+ continue
71
+
72
+ # Strict numeric pattern detection
73
+ non_null_count = df_cleaned[col].dropna().shape[0]
74
+ sample_size = min(100, non_null_count)
75
+ sample = df_cleaned[col].dropna().sample(sample_size, random_state=42)
76
+ numeric_pattern = r'^[-+]?\d*\.?\d+$' # Full string match
77
+ num_matches = sample.astype(str).str.fullmatch(numeric_pattern).mean()
78
+
79
+ if num_matches > 0.8: # High threshold
80
+ # Conservative cleaning
81
+ cleaned = df_cleaned[col].replace(r'[^\d\.\-]', '', regex=True)
82
+ converted = pd.to_numeric(cleaned, errors='coerce')
83
+ success_rate = converted.notna().mean()
84
+
85
+ if success_rate > 0.9: # Strict success requirement
86
+ df_cleaned[col] = converted
87
+ print(f"Converted {col} to numeric (success: {success_rate:.1%})")
88
+
89
+ # Step 3 - Date detection
90
+ date_cols = []
91
+ for col in df_cleaned.select_dtypes(exclude=np.number).columns:
92
+ if col in preserved_cols:
93
+ continue
94
+ try:
95
+ df_cleaned[col] = pd.to_datetime(df_cleaned[col], errors='raise')
96
+ date_cols.append(col)
97
+ print(f"Detected datetime: {col}")
98
+ except:
99
+ pass
100
+
101
+ # Example manual approach:
102
+ currency_cols = [col for col in df_cleaned.columns if any(keyword in col.lower() for keyword in ["price", "gross", "budget"])]
103
+ for col in currency_cols:
104
+ df_cleaned[col] = df_cleaned[col].astype(str).str.replace(r'[^\d\.]', '', regex=True) # remove everything except digits & dots
105
+ df_cleaned[col] = pd.to_numeric(df_cleaned[col], errors='coerce')
106
+
107
+
108
+
109
+ # Step 4 - Missing value handling
110
+ numeric_cols = df_cleaned.select_dtypes(include=np.number).columns
111
+ categorical_cols = df_cleaned.select_dtypes(exclude=np.number).columns
112
+
113
+ # Numeric imputation
114
+ for col in numeric_cols:
115
+ if df_cleaned[col].isna().any():
116
+ df_cleaned[f'{col}_missing'] = df_cleaned[col].isna().astype(int)
117
+ df_cleaned[col].fillna(df_cleaned[col].median(), inplace=True)
118
+
119
+ # Categorical imputation
120
+ for col in categorical_cols:
121
+ if df_cleaned[col].isna().any():
122
+ mode_val = df_cleaned[col].mode()[0] if not df_cleaned[col].mode().empty else 'Unknown'
123
+ df_cleaned[col] = df_cleaned[col].fillna(mode_val)
124
+
125
+ # Step 5 - Text normalization for non-preserved columns
126
+ text_cols = [col for col in categorical_cols if col not in preserved_cols]
127
+ for col in text_cols:
128
+ df_cleaned[col] = df_cleaned[col].astype(str).apply(lambda x: re.sub(r'\s+', ' ', re.sub(r'[^\w\s]', '', x)).strip().lower())
129
+
130
+
131
+ # Step 6 - Outlier handling (preserve categoricals)
132
+ numeric_cols = df_cleaned.select_dtypes(include=np.number).columns
133
+ for col in numeric_cols:
134
+ if df_cleaned[col].nunique() > 10:
135
+ q1 = df_cleaned[col].quantile(0.05)
136
+ q3 = df_cleaned[col].quantile(0.95)
137
+ df_cleaned[col] = np.clip(df_cleaned[col], q1, q3)
138
+
139
+ # Step 7 - Final validation
140
+ df_cleaned = df_cleaned.drop_duplicates().reset_index(drop=True)
141
+
142
+ return df_cleaned
143
+
src/training/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .train import train_model
2
+ from .hyperparametrs import get_hyperparams_ui
3
+ from .model_training import model_training_tab
4
+ from .test_result import display_model_evaluation
5
+
6
+ __all__ = ["train_model" , "get_hyperparams_ui", "model_training_tab" , "display_model_evaluation"]
7
+
src/training/hyperparametrs.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+
5
+ # Define Hyperparameter Options
6
+ def get_hyperparams_ui(model_name):
7
+ """Generate UI components for model-specific hyperparameters."""
8
+ hyperparams = {}
9
+
10
+ if model_name in ["Random Forest Regressor", "Random Forest"]:
11
+ hyperparams["n_estimators"] = st.number_input("Number of Trees (n_estimators)", min_value=10, max_value=500, value=100)
12
+ hyperparams["max_depth"] = st.number_input("Max Depth", min_value=1, max_value=50, value=10)
13
+ hyperparams["min_samples_split"] = st.number_input("Min Samples Split", min_value=2, max_value=10, value=2)
14
+
15
+ elif model_name in ["XGBoost Regressor", "XGBoost"]:
16
+ hyperparams["n_estimators"] = st.number_input("Number of Boosting Rounds (n_estimators)", min_value=10, max_value=500, value=100)
17
+ hyperparams["learning_rate"] = st.slider("Learning Rate", 0.01, 1.0, 0.1)
18
+ hyperparams["max_depth"] = st.number_input("Max Depth", min_value=1, max_value=50, value=6)
19
+
20
+ elif model_name == "Linear Regression":
21
+ st.info("No hyperparameters required for Linear Regression.")
22
+
23
+ # New Regression Models:
24
+ elif model_name == "Polynomial Regression":
25
+ hyperparams["degree"] = st.number_input("Degree of Polynomial Features", min_value=2, max_value=10, value=2)
26
+ # You may add additional hyperparameters for the underlying LinearRegression if needed
27
+
28
+ elif model_name == "Ridge Regression":
29
+ hyperparams["alpha"] = st.slider("Regularization Strength (alpha)", 0.01, 10.0, 1.0)
30
+ hyperparams["solver"] = st.selectbox("Solver", ["auto", "svd", "cholesky", "lsqr", "sparse_cg", "sag", "saga", "lbfgs"])
31
+
32
+ elif model_name == "Lasso Regression":
33
+ hyperparams["alpha"] = st.slider("Regularization Strength (alpha)", 0.01, 10.0, 1.0)
34
+ hyperparams["max_iter"] = st.number_input("Max Iterations", min_value=100, max_value=1000, value=1000)
35
+
36
+
37
+
38
+ elif model_name == "Logistic Regression":
39
+ hyperparams["C"] = st.slider("Regularization Strength (C)", 0.01, 10.0, 1.0)
40
+ hyperparams["max_iter"] = st.number_input("Max Iterations", min_value=100, max_value=1000, value=200)
41
+
42
+ elif model_name == "Support Vector Regressor":
43
+ hyperparams["C"] = st.slider("Regularization parameter (C)", 0.1, 100.0, 1.0)
44
+ hyperparams["epsilon"] = st.slider("Epsilon", 0.0, 1.0, 0.1)
45
+ hyperparams["kernel"] = st.selectbox("Kernel", ["linear", "rbf", "poly", "sigmoid"])
46
+
47
+ elif model_name == "Decision Tree Regressor":
48
+ hyperparams["max_depth"] = st.number_input("Max Depth", min_value=1, max_value=50, value=10)
49
+ hyperparams["min_samples_split"] = st.number_input("Min Samples Split", min_value=2, max_value=10, value=2)
50
+
51
+ elif model_name == "K-Nearest Neighbors Regressor":
52
+ hyperparams["n_neighbors"] = st.number_input("Number of Neighbors", min_value=1, max_value=100, value=5)
53
+ hyperparams["weights"] = st.selectbox("Weight Function", ["uniform", "distance"])
54
+
55
+ elif model_name == "ElasticNet":
56
+ hyperparams["alpha"] = st.slider("Alpha", 0.01, 10.0, 1.0)
57
+ hyperparams["l1_ratio"] = st.slider("L1 Ratio", 0.0, 1.0, 0.5)
58
+
59
+ elif model_name == "Gradient Boosting Regressor":
60
+ hyperparams["n_estimators"] = st.number_input("Number of Estimators", min_value=10, max_value=500, value=100)
61
+ hyperparams["learning_rate"] = st.slider("Learning Rate", 0.01, 1.0, 0.1)
62
+ hyperparams["max_depth"] = st.number_input("Max Depth", min_value=1, max_value=20, value=3)
63
+
64
+ elif model_name == "AdaBoost Regressor":
65
+ hyperparams["n_estimators"] = st.number_input("Number of Estimators", min_value=10, max_value=500, value=50)
66
+ hyperparams["learning_rate"] = st.slider("Learning Rate", 0.01, 1.0, 0.1)
67
+
68
+ elif model_name == "Bayesian Ridge":
69
+ hyperparams["alpha_1"] = st.slider("Alpha 1", 1e-6, 1e-1, 1e-4, format="%.6f")
70
+ hyperparams["alpha_2"] = st.slider("Alpha 2", 1e-6, 1e-1, 1e-4, format="%.6f")
71
+ hyperparams["lambda_1"] = st.slider("Lambda 1", 1e-6, 1e-1, 1e-4, format="%.6f")
72
+ hyperparams["lambda_2"] = st.slider("Lambda 2", 1e-6, 1e-1, 1e-4, format="%.6f")
73
+
74
+ # --- Additional Classification Models ---
75
+ elif model_name == "Support Vector Classifier":
76
+ hyperparams["C"] = st.slider("Regularization parameter (C)", 0.1, 100.0, 1.0)
77
+ hyperparams["kernel"] = st.selectbox("Kernel", ["linear", "rbf", "poly", "sigmoid"])
78
+
79
+ elif model_name == "Decision Tree Classifier":
80
+ hyperparams["max_depth"] = st.number_input("Max Depth", min_value=1, max_value=50, value=10)
81
+ hyperparams["min_samples_split"] = st.number_input("Min Samples Split", min_value=2, max_value=10, value=2)
82
+
83
+ elif model_name == "K-Nearest Neighbors Classifier":
84
+ hyperparams["n_neighbors"] = st.number_input("Number of Neighbors", min_value=1, max_value=100, value=5)
85
+ hyperparams["weights"] = st.selectbox("Weight Function", ["uniform", "distance"])
86
+
87
+ elif model_name == "Gradient Boosting Classifier":
88
+ hyperparams["n_estimators"] = st.number_input("Number of Estimators", min_value=10, max_value=500, value=100)
89
+ hyperparams["learning_rate"] = st.slider("Learning Rate", 0.01, 1.0, 0.1)
90
+ hyperparams["max_depth"] = st.number_input("Max Depth", min_value=1, max_value=20, value=3)
91
+
92
+ elif model_name == "AdaBoost Classifier":
93
+ hyperparams["n_estimators"] = st.number_input("Number of Estimators", min_value=10, max_value=500, value=50)
94
+ hyperparams["learning_rate"] = st.slider("Learning Rate", 0.01, 1.0, 0.1)
95
+
96
+ elif model_name == "Gaussian Naive Bayes":
97
+ hyperparams["var_smoothing"] = st.slider("Var Smoothing", 1e-12, 1e-8, 1e-9, format="%.12f")
98
+
99
+ elif model_name == "Quadratic Discriminant Analysis":
100
+ hyperparams["reg_param"] = st.slider("Regularization Parameter", 0.0, 1.0, 0.0)
101
+
102
+ elif model_name == "Linear Discriminant Analysis":
103
+ hyperparams["solver"] = st.selectbox("Solver", ["svd", "lsqr", "eigen"])
104
+
105
+
106
+
107
+ return hyperparams
src/training/model_training.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from .hyperparametrs import get_hyperparams_ui
3
+ import pickle
4
+ from .train import train_model
5
+
6
+ # Model Training Tab
7
+ def model_training_tab(df):
8
+ # Ensure we have session state for model training
9
+ if "target_column" not in st.session_state:
10
+ st.session_state.target_column = df.columns[0] if not df.empty else None
11
+
12
+ if "selected_model" not in st.session_state:
13
+ st.session_state.selected_model = None
14
+
15
+ st.subheader("📌 Model Training")
16
+
17
+ # Use session state to maintain selection across reruns
18
+ target_column = st.selectbox(
19
+ "🎯 Select Target Column (Y)",
20
+ df.columns,
21
+ index=list(df.columns).index(st.session_state.target_column) if st.session_state.target_column in df.columns else 0,
22
+ key="target_column_select"
23
+ )
24
+
25
+ # Update session state after selection
26
+ st.session_state.target_column = target_column
27
+
28
+ # Infer task type automatically
29
+ task_type = "classification" if df[target_column].dtype == "object" or df[target_column].nunique() <= 10 else "regression"
30
+ st.write(f"🔍 Detected Task Type: **{task_type.capitalize()}**")
31
+
32
+ model_options = {
33
+ "classification": ["Random Forest", "Logistic Regression", "XGBoost" , "Support Vector Classifier", "Decision Tree Classifier", "K-Nearest Neighbors Classifier", "Gradient Boosting Classifier", "AdaBoost Classifier", "Gaussian Naive Bayes", "Quadratic Discriminant Analysis", "Linear Discriminant Analysis"],
34
+ "regression": ["Linear Regression", "Random Forest Regressor", "XGBoost Regressor" , "Support Vector Regressor", "Decision Tree Regressor", "K-Nearest Neighbors Regressor", "ElasticNet", "Gradient Boosting Regressor", "AdaBoost Regressor", "Bayesian Ridge" , "Ridge Regression", "Lasso Regression"],
35
+ }
36
+
37
+ # Initialize selected model if not already set or if task type changed
38
+ if st.session_state.selected_model not in model_options[task_type]:
39
+ st.session_state.selected_model = model_options[task_type][0]
40
+
41
+ # Use session state to maintain selection across reruns
42
+ selected_model_name = st.selectbox(
43
+ "🤖 Choose Model",
44
+ model_options[task_type],
45
+ index=model_options[task_type].index(st.session_state.selected_model),
46
+ key="selected_model_select"
47
+ )
48
+
49
+ # Update session state after selection
50
+ st.session_state.selected_model = selected_model_name
51
+
52
+ st.markdown("### 🔧 Hyperparameters")
53
+ hyperparams = get_hyperparams_ui(selected_model_name)
54
+
55
+ # Use a unique key for the button to avoid conflicts
56
+ if st.button("🚀 Train Model", key="train_model_button_unique"):
57
+ with st.spinner("Training in progress... ⏳"):
58
+ trained_model = train_model(df, target_column, task_type, selected_model_name, hyperparams)
59
+ st.success("✅ Model trained successfully!")
60
+ st.session_state.trained_model = trained_model
61
+ st.session_state.model_trained = True
62
+
63
+ # Note: test_results_calculated is already reset in train_model function
64
+
65
+ if "trained_model" in st.session_state:
66
+ st.markdown("### 📥 Download Trained Model")
67
+
68
+ # Use a safer approach for file operations with proper cleanup
69
+ try:
70
+ # Use a temporary file that will be automatically cleaned up
71
+ import tempfile
72
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as temp_file:
73
+ pickle.dump(st.session_state.trained_model, temp_file)
74
+ temp_file_path = temp_file.name
75
+
76
+ # Read the file for download
77
+ with open(temp_file_path, "rb") as f:
78
+ st.download_button(
79
+ label="📥 Download Model",
80
+ data=f,
81
+ file_name="trained_model.pkl",
82
+ mime="application/octet-stream",
83
+ )
84
+
85
+ # Clean up the temporary file
86
+ import os
87
+ try:
88
+ os.unlink(temp_file_path)
89
+ except:
90
+ pass # Silently handle deletion errors
91
+
92
+ except Exception as e:
93
+ st.error(f"Error preparing model for download: {str(e)}")
src/training/test_result.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from ui.test_results import display_test_results
3
+
4
+ def display_model_evaluation():
5
+ """Displays the evaluation results of the trained model on the test set."""
6
+
7
+ st.header("📊 Model Evaluation on Test Set")
8
+
9
+ # Ensure model and test data exist in session state
10
+ if "trained_model" in st.session_state and "X_test" in st.session_state:
11
+ trained_model = st.session_state.trained_model
12
+ X_test = st.session_state.X_test
13
+ y_test = st.session_state.y_test
14
+ task_type = st.session_state.task_type
15
+
16
+ # Handle classification case where model may include a label encoder
17
+ if task_type == "classification":
18
+ if isinstance(trained_model, tuple):
19
+ pipeline, label_encoder = trained_model
20
+ display_test_results((pipeline, label_encoder), X_test, y_test, task_type)
21
+ else:
22
+ display_test_results(trained_model, X_test, y_test, task_type)
23
+ else:
24
+ display_test_results(trained_model, X_test, y_test, task_type)
25
+
26
+ else:
27
+ st.warning("🚨 Train a model first to see test results!")
src/training/train.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.compose import ColumnTransformer
2
+ from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelEncoder
3
+ from sklearn.model_selection import train_test_split
4
+ from sklearn.linear_model import LinearRegression, LogisticRegression
5
+ from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
6
+ from xgboost import XGBRegressor, XGBClassifier
7
+ from sklearn.svm import SVR, SVC
8
+ from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier
9
+ from sklearn.neighbors import KNeighborsRegressor, KNeighborsClassifier
10
+ from sklearn.linear_model import ElasticNet, BayesianRidge
11
+ from sklearn.ensemble import GradientBoostingRegressor, AdaBoostRegressor, GradientBoostingClassifier, AdaBoostClassifier
12
+ from sklearn.naive_bayes import GaussianNB
13
+ from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis, LinearDiscriminantAnalysis
14
+ from sklearn.linear_model import Ridge, Lasso
15
+ from sklearn.impute import SimpleImputer
16
+ from sklearn.preprocessing import StandardScaler, OneHotEncoder
17
+ from sklearn.compose import ColumnTransformer
18
+ from sklearn.pipeline import Pipeline as SkPipeline
19
+
20
+ import streamlit as st
21
+
22
+
23
+
24
+
25
+ def get_model(task_type, model_name, hyperparams):
26
+ """Returns the model instance based on user selection with hyperparameters."""
27
+ models = {
28
+ "regression": {
29
+ # Already existing:
30
+ "Linear Regression": LinearRegression,
31
+ "Random Forest Regressor": RandomForestRegressor,
32
+ "XGBoost Regressor": XGBRegressor,
33
+ # Additional regression models:
34
+ "Support Vector Regressor": SVR,
35
+ "Decision Tree Regressor": DecisionTreeRegressor,
36
+ "K-Nearest Neighbors Regressor": KNeighborsRegressor,
37
+ "ElasticNet": ElasticNet,
38
+ "Gradient Boosting Regressor": GradientBoostingRegressor,
39
+ "AdaBoost Regressor": AdaBoostRegressor,
40
+ "Bayesian Ridge": BayesianRidge,
41
+ "Ridge Regression": Ridge,
42
+ "Lasso Regression": Lasso ,
43
+
44
+ },
45
+ "classification": {
46
+ # Already existing:
47
+ "Logistic Regression": LogisticRegression,
48
+ "Random Forest": RandomForestClassifier,
49
+ "XGBoost": XGBClassifier,
50
+ # Additional classification models:
51
+ "Support Vector Classifier": SVC,
52
+ "Decision Tree Classifier": DecisionTreeClassifier,
53
+ "K-Nearest Neighbors Classifier": KNeighborsClassifier,
54
+ "Gradient Boosting Classifier": GradientBoostingClassifier,
55
+ "AdaBoost Classifier": AdaBoostClassifier,
56
+ "Gaussian Naive Bayes": GaussianNB,
57
+ "Quadratic Discriminant Analysis": QuadraticDiscriminantAnalysis,
58
+ "Linear Discriminant Analysis": LinearDiscriminantAnalysis
59
+ }
60
+ }
61
+
62
+
63
+ if task_type in models and model_name in models[task_type]:
64
+ return models[task_type][model_name](**hyperparams) # Apply hyperparameters
65
+ else:
66
+ raise ValueError(f"Invalid model selection: {model_name} for {task_type}")
67
+
68
+
69
+ def train_model(df, target_column, task_type, selected_model_name, hyperparams):
70
+ """Preprocess data, train the selected model with hyperparameters, and return the trained model."""
71
+
72
+ with st.spinner(" Training model... Please wait!"):
73
+
74
+ # Get the model with hyperparameters
75
+ model = get_model(task_type, selected_model_name, hyperparams)
76
+
77
+ # Split features and target
78
+ X = df.drop(columns=[target_column])
79
+ y = df[target_column]
80
+
81
+ # Label encode target if classification (for categorical labels)
82
+ label_encoder = None
83
+ if task_type == "classification" and y.dtype == "object":
84
+ from sklearn.preprocessing import LabelEncoder
85
+ label_encoder = LabelEncoder()
86
+ y = label_encoder.fit_transform(y)
87
+
88
+ # Train-Test Split (80-20)
89
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
90
+
91
+ # Identify numerical and categorical columns
92
+ num_cols = X.select_dtypes(include=["int64", "float64"]).columns
93
+ cat_cols = X.select_dtypes(include=["object", "category"]).columns
94
+
95
+ # Preprocessing Pipeline
96
+ # Numeric pipeline: impute missing values then scale them
97
+ num_pipeline = SkPipeline([
98
+ ("imputer", SimpleImputer(strategy="median")),
99
+ ("scaler", StandardScaler())
100
+ ])
101
+
102
+ # Categorical pipeline: impute missing values then one-hot encode them
103
+ cat_pipeline = SkPipeline([
104
+ ("imputer", SimpleImputer(strategy="most_frequent")),
105
+ ("encoder", OneHotEncoder(handle_unknown="ignore", sparse_output=False))
106
+ ])
107
+
108
+ preprocessor = ColumnTransformer([
109
+ ("num", num_pipeline, num_cols),
110
+ ("cat", cat_pipeline, cat_cols)
111
+ ])
112
+
113
+ pipeline = SkPipeline([
114
+ ("preprocessor", preprocessor),
115
+ ("model", model)
116
+ ])
117
+
118
+ # Train Model
119
+ pipeline.fit(X_train, y_train)
120
+
121
+ # Store test data and metadata in session state
122
+ st.session_state.X_test = X_test
123
+ st.session_state.y_test = y_test
124
+ st.session_state.task_type = task_type
125
+ st.session_state.label_encoder = label_encoder # Store label encoder for decoding predictions
126
+
127
+ # Reset test results calculation flag when a new model is trained
128
+ if "test_results_calculated" in st.session_state:
129
+ st.session_state.test_results_calculated = False
130
+
131
+ # Clear any previous test metrics to avoid using stale data
132
+ for key in ['test_metrics', 'test_y_pred', 'test_y_test', 'test_cm', 'sampling_message']:
133
+ if key in st.session_state:
134
+ del st.session_state[key]
135
+
136
+ # Return trained model + label encoder (needed for decoding predictions if classification)
137
+ if task_type == "classification":
138
+ return pipeline, label_encoder
139
+ else:
140
+ return pipeline
src/ui/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .visualization import visualize_data
2
+ from .css import load_css
3
+ from .loading import show_loading_state
4
+ from .footer import show_footer
5
+ from .welcome import show_welcome_page
6
+ from .test_results import display_test_results
7
+ from .overview import show_overview_page
8
+ from .insight import display_ai_insights
9
+
10
+ __all__ = [
11
+ "load_css",
12
+ "show_footer",
13
+ "show_loading_state",
14
+ "show_welcome_page",
15
+ "visualize_data",
16
+ "display_test_results",
17
+ "show_overview_page" ,
18
+ "display_ai_insights"
19
+
20
+ ]
src/ui/css.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/css.py
2
+ import streamlit as st
3
+
4
+ def load_css():
5
+ css = """
6
+ <style>
7
+
8
+ /* --- EXPLICIT COLOR DEFINITIONS --- */
9
+ :root {
10
+ --dark-bg: #0E1117;
11
+ --light-text: #FAFAFA;
12
+
13
+ --neon-green: #00FFA3;
14
+ --neon-font: Arial, Helvetica, sans-serif;
15
+ --tab-font-size: 18px;
16
+ --tab-bottom-border-height: 4px;
17
+ /* Component Styling Variables */
18
+ --card-bg: rgba(30, 30, 40, 0.5);
19
+ --card-border: rgba(0, 255, 163, 0.25);
20
+ --expander-header-bg: rgba(40, 40, 55, 0.6);
21
+ --expander-hover-bg: rgba(0, 255, 163, 0.1);
22
+ }
23
+
24
+ /* --- Base Styles --- */
25
+ body {
26
+ background-color: var(--dark-bg) !important;
27
+ color: var(--light-text) !important;
28
+ }
29
+ .stApp {
30
+ background-color: var(--dark-bg) !important;
31
+ }
32
+
33
+ /* --- App Header Styling --- */
34
+ .app-title {
35
+ color: var(--neon-green) !important;
36
+ font-family: var(--neon-font);
37
+ font-size: 80px !important;
38
+ font-weight: 700;
39
+ text-shadow:
40
+ 0 0 1px var(--neon-green),
41
+ 0 0 2px var(--neon-green),
42
+ 0 0 5px var(--neon-green),
43
+ 0 0 45px var(--neon-green);
44
+ margin-bottom: 25px !important;
45
+ line-height: 1.0 !important;
46
+ text-align: center !important;
47
+ }
48
+ .app-tagline {
49
+ color: var(--neon-green) !important;
50
+ font-family: var(--neon-font);
51
+ font-size: 27px !important;
52
+ font-style: normal !important;
53
+ font-weight: 400 !important;
54
+ text-shadow:
55
+ 0 0 5px var(--neon-green),
56
+ 0 0 10px var(--neon-green);
57
+ margin-top: 10px !important;
58
+ text-align: center !important;
59
+ }
60
+ .app-header {
61
+ padding: 1rem 0 !important;
62
+ margin-bottom: 2rem !important;
63
+ text-align: center !important;
64
+ }
65
+ /* --- End App Header Styling --- */
66
+
67
+
68
+ /* --- Tab Styling --- */
69
+ div[data-baseweb="tab-list"] button[data-baseweb="tab"],
70
+ div[data-baseweb="tab-list"] button[data-baseweb="tab"] > div,
71
+ div[data-baseweb="tab-list"] button[data-baseweb="tab"] > div > span {
72
+ font-size: var(--tab-font-size) !important;
73
+ font-family: var(--neon-font) !important;
74
+ font-weight: 600 !important;
75
+ }
76
+ div[data-baseweb="tab-list"] button[data-baseweb="tab"][aria-selected="true"],
77
+ div[data-baseweb="tab-list"] button[data-baseweb="tab"][aria-selected="true"] > div,
78
+ div[data-baseweb="tab-list"] button[data-baseweb="tab"][aria-selected="true"] > div > span {
79
+ font-size: var(--tab-font-size) !important;
80
+ }
81
+ div[data-baseweb="tab-list"] {
82
+ border: none !important; border-top: none !important; border-right: none !important; border-left: none !important; border-bottom: none !important;
83
+ border-color: transparent !important; outline: none !important; box-shadow: none !important;
84
+ margin-bottom: 25px !important; padding: 0 !important;
85
+ display: flex !important; justify-content: space-around !important;
86
+ }
87
+ button[data-baseweb="tab"] {
88
+ color: var(--light-text) !important; padding: 1rem 1.5rem !important;
89
+ transition: color 0.3s ease, text-shadow 0.3s ease, border-bottom-color 0.3s ease !important;
90
+ border-style: solid !important; border-width: 0 0 var(--tab-bottom-border-height) 0 !important;
91
+ border-color: transparent transparent transparent transparent !important;
92
+ outline: none !important; box-shadow: none !important; background-color: transparent !important;
93
+ margin: 0 !important; line-height: normal !important; flex-shrink: 0 !important;
94
+ }
95
+ button[data-baseweb="tab"]::before, button[data-baseweb="tab"]::after { display: none !important; content: none !important; }
96
+ button[data-baseweb="tab"]:hover:not([aria-selected="true"]) {
97
+ color: var(--neon-green) !important; background-color: transparent !important;
98
+ border-color: transparent transparent transparent transparent !important; outline: none !important; box-shadow: none !important;
99
+ text-shadow: 0 0 3px var(--neon-green), 0 0 6px var(--neon-green);
100
+ }
101
+ button[data-baseweb="tab"][aria-selected="true"] {
102
+ border: none !important; border-top: none !important; border-right: none !important; border-left: none !important; border-bottom: none !important;
103
+ border-color: transparent !important; outline: none !important; box-shadow: none !important; background-color: transparent !important;
104
+ color: var(--neon-green) !important; border-bottom-style: solid !important;
105
+ border-bottom-width: var(--tab-bottom-border-height) !important; border-bottom-color: var(--neon-green) !important;
106
+ text-shadow: 0 0 3px var(--neon-green), 0 0 6px var(--neon-green);
107
+ }
108
+ /* --- End Tab Styling --- */
109
+
110
+
111
+
112
+ /* --- Welcome Page Specific Styling --- */
113
+
114
+ /* Main Welcome Header (H1) - *** UPDATED COLOR *** */
115
+ .welcome-header h1 {
116
+ font-size: 2.8rem !important;
117
+ font-weight: 700 !important;
118
+ margin-bottom: 0.5rem !important;
119
+ color: var(--neon-green) !important; /* Use neon green */
120
+ text-align: left !important;
121
+ border-bottom: none !important;
122
+ /* Optional: Add a subtle glow like the main title */
123
+ text-shadow: 0 0 4px rgba(0, 255, 163, 0.7);
124
+ }
125
+ /* Main Welcome Subtitle (P) */
126
+ .welcome-header p.subtitle {
127
+ font-size: 1.15rem !important;
128
+ color: var(--subtitle-text) !important; /* Keep subtitle gray */
129
+ margin-bottom: 0 !important;
130
+ text-align: left !important;
131
+ }
132
+
133
+ /* Section Headers (H2 generated by st.markdown("## ...")) */
134
+ .stApp h2 {
135
+ font-size: 1.9rem !important;
136
+ font-weight: 600 !important;
137
+ color: var(--neon-green) !important;
138
+ border-bottom: 1px solid rgba(0, 255, 163, 0.3);
139
+ padding-bottom: 8px !important;
140
+ margin-top: 40px !important;
141
+ margin-bottom: 25px !important;
142
+ }
143
+ /* Override for Sidebar Title H2 */
144
+ section[data-testid="stSidebar"] h2 {
145
+ border-bottom: none !important;
146
+ color: #E6E6FA !important;
147
+ font-size: 1.8rem !important;
148
+ font-weight: 600 !important;
149
+ margin: 0 !important;
150
+ padding-bottom: 0 !important;
151
+ }
152
+
153
+ /* Feature Card Styling */
154
+ .feature-card {
155
+ background-color: var(--card-bg);
156
+ border: 1px solid var(--card-border);
157
+ border-radius: 8px;
158
+ padding: 1.5rem 1.75rem;
159
+ height: 100%;
160
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
161
+ transition: transform 0.2s ease-in-out, box-shadow 0.2s ease-in-out;
162
+ margin-bottom: 1rem;
163
+ }
164
+ .feature-card:hover {
165
+ transform: translateY(-3px);
166
+ box-shadow: 0 6px 12px rgba(0, 255, 163, 0.15);
167
+ }
168
+ /* Titles within cards (H3) */
169
+ .feature-card h3 {
170
+ font-size: 1.3rem !important;
171
+ font-weight: 600 !important;
172
+ color: var(--light-text) !important;
173
+ margin-top: 0 !important;
174
+ margin-bottom: 1rem !important;
175
+ border-bottom: none !important;
176
+ padding-bottom: 0 !important;
177
+ }
178
+ /* Lists within cards */
179
+ .feature-card ul {
180
+ padding-left: 0 !important; margin-left: 5px; margin-bottom: 0 !important;
181
+ list-style-type: none;
182
+ }
183
+ .feature-card ul li {
184
+ margin-bottom: 0.6rem !important; line-height: 1.5;
185
+ color: var(--subtitle-text) !important; position: relative; padding-left: 1.2em;
186
+ }
187
+ .feature-card ul li::before {
188
+ content: '▪'; color: var(--neon-green); font-weight: bold;
189
+ display: inline-block; position: absolute; left: 0; top: 0;
190
+ }
191
+
192
+ /* Getting Started List (Ordered List - OL) */
193
+ .stApp ol {
194
+ padding-left: 0 !important; margin-left: 5px; margin-bottom: 30px !important;
195
+ list-style-type: none; counter-reset: getting-started-counter;
196
+ }
197
+ .stApp ol li {
198
+ margin-bottom: 12px !important; line-height: 1.6 !important;
199
+ color: var(--light-text) !important; counter-increment: getting-started-counter;
200
+ position: relative; padding-left: 2.5em;
201
+ }
202
+ .stApp ol li::before {
203
+ content: counter(getting-started-counter); color: var(--dark-bg); background-color: var(--neon-green);
204
+ font-weight: bold; border-radius: 50%; width: 1.6em; height: 1.6em; display: inline-block;
205
+ text-align: center; line-height: 1.6em; position: absolute; left: 0; top: 0;
206
+ }
207
+ .stApp ol li strong { color: var(--neon-green) !important; font-weight: 600; }
208
+
209
+ /* Expander Styling */
210
+ div[data-testid="stExpander"] {
211
+ border: none !important; border-radius: 8px !important; margin-bottom: 1rem !important;
212
+ box-shadow: 0 2px 4px rgba(0,0,0,0.2); background-color: transparent !important; overflow: hidden;
213
+ }
214
+ div[data-testid="stExpander"] summary {
215
+ padding: 0.8rem 1.2rem !important; font-size: 1.1rem !important; font-weight: 600 !important;
216
+ color: var(--light-text) !important; background-color: var(--expander-header-bg) !important;
217
+ border: none !important; border-radius: 0 !important;
218
+ transition: background-color 0.2s ease, color 0.2s ease; cursor: pointer;
219
+ }
220
+ div[data-testid="stExpander"] summary:hover {
221
+ background-color: var(--expander-hover-bg) !important; color: var(--neon-green) !important;
222
+ }
223
+ div[data-testid="stExpander"] summary svg { fill: var(--light-text) !important; }
224
+ div[data-testid="stExpander"] summary:hover svg { fill: var(--neon-green) !important; }
225
+ div[data-testid="stExpander"] div[role="button"] + div { /* Content area */
226
+ padding: 1.2rem 1.5rem !important; background-color: var(--card-bg); border: none !important;
227
+ }
228
+ div[data-testid="stExpander"] div[role="button"] + div ul,
229
+ div[data-testid="stExpander"] div[role="button"] + div ol { margin-bottom: 0 !important; padding-left: 20px !important; list-style-type: disc; }
230
+ div[data-testid="stExpander"] div[role="button"] + div li {
231
+ color: var(--subtitle-text) !important; margin-bottom: 0.5rem !important; list-style-type: disc; padding-left: 0;
232
+ }
233
+ div[data-testid="stExpander"] div[role="button"] + div li::before { content: none !important; }
234
+ div[data-testid="stExpander"] a { color: var(--neon-green) !important; text-decoration: underline; }
235
+ div[data-testid="stExpander"] a:hover { text-shadow: 0 0 3px var(--neon-green); }
236
+
237
+ /* Footer Styling */
238
+ .footer {
239
+ margin-top: 4rem !important; padding: 1rem !important; font-size: 0.9rem !important;
240
+ color: var(--subtitle-text) !important; text-align: center; width: 100%;
241
+ position: relative; bottom: auto; left: auto;
242
+ border-top: 1px solid rgba(255, 255, 255, 0.1);
243
+ }
244
+ /* --- End Welcome Page Specific Styling --- */
245
+
246
+
247
+ /* --- Overview Tab Styling --- */
248
+
249
+ /* Style for st.metric containers */
250
+ div[data-testid="stMetric"] {
251
+ background-color: var(--card-bg); /* Use card background */
252
+ border: 1px solid var(--card-border); /* Use card border */
253
+ border-radius: 8px;
254
+ padding: 1rem 1.25rem;
255
+ box-shadow: 0 2px 4px rgba(0,0,0,0.15);
256
+ transition: transform 0.2s ease-in-out, box-shadow 0.2s ease-in-out;
257
+ height: 100%; /* Ensure metrics in a row are same height */
258
+ }
259
+ div[data-testid="stMetric"]:hover {
260
+ transform: translateY(-2px); /* Lift effect */
261
+ box-shadow: 0 4px 8px rgba(0, 255, 163, 0.1); /* Neon glow */
262
+ }
263
+
264
+ /* Style for st.metric Label */
265
+ div[data-testid="stMetric"] label[data-testid="stMetricLabel"] {
266
+ color: var(--subtitle-text) !important; /* Dimmer label */
267
+ font-weight: 500 !important;
268
+ font-size: 0.95rem !important;
269
+ }
270
+
271
+ /* Style for st.metric Value */
272
+ div[data-testid="stMetric"] div[data-testid="stMetricValue"] {
273
+ color: var(--neon-green) !important; /* Neon value */
274
+ font-size: 2.5rem !important; /* Larger value */
275
+ font-weight: 700 !important;
276
+ padding-top: 5px;
277
+ }
278
+ div[data-testid="stMetric"] div[data-testid="stMetricDelta"] {
279
+ /* Style delta if you use it */
280
+ font-weight: 500 !important;
281
+ }
282
+
283
+ /* Section Headers in Overview Tab (Reuse H3 style) */
284
+ /* The general .stApp h3 rule should cover this if specific enough */
285
+ /* If not, target more specifically: */
286
+ div[data-testid="stVerticalBlock"] h3 { /* Assuming Overview content is in a vertical block */
287
+ font-size: 1.75rem !important;
288
+ font-weight: 600 !important;
289
+ color: var(--neon-green) !important;
290
+ border-bottom: 1px solid rgba(0, 255, 163, 0.3);
291
+ padding-bottom: 8px !important;
292
+ margin-top: 30px !important; /* Adjust spacing */
293
+ margin-bottom: 20px !important;
294
+ }
295
+ /* Reset for feature card H3 if needed */
296
+ .feature-card h3 {
297
+ font-size: 1.3rem !important;
298
+ border-bottom: none !important;
299
+ padding-bottom: 0 !important;
300
+ }
301
+
302
+
303
+ /* DataFrame Styling */
304
+ div[data-testid="stDataFrame"] {
305
+ border: 1px solid var(--card-border) !important; /* Neon border */
306
+ border-radius: 8px;
307
+ overflow: hidden; /* Ensures border radius applies to table */
308
+ box-shadow: 0 2px 4px rgba(0,0,0,0.2);
309
+ }
310
+
311
+ /* DataFrame Header */
312
+ div[data-testid="stDataFrame"] .col_heading {
313
+ background-color: var(--expander-header-bg) !important; /* Darker header */
314
+ color: var(--light-text) !important;
315
+ font-weight: 600 !important;
316
+ font-size: 0.95rem !important;
317
+ text-align: left !important;
318
+ border-bottom: 1px solid var(--neon-green) !important; /* Neon underline */
319
+ }
320
+ div[data-testid="stDataFrame"] .col_heading:first-of-type {
321
+ border-top-left-radius: 7px; /* Match container radius */
322
+ }
323
+ div[data-testid="stDataFrame"] .col_heading:last-of-type {
324
+ border-top-right-radius: 7px; /* Match container radius */
325
+ }
326
+
327
+
328
+ /* DataFrame Cells */
329
+ div[data-testid="stDataFrame"] .dataframe td,
330
+ div[data-testid="stDataFrame"] .dataframe th { /* Also style index header */
331
+ color: var(--subtitle-text) !important;
332
+ border-bottom: 1px solid rgba(255, 255, 255, 0.1) !important; /* Faint row separators */
333
+ border-right: none !important; /* Remove vertical separators */
334
+ padding: 0.5rem 0.75rem !important;
335
+ font-size: 0.9rem !important;
336
+ }
337
+ div[data-testid="stDataFrame"] .dataframe th { /* Index header specifically */
338
+ background-color: rgba(30, 30, 40, 0.3); /* Slightly different background for index */
339
+ color: var(--light-text) !important;
340
+ font-weight: 500 !important;
341
+ }
342
+
343
+ /* DataFrame Rows Hover */
344
+ div[data-testid="stDataFrame"] .dataframe tr:hover td,
345
+ div[data-testid="stDataFrame"] .dataframe tr:hover th {
346
+ background-color: rgba(0, 255, 163, 0.05) !important; /* Faint neon hover */
347
+ color: var(--light-text) !important;
348
+ }
349
+
350
+ /* --- End Overview Tab Styling --- */
351
+
352
+
353
+
354
+
355
+
356
+ /* --- Hide Streamlit elements --- */
357
+ #MainMenu {visibility: hidden !important;}
358
+ header {visibility: hidden !important;}
359
+ .stDeployButton {display: none !important;}
360
+ div[data-testid="stToolbar"] {display: none !important;}
361
+ div[data-testid="stDecoration"] {display: none !important;}
362
+ div[data-testid="stStatusWidget"] {display: none !important;}
363
+ /* --- End Hide Streamlit elements --- */
364
+
365
+ /* --- Sidebar styling --- */
366
+ section[data-testid="stSidebar"] > div:first-child {
367
+ background-color: var(--dark-bg) !important;
368
+ }
369
+
370
+ div[data-testid="stMetric"] { color: var(--light-text) !important; }
371
+ div[data-testid="stMetric"] > div { color: var(--light-text) !important; }
372
+ div[data-testid="stMetric"] label { color: var(--light-text) !important; }
373
+ /* --- End Remaining Styles --- */
374
+
375
+ </style>
376
+ """
377
+ st.markdown(css, unsafe_allow_html=True)
src/ui/footer.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ def show_footer():
4
+ """Display footer with copyright information."""
5
+ footer_html = """
6
+ <div class="footer">
7
+ © 2025 AutoML All Rights Reserved.
8
+ </div>
9
+ """
10
+ st.markdown(footer_html, unsafe_allow_html=True)
src/ui/insight.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ def display_ai_insights():
4
+ """Displays AI-Powered Insights and Data Cleaning Process."""
5
+
6
+ st.header("💡 AI-Powered Insights")
7
+
8
+ with st.expander("🧹 Data Cleaning Process", expanded=True):
9
+ if "insights" in st.session_state and "df" in st.session_state:
10
+ # Split insights into cleaning process and analysis
11
+ parts = st.session_state.insights.split("ANALYSIS INSIGHTS:")
12
+
13
+ # Show cleaning instructions
14
+ st.markdown(parts[0])
15
+
16
+ # Show interactive dataframe preview using st.session_state.df
17
+ st.subheader("Cleaned Data Sample")
18
+ st.dataframe(
19
+ st.session_state.df.head(), # Use the existing df state
20
+ use_container_width=True,
21
+ hide_index=True,
22
+ )
23
+
24
+ # Show analysis insights if present
25
+ if len(parts) > 1:
26
+ st.markdown("---")
27
+ st.markdown("#### Analysis Insights")
28
+ st.markdown(parts[1])
29
+ else:
30
+ st.warning("No insights generated yet. Upload and process a file first.")
src/ui/loading.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import time
3
+
4
+ def show_loading_state():
5
+ """
6
+ Cyber-inspired loading animation with circuit-like effects
7
+ """
8
+ try:
9
+ st.html("""
10
+ <div class="loading-container-cyber">
11
+ <div class="rocket-animation-cyber">
12
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 100 100" class="cyber-rocket">
13
+ <path d="M50 10 L30 70 L50 65 L70 70 Z" fill="#00FFD1"/>
14
+ <path d="M40 80 L50 90 L60 80" stroke="#00FFD1" stroke-width="3" fill="none"/>
15
+ </svg>
16
+ </div>
17
+
18
+ <h1 class="title-cyber">AutoML</h1>
19
+
20
+ <h2 class="subtitle-cyber">You Ask , We Deliver</h2>
21
+
22
+ <div class="loading-content-cyber">
23
+ <p class="loading-text-cyber">Initializing neural networks...</p>
24
+ <div class="loading-bar-container-cyber">
25
+ <div class="loading-bar-cyber"></div>
26
+ </div>
27
+ </div>
28
+
29
+ <style>
30
+ body { background-color: #000000 !important; }
31
+
32
+ .loading-container-cyber {
33
+ display: flex;
34
+ flex-direction: column;
35
+ align-items: center;
36
+ justify-content: center;
37
+ min-height: 80vh;
38
+ text-align: center;
39
+ padding: 2rem;
40
+ background: radial-gradient(circle, rgba(0,0,0,1) 0%, rgba(0,0,0,1) 100%);
41
+ }
42
+
43
+ .cyber-rocket {
44
+ width: 100px;
45
+ height: 100px;
46
+ animation: pulse 2s infinite;
47
+ }
48
+
49
+ .title-cyber {
50
+ font-size: 3rem;
51
+ margin-bottom: 0.5rem;
52
+ color: #00FFD1;
53
+ text-shadow: 0 0 10px #00FFD1;
54
+ font-family: 'Orbitron', sans-serif;
55
+ }
56
+
57
+ .subtitle-cyber {
58
+ font-size: 1.5rem;
59
+ margin-bottom: 2rem;
60
+ color: #00A86B;
61
+ font-family: 'Chakra Petch', sans-serif;
62
+ }
63
+
64
+ .loading-content-cyber {
65
+ background: rgba(0, 255, 209, 0.05);
66
+ border: 1px solid rgba(0, 255, 209, 0.2);
67
+ padding: 1.5rem 2rem;
68
+ border-radius: 8px;
69
+ max-width: 600px;
70
+ width: 100%;
71
+ }
72
+
73
+ .loading-text-cyber {
74
+ margin: 0 0 1rem 0;
75
+ font-size: 1.1rem;
76
+ color: #00FFD1;
77
+ font-family: 'Chakra Petch', sans-serif;
78
+ }
79
+
80
+ .loading-bar-container-cyber {
81
+ height: 6px;
82
+ background: rgba(0, 255, 209, 0.2);
83
+ border-radius: 3px;
84
+ overflow: hidden;
85
+ }
86
+
87
+ .loading-bar-cyber {
88
+ height: 100%;
89
+ width: 30%;
90
+ background: linear-gradient(90deg, #00FFD1, #00A86B);
91
+ animation: circuit-load 1.5s cubic-bezier(0.4, 0.0, 0.2, 1) infinite;
92
+ }
93
+
94
+ @keyframes pulse {
95
+ 0%, 100% { transform: scale(1); }
96
+ 50% { transform: scale(1.1); }
97
+ }
98
+
99
+ @keyframes circuit-load {
100
+ 0% { transform: translateX(-100%); box-shadow: 0 0 10px #00FFD1; }
101
+ 50% { box-shadow: 0 0 20px #00FFD1; }
102
+ 100% { transform: translateX(400%); box-shadow: 0 0 10px #00FFD1; }
103
+ }
104
+ </style>
105
+ </div> """)
106
+ except Exception as e:
107
+ # Fallback to built-in Streamlit spinner if custom animation fails
108
+ st.warning("Custom loading animation unavailable. Using default spinner...")
109
+ with st.spinner("Loading, please wait..."):
110
+ time.sleep(3)
111
+
112
+
113
+ if __name__ == "__main__":
114
+ show_loading_state()
115
+ time.sleep(3)
116
+ st.empty()
117
+ st.success("App loaded successfully!")
118
+
119
+
src/ui/overview.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import streamlit as st
3
+ import pandas as pd
4
+
5
+ @st.cache_data
6
+ def compute_column_info(df):
7
+ """Compute summary statistics for each column."""
8
+ return pd.DataFrame({
9
+ "Column": df.dtypes.index,
10
+ "Type": df.dtypes.astype(str),
11
+ "Non-Null Count": df.count(),
12
+ "Null Count": df.isnull().sum(),
13
+ "Unique Values": df.nunique(),
14
+ })
15
+
16
+ def show_overview_page():
17
+ """Displays dataset statistics, preview, and column information."""
18
+
19
+ if "df" not in st.session_state or st.session_state.df is None:
20
+ st.warning("⚠️ No dataset loaded. Please upload a dataset first.")
21
+ return
22
+
23
+ df = st.session_state.df
24
+
25
+ # Dataset Statistics
26
+ st.markdown("## 📊 Dataset Statistics")
27
+ col1, col2, col3, col4 = st.columns(4)
28
+
29
+ with col1:
30
+ st.metric("Total Rows", len(df))
31
+ with col2:
32
+ st.metric("Total Columns", len(df.columns))
33
+ with col3:
34
+ numeric_count = len(df.select_dtypes(include=["int64", "float64"]).columns)
35
+ st.metric("Numeric Columns", numeric_count)
36
+ with col4:
37
+ categorical_count = len(df.select_dtypes(include=["object", "category"]).columns)
38
+ st.metric("Categorical Columns", categorical_count)
39
+
40
+ # Data Preview: Only display the top few rows
41
+ st.markdown("## 🔍 Data Preview")
42
+ st.dataframe(df.head(), use_container_width=True)
43
+
44
+ # Column Information: Use cached computation for faster loading
45
+ st.markdown("## 📌 Column Information")
46
+ dtypes_df = compute_column_info(df)
47
+ st.dataframe(dtypes_df, use_container_width=True)
src/ui/test_results.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import io
3
+ import numpy as np
4
+ import pandas as pd
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
+ import plotly.express as px
8
+ from sklearn.metrics import (
9
+ accuracy_score,
10
+ precision_score,
11
+ recall_score,
12
+ f1_score,
13
+ confusion_matrix,
14
+ mean_absolute_error,
15
+ mean_squared_error,
16
+ r2_score,
17
+ )
18
+
19
+ # ==== LLM Setup with Caching ====
20
+ @st.cache_resource(show_spinner=False) # Disable default spinner
21
+ def get_llm():
22
+ """Cached LLM initialization to prevent reloading on every rerun"""
23
+ from langchain_google_genai import ChatGoogleGenerativeAI
24
+ from langchain_groq import ChatGroq
25
+ import os
26
+
27
+ try:
28
+ return ChatGroq(
29
+ model="gemma2-9b-it",
30
+ groq_api_key=os.getenv("GROQ_API_KEY")
31
+ )
32
+
33
+ except Exception as e:
34
+ try:
35
+ return ChatGoogleGenerativeAI(
36
+ model="gemini-2.0-flash-lite-preview-02-05",
37
+ google_api_key=os.getenv("GEMINI_API_KEY")
38
+ )
39
+ except:
40
+ return None
41
+
42
+ llm_insights = get_llm()
43
+
44
+ # ==== Cached Metric Calculations ====
45
+ @st.cache_data(show_spinner=False) # Add to heavy computations
46
+ def _compute_classification_metrics(y_test, y_pred):
47
+ """Cached metric computation for classification"""
48
+ return {
49
+ 'accuracy': accuracy_score(y_test, y_pred),
50
+ 'precision': precision_score(y_test, y_pred, average="weighted", zero_division=0),
51
+ 'recall': recall_score(y_test, y_pred, average="weighted", zero_division=0),
52
+ 'f1': f1_score(y_test, y_pred, average="weighted", zero_division=0),
53
+ 'cm': confusion_matrix(y_test, y_pred)
54
+ }
55
+
56
+ @st.cache_data
57
+ def _compute_regression_metrics(y_test, y_pred):
58
+ """Cached metric computation for regression"""
59
+ return {
60
+ 'mae': mean_absolute_error(y_test, y_pred),
61
+ 'mse': mean_squared_error(y_test, y_pred),
62
+ 'rmse': np.sqrt(mean_squared_error(y_test, y_pred)),
63
+ 'r2': r2_score(y_test, y_pred)
64
+ }
65
+
66
+ # ==== Cached Visualization Generation ====
67
+ @st.cache_data(show_spinner=False) # Add to heavy computations
68
+ def _plot_confusion_matrix(cm, classes):
69
+ """Cached confusion matrix plotting"""
70
+ fig, ax = plt.subplots(figsize=(2, 2), dpi=200)
71
+ sns.heatmap(
72
+ cm,
73
+ annot=True,
74
+ fmt="d",
75
+ cmap="Blues",
76
+ xticklabels=classes,
77
+ yticklabels=classes,
78
+ annot_kws={"size": 8},
79
+ )
80
+ plt.xticks(fontsize=5)
81
+ plt.yticks(fontsize=5)
82
+
83
+ buf = io.BytesIO()
84
+ fig.savefig(buf, format="png", bbox_inches="tight", dpi=200)
85
+ buf.seek(0)
86
+ return buf
87
+
88
+
89
+ # ==== Optimized Insights Generation ====
90
+ @st.cache_data(show_spinner=False) # Add to heavy computations
91
+ def _get_insights_classification(accuracy, precision, recall, f1, cm_shape):
92
+ """Cached insights generation based on metrics"""
93
+ if llm_insights is None:
94
+ return (
95
+ f"### Classification Metrics Explained\n\n"
96
+ f"**Accuracy** ({accuracy:.3f}): Correct predictions ratio\n"
97
+ f"**Precision** ({precision:.3f}): Positive prediction accuracy\n"
98
+ f"**Recall** ({recall:.3f}): Actual positives found\n"
99
+ f"**F1 Score** ({f1:.3f}): Precision-Recall balance\n"
100
+ f"Confusion Matrix ({cm_shape[0]}x{cm_shape[1]}): Prediction vs Actual distribution"
101
+ )
102
+
103
+ try:
104
+ response = llm_insights.invoke(f"""
105
+ Briefly explain these classification metrics (accuracy={accuracy:.3f},
106
+ precision={precision:.3f}, recall={recall:.3f}, f1={f1:.3f})
107
+ and {cm_shape[0]}x{cm_shape[1]} confusion matrix.
108
+ Use markdown bullet points.
109
+ """)
110
+ return response.content.strip()
111
+ except:
112
+ return "Could not generate AI insights - showing basic metrics explanation."
113
+
114
+
115
+
116
+ def display_test_results(trained_model, X_test, y_test, task_type, label_encoder=None):
117
+ """
118
+ Displays test results, including metrics, confusion matrix (if classification),
119
+ and LLM-based or fallback insights about the metrics.
120
+ """
121
+
122
+ # Create a placeholder for the loading message at the top of the page
123
+ st.markdown("## Test Results")
124
+ loading_placeholder = st.empty()
125
+
126
+ # Show initial loading message
127
+ with loading_placeholder.container():
128
+ st.info("⏳ Evaluating model performance on test data. This may take a moment for large datasets.")
129
+ progress_bar = st.progress(0)
130
+
131
+ # Set a flag to track if results have been calculated
132
+ if "test_results_calculated" not in st.session_state:
133
+ st.session_state.test_results_calculated = False
134
+
135
+ # Only perform calculations if they haven't been done yet
136
+ if not st.session_state.test_results_calculated:
137
+
138
+ sampling_message = None
139
+ MAX_SAMPLES = 5000 # Increased from 50 to 5000
140
+
141
+ # Update progress - Starting evaluation
142
+ with loading_placeholder.container():
143
+ progress_bar.progress(10)
144
+
145
+ if len(X_test) <= MAX_SAMPLES:
146
+ # Use all test data
147
+ X_test_sample = X_test
148
+ y_test_sample = y_test
149
+ st.info("🔍 Using all test data for evaluation...")
150
+ else:
151
+ # Use sampling for large datasets
152
+ sampling_message = f"📊 Using {MAX_SAMPLES} samples from the test set for visualization (out of {len(X_test)} total)"
153
+ st.info("🔍 Sampling test data for evaluation...")
154
+
155
+ # Simple random sampling
156
+ idx = np.random.choice(len(X_test.index if hasattr(X_test, 'index') else X_test), size=MAX_SAMPLES, replace=False)
157
+ X_test_sample = X_test.iloc[idx] if hasattr(X_test, 'iloc') else X_test[idx]
158
+ y_test_sample = y_test.iloc[idx] if hasattr(y_test, 'iloc') else y_test[idx]
159
+
160
+ # Generate predictions
161
+ with loading_placeholder.container():
162
+ progress_bar.progress(30)
163
+ st.info("🔄 Generating predictions... Please wait")
164
+ # Add a spinner for visual feedback during prediction
165
+ with st.spinner("Model working..."):
166
+ if task_type == "regression":
167
+ y_pred = trained_model.predict(X_test_sample)
168
+ elif task_type == "classification":
169
+ pipeline, enc = trained_model if label_encoder is None else (trained_model, label_encoder)
170
+ y_pred = pipeline.predict(X_test_sample)
171
+
172
+ # Decode if label_encoder is used
173
+ if enc:
174
+ y_pred = enc.inverse_transform(y_pred)
175
+ y_test_decoded = enc.inverse_transform(y_test_sample)
176
+ else:
177
+ y_test_decoded = y_test_sample
178
+
179
+ # Update progress - Computing metrics
180
+ with loading_placeholder.container():
181
+ progress_bar.progress(60)
182
+ st.info("📊 Computing metrics...")
183
+
184
+ # Compute metrics
185
+ if task_type == "regression":
186
+ metrics = _compute_regression_metrics(y_test_sample, y_pred)
187
+ else:
188
+ metrics = _compute_classification_metrics(y_test_decoded, y_pred)
189
+
190
+ # Update progress - Preparing visualizations
191
+ with loading_placeholder.container():
192
+ progress_bar.progress(90)
193
+ st.info("📈 Preparing visualizations...")
194
+
195
+ # For classification, pre-calculate confusion matrix before showing "ready" message
196
+ if task_type == "classification":
197
+ # Pre-calculate confusion matrix (this is the slow part)
198
+ _ = _plot_confusion_matrix(metrics['cm'], np.unique(y_test_decoded))
199
+ # Pre-calculate insights (also potentially slow with LLM)
200
+ _ = _get_insights_classification(
201
+ metrics['accuracy'],
202
+ metrics['precision'],
203
+ metrics['recall'],
204
+ metrics['f1'],
205
+ metrics['cm'].shape
206
+ )
207
+
208
+ # Update progress - Complete (only after all calculations are done)
209
+ with loading_placeholder.container():
210
+ progress_bar.progress(100)
211
+ st.success("✅ Test results ready!")
212
+
213
+ # Mark results as calculated
214
+ st.session_state.test_results_calculated = True
215
+
216
+ # Store results in session state for reuse
217
+ st.session_state.test_metrics = metrics
218
+ if task_type == "classification":
219
+ st.session_state.test_y_pred = y_pred
220
+ st.session_state.test_y_test = y_test_decoded
221
+ else:
222
+ st.session_state.test_y_pred = y_pred
223
+ st.session_state.test_y_test = y_test_sample
224
+
225
+ # Store sampling message
226
+ st.session_state.sampling_message = sampling_message
227
+
228
+ # Import time only when needed (moved from global to local scope)
229
+ import time
230
+ time.sleep(0.5) # Short delay to show the "Test results ready!" message
231
+
232
+ # Display sampling message if it exists
233
+ if "sampling_message" in st.session_state and st.session_state.sampling_message:
234
+ st.info(st.session_state.sampling_message)
235
+
236
+ # Display the results using stored values
237
+ if task_type == "regression":
238
+ st.subheader("🔍 Regression Metrics")
239
+
240
+ # Get metrics from session state or use the ones we just calculated
241
+ if "test_metrics" in st.session_state and st.session_state.test_results_calculated:
242
+ metrics = st.session_state.test_metrics
243
+ y_pred = st.session_state.test_y_pred
244
+ y_test = st.session_state.test_y_test
245
+
246
+ mae, mse, rmse, r2 = metrics['mae'], metrics['mse'], np.sqrt(metrics['mse']), metrics['r2']
247
+
248
+ col1, col2, col3, col4 = st.columns(4)
249
+ col1.metric("📉 MAE", f"{mae:.4f}")
250
+ col2.metric("📊 MSE", f"{mse:.4f}")
251
+ col3.metric("📈 RMSE", f"{rmse:.4f}")
252
+ col4.metric("📌 R² Score", f"{r2:.4f}")
253
+
254
+ # Add regression visualization
255
+ st.subheader("📈 Prediction vs Actual")
256
+ df_results = pd.DataFrame({
257
+ 'Actual': y_test,
258
+ 'Predicted': y_pred
259
+ })
260
+ fig = px.scatter(df_results, x='Actual', y='Predicted',
261
+ title='Predicted vs Actual Values',
262
+ labels={'Actual': 'Actual Values', 'Predicted': 'Predicted Values'})
263
+ fig.add_shape(type='line', x0=min(y_test), y0=min(y_test),
264
+ x1=max(y_test), y1=max(y_test),
265
+ line=dict(color='red', dash='dash'))
266
+ st.plotly_chart(fig, use_container_width=True)
267
+
268
+ elif task_type == "classification":
269
+ st.subheader("🔍 Classification Metrics")
270
+
271
+ # Get metrics from session state or use the ones we just calculated
272
+ if "test_metrics" in st.session_state and st.session_state.test_results_calculated:
273
+ metrics = st.session_state.test_metrics
274
+ y_pred = st.session_state.test_y_pred
275
+ y_test_decoded = st.session_state.test_y_test
276
+
277
+ accuracy, precision, recall, f1 = metrics['accuracy'], metrics['precision'], metrics['recall'], metrics['f1']
278
+
279
+ col1, col2, col3, col4 = st.columns(4)
280
+ col1.metric("✅ Accuracy", f"{accuracy:.4f}")
281
+ col2.metric("🎯 Precision", f"{precision:.4f}")
282
+ col3.metric("📢 Recall", f"{recall:.4f}")
283
+ col4.metric("🔥 F1 Score", f"{f1:.4f}")
284
+
285
+ st.subheader("📊 Confusion Matrix")
286
+ # Use cached function for confusion matrix visualization
287
+ buf = _plot_confusion_matrix(metrics['cm'], np.unique(y_test_decoded))
288
+ st.image(buf, width=450)
289
+
290
+ # === Additional Insights Section ===
291
+ st.markdown("---")
292
+ st.markdown("#### Test Insights")
293
+ accuracy, precision, recall, f1 = metrics['accuracy'], metrics['precision'], metrics['recall'], metrics['f1']
294
+ classification_insights = _get_insights_classification(accuracy, precision, recall, f1, metrics['cm'].shape)
295
+ st.markdown(classification_insights)
src/ui/visualization.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import streamlit as st
3
+ import plotly.express as px
4
+ import plotly.graph_objects as go
5
+ from plotly.subplots import make_subplots
6
+ import pandas as pd
7
+ import numpy as np
8
+ from src.utils.logging import log_frontend_error, log_frontend_warning
9
+
10
+ SAMPLE_SIZE = 10000 # Define a sample size for subsampling large datasets
11
+
12
+ # Efficiently hash a dataframe to detect changes
13
+ @st.cache_data(show_spinner=False)
14
+ def compute_df_hash(df):
15
+ """Optimized dataframe hashing"""
16
+ return hash((df.shape, pd.util.hash_pandas_object(df.iloc[:min(100, len(df))]).sum())) # Sample-based hashing
17
+
18
+
19
+ @st.cache_data(show_spinner=False, ttl=3600) # Cache for 1 hour
20
+ def is_potential_date_column(series, sample_size=5):
21
+ """Check if column might contain dates"""
22
+ # Check column name first
23
+ if any(keyword in series.name.lower() for keyword in ['date', 'time', 'year', 'month', 'day']):
24
+ return True
25
+
26
+ # Check sample values
27
+ sample = series.dropna().head(sample_size).astype(str)
28
+ date_patterns = [
29
+ r'\d{4}-\d{2}-\d{2}', # YYYY-MM-DD
30
+ r'\d{2}/\d{2}/\d{4}', # MM/DD/YYYY
31
+ r'\d{2}-\w{3}-\d{2,4}', # DD-MON-YY(Y)
32
+ r'\d{1,2} \w{3,} \d{4}' # 1 January 2023
33
+ ]
34
+
35
+ date_count = sum(1 for val in sample if any(re.match(p, val) for p in date_patterns))
36
+ return date_count / len(sample) > 0.5 if len(sample) > 0 else False # >50% match
37
+
38
+
39
+
40
+
41
+ # Cache column type detection with improved performance
42
+ @st.cache_data(show_spinner=False, ttl=3600) # Cache for 1 hour
43
+ def get_column_types(df):
44
+ """Detect column types efficiently and cache the results."""
45
+ column_types = {}
46
+
47
+ # Process columns in batches for better performance
48
+ for chunk_start in range(0, len(df.columns), 10):
49
+ chunk_end = min(chunk_start + 10, len(df.columns))
50
+ chunk_columns = df.columns[chunk_start:chunk_end]
51
+
52
+ for column in chunk_columns:
53
+ # Check for numeric columns
54
+ if pd.api.types.is_numeric_dtype(df[column]):
55
+ # Detect if it's a binary column (0/1, True/False)
56
+ if df[column].nunique() <= 2:
57
+ column_types[column] = "BINARY"
58
+ # Detect if it's a discrete numeric column (few unique values)
59
+ elif df[column].nunique() < 20:
60
+ column_types[column] = "NUMERIC_DISCRETE"
61
+ # Otherwise it's a continuous numeric column
62
+ else:
63
+ column_types[column] = "NUMERIC_CONTINUOUS"
64
+ else:
65
+ # Check for temporal/date columns
66
+ if is_potential_date_column(df[column]):
67
+ try:
68
+ # Attempt conversion with coerce
69
+ converted = pd.to_datetime(df[column], errors='coerce')
70
+ if not converted.isnull().all(): # At least some valid dates
71
+ column_types[column] = "TEMPORAL"
72
+ continue
73
+ except Exception:
74
+ pass
75
+
76
+ # Check for ID-like columns (high cardinality with unique patterns)
77
+ if (df[column].nunique() > len(df) * 0.9 and
78
+ any(x in column.lower() for x in ['id', 'code', 'key', 'uuid', 'identifier'])):
79
+ column_types[column] = "ID"
80
+ # Check for categorical columns (low to medium cardinality)
81
+ elif df[column].nunique() <= 20:
82
+ column_types[column] = "CATEGORICAL"
83
+ # Otherwise it's a text column
84
+ else:
85
+ column_types[column] = "TEXT"
86
+
87
+ return column_types
88
+
89
+
90
+
91
+
92
+ # Cache correlation matrix computation with improved performance
93
+ @st.cache_data(show_spinner=False, ttl=3600) # Cache for 1 hour
94
+ def get_corr_matrix(df):
95
+ """Compute and cache the correlation matrix for numeric columns."""
96
+ # Only select numeric columns to avoid errors
97
+ numeric_cols = df.select_dtypes(include=[np.number]).columns
98
+
99
+ # If we have too many numeric columns, sample them for better performance
100
+ if len(numeric_cols) > 30:
101
+ numeric_cols = numeric_cols[:30]
102
+
103
+ # Return correlation matrix if we have at least 2 numeric columns
104
+ return df[numeric_cols].corr() if len(numeric_cols) > 1 else None
105
+
106
+
107
+
108
+
109
+
110
+ # Cache subsampled data with improved performance
111
+ @st.cache_data(show_spinner=False, ttl=3600) # Cache for 1 hour
112
+ def get_subsampled_data(df, column):
113
+ """Return subsampled data for faster visualization."""
114
+ # Check if column exists
115
+ if column not in df.columns:
116
+ return pd.DataFrame()
117
+
118
+ # Use stratified sampling for categorical columns if possible
119
+ if df[column].nunique() < 20 and len(df) > SAMPLE_SIZE:
120
+ try:
121
+ # Try to get a representative sample
122
+ fractions = min(0.5, SAMPLE_SIZE / len(df))
123
+ return df[[column]].groupby(column, group_keys=False).apply(
124
+ lambda x: x.sample(max(1, int(fractions * len(x))), random_state=42)
125
+ )
126
+ except Exception:
127
+ # Fall back to random sampling
128
+ pass
129
+
130
+ # Use random sampling
131
+ return df[[column]].sample(min(len(df), SAMPLE_SIZE), random_state=42)
132
+
133
+
134
+
135
+
136
+ # Cache chart creation with improved performance
137
+ @st.cache_data(show_spinner=False, ttl=1800, hash_funcs={ # Cache for 30 minutes
138
+ pd.DataFrame: compute_df_hash,
139
+ pd.Series: lambda s: hash((s.name, compute_df_hash(s.to_frame())))
140
+ })
141
+ def create_chart(df, column, column_type):
142
+ """Generate optimized charts based on column type."""
143
+ # Check if column exists in the dataframe
144
+ if column not in df.columns:
145
+ return None
146
+
147
+ # Get subsampled data for better performance
148
+ df_sample = get_subsampled_data(df, column)
149
+ if df_sample.empty:
150
+ return None
151
+
152
+ try:
153
+ # Year-based columns (special case)
154
+ if "year" in column.lower():
155
+ fig = make_subplots(rows=1, cols=2, subplot_titles=("Year Distribution", "Box Plot"),
156
+ specs=[[{"type": "bar"}, {"type": "box"}]], column_widths=[0.7, 0.3], horizontal_spacing=0.1)
157
+ year_counts = df_sample[column].value_counts().sort_index()
158
+ fig.add_trace(go.Bar(x=year_counts.index, y=year_counts.values, marker_color='#7B68EE'), row=1, col=1)
159
+ fig.add_trace(go.Box(x=df_sample[column], marker_color='#7B68EE'), row=1, col=2)
160
+
161
+ # Binary columns (0/1, True/False)
162
+ elif column_type == "BINARY":
163
+ value_counts = df_sample[column].value_counts()
164
+ fig = make_subplots(rows=1, cols=2,
165
+ subplot_titles=("Distribution", "Percentage"),
166
+ specs=[[{"type": "bar"}, {"type": "pie"}]],
167
+ column_widths=[0.5, 0.5],
168
+ horizontal_spacing=0.1)
169
+
170
+ fig.add_trace(go.Bar(
171
+ x=value_counts.index,
172
+ y=value_counts.values,
173
+ marker_color=['#FF4B4B', '#4CAF50'],
174
+ text=value_counts.values,
175
+ textposition='auto'
176
+ ), row=1, col=1)
177
+
178
+ fig.add_trace(go.Pie(
179
+ labels=value_counts.index,
180
+ values=value_counts.values,
181
+ marker=dict(colors=['#FF4B4B', '#4CAF50']),
182
+ textinfo='percent+label'
183
+ ), row=1, col=2)
184
+
185
+ fig.update_layout(title_text=f"Binary Distribution: {column}")
186
+
187
+ # Numeric continuous columns
188
+ elif column_type == "NUMERIC_CONTINUOUS":
189
+ fig = make_subplots(rows=2, cols=2,
190
+ subplot_titles=("Distribution", "Box Plot", "Violin Plot", "Cumulative Distribution"),
191
+ specs=[[{"type": "histogram"}, {"type": "box"}],
192
+ [{"type": "violin"}, {"type": "scatter"}]],
193
+ vertical_spacing=0.15,
194
+ horizontal_spacing=0.1)
195
+
196
+ # Histogram
197
+ fig.add_trace(go.Histogram(
198
+ x=df_sample[column],
199
+ nbinsx=30,
200
+ marker_color='#FF4B4B',
201
+ opacity=0.7
202
+ ), row=1, col=1)
203
+
204
+ # Box plot
205
+ fig.add_trace(go.Box(
206
+ x=df_sample[column],
207
+ marker_color='#FF4B4B',
208
+ boxpoints='outliers'
209
+ ), row=1, col=2)
210
+
211
+ # Violin plot
212
+ fig.add_trace(go.Violin(
213
+ x=df_sample[column],
214
+ marker_color='#FF4B4B',
215
+ box_visible=True,
216
+ points='outliers'
217
+ ), row=2, col=1)
218
+
219
+ # CDF
220
+ sorted_data = np.sort(df_sample[column].dropna())
221
+ cumulative = np.arange(1, len(sorted_data) + 1) / len(sorted_data)
222
+
223
+ fig.add_trace(go.Scatter(
224
+ x=sorted_data,
225
+ y=cumulative,
226
+ mode='lines',
227
+ line=dict(color='#FF4B4B', width=2)
228
+ ), row=2, col=2)
229
+
230
+ fig.update_layout(height=600, title_text=f"Continuous Variable Analysis: {column}")
231
+
232
+ # Numeric discrete columns
233
+ elif column_type == "NUMERIC_DISCRETE":
234
+ value_counts = df_sample[column].value_counts().sort_index()
235
+ fig = make_subplots(rows=1, cols=2,
236
+ subplot_titles=("Distribution", "Percentage"),
237
+ specs=[[{"type": "bar"}, {"type": "pie"}]],
238
+ column_widths=[0.7, 0.3],
239
+ horizontal_spacing=0.1)
240
+
241
+ fig.add_trace(go.Bar(
242
+ x=value_counts.index,
243
+ y=value_counts.values,
244
+ marker_color='#FF4B4B',
245
+ text=value_counts.values,
246
+ textposition='auto'
247
+ ), row=1, col=1)
248
+
249
+ fig.add_trace(go.Pie(
250
+ labels=value_counts.index,
251
+ values=value_counts.values,
252
+ marker=dict(colors=px.colors.sequential.Reds),
253
+ textinfo='percent+label'
254
+ ), row=1, col=2)
255
+
256
+ fig.update_layout(title_text=f"Discrete Numeric Distribution: {column}")
257
+
258
+ # Categorical columns
259
+ elif column_type == "CATEGORICAL":
260
+ value_counts = df_sample[column].value_counts().head(20) # Limit to top 20 categories
261
+ fig = make_subplots(rows=1, cols=2,
262
+ subplot_titles=("Category Distribution", "Percentage Breakdown"),
263
+ specs=[[{"type": "bar"}, {"type": "pie"}]],
264
+ column_widths=[0.6, 0.4],
265
+ horizontal_spacing=0.1)
266
+
267
+ # Bar chart
268
+ fig.add_trace(go.Bar(
269
+ x=value_counts.index,
270
+ y=value_counts.values,
271
+ marker_color='#00FFA3',
272
+ text=value_counts.values,
273
+ textposition='auto'
274
+ ), row=1, col=1)
275
+
276
+ # Pie chart
277
+ fig.add_trace(go.Pie(
278
+ labels=value_counts.index,
279
+ values=value_counts.values,
280
+ marker=dict(colors=px.colors.sequential.Greens),
281
+ textinfo='percent+label'
282
+ ), row=1, col=2)
283
+
284
+ fig.update_layout(title_text=f"Categorical Analysis: {column}")
285
+
286
+ # Temporal/date columns
287
+ elif column_type == "TEMPORAL":
288
+ # Convert with safe datetime parsing
289
+ dates = pd.to_datetime(df_sample[column], errors='coerce', format='mixed')
290
+ valid_dates = dates[dates.notna()]
291
+
292
+ fig = make_subplots(
293
+ rows=2,
294
+ cols=2,
295
+ subplot_titles=("Monthly Pattern", "Yearly Pattern", "Cumulative Trend", "Day of Week Distribution"),
296
+ vertical_spacing=0.15,
297
+ horizontal_spacing=0.1,
298
+ specs=[[{"type": "bar"}, {"type": "bar"}],
299
+ [{"type": "scatter"}, {"type": "bar"}]]
300
+ )
301
+
302
+ # Monthly pattern
303
+ if not valid_dates.empty:
304
+ monthly_counts = valid_dates.dt.month.value_counts().sort_index()
305
+ month_names = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
306
+ month_labels = [month_names[i-1] for i in monthly_counts.index]
307
+
308
+ fig.add_trace(go.Bar(
309
+ x=month_labels,
310
+ y=monthly_counts.values,
311
+ marker_color='#7B68EE',
312
+ text=monthly_counts.values,
313
+ textposition='auto'
314
+ ), row=1, col=1)
315
+
316
+ # Yearly pattern
317
+ yearly_counts = valid_dates.dt.year.value_counts().sort_index()
318
+
319
+ fig.add_trace(go.Bar(
320
+ x=yearly_counts.index,
321
+ y=yearly_counts.values,
322
+ marker_color='#7B68EE',
323
+ text=yearly_counts.values,
324
+ textposition='auto'
325
+ ), row=1, col=2)
326
+
327
+ # Cumulative trend
328
+ sorted_dates = valid_dates.sort_values()
329
+ cumulative = np.arange(1, len(sorted_dates) + 1)
330
+
331
+ fig.add_trace(go.Scatter(
332
+ x=sorted_dates,
333
+ y=cumulative,
334
+ mode='lines',
335
+ line=dict(color='#7B68EE', width=2)
336
+ ), row=2, col=1)
337
+
338
+ # Day of week distribution
339
+ dow_counts = valid_dates.dt.dayofweek.value_counts().sort_index()
340
+ dow_names = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
341
+ dow_labels = [dow_names[i] for i in dow_counts.index]
342
+
343
+ fig.add_trace(go.Bar(
344
+ x=dow_labels,
345
+ y=dow_counts.values,
346
+ marker_color='#7B68EE',
347
+ text=dow_counts.values,
348
+ textposition='auto'
349
+ ), row=2, col=2)
350
+
351
+ fig.update_layout(height=600, title_text=f"Temporal Analysis: {column}")
352
+
353
+ # ID columns (show distribution of first few characters, length distribution)
354
+ elif column_type == "ID":
355
+ # Calculate ID length statistics
356
+ id_lengths = df_sample[column].astype(str).str.len()
357
+
358
+ # Extract first 2 characters for prefix analysis
359
+ id_prefixes = df_sample[column].astype(str).str[:2].value_counts().head(15)
360
+
361
+ fig = make_subplots(
362
+ rows=1,
363
+ cols=2,
364
+ subplot_titles=("ID Length Distribution", "Common ID Prefixes"),
365
+ horizontal_spacing=0.1,
366
+ specs=[[{"type": "histogram"}, {"type": "bar"}]]
367
+ )
368
+
369
+ # ID length histogram
370
+ fig.add_trace(go.Histogram(
371
+ x=id_lengths,
372
+ nbinsx=20,
373
+ marker_color='#9C27B0'
374
+ ), row=1, col=1)
375
+
376
+ # ID prefix bar chart
377
+ fig.add_trace(go.Bar(
378
+ x=id_prefixes.index,
379
+ y=id_prefixes.values,
380
+ marker_color='#9C27B0',
381
+ text=id_prefixes.values,
382
+ textposition='auto'
383
+ ), row=1, col=2)
384
+
385
+ fig.update_layout(title_text=f"ID Analysis: {column}")
386
+
387
+ # Text columns
388
+ elif column_type == "TEXT":
389
+ # For text columns, show top values and length distribution
390
+ value_counts = df_sample[column].value_counts().head(15)
391
+
392
+ # Calculate text length statistics
393
+ text_lengths = df_sample[column].astype(str).str.len()
394
+
395
+ fig = make_subplots(
396
+ rows=2,
397
+ cols=1,
398
+ subplot_titles=("Top Values", "Text Length Distribution"),
399
+ vertical_spacing=0.2,
400
+ specs=[[{"type": "bar"}], [{"type": "histogram"}]]
401
+ )
402
+
403
+ # Top values bar chart
404
+ fig.add_trace(
405
+ go.Bar(
406
+ x=value_counts.index,
407
+ y=value_counts.values,
408
+ marker_color='#00B4D8',
409
+ text=value_counts.values,
410
+ textposition='auto'
411
+ ),
412
+ row=1, col=1
413
+ )
414
+
415
+ # Text length histogram
416
+ fig.add_trace(
417
+ go.Histogram(
418
+ x=text_lengths,
419
+ nbinsx=20,
420
+ marker_color='#00B4D8'
421
+ ),
422
+ row=2, col=1
423
+ )
424
+
425
+ fig.update_layout(
426
+ height=600,
427
+ title_text=f"Text Analysis: {column}"
428
+ )
429
+
430
+ # Fallback for any other column type
431
+ else:
432
+ fig = go.Figure(go.Histogram(x=df_sample[column], marker_color='#888'))
433
+ fig.update_layout(title_text=f"Generic Analysis: {column}")
434
+
435
+ # Common layout settings
436
+ fig.update_layout(
437
+ height=400,
438
+ showlegend=False,
439
+ plot_bgcolor='rgba(0,0,0,0)',
440
+ paper_bgcolor='rgba(0,0,0,0)',
441
+ font=dict(color='#FFFFFF'),
442
+ margin=dict(l=40, r=40, t=50, b=40)
443
+ )
444
+
445
+ return fig
446
+
447
+ except Exception as e:
448
+ log_frontend_error("Chart Generation", f"Error creating chart for {column}: {str(e)}")
449
+ return None
450
+
451
+
452
+
453
+
454
+ def visualize_data(df):
455
+ """Automated dashboard with optimized visualizations."""
456
+ if df is None or df.empty:
457
+ st.error("❌ No data available. Please upload and clean your data first.")
458
+ return
459
+
460
+ # Calculate dataframe hash only once
461
+ df_hash = compute_df_hash(df)
462
+
463
+ # Initialize selected columns in session state if not already present
464
+ if "selected_viz_columns" not in st.session_state:
465
+ # Initialize with first 4 columns or fewer if df has fewer columns
466
+ initial_columns = list(df.columns[:min(4, len(df.columns))])
467
+ st.session_state.selected_viz_columns = initial_columns
468
+
469
+ # Filter out any columns that no longer exist in the dataframe
470
+ valid_columns = [col for col in st.session_state.selected_viz_columns if col in df.columns]
471
+
472
+ # Define a callback function to update selected columns
473
+ def on_column_selection_change():
474
+ # Store the selected columns in session state
475
+ st.session_state.selected_viz_columns = st.session_state.viz_column_selector
476
+ # Ensure we stay on the visualization tab (index 2)
477
+ st.session_state.current_tab_index = 2
478
+
479
+ # Use session state for the multiselect with a consistent key and callback
480
+ selected_columns = st.multiselect(
481
+ "Select columns to visualize",
482
+ options=df.columns,
483
+ default=valid_columns,
484
+ key="viz_column_selector",
485
+ on_change=on_column_selection_change
486
+ )
487
+
488
+ # Check if we need to recompute column types and correlation matrix
489
+ # This will only happen when:
490
+ # 1. We don't have column_types in session_state
491
+ # 2. The dataframe hash has changed (new data)
492
+ # 3. We're using a user-uploaded dataset for the first time
493
+ recompute_needed = (
494
+ "column_types" not in st.session_state or
495
+ "df_hash" not in st.session_state or
496
+ st.session_state.get("df_hash") != df_hash
497
+ )
498
+
499
+ if recompute_needed:
500
+ with st.spinner("🔄 Analyzing data structure..."):
501
+ # Compute and cache column types
502
+ st.session_state.column_types = get_column_types(df)
503
+ # Compute and cache correlation matrix
504
+ st.session_state.corr_matrix = get_corr_matrix(df)
505
+ # Update the dataframe hash
506
+ st.session_state.df_hash = df_hash
507
+ # Ensure we stay on the visualization tab
508
+ st.session_state.current_tab_index = 2
509
+
510
+ # Reset any test results if the data has changed
511
+ if "test_results_calculated" in st.session_state:
512
+ st.session_state.test_results_calculated = False
513
+ # Clear any previous test metrics to avoid using stale data
514
+ for key in ['test_metrics', 'test_y_pred', 'test_y_test', 'test_cm', 'sampling_message']:
515
+ if key in st.session_state:
516
+ del st.session_state[key]
517
+
518
+ # Use cached values from session state
519
+ column_types = st.session_state.column_types
520
+ corr_matrix = st.session_state.corr_matrix
521
+
522
+ if selected_columns:
523
+ # Use a container to wrap all visualizations
524
+ viz_container = st.container()
525
+
526
+ with viz_container:
527
+ for idx in range(0, len(selected_columns), 2):
528
+ col1, col2 = st.columns(2)
529
+
530
+ for i, col in enumerate([col1, col2]):
531
+ if idx + i < len(selected_columns):
532
+ column = selected_columns[idx + i]
533
+ with col:
534
+ # Use consistent keys for charts based on column name
535
+ chart_key = f"plot_{column.replace(' ', '_')}"
536
+
537
+ # Only create chart if column exists in column_types
538
+ if column in column_types:
539
+ fig = create_chart(df, column, column_types[column])
540
+ if fig:
541
+ st.plotly_chart(fig, use_container_width=True, key=chart_key)
542
+ with st.expander(f"📊 Summary Statistics - {column}", expanded=False):
543
+ if "NUMERIC" in column_types[column]:
544
+ st.dataframe(df[column].describe(), key=f"stats_{column.replace(' ', '_')}")
545
+ else:
546
+ st.dataframe(df[column].value_counts(), key=f"counts_{column.replace(' ', '_')}")
547
+ else:
548
+ st.warning(f"⚠️ Column '{column}' not found in the dataset or its type couldn't be determined.")
549
+
550
+ if corr_matrix is not None:
551
+ st.subheader("🔗 Correlation Analysis")
552
+ fig = px.imshow(corr_matrix, title="Correlation Matrix", color_continuous_scale="RdBu")
553
+ st.plotly_chart(fig, use_container_width=True, key="corr_matrix_plot")
554
+
555
+ else:
556
+ st.info("👆 Please select columns to visualize")
src/ui/welcome.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from src.preprocessing.clean_data import cached_clean_csv
3
+ import pandas as pd
4
+ from functools import lru_cache
5
+
6
+ # Cache static content to avoid recomputation
7
+ @lru_cache(maxsize=1)
8
+ def get_static_content():
9
+ """Cache static HTML content to avoid regeneration."""
10
+ welcome_header = """
11
+ <div class="welcome-header" style="text-align: left; margin-bottom: 2rem;">
12
+ <h1>Experience Ai like never before</h1>
13
+ <p class="subtitle">
14
+ Performance, Analysis, Insights Made Simple.
15
+ </p>
16
+ </div>
17
+ """
18
+ features_header = "## ✨ Key Features"
19
+ feature_cards = [
20
+ """
21
+ <div class="feature-card">
22
+ <h3>📊 Data Analysis</h3>
23
+ <ul>
24
+ <li>Automated data cleaning</li>
25
+ <li>Interactive visualizations</li>
26
+ <li>Statistical insights</li>
27
+ <li>Correlation analysis</li>
28
+ </ul>
29
+ </div>
30
+ """,
31
+ """
32
+ <div class="feature-card">
33
+ <h3>🤖 Machine Learning</h3>
34
+ <ul>
35
+ <li>Multiple ML algorithms</li>
36
+ <li>Automated model selection</li>
37
+ <li>Hyperparameter tuning</li>
38
+ <li>Performance metrics</li>
39
+ </ul>
40
+ </div>
41
+ """,
42
+ """
43
+ <div class="feature-card">
44
+ <h3>🔍 AI Insights</h3>
45
+ <ul>
46
+ <li>Data quality checks</li>
47
+ <li>Feature importance</li>
48
+ <li>Model explanations</li>
49
+ <li>Smart recommendations</li>
50
+ </ul>
51
+ </div>
52
+ """
53
+ ]
54
+ getting_started = """
55
+ ## 🚀 Getting Started
56
+ 1. **Upload Your Dataset**: Use the sidebar to upload your CSV file
57
+ 2. **Explore Data**: View statistics and visualizations in the Overview tab
58
+ 3. **Train Models**: Select algorithms and tune parameters
59
+ 4. **Get Insights**: Receive AI-powered recommendations
60
+ """
61
+ dataset_requirements = """
62
+ * File format: CSV
63
+ * Maximum size: 200MB
64
+ * Supported column types:
65
+ * Numeric (int, float)
66
+ * Categorical (string, boolean)
67
+ * Temporal (date, datetime)
68
+ * Clean data preferred, but not required
69
+ """
70
+ example_datasets = """
71
+ Try these example datasets to explore the app:
72
+ * [Iris Dataset](https://archive.ics.uci.edu/ml/datasets/iris)
73
+ * [Boston Housing](https://www.kaggle.com/c/boston-housing)
74
+ * [Wine Quality](https://archive.ics.uci.edu/ml/datasets/wine+quality)
75
+ """
76
+ return welcome_header, features_header, feature_cards, getting_started, dataset_requirements, example_datasets
77
+
78
+ def show_welcome_page():
79
+ """Display welcome page with features and instructions efficiently."""
80
+ # Load cached static content
81
+ welcome_header, features_header, feature_cards, getting_started, dataset_requirements, example_datasets = get_static_content()
82
+
83
+ # Render static content
84
+ st.markdown(welcome_header, unsafe_allow_html=True)
85
+ st.markdown(features_header, unsafe_allow_html=True)
86
+
87
+
88
+
89
+ # Feature columns with minimal overhead
90
+ col1, col2, col3 = st.columns(3, gap="medium")
91
+ with col1:
92
+ st.markdown(feature_cards[0], unsafe_allow_html=True)
93
+ with col2:
94
+ st.markdown(feature_cards[1], unsafe_allow_html=True)
95
+ with col3:
96
+ st.markdown(feature_cards[2], unsafe_allow_html=True)
97
+
98
+ st.markdown("<br>", unsafe_allow_html=True) # Spacing
99
+
100
+ # Getting Started and Expanders
101
+ st.markdown(getting_started, unsafe_allow_html=True)
102
+ with st.expander("📋 Dataset Requirements"):
103
+ st.markdown(dataset_requirements)
104
+
105
+ with st.expander("🎯 Example Datasets"):
106
+ st.markdown(example_datasets)
107
+
108
+
109
+
110
+ # New File Uploader Section
111
+ st.markdown("### 📤 Upload Your Dataset (Currently Using Default Dataset)")
112
+
113
+ # Add a checkbox to indicate if the dataset is already cleaned
114
+ skip_cleaning = st.checkbox("My dataset is already cleaned (skip cleaning)")
115
+
116
+ uploaded_file = st.file_uploader("Upload CSV file", type=["csv"])
117
+
118
+ if uploaded_file is not None:
119
+ try:
120
+ # Validate file size
121
+ file_details = {"FileName": uploaded_file.name, "FileType": uploaded_file.type, "FileSize": uploaded_file.size}
122
+ if uploaded_file.size > 200 * 1024 * 1024: # 200MB limit
123
+ st.error("❌ File size exceeds 200MB limit. Please upload a smaller file.")
124
+ return
125
+
126
+ # Attempt to read the CSV
127
+ try:
128
+ df = pd.read_csv(uploaded_file)
129
+ if df.empty:
130
+ st.error("❌ The uploaded file is empty. Please upload a file with data.")
131
+ return
132
+
133
+ st.success("✅ Dataset uploaded successfully!")
134
+ except pd.errors.EmptyDataError:
135
+ st.error("❌ The uploaded file is empty. Please upload a file with data.")
136
+ return
137
+ except pd.errors.ParserError:
138
+ st.error("❌ Unable to parse the CSV file. Please ensure it's properly formatted.")
139
+ return
140
+
141
+ # Convert dataframe to JSON for caching
142
+ df_json = df.to_json(orient='records')
143
+
144
+ # Use the cached cleaning function with proper error handling
145
+ with st.spinner("🧠 AI is analyzing and cleaning the data..." if not skip_cleaning else "Processing dataset..."):
146
+ try:
147
+ cleaned_df, insights = cached_clean_csv(df_json, skip_cleaning)
148
+ except Exception as cleaning_error:
149
+ st.error(f"❌ Error during data cleaning: {str(cleaning_error)}")
150
+ # Fallback to using the original dataframe
151
+ st.warning("⚠️ Using original dataset without cleaning due to errors.")
152
+ cleaned_df = df
153
+ insights = "Cleaning failed, using original data."
154
+
155
+ # Save results to session state
156
+ st.session_state.df = cleaned_df
157
+ st.session_state.insights = insights
158
+ st.session_state.data_cleaned = True
159
+ st.session_state.dataset_loaded = True
160
+
161
+ # Store a flag to indicate this is a user-uploaded dataset
162
+ st.session_state.is_user_uploaded = True
163
+
164
+ # Store the original dataframe JSON and skip_cleaning preference
165
+ # This helps prevent redundant cleaning
166
+ st.session_state.original_df_json = df_json
167
+ st.session_state.skip_cleaning = skip_cleaning
168
+
169
+ # Reset visualization and model training related session state
170
+ if "column_types" in st.session_state:
171
+ del st.session_state.column_types
172
+ if "corr_matrix" in st.session_state:
173
+ del st.session_state.corr_matrix
174
+ if "df_hash" in st.session_state:
175
+ del st.session_state.df_hash
176
+ if "test_results_calculated" in st.session_state:
177
+ st.session_state.test_results_calculated = False
178
+
179
+ if skip_cleaning:
180
+ st.success("✅ Using uploaded dataset as-is (skipped cleaning).")
181
+ else:
182
+ st.success("✅ Data cleaned successfully!")
183
+
184
+ except Exception as e:
185
+ st.error(f"❌ Error processing dataset: {str(e)}")
186
+ st.info("ℹ️ Please check that your file is a valid CSV and try again.")
src/utils/logging.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from datetime import datetime
4
+
5
+ # Create logs directory if it doesn't exist
6
+ logs_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'logs')
7
+ if not os.path.exists(logs_dir):
8
+ os.makedirs(logs_dir)
9
+
10
+ # Configure the logger
11
+ def setup_logger():
12
+ """
13
+ Set up and configure the frontend error logger.
14
+ """
15
+ # Create a logger instance
16
+ logger = logging.getLogger('frontend_logger')
17
+ logger.setLevel(logging.DEBUG)
18
+
19
+ # Create a file handler
20
+ log_file = os.path.join(logs_dir, f'frontend_errors_{datetime.now().strftime("%Y%m%d")}.log')
21
+ file_handler = logging.FileHandler(log_file)
22
+ file_handler.setLevel(logging.DEBUG)
23
+
24
+ # Create a console handler
25
+ console_handler = logging.StreamHandler()
26
+ console_handler.setLevel(logging.ERROR)
27
+
28
+ # Create a formatter
29
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
30
+ file_handler.setFormatter(formatter)
31
+ console_handler.setFormatter(formatter)
32
+
33
+ # Add handlers to logger
34
+ logger.addHandler(file_handler)
35
+ logger.addHandler(console_handler)
36
+
37
+ return logger
38
+
39
+ # Initialize logger
40
+ frontend_logger = setup_logger()
41
+
42
+ def log_frontend_error(error_type: str, error_message: str, additional_info: dict = None):
43
+ """
44
+ Log frontend errors with detailed information.
45
+
46
+ Args:
47
+ error_type (str): Type of error (e.g., 'Arrow Conversion', 'Model Training', etc.)
48
+ error_message (str): The error message
49
+ additional_info (dict, optional): Additional context about the error
50
+ """
51
+ error_details = f"Type: {error_type}\nMessage: {error_message}"
52
+ if additional_info:
53
+ error_details += f"\nAdditional Info: {additional_info}"
54
+
55
+ frontend_logger.error(error_details)
56
+
57
+ def log_frontend_warning(warning_type: str, warning_message: str, additional_info: dict = None):
58
+ """
59
+ Log frontend warnings with detailed information.
60
+
61
+ Args:
62
+ warning_type (str): Type of warning
63
+ warning_message (str): The warning message
64
+ additional_info (dict, optional): Additional context about the warning
65
+ """
66
+ warning_details = f"Type: {warning_type}\nMessage: {warning_message}"
67
+ if additional_info:
68
+ warning_details += f"\nAdditional Info: {additional_info}"
69
+
70
+ frontend_logger.warning(warning_details)