pavanmutha commited on
Commit
7b22a65
·
verified ·
1 Parent(s): 59412e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +245 -123
app.py CHANGED
@@ -17,6 +17,8 @@ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_sc
17
  from sklearn.model_selection import train_test_split
18
  from sklearn.ensemble import RandomForestClassifier
19
  from sklearn.preprocessing import StandardScaler, PolynomialFeatures
 
 
20
 
21
  # Authenticate Hugging Face
22
  hf_token = os.getenv("HF_TOKEN")
@@ -25,10 +27,49 @@ login(token=hf_token, add_to_git_credential=True)
25
  # Initialize Model
26
  model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def format_analysis_report(raw_output, visuals, metrics=None, explainability_plots=None):
29
  try:
30
- analysis_dict = raw_output if isinstance(raw_output, dict) else ast.literal_eval(str(raw_output))
31
-
 
 
 
 
 
 
 
 
 
 
32
  metrics_section = ""
33
  if metrics:
34
  metrics_section = f"""
@@ -37,24 +78,25 @@ def format_analysis_report(raw_output, visuals, metrics=None, explainability_plo
37
  <div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 15px;">
38
  <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
39
  <h3 style="margin: 0 0 10px 0; color: #4A708B;">Accuracy</h3>
40
- <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics['accuracy']:.2f}</p>
41
  </div>
42
  <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
43
  <h3 style="margin: 0 0 10px 0; color: #4A708B;">Precision</h3>
44
- <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics['precision']:.2f}</p>
45
  </div>
46
  <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
47
  <h3 style="margin: 0 0 10px 0; color: #4A708B;">Recall</h3>
48
- <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics['recall']:.2f}</p>
49
  </div>
50
  <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
51
  <h3 style="margin: 0 0 10px 0; color: #4A708B;">F1 Score</h3>
52
- <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics['f1']:.2f}</p>
53
  </div>
54
  </div>
55
  </div>
56
  """
57
 
 
58
  explainability_section = ""
59
  if explainability_plots:
60
  explainability_section = f"""
@@ -66,29 +108,61 @@ def format_analysis_report(raw_output, visuals, metrics=None, explainability_plo
66
  </div>
67
  """
68
 
69
- report = f"""
70
- <div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
71
- <h1 style="color: #2B547E; border-bottom: 2px solid #2B547E; padding-bottom: 10px;">📊 Data Analysis Report</h1>
72
- {metrics_section}
73
  <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
74
  <h2 style="color: #2B547E;">🔍 Key Observations</h2>
75
- {format_observations(analysis_dict.get('observations', {}))}
76
  </div>
 
 
 
 
 
 
77
  <div style="margin-top: 30px;">
78
  <h2 style="color: #2B547E;">💡 Insights & Visualizations</h2>
79
  {format_insights(analysis_dict.get('insights', {}), visuals)}
80
  </div>
 
 
 
 
 
 
 
81
  {explainability_section}
 
 
82
  </div>
83
  """
 
84
  return report, visuals
85
- except:
86
- return raw_output, visuals
 
 
 
 
 
 
 
 
 
 
87
 
88
  def preprocess_data(df, feature_engineering=True):
89
  """Handle missing values, categorical encoding, and feature engineering"""
90
- # Basic preprocessing
91
- df = df.dropna()
 
 
 
 
 
 
92
 
93
  # Convert categorical variables if any
94
  categorical_cols = df.select_dtypes(include=['object']).columns
@@ -96,17 +170,17 @@ def preprocess_data(df, feature_engineering=True):
96
  if len(df[col].unique()) <= 10: # One-hot encode if few categories
97
  df = pd.concat([df, pd.get_dummies(df[col], prefix=col)], axis=1)
98
  df = df.drop(col, axis=1)
 
 
99
 
100
  # Feature engineering
101
- if feature_engineering:
102
  # Create polynomial features for numerical columns
103
- num_cols = df.select_dtypes(include=['int64', 'float64']).columns
104
- if len(num_cols) > 0:
105
- poly = PolynomialFeatures(degree=2, interaction_only=True, include_bias=False)
106
- poly_features = poly.fit_transform(df[num_cols])
107
- poly_cols = [f"poly_{i}" for i in range(poly_features.shape[1])]
108
- poly_df = pd.DataFrame(poly_features, columns=poly_cols)
109
- df = pd.concat([df, poly_df], axis=1)
110
 
111
  return df
112
 
@@ -134,30 +208,36 @@ def generate_explainability_plots(X, model, feature_names, output_dir='./figures
134
  os.makedirs(output_dir, exist_ok=True)
135
  plot_paths = []
136
 
137
- # SHAP Analysis
138
- explainer = shap.Explainer(model)
139
- shap_values = explainer(X)
140
-
141
- plt = shap.summary_plot(shap_values, X, feature_names=feature_names, show=False)
142
- shap_path = os.path.join(output_dir, 'shap_summary.png')
143
- plt.savefig(shap_path, bbox_inches='tight')
144
- plt.close()
145
- plot_paths.append(shap_path)
146
-
147
- # LIME Analysis
148
- explainer = lime.lime_tabular.LimeTabularExplainer(
149
- X,
150
- feature_names=feature_names,
151
- class_names=['class_0', 'class_1'], # Update based on your classes
152
- verbose=True,
153
- mode='classification'
154
- )
155
-
156
- # Explain a random instance
157
- exp = explainer.explain_instance(X[0], model.predict_proba, num_features=5)
158
- lime_path = os.path.join(output_dir, 'lime_explanation.png')
159
- exp.as_pyplot_figure().savefig(lime_path, bbox_inches='tight')
160
- plot_paths.append(lime_path)
 
 
 
 
 
 
161
 
162
  return plot_paths
163
 
@@ -178,40 +258,52 @@ def analyze_data(csv_file, additional_notes="", perform_ml=True):
178
  "perform_ml": perform_ml
179
  })
180
 
181
- # Load and preprocess data
182
- df = pd.read_csv(csv_file)
183
- processed_df = preprocess_data(df)
184
-
185
  metrics = None
186
  explainability_plots = None
187
 
188
- if perform_ml and len(processed_df.columns) > 1:
189
- try:
190
- # Assume last column is target for demonstration
191
- X = processed_df.iloc[:, :-1].values
192
- y = processed_df.iloc[:, -1].values
193
-
194
- # Evaluate baseline model
195
- baseline_model = RandomForestClassifier(random_state=42)
196
- metrics = evaluate_model(X, y, baseline_model)
197
-
198
- # Generate explainability plots
199
- feature_names = processed_df.columns[:-1]
200
- explainability_plots = generate_explainability_plots(X[:100], baseline_model, feature_names)
201
-
202
- wandb.log(metrics)
203
- except Exception as e:
204
- print(f"ML analysis failed: {str(e)}")
205
-
206
- agent = CodeAgent(tools=[], model=model, additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn"])
207
- analysis_result = agent.run("""
208
- You are an expert data analyst. Perform comprehensive analysis including:
209
- 1. Basic statistics and data quality checks
210
- 2. 3 insightful analytical questions about relationships in the data
211
- 3. Visualization of key patterns and correlations
212
- 4. Actionable real-world insights derived from findings
213
- Generate publication-quality visualizations and save to './figures/'
214
- """, additional_args={"additional_notes": additional_notes, "source_file": csv_file})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  execution_time = time.time() - start_time
217
  final_memory = process.memory_info().rss / 1024 ** 2
@@ -225,52 +317,77 @@ def analyze_data(csv_file, additional_notes="", perform_ml=True):
225
  run.finish()
226
  return format_analysis_report(analysis_result, visuals, metrics, explainability_plots)
227
 
228
- def objective(trial):
229
- # Define hyperparameter space
230
- params = {
231
- 'n_estimators': trial.suggest_int('n_estimators', 50, 500),
232
- 'max_depth': trial.suggest_int('max_depth', 3, 15),
233
- 'min_samples_split': trial.suggest_int('min_samples_split', 2, 10),
234
- 'min_samples_leaf': trial.suggest_int('min_samples_leaf', 1, 5),
235
- 'max_features': trial.suggest_categorical('max_features', ['sqrt', 'log2', None]),
236
- 'bootstrap': trial.suggest_categorical('bootstrap', [True, False])
237
- }
238
-
239
- # Load data (you would need to pass this or make it available)
240
- df = pd.read_csv("temp_data.csv") # You'll need to handle this properly
241
- processed_df = preprocess_data(df)
242
- X = processed_df.iloc[:, :-1].values
243
- y = processed_df.iloc[:, -1].values
244
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
245
-
246
- # Standardize features
247
- scaler = StandardScaler()
248
- X_train = scaler.fit_transform(X_train)
249
- X_test = scaler.transform(X_test)
250
-
251
- # Create and evaluate model
252
- model = RandomForestClassifier(**params, random_state=42)
253
- model.fit(X_train, y_train)
254
- y_pred = model.predict(X_test)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
- # Return metric to optimize (F1 score in this case)
257
- return f1_score(y_test, y_pred, average='weighted')
 
258
 
259
  def tune_hyperparameters(n_trials: int, csv_file):
260
  try:
 
 
 
261
  # Save the uploaded file temporarily for Optuna
262
- if csv_file:
263
- temp_path = "temp_data.csv"
264
- with open(temp_path, "wb") as f:
265
- f.write(csv_file.read())
266
-
267
- study = optuna.create_study(direction="maximize")
268
- study.optimize(objective, n_trials=n_trials)
269
-
270
  os.remove(temp_path)
271
- return f"Best Hyperparameters: {study.best_params}\nBest F1 Score: {study.best_value:.4f}"
272
- else:
273
- return "Please upload a CSV file first for hyperparameter tuning."
 
 
 
 
 
 
 
 
274
  except Exception as e:
275
  return f"Hyperparameter tuning failed: {str(e)}"
276
 
@@ -282,11 +399,15 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
282
  notes_input = gr.Textbox(label="Dataset Notes (Optional)", lines=3)
283
  perform_ml = gr.Checkbox(label="Perform Machine Learning Analysis", value=True)
284
  analyze_btn = gr.Button("Analyze", variant="primary")
285
- optuna_trials = gr.Number(label="Number of Hyperparameter Tuning Trials", value=10)
286
- tune_btn = gr.Button("Optimize Hyperparameters", variant="secondary")
 
287
  with gr.Column():
288
- analysis_output = gr.Markdown("### Analysis results will appear here...")
289
- optuna_output = gr.Textbox(label="Best Hyperparameters")
 
 
 
290
  gallery = gr.Gallery(label="Data Visualizations", columns=2)
291
 
292
  analyze_btn.click(
@@ -300,4 +421,5 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
300
  outputs=[optuna_output]
301
  )
302
 
303
- demo.launch(debug=True)
 
 
17
  from sklearn.model_selection import train_test_split
18
  from sklearn.ensemble import RandomForestClassifier
19
  from sklearn.preprocessing import StandardScaler, PolynomialFeatures
20
+ from sklearn.impute import SimpleImputer
21
+ import matplotlib.pyplot as plt
22
 
23
  # Authenticate Hugging Face
24
  hf_token = os.getenv("HF_TOKEN")
 
27
  # Initialize Model
28
  model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
29
 
30
+ def format_observations(observations):
31
+ if not isinstance(observations, dict):
32
+ return f"<pre>{str(observations)}</pre>"
33
+
34
+ return '\n'.join([
35
+ f"""
36
+ <div style="margin: 15px 0; padding: 15px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
37
+ <h3 style="margin: 0 0 10px 0; color: #4A708B;">{key.replace('_', ' ').title()}</h3>
38
+ <pre style="margin: 0; padding: 10px; background: #f8f9fa; border-radius: 4px;">{value}</pre>
39
+ </div>
40
+ """ for key, value in observations.items()
41
+ ])
42
+
43
+ def format_insights(insights, visuals):
44
+ if not isinstance(insights, dict):
45
+ return f"<pre>{str(insights)}</pre>"
46
+
47
+ return '\n'.join([
48
+ f"""
49
+ <div style="margin: 20px 0; padding: 20px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
50
+ <div style="display: flex; align-items: center; gap: 10px;">
51
+ <div style="background: #2B547E; color: white; width: 30px; height: 30px; border-radius: 50%; display: flex; align-items: center; justify-content: center;">{idx+1}</div>
52
+ <p style="margin: 0; font-size: 16px;">{insight}</p>
53
+ </div>
54
+ {f'<img src="/file={visuals[idx]}" style="max-width: 100%; height: auto; margin-top: 10px; border-radius: 6px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">' if idx < len(visuals) else ''}
55
+ </div>
56
+ """ for idx, (key, insight) in enumerate(insights.items())
57
+ ])
58
+
59
  def format_analysis_report(raw_output, visuals, metrics=None, explainability_plots=None):
60
  try:
61
+ # Ensure we have a dictionary to work with
62
+ if isinstance(raw_output, str):
63
+ try:
64
+ analysis_dict = ast.literal_eval(raw_output)
65
+ except:
66
+ analysis_dict = {'observations': {'raw_output': raw_output}, 'insights': {}}
67
+ elif isinstance(raw_output, dict):
68
+ analysis_dict = raw_output
69
+ else:
70
+ analysis_dict = {'observations': {'raw_output': str(raw_output)}, 'insights': {}}
71
+
72
+ # Metrics section
73
  metrics_section = ""
74
  if metrics:
75
  metrics_section = f"""
 
78
  <div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 15px;">
79
  <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
80
  <h3 style="margin: 0 0 10px 0; color: #4A708B;">Accuracy</h3>
81
+ <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics.get('accuracy', 0):.2f}</p>
82
  </div>
83
  <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
84
  <h3 style="margin: 0 0 10px 0; color: #4A708B;">Precision</h3>
85
+ <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics.get('precision', 0):.2f}</p>
86
  </div>
87
  <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
88
  <h3 style="margin: 0 0 10px 0; color: #4A708B;">Recall</h3>
89
+ <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics.get('recall', 0):.2f}</p>
90
  </div>
91
  <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
92
  <h3 style="margin: 0 0 10px 0; color: #4A708B;">F1 Score</h3>
93
+ <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics.get('f1', 0):.2f}</p>
94
  </div>
95
  </div>
96
  </div>
97
  """
98
 
99
+ # Explainability section
100
  explainability_section = ""
101
  if explainability_plots:
102
  explainability_section = f"""
 
108
  </div>
109
  """
110
 
111
+ # Observations section
112
+ observations_section = ""
113
+ if 'observations' in analysis_dict:
114
+ observations_section = f"""
115
  <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
116
  <h2 style="color: #2B547E;">🔍 Key Observations</h2>
117
+ {format_observations(analysis_dict['observations'])}
118
  </div>
119
+ """
120
+
121
+ # Insights section
122
+ insights_section = ""
123
+ if 'insights' in analysis_dict:
124
+ insights_section = f"""
125
  <div style="margin-top: 30px;">
126
  <h2 style="color: #2B547E;">💡 Insights & Visualizations</h2>
127
  {format_insights(analysis_dict.get('insights', {}), visuals)}
128
  </div>
129
+ """
130
+
131
+ # Build the complete report
132
+ report = f"""
133
+ <div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
134
+ <h1 style="color: #2B547E; border-bottom: 2px solid #2B547E; padding-bottom: 10px;">📊 Data Analysis Report</h1>
135
+ {metrics_section}
136
  {explainability_section}
137
+ {observations_section}
138
+ {insights_section}
139
  </div>
140
  """
141
+
142
  return report, visuals
143
+
144
+ except Exception as e:
145
+ error_report = f"""
146
+ <div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
147
+ <h1 style="color: #B22222;">⚠️ Error Generating Report</h1>
148
+ <p>An error occurred while generating the report:</p>
149
+ <pre style="background: #f8f9fa; padding: 10px; border-radius: 4px;">{str(e)}</pre>
150
+ <p>Raw output:</p>
151
+ <pre style="background: #f8f9fa; padding: 10px; border-radius: 4px;">{str(raw_output)}</pre>
152
+ </div>
153
+ """
154
+ return error_report, visuals
155
 
156
  def preprocess_data(df, feature_engineering=True):
157
  """Handle missing values, categorical encoding, and feature engineering"""
158
+ # Make a copy to avoid modifying the original
159
+ df = df.copy()
160
+
161
+ # Basic preprocessing - handle missing values
162
+ numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns
163
+ if len(numeric_cols) > 0:
164
+ imputer = SimpleImputer(strategy='median')
165
+ df[numeric_cols] = imputer.fit_transform(df[numeric_cols])
166
 
167
  # Convert categorical variables if any
168
  categorical_cols = df.select_dtypes(include=['object']).columns
 
170
  if len(df[col].unique()) <= 10: # One-hot encode if few categories
171
  df = pd.concat([df, pd.get_dummies(df[col], prefix=col)], axis=1)
172
  df = df.drop(col, axis=1)
173
+ else: # Otherwise just drop (or could use target encoding)
174
+ df = df.drop(col, axis=1)
175
 
176
  # Feature engineering
177
+ if feature_engineering and len(numeric_cols) > 0:
178
  # Create polynomial features for numerical columns
179
+ poly = PolynomialFeatures(degree=2, interaction_only=True, include_bias=False)
180
+ poly_features = poly.fit_transform(df[numeric_cols])
181
+ poly_cols = [f"poly_{i}" for i in range(poly_features.shape[1])]
182
+ poly_df = pd.DataFrame(poly_features, columns=poly_cols)
183
+ df = pd.concat([df, poly_df], axis=1)
 
 
184
 
185
  return df
186
 
 
208
  os.makedirs(output_dir, exist_ok=True)
209
  plot_paths = []
210
 
211
+ try:
212
+ # SHAP Analysis
213
+ explainer = shap.Explainer(model)
214
+ shap_values = explainer(X[:100]) # Use first 100 samples for speed
215
+
216
+ plt.figure()
217
+ shap.summary_plot(shap_values, X[:100], feature_names=feature_names, show=False)
218
+ shap_path = os.path.join(output_dir, 'shap_summary.png')
219
+ plt.savefig(shap_path, bbox_inches='tight')
220
+ plt.close()
221
+ plot_paths.append(shap_path)
222
+
223
+ # LIME Analysis
224
+ explainer = lime.lime_tabular.LimeTabularExplainer(
225
+ X,
226
+ feature_names=feature_names,
227
+ class_names=[str(x) for x in np.unique(model.classes_)],
228
+ verbose=False,
229
+ mode='classification'
230
+ )
231
+
232
+ # Explain a random instance
233
+ exp = explainer.explain_instance(X[0], model.predict_proba, num_features=5)
234
+ lime_path = os.path.join(output_dir, 'lime_explanation.png')
235
+ exp.as_pyplot_figure().savefig(lime_path, bbox_inches='tight')
236
+ plt.close()
237
+ plot_paths.append(lime_path)
238
+
239
+ except Exception as e:
240
+ print(f"Explainability failed: {str(e)}")
241
 
242
  return plot_paths
243
 
 
258
  "perform_ml": perform_ml
259
  })
260
 
 
 
 
 
261
  metrics = None
262
  explainability_plots = None
263
 
264
+ try:
265
+ # Load and preprocess data
266
+ df = pd.read_csv(csv_file)
267
+
268
+ if perform_ml and len(df.columns) > 1:
269
+ try:
270
+ processed_df = preprocess_data(df)
271
+
272
+ # Assume last column is target for demonstration
273
+ if len(processed_df.columns) > 1: # Ensure we still have features after preprocessing
274
+ X = processed_df.iloc[:, :-1].values
275
+ y = processed_df.iloc[:, -1].values
276
+
277
+ # Convert y to numeric if needed
278
+ if y.dtype == object:
279
+ y = pd.factorize(y)[0]
280
+
281
+ # Evaluate baseline model
282
+ baseline_model = RandomForestClassifier(random_state=42, n_estimators=100)
283
+ metrics = evaluate_model(X, y, baseline_model)
284
+
285
+ # Generate explainability plots
286
+ feature_names = processed_df.columns[:-1]
287
+ explainability_plots = generate_explainability_plots(X, baseline_model, feature_names)
288
+
289
+ wandb.log(metrics)
290
+ except Exception as e:
291
+ print(f"ML analysis failed: {str(e)}")
292
+ wandb.log({"ml_error": str(e)})
293
+
294
+ # Run the main analysis
295
+ agent = CodeAgent(tools=[], model=model, additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn"])
296
+ analysis_result = agent.run("""
297
+ You are an expert data analyst. Perform comprehensive analysis including:
298
+ 1. Basic statistics and data quality checks
299
+ 2. 3 insightful analytical questions about relationships in the data
300
+ 3. Visualization of key patterns and correlations
301
+ 4. Actionable real-world insights derived from findings
302
+ Generate publication-quality visualizations and save to './figures/'
303
+ """, additional_args={"additional_notes": additional_notes, "source_file": csv_file})
304
+
305
+ except Exception as e:
306
+ analysis_result = f"Analysis failed: {str(e)}"
307
 
308
  execution_time = time.time() - start_time
309
  final_memory = process.memory_info().rss / 1024 ** 2
 
317
  run.finish()
318
  return format_analysis_report(analysis_result, visuals, metrics, explainability_plots)
319
 
320
+ def objective(trial, csv_path):
321
+ try:
322
+ # Load and preprocess data
323
+ df = pd.read_csv(csv_path)
324
+ processed_df = preprocess_data(df)
325
+
326
+ if len(processed_df.columns) <= 1:
327
+ return 0.0 # No features to work with
328
+
329
+ X = processed_df.iloc[:, :-1].values
330
+ y = processed_df.iloc[:, -1].values
331
+
332
+ # Convert y to numeric if needed
333
+ if y.dtype == object:
334
+ y = pd.factorize(y)[0]
335
+
336
+ # Define hyperparameter space
337
+ params = {
338
+ 'n_estimators': trial.suggest_int('n_estimators', 50, 500),
339
+ 'max_depth': trial.suggest_int('max_depth', 3, 15),
340
+ 'min_samples_split': trial.suggest_int('min_samples_split', 2, 10),
341
+ 'min_samples_leaf': trial.suggest_int('min_samples_leaf', 1, 5),
342
+ 'max_features': trial.suggest_categorical('max_features', ['sqrt', 'log2']),
343
+ 'bootstrap': trial.suggest_categorical('bootstrap', [True, False])
344
+ }
345
+
346
+ # Split data
347
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
348
+
349
+ # Standardize features
350
+ scaler = StandardScaler()
351
+ X_train = scaler.fit_transform(X_train)
352
+ X_test = scaler.transform(X_test)
353
+
354
+ # Create and evaluate model
355
+ model = RandomForestClassifier(**params, random_state=42)
356
+ model.fit(X_train, y_train)
357
+ y_pred = model.predict(X_test)
358
+
359
+ # Return metric to optimize (F1 score in this case)
360
+ return f1_score(y_test, y_pred, average='weighted')
361
 
362
+ except Exception as e:
363
+ print(f"Trial failed: {str(e)}")
364
+ return 0.0
365
 
366
  def tune_hyperparameters(n_trials: int, csv_file):
367
  try:
368
+ if not csv_file:
369
+ return "Please upload a CSV file first for hyperparameter tuning."
370
+
371
  # Save the uploaded file temporarily for Optuna
372
+ temp_path = "temp_optuna_data.csv"
373
+ with open(temp_path, "wb") as f:
374
+ f.write(csv_file.read())
375
+
376
+ # Verify the data can be loaded
377
+ df = pd.read_csv(temp_path)
378
+ if len(df.columns) <= 1:
 
379
  os.remove(temp_path)
380
+ return "Dataset needs at least one feature and one target column."
381
+
382
+ # Create study and optimize
383
+ study = optuna.create_study(direction="maximize")
384
+ study.optimize(lambda trial: objective(trial, temp_path), n_trials=n_trials)
385
+
386
+ os.remove(temp_path)
387
+ return f"""
388
+ Best Hyperparameters: {study.best_params}
389
+ Best F1 Score: {study.best_value:.4f}
390
+ """
391
  except Exception as e:
392
  return f"Hyperparameter tuning failed: {str(e)}"
393
 
 
399
  notes_input = gr.Textbox(label="Dataset Notes (Optional)", lines=3)
400
  perform_ml = gr.Checkbox(label="Perform Machine Learning Analysis", value=True)
401
  analyze_btn = gr.Button("Analyze", variant="primary")
402
+ with gr.Accordion("Hyperparameter Tuning", open=False):
403
+ optuna_trials = gr.Number(label="Number of Trials", value=10, precision=0)
404
+ tune_btn = gr.Button("Optimize Hyperparameters", variant="secondary")
405
  with gr.Column():
406
+ analysis_output = gr.HTML("""<div style="font-family: Arial, sans-serif; padding: 20px;">
407
+ <h2 style="color: #2B547E;">Analysis results will appear here...</h2>
408
+ <p>Upload a CSV file and click "Analyze" to begin.</p>
409
+ </div>""")
410
+ optuna_output = gr.Textbox(label="Tuning Results", interactive=False)
411
  gallery = gr.Gallery(label="Data Visualizations", columns=2)
412
 
413
  analyze_btn.click(
 
421
  outputs=[optuna_output]
422
  )
423
 
424
+ if __name__ == "__main__":
425
+ demo.launch(debug=True)