pavanmutha commited on
Commit
44fb1b0
·
verified ·
1 Parent(s): 01fbc32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -336
app.py CHANGED
@@ -1,359 +1,138 @@
1
  import gradio as gr
2
- from smolagents import HfApiModel, CodeAgent
3
- from huggingface_hub import login
4
- import os
5
- import shutil
6
- import wandb
7
- import time
8
- import psutil
9
- import optuna
10
- import ast
11
  import pandas as pd
12
- import numpy as np
13
- from sklearn.model_selection import train_test_split
14
- from sklearn.ensemble import RandomForestClassifier
15
- from sklearn.metrics import (accuracy_score, precision_score,
16
- recall_score, f1_score, classification_report)
17
- from sklearn.preprocessing import LabelEncoder
18
  import shap
19
- import lime
20
  import lime.lime_tabular
 
 
21
  import matplotlib.pyplot as plt
22
  import seaborn as sns
23
- from optuna.visualization import (plot_optimization_history,
24
- plot_param_importances,
25
- plot_parallel_coordinate)
26
- from PIL import Image
27
- import base64
28
- from io import BytesIO
29
 
30
- # Authenticate Hugging Face
31
- hf_token = os.getenv("HF_TOKEN")
32
- login(token=hf_token, add_to_git_credential=True)
33
 
34
- # Initialize Model
35
- model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
36
 
37
- # Initialize W&B
38
- wandb.login(key=os.environ.get('WANDB_API_KEY'))
39
 
40
- def save_figure(fig, filename):
41
- """Helper function to save matplotlib figures"""
42
- os.makedirs('./figures', exist_ok=True)
43
- path = f"./figures/{filename}"
44
- fig.savefig(path, bbox_inches='tight')
45
- plt.close(fig)
46
- return path
 
47
 
48
- def encode_categorical_data(df):
49
- """Encode categorical columns and return encoded df and encoders"""
50
- encoders = {}
51
- for col in df.select_dtypes(include=['object', 'category']).columns:
52
- le = LabelEncoder()
53
- df[col] = le.fit_transform(df[col].astype(str))
54
- encoders[col] = le
55
- return df, encoders
56
 
57
- def generate_data_insights(df):
58
- """Generate insights using smolagent"""
59
- agent = CodeAgent(
60
- tools=[],
61
- model=model,
62
- additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn"]
63
- )
64
-
65
- prompt = """
66
- Analyze this dataset and provide:
67
- 1. 5 key statistical insights about the data
68
- 2. 5 suggested visualizations with explanations
69
- 3. Data quality assessment
70
- 4. Recommendations for preprocessing
71
-
72
- For each insight:
73
- - Explain its significance
74
- - Provide the Python code to verify it
75
- - Suggest potential actions
76
-
77
- Return the results as a dictionary with:
78
- - 'insights': List of 5 key insights
79
- - 'visualizations': List of 5 visualization descriptions with code
80
- - 'quality': Data quality assessment
81
- - 'recommendations': Preprocessing recommendations
82
- """
83
-
84
- return agent.run(prompt, additional_args={"df": df})
85
 
86
- def create_visualizations(df, insights):
87
- """Create visualizations based on insights"""
88
- visuals = []
 
89
  try:
90
- # Visualization 1: Missing values heatmap
91
- if df.isnull().any().any():
92
- fig, ax = plt.subplots(figsize=(10, 6))
93
- sns.heatmap(df.isnull(), cbar=False, ax=ax)
94
- plt.title("Missing Values Heatmap")
95
- visuals.append(save_figure(fig, "missing_values.png"))
96
-
97
- # Visualization 2: Correlation heatmap
98
- numeric_cols = df.select_dtypes(include=np.number).columns
99
- if len(numeric_cols) > 1:
100
- fig, ax = plt.subplots(figsize=(10, 8))
101
- sns.heatmap(df[numeric_cols].corr(), annot=True, cmap='coolwarm', ax=ax)
102
- plt.title("Correlation Heatmap")
103
- visuals.append(save_figure(fig, "correlation_heatmap.png"))
104
-
105
- # Visualization 3: Feature distributions
106
- for col in numeric_cols[:3]: # First 3 numeric columns
107
- fig, ax = plt.subplots(figsize=(10, 6))
108
- sns.histplot(df[col], kde=True, ax=ax)
109
- plt.title(f"Distribution of {col}")
110
- visuals.append(save_figure(fig, f"distribution_{col}.png"))
111
-
112
- # Visualization 4: Pairplot (sample if large)
113
- if len(numeric_cols) > 1:
114
- fig = sns.pairplot(df[numeric_cols].sample(min(100, len(df))))
115
- visuals.append(save_figure(fig, "pairplot.png"))
116
-
117
- # Visualization 5: Categorical counts
118
- cat_cols = df.select_dtypes(include=['object', 'category']).columns
119
- for col in cat_cols[:2]: # First 2 categorical columns
120
- fig, ax = plt.subplots(figsize=(10, 6))
121
- df[col].value_counts().plot(kind='bar', ax=ax)
122
- plt.title(f"Count of {col}")
123
- visuals.append(save_figure(fig, f"count_{col}.png"))
124
-
125
  except Exception as e:
126
- print(f"Visualization error: {e}")
127
-
128
- return visuals
129
 
130
- def analyze_data(csv_file, additional_notes=""):
131
- """Main data analysis function"""
132
- start_time = time.time()
133
-
134
- # Initialize W&B run
135
- run = wandb.init(project="data-analysis", config={
136
- "model": "Mixtral-8x7B",
137
- "notes": additional_notes,
138
- "file": csv_file.name if csv_file else None
139
- })
140
-
141
- try:
142
- # Load data
143
- df = pd.read_csv(csv_file)
144
-
145
- # Generate insights with smolagent
146
- insights = generate_data_insights(df)
147
-
148
- # Create visualizations
149
- visuals = create_visualizations(df, insights)
150
-
151
- # Log to W&B
152
- for viz in visuals:
153
- wandb.log({"visualizations": wandb.Image(viz)})
154
-
155
- # Format report
156
- report = format_analysis_report(insights, visuals)
157
-
158
- # Track performance
159
- execution_time = time.time() - start_time
160
- wandb.log({"execution_time": execution_time})
161
-
162
- return report, visuals
163
-
164
- except Exception as e:
165
- return f"Error: {str(e)}", []
166
- finally:
167
- run.finish()
168
 
169
- def objective(trial, X, y):
170
- """Optuna objective function for hyperparameter tuning"""
171
- params = {
172
- 'n_estimators': trial.suggest_int('n_estimators', 50, 500),
173
- 'max_depth': trial.suggest_int('max_depth', 3, 15),
174
- 'min_samples_split': trial.suggest_int('min_samples_split', 2, 10),
175
- 'min_samples_leaf': trial.suggest_int('min_samples_leaf', 1, 5),
176
- 'max_features': trial.suggest_categorical('max_features', ['sqrt', 'log2', None]),
177
- 'bootstrap': trial.suggest_categorical('bootstrap', [True, False]),
178
- 'criterion': trial.suggest_categorical('criterion', ['gini', 'entropy'])
179
- }
180
-
181
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
182
-
183
- model = RandomForestClassifier(**params, random_state=42, n_jobs=-1)
184
- model.fit(X_train, y_train)
185
-
186
- y_pred = model.predict(X_test)
187
-
188
- # Track multiple metrics
189
- metrics = {
190
- 'accuracy': accuracy_score(y_test, y_pred),
191
- 'precision': precision_score(y_test, y_pred, average='weighted'),
192
- 'recall': recall_score(y_test, y_pred, average='weighted'),
193
- 'f1': f1_score(y_test, y_pred, average='weighted')
194
  }
195
-
196
- # Log to W&B
197
- wandb.log({**params, **metrics})
198
-
199
- return metrics['accuracy']
200
 
201
- def tune_hyperparameters(csv_file, n_trials=50):
202
- """Hyperparameter tuning with Optuna and W&B"""
203
- run = wandb.init(project="hyperparameter-tuning", config={
204
- "n_trials": n_trials,
205
- "model": "RandomForest"
206
- })
207
-
208
- try:
209
- # Load and prepare data
210
- df = pd.read_csv(csv_file)
211
- df, _ = encode_categorical_data(df)
212
-
213
- y = df.iloc[:, -1] # Assume last column is target
214
- X = df.iloc[:, :-1]
215
-
216
- # Optuna study
217
- study = optuna.create_study(
218
- direction='maximize',
219
- sampler=optuna.samplers.TPESampler(),
220
- pruner=optuna.pruners.MedianPruner()
221
- )
222
-
223
- study.optimize(lambda trial: objective(trial, X, y), n_trials=n_trials)
224
-
225
- # Generate visualizations
226
- visuals = []
227
- fig = plot_optimization_history(study)
228
- visuals.append(save_figure(fig, "optimization_history.png"))
229
-
230
- fig = plot_param_importances(study)
231
- visuals.append(save_figure(fig, "param_importance.png"))
232
-
233
- fig = plot_parallel_coordinate(study)
234
- visuals.append(save_figure(fig, "parallel_coordinate.png"))
235
-
236
- # Train best model
237
- best_model = RandomForestClassifier(**study.best_params, random_state=42)
238
- best_model.fit(X, y)
239
-
240
- # SHAP explainability
241
- explainer = shap.TreeExplainer(best_model)
242
- shap_values = explainer.shap_values(X)
243
-
244
- fig, ax = plt.subplots(figsize=(10, 8))
245
- shap.summary_plot(shap_values, X, show=False)
246
- visuals.append(save_figure(fig, "shap_summary.png"))
247
-
248
- # LIME explainability
249
- explainer = lime.lime_tabular.LimeTabularExplainer(
250
- X.values,
251
- feature_names=X.columns,
252
- class_names=['class_0', 'class_1'], # Modify as needed
253
- mode='classification'
254
- )
255
- exp = explainer.explain_instance(X.iloc[0].values, best_model.predict_proba)
256
- fig = exp.as_pyplot_figure()
257
- visuals.append(save_figure(fig, "lime_explanation.png"))
258
-
259
- # Format results
260
- report = format_tuning_results(study, best_model, X, y)
261
-
262
- return report, visuals
263
-
264
- except Exception as e:
265
- return f"Error: {str(e)}", []
266
- finally:
267
- run.finish()
268
 
269
- def format_analysis_report(insights, visuals):
270
- """Format the analysis report with insights and visuals"""
271
- report = f"""
272
- <div style="font-family: Arial; max-width: 1000px; margin: 0 auto;">
273
- <h1 style="color: #2B547E;">📊 Data Analysis Report</h1>
274
-
275
- <div style="margin-top: 20px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
276
- <h2 style="color: #2B547E;">🔍 Key Insights</h2>
277
- {format_insights_section(insights.get('insights', []))}
278
- </div>
279
-
280
- <div style="margin-top: 30px;">
281
- <h2 style="color: #2B547E;">📈 Visualizations</h2>
282
- {format_visualizations(visuals)}
283
- </div>
284
- </div>
285
- """
286
- return report
287
 
288
- def format_tuning_results(study, model, X, y):
289
- """Format hyperparameter tuning results"""
290
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
291
- y_pred = model.predict(X_test)
292
-
293
- report = f"""
294
- <div style="font-family: Arial; max-width: 1000px; margin: 0 auto;">
295
- <h1 style="color: #2B547E;">⚙️ Hyperparameter Tuning Results</h1>
296
-
297
- <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 20px; margin-top: 20px;">
298
- <div style="background: #f8f9fa; padding: 20px; border-radius: 8px;">
299
- <h2 style="color: #2B547E;">📊 Best Parameters</h2>
300
- <pre>{study.best_params}</pre>
301
- </div>
302
-
303
- <div style="background: #f8f9fa; padding: 20px; border-radius: 8px;">
304
- <h2 style="color: #2B547E;">📈 Performance Metrics</h2>
305
- <p>Accuracy: {accuracy_score(y_test, y_pred):.4f}</p>
306
- <p>Precision: {precision_score(y_test, y_pred, average='weighted'):.4f}</p>
307
- <p>Recall: {recall_score(y_test, y_pred, average='weighted'):.4f}</p>
308
- <p>F1 Score: {f1_score(y_test, y_pred, average='weighted'):.4f}</p>
309
- </div>
310
- </div>
311
-
312
- <div style="margin-top: 30px;">
313
- <h2 style="color: #2B547E;">🔍 Classification Report</h2>
314
- <pre>{classification_report(y_test, y_pred)}</pre>
315
- </div>
316
- </div>
317
- """
318
- return report
319
 
320
- # Create Gradio interface
321
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
322
- gr.Markdown("# 🧠 Advanced Data Analysis with AI")
323
-
324
- with gr.Tab("Data Analysis"):
325
- with gr.Row():
326
- with gr.Column():
327
- data_file = gr.File(label="Upload CSV", file_types=[".csv"])
328
- notes = gr.Textbox(label="Analysis Notes (Optional)", lines=3)
329
- analyze_btn = gr.Button("Analyze Data", variant="primary")
330
-
331
- with gr.Column():
332
- analysis_report = gr.HTML(label="Analysis Report")
333
- viz_gallery = gr.Gallery(label="Visualizations")
334
-
335
- with gr.Tab("Model Tuning"):
336
- with gr.Row():
337
- with gr.Column():
338
- tune_file = gr.File(label="Upload CSV for Tuning", file_types=[".csv"])
339
- trials = gr.Slider(10, 200, value=50, label="Number of Trials")
340
- tune_btn = gr.Button("Tune Hyperparameters", variant="primary")
341
-
342
- with gr.Column():
343
- tuning_report = gr.HTML(label="Tuning Results")
344
- tuning_viz = gr.Gallery(label="Tuning Visualizations")
345
-
346
- # Event handlers
347
- analyze_btn.click(
348
- fn=analyze_data,
349
- inputs=[data_file, notes],
350
- outputs=[analysis_report, viz_gallery]
351
- )
352
-
353
- tune_btn.click(
354
- fn=tune_hyperparameters,
355
- inputs=[tune_file, trials],
356
- outputs=[tuning_report, tuning_viz]
357
- )
358
 
359
- demo.launch()
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
2
  import pandas as pd
 
 
 
 
 
 
3
  import shap
 
4
  import lime.lime_tabular
5
+ import wandb
6
+ import optuna
7
  import matplotlib.pyplot as plt
8
  import seaborn as sns
9
+ import tempfile
10
+ import os
 
 
 
 
11
 
12
+ from sklearn.ensemble import RandomForestClassifier
13
+ from sklearn.model_selection import train_test_split, cross_val_score
14
+ from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score
15
 
16
+ from huggingface_hub import login
17
+ from smolagents import HfApiModel, CodeAgent
18
 
19
+ # Authenticate with Hugging Face using environment token
20
+ login(os.getenv("HUGGINGFACEHUB_API_TOKEN"))
21
 
22
+ # Initialize LLM model and CodeAgent
23
+ llm_model = HfApiModel("meta-llama/Llama-3.1-70B-Instruct")
24
+ agent = CodeAgent(
25
+ tools=[],
26
+ model=llm_model,
27
+ additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn"],
28
+ max_iterations=10,
29
+ )
30
 
31
+ # Global DataFrame
32
+ df_global = None
 
 
 
 
 
 
33
 
34
+ # Load and clean data
35
+ def load_data(file):
36
+ global df_global
37
+ ext = os.path.splitext(file.name)[-1]
38
+ if ext in [".csv"]:
39
+ df = pd.read_csv(file.name)
40
+ else:
41
+ df = pd.read_excel(file.name)
42
+ df = df.dropna(how='all', axis=1).dropna(how='all', axis=0)
43
+ df = df.fillna(df.mean(numeric_only=True))
44
+ df_global = df
45
+ return df.head()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ # Use SmolAgent to generate insights and visuals
48
+ def get_insights(_):
49
+ if df_global is None:
50
+ return "No data loaded yet."
51
  try:
52
+ result = agent.run(df_global, instructions="Generate 5 data insights and 5 data visualizations.")
53
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  except Exception as e:
55
+ return f"Error from SmolAgent: {e}"
 
 
56
 
57
+ # Train model + hyperparameter tuning
58
+ def run_model(_):
59
+ wandb_run = wandb.init(project="huggingface_smol_data_analysis", name="Optuna_Tuning", reinit=True)
60
+ target = df_global.columns[-1]
61
+ X = df_global.drop(target, axis=1)
62
+ y = df_global[target]
63
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ def objective(trial):
66
+ n_estimators = trial.suggest_int("n_estimators", 10, 200)
67
+ max_depth = trial.suggest_int("max_depth", 2, 32, log=True)
68
+ clf = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth)
69
+ score = cross_val_score(clf, X_train, y_train, n_jobs=-1, cv=3).mean()
70
+ wandb.log({"cv_score": score, "n_estimators": n_estimators, "max_depth": max_depth})
71
+ return score
72
+
73
+ study = optuna.create_study(direction="maximize")
74
+ study.optimize(objective, n_trials=20)
75
+
76
+ best_params = study.best_params
77
+ best_model = RandomForestClassifier(**best_params)
78
+ best_model.fit(X_train, y_train)
79
+ y_pred = best_model.predict(X_test)
80
+
81
+ scores = {
82
+ "accuracy": accuracy_score(y_test, y_pred),
83
+ "precision": precision_score(y_test, y_pred, average="weighted", zero_division=0),
84
+ "recall": recall_score(y_test, y_pred, average="weighted", zero_division=0),
85
+ "f1_score": f1_score(y_test, y_pred, average="weighted", zero_division=0)
 
 
 
 
86
  }
 
 
 
 
 
87
 
88
+ wandb.log(scores)
89
+ wandb_run.finish()
90
+
91
+ top_params_report = pd.DataFrame(study.trials_dataframe().sort_values(by="value", ascending=False).head(7))
92
+
93
+ return scores, top_params_report
94
+
95
+ # SHAP + LIME Explainability
96
+ def explainability(_):
97
+ target = df_global.columns[-1]
98
+ X = df_global.drop(target, axis=1)
99
+ y = df_global[target]
100
+
101
+ model = RandomForestClassifier()
102
+ model.fit(X, y)
103
+
104
+ explainer = shap.Explainer(model, X)
105
+ shap_values = explainer(X)
106
+ shap.plots.beeswarm(shap_values, show=False)
107
+ plt.tight_layout()
108
+ shap_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
109
+ plt.savefig(shap_file.name)
110
+ plt.close()
111
+
112
+ lime_explainer = lime.lime_tabular.LimeTabularExplainer(X.values, feature_names=X.columns, class_names=list(set(y)), discretize_continuous=True)
113
+ exp = lime_explainer.explain_instance(X.iloc[0].values, model.predict_proba)
114
+ lime_html = exp.as_html()
115
+
116
+ wandb.log({"shap": wandb.Image(shap_file.name), "lime": lime_html})
117
+
118
+ return shap_file.name, lime_html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
+ # Gradio UI
121
+ with gr.Blocks() as demo:
122
+ with gr.Row():
123
+ upload = gr.File(label="Upload CSV or Excel", type="file")
124
+ load_btn = gr.Button("Load & Analyze Data")
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ data_output = gr.DataFrame()
127
+ insights_output = gr.Textbox(label="Insights & Visuals (SmolAgent)", lines=15)
128
+ model_scores = gr.JSON(label="Model Performance Scores")
129
+ param_table = gr.DataFrame(label="Top 7 Hyperparameters")
130
+ shap_img = gr.Image(label="SHAP Plot")
131
+ lime_out = gr.HTML(label="LIME Explanation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ load_btn.click(fn=load_data, inputs=upload, outputs=data_output)
134
+ load_btn.click(fn=get_insights, inputs=data_output, outputs=insights_output)
135
+ load_btn.click(fn=run_model, inputs=data_output, outputs=[model_scores, param_table])
136
+ load_btn.click(fn=explainability, inputs=data_output, outputs=[shap_img, lime_out])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ demo.launch()