pavanmutha commited on
Commit
6b42202
·
verified ·
1 Parent(s): 33d02da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -168
app.py CHANGED
@@ -1,204 +1,164 @@
1
- import gradio as gr
2
- from smolagents import HfApiModel, CodeAgent
3
- from huggingface_hub import login
4
  import os
5
- import shutil
6
- import wandb
7
- import time
8
- import psutil
 
 
9
  import optuna
 
10
  import ast
11
- import pandas as pd
12
- from sklearn.model_selection import train_test_split
13
  from sklearn.ensemble import RandomForestClassifier
 
14
  from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
15
- import shap
16
- import lime
17
- import lime.lime_tabular
18
- import matplotlib.pyplot as plt
19
 
20
- # Authenticate Hugging Face
21
  hf_token = os.getenv("HF_TOKEN")
22
  login(token=hf_token)
23
 
24
- # Initialize Model
25
  model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
26
-
27
- def format_analysis_report(raw_output, visuals):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  try:
29
- if isinstance(raw_output, dict):
30
- analysis_dict = raw_output
31
- else:
32
- try:
33
- analysis_dict = ast.literal_eval(str(raw_output))
34
- except (SyntaxError, ValueError) as e:
35
- print(f"Error parsing CodeAgent output: {e}")
36
- return str(raw_output), visuals
37
-
38
- report = f"""
39
- <div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
40
- <h1 style="color: #2B547E; border-bottom: 2px solid #2B547E; padding-bottom: 10px;">📊 Data Analysis Report</h1>
41
- <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
42
- <h2 style="color: #2B547E;">🔍 Key Observations</h2>
43
- {format_observations(analysis_dict.get('observations', {}))}
44
- </div>
45
- <div style="margin-top: 30px;">
46
- <h2 style="color: #2B547E;">💡 Insights & Visualizations</h2>
47
- {format_insights(analysis_dict.get('insights', {}), visuals)}
48
- </div>
49
- </div>
50
- """
51
- return report, visuals
52
  except Exception as e:
53
- print(f"Error in format_analysis_report: {e}")
54
- return str(raw_output), visuals
55
-
56
- def format_observations(observations):
57
- return '\n'.join([
58
- f"""
59
- <div style="margin: 15px 0; padding: 15px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
60
- <h3 style="margin: 0 0 10px 0; color: #4A708B;">{key.replace('_', ' ').title()}</h3>
61
- <pre style="margin: 0; padding: 10px; background: #f8f9fa; border-radius: 4px;">{value}</pre>
62
- </div>
63
- """ for key, value in observations.items() if 'proportions' in key
64
- ])
65
-
66
- def format_insights(insights, visuals):
67
- return '\n'.join([
68
- f"""
69
- <div style="margin: 20px 0; padding: 20px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
70
- <div style="display: flex; align-items: center; gap: 10px;">
71
- <div style="background: #2B547E; color: white; width: 30px; height: 30px; border-radius: 50%; display: flex; align-items: center; justify-content: center;">{idx+1}</div>
72
- <p style="margin: 0; font-size: 16px;">{insight}</p>
73
- </div>
74
- {f'<img src="/file={visuals[idx]}" style="max-width: 100%; height: auto; margin-top: 10px; border-radius: 6px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">' if idx < len(visuals) else ''}
75
- </div>
76
- """ for idx, (key, insight) in enumerate(insights.items())
77
- ])
78
-
79
- def analyze_data(csv_file, additional_notes=""):
80
- start_time = time.time()
81
- process = psutil.Process(os.getpid())
82
- initial_memory = process.memory_info().rss / 1024 ** 2
83
-
84
- if os.path.exists('./figures'):
85
- shutil.rmtree('./figures')
86
- os.makedirs('./figures', exist_ok=True)
87
-
88
- wandb.login(key=os.environ.get('WANDB_API_KEY'))
89
- run = wandb.init(project="huggingface-data-analysis", config={
90
- "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
91
- "additional_notes": additional_notes,
92
- "source_file": csv_file.name if csv_file else None
93
- })
94
-
95
- agent = CodeAgent(tools=[], model=model, additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn", "sklearn"])
96
- analysis_result = agent.run("""
97
- You are an expert data analyst. Perform comprehensive analysis including:
98
- 1. Basic statistics and data quality checks
99
- 2. 3 insightful analytical questions about relationships in the data
100
- 3. Visualization of key patterns and correlations
101
- 4. Actionable real-world insights derived from findings.
102
- Generate publication-quality visualizations and save to './figures/'.
103
- Return the analysis results as a python dictionary that can be parsed by ast.literal_eval().
104
- The dictionary should have the following structure:
105
- {
106
- 'observations': {
107
- 'observation_1_key': 'observation_1_value',
108
- ...
109
- },
110
- 'insights': {
111
- 'insight_1_key': 'insight_1_value',
112
- ...
113
- }
114
  }
115
- """, additional_args={"additional_notes": additional_notes, "source_file": csv_file})
116
-
117
- execution_time = time.time() - start_time
118
- final_memory = process.memory_info().rss / 1024 ** 2
119
- memory_usage = final_memory - initial_memory
120
- wandb.log({"execution_time_sec": execution_time, "memory_usage_mb": memory_usage})
121
-
122
- visuals = [os.path.join('./figures', f) for f in os.listdir('./figures') if f.endswith(('.png', '.jpg', '.jpeg'))]
123
- for viz in visuals:
124
- wandb.log({os.path.basename(viz): wandb.Image(viz)})
125
-
126
- run.finish()
127
- return format_analysis_report(analysis_result, visuals)
128
-
129
- def objective(trial, X_train, y_train, X_test, y_test):
130
- n_estimators = trial.suggest_int("n_estimators", 50, 200)
131
- max_depth = trial.suggest_int("max_depth", 3, 10)
132
-
133
- model = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, random_state=42)
134
- model.fit(X_train, y_train)
135
- predictions = model.predict(X_test)
136
-
137
- return accuracy_score(y_test, predictions)
138
-
139
- def tune_hyperparameters(csv_file, n_trials: int):
140
- df = pd.read_csv(csv_file)
141
- y = df.iloc[:, -1]
142
- X = df.iloc[:, :-1]
143
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
144
 
145
  study = optuna.create_study(direction="maximize")
146
- study.optimize(lambda trial: objective(trial, X_train, y_train, X_test, y_test), n_trials=n_trials)
147
 
148
  best_params = study.best_params
149
- model = RandomForestClassifier(**best_params, random_state=42)
150
  model.fit(X_train, y_train)
151
- predictions = model.predict(X_test)
152
 
153
- accuracy = accuracy_score(y_test, predictions)
154
- precision = precision_score(y_test, predictions, average='weighted', zero_division=0)
155
- recall = recall_score(y_test, predictions, average='weighted', zero_division=0)
156
- f1 = f1_score(y_test, predictions, average='weighted', zero_division=0)
 
 
 
 
157
 
158
- wandb.log({
159
- "best_params": best_params,
160
- "accuracy": accuracy,
161
- "precision": precision,
162
- "recall": recall,
163
- "f1": f1,
164
- })
 
 
 
 
165
 
166
  # SHAP
167
  explainer = shap.TreeExplainer(model)
168
- shap_values = explainer.shap_values(X_test)
169
- shap.summary_plot(shap_values, X_test, show=False)
170
- shap_fig_path = "./figures/shap_summary.png"
171
  plt.savefig(shap_fig_path)
172
- wandb.log({"shap_summary": wandb.Image(shap_fig_path)})
173
  plt.clf()
174
 
175
  # LIME
176
- lime_explainer = lime.lime_tabular.LimeTabularExplainer(X_train.values, feature_names=X_train.columns, class_names=['target'], mode='classification')
177
- lime_exp = lime_explainer.explain_instance(X_test.iloc[0].values, model.predict_proba)
178
  lime_fig = lime_exp.as_pyplot_figure()
179
- lime_path = "./figures/lime_explanation.png"
180
- lime_fig.savefig(lime_path)
181
- wandb.log({"lime_explanation": wandb.Image(lime_path)})
182
  plt.clf()
183
 
184
- return f"Best Hyperparameters: {best_params}<br>Accuracy: {accuracy}<br>Precision: {precision}<br>Recall: {recall}<br>F1-score: {f1}"
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
- # Gradio Interface
187
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
188
- gr.Markdown("## 📊 AI Data Analysis Agent with Hyperparameter Optimization")
189
  with gr.Row():
190
  with gr.Column():
191
- file_input = gr.File(label="Upload CSV Dataset", type="filepath")
192
- notes_input = gr.Textbox(label="Dataset Notes (Optional)", lines=3)
193
- analyze_btn = gr.Button("Analyze", variant="primary")
194
- optuna_trials = gr.Number(label="Number of Hyperparameter Tuning Trials", value=10)
195
- tune_btn = gr.Button("Optimize Hyperparameters", variant="secondary")
196
  with gr.Column():
197
- analysis_output = gr.Markdown("### Analysis results will appear here...")
198
- optuna_output = gr.HTML(label="Hyperparameter Tuning Results")
199
- gallery = gr.Gallery(label="Data Visualizations", columns=2)
 
 
 
 
 
 
 
 
 
200
 
201
- analyze_btn.click(fn=analyze_data, inputs=[file_input, notes_input], outputs=[analysis_output, gallery])
202
- tune_btn.click(fn=tune_hyperparameters, inputs=[file_input, optuna_trials], outputs=[optuna_output])
 
 
203
 
204
- demo.launch(debug=True)
 
 
 
 
1
  import os
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import shap
7
+ import lime.lime_tabular
8
  import optuna
9
+ import wandb
10
  import ast
11
+ from smolagents import HfApiModel, CodeAgent
12
+ from huggingface_hub import login
13
  from sklearn.ensemble import RandomForestClassifier
14
+ from sklearn.model_selection import train_test_split, cross_val_score
15
  from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
 
 
 
 
16
 
17
+ # Authenticate Hugging Face Hub
18
  hf_token = os.getenv("HF_TOKEN")
19
  login(token=hf_token)
20
 
21
+ # Setup SmolAgent with LLM
22
  model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
23
+ agent = CodeAgent(
24
+ tools=[],
25
+ model=model,
26
+ additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn", "sklearn"],
27
+ max_iterations=10,
28
+ )
29
+
30
+ # Data cleaning function
31
+ def clean_data(df):
32
+ df = df.dropna(how='all', axis=1).dropna(how='all', axis=0)
33
+ df = df.fillna(df.mean(numeric_only=True))
34
+ df = df.select_dtypes(include=[np.number])
35
+ return df
36
+
37
+ # Global dataframe
38
+ df_global = None
39
+
40
+ # Upload and clean
41
+ def upload_file(file):
42
+ global df_global
43
+ ext = os.path.splitext(file.name)[-1]
44
+ df = pd.read_csv(file.name) if ext == ".csv" else pd.read_excel(file.name)
45
+ df = clean_data(df)
46
+ df_global = df
47
+ return df.head()
48
+
49
+ # Run SmolAgent for analysis
50
+ def run_agent(_):
51
  try:
52
+ output = agent.run(
53
+ df_global,
54
+ instructions="Generate 5 data insights and 5 data visualizations. Visualizations should be saved in current working directory."
55
+ )
56
+ return str(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  except Exception as e:
58
+ return f"SmolAgent Error: {str(e)}"
59
+
60
+ # Train model + Optuna + WandB
61
+ def train_model(_):
62
+ wandb.login(key=os.environ.get("WANDB_API_KEY"))
63
+ wandb_run = wandb.init(project="huggingface-data-analysis", name="Optuna_Run", reinit=True)
64
+
65
+ target = df_global.columns[-1]
66
+ X = df_global.drop(target, axis=1)
67
+ y = df_global[target]
68
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
69
+
70
+ def objective(trial):
71
+ params = {
72
+ "n_estimators": trial.suggest_int("n_estimators", 50, 200),
73
+ "max_depth": trial.suggest_int("max_depth", 3, 10),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  }
75
+ model = RandomForestClassifier(**params)
76
+ score = cross_val_score(model, X_train, y_train, cv=3).mean()
77
+ wandb.log(params | {"cv_score": score})
78
+ return score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  study = optuna.create_study(direction="maximize")
81
+ study.optimize(objective, n_trials=15)
82
 
83
  best_params = study.best_params
84
+ model = RandomForestClassifier(**best_params)
85
  model.fit(X_train, y_train)
86
+ y_pred = model.predict(X_test)
87
 
88
+ metrics = {
89
+ "accuracy": accuracy_score(y_test, y_pred),
90
+ "precision": precision_score(y_test, y_pred, average="weighted", zero_division=0),
91
+ "recall": recall_score(y_test, y_pred, average="weighted", zero_division=0),
92
+ "f1_score": f1_score(y_test, y_pred, average="weighted", zero_division=0),
93
+ }
94
+ wandb.log(metrics)
95
+ wandb_run.finish()
96
 
97
+ top_trials = pd.DataFrame(study.trials_dataframe().sort_values(by="value", ascending=False).head(7))
98
+ return metrics, top_trials
99
+
100
+ # SHAP & LIME
101
+ def explainability(_):
102
+ target = df_global.columns[-1]
103
+ X = df_global.drop(target, axis=1)
104
+ y = df_global[target]
105
+
106
+ model = RandomForestClassifier()
107
+ model.fit(X, y)
108
 
109
  # SHAP
110
  explainer = shap.TreeExplainer(model)
111
+ shap_values = explainer.shap_values(X)
112
+ shap.summary_plot(shap_values, X, show=False)
113
+ shap_fig_path = "./shap_plot.png"
114
  plt.savefig(shap_fig_path)
 
115
  plt.clf()
116
 
117
  # LIME
118
+ lime_explainer = lime.lime_tabular.LimeTabularExplainer(X.values, feature_names=X.columns, class_names=['target'], mode="classification")
119
+ lime_exp = lime_explainer.explain_instance(X.iloc[0].values, model.predict_proba)
120
  lime_fig = lime_exp.as_pyplot_figure()
121
+ lime_fig_path = "./lime_plot.png"
122
+ lime_fig.savefig(lime_fig_path)
 
123
  plt.clf()
124
 
125
+ # Log to wandb
126
+ wandb.init(project="huggingface-data-analysis", name="Explainability", reinit=True)
127
+ wandb.log({
128
+ "shap_summary": wandb.Image(shap_fig_path),
129
+ "lime_explanation": wandb.Image(lime_fig_path)
130
+ })
131
+ wandb.finish()
132
+
133
+ return shap_fig_path, lime_fig_path
134
+
135
+ # Gradio UI
136
+ with gr.Blocks() as demo:
137
+ gr.Markdown("## 📊 AI-Powered Data Analysis with Hyperparameter Optimization")
138
 
 
 
 
139
  with gr.Row():
140
  with gr.Column():
141
+ file_input = gr.File(label="Upload CSV or Excel", type="file")
142
+ upload_btn = gr.Button("Upload & Clean")
143
+ df_output = gr.DataFrame(label="Cleaned Data Preview")
144
+
 
145
  with gr.Column():
146
+ insights_output = gr.Textbox(label="Insights from SmolAgent", lines=15)
147
+ agent_btn = gr.Button("Run AI Agent (5 Insights + 5 Visualizations)")
148
+
149
+ with gr.Row():
150
+ train_btn = gr.Button("Train Model with Optuna + WandB")
151
+ metrics_output = gr.JSON(label="Performance Metrics")
152
+ trials_output = gr.DataFrame(label="Top 7 Hyperparameter Trials")
153
+
154
+ with gr.Row():
155
+ explain_btn = gr.Button("SHAP + LIME Explainability")
156
+ shap_img = gr.Image(label="SHAP Summary Plot")
157
+ lime_img = gr.Image(label="LIME Explanation")
158
 
159
+ upload_btn.click(fn=upload_file, inputs=file_input, outputs=df_output)
160
+ agent_btn.click(fn=run_agent, inputs=df_output, outputs=insights_output)
161
+ train_btn.click(fn=train_model, inputs=df_output, outputs=[metrics_output, trials_output])
162
+ explain_btn.click(fn=explainability, inputs=df_output, outputs=[shap_img, lime_img])
163
 
164
+ demo.launch()