DataWizard9742 commited on
Commit
3c16f2f
Β·
verified Β·
1 Parent(s): dae00bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +296 -0
app.py CHANGED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import textwrap
4
+ import tempfile
5
+
6
+ import pandas as pd
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ import seaborn as sns
10
+ import gradio as gr
11
+
12
+ from openai import OpenAI
13
+
14
+ # --------- OpenAI client helper ---------
15
+ def get_client(api_key: str = None):
16
+ key = api_key or os.getenv("OPENAI_API_KEY")
17
+ if not key:
18
+ raise ValueError("OpenAI API key not provided. "
19
+ "Either set OPENAI_API_KEY env var or pass it in the UI.")
20
+ return OpenAI(api_key=key)
21
+
22
+
23
+ # --------- Data summarisation helpers ---------
24
+ def summarize_dataframe(df: pd.DataFrame, max_cols=15, max_rows=5) -> str:
25
+ buf = []
26
+
27
+ # Basic info
28
+ buf.append("### 1. Basic Structure")
29
+ buf.append(f"- Number of rows: {df.shape[0]}")
30
+ buf.append(f"- Number of columns: {df.shape[1]}")
31
+ buf.append("")
32
+
33
+ # Dtypes
34
+ buf.append("### 2. Column Types")
35
+ dtypes_summary = df.dtypes.astype(str).value_counts()
36
+ for t, c in dtypes_summary.items():
37
+ buf.append(f"- {t}: {c} columns")
38
+ buf.append("")
39
+
40
+ # Per-column summary
41
+ buf.append("### 3. Column-wise Summary")
42
+ cols_to_show = df.columns[:max_cols]
43
+ for col in cols_to_show:
44
+ series = df[col]
45
+ col_info = [f"**Column:** {col}"]
46
+ col_info.append(f"- dtype: {series.dtype}")
47
+ col_info.append(f"- Missing values: {series.isna().sum()} "
48
+ f"({series.isna().mean():.2%} of rows)")
49
+
50
+ if pd.api.types.is_numeric_dtype(series):
51
+ desc = series.describe()
52
+ col_info.append(
53
+ "- Stats: "
54
+ f"min={desc['min']:.4g}, "
55
+ f"25%={desc['25%']:.4g}, "
56
+ f"mean={desc['mean']:.4g}, "
57
+ f"50%={desc['50%']:.4g}, "
58
+ f"75%={desc['75%']:.4g}, "
59
+ f"max={desc['max']:.4g}"
60
+ )
61
+ else:
62
+ # Categorical/text summary
63
+ nunique = series.nunique(dropna=True)
64
+ top_vals = series.value_counts(dropna=True).head(5)
65
+ col_info.append(f"- Unique values (non-null): {nunique}")
66
+ tv_str = ", ".join([f"{idx} ({val})" for idx, val in top_vals.items()])
67
+ col_info.append(f"- Top values: {tv_str}")
68
+
69
+ buf.append("\n".join(col_info))
70
+ buf.append("")
71
+
72
+ if df.shape[1] > max_cols:
73
+ buf.append(f"... ({df.shape[1] - max_cols} more columns not listed here)")
74
+ buf.append("")
75
+
76
+ # Correlation summary for numeric columns
77
+ num_cols = df.select_dtypes(include=[np.number]).columns
78
+ if len(num_cols) >= 2:
79
+ buf.append("### 4. Numeric Correlations (Top pairs)")
80
+ corr = df[num_cols].corr().abs()
81
+ # Get upper triangle pairs
82
+ pairs = []
83
+ for i in range(len(num_cols)):
84
+ for j in range(i + 1, len(num_cols)):
85
+ pairs.append((num_cols[i], num_cols[j], corr.iloc[i, j]))
86
+ pairs.sort(key=lambda x: x[2], reverse=True)
87
+ top_pairs = pairs[:10]
88
+ for a, b, v in top_pairs:
89
+ buf.append(f"- {a} vs {b}: correlation={v:.3f}")
90
+ buf.append("")
91
+
92
+ # Small sample of rows
93
+ buf.append("### 5. Sample Rows")
94
+ sample = df.head(max_rows)
95
+ buf.append(sample.to_markdown(index=False))
96
+
97
+ return "\n".join(buf)
98
+
99
+
100
+ # --------- Plotting helpers ---------
101
+ def make_distribution_plots(df: pd.DataFrame, max_numeric=4, max_categorical=4):
102
+ plots = []
103
+
104
+ # Numeric distributions
105
+ num_cols = df.select_dtypes(include=[np.number]).columns[:max_numeric]
106
+ for col in num_cols:
107
+ fig, ax = plt.subplots()
108
+ sns.histplot(df[col].dropna(), kde=True, ax=ax)
109
+ ax.set_title(f"Distribution of {col}")
110
+ ax.set_xlabel(col)
111
+ ax.set_ylabel("Count")
112
+ plt.tight_layout()
113
+ plots.append(fig)
114
+
115
+ # Categorical distributions
116
+ cat_cols = df.select_dtypes(exclude=[np.number]).columns[:max_categorical]
117
+ for col in cat_cols:
118
+ fig, ax = plt.subplots()
119
+ value_counts = df[col].value_counts().head(15)
120
+ sns.barplot(x=value_counts.values, y=value_counts.index, ax=ax)
121
+ ax.set_title(f"Top categories in {col}")
122
+ ax.set_xlabel("Count")
123
+ ax.set_ylabel(col)
124
+ plt.tight_layout()
125
+ plots.append(fig)
126
+
127
+ # Correlation heatmap
128
+ if len(df.select_dtypes(include=[np.number]).columns) >= 2:
129
+ fig, ax = plt.subplots(figsize=(6, 5))
130
+ corr = df.select_dtypes(include=[np.number]).corr()
131
+ sns.heatmap(corr, annot=False, cmap="coolwarm", ax=ax)
132
+ ax.set_title("Correlation Heatmap (Numeric Features)")
133
+ plt.tight_layout()
134
+ plots.append(fig)
135
+
136
+ return plots
137
+
138
+
139
+ # --------- OpenAI analysis ---------
140
+ def generate_ai_report(df_summary: str, api_key: str = None, model: str = "gpt-4.1-mini") -> str:
141
+ """
142
+ Sends the structured summary to OpenAI and gets a very detailed report.
143
+ """
144
+ client = get_client(api_key)
145
+
146
+ system_msg = (
147
+ "You are a senior data analyst. You receive a structured summary of a dataset. "
148
+ "Your job is to produce a VERY detailed, structured analysis report.\n\n"
149
+ "Your report MUST include at least these sections with clear headings:\n"
150
+ "1. Dataset Overview (rows, columns, column types, what this might be about)\n"
151
+ "2. Data Quality & Missing Values (what is good/bad, issues, suggestions)\n"
152
+ "3. Univariate Analysis (patterns in individual columns: numeric & categorical)\n"
153
+ "4. Bivariate & Correlation Insights (relationships between key columns)\n"
154
+ "5. Potential Target Variables & Use Cases (what could be predicted or modelled)\n"
155
+ '6. Feature Engineering Ideas (new variables or transformations to create)\n'
156
+ "7. Potential Visualizations (suggest specific plots and what they would reveal)\n"
157
+ "8. Risks, Biases & Limitations of this dataset\n"
158
+ "9. Recommended Next Steps for deeper analysis or modelling.\n\n"
159
+ "Be concrete and descriptive. Use bullet points and short paragraphs. "
160
+ "Assume the user understands basic data science but wants expert-level insight."
161
+ )
162
+
163
+ user_msg = (
164
+ "Here is a detailed summary of the dataset. Use ONLY this information in your reasoning; "
165
+ "do not invent columns that are not mentioned.\n\n"
166
+ f"{df_summary}"
167
+ )
168
+
169
+ response = client.responses.create(
170
+ model=model,
171
+ reasoning={"effort": "medium"},
172
+ input=[
173
+ {
174
+ "role": "system",
175
+ "content": system_msg,
176
+ },
177
+ {
178
+ "role": "user",
179
+ "content": user_msg,
180
+ },
181
+ ],
182
+ max_output_tokens=1800,
183
+ )
184
+
185
+ # Extract text from the first output
186
+ chunks = []
187
+ for item in response.output[0].content:
188
+ if item.type == "output_text":
189
+ chunks.append(item.text)
190
+
191
+ return "\n".join(chunks).strip()
192
+
193
+
194
+ # --------- Main Gradio function ---------
195
+ def analyze_dataset(file, api_key, model_name, sample_rows, max_cols_summary):
196
+ if file is None:
197
+ return "Please upload a CSV file.", None
198
+
199
+ try:
200
+ # Read CSV
201
+ df = pd.read_csv(file.name)
202
+
203
+ # Optional sampling for very large datasets
204
+ if sample_rows and df.shape[0] > sample_rows:
205
+ df = df.sample(sample_rows, random_state=42)
206
+
207
+ # Build summary for the LLM
208
+ df_summary = summarize_dataframe(df, max_cols=max_cols_summary)
209
+ ai_report = generate_ai_report(df_summary, api_key=api_key, model=model_name)
210
+
211
+ # Generate plots
212
+ figs = make_distribution_plots(df)
213
+
214
+ return ai_report, figs
215
+
216
+ except Exception as e:
217
+ return f"❌ Error while processing file: {e}", None
218
+
219
+
220
+ # --------- Build Gradio UI ---------
221
+ def build_interface():
222
+ with gr.Blocks(title="AI Data Analyst", theme=gr.themes.Soft()) as demo:
223
+ gr.Markdown(
224
+ """
225
+ # πŸ“Š AI Data Analyst – Dataset Explorer
226
+
227
+ Upload a CSV dataset and let an OpenAI model act as your **senior data analyst**.
228
+
229
+ - βœ… Automatic structural summary (rows, columns, types, missingness)
230
+ - βœ… AI-generated **very detailed** analysis report
231
+ - βœ… Auto-generated plots (distributions & correlation heatmap)
232
+
233
+ **Note:** For security, prefer setting your `OPENAI_API_KEY` as an environment variable
234
+ instead of typing it in the UI.
235
+ """
236
+ )
237
+
238
+ with gr.Row():
239
+ with gr.Column(scale=1):
240
+ file_input = gr.File(label="Upload CSV file", file_types=[".csv"])
241
+
242
+ api_key_input = gr.Textbox(
243
+ label="OpenAI API Key (optional, leave blank to use environment variable)",
244
+ type="password",
245
+ placeholder="sk-...",
246
+ )
247
+
248
+ model_dropdown = gr.Dropdown(
249
+ label="OpenAI Model",
250
+ choices=["gpt-4.1-mini", "gpt-4.1"],
251
+ value="gpt-4.1-mini",
252
+ )
253
+
254
+ sample_rows = gr.Slider(
255
+ minimum=0,
256
+ maximum=5000,
257
+ value=2000,
258
+ step=100,
259
+ label="Max rows to sample for analysis (0 = use all rows)",
260
+ )
261
+
262
+ max_cols_summary = gr.Slider(
263
+ minimum=5,
264
+ maximum=40,
265
+ value=15,
266
+ step=1,
267
+ label="Max columns to include in text summary",
268
+ )
269
+
270
+ analyze_button = gr.Button("πŸ” Analyze Dataset", variant="primary")
271
+
272
+ with gr.Column(scale=2):
273
+ report_output = gr.Markdown(label="AI Analysis Report")
274
+ plots_output = gr.Gallery(
275
+ label="Auto-generated Plots",
276
+ columns=2,
277
+ height="auto",
278
+ preview=True,
279
+ )
280
+
281
+ def _wrapped_analyze(file, api_key, model_name, sample_rows_val, max_cols_val):
282
+ sr = int(sample_rows_val) if sample_rows_val and sample_rows_val > 0 else None
283
+ return analyze_dataset(file, api_key, model_name, sr, int(max_cols_val))
284
+
285
+ analyze_button.click(
286
+ _wrapped_analyze,
287
+ inputs=[file_input, api_key_input, model_dropdown, sample_rows, max_cols_summary],
288
+ outputs=[report_output, plots_output],
289
+ )
290
+
291
+ return demo
292
+
293
+
294
+ if __name__ == "__main__":
295
+ demo = build_interface()
296
+ demo.launch()