Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import textwrap | |
| import tempfile | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import gradio as gr | |
| from openai import OpenAI | |
| # --------- OpenAI client helper --------- | |
| def get_client(api_key: str = None): | |
| key = api_key or os.getenv("OPENAI_API_KEY") | |
| if not key: | |
| raise ValueError("OpenAI API key not provided. " | |
| "Either set OPENAI_API_KEY env var or pass it in the UI.") | |
| return OpenAI(api_key=key) | |
| # --------- Data summarisation helpers --------- | |
| def summarize_dataframe(df: pd.DataFrame, max_cols=15, max_rows=5) -> str: | |
| buf = [] | |
| # Basic info | |
| buf.append("### 1. Basic Structure") | |
| buf.append(f"- Number of rows: {df.shape[0]}") | |
| buf.append(f"- Number of columns: {df.shape[1]}") | |
| buf.append("") | |
| # Dtypes | |
| buf.append("### 2. Column Types") | |
| dtypes_summary = df.dtypes.astype(str).value_counts() | |
| for t, c in dtypes_summary.items(): | |
| buf.append(f"- {t}: {c} columns") | |
| buf.append("") | |
| # Per-column summary | |
| buf.append("### 3. Column-wise Summary") | |
| cols_to_show = df.columns[:max_cols] | |
| for col in cols_to_show: | |
| series = df[col] | |
| col_info = [f"**Column:** {col}"] | |
| col_info.append(f"- dtype: {series.dtype}") | |
| col_info.append(f"- Missing values: {series.isna().sum()} " | |
| f"({series.isna().mean():.2%} of rows)") | |
| if pd.api.types.is_numeric_dtype(series): | |
| desc = series.describe() | |
| col_info.append( | |
| "- Stats: " | |
| f"min={desc['min']:.4g}, " | |
| f"25%={desc['25%']:.4g}, " | |
| f"mean={desc['mean']:.4g}, " | |
| f"50%={desc['50%']:.4g}, " | |
| f"75%={desc['75%']:.4g}, " | |
| f"max={desc['max']:.4g}" | |
| ) | |
| else: | |
| # Categorical/text summary | |
| nunique = series.nunique(dropna=True) | |
| top_vals = series.value_counts(dropna=True).head(5) | |
| col_info.append(f"- Unique values (non-null): {nunique}") | |
| tv_str = ", ".join([f"{idx} ({val})" for idx, val in top_vals.items()]) | |
| col_info.append(f"- Top values: {tv_str}") | |
| buf.append("\n".join(col_info)) | |
| buf.append("") | |
| if df.shape[1] > max_cols: | |
| buf.append(f"... ({df.shape[1] - max_cols} more columns not listed here)") | |
| buf.append("") | |
| # Correlation summary for numeric columns | |
| num_cols = df.select_dtypes(include=[np.number]).columns | |
| if len(num_cols) >= 2: | |
| buf.append("### 4. Numeric Correlations (Top pairs)") | |
| corr = df[num_cols].corr().abs() | |
| # Get upper triangle pairs | |
| pairs = [] | |
| for i in range(len(num_cols)): | |
| for j in range(i + 1, len(num_cols)): | |
| pairs.append((num_cols[i], num_cols[j], corr.iloc[i, j])) | |
| pairs.sort(key=lambda x: x[2], reverse=True) | |
| top_pairs = pairs[:10] | |
| for a, b, v in top_pairs: | |
| buf.append(f"- {a} vs {b}: correlation={v:.3f}") | |
| buf.append("") | |
| # Small sample of rows | |
| buf.append("### 5. Sample Rows") | |
| sample = df.head(max_rows) | |
| buf.append(sample.to_markdown(index=False)) | |
| return "\n".join(buf) | |
| # --------- Plotting helpers --------- | |
| def make_distribution_plots(df: pd.DataFrame, max_numeric=4, max_categorical=4): | |
| plots = [] | |
| # Numeric distributions | |
| num_cols = df.select_dtypes(include=[np.number]).columns[:max_numeric] | |
| for col in num_cols: | |
| fig, ax = plt.subplots() | |
| sns.histplot(df[col].dropna(), kde=True, ax=ax) | |
| ax.set_title(f"Distribution of {col}") | |
| ax.set_xlabel(col) | |
| ax.set_ylabel("Count") | |
| plt.tight_layout() | |
| plots.append(fig) | |
| # Categorical distributions | |
| cat_cols = df.select_dtypes(exclude=[np.number]).columns[:max_categorical] | |
| for col in cat_cols: | |
| fig, ax = plt.subplots() | |
| value_counts = df[col].value_counts().head(15) | |
| sns.barplot(x=value_counts.values, y=value_counts.index, ax=ax) | |
| ax.set_title(f"Top categories in {col}") | |
| ax.set_xlabel("Count") | |
| ax.set_ylabel(col) | |
| plt.tight_layout() | |
| plots.append(fig) | |
| # Correlation heatmap | |
| if len(df.select_dtypes(include=[np.number]).columns) >= 2: | |
| fig, ax = plt.subplots(figsize=(6, 5)) | |
| corr = df.select_dtypes(include=[np.number]).corr() | |
| sns.heatmap(corr, annot=False, cmap="coolwarm", ax=ax) | |
| ax.set_title("Correlation Heatmap (Numeric Features)") | |
| plt.tight_layout() | |
| plots.append(fig) | |
| return plots | |
| # --------- OpenAI analysis --------- | |
| def generate_ai_report(df_summary: str, api_key: str = None, model: str = "gpt-4o-mini") -> str: | |
| client = get_client(api_key) | |
| system_msg = ( | |
| "You are a senior data analyst. You receive a structured summary of a dataset. " | |
| "Your job is to produce a VERY detailed, structured analysis report.\n\n" | |
| "Your report MUST include at least these sections:\n" | |
| "1. Dataset Overview\n" | |
| "2. Data Quality & Missing Values\n" | |
| "3. Univariate Analysis\n" | |
| "4. Bivariate & Correlation Insights\n" | |
| "5. Target Variables & Use Cases\n" | |
| "6. Feature Engineering Ideas\n" | |
| "7. Recommended Visualizations\n" | |
| "8. Risks, Biases & Limitations\n" | |
| "9. Next Steps for Modelling\n" | |
| ) | |
| user_msg = ( | |
| "Here is a detailed summary of the dataset. Use ONLY this information while reasoning:\n\n" | |
| f"{df_summary}" | |
| ) | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=[ | |
| {"role": "system", "content": system_msg}, | |
| {"role": "user", "content": user_msg}, | |
| ], | |
| max_tokens=2000, | |
| temperature=0.7 | |
| ) | |
| return response.choices[0].message.content | |
| # Extract text from the first output | |
| chunks = [] | |
| for item in response.output[0].content: | |
| if item.type == "output_text": | |
| chunks.append(item.text) | |
| return "\n".join(chunks).strip() | |
| # --------- Main Gradio function --------- | |
| def analyze_dataset(file, api_key, model_name, sample_rows, max_cols_summary): | |
| if file is None: | |
| return "Please upload a CSV file.", None | |
| try: | |
| # Read CSV | |
| df = pd.read_csv(file.name) | |
| # Optional sampling for very large datasets | |
| if sample_rows and df.shape[0] > sample_rows: | |
| df = df.sample(sample_rows, random_state=42) | |
| # Build summary for the LLM | |
| df_summary = summarize_dataframe(df, max_cols=max_cols_summary) | |
| ai_report = generate_ai_report(df_summary, api_key=api_key, model=model_name) | |
| # Generate plots | |
| figs = make_distribution_plots(df) | |
| return ai_report, figs | |
| except Exception as e: | |
| return f"β Error while processing file: {e}", None | |
| # --------- Build Gradio UI --------- | |
| def build_interface(): | |
| with gr.Blocks(title="AI Data Analyst", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # π AI Data Analyst β Dataset Explorer | |
| Upload a CSV dataset and let an OpenAI model act as your **senior data analyst**. | |
| - β Automatic structural summary (rows, columns, types, missingness) | |
| - β AI-generated **very detailed** analysis report | |
| - β Auto-generated plots (distributions & correlation heatmap) | |
| **Note:** For security, prefer setting your `OPENAI_API_KEY` as an environment variable | |
| instead of typing it in the UI. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| file_input = gr.File(label="Upload CSV file", file_types=[".csv"]) | |
| api_key_input = gr.Textbox( | |
| label="OpenAI API Key (optional, leave blank to use environment variable)", | |
| type="password", | |
| placeholder="sk-...", | |
| ) | |
| model_dropdown = gr.Dropdown( | |
| label="OpenAI Model", | |
| choices=["gpt-4o-mini", "gpt-4o", "gpt-4.1-mini", "gpt-4.1"], | |
| value="gpt-4o-mini", | |
| ) | |
| sample_rows = gr.Slider( | |
| minimum=0, | |
| maximum=5000, | |
| value=2000, | |
| step=100, | |
| label="Max rows to sample for analysis (0 = use all rows)", | |
| ) | |
| max_cols_summary = gr.Slider( | |
| minimum=5, | |
| maximum=40, | |
| value=15, | |
| step=1, | |
| label="Max columns to include in text summary", | |
| ) | |
| analyze_button = gr.Button("π Analyze Dataset", variant="primary") | |
| with gr.Column(scale=2): | |
| report_output = gr.Markdown(label="AI Analysis Report") | |
| plots_output = gr.Gallery( | |
| label="Auto-generated Plots", | |
| columns=2, | |
| height="auto", | |
| preview=True, | |
| ) | |
| def _wrapped_analyze(file, api_key, model_name, sample_rows_val, max_cols_val): | |
| sr = int(sample_rows_val) if sample_rows_val and sample_rows_val > 0 else None | |
| return analyze_dataset(file, api_key, model_name, sr, int(max_cols_val)) | |
| analyze_button.click( | |
| _wrapped_analyze, | |
| inputs=[file_input, api_key_input, model_dropdown, sample_rows, max_cols_summary], | |
| outputs=[report_output, plots_output], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_interface() | |
| demo.launch() |