Spaces:
Sleeping
Sleeping
| # ui/callbacks.py | |
| # -*- coding: utf-8 -*- | |
| # | |
| # PROJECT: CognitiveEDA v5.9 - The QuantumLeap Intelligence Platform | |
| # | |
| # DESCRIPTION: This module is updated with a generic, data-agnostic | |
| # stratification engine. It dynamically identifies candidate | |
| # features for filtering and updates the UI accordingly. | |
| import gradio as gr | |
| import pandas as pd | |
| import logging | |
| from threading import Thread | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| from core.analyzer import DataAnalyzer, engineer_features | |
| from core.llm import GeminiNarrativeGenerator | |
| from core.config import settings | |
| from modules.clustering import perform_clustering | |
| from modules.profiling import profile_clusters | |
| # --- Primary Analysis Chain --- | |
| def run_initial_analysis(file_obj, progress=gr.Progress(track_tqdm=True)): | |
| if file_obj is None: raise gr.Error("No file uploaded.") | |
| progress(0, desc="Validating configuration...") | |
| if not settings.GOOGLE_API_KEY: raise gr.Error("CRITICAL: GOOGLE_API_KEY is not configured.") | |
| try: | |
| progress(0.1, desc="Loading raw data...") | |
| df_raw = pd.read_csv(file_obj.name) if file_obj.name.endswith('.csv') else pd.read_excel(file_obj.name) | |
| if len(df_raw) > settings.MAX_UI_ROWS: | |
| df_raw = df_raw.sample(n=settings.MAX_UI_ROWS, random_state=42) | |
| progress(0.5, desc="Applying strategic feature engineering...") | |
| df_engineered = engineer_features(df_raw) | |
| progress(0.8, desc="Instantiating analysis engine...") | |
| analyzer = DataAnalyzer(df_engineered) | |
| progress(1.0, desc="Analysis complete. Generating reports...") | |
| return analyzer | |
| except Exception as e: | |
| logging.error(f"Error in initial analysis: {e}", exc_info=True) | |
| raise gr.Error(f"Analysis Failed: {str(e)}") | |
| def generate_reports_and_visuals(analyzer, progress=gr.Progress(track_tqdm=True)): | |
| """ | |
| Phase 2: Now populates the generic 'Stratify By' dropdown with candidate columns. | |
| """ | |
| if not isinstance(analyzer, DataAnalyzer): | |
| yield (None,) * 15 | |
| return | |
| progress(0, desc="Spawning AI report thread...") | |
| ai_report_queue = [""] | |
| def generate_ai_report_threaded(a): ai_report_queue[0] = GeminiNarrativeGenerator(settings.GOOGLE_API_KEY).generate_narrative(a) | |
| thread = Thread(target=generate_ai_report_threaded, args=(analyzer,)) | |
| thread.start() | |
| progress(0.4, desc="Generating reports and visuals...") | |
| meta = analyzer.metadata | |
| missing_df, num_df, cat_df = analyzer.get_profiling_reports() | |
| fig_types, fig_missing, fig_corr = analyzer.get_overview_visuals() | |
| # --- Dynamically identify candidate columns for stratification --- | |
| candidate_cols = ["(Do not stratify)"] | |
| if 'categorical_cols' in meta: | |
| for col in meta['categorical_cols']: | |
| # A good candidate has more than 1 but fewer than 50 unique values (heuristic) | |
| if analyzer.df[col].dtype.name != 'object' or (1 < analyzer.df[col].nunique() < 50): | |
| candidate_cols.append(col) | |
| initial_updates = ( | |
| gr.update(value="⏳ Generating AI report..."), gr.update(value=missing_df), | |
| gr.update(value=num_df), gr.update(value=cat_df), gr.update(value=fig_types), | |
| gr.update(value=fig_missing), gr.update(value=fig_corr), | |
| gr.update(choices=meta.get('numeric_cols', [])), | |
| gr.update(choices=meta.get('numeric_cols', [])), | |
| gr.update(choices=meta.get('numeric_cols', [])), | |
| gr.update(choices=meta.get('columns', [])), gr.update(visible=bool(meta.get('datetime_cols'))), | |
| gr.update(visible=bool(meta.get('text_cols'))), gr.update(visible=len(meta.get('numeric_cols', [])) > 1), | |
| gr.update(choices=candidate_cols, value="(Do not stratify)") # dd_stratify_by_col | |
| ) | |
| yield initial_updates | |
| thread.join() | |
| progress(1.0, desc="AI Report complete!") | |
| final_updates_list = list(initial_updates) | |
| final_updates_list[0] = gr.update(value=ai_report_queue[0]) | |
| yield tuple(final_updates_list) | |
| # --- Stratification Callbacks --- | |
| def update_filter_dropdown(analyzer, stratify_col): | |
| """ | |
| When the user selects a feature to stratify by, this function populates | |
| the second dropdown with the unique values of that feature. | |
| """ | |
| if not isinstance(analyzer, DataAnalyzer) or not stratify_col or stratify_col == "(Do not stratify)": | |
| return gr.update(choices=[], value=None, interactive=False) | |
| values = ["(Global Analysis)"] + sorted(analyzer.df[stratify_col].unique().tolist()) | |
| return gr.update(choices=values, value="(Global Analysis)", interactive=True) | |
| def update_stratified_clustering(analyzer, stratify_col, filter_value, k): | |
| """ | |
| Orchestrates the full clustering workflow on a dataset that is generically | |
| filtered based on user selections. | |
| """ | |
| if not isinstance(analyzer, DataAnalyzer): | |
| return go.Figure(), go.Figure(), "", "", go.Figure() | |
| logging.info(f"Updating clustering. Stratify by: '{stratify_col}', Filter: '{filter_value}', K={k}") | |
| # Step 1: Stratify the DataFrame based on user selection | |
| analysis_df = analyzer.df | |
| report_title_prefix = "Global Analysis: " | |
| if stratify_col and stratify_col != "(Do not stratify)" and filter_value and filter_value != "(Global Analysis)": | |
| analysis_df = analyzer.df[analyzer.df[stratify_col] == filter_value] | |
| report_title_prefix = f"Analysis for '{stratify_col}' = '{filter_value}': " | |
| if len(analysis_df) < k: | |
| error_msg = f"Not enough data ({len(analysis_df)} rows) to form {k} clusters for the selected filter." | |
| return go.Figure(), go.Figure(), error_msg, error_msg, go.Figure() | |
| # Step 2: Perform Clustering | |
| numeric_cols = [c for c in analyzer.metadata['numeric_cols'] if c in analysis_df.columns] | |
| fig_cluster, fig_elbow, summary, cluster_labels = perform_clustering( | |
| analysis_df, numeric_cols, k | |
| ) | |
| if cluster_labels.empty: | |
| return fig_cluster, fig_elbow, summary, "Clustering failed.", go.Figure() | |
| # Step 3: Profile the resulting clusters | |
| cats_to_profile = [c for c in analyzer.metadata['categorical_cols'] if c in analysis_df.columns] | |
| numeric_to_profile = [c for c in numeric_cols if c not in ['Month', 'Day_of_Week', 'Is_Weekend', 'Hour']] | |
| md_personas, fig_profile = profile_clusters( | |
| analysis_df, cluster_labels, numeric_to_profile, cats_to_profile | |
| ) | |
| summary = f"**{report_title_prefix}**" + summary | |
| md_personas = f"**{report_title_prefix}**" + md_personas | |
| # Step 4: Return all results | |
| return fig_cluster, fig_elbow, summary, md_personas, fig_profile | |
| # --- Other Callbacks --- | |
| def create_histogram(analyzer, col): | |
| if not isinstance(analyzer, DataAnalyzer) or not col: return go.Figure() | |
| return px.histogram(analyzer.df, x=col, title=f"<b>Distribution of {col}</b>", marginal="box") | |
| def create_scatterplot(analyzer, x_col, y_col, color_col): | |
| if not isinstance(analyzer, DataAnalyzer) or not x_col or not y_col: return go.Figure() | |
| df_sample = analyzer.df.sample(n=min(len(analyzer.df), 10000)) | |
| return px.scatter(df_sample, x=x_col, y=y_col, color=color_col if color_col else None) |