Data_Analyzer / app.py
Tamannathakur's picture
Update app.py
6885df0 verified
raw
history blame
39.8 kB
import gradio as gr
import pandas as pd
import numpy as np
import os
import json
import plotly.express as px
import plotly.graph_objects as go
from langchain_groq import ChatGroq
from prompts import ENHANCED_SYSTEM_PROMPT, SAMPLE_QUESTIONS, get_chart_prompt, validate_plot_spec, INSIGHTS_SYSTEM_PROMPT, get_insights_prompt
GROQ_API_KEY = "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"
llm = None
uploaded_df = None
dataset_name = None
def initialize_llm():
global llm
try:
os.environ["GROQ_API_KEY"] = GROQ_API_KEY
llm = ChatGroq(model="llama-3.3-70b-versatile", api_key=GROQ_API_KEY, temperature=0.0)
return True
except:
return False
def upload_dataset(file):
global uploaded_df, dataset_name
if file is None:
return "No file uploaded", gr.update(visible=False), gr.update(choices=[]), gr.update(visible=False)
try:
dataset_name = os.path.basename(file.name)
if file.name.endswith('.csv'):
uploaded_df = pd.read_csv(file.name)
elif file.name.endswith(('.xlsx', '.xls')):
uploaded_df = pd.read_excel(file.name)
else:
return "Unsupported file format. Please upload CSV or Excel files.", gr.update(visible=False), gr.update(choices=[]), gr.update(visible=False)
uploaded_df = clean_numeric(uploaded_df)
info_text = f"**Dataset Loaded:** {dataset_name} ({uploaded_df.shape[0]} rows × {uploaded_df.shape[1]} columns)"
return info_text, gr.update(visible=False), gr.update(choices=list(uploaded_df.columns), value=[]), gr.update(visible=True)
except Exception as e:
return f"Error loading file: {str(e)}", gr.update(visible=False), gr.update(choices=[]), gr.update(visible=False)
def clean_numeric(df):
df = df.copy()
for col in df.columns:
if pd.api.types.is_string_dtype(df[col]) or df[col].dtype == object:
s = df[col].astype(str).str.strip()
if s.str.contains("%", na=False).any():
numeric_vals = pd.to_numeric(s.str.replace("%", "", regex=False), errors="coerce")
if numeric_vals.notna().sum() / len(df) > 0.5:
df[col] = numeric_vals / 100.0
continue
cleaned = s.str.replace(",", "", regex=False).str.replace("₹", "", regex=False).str.replace("$", "", regex=False)
numeric_vals = pd.to_numeric(cleaned, errors="coerce")
if numeric_vals.notna().sum() / len(df) > 0.5:
df[col] = numeric_vals
return df
def display_data_format(format_type, selected_columns):
global uploaded_df
if uploaded_df is None or format_type == "None":
return None
df_to_show = uploaded_df[selected_columns] if selected_columns else uploaded_df
return df_to_show.head(100) if format_type == "DataFrame" else None
def display_text_format(format_type, selected_columns):
global uploaded_df
if uploaded_df is None or format_type == "None":
return ""
df_to_show = uploaded_df[selected_columns] if selected_columns else uploaded_df
if format_type == "JSON":
return df_to_show.head(20).to_json(orient='records', indent=2)
elif format_type == "Dictionary":
return str(df_to_show.head(20).to_dict(orient='records'))
return ""
def run_analysis(analysis_type, selected_columns):
global uploaded_df
if uploaded_df is None:
return "Please upload a dataset first.", gr.update(visible=True), None
if analysis_type == "None" or analysis_type is None:
return "", gr.update(visible=False), None
whole_dataset_analyses = ["Summary", "Top 5 Rows", "Bottom 5 Rows", "Missing Values"]
if analysis_type in whole_dataset_analyses:
df_to_analyze = uploaded_df
else:
if not selected_columns:
return f"Please select columns for {analysis_type} analysis.", gr.update(visible=True), None
df_to_analyze = uploaded_df[selected_columns]
try:
if analysis_type == "Summary":
numeric_cols = uploaded_df.select_dtypes(include=[np.number]).columns
categorical_cols = uploaded_df.select_dtypes(include=['object', 'category']).columns
result = f"Dataset Summary:\nRows: {len(uploaded_df):,}\nColumns: {len(uploaded_df.columns)}\nNumeric Columns: {len(numeric_cols)}\nText Columns: {len(categorical_cols)}\n\n"
if len(numeric_cols) > 0:
result += "Numeric Columns: " + ", ".join(numeric_cols.tolist()) + "\n"
if len(categorical_cols) > 0:
result += "Text Columns: " + ", ".join(categorical_cols.tolist())
return result, gr.update(visible=True), None
elif analysis_type == "Describe":
result = "Column Description:\n" + "=" * 30 + "\n\n"
for col in selected_columns:
if col in df_to_analyze.columns:
result += f"Column: {col}\n"
if pd.api.types.is_numeric_dtype(df_to_analyze[col]):
stats = df_to_analyze[col].describe()
result += f" Type: Numeric\n Count: {stats['count']:.0f}\n Mean: {stats['mean']:.3f}\n Std: {stats['std']:.3f}\n Min: {stats['min']:.3f}\n 25%: {stats['25%']:.3f}\n 50%: {stats['50%']:.3f}\n 75%: {stats['75%']:.3f}\n Max: {stats['max']:.3f}\n\n"
else:
unique_count = df_to_analyze[col].nunique()
null_count = df_to_analyze[col].isnull().sum()
most_common = df_to_analyze[col].mode().iloc[0] if len(df_to_analyze[col].mode()) > 0 else "N/A"
result += f" Type: Categorical/Text\n Unique Values: {unique_count}\n Missing Values: {null_count}\n Most Common: {most_common}\n"
top_values = df_to_analyze[col].value_counts().head(5)
result += " Top Values:\n"
for val, count in top_values.items():
result += f" {val}: {count} times\n"
result += "\n"
return result, gr.update(visible=True), None
elif analysis_type == "Top 5 Rows":
return "Top 5 Rows - See data table below", gr.update(visible=True), df_to_analyze.head(5)
elif analysis_type == "Bottom 5 Rows":
return "Bottom 5 Rows - See data table below", gr.update(visible=True), df_to_analyze.tail(5)
elif analysis_type == "Missing Values":
missing = df_to_analyze.isnull().sum()
result = "Missing Values Analysis:\n" + "=" * 30 + "\n\n"
for col in df_to_analyze.columns:
missing_count = missing[col]
missing_percent = (missing_count / len(df_to_analyze)) * 100
result += f"{col}: {missing_count} missing ({missing_percent:.2f}%)\n"
return result, gr.update(visible=True), None
elif analysis_type == "Highest Correlation":
numeric_cols = df_to_analyze.select_dtypes(include=[np.number]).columns
if len(numeric_cols) < 2:
return "Need at least 2 numeric columns for correlation analysis.", gr.update(visible=True), None
corr_matrix = df_to_analyze[numeric_cols].corr()
result = "Highest Correlations:\n" + "=" * 25 + "\n\n"
correlations = []
for i in range(len(corr_matrix.columns)):
for j in range(i+1, len(corr_matrix.columns)):
col1, col2 = corr_matrix.columns[i], corr_matrix.columns[j]
corr_val = corr_matrix.iloc[i, j]
correlations.append((abs(corr_val), col1, col2, corr_val))
correlations.sort(reverse=True)
for _, col1, col2, corr_val in correlations[:10]:
result += f"{col1}{col2}: {corr_val:.3f}\n"
return result, gr.update(visible=True), None
elif analysis_type == "Group & Aggregate":
if not selected_columns:
result = "Please select columns for grouping and aggregation."
else:
categorical_cols = [col for col in selected_columns if not pd.api.types.is_numeric_dtype(df_to_analyze[col])]
numeric_cols = [col for col in selected_columns if pd.api.types.is_numeric_dtype(df_to_analyze[col])]
if categorical_cols and numeric_cols:
group_col = categorical_cols[0]
agg_col = numeric_cols[0]
grouped = df_to_analyze.groupby(group_col)[agg_col].agg(['count', 'mean', 'sum']).round(2)
result = f"Group & Aggregate Analysis:\n" + "=" * 35 + "\n\n"
result += f"Grouped by: {group_col}\nAggregated: {agg_col}\n\n"
result += grouped.to_string()
elif categorical_cols:
group_col = categorical_cols[0]
grouped = df_to_analyze[group_col].value_counts()
result = f"Group Count Analysis:\n" + "=" * 25 + "\n\n"
result += grouped.to_string()
else:
result = "Please select at least one categorical column for grouping."
return result, gr.update(visible=True), None
elif analysis_type == "Calculate Expressions":
numeric_cols = df_to_analyze.select_dtypes(include=[np.number]).columns
if len(numeric_cols) >= 2:
col1, col2 = numeric_cols[0], numeric_cols[1]
df_calc = df_to_analyze.copy()
df_calc['Sum'] = df_calc[col1] + df_calc[col2]
df_calc['Difference'] = df_calc[col1] - df_calc[col2]
result = f"Calculated Expressions:\n" + "=" * 30 + "\n\n"
result += f"Using columns: {col1} and {col2}\n\n"
result += f"New calculated columns:\nSum = {col1} + {col2}\nDifference = {col1} - {col2}\n\n"
result += "Sample results:\n"
result += df_calc[['Sum', 'Difference']].head().to_string()
else:
result = "Need at least 2 numeric columns for calculations."
return result, gr.update(visible=True), None
else:
return f"Analysis type '{analysis_type}' is under development.", gr.update(visible=True), None
except Exception as e:
return f"Error in analysis: {str(e)}", gr.update(visible=True), None
def create_chart_explanation(viz_type, df_to_plot, selected_columns, fig_data=None):
if viz_type == "Bar Chart" and len(selected_columns) >= 2:
x_col, y_col = selected_columns[0], selected_columns[1]
if pd.api.types.is_numeric_dtype(df_to_plot[y_col]):
max_val_idx = df_to_plot[y_col].idxmax()
max_category = df_to_plot.loc[max_val_idx, x_col]
max_value = df_to_plot[y_col].max()
y_mean = df_to_plot[y_col].mean()
else:
grouped = df_to_plot.groupby(x_col)[y_col].count()
max_category = grouped.idxmax()
max_value = grouped.max()
y_mean = grouped.mean()
return f"BAR CHART ANALYSIS: {y_col} by {x_col}\n\nKEY INSIGHTS:\n• Highest bar: {max_category} with value {max_value:.2f}\n• Average value: {y_mean:.2f}\n• Categories analyzed: {df_to_plot[x_col].nunique()}\n\nINTERPRETATION:\n• Each bar represents a {x_col} category\n• Bar height shows {y_col} value for that category\n• Compare bars to identify highest/lowest performing categories"
elif viz_type == "Line Chart" and fig_data is not None:
max_combo = fig_data.loc[fig_data['Count'].idxmax()]
min_combo = fig_data.loc[fig_data['Count'].idxmin()]
return f"LINE CHART ANALYSIS: Distribution\n\nKEY INSIGHTS:\n• Highest point: {max_combo[selected_columns[1]]} in {max_combo[selected_columns[0]]} ({max_combo['Count']} records)\n• Lowest point: {min_combo[selected_columns[1]]} in {min_combo[selected_columns[0]]} ({min_combo['Count']} records)\n• Total records analyzed: {len(df_to_plot)}\n\nINTERPRETATION:\n• Each line represents a category\n• Line height shows count of records\n• Compare lines to see patterns between variables"
return f"• {viz_type} showing data visualization\n• Use to understand data patterns and relationships"
def create_visualization(viz_type, selected_columns):
global uploaded_df
if uploaded_df is None or viz_type == "None":
return None, "", gr.update(visible=False)
if not selected_columns:
return None, "Please select columns for visualization.", gr.update(visible=False)
df_to_plot = uploaded_df[selected_columns]
try:
if viz_type == "Bar Chart":
if len(selected_columns) >= 2:
x_col, y_col = selected_columns[0], selected_columns[1]
color_col = selected_columns[2] if len(selected_columns) > 2 else None
fig = px.bar(df_to_plot.head(50), x=x_col, y=y_col, color=color_col, title=f"{y_col} by {x_col}")
explanation = create_chart_explanation(viz_type, df_to_plot, selected_columns)
else:
col = selected_columns[0]
if pd.api.types.is_numeric_dtype(df_to_plot[col]):
fig = px.histogram(df_to_plot, x=col, title=f"Distribution of {col}")
else:
value_counts = df_to_plot[col].value_counts().head(10)
fig = px.bar(x=value_counts.index, y=value_counts.values, title=f"Top Values in {col}")
explanation = f"• Chart showing distribution of {col}"
fig.update_layout(width=800, height=500)
return fig, explanation, gr.update(visible=True)
elif viz_type == "Pie Chart":
col = selected_columns[0]
if len(selected_columns) >= 2 and pd.api.types.is_numeric_dtype(df_to_plot[selected_columns[1]]):
grouped_data = df_to_plot.groupby(col)[selected_columns[1]].sum().reset_index()
fig = px.pie(grouped_data, values=selected_columns[1], names=col, title=f"Total {selected_columns[1]} by {col}")
else:
value_counts = df_to_plot[col].value_counts().head(10)
fig = px.pie(values=value_counts.values, names=value_counts.index, title=f"Distribution of {col}")
explanation = f"PIE CHART ANALYSIS: {col} Distribution\n\nKEY INSIGHTS:\n• Shows proportion of each category\n• Use to understand category distribution patterns"
fig.update_layout(width=800, height=500)
return fig, explanation, gr.update(visible=True)
elif viz_type == "Scatter Plot":
if len(selected_columns) >= 2:
x_col, y_col = selected_columns[0], selected_columns[1]
color_col = selected_columns[2] if len(selected_columns) > 2 else None
fig = px.scatter(df_to_plot, x=x_col, y=y_col, color=color_col, title=f"{y_col} vs {x_col}")
explanation = f"• Scatter plot showing relationship between {x_col} and {y_col}"
else:
return None, "Scatter plot requires at least 2 columns.", gr.update(visible=False)
fig.update_layout(width=800, height=500)
return fig, explanation, gr.update(visible=True)
elif viz_type == "Line Chart":
if len(selected_columns) >= 2:
x_col, y_col = selected_columns[0], selected_columns[1]
if not pd.api.types.is_numeric_dtype(df_to_plot[y_col]):
crosstab = pd.crosstab(df_to_plot[x_col], df_to_plot[y_col])
melted = pd.melt(crosstab.reset_index(), id_vars=[x_col], var_name=y_col, value_name='Count')
fig = px.line(melted, x=x_col, y='Count', color=y_col, title=f"Distribution of {y_col} across {x_col}", markers=True)
explanation = create_chart_explanation(viz_type, df_to_plot, selected_columns, melted)
else:
fig = px.line(df_to_plot.sort_values(x_col), x=x_col, y=y_col, title=f"Trend of {y_col} over {x_col}", markers=True)
explanation = f"• Line chart showing trend of {y_col} over {x_col}"
else:
return None, "Line chart requires at least 2 columns.", gr.update(visible=False)
fig.update_layout(width=800, height=500)
return fig, explanation, gr.update(visible=True)
elif viz_type == "Histogram":
col = selected_columns[0]
if pd.api.types.is_numeric_dtype(df_to_plot[col]):
fig = px.histogram(df_to_plot, x=col, title=f"Distribution of {col}", nbins=30)
explanation = f"• Histogram showing distribution of {col}"
else:
return None, f"Histogram requires numeric data. Try Bar Chart instead.", gr.update(visible=False)
fig.update_layout(width=800, height=500)
return fig, explanation, gr.update(visible=True)
elif viz_type == "Heat Map":
if len(selected_columns) >= 2:
numeric_cols = [col for col in selected_columns if pd.api.types.is_numeric_dtype(df_to_plot[col])]
if len(numeric_cols) >= 2:
corr_matrix = df_to_plot[numeric_cols].corr()
fig = px.imshow(corr_matrix, text_auto=True, aspect="auto", title="Correlation Heatmap", color_continuous_scale='RdBu')
explanation = f"• Heatmap showing correlations between numeric columns"
else:
x_col, y_col = selected_columns[0], selected_columns[1]
crosstab = pd.crosstab(df_to_plot[x_col], df_to_plot[y_col])
fig = px.imshow(crosstab.values, x=crosstab.columns, y=crosstab.index, text_auto=True, aspect="auto", title=f"Cross-tabulation: {y_col} vs {x_col}")
explanation = f"• Heatmap showing cross-tabulation between {x_col} and {y_col}"
else:
return None, "Heat map requires at least 2 columns.", gr.update(visible=False)
fig.update_layout(width=800, height=500)
return fig, explanation, gr.update(visible=True)
elif viz_type == "Box Plot":
y_col = selected_columns[0]
if pd.api.types.is_numeric_dtype(df_to_plot[y_col]):
x_col = selected_columns[1] if len(selected_columns) > 1 else None
fig = px.box(df_to_plot, x=x_col, y=y_col, title=f"Box Plot of {y_col}")
explanation = f"• Box plot showing distribution of {y_col}"
else:
return None, f"Box plot requires numeric data.", gr.update(visible=False)
fig.update_layout(width=800, height=500)
return fig, explanation, gr.update(visible=True)
else:
return None, f"Visualization type '{viz_type}' is under development.", gr.update(visible=False)
except Exception as e:
return None, f"Error creating visualization: {str(e)}", gr.update(visible=False)
def parse_plan(raw_text):
txt = raw_text.strip().replace("```json", "").replace("```", "").strip()
try:
start = txt.index("{")
end = txt.rindex("}") + 1
plan = json.loads(txt[start:end])
plan.setdefault("type", "analysis")
plan.setdefault("operations", [])
plan.setdefault("plot", None)
plan.setdefault("narrative", "")
plan.setdefault("insights_needed", False)
return plan
except Exception as e:
return {
"type": "error",
"operations": [],
"plot": None,
"narrative": f"Error parsing response: {str(e)}",
"insights_needed": False
}
def execute_plan(df, plan):
dfw = df.copy()
describe_stats = {}
try:
for op in plan.get("operations", []):
optype = op.get("op", "").lower()
if optype == "describe":
cols = op.get("columns", [])
for col in cols:
if col in dfw.columns:
stats = dfw[col].describe()
describe_stats[col] = stats
elif optype == "groupby":
cols = op.get("columns", [])
agg_col = op.get("agg_col")
agg_func = op.get("agg_func", "count")
if cols and all(c in dfw.columns for c in cols):
if agg_func == "count" or not agg_col:
dfw = dfw.groupby(cols).size().reset_index(name="count")
else:
if agg_col in dfw.columns:
result_col = f"{agg_func}_{agg_col}"
dfw = dfw.groupby(cols)[agg_col].agg(agg_func).reset_index(name=result_col)
elif optype == "filter":
expr = op.get("expr", "")
if expr:
dfw = dfw.query(expr)
elif optype == "calculate":
expr = op.get("expr", "")
new_col = op.get("new_col", "Calculated")
if expr:
try:
dfw[new_col] = dfw.eval(expr)
except:
# Handle statistical functions that eval() doesn't support
if "std" in expr:
for col in dfw.select_dtypes(include=[np.number]).columns:
if col in expr:
dfw[new_col] = dfw[col].std()
break
elif "mean" in expr:
for col in dfw.select_dtypes(include=[np.number]).columns:
if col in expr:
dfw[new_col] = dfw[col].mean()
break
else:
# Skip if expression can't be evaluated
pass
return dfw, describe_stats
except Exception as e:
raise Exception(f"Execution error: {str(e)}")
def create_multi_column_plot(df, selected_columns, chart_type="bar", title="Multi-Column Chart"):
try:
if len(selected_columns) >= 2:
x_col, y_col = selected_columns[0], selected_columns[1]
if x_col in df.columns and y_col in df.columns:
if chart_type == "bar":
fig = px.bar(df.head(50), x=x_col, y=y_col, title=title)
elif chart_type == "scatter":
fig = px.scatter(df.head(100), x=x_col, y=y_col, title=title)
else:
fig = px.line(df.head(50), x=x_col, y=y_col, title=title, markers=True)
fig.update_layout(width=900, height=500)
return fig
return None
except:
return None
def create_simple_chart(df, selected_columns=None):
try:
if selected_columns and len(selected_columns) >= 2:
x_col, y_col = selected_columns[0], selected_columns[1]
if x_col in df.columns and y_col in df.columns:
fig = px.bar(df.head(50), x=x_col, y=y_col, title=f"{y_col} by {x_col}")
fig.update_layout(width=900, height=500)
return fig
if selected_columns and len(selected_columns) == 1:
col = selected_columns[0]
if col in df.columns:
if pd.api.types.is_numeric_dtype(df[col]):
fig = px.histogram(df, x=col, title=f"Distribution of {col}")
else:
value_counts = df[col].value_counts().head(10)
fig = px.bar(x=value_counts.index, y=value_counts.values, title=f"Top Values in {col}")
fig.update_layout(width=900, height=500)
return fig
categorical_cols = df.select_dtypes(include=['object', 'category']).columns
numeric_cols = df.select_dtypes(include=[np.number]).columns
if len(numeric_cols) > 1:
fig = px.scatter(df.head(100), x=numeric_cols[0], y=numeric_cols[1], title=f"{numeric_cols[1]} vs {numeric_cols[0]}")
elif len(categorical_cols) > 0 and len(numeric_cols) > 0:
fig = px.bar(df.head(50), x=categorical_cols[0], y=numeric_cols[0], title=f"{numeric_cols[0]} by {categorical_cols[0]}")
elif len(categorical_cols) > 0:
value_counts = df[categorical_cols[0]].value_counts().head(10)
fig = px.pie(values=value_counts.values, names=value_counts.index, title=f"Distribution of {categorical_cols[0]}")
elif len(numeric_cols) > 0:
fig = px.histogram(df, x=numeric_cols[0], title=f"Distribution of {numeric_cols[0]}")
else:
return None
fig.update_layout(width=900, height=500)
return fig
except:
return None
def create_plot(df, dfw, plan, describe_stats, selected_columns=None):
plot_spec = plan.get("plot")
if not plot_spec:
return None
ptype = plot_spec.get("type", "bar")
title = plot_spec.get("title", "Chart")
plot_df = df if describe_stats else dfw
x = plot_spec.get("x")
y = plot_spec.get("y")
if not x and len(plot_df.columns) > 0:
categorical_cols = plot_df.select_dtypes(include=['object', 'category']).columns
x = categorical_cols[0] if len(categorical_cols) > 0 else plot_df.columns[0]
if not y:
numeric_cols = plot_df.select_dtypes(include=[np.number]).columns
y = numeric_cols[0] if len(numeric_cols) > 0 else None
try:
if ptype == "pie" and x and x in plot_df.columns:
value_counts = plot_df[x].value_counts()
fig = go.Figure(data=[go.Pie(labels=value_counts.index, values=value_counts.values, hole=0.3)])
fig.update_layout(title=title, width=900, height=500)
return fig
elif ptype == "bar" and x and x in plot_df.columns and y and y in plot_df.columns:
fig = px.bar(plot_df, x=x, y=y, title=title)
fig.update_layout(width=900, height=500)
return fig
elif ptype == "line" and x and x in plot_df.columns and y and y in plot_df.columns:
fig = px.line(plot_df, x=x, y=y, title=title, markers=True)
fig.update_layout(width=900, height=500)
return fig
elif ptype == "hist" and y and y in plot_df.columns:
fig = px.histogram(plot_df, x=y, title=title, nbins=30)
fig.update_layout(width=900, height=500)
return fig
elif ptype == "scatter" and x and x in plot_df.columns and y and y in plot_df.columns:
fig = px.scatter(plot_df, x=x, y=y, title=title)
fig.update_layout(width=900, height=500)
return fig
except:
pass
return None
def generate_insights(df, dfw, plan):
try:
context_parts = []
for op in plan.get("operations", []):
if op.get("op") == "describe":
cols = op.get("columns", [])
for col in cols:
if col in df.columns:
desc = df[col].describe()
context_parts.append(f"\n{col} Statistics:\n{desc.to_string()}")
elif op.get("op") == "groupby":
context_parts.append(f"\nGrouped Results:\n{dfw.head(10).to_string()}")
insights_prompt = get_insights_prompt(context_parts, plan.get('narrative', ''))
response = llm.invoke([
{"role": "system", "content": INSIGHTS_SYSTEM_PROMPT},
{"role": "user", "content": insights_prompt}
])
return response.content if hasattr(response, 'content') else str(response)
except Exception as e:
return f"Error generating insights: {str(e)}"
def analyze_question(question, selected_columns):
global uploaded_df, llm
if llm is None:
return "API not initialized. Please restart.", None, None
if uploaded_df is None:
return "Please upload a dataset first.", None, None
if not question.strip():
return "Please enter a question.", None, None
try:
df_to_analyze = uploaded_df[selected_columns] if selected_columns else uploaded_df
sample_data = df_to_analyze.head(3).to_string(max_cols=10, max_colwidth=20)
if selected_columns:
column_context = f"Selected columns for analysis: {', '.join(selected_columns)}\n"
else:
column_context = ""
data_ctx = f"""{column_context}Dataset: {len(df_to_analyze)} rows, {len(df_to_analyze.columns)} columns
Columns: {', '.join(df_to_analyze.columns)}
Sample data:
{sample_data}"""
enhanced_prompt = get_chart_prompt(question, df_to_analyze.columns.tolist(), sample_data)
messages = [
{"role": "system", "content": ENHANCED_SYSTEM_PROMPT},
{"role": "user", "content": enhanced_prompt}
]
response = llm.invoke(messages)
raw_text = response.content if hasattr(response, 'content') else str(response)
plan = parse_plan(raw_text)
if plan.get("type") == "explain":
return plan.get("narrative", "No explanation provided"), None, None
if plan.get("type") == "error":
return plan.get("narrative", "Error occurred"), None, None
if plan.get("plot"):
plan["plot"] = validate_plot_spec(plan["plot"], df_to_analyze.columns.tolist())
dfw, describe_stats = execute_plan(df_to_analyze, plan)
response_text = plan.get("narrative", "Analysis complete")
if describe_stats:
response_text += "\n\nStatistical Summary:\n"
for col, stats in describe_stats.items():
response_text += f"\n{col}:\n{stats.to_string()}\n"
fig = None
if plan.get("plot"):
fig = create_plot(df_to_analyze, dfw, plan, describe_stats, selected_columns)
if fig is None and selected_columns and len(selected_columns) >= 2:
fig = create_multi_column_plot(df_to_analyze, selected_columns, "bar", "Multi-Column Chart")
if fig is None:
fig = create_simple_chart(df_to_analyze, selected_columns)
if fig:
response_text += "\n\nChart generated successfully!"
if selected_columns and len(selected_columns) >= 1:
response_text += f"\nUsing selected columns: {', '.join(selected_columns)}"
if plan.get("insights_needed") and fig:
insights = generate_insights(df_to_analyze, dfw, plan)
response_text += f"\n\nKey Insights:\n{insights}"
result_table = None
if not describe_stats and len(dfw) != len(df_to_analyze):
result_table = dfw.head(50)
return response_text, fig, result_table
except Exception as e:
return f"Error during analysis: {str(e)}", None, None
def clear_dataset():
global uploaded_df, dataset_name
uploaded_df = None
dataset_name = None
return "Dataset cleared. Please upload a new file.", gr.update(visible=False), gr.update(choices=[], value=[]), gr.update(visible=False)
custom_css = """
.gradio-container {
max-width: 1400px !important;
margin: 0 auto !important;
}
.header-box {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border-radius: 15px;
padding: 25px;
margin: 20px auto;
text-align: center;
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
}
.header-title {
font-size: 36px;
font-weight: bold;
color: white;
margin: 0;
text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
}
.section-box {
background-color: #f8f9fa;
padding: 20px;
border-radius: 12px;
margin: 15px 0;
border: 1px solid #e9ecef;
}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
gr.HTML("""
<div class="header-box">
<h1 class="header-title">SparkNova</h1>
<p style="color: white; font-size: 18px; margin: 10px 0 0 0;">Advanced Data Analysis Platform</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Upload Dataset")
file_input = gr.File(label="Choose CSV or Excel File", file_types=[".csv", ".xlsx", ".xls"])
dataset_info = gr.Markdown()
with gr.Row():
clear_btn = gr.Button("Clear Dataset", variant="secondary", size="sm")
column_selector = gr.CheckboxGroup(
label="Select Columns (optional - for multi-column charts)",
choices=[],
visible=False
)
format_selector = gr.Dropdown(
choices=["None", "DataFrame", "JSON", "Dictionary"],
value="None",
label="Display Format"
)
gr.Markdown("### Choose an Analysis Type")
analysis_selector = gr.Dropdown(
choices=["None", "Summary", "Describe", "Top 5 Rows", "Bottom 5 Rows", "Missing Values", "Group & Aggregate", "Calculate Expressions", "Highest Correlation"],
value="None",
label="Analysis Type"
)
gr.Markdown("### Visualization Types")
viz_selector = gr.Dropdown(
choices=["None", "Bar Chart", "Line Chart", "Scatter Plot", "Pie Chart", "Histogram", "Box Plot", "Heat Map"],
value="None",
label="Chart Type"
)
with gr.Column(scale=2):
preview_heading = gr.Markdown("### Dataset Preview", visible=False)
dataset_preview = gr.Dataframe(wrap=True, visible=False)
text_preview = gr.Textbox(label="Text Preview", lines=15, visible=False)
analysis_heading = gr.Markdown("### Analysis Results", visible=False)
analysis_output = gr.Textbox(label="Analysis Output", lines=10, visible=False, interactive=False)
analysis_data_table = gr.Dataframe(label="Data Table", visible=False, wrap=True)
chart_output_new = gr.Plot(label="Chart", visible=False)
chart_explanation = gr.Textbox(label="Chart Analysis", lines=5, visible=False, interactive=False)
gr.Markdown("### Sample Questions")
with gr.Row():
for i in range(0, len(SAMPLE_QUESTIONS), 3):
with gr.Column():
for j in range(3):
if i + j < len(SAMPLE_QUESTIONS):
gr.Markdown(f"• {SAMPLE_QUESTIONS[i + j]}")
gr.Markdown("### Ask Your Question")
user_question = gr.Textbox(
label="Enter your question",
placeholder="Ask anything about your data...",
lines=3
)
submit_btn = gr.Button("Analyze", variant="primary", size="lg")
gr.Markdown("### Analysis Results")
with gr.Tabs():
with gr.Tab("Response"):
output_text = gr.Textbox(
label="Analysis Response",
interactive=False,
lines=15,
show_copy_button=True
)
with gr.Tab("Visualization"):
chart_output = gr.Plot(label="Generated Chart")
with gr.Tab("Data"):
result_table = gr.Dataframe(label="Result Data", wrap=True)
file_input.change(
upload_dataset,
inputs=file_input,
outputs=[dataset_info, dataset_preview, column_selector, column_selector]
)
clear_btn.click(
clear_dataset,
outputs=[dataset_info, dataset_preview, column_selector, column_selector]
)
def update_preview(format_type, selected_columns):
if format_type == "None":
return None, "", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
elif format_type == "DataFrame":
return display_data_format(format_type, selected_columns), "", gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)
else:
return None, display_text_format(format_type, selected_columns), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
def handle_analysis_change(analysis_type, selected_columns):
result_text, heading_update, data_table = run_analysis(analysis_type, selected_columns)
if result_text and result_text.strip() and analysis_type != "None":
if data_table is not None:
return gr.update(value=result_text, visible=True), gr.update(visible=True), gr.update(value=data_table, visible=True)
else:
return gr.update(value=result_text, visible=True), gr.update(visible=True), gr.update(visible=False)
else:
return gr.update(value="", visible=False), gr.update(visible=False), gr.update(visible=False)
analysis_selector.change(
handle_analysis_change,
inputs=[analysis_selector, column_selector],
outputs=[analysis_output, analysis_heading, analysis_data_table]
)
def update_viz_display(viz_type, selected_columns):
result = create_visualization(viz_type, selected_columns)
if result and len(result) == 3:
fig, explanation, chart_visible = result
return fig, explanation, chart_visible
else:
return None, "Error in visualization", gr.update(visible=False)
def handle_viz_change(viz_type, selected_columns):
fig, explanation, chart_visible = update_viz_display(viz_type, selected_columns)
if explanation:
return fig, chart_visible, explanation, gr.update(visible=True)
else:
return fig, chart_visible, "", gr.update(visible=False)
viz_selector.change(
handle_viz_change,
inputs=[viz_selector, column_selector],
outputs=[chart_output_new, chart_output_new, chart_explanation, chart_explanation]
)
format_selector.change(
update_preview,
inputs=[format_selector, column_selector],
outputs=[dataset_preview, text_preview, dataset_preview, text_preview, preview_heading]
)
column_selector.change(
update_preview,
inputs=[format_selector, column_selector],
outputs=[dataset_preview, text_preview, dataset_preview, text_preview, preview_heading]
)
column_selector.change(
handle_analysis_change,
inputs=[analysis_selector, column_selector],
outputs=[analysis_output, analysis_heading, analysis_data_table]
)
column_selector.change(
handle_viz_change,
inputs=[viz_selector, column_selector],
outputs=[chart_output_new, chart_output_new, chart_explanation, chart_explanation]
)
submit_btn.click(
analyze_question,
inputs=[user_question, column_selector],
outputs=[output_text, chart_output, result_table]
)
gr.HTML("<div style='text-align: center; margin-top: 20px; color: #666;'>Powered by GROQ LLM & Gradio</div>")
if __name__ == "__main__":
if not initialize_llm():
print("Warning: Failed to initialize GROQ API")
demo.launch(show_error=True, share=True)