pavanmutha commited on
Commit
9ff587a
·
verified ·
1 Parent(s): 94ecd73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +494 -282
app.py CHANGED
@@ -8,169 +8,111 @@ import time
8
  import psutil
9
  import optuna
10
  import ast
11
- import numpy as np
12
  import pandas as pd
 
13
  from sklearn.model_selection import train_test_split
14
- from sklearn.ensemble import RandomForestClassifier
15
- from sklearn.metrics import (accuracy_score, precision_score,
16
- recall_score, f1_score, classification_report)
17
- from lime.lime_text import LimeTextExplainer
18
- from functools import lru_cache
19
- import shap
20
  import matplotlib.pyplot as plt
21
- from sklearn.preprocessing import StandardScaler, PolynomialFeatures
22
- from sklearn.feature_selection import SelectKBest, f_classif
 
 
23
 
24
  # Authenticate Hugging Face
25
  hf_token = os.getenv("HF_TOKEN")
26
- if hf_token:
27
- login(token=hf_token, add_to_git_credential=True)
28
 
29
  # Initialize Model
30
  model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
31
 
32
- def detect_target_column(df):
33
- """Try to automatically detect the target column"""
34
- # Common target column names
35
- possible_targets = ['target', 'label', 'class', 'y', 'outcome', 'result']
36
-
37
- for col in possible_targets:
38
- if col in df.columns:
39
- return col
40
-
41
- # If none found, return the last column by default
42
- return df.columns[-1]
43
-
44
- @lru_cache(maxsize=10)
45
- def cached_generate_lime_explanation(insight_text: str, class_names: tuple = ("Negative", "Positive")):
46
- """Generate and cache LIME explanations to improve performance"""
47
- explainer = LimeTextExplainer(class_names=class_names)
48
-
49
- def classifier_fn(texts):
50
- responses = []
51
- for text in texts:
52
- prompt = f"""
53
- Analyze the following data insight and classify its sentiment:
54
- Insight: {text}
55
-
56
- Return response as a JSON format with 'positive' and 'negative' scores:
57
- {{"positive": 0.0-1.0, "negative": 0.0-1.0}}
58
- """
59
- response = model.generate(prompt, max_tokens=100)
60
- try:
61
- response_dict = ast.literal_eval(response)
62
- pos = float(response_dict.get("positive", 0))
63
- neg = float(response_dict.get("negative", 0))
64
- total = pos + neg
65
- if total > 0:
66
- pos /= total
67
- neg /= total
68
- responses.append([neg, pos])
69
- except:
70
- responses.append([0.5, 0.5])
71
- return np.array(responses)
72
-
73
- exp = explainer.explain_instance(
74
- insight_text,
75
- classifier_fn,
76
- num_features=10,
77
- top_labels=1,
78
- num_samples=100
79
- )
80
- return exp.as_html()
81
-
82
- def generate_shap_explanation(model, X_train, X_test):
83
- """Generate SHAP explanations for model predictions"""
84
  try:
85
- explainer = shap.TreeExplainer(model)
86
- shap_values = explainer.shap_values(X_test)
87
 
88
- # Save SHAP plots
89
- shap_figures = []
90
- for plot_type in ['summary', 'bar']:
91
- plt.figure()
92
- if plot_type == 'summary':
93
- shap.summary_plot(shap_values, X_test, plot_size=(10, 8), show=False)
94
- elif plot_type == 'bar':
95
- shap.summary_plot(shap_values, X_test, plot_type="bar", show=False)
96
-
97
- fig_path = f'./figures/shap_{plot_type}.png'
98
- plt.savefig(fig_path, bbox_inches='tight')
99
- plt.close()
100
- shap_figures.append(fig_path)
101
-
102
- return shap_figures
103
- except Exception as e:
104
- print(f"Error generating SHAP explanation: {e}")
105
- return []
106
 
107
- def feature_engineering_experiments(X_train, X_test, y_train, y_test):
108
- """Run different feature engineering approaches and compare results"""
109
- results = {}
110
-
111
- try:
112
- # Original features baseline
113
- base_model = RandomForestClassifier(random_state=42)
114
- base_model.fit(X_train, y_train)
115
- y_pred = base_model.predict(X_test)
116
- results['baseline'] = {
117
- 'accuracy': accuracy_score(y_test, y_pred),
118
- 'precision': precision_score(y_test, y_pred, average='weighted'),
119
- 'recall': recall_score(y_test, y_pred, average='weighted'),
120
- 'f1': f1_score(y_test, y_pred, average='weighted')
121
- }
122
-
123
- # Standardized features
124
- scaler = StandardScaler()
125
- X_train_scaled = scaler.fit_transform(X_train)
126
- X_test_scaled = scaler.transform(X_test)
127
-
128
- scaled_model = RandomForestClassifier(random_state=42)
129
- scaled_model.fit(X_train_scaled, y_train)
130
- y_pred = scaled_model.predict(X_test_scaled)
131
- results['scaled'] = {
132
- 'accuracy': accuracy_score(y_test, y_pred),
133
- 'precision': precision_score(y_test, y_pred, average='weighted'),
134
- 'recall': recall_score(y_test, y_pred, average='weighted'),
135
- 'f1': f1_score(y_test, y_pred, average='weighted')
136
- }
137
-
138
- # Polynomial features (only if few features)
139
- if X_train.shape[1] < 10:
140
- poly = PolynomialFeatures(degree=2, interaction_only=True)
141
- X_train_poly = poly.fit_transform(X_train)
142
- X_test_poly = poly.transform(X_test)
143
-
144
- poly_model = RandomForestClassifier(random_state=42)
145
- poly_model.fit(X_train_poly, y_train)
146
- y_pred = poly_model.predict(X_test_poly)
147
- results['polynomial'] = {
148
- 'accuracy': accuracy_score(y_test, y_pred),
149
- 'precision': precision_score(y_test, y_pred, average='weighted'),
150
- 'recall': recall_score(y_test, y_pred, average='weighted'),
151
- 'f1': f1_score(y_test, y_pred, average='weighted')
152
- }
153
 
154
- # Feature selection
155
- k = min(5, X_train.shape[1])
156
- selector = SelectKBest(f_classif, k=k)
157
- X_train_selected = selector.fit_transform(X_train, y_train)
158
- X_test_selected = selector.transform(X_test)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
- selected_model = RandomForestClassifier(random_state=42)
161
- selected_model.fit(X_train_selected, y_train)
162
- y_pred = selected_model.predict(X_test_selected)
163
- results['selected'] = {
164
- 'accuracy': accuracy_score(y_test, y_pred),
165
- 'precision': precision_score(y_test, y_pred, average='weighted'),
166
- 'recall': recall_score(y_test, y_pred, average='weighted'),
167
- 'f1': f1_score(y_test, y_pred, average='weighted')
168
- }
169
 
170
- except Exception as e:
171
- print(f"Error in feature engineering experiments: {e}")
172
-
173
- return results
 
 
 
174
 
175
  def analyze_data(csv_file, additional_notes=""):
176
  start_time = time.time()
@@ -181,160 +123,430 @@ def analyze_data(csv_file, additional_notes=""):
181
  shutil.rmtree('./figures')
182
  os.makedirs('./figures', exist_ok=True)
183
 
184
- try:
185
- wandb.login(key=os.environ.get('WANDB_API_KEY'))
186
- run = wandb.init(project="huggingface-data-analysis", config={
187
- "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
188
- "additional_notes": additional_notes,
189
- "source_file": csv_file.name if csv_file else None
190
- })
191
- except:
192
- run = None
193
 
194
- try:
195
- # Load and preprocess data
196
- data = pd.read_csv(csv_file.name)
197
- target_col = detect_target_column(data)
198
- X = data.drop(target_col, axis=1)
199
- y = data[target_col]
200
-
201
- # Split data
202
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
203
-
204
- # Feature engineering experiments
205
- feat_eng_results = feature_engineering_experiments(X_train, X_test, y_train, y_test)
206
- if run:
207
- wandb.log({"feature_engineering": feat_eng_results})
208
-
209
- # Train final model with best approach (using baseline here for demo)
210
- final_model = RandomForestClassifier(random_state=42)
211
- final_model.fit(X_train, y_train)
212
-
213
- # Generate SHAP explanations
214
- shap_figs = generate_shap_explanation(final_model, X_train, X_test)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
- agent = CodeAgent(tools=[], model=model, additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn"])
217
- analysis_result = agent.run(f"""
218
- You are an expert data analyst. Perform comprehensive analysis including:
219
- 1. Basic statistics and data quality checks
220
- 2. Feature engineering experiment results: {feat_eng_results}
221
- 3. Target column used: {target_col}
222
- 4. 3 insightful analytical questions about relationships in the data
223
- 5. Visualization of key patterns and correlations
224
- 6. Actionable real-world insights derived from findings
225
- Generate publication-quality visualizations and save to './figures/'
226
- """, additional_args={"additional_notes": additional_notes, "source_file": csv_file})
227
 
228
- execution_time = time.time() - start_time
229
- final_memory = process.memory_info().rss / 1024 ** 2
230
- memory_usage = final_memory - initial_memory
231
- if run:
232
- wandb.log({"execution_time_sec": execution_time, "memory_usage_mb": memory_usage})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
- visuals = [os.path.join('./figures', f) for f in os.listdir('./figures') if f.endswith(('.png', '.jpg', '.jpeg'))]
235
- visuals.extend(shap_figs) # Add SHAP visualizations
 
 
 
 
 
 
 
236
 
237
- if run:
238
- for viz in visuals:
239
- wandb.log({os.path.basename(viz): wandb.Image(viz)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
- if run:
242
- run.finish()
243
- return format_analysis_report(analysis_result, visuals)
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
 
 
 
245
  except Exception as e:
246
- if run:
247
- run.finish()
248
- return f"Error analyzing data: {str(e)}", [], []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
- def tune_hyperparameters(csv_file, n_trials: int):
251
- """Run hyperparameter optimization with Optuna"""
252
  try:
253
- data = pd.read_csv(csv_file.name)
254
- target_col = detect_target_column(data)
255
- X = data.drop(target_col, axis=1)
256
- y = data[target_col]
257
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
 
 
258
 
259
- study = optuna.create_study(direction="maximize")
260
- study.optimize(lambda trial: objective(trial, X_train, y_train, X_test, y_test), n_trials=n_trials)
261
 
262
- # Train final model with best params
263
- best_model = RandomForestClassifier(**study.best_params, random_state=42)
264
- best_model.fit(X_train, y_train)
265
- y_pred = best_model.predict(X_test)
266
 
267
- metrics = {
268
- 'accuracy': accuracy_score(y_test, y_pred),
269
- 'precision': precision_score(y_test, y_pred, average='weighted'),
270
- 'recall': recall_score(y_test, y_pred, average='weighted'),
271
- 'f1': f1_score(y_test, y_pred, average='weighted')
272
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
- return f"Best Hyperparameters: {study.best_params}\n\nValidation Metrics:\n{metrics}"
 
275
  except Exception as e:
276
- return f"Error tuning hyperparameters: {str(e)}"
 
277
 
278
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
279
- gr.Markdown("## 📊 AI Data Analysis Agent with Explainability")
 
 
280
 
281
- insights_store = gr.State([])
 
 
 
 
 
282
 
283
- with gr.Row():
284
- with gr.Column():
285
- file_input = gr.File(label="Upload CSV Dataset", type="filepath")
286
- notes_input = gr.Textbox(label="Dataset Notes (Optional)", lines=3)
287
- analyze_btn = gr.Button("Analyze", variant="primary")
288
- optuna_trials = gr.Number(label="Number of Hyperparameter Tuning Trials", value=10)
289
- tune_btn = gr.Button("Optimize Hyperparameters", variant="secondary")
290
-
291
- insight_dropdown = gr.Dropdown(
292
- label="Select Insight to Explain",
293
- interactive=True,
294
- visible=False
295
- )
296
- explain_btn = gr.Button("Generate Explanation", variant="primary", visible=False)
297
-
298
- with gr.Column():
299
- analysis_output = gr.HTML("### Analysis results will appear here...")
300
- optuna_output = gr.Textbox(label="Optimization Results")
301
- gallery = gr.Gallery(label="Data Visualizations", columns=2)
302
- explanation_html = gr.HTML(label="Model Explanation")
303
-
304
- def update_insight_dropdown(insights):
305
- if insights and len(insights) > 0:
306
- return gr.Dropdown(
307
- choices=[(f"Insight {i+1}", insight) for i, insight in enumerate(insights)],
308
- value=insights[0],
309
- visible=True
310
- ), gr.Button(visible=True)
311
- return gr.Dropdown(visible=False), gr.Button(visible=False)
312
-
313
- def generate_explanation(selected_insight):
314
- if not selected_insight:
315
- return "<p>Please select an insight first</p>"
316
- return cached_generate_lime_explanation(selected_insight)
317
-
318
- analyze_btn.click(
319
- fn=analyze_data,
320
- inputs=[file_input, notes_input],
321
- outputs=[analysis_output, gallery, insights_store]
322
- ).then(
323
- fn=update_insight_dropdown,
324
- inputs=insights_store,
325
- outputs=[insight_dropdown, explain_btn]
326
  )
327
 
328
- explain_btn.click(
329
- fn=generate_explanation,
330
- inputs=insight_dropdown,
331
- outputs=explanation_html
332
- )
333
 
334
- tune_btn.click(
335
- fn=tune_hyperparameters,
336
- inputs=[file_input, optuna_trials],
337
- outputs=[optuna_output]
338
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
 
340
- demo.launch(debug=True, share=True)
 
8
  import psutil
9
  import optuna
10
  import ast
 
11
  import pandas as pd
12
+ import numpy as np
13
  from sklearn.model_selection import train_test_split
14
+ from sklearn.preprocessing import StandardScaler, OneHotEncoder
15
+ from sklearn.compose import ColumnTransformer
16
+ from sklearn.pipeline import Pipeline
17
+ from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
18
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
 
19
  import matplotlib.pyplot as plt
20
+ import seaborn as sns
21
+ import shap
22
+ import lime
23
+ from lime import lime_tabular
24
 
25
  # Authenticate Hugging Face
26
  hf_token = os.getenv("HF_TOKEN")
27
+ login(token=hf_token, add_to_git_credential=True)
 
28
 
29
  # Initialize Model
30
  model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
31
 
32
+ def format_analysis_report(raw_output, visuals):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  try:
34
+ analysis_dict = raw_output if isinstance(raw_output, dict) else ast.literal_eval(str(raw_output))
 
35
 
36
+ report = f"""
37
+ <div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
38
+ <h1 style="color: #2B547E; border-bottom: 2px solid #2B547E; padding-bottom: 10px;">📊 Data Analysis Report</h1>
39
+ <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
40
+ <h2 style="color: #2B547E;">🔍 Key Observations</h2>
41
+ {format_observations(analysis_dict.get('observations', {}))}
42
+ </div>
43
+ <div style="margin-top: 30px;">
44
+ <h2 style="color: #2B547E;">💡 Insights & Visualizations</h2>
45
+ {format_insights(analysis_dict.get('insights', {}), visuals)}
46
+ </div>
47
+ </div>
48
+ """
49
+ return report, visuals
50
+ except:
51
+ return raw_output, visuals
 
 
52
 
53
+ def format_observations(observations):
54
+ return '\n'.join([
55
+ f"""
56
+ <div style="margin: 15px 0; padding: 15px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
57
+ <h3 style="margin: 0 0 10px 0; color: #4A708B;">{key.replace('_', ' ').title()}</h3>
58
+ <pre style="margin: 0; padding: 10px; background: #f8f9fa; border-radius: 4px;">{value}</pre>
59
+ </div>
60
+ """ for key, value in observations.items() if 'proportions' in key
61
+ ])
62
+
63
+ def format_insights(insights, visuals):
64
+ return '\n'.join([
65
+ f"""
66
+ <div style="margin: 20px 0; padding: 20px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
67
+ <div style="display: flex; align-items: center; gap: 10px;">
68
+ <div style="background: #2B547E; color: white; width: 30px; height: 30px; border-radius: 50%; display: flex; align-items: center; justify-content: center;">{idx+1}</div>
69
+ <p style="margin: 0; font-size: 16px;">{insight}</p>
70
+ </div>
71
+ {f'<img src="/file={visuals[idx]}" style="max-width: 100%; height: auto; margin-top: 10px; border-radius: 6px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">' if idx < len(visuals) else ''}
72
+ </div>
73
+ """ for idx, (key, insight) in enumerate(insights.items())
74
+ ])
75
+
76
+ def format_model_evaluation(metrics_dict, feature_importance_path=None, explainability_path=None):
77
+ report = f"""
78
+ <div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
79
+ <h1 style="color: #2B547E; border-bottom: 2px solid #2B547E; padding-bottom: 10px;">🧠 Model Evaluation Report</h1>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
82
+ <h2 style="color: #2B547E;">📈 Performance Metrics</h2>
83
+ <div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 15px;">
84
+ <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
85
+ <h3 style="margin: 0 0 10px 0; color: #4A708B;">Accuracy</h3>
86
+ <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics_dict.get('accuracy', 'N/A'):.4f}</p>
87
+ </div>
88
+ <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
89
+ <h3 style="margin: 0 0 10px 0; color: #4A708B;">Precision</h3>
90
+ <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics_dict.get('precision', 'N/A'):.4f}</p>
91
+ </div>
92
+ <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
93
+ <h3 style="margin: 0 0 10px 0; color: #4A708B;">Recall</h3>
94
+ <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics_dict.get('recall', 'N/A'):.4f}</p>
95
+ </div>
96
+ <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
97
+ <h3 style="margin: 0 0 10px 0; color: #4A708B;">F1 Score</h3>
98
+ <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics_dict.get('f1', 'N/A'):.4f}</p>
99
+ </div>
100
+ </div>
101
+ </div>
102
 
103
+ <div style="margin-top: 30px;">
104
+ <h2 style="color: #2B547E;">📊 Feature Importance & Explainability</h2>
105
+ {f'<img src="/file={feature_importance_path}" style="max-width: 100%; height: auto; margin-top: 10px; border-radius: 6px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">' if feature_importance_path else ''}
106
+ {f'<img src="/file={explainability_path}" style="max-width: 100%; height: auto; margin-top: 10px; border-radius: 6px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">' if explainability_path else ''}
107
+ </div>
 
 
 
 
108
 
109
+ <div style="margin-top: 30px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
110
+ <h2 style="color: #2B547E;">🔄 Hyperparameters</h2>
111
+ <pre style="margin: 0; padding: 15px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">{metrics_dict.get('best_params', 'N/A')}</pre>
112
+ </div>
113
+ </div>
114
+ """
115
+ return report
116
 
117
  def analyze_data(csv_file, additional_notes=""):
118
  start_time = time.time()
 
123
  shutil.rmtree('./figures')
124
  os.makedirs('./figures', exist_ok=True)
125
 
126
+ wandb.login(key=os.environ.get('WANDB_API_KEY'))
127
+ run = wandb.init(project="huggingface-data-analysis", config={
128
+ "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
129
+ "additional_notes": additional_notes,
130
+ "source_file": csv_file.name if csv_file else None
131
+ })
 
 
 
132
 
133
+ agent = CodeAgent(tools=[], model=model, additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn"])
134
+ analysis_result = agent.run("""
135
+ You are an expert data analyst. Perform comprehensive analysis including:
136
+ 1. Basic statistics and data quality checks
137
+ 2. 3 insightful analytical questions about relationships in the data
138
+ 3. Visualization of key patterns and correlations
139
+ 4. Actionable real-world insights derived from findings
140
+ Generate publication-quality visualizations and save to './figures/'
141
+ """, additional_args={"additional_notes": additional_notes, "source_file": csv_file})
142
+
143
+ execution_time = time.time() - start_time
144
+ final_memory = process.memory_info().rss / 1024 ** 2
145
+ memory_usage = final_memory - initial_memory
146
+ wandb.log({"execution_time_sec": execution_time, "memory_usage_mb": memory_usage})
147
+
148
+ visuals = [os.path.join('./figures', f) for f in os.listdir('./figures') if f.endswith(('.png', '.jpg', '.jpeg'))]
149
+ for viz in visuals:
150
+ wandb.log({os.path.basename(viz): wandb.Image(viz)})
151
+
152
+ run.finish()
153
+ return format_analysis_report(analysis_result, visuals)
154
+
155
+ def preprocess_features(data, target_column, feature_engineering=True):
156
+ """
157
+ Preprocess features with optional feature engineering
158
+ """
159
+ # Check if data is loaded
160
+ if data is None or not isinstance(data, pd.DataFrame):
161
+ return None, None, None, None, None
162
+
163
+ # Separate features and target
164
+ if target_column not in data.columns:
165
+ # Try to infer target column if it's not specified
166
+ for col in ['target', 'label', 'class', 'outcome', 'y']:
167
+ if col in data.columns:
168
+ target_column = col
169
+ break
170
+ else:
171
+ return None, None, None, None, None
172
+
173
+ X = data.drop(target_column, axis=1)
174
+ y = data[target_column]
175
+
176
+ # Split data
177
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
178
+
179
+ # Identify numerical and categorical columns
180
+ numerical_cols = X.select_dtypes(include=['int64', 'float64']).columns.tolist()
181
+ categorical_cols = X.select_dtypes(include=['object', 'category']).columns.tolist()
182
+
183
+ # Basic preprocessing
184
+ numeric_transformer = Pipeline(steps=[
185
+ ('scaler', StandardScaler())
186
+ ])
187
+
188
+ categorical_transformer = Pipeline(steps=[
189
+ ('onehot', OneHotEncoder(handle_unknown='ignore'))
190
+ ])
191
+
192
+ preprocessor = ColumnTransformer(
193
+ transformers=[
194
+ ('num', numeric_transformer, numerical_cols),
195
+ ('cat', categorical_transformer, categorical_cols)
196
+ ])
197
+
198
+ # Feature engineering when enabled
199
+ if feature_engineering:
200
+ # Create interaction terms between numerical features
201
+ for i, col1 in enumerate(numerical_cols):
202
+ for col2 in numerical_cols[i+1:]:
203
+ if len(numerical_cols) > 1:
204
+ X_train[f'{col1}_{col2}_interaction'] = X_train[col1] * X_train[col2]
205
+ X_test[f'{col1}_{col2}_interaction'] = X_test[col1] * X_test[col2]
206
 
207
+ # Create polynomial features for numerical columns (quadratic terms)
208
+ for col in numerical_cols:
209
+ X_train[f'{col}_squared'] = X_train[col] ** 2
210
+ X_test[f'{col}_squared'] = X_test[col] ** 2
 
 
 
 
 
 
 
211
 
212
+ # Create aggregate features for categorical columns
213
+ for col in categorical_cols:
214
+ # For each categorical column, calculate mean of numerical columns grouped by categories
215
+ for num_col in numerical_cols:
216
+ if num_col in X_train.columns:
217
+ agg_map = X_train.groupby(col)[num_col].mean().to_dict()
218
+ X_train[f'{col}_{num_col}_agg'] = X_train[col].map(agg_map)
219
+ X_test[f'{col}_{num_col}_agg'] = X_test[col].map(agg_map)
220
+
221
+ return X_train, X_test, y_train, y_test, preprocessor
222
+
223
+ def create_shap_plot(model, X_test, feature_names):
224
+ """Create SHAP summary plot for model explainability"""
225
+ plt.figure(figsize=(12, 8))
226
+
227
+ # For tree-based models
228
+ if hasattr(model, 'feature_importances_'):
229
+ explainer = shap.TreeExplainer(model)
230
+ shap_values = explainer.shap_values(X_test)
231
 
232
+ # Handle multi-class case
233
+ if isinstance(shap_values, list):
234
+ shap_values = shap_values[1] # Use the positive class
235
+
236
+ shap.summary_plot(shap_values, X_test, feature_names=feature_names, show=False)
237
+ else:
238
+ # Fallback for non-tree models
239
+ explainer = shap.KernelExplainer(model.predict_proba, shap.sample(X_test, 50))
240
+ shap_values = explainer.shap_values(X_test[:50])
241
 
242
+ # Handle multi-class case
243
+ if isinstance(shap_values, list):
244
+ shap_values = shap_values[1] # Use the positive class
245
+
246
+ shap.summary_plot(shap_values, X_test[:50], feature_names=feature_names, show=False)
247
+
248
+ plt.tight_layout()
249
+ file_path = './figures/shap_summary.png'
250
+ plt.savefig(file_path)
251
+ plt.close()
252
+ return file_path
253
+
254
+ def create_lime_explanation(model, X_train, X_test, feature_names):
255
+ """Create LIME explanation for a sample instance"""
256
+ # Create LIME explainer
257
+ explainer = lime_tabular.LimeTabularExplainer(
258
+ X_train,
259
+ feature_names=feature_names,
260
+ class_names=["Negative", "Positive"],
261
+ mode="classification"
262
+ )
263
+
264
+ # Explain a sample instance
265
+ instance_idx = 0
266
+ exp = explainer.explain_instance(
267
+ X_test[instance_idx],
268
+ model.predict_proba,
269
+ num_features=10
270
+ )
271
+
272
+ # Plot explanation
273
+ plt.figure(figsize=(10, 6))
274
+ exp.as_pyplot_figure()
275
+ plt.tight_layout()
276
+ file_path = './figures/lime_explanation.png'
277
+ plt.savefig(file_path)
278
+ plt.close()
279
+ return file_path
280
+
281
+ def create_feature_importance_plot(model, feature_names):
282
+ """Create feature importance plot if model supports it"""
283
+ if hasattr(model, 'feature_importances_'):
284
+ importances = model.feature_importances_
285
+ indices = np.argsort(importances)[-20:] # Top 20 features
286
 
287
+ plt.figure(figsize=(12, 8))
288
+ plt.title('Feature Importances')
289
+ plt.barh(range(len(indices)), importances[indices], align='center')
290
+ plt.yticks(range(len(indices)), [feature_names[i] for i in indices])
291
+ plt.xlabel('Relative Importance')
292
+ plt.tight_layout()
293
+ file_path = './figures/feature_importance.png'
294
+ plt.savefig(file_path)
295
+ plt.close()
296
+ return file_path
297
+ return None
298
+
299
+ def train_and_evaluate_model(csv_file, target_column, model_type, feature_eng_enabled=True, explainer_type="shap"):
300
+ """Train, evaluate model with metrics and explainability"""
301
+ if not csv_file:
302
+ return "Please upload a CSV file", None, []
303
 
304
+ # Load data
305
+ try:
306
+ data = pd.read_csv(csv_file)
307
  except Exception as e:
308
+ return f"Error loading data: {str(e)}", None, []
309
+
310
+ # Preprocess data
311
+ X_train, X_test, y_train, y_test, preprocessor = preprocess_features(
312
+ data, target_column, feature_engineering=feature_eng_enabled
313
+ )
314
+
315
+ if X_train is None:
316
+ return f"Error: Could not identify target column '{target_column}'", None, []
317
+
318
+ # Apply preprocessing
319
+ X_train_processed = X_train
320
+ X_test_processed = X_test
321
+
322
+ # Select model
323
+ if model_type == "random_forest":
324
+ model = RandomForestClassifier(random_state=42)
325
+ else: # Default to gradient boosting
326
+ model = GradientBoostingClassifier(random_state=42)
327
+
328
+ # Train model
329
+ model.fit(X_train_processed, y_train)
330
+
331
+ # Make predictions
332
+ y_pred = model.predict(X_test_processed)
333
+
334
+ # Calculate metrics
335
+ metrics = {
336
+ 'accuracy': accuracy_score(y_test, y_pred),
337
+ 'precision': precision_score(y_test, y_pred, average='weighted'),
338
+ 'recall': recall_score(y_test, y_pred, average='weighted'),
339
+ 'f1': f1_score(y_test, y_pred, average='weighted'),
340
+ }
341
+
342
+ # Generate feature names
343
+ feature_names = X_train_processed.columns.tolist()
344
+
345
+ # Create feature importance plot
346
+ feature_importance_path = create_feature_importance_plot(model, feature_names)
347
+
348
+ # Create explainability visualization
349
+ explainability_path = None
350
+ if explainer_type == "shap":
351
+ explainability_path = create_shap_plot(model, X_test_processed, feature_names)
352
+ else: # LIME
353
+ explainability_path = create_lime_explanation(model, X_train_processed.values,
354
+ X_test_processed.values, feature_names)
355
+
356
+ # Log to wandb
357
+ wandb.login(key=os.environ.get('WANDB_API_KEY'))
358
+ run = wandb.init(project="huggingface-model-evaluation", config={
359
+ "model_type": model_type,
360
+ "feature_engineering": feature_eng_enabled,
361
+ "explainer": explainer_type,
362
+ "metrics": metrics
363
+ })
364
+
365
+ wandb.log(metrics)
366
+
367
+ if feature_importance_path:
368
+ wandb.log({"feature_importance": wandb.Image(feature_importance_path)})
369
+
370
+ if explainability_path:
371
+ wandb.log({"explainability": wandb.Image(explainability_path)})
372
+
373
+ run.finish()
374
+
375
+ # Return results
376
+ results = [feature_importance_path, explainability_path] if feature_importance_path and explainability_path else []
377
+ return format_model_evaluation(metrics, feature_importance_path, explainability_path), None, results
378
 
379
+ def objective(trial, csv_file, target_column, model_type, feature_eng_enabled=True):
380
+ """Objective function for Optuna hyperparameter optimization"""
381
  try:
382
+ # Load data
383
+ data = pd.read_csv(csv_file)
384
+
385
+ # Preprocess data
386
+ X_train, X_test, y_train, y_test, preprocessor = preprocess_features(
387
+ data, target_column, feature_engineering=feature_eng_enabled
388
+ )
389
 
390
+ if X_train is None:
391
+ return 0.0
392
 
393
+ # Apply preprocessing
394
+ X_train_processed = X_train
395
+ X_test_processed = X_test
 
396
 
397
+ # Hyperparameters based on model type
398
+ if model_type == "random_forest":
399
+ model = RandomForestClassifier(
400
+ n_estimators=trial.suggest_int("n_estimators", 50, 500),
401
+ max_depth=trial.suggest_int("max_depth", 3, 20),
402
+ min_samples_split=trial.suggest_int("min_samples_split", 2, 10),
403
+ min_samples_leaf=trial.suggest_int("min_samples_leaf", 1, 4),
404
+ bootstrap=trial.suggest_categorical("bootstrap", [True, False]),
405
+ random_state=42
406
+ )
407
+ else: # Gradient Boosting
408
+ model = GradientBoostingClassifier(
409
+ learning_rate=trial.suggest_float("learning_rate", 0.01, 0.3),
410
+ n_estimators=trial.suggest_int("n_estimators", 50, 500),
411
+ max_depth=trial.suggest_int("max_depth", 3, 10),
412
+ min_samples_split=trial.suggest_int("min_samples_split", 2, 10),
413
+ min_samples_leaf=trial.suggest_int("min_samples_leaf", 1, 4),
414
+ subsample=trial.suggest_float("subsample", 0.6, 1.0),
415
+ random_state=42
416
+ )
417
+
418
+ # Train model
419
+ model.fit(X_train_processed, y_train)
420
+
421
+ # Evaluate model
422
+ y_pred = model.predict(X_test_processed)
423
+ f1 = f1_score(y_test, y_pred, average='weighted')
424
 
425
+ return f1
426
+
427
  except Exception as e:
428
+ print(f"Error in objective function: {str(e)}")
429
+ return 0.0
430
 
431
+ def tune_hyperparameters(csv_file, target_column, model_type, n_trials=10, feature_eng_enabled=True):
432
+ """Run hyperparameter tuning with Optuna"""
433
+ if not csv_file:
434
+ return "Please upload a CSV file first"
435
 
436
+ wandb.login(key=os.environ.get('WANDB_API_KEY'))
437
+ run = wandb.init(project="huggingface-hyperparameter-tuning", config={
438
+ "model_type": model_type,
439
+ "feature_engineering": feature_eng_enabled,
440
+ "n_trials": n_trials
441
+ })
442
 
443
+ study = optuna.create_study(direction="maximize")
444
+ study.optimize(
445
+ lambda trial: objective(trial, csv_file, target_column, model_type, feature_eng_enabled),
446
+ n_trials=n_trials
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  )
448
 
449
+ # Log best parameters to wandb
450
+ wandb.log({"best_params": study.best_params, "best_value": study.best_value})
 
 
 
451
 
452
+ # Visualization of optimization history
453
+ plt.figure(figsize=(10, 6))
454
+ optuna.visualization.matplotlib.plot_optimization_history(study)
455
+ plt.tight_layout()
456
+ history_path = './figures/optuna_history.png'
457
+ plt.savefig(history_path)
458
+ plt.close()
459
+
460
+ # Visualization of parameter importances
461
+ plt.figure(figsize=(10, 6))
462
+ optuna.visualization.matplotlib.plot_param_importances(study)
463
+ plt.tight_layout()
464
+ importance_path = './figures/optuna_importance.png'
465
+ plt.savefig(importance_path)
466
+ plt.close()
467
+
468
+ # Log visualizations
469
+ wandb.log({"optimization_history": wandb.Image(history_path)})
470
+ wandb.log({"parameter_importance": wandb.Image(importance_path)})
471
+
472
+ run.finish()
473
+
474
+ # Return a formatted result
475
+ result = f"""
476
+ <div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
477
+ <h1 style="color: #2B547E; border-bottom: 2px solid #2B547E; padding-bottom: 10px;">⚙️ Hyperparameter Optimization Results</h1>
478
+
479
+ <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
480
+ <h2 style="color: #2B547E;">🏆 Best Parameters</h2>
481
+ <pre style="margin: 10px 0; padding: 15px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">{study.best_params}</pre>
482
+
483
+ <h3 style="color: #4A708B; margin-top: 20px;">Best F1 Score</h3>
484
+ <p style="font-size: 20px; font-weight: bold;">{study.best_value:.4f}</p>
485
+ </div>
486
+
487
+ <div style="margin-top: 30px;">
488
+ <h2 style="color: #2B547E;">📈 Optimization Results</h2>
489
+ <img src="/file={history_path}" style="max-width: 100%; height: auto; margin-top: 10px; border-radius: 6px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
490
+ <img src="/file={importance_path}" style="max-width: 100%; height: auto; margin-top: 10px; border-radius: 6px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
491
+ </div>
492
+ </div>
493
+ """
494
+
495
+ # Return results and visualization paths for gallery
496
+ return result, [history_path, importance_path]
497
+
498
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
499
+ gr.Markdown("## 📊 AI Data Analysis & ML Experimentation Platform")
500
+
501
+ with gr.Tab("Data Analysis"):
502
+ with gr.Row():
503
+ with gr.Column():
504
+ file_input_analysis = gr.File(label="Upload CSV Dataset", type="filepath")
505
+ notes_input = gr.Textbox(label="Dataset Notes (Optional)", lines=3)
506
+ analyze_btn = gr.Button("Analyze", variant="primary")
507
+ with gr.Column():
508
+ analysis_output = gr.HTML("### Analysis results will appear here...")
509
+ gallery = gr.Gallery(label="Data Visualizations", columns=2)
510
+
511
+ analyze_btn.click(fn=analyze_data, inputs=[file_input_analysis, notes_input], outputs=[analysis_output, gallery])
512
+
513
+ with gr.Tab("ML Model Experimentation"):
514
+ with gr.Row():
515
+ with gr.Column():
516
+ file_input_model = gr.File(label="Upload CSV Dataset", type="filepath")
517
+ target_column = gr.Textbox(label="Target Column Name", placeholder="e.g., target, class, outcome")
518
+ model_type = gr.Radio(["random_forest", "gradient_boosting"], label="Model Type", value="random_forest")
519
+ feature_eng = gr.Checkbox(label="Enable Feature Engineering", value=True)
520
+ explainer_type = gr.Radio(["shap", "lime"], label="Explainability Tool", value="shap")
521
+ train_btn = gr.Button("Train & Evaluate Model", variant="primary")
522
+ with gr.Column():
523
+ model_output = gr.HTML("### Model evaluation results will appear here...")
524
+ model_metrics = gr.Textbox(label="Raw Metrics", visible=False)
525
+ model_gallery = gr.Gallery(label="Model Visualizations", columns=2)
526
+
527
+ train_btn.click(
528
+ fn=train_and_evaluate_model,
529
+ inputs=[file_input_model, target_column, model_type, feature_eng, explainer_type],
530
+ outputs=[model_output, model_metrics, model_gallery]
531
+ )
532
+
533
+ with gr.Tab("Hyperparameter Tuning"):
534
+ with gr.Row():
535
+ with gr.Column():
536
+ file_input_hp = gr.File(label="Upload CSV Dataset", type="filepath")
537
+ target_column_hp = gr.Textbox(label="Target Column Name", placeholder="e.g., target, class, outcome")
538
+ model_type_hp = gr.Radio(["random_forest", "gradient_boosting"], label="Model Type", value="random_forest")
539
+ feature_eng_hp = gr.Checkbox(label="Enable Feature Engineering", value=True)
540
+ n_trials = gr.Slider(minimum=5, maximum=50, value=10, step=5, label="Number of Optimization Trials")
541
+ tune_btn = gr.Button("Run Hyperparameter Optimization", variant="primary")
542
+ with gr.Column():
543
+ hp_output = gr.HTML("### Hyperparameter tuning results will appear here...")
544
+ hp_gallery = gr.Gallery(label="Optimization Visualizations", columns=2)
545
+
546
+ tune_btn.click(
547
+ fn=tune_hyperparameters,
548
+ inputs=[file_input_hp, target_column_hp, model_type_hp, n_trials, feature_eng_hp],
549
+ outputs=[hp_output, hp_gallery]
550
+ )
551
 
552
+ demo.launch(debug=True)