Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import os | |
| import json | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from langchain_groq import ChatGroq | |
| from prompts import ENHANCED_SYSTEM_PROMPT, SAMPLE_QUESTIONS, get_chart_prompt, validate_plot_spec, INSIGHTS_SYSTEM_PROMPT, get_insights_prompt | |
| GROQ_API_KEY = "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" | |
| llm = None | |
| uploaded_df = None | |
| dataset_name = None | |
| def initialize_llm(): | |
| global llm | |
| try: | |
| os.environ["GROQ_API_KEY"] = GROQ_API_KEY | |
| llm = ChatGroq(model="llama-3.3-70b-versatile", api_key=GROQ_API_KEY, temperature=0.0) | |
| return True | |
| except: | |
| return False | |
| def upload_dataset(file): | |
| global uploaded_df, dataset_name | |
| if file is None: | |
| return "No file uploaded", gr.update(visible=False), gr.update(choices=[]), gr.update(visible=False) | |
| try: | |
| dataset_name = os.path.basename(file.name) | |
| if file.name.endswith('.csv'): | |
| uploaded_df = pd.read_csv(file.name) | |
| elif file.name.endswith(('.xlsx', '.xls')): | |
| uploaded_df = pd.read_excel(file.name) | |
| else: | |
| return "Unsupported file format. Please upload CSV or Excel files.", gr.update(visible=False), gr.update(choices=[]), gr.update(visible=False) | |
| uploaded_df = clean_numeric(uploaded_df) | |
| info_text = f"**Dataset Loaded:** {dataset_name} ({uploaded_df.shape[0]} rows × {uploaded_df.shape[1]} columns)" | |
| return info_text, gr.update(visible=False), gr.update(choices=list(uploaded_df.columns), value=[]), gr.update(visible=True) | |
| except Exception as e: | |
| return f"Error loading file: {str(e)}", gr.update(visible=False), gr.update(choices=[]), gr.update(visible=False) | |
| def clean_numeric(df): | |
| df = df.copy() | |
| for col in df.columns: | |
| if pd.api.types.is_string_dtype(df[col]) or df[col].dtype == object: | |
| s = df[col].astype(str).str.strip() | |
| if s.str.contains("%", na=False).any(): | |
| numeric_vals = pd.to_numeric(s.str.replace("%", "", regex=False), errors="coerce") | |
| if numeric_vals.notna().sum() / len(df) > 0.5: | |
| df[col] = numeric_vals / 100.0 | |
| continue | |
| cleaned = s.str.replace(",", "", regex=False).str.replace("₹", "", regex=False).str.replace("$", "", regex=False) | |
| numeric_vals = pd.to_numeric(cleaned, errors="coerce") | |
| if numeric_vals.notna().sum() / len(df) > 0.5: | |
| df[col] = numeric_vals | |
| return df | |
| def display_data_format(format_type, selected_columns): | |
| global uploaded_df | |
| if uploaded_df is None or format_type == "None": | |
| return None | |
| df_to_show = uploaded_df[selected_columns] if selected_columns else uploaded_df | |
| return df_to_show.head(100) if format_type == "DataFrame" else None | |
| def display_text_format(format_type, selected_columns): | |
| global uploaded_df | |
| if uploaded_df is None or format_type == "None": | |
| return "" | |
| df_to_show = uploaded_df[selected_columns] if selected_columns else uploaded_df | |
| if format_type == "JSON": | |
| return df_to_show.head(20).to_json(orient='records', indent=2) | |
| elif format_type == "Dictionary": | |
| return str(df_to_show.head(20).to_dict(orient='records')) | |
| return "" | |
| def run_analysis(analysis_type, selected_columns): | |
| global uploaded_df | |
| if uploaded_df is None: | |
| return "Please upload a dataset first.", gr.update(visible=True), None | |
| if analysis_type == "None" or analysis_type is None: | |
| return "", gr.update(visible=False), None | |
| whole_dataset_analyses = ["Summary", "Top 5 Rows", "Bottom 5 Rows", "Missing Values"] | |
| if analysis_type in whole_dataset_analyses: | |
| df_to_analyze = uploaded_df | |
| else: | |
| if not selected_columns: | |
| return f"Please select columns for {analysis_type} analysis.", gr.update(visible=True), None | |
| df_to_analyze = uploaded_df[selected_columns] | |
| try: | |
| if analysis_type == "Summary": | |
| numeric_cols = uploaded_df.select_dtypes(include=[np.number]).columns | |
| categorical_cols = uploaded_df.select_dtypes(include=['object', 'category']).columns | |
| result = f"Dataset Summary:\nRows: {len(uploaded_df):,}\nColumns: {len(uploaded_df.columns)}\nNumeric Columns: {len(numeric_cols)}\nText Columns: {len(categorical_cols)}\n\n" | |
| if len(numeric_cols) > 0: | |
| result += "Numeric Columns: " + ", ".join(numeric_cols.tolist()) + "\n" | |
| if len(categorical_cols) > 0: | |
| result += "Text Columns: " + ", ".join(categorical_cols.tolist()) | |
| return result, gr.update(visible=True), None | |
| elif analysis_type == "Describe": | |
| result = "Column Description:\n" + "=" * 30 + "\n\n" | |
| for col in selected_columns: | |
| if col in df_to_analyze.columns: | |
| result += f"Column: {col}\n" | |
| if pd.api.types.is_numeric_dtype(df_to_analyze[col]): | |
| stats = df_to_analyze[col].describe() | |
| result += f" Type: Numeric\n Count: {stats['count']:.0f}\n Mean: {stats['mean']:.3f}\n Std: {stats['std']:.3f}\n Min: {stats['min']:.3f}\n 25%: {stats['25%']:.3f}\n 50%: {stats['50%']:.3f}\n 75%: {stats['75%']:.3f}\n Max: {stats['max']:.3f}\n\n" | |
| else: | |
| unique_count = df_to_analyze[col].nunique() | |
| null_count = df_to_analyze[col].isnull().sum() | |
| most_common = df_to_analyze[col].mode().iloc[0] if len(df_to_analyze[col].mode()) > 0 else "N/A" | |
| result += f" Type: Categorical/Text\n Unique Values: {unique_count}\n Missing Values: {null_count}\n Most Common: {most_common}\n" | |
| top_values = df_to_analyze[col].value_counts().head(5) | |
| result += " Top Values:\n" | |
| for val, count in top_values.items(): | |
| result += f" {val}: {count} times\n" | |
| result += "\n" | |
| return result, gr.update(visible=True), None | |
| elif analysis_type == "Top 5 Rows": | |
| return "Top 5 Rows - See data table below", gr.update(visible=True), df_to_analyze.head(5) | |
| elif analysis_type == "Bottom 5 Rows": | |
| return "Bottom 5 Rows - See data table below", gr.update(visible=True), df_to_analyze.tail(5) | |
| elif analysis_type == "Missing Values": | |
| missing = df_to_analyze.isnull().sum() | |
| result = "Missing Values Analysis:\n" + "=" * 30 + "\n\n" | |
| for col in df_to_analyze.columns: | |
| missing_count = missing[col] | |
| missing_percent = (missing_count / len(df_to_analyze)) * 100 | |
| result += f"{col}: {missing_count} missing ({missing_percent:.2f}%)\n" | |
| return result, gr.update(visible=True), None | |
| elif analysis_type == "Highest Correlation": | |
| numeric_cols = df_to_analyze.select_dtypes(include=[np.number]).columns | |
| if len(numeric_cols) < 2: | |
| return "Need at least 2 numeric columns for correlation analysis.", gr.update(visible=True), None | |
| corr_matrix = df_to_analyze[numeric_cols].corr() | |
| result = "Highest Correlations:\n" + "=" * 25 + "\n\n" | |
| correlations = [] | |
| for i in range(len(corr_matrix.columns)): | |
| for j in range(i+1, len(corr_matrix.columns)): | |
| col1, col2 = corr_matrix.columns[i], corr_matrix.columns[j] | |
| corr_val = corr_matrix.iloc[i, j] | |
| correlations.append((abs(corr_val), col1, col2, corr_val)) | |
| correlations.sort(reverse=True) | |
| for _, col1, col2, corr_val in correlations[:10]: | |
| result += f"{col1} ↔ {col2}: {corr_val:.3f}\n" | |
| return result, gr.update(visible=True), None | |
| elif analysis_type == "Group & Aggregate": | |
| if not selected_columns: | |
| result = "Please select columns for grouping and aggregation." | |
| else: | |
| categorical_cols = [col for col in selected_columns if not pd.api.types.is_numeric_dtype(df_to_analyze[col])] | |
| numeric_cols = [col for col in selected_columns if pd.api.types.is_numeric_dtype(df_to_analyze[col])] | |
| if categorical_cols and numeric_cols: | |
| group_col = categorical_cols[0] | |
| agg_col = numeric_cols[0] | |
| grouped = df_to_analyze.groupby(group_col)[agg_col].agg(['count', 'mean', 'sum']).round(2) | |
| result = f"Group & Aggregate Analysis:\n" + "=" * 35 + "\n\n" | |
| result += f"Grouped by: {group_col}\nAggregated: {agg_col}\n\n" | |
| result += grouped.to_string() | |
| elif categorical_cols: | |
| group_col = categorical_cols[0] | |
| grouped = df_to_analyze[group_col].value_counts() | |
| result = f"Group Count Analysis:\n" + "=" * 25 + "\n\n" | |
| result += grouped.to_string() | |
| else: | |
| result = "Please select at least one categorical column for grouping." | |
| return result, gr.update(visible=True), None | |
| elif analysis_type == "Calculate Expressions": | |
| numeric_cols = df_to_analyze.select_dtypes(include=[np.number]).columns | |
| if len(numeric_cols) >= 2: | |
| col1, col2 = numeric_cols[0], numeric_cols[1] | |
| df_calc = df_to_analyze.copy() | |
| df_calc['Sum'] = df_calc[col1] + df_calc[col2] | |
| df_calc['Difference'] = df_calc[col1] - df_calc[col2] | |
| result = f"Calculated Expressions:\n" + "=" * 30 + "\n\n" | |
| result += f"Using columns: {col1} and {col2}\n\n" | |
| result += f"New calculated columns:\nSum = {col1} + {col2}\nDifference = {col1} - {col2}\n\n" | |
| result += "Sample results:\n" | |
| result += df_calc[['Sum', 'Difference']].head().to_string() | |
| else: | |
| result = "Need at least 2 numeric columns for calculations." | |
| return result, gr.update(visible=True), None | |
| else: | |
| return f"Analysis type '{analysis_type}' is under development.", gr.update(visible=True), None | |
| except Exception as e: | |
| return f"Error in analysis: {str(e)}", gr.update(visible=True), None | |
| def create_chart_explanation(viz_type, df_to_plot, selected_columns, fig_data=None): | |
| if viz_type == "Bar Chart" and len(selected_columns) >= 2: | |
| x_col, y_col = selected_columns[0], selected_columns[1] | |
| if pd.api.types.is_numeric_dtype(df_to_plot[y_col]): | |
| max_val_idx = df_to_plot[y_col].idxmax() | |
| max_category = df_to_plot.loc[max_val_idx, x_col] | |
| max_value = df_to_plot[y_col].max() | |
| y_mean = df_to_plot[y_col].mean() | |
| else: | |
| grouped = df_to_plot.groupby(x_col)[y_col].count() | |
| max_category = grouped.idxmax() | |
| max_value = grouped.max() | |
| y_mean = grouped.mean() | |
| return f"BAR CHART ANALYSIS: {y_col} by {x_col}\n\nKEY INSIGHTS:\n• Highest bar: {max_category} with value {max_value:.2f}\n• Average value: {y_mean:.2f}\n• Categories analyzed: {df_to_plot[x_col].nunique()}\n\nINTERPRETATION:\n• Each bar represents a {x_col} category\n• Bar height shows {y_col} value for that category\n• Compare bars to identify highest/lowest performing categories" | |
| elif viz_type == "Line Chart" and fig_data is not None: | |
| max_combo = fig_data.loc[fig_data['Count'].idxmax()] | |
| min_combo = fig_data.loc[fig_data['Count'].idxmin()] | |
| return f"LINE CHART ANALYSIS: Distribution\n\nKEY INSIGHTS:\n• Highest point: {max_combo[selected_columns[1]]} in {max_combo[selected_columns[0]]} ({max_combo['Count']} records)\n• Lowest point: {min_combo[selected_columns[1]]} in {min_combo[selected_columns[0]]} ({min_combo['Count']} records)\n• Total records analyzed: {len(df_to_plot)}\n\nINTERPRETATION:\n• Each line represents a category\n• Line height shows count of records\n• Compare lines to see patterns between variables" | |
| return f"• {viz_type} showing data visualization\n• Use to understand data patterns and relationships" | |
| def create_visualization(viz_type, selected_columns): | |
| global uploaded_df | |
| if uploaded_df is None or viz_type == "None": | |
| return None, "", gr.update(visible=False) | |
| if not selected_columns: | |
| return None, "Please select columns for visualization.", gr.update(visible=False) | |
| df_to_plot = uploaded_df[selected_columns] | |
| try: | |
| if viz_type == "Bar Chart": | |
| if len(selected_columns) >= 2: | |
| x_col, y_col = selected_columns[0], selected_columns[1] | |
| color_col = selected_columns[2] if len(selected_columns) > 2 else None | |
| fig = px.bar(df_to_plot.head(50), x=x_col, y=y_col, color=color_col, title=f"{y_col} by {x_col}") | |
| explanation = create_chart_explanation(viz_type, df_to_plot, selected_columns) | |
| else: | |
| col = selected_columns[0] | |
| if pd.api.types.is_numeric_dtype(df_to_plot[col]): | |
| fig = px.histogram(df_to_plot, x=col, title=f"Distribution of {col}") | |
| else: | |
| value_counts = df_to_plot[col].value_counts().head(10) | |
| fig = px.bar(x=value_counts.index, y=value_counts.values, title=f"Top Values in {col}") | |
| explanation = f"• Chart showing distribution of {col}" | |
| fig.update_layout(width=800, height=500) | |
| return fig, explanation, gr.update(visible=True) | |
| elif viz_type == "Pie Chart": | |
| col = selected_columns[0] | |
| if len(selected_columns) >= 2 and pd.api.types.is_numeric_dtype(df_to_plot[selected_columns[1]]): | |
| grouped_data = df_to_plot.groupby(col)[selected_columns[1]].sum().reset_index() | |
| fig = px.pie(grouped_data, values=selected_columns[1], names=col, title=f"Total {selected_columns[1]} by {col}") | |
| else: | |
| value_counts = df_to_plot[col].value_counts().head(10) | |
| fig = px.pie(values=value_counts.values, names=value_counts.index, title=f"Distribution of {col}") | |
| explanation = f"PIE CHART ANALYSIS: {col} Distribution\n\nKEY INSIGHTS:\n• Shows proportion of each category\n• Use to understand category distribution patterns" | |
| fig.update_layout(width=800, height=500) | |
| return fig, explanation, gr.update(visible=True) | |
| elif viz_type == "Scatter Plot": | |
| if len(selected_columns) >= 2: | |
| x_col, y_col = selected_columns[0], selected_columns[1] | |
| color_col = selected_columns[2] if len(selected_columns) > 2 else None | |
| fig = px.scatter(df_to_plot, x=x_col, y=y_col, color=color_col, title=f"{y_col} vs {x_col}") | |
| explanation = f"• Scatter plot showing relationship between {x_col} and {y_col}" | |
| else: | |
| return None, "Scatter plot requires at least 2 columns.", gr.update(visible=False) | |
| fig.update_layout(width=800, height=500) | |
| return fig, explanation, gr.update(visible=True) | |
| elif viz_type == "Line Chart": | |
| if len(selected_columns) >= 2: | |
| x_col, y_col = selected_columns[0], selected_columns[1] | |
| if not pd.api.types.is_numeric_dtype(df_to_plot[y_col]): | |
| crosstab = pd.crosstab(df_to_plot[x_col], df_to_plot[y_col]) | |
| melted = pd.melt(crosstab.reset_index(), id_vars=[x_col], var_name=y_col, value_name='Count') | |
| fig = px.line(melted, x=x_col, y='Count', color=y_col, title=f"Distribution of {y_col} across {x_col}", markers=True) | |
| explanation = create_chart_explanation(viz_type, df_to_plot, selected_columns, melted) | |
| else: | |
| fig = px.line(df_to_plot.sort_values(x_col), x=x_col, y=y_col, title=f"Trend of {y_col} over {x_col}", markers=True) | |
| explanation = f"• Line chart showing trend of {y_col} over {x_col}" | |
| else: | |
| return None, "Line chart requires at least 2 columns.", gr.update(visible=False) | |
| fig.update_layout(width=800, height=500) | |
| return fig, explanation, gr.update(visible=True) | |
| elif viz_type == "Histogram": | |
| col = selected_columns[0] | |
| if pd.api.types.is_numeric_dtype(df_to_plot[col]): | |
| fig = px.histogram(df_to_plot, x=col, title=f"Distribution of {col}", nbins=30) | |
| explanation = f"• Histogram showing distribution of {col}" | |
| else: | |
| return None, f"Histogram requires numeric data. Try Bar Chart instead.", gr.update(visible=False) | |
| fig.update_layout(width=800, height=500) | |
| return fig, explanation, gr.update(visible=True) | |
| elif viz_type == "Heat Map": | |
| if len(selected_columns) >= 2: | |
| numeric_cols = [col for col in selected_columns if pd.api.types.is_numeric_dtype(df_to_plot[col])] | |
| if len(numeric_cols) >= 2: | |
| corr_matrix = df_to_plot[numeric_cols].corr() | |
| fig = px.imshow(corr_matrix, text_auto=True, aspect="auto", title="Correlation Heatmap", color_continuous_scale='RdBu') | |
| explanation = f"• Heatmap showing correlations between numeric columns" | |
| else: | |
| x_col, y_col = selected_columns[0], selected_columns[1] | |
| crosstab = pd.crosstab(df_to_plot[x_col], df_to_plot[y_col]) | |
| fig = px.imshow(crosstab.values, x=crosstab.columns, y=crosstab.index, text_auto=True, aspect="auto", title=f"Cross-tabulation: {y_col} vs {x_col}") | |
| explanation = f"• Heatmap showing cross-tabulation between {x_col} and {y_col}" | |
| else: | |
| return None, "Heat map requires at least 2 columns.", gr.update(visible=False) | |
| fig.update_layout(width=800, height=500) | |
| return fig, explanation, gr.update(visible=True) | |
| elif viz_type == "Box Plot": | |
| y_col = selected_columns[0] | |
| if pd.api.types.is_numeric_dtype(df_to_plot[y_col]): | |
| x_col = selected_columns[1] if len(selected_columns) > 1 else None | |
| fig = px.box(df_to_plot, x=x_col, y=y_col, title=f"Box Plot of {y_col}") | |
| explanation = f"• Box plot showing distribution of {y_col}" | |
| else: | |
| return None, f"Box plot requires numeric data.", gr.update(visible=False) | |
| fig.update_layout(width=800, height=500) | |
| return fig, explanation, gr.update(visible=True) | |
| else: | |
| return None, f"Visualization type '{viz_type}' is under development.", gr.update(visible=False) | |
| except Exception as e: | |
| return None, f"Error creating visualization: {str(e)}", gr.update(visible=False) | |
| def parse_plan(raw_text): | |
| txt = raw_text.strip().replace("```json", "").replace("```", "").strip() | |
| try: | |
| start = txt.index("{") | |
| end = txt.rindex("}") + 1 | |
| plan = json.loads(txt[start:end]) | |
| plan.setdefault("type", "analysis") | |
| plan.setdefault("operations", []) | |
| plan.setdefault("plot", None) | |
| plan.setdefault("narrative", "") | |
| plan.setdefault("insights_needed", False) | |
| return plan | |
| except Exception as e: | |
| return { | |
| "type": "error", | |
| "operations": [], | |
| "plot": None, | |
| "narrative": f"Error parsing response: {str(e)}", | |
| "insights_needed": False | |
| } | |
| def execute_plan(df, plan): | |
| dfw = df.copy() | |
| describe_stats = {} | |
| try: | |
| for op in plan.get("operations", []): | |
| optype = op.get("op", "").lower() | |
| if optype == "describe": | |
| cols = op.get("columns", []) | |
| for col in cols: | |
| if col in dfw.columns: | |
| stats = dfw[col].describe() | |
| describe_stats[col] = stats | |
| elif optype == "groupby": | |
| cols = op.get("columns", []) | |
| agg_col = op.get("agg_col") | |
| agg_func = op.get("agg_func", "count") | |
| if cols and all(c in dfw.columns for c in cols): | |
| if agg_func == "count" or not agg_col: | |
| dfw = dfw.groupby(cols).size().reset_index(name="count") | |
| else: | |
| if agg_col in dfw.columns: | |
| result_col = f"{agg_func}_{agg_col}" | |
| dfw = dfw.groupby(cols)[agg_col].agg(agg_func).reset_index(name=result_col) | |
| elif optype == "filter": | |
| expr = op.get("expr", "") | |
| if expr: | |
| dfw = dfw.query(expr) | |
| elif optype == "calculate": | |
| expr = op.get("expr", "") | |
| new_col = op.get("new_col", "Calculated") | |
| if expr: | |
| try: | |
| dfw[new_col] = dfw.eval(expr) | |
| except: | |
| # Handle statistical functions that eval() doesn't support | |
| if "std" in expr: | |
| for col in dfw.select_dtypes(include=[np.number]).columns: | |
| if col in expr: | |
| dfw[new_col] = dfw[col].std() | |
| break | |
| elif "mean" in expr: | |
| for col in dfw.select_dtypes(include=[np.number]).columns: | |
| if col in expr: | |
| dfw[new_col] = dfw[col].mean() | |
| break | |
| else: | |
| # Skip if expression can't be evaluated | |
| pass | |
| return dfw, describe_stats | |
| except Exception as e: | |
| raise Exception(f"Execution error: {str(e)}") | |
| def create_multi_column_plot(df, selected_columns, chart_type="bar", title="Multi-Column Chart"): | |
| try: | |
| if len(selected_columns) >= 2: | |
| x_col, y_col = selected_columns[0], selected_columns[1] | |
| if x_col in df.columns and y_col in df.columns: | |
| if chart_type == "bar": | |
| fig = px.bar(df.head(50), x=x_col, y=y_col, title=title) | |
| elif chart_type == "scatter": | |
| fig = px.scatter(df.head(100), x=x_col, y=y_col, title=title) | |
| else: | |
| fig = px.line(df.head(50), x=x_col, y=y_col, title=title, markers=True) | |
| fig.update_layout(width=900, height=500) | |
| return fig | |
| return None | |
| except: | |
| return None | |
| def create_simple_chart(df, selected_columns=None): | |
| try: | |
| if selected_columns and len(selected_columns) >= 2: | |
| x_col, y_col = selected_columns[0], selected_columns[1] | |
| if x_col in df.columns and y_col in df.columns: | |
| fig = px.bar(df.head(50), x=x_col, y=y_col, title=f"{y_col} by {x_col}") | |
| fig.update_layout(width=900, height=500) | |
| return fig | |
| if selected_columns and len(selected_columns) == 1: | |
| col = selected_columns[0] | |
| if col in df.columns: | |
| if pd.api.types.is_numeric_dtype(df[col]): | |
| fig = px.histogram(df, x=col, title=f"Distribution of {col}") | |
| else: | |
| value_counts = df[col].value_counts().head(10) | |
| fig = px.bar(x=value_counts.index, y=value_counts.values, title=f"Top Values in {col}") | |
| fig.update_layout(width=900, height=500) | |
| return fig | |
| categorical_cols = df.select_dtypes(include=['object', 'category']).columns | |
| numeric_cols = df.select_dtypes(include=[np.number]).columns | |
| if len(numeric_cols) > 1: | |
| fig = px.scatter(df.head(100), x=numeric_cols[0], y=numeric_cols[1], title=f"{numeric_cols[1]} vs {numeric_cols[0]}") | |
| elif len(categorical_cols) > 0 and len(numeric_cols) > 0: | |
| fig = px.bar(df.head(50), x=categorical_cols[0], y=numeric_cols[0], title=f"{numeric_cols[0]} by {categorical_cols[0]}") | |
| elif len(categorical_cols) > 0: | |
| value_counts = df[categorical_cols[0]].value_counts().head(10) | |
| fig = px.pie(values=value_counts.values, names=value_counts.index, title=f"Distribution of {categorical_cols[0]}") | |
| elif len(numeric_cols) > 0: | |
| fig = px.histogram(df, x=numeric_cols[0], title=f"Distribution of {numeric_cols[0]}") | |
| else: | |
| return None | |
| fig.update_layout(width=900, height=500) | |
| return fig | |
| except: | |
| return None | |
| def create_plot(df, dfw, plan, describe_stats, selected_columns=None): | |
| plot_spec = plan.get("plot") | |
| if not plot_spec: | |
| return None | |
| ptype = plot_spec.get("type", "bar") | |
| title = plot_spec.get("title", "Chart") | |
| plot_df = df if describe_stats else dfw | |
| x = plot_spec.get("x") | |
| y = plot_spec.get("y") | |
| if not x and len(plot_df.columns) > 0: | |
| categorical_cols = plot_df.select_dtypes(include=['object', 'category']).columns | |
| x = categorical_cols[0] if len(categorical_cols) > 0 else plot_df.columns[0] | |
| if not y: | |
| numeric_cols = plot_df.select_dtypes(include=[np.number]).columns | |
| y = numeric_cols[0] if len(numeric_cols) > 0 else None | |
| try: | |
| if ptype == "pie" and x and x in plot_df.columns: | |
| value_counts = plot_df[x].value_counts() | |
| fig = go.Figure(data=[go.Pie(labels=value_counts.index, values=value_counts.values, hole=0.3)]) | |
| fig.update_layout(title=title, width=900, height=500) | |
| return fig | |
| elif ptype == "bar" and x and x in plot_df.columns and y and y in plot_df.columns: | |
| fig = px.bar(plot_df, x=x, y=y, title=title) | |
| fig.update_layout(width=900, height=500) | |
| return fig | |
| elif ptype == "line" and x and x in plot_df.columns and y and y in plot_df.columns: | |
| fig = px.line(plot_df, x=x, y=y, title=title, markers=True) | |
| fig.update_layout(width=900, height=500) | |
| return fig | |
| elif ptype == "hist" and y and y in plot_df.columns: | |
| fig = px.histogram(plot_df, x=y, title=title, nbins=30) | |
| fig.update_layout(width=900, height=500) | |
| return fig | |
| elif ptype == "scatter" and x and x in plot_df.columns and y and y in plot_df.columns: | |
| fig = px.scatter(plot_df, x=x, y=y, title=title) | |
| fig.update_layout(width=900, height=500) | |
| return fig | |
| except: | |
| pass | |
| return None | |
| def generate_insights(df, dfw, plan): | |
| try: | |
| context_parts = [] | |
| for op in plan.get("operations", []): | |
| if op.get("op") == "describe": | |
| cols = op.get("columns", []) | |
| for col in cols: | |
| if col in df.columns: | |
| desc = df[col].describe() | |
| context_parts.append(f"\n{col} Statistics:\n{desc.to_string()}") | |
| elif op.get("op") == "groupby": | |
| context_parts.append(f"\nGrouped Results:\n{dfw.head(10).to_string()}") | |
| insights_prompt = get_insights_prompt(context_parts, plan.get('narrative', '')) | |
| response = llm.invoke([ | |
| {"role": "system", "content": INSIGHTS_SYSTEM_PROMPT}, | |
| {"role": "user", "content": insights_prompt} | |
| ]) | |
| return response.content if hasattr(response, 'content') else str(response) | |
| except Exception as e: | |
| return f"Error generating insights: {str(e)}" | |
| def analyze_question(question, selected_columns): | |
| global uploaded_df, llm | |
| if llm is None: | |
| return "API not initialized. Please restart.", None, None | |
| if uploaded_df is None: | |
| return "Please upload a dataset first.", None, None | |
| if not question.strip(): | |
| return "Please enter a question.", None, None | |
| try: | |
| df_to_analyze = uploaded_df[selected_columns] if selected_columns else uploaded_df | |
| sample_data = df_to_analyze.head(3).to_string(max_cols=10, max_colwidth=20) | |
| if selected_columns: | |
| column_context = f"Selected columns for analysis: {', '.join(selected_columns)}\n" | |
| else: | |
| column_context = "" | |
| data_ctx = f"""{column_context}Dataset: {len(df_to_analyze)} rows, {len(df_to_analyze.columns)} columns | |
| Columns: {', '.join(df_to_analyze.columns)} | |
| Sample data: | |
| {sample_data}""" | |
| enhanced_prompt = get_chart_prompt(question, df_to_analyze.columns.tolist(), sample_data) | |
| messages = [ | |
| {"role": "system", "content": ENHANCED_SYSTEM_PROMPT}, | |
| {"role": "user", "content": enhanced_prompt} | |
| ] | |
| response = llm.invoke(messages) | |
| raw_text = response.content if hasattr(response, 'content') else str(response) | |
| plan = parse_plan(raw_text) | |
| if plan.get("type") == "explain": | |
| return plan.get("narrative", "No explanation provided"), None, None | |
| if plan.get("type") == "error": | |
| return plan.get("narrative", "Error occurred"), None, None | |
| if plan.get("plot"): | |
| plan["plot"] = validate_plot_spec(plan["plot"], df_to_analyze.columns.tolist()) | |
| dfw, describe_stats = execute_plan(df_to_analyze, plan) | |
| response_text = plan.get("narrative", "Analysis complete") | |
| if describe_stats: | |
| response_text += "\n\nStatistical Summary:\n" | |
| for col, stats in describe_stats.items(): | |
| response_text += f"\n{col}:\n{stats.to_string()}\n" | |
| fig = None | |
| if plan.get("plot"): | |
| fig = create_plot(df_to_analyze, dfw, plan, describe_stats, selected_columns) | |
| if fig is None and selected_columns and len(selected_columns) >= 2: | |
| fig = create_multi_column_plot(df_to_analyze, selected_columns, "bar", "Multi-Column Chart") | |
| if fig is None: | |
| fig = create_simple_chart(df_to_analyze, selected_columns) | |
| if fig: | |
| response_text += "\n\nChart generated successfully!" | |
| if selected_columns and len(selected_columns) >= 1: | |
| response_text += f"\nUsing selected columns: {', '.join(selected_columns)}" | |
| if plan.get("insights_needed") and fig: | |
| insights = generate_insights(df_to_analyze, dfw, plan) | |
| response_text += f"\n\nKey Insights:\n{insights}" | |
| result_table = None | |
| if not describe_stats and len(dfw) != len(df_to_analyze): | |
| result_table = dfw.head(50) | |
| return response_text, fig, result_table | |
| except Exception as e: | |
| return f"Error during analysis: {str(e)}", None, None | |
| def clear_dataset(): | |
| global uploaded_df, dataset_name | |
| uploaded_df = None | |
| dataset_name = None | |
| return "Dataset cleared. Please upload a new file.", gr.update(visible=False), gr.update(choices=[], value=[]), gr.update(visible=False) | |
| custom_css = """ | |
| .gradio-container { | |
| max-width: 1400px !important; | |
| margin: 0 auto !important; | |
| } | |
| .header-box { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| border-radius: 15px; | |
| padding: 25px; | |
| margin: 20px auto; | |
| text-align: center; | |
| box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1); | |
| } | |
| .header-title { | |
| font-size: 36px; | |
| font-weight: bold; | |
| color: white; | |
| margin: 0; | |
| text-shadow: 2px 2px 4px rgba(0,0,0,0.3); | |
| } | |
| .section-box { | |
| background-color: #f8f9fa; | |
| padding: 20px; | |
| border-radius: 12px; | |
| margin: 15px 0; | |
| border: 1px solid #e9ecef; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo: | |
| gr.HTML(""" | |
| <div class="header-box"> | |
| <h1 class="header-title">SparkNova</h1> | |
| <p style="color: white; font-size: 18px; margin: 10px 0 0 0;">Advanced Data Analysis Platform</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Upload Dataset") | |
| file_input = gr.File(label="Choose CSV or Excel File", file_types=[".csv", ".xlsx", ".xls"]) | |
| dataset_info = gr.Markdown() | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear Dataset", variant="secondary", size="sm") | |
| column_selector = gr.CheckboxGroup( | |
| label="Select Columns (optional - for multi-column charts)", | |
| choices=[], | |
| visible=False | |
| ) | |
| format_selector = gr.Dropdown( | |
| choices=["None", "DataFrame", "JSON", "Dictionary"], | |
| value="None", | |
| label="Display Format" | |
| ) | |
| gr.Markdown("### Choose an Analysis Type") | |
| analysis_selector = gr.Dropdown( | |
| choices=["None", "Summary", "Describe", "Top 5 Rows", "Bottom 5 Rows", "Missing Values", "Group & Aggregate", "Calculate Expressions", "Highest Correlation"], | |
| value="None", | |
| label="Analysis Type" | |
| ) | |
| gr.Markdown("### Visualization Types") | |
| viz_selector = gr.Dropdown( | |
| choices=["None", "Bar Chart", "Line Chart", "Scatter Plot", "Pie Chart", "Histogram", "Box Plot", "Heat Map"], | |
| value="None", | |
| label="Chart Type" | |
| ) | |
| with gr.Column(scale=2): | |
| preview_heading = gr.Markdown("### Dataset Preview", visible=False) | |
| dataset_preview = gr.Dataframe(wrap=True, visible=False) | |
| text_preview = gr.Textbox(label="Text Preview", lines=15, visible=False) | |
| analysis_heading = gr.Markdown("### Analysis Results", visible=False) | |
| analysis_output = gr.Textbox(label="Analysis Output", lines=10, visible=False, interactive=False) | |
| analysis_data_table = gr.Dataframe(label="Data Table", visible=False, wrap=True) | |
| chart_output_new = gr.Plot(label="Chart", visible=False) | |
| chart_explanation = gr.Textbox(label="Chart Analysis", lines=5, visible=False, interactive=False) | |
| gr.Markdown("### Sample Questions") | |
| with gr.Row(): | |
| for i in range(0, len(SAMPLE_QUESTIONS), 3): | |
| with gr.Column(): | |
| for j in range(3): | |
| if i + j < len(SAMPLE_QUESTIONS): | |
| gr.Markdown(f"• {SAMPLE_QUESTIONS[i + j]}") | |
| gr.Markdown("### Ask Your Question") | |
| user_question = gr.Textbox( | |
| label="Enter your question", | |
| placeholder="Ask anything about your data...", | |
| lines=3 | |
| ) | |
| submit_btn = gr.Button("Analyze", variant="primary", size="lg") | |
| gr.Markdown("### Analysis Results") | |
| with gr.Tabs(): | |
| with gr.Tab("Response"): | |
| output_text = gr.Textbox( | |
| label="Analysis Response", | |
| interactive=False, | |
| lines=15, | |
| show_copy_button=True | |
| ) | |
| with gr.Tab("Visualization"): | |
| chart_output = gr.Plot(label="Generated Chart") | |
| with gr.Tab("Data"): | |
| result_table = gr.Dataframe(label="Result Data", wrap=True) | |
| file_input.change( | |
| upload_dataset, | |
| inputs=file_input, | |
| outputs=[dataset_info, dataset_preview, column_selector, column_selector] | |
| ) | |
| clear_btn.click( | |
| clear_dataset, | |
| outputs=[dataset_info, dataset_preview, column_selector, column_selector] | |
| ) | |
| def update_preview(format_type, selected_columns): | |
| if format_type == "None": | |
| return None, "", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) | |
| elif format_type == "DataFrame": | |
| return display_data_format(format_type, selected_columns), "", gr.update(visible=True), gr.update(visible=False), gr.update(visible=True) | |
| else: | |
| return None, display_text_format(format_type, selected_columns), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) | |
| def handle_analysis_change(analysis_type, selected_columns): | |
| result_text, heading_update, data_table = run_analysis(analysis_type, selected_columns) | |
| if result_text and result_text.strip() and analysis_type != "None": | |
| if data_table is not None: | |
| return gr.update(value=result_text, visible=True), gr.update(visible=True), gr.update(value=data_table, visible=True) | |
| else: | |
| return gr.update(value=result_text, visible=True), gr.update(visible=True), gr.update(visible=False) | |
| else: | |
| return gr.update(value="", visible=False), gr.update(visible=False), gr.update(visible=False) | |
| analysis_selector.change( | |
| handle_analysis_change, | |
| inputs=[analysis_selector, column_selector], | |
| outputs=[analysis_output, analysis_heading, analysis_data_table] | |
| ) | |
| def update_viz_display(viz_type, selected_columns): | |
| result = create_visualization(viz_type, selected_columns) | |
| if result and len(result) == 3: | |
| fig, explanation, chart_visible = result | |
| return fig, explanation, chart_visible | |
| else: | |
| return None, "Error in visualization", gr.update(visible=False) | |
| def handle_viz_change(viz_type, selected_columns): | |
| fig, explanation, chart_visible = update_viz_display(viz_type, selected_columns) | |
| if explanation: | |
| return fig, chart_visible, explanation, gr.update(visible=True) | |
| else: | |
| return fig, chart_visible, "", gr.update(visible=False) | |
| viz_selector.change( | |
| handle_viz_change, | |
| inputs=[viz_selector, column_selector], | |
| outputs=[chart_output_new, chart_output_new, chart_explanation, chart_explanation] | |
| ) | |
| format_selector.change( | |
| update_preview, | |
| inputs=[format_selector, column_selector], | |
| outputs=[dataset_preview, text_preview, dataset_preview, text_preview, preview_heading] | |
| ) | |
| column_selector.change( | |
| update_preview, | |
| inputs=[format_selector, column_selector], | |
| outputs=[dataset_preview, text_preview, dataset_preview, text_preview, preview_heading] | |
| ) | |
| column_selector.change( | |
| handle_analysis_change, | |
| inputs=[analysis_selector, column_selector], | |
| outputs=[analysis_output, analysis_heading, analysis_data_table] | |
| ) | |
| column_selector.change( | |
| handle_viz_change, | |
| inputs=[viz_selector, column_selector], | |
| outputs=[chart_output_new, chart_output_new, chart_explanation, chart_explanation] | |
| ) | |
| submit_btn.click( | |
| analyze_question, | |
| inputs=[user_question, column_selector], | |
| outputs=[output_text, chart_output, result_table] | |
| ) | |
| gr.HTML("<div style='text-align: center; margin-top: 20px; color: #666;'>Powered by GROQ LLM & Gradio</div>") | |
| if __name__ == "__main__": | |
| if not initialize_llm(): | |
| print("Warning: Failed to initialize GROQ API") | |
| demo.launch(show_error=True, share=True) |