szili2011 commited on
Commit
591f159
·
verified ·
1 Parent(s): 881d92c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -225
app.py CHANGED
@@ -37,6 +37,10 @@ import tempfile
37
  import json
38
  import math
39
  import collections.abc # For Gradio issue with new Python versions
 
 
 
 
40
 
41
 
42
  # --- Global Variables / Constants ---
@@ -93,13 +97,14 @@ class SimpleMLP(nn.Module):
93
 
94
  layers.append(nn.Linear(current_dim, output_dim))
95
 
96
- if task_type == "classification": # Changed from task_type.endswith("Classification")
97
- if output_dim == 1: layers.append(nn.Sigmoid()) # Binary
98
- elif output_dim > 1: layers.append(nn.Softmax(dim=-1)) # Multi-class
 
99
  self.network = nn.Sequential(*layers)
100
  def forward(self, x): return self.network(x)
101
 
102
- class SimpleCNN(nn.Module): # Added task_type to constructor for clarity
103
  def __init__(self, input_channels, img_size_wh, num_classes=10, task_type="classification",
104
  c_out1=16, k1=3, s1=1, p1=1, pool1_k=2, pool1_s=2,
105
  c_out2=32, k2=3, s2=1, p2=1, pool2_k=2, pool2_s=2,
@@ -107,7 +112,7 @@ class SimpleCNN(nn.Module): # Added task_type to constructor for clarity
107
  super(SimpleCNN, self).__init__()
108
  self.input_channels = input_channels
109
  self.img_h, self.img_w = img_size_wh
110
- self.num_classes = num_classes # This is the direct output dimension from the last linear layer
111
 
112
  self.conv1 = nn.Conv2d(self.input_channels, c_out1, kernel_size=k1, stride=s1, padding=p1)
113
  self.relu1 = nn.ReLU()
@@ -115,9 +120,8 @@ class SimpleCNN(nn.Module): # Added task_type to constructor for clarity
115
 
116
  h_out_conv1 = (self.img_h - k1 + 2 * p1) // s1 + 1
117
  w_out_conv1 = (self.img_w - k1 + 2 * p1) // s1 + 1
118
- h_pool1 = (h_out_conv1 - pool1_k) // pool1_s + 1 if pool1_k > 0 else h_out_conv1 # handle no pooling
119
- w_pool1 = (w_out_conv1 - pool1_k) // pool1_s + 1 if pool1_k > 0 else w_out_conv1
120
-
121
 
122
  self.conv2 = nn.Conv2d(c_out1, c_out2, kernel_size=k2, stride=s2, padding=p2)
123
  self.relu2 = nn.ReLU()
@@ -125,24 +129,23 @@ class SimpleCNN(nn.Module): # Added task_type to constructor for clarity
125
 
126
  h_out_conv2 = (h_pool1 - k2 + 2 * p2) // s2 + 1
127
  w_out_conv2 = (w_pool1 - k2 + 2 * p2) // s2 + 1
128
- h_pool2 = (h_out_conv2 - pool2_k) // pool2_s + 1 if pool2_k > 0 else h_out_conv2
129
- w_pool2 = (w_out_conv2 - pool2_k) // pool2_s + 1 if pool2_k > 0 else w_out_conv2
130
 
131
  self.flattened_size = c_out2 * h_pool2 * w_pool2
132
  if self.flattened_size <= 0:
133
- raise ValueError(f"Calculated flattened size is {self.flattened_size}. Check CNN params and image size. Current (h_pool2, w_pool2): ({h_pool2},{w_pool2}) from img ({self.img_h},{self.img_w})")
134
 
135
  self.fc1 = nn.Linear(self.flattened_size, fc_hidden)
136
  self.relu3 = nn.ReLU()
137
- self.fc2 = nn.Linear(fc_hidden, self.num_classes) # Output layer before final activation
138
 
 
139
  if task_type == "classification":
140
- if self.num_classes == 1: # Binary classification
141
  self.final_activation = nn.Sigmoid()
142
- elif self.num_classes > 1: # Multi-class classification
143
- self.final_activation = nn.Softmax(dim=1)
144
- else: # Should not happen for classification if num_classes is properly set
145
- self.final_activation = nn.Identity()
146
  else: # Regression
147
  self.final_activation = nn.Identity()
148
 
@@ -155,7 +158,7 @@ class SimpleCNN(nn.Module): # Added task_type to constructor for clarity
155
  x = self.final_activation(x)
156
  return x
157
 
158
- # --- Parameter Target Helpers (PARAM_RANGES, suggest_mlp_layers_for_range, estimate_current_mlp_params, estimate_cnn_params) ---
159
  PARAM_RANGES = collections.OrderedDict([
160
  ("Tiny (<10k)", (0, 10000)),
161
  ("Small (10k-50k)", (10000, 50000)),
@@ -194,34 +197,32 @@ def suggest_mlp_layers_for_range(input_dim, output_dim, target_range_str, curren
194
  if not suggested_layers_str: suggested_layers_str = "64"; logs += "Defaulting to '64'.\n"
195
  return suggested_layers_str, logs
196
 
197
- def estimate_current_mlp_params(input_dim_str, hidden_layers_str, output_dim_str, task_type, current_logs=""): # Added task_type
198
  logs = current_logs
199
  try:
200
  input_dim = int(input_dim_str); output_dim = int(output_dim_str)
201
  if input_dim <= 0 or output_dim <= 0: return "Input/Output dims must be > 0", logs
202
 
203
- # Determine task_type for MLP constructor
204
  mlp_task_type = "classification" if task_type.endswith("Classification") else "regression"
205
  temp_mlp = SimpleMLP(input_dim, hidden_layers_str, output_dim, task_type=mlp_task_type)
206
  params = count_pytorch_parameters(temp_mlp); del temp_mlp
207
  return f"{params:,}", logs
208
  except Exception as e: logs += f"Error estimating MLP params: {e}\n"; return "Error", logs
209
 
210
- def estimate_cnn_params(img_h_str, img_w_str, num_classes_str, task_type, current_logs=""): # Added task_type
211
  logs = current_logs
212
  try:
213
  img_h, img_w, num_classes_parsed = int(img_h_str), int(img_w_str), int(num_classes_str)
214
  if not (img_h > 0 and img_w > 0 and num_classes_parsed > 0): return "Image dims/classes must be > 0", logs
215
 
216
- # Determine task_type for CNN constructor
217
- cnn_task_type = "classification" if task_type.endswith("Classification") else "regression" # Assuming CNN for image is classification
218
  temp_cnn = SimpleCNN(input_channels=1, img_size_wh=(img_h, img_w), num_classes=num_classes_parsed, task_type=cnn_task_type)
219
  params = count_pytorch_parameters(temp_cnn); del temp_cnn
220
  return f"{params:,}", logs
221
  except Exception as e: logs += f"Error estimating CNN params: {traceback.format_exc()}\n"; return "Error", logs
222
 
223
 
224
- # --- Dataset and Preprocessing (generate_dataset_backend, preprocess_tabular_data) ---
225
  def generate_dataset_backend(task_type, n_samples_str, n_features_str,
226
  n_classes_or_informative_str, dataset_format,
227
  ai_suggest_ds_shape, target_param_range_str, model_type_selection,
@@ -253,7 +254,7 @@ def generate_dataset_backend(task_type, n_samples_str, n_features_str,
253
  try:
254
  if task_type == "Tabular Classification":
255
  n_cls = max(2, n_classes_or_informative)
256
- n_inf = max(1, min(n_features, n_classes_or_informative if n_classes_or_informative >= n_cls else n_features // 2)) # make_classification expects n_informative <= n_features
257
  if n_inf > n_features: n_inf = n_features
258
  X_data, y_data = make_classification(n_samples=n_samples, n_features=n_features, n_informative=n_inf,
259
  n_redundant=max(0,n_features - n_inf)//2, n_classes=n_cls, flip_y=0.05, random_state=42)
@@ -280,15 +281,11 @@ def generate_dataset_backend(task_type, n_samples_str, n_features_str,
280
  elif dataset_format == ".parquet": df.to_parquet(file_path, index=False)
281
  else: logs += f"Unsupported format {dataset_format}. Defaulting to CSV.\n"; file_path=get_temp_filepath("generated_dataset","csv"); df.to_csv(file_path, index=False)
282
  logs += f"Dataset saved to {file_path}\n"
283
- # For consistency, data_obj returned should be what train functions expect
284
- # Sklearn train func can take df or path. PyTorch train func can take df, path, or (X,y) tuple.
285
- # Returning df here for generated data is fine.
286
  return df.head(), df, logs, file_path
287
  else:
288
  logs += "Dataset generated as numpy arrays. Not saving to file from this function directly.\n"
289
  return pd.DataFrame(X_data[:5] if X_data is not None else None), (X_data, y_data), logs, None
290
 
291
-
292
  except Exception as e: error_msg=f"Error generating dataset: {traceback.format_exc()}"; logs+=error_msg+"\n"; return None, error_msg, logs, None
293
 
294
  def preprocess_tabular_data(df_or_X, y_if_X_is_numpy, target_column_name, task_type, current_logs=""):
@@ -313,27 +310,6 @@ def preprocess_tabular_data(df_or_X, y_if_X_is_numpy, target_column_name, task_t
313
  ], remainder='passthrough')
314
 
315
  X_processed_np = preprocessor.fit_transform(X_df)
316
-
317
- feature_names_out_list = []
318
- try: feature_names_out_list = list(preprocessor.get_feature_names_out())
319
- except AttributeError:
320
- current_pos = 0
321
- if numerical_features: feature_names_out_list.extend(numerical_features); current_pos += len(numerical_features)
322
- if categorical_features:
323
- cat_encoder = preprocessor.named_transformers_['cat'].named_steps['onehot']
324
- if hasattr(cat_encoder, 'get_feature_names_out'):
325
- cat_feature_names = cat_encoder.get_feature_names_out(categorical_features)
326
- elif hasattr(cat_encoder, 'get_feature_names'):
327
- cat_feature_names = cat_encoder.get_feature_names(categorical_features)
328
- else: # Estimate number of one-hot features
329
- num_onehot_cols = X_processed_np.shape[1] - len(numerical_features) # Assuming only num and cat
330
- cat_feature_names = [f"cat_feat_{i}" for i in range(num_onehot_cols)]
331
- feature_names_out_list.extend(cat_feature_names); current_pos += len(cat_feature_names)
332
- # Handle remainder='passthrough' if necessary, X_processed_np.shape[1] would be total
333
- if X_processed_np.shape[1] > current_pos:
334
- feature_names_out_list.extend([f"other_feat_{i}" for i in range(X_processed_np.shape[1] - current_pos)])
335
-
336
-
337
  processed_input_dim = X_processed_np.shape[1]
338
  logs += f"Tabular data preprocessed. X shape: {X_processed_np.shape}, Processed input dim: {processed_input_dim}\n"
339
 
@@ -342,16 +318,16 @@ def preprocess_tabular_data(df_or_X, y_if_X_is_numpy, target_column_name, task_t
342
  y_processed_np = le.fit_transform(y_series)
343
  num_classes = len(le.classes_)
344
  logs += f"Target encoded. Classes: {num_classes} ({le.classes_})\n"
345
- output_dim_nn = 1 if num_classes == 2 else num_classes # For NN output layer
346
  else: # Regression
347
  y_processed_np = y_series.astype(float).values
348
  num_classes = 1
349
  output_dim_nn = 1
350
 
351
- return X_processed_np, y_processed_np, preprocessor, logs, processed_input_dim, output_dim_nn, feature_names_out_list
352
 
353
 
354
- # --- Training Functions (train_model_sklearn, train_model_pytorch) ---
355
  def train_model_sklearn(data_input_obj, target_column, task_type, model_name, model_output_format, current_logs=""):
356
  logs = current_logs + f"\n--- Training Scikit-learn Model: {model_name} ---\n"
357
  model_path_out, metrics_out, model_params_out = None, "Training failed.", "N/A"
@@ -367,11 +343,11 @@ def train_model_sklearn(data_input_obj, target_column, task_type, model_name, mo
367
  elif isinstance(data_input_obj, pd.DataFrame): df = data_input_obj
368
  else: logs += "Invalid data for training.\n"; return logs, "Error: Invalid data.", None, "N/A"
369
 
370
- if not target_column or target_column not in df.columns: # check if target_column is empty
371
  logs += f"Target column '{target_column}' not provided or not found.\n"; return logs, f"Error: Target '{target_column}' not found/provided.", None, "N/A"
372
 
373
  try:
374
- X_processed_np, y_processed_np, preprocessor, logs, _, _, feature_names_original = preprocess_tabular_data(df, None, target_column, task_type, logs)
375
  except ValueError as e: logs += f"Preprocessing error: {e}\n"; return logs, f"Error: {e}", None, "N/A"
376
 
377
  X_train, X_test, y_train, y_test = train_test_split(X_processed_np, y_processed_np, test_size=0.2, random_state=42)
@@ -414,21 +390,15 @@ def train_model_sklearn(data_input_obj, target_column, task_type, model_name, mo
414
  model_path_out = get_temp_filepath(model_filename_base, "onnx")
415
  raw_X_for_types_df = df.drop(target_column, axis=1).infer_objects()
416
  onnx_initial_types = []
417
- for col_idx, col_name in enumerate(raw_X_for_types_df.columns):
418
  col_dtype = raw_X_for_types_df[col_name].dtype
419
- # Forcing float32 for numeric inputs to ONNX for broader compatibility
420
- # ONNX is stricter about types than scikit-learn sometimes.
421
  if pd.api.types.is_numeric_dtype(col_dtype):
422
- # Create a sample of the correct type for skl2onnx to infer shape and type
423
- # Shape [None, 1] implies one feature at a time for this column.
424
- # If a feature is multi-dimensional (e.g. embeddings), this needs adjustment.
425
- # For typical tabular, each column is one feature.
426
  onnx_initial_types.append((col_name, FloatTensorType([None, 1])))
427
  elif pd.api.types.is_string_dtype(col_dtype) or col_dtype == 'object':
428
  onnx_initial_types.append((col_name, StringTensorType([None, 1])))
429
  else:
430
- logs += f"Warning: Unsupported dtype {col_dtype} for column {col_name} in ONNX. Defaulting to FloatTensorType.\n"
431
- onnx_initial_types.append((col_name, FloatTensorType([None, 1]))) # Fallback
432
 
433
  if not onnx_initial_types: raise ValueError("ONNX initial types failed: No valid columns found.")
434
  try:
@@ -437,9 +407,9 @@ def train_model_sklearn(data_input_obj, target_column, task_type, model_name, mo
437
  with open(model_path_out, "wb") as f: f.write(onnx_model.SerializeToString())
438
  logs += f"Model saved to {model_path_out} as ONNX.\n"
439
  sess = rt.InferenceSession(model_path_out, providers=rt.get_available_providers())
440
- logs += f"ONNX model loaded with ONNX Runtime. Inputs: {[inp.name for inp in sess.get_inputs()]}\n"
441
  except Exception as onnx_e: logs += f"ONNX Error: {traceback.format_exc()}\n"; model_path_out=None; metrics_out+="\nONNX EXPORT FAILED."
442
- else: # Fallback to PKL
443
  logs += f"Unsupported format '{model_output_format}'. Saving as .pkl\n"
444
  model_path_out = get_temp_filepath(model_filename_base, "pkl")
445
  joblib.dump(full_pipeline_for_saving, model_path_out)
@@ -462,8 +432,7 @@ def train_model_pytorch(data_input_obj, target_column, task_type, model_type_pt,
462
  else: logs += f"Unsupported file: {data_input_obj}\n"; return logs, "Error", None, "N/A", None
463
  except Exception as e: logs += f"Error reading {data_input_obj}: {e}\n"; return logs, f"Error: {e}", None, "N/A", None
464
  elif isinstance(data_input_obj, pd.DataFrame): df_for_pytorch = data_input_obj
465
- elif isinstance(data_input_obj, tuple) and len(data_input_obj) == 2 and isinstance(data_input_obj[0], np.ndarray) and isinstance(data_input_obj[1], np.ndarray):
466
- X_numpy_for_pytorch, y_numpy_for_pytorch = data_input_obj
467
  else: logs += "Invalid data for PyTorch training.\n"; return logs, "Error", None, "N/A", None
468
 
469
  try:
@@ -475,36 +444,21 @@ def train_model_pytorch(data_input_obj, target_column, task_type, model_type_pt,
475
  X_processed_np, y_processed_np = None, None
476
 
477
  if model_type_pt == "Simple Neural Network (MLP)":
478
- if not task_type.startswith("Tabular"):
479
- logs += "MLP requires Tabular task.\n"; return logs, "MLP Task Error", None, "N/A", None
480
- if not target_column and df_for_pytorch is not None: # Check if target column is provided for DataFrame
481
- logs += "Target column needed for MLP with DataFrame input.\n"; return logs, "MLP Target Error", None, "N/A", None
482
-
483
  try:
484
  data_arg1 = df_for_pytorch if df_for_pytorch is not None else X_numpy_for_pytorch
485
  data_arg2 = y_numpy_for_pytorch if df_for_pytorch is None else None
486
- # target_column is only relevant if data_arg1 is a DataFrame
487
- current_target_col = target_column if df_for_pytorch is not None else "target" # Placeholder if from numpy
488
-
489
- X_processed_np, y_processed_np, preprocessor_pipeline, logs, processed_input_dim_actual, nn_output_dim_actual, _ = \
490
  preprocess_tabular_data(data_arg1, data_arg2, current_target_col, task_type, logs)
491
  except ValueError as e: logs+=f"MLP Preprocessing error: {e}\n"; return logs,f"Error: {e}",None,"N/A",None
492
  elif model_type_pt == "Simple Convolutional Network (CNN)":
493
- if task_type != "Basic Image Classification": logs += "Warning: CNN selected, but task is not Basic Image Classification.\n"
494
-
495
- X_raw, y_raw = None, None
496
- if df_for_pytorch is not None:
497
- if not target_column or target_column not in df_for_pytorch.columns:
498
- logs += f"Target '{target_column}' not found/provided for CNN.\n"; return logs, "CNN Target Error", None, "N/A", None
499
- X_raw = df_for_pytorch.drop(target_column, axis=1).values
500
- y_raw = df_for_pytorch[target_column].values
501
- elif X_numpy_for_pytorch is not None and y_numpy_for_pytorch is not None:
502
- X_raw = X_numpy_for_pytorch; y_raw = y_numpy_for_pytorch
503
- else: logs += "No valid data for CNN.\n"; return logs, "CNN Data Error", None, "N/A", None
504
-
505
  le = LabelEncoder(); y_processed_np = le.fit_transform(y_raw)
506
- nn_output_dim_actual = len(le.classes_)
507
- if nn_output_dim_actual == 2: nn_output_dim_actual = 1
508
 
509
  pixels_per_sample = X_raw.shape[1]; img_h, img_w, input_channels = 28,28,1
510
  img_dim_approx = int(math.sqrt(pixels_per_sample))
@@ -513,86 +467,75 @@ def train_model_pytorch(data_input_obj, target_column, task_type, model_type_pt,
513
 
514
  X_processed_np = X_raw.reshape(-1, input_channels, img_h, img_w).astype(np.float32) / 255.0
515
  processed_input_dim_actual = (input_channels, img_h, img_w)
516
- logs += f"CNN Data: X reshaped to {X_processed_np.shape}, y: {y_processed_np.shape}, NN Output Dim: {nn_output_dim_actual}\n"
517
  else: logs += f"Unknown PyTorch model: {model_type_pt}\n"; return logs, "Unknown PyTorch model", None, "N/A", None
518
 
519
- y_dtype = torch.float32 if (nn_output_dim_actual == 1 and task_type.endswith("Regression")) or \
520
- (nn_output_dim_actual == 1 and task_type.endswith("Classification")) \
521
- else torch.long
522
- X_tensor = torch.tensor(X_processed_np, dtype=torch.float32)
523
- y_tensor = torch.tensor(y_processed_np, dtype=y_dtype)
524
- if nn_output_dim_actual == 1 and task_type.endswith("Classification"): y_tensor = y_tensor.unsqueeze(1)
525
- if task_type.endswith("Regression"): y_tensor = y_tensor.unsqueeze(1)
 
526
 
527
- dataset = TensorDataset(X_tensor, y_tensor)
528
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
 
 
 
 
529
 
530
  pytorch_model = None
531
  try:
 
532
  if model_type_pt == "Simple Neural Network (MLP)":
533
  pytorch_model = SimpleMLP(input_dim=processed_input_dim_actual, hidden_layers_str=mlp_hidden_layers_str,
534
  output_dim=nn_output_dim_actual, activation_fn_str=mlp_activation,
535
- task_type="classification" if task_type.endswith("Classification") else "regression")
536
  elif model_type_pt == "Simple Convolutional Network (CNN)":
537
  channels, h, w = processed_input_dim_actual
538
  pytorch_model = SimpleCNN(input_channels=channels, img_size_wh=(h,w), num_classes=nn_output_dim_actual,
539
- task_type="classification" if task_type.endswith("Classification") else "regression") # Pass task_type
540
  except Exception as model_e: logs += f"Error creating PyTorch model: {traceback.format_exc()}\n"; return logs, f"Model Creation Error: {model_e}", None, "N/A", None
541
 
542
  if pytorch_model is None: logs += "Failed to instantiate PyTorch model.\n"; return logs, "Model instantiate fail", None, "N/A", None
543
  model_params_val = count_pytorch_parameters(pytorch_model); model_params_out = f"{model_params_val:,}"
544
  logs += f"PyTorch Model: {model_params_out} params.\n"
545
- if model_params_val > 500000: logs += "Warning: >500k params on CPU will be SLOW.\n"
546
 
547
- is_classification_task = task_type.endswith("Classification") # Simplified condition
548
- if is_classification_task:
549
- criterion = nn.BCELoss() if nn_output_dim_actual == 1 else nn.CrossEntropyLoss()
550
- else: criterion = nn.MSELoss()
551
  optimizer = optim.Adam(pytorch_model.parameters(), lr=lr)
552
 
553
  logs += f"Starting PyTorch training for {epochs} epochs...\n"; start_time = time.time()
554
  epoch_losses = []
555
  pytorch_model.train()
556
  for epoch in range(epochs):
557
- epoch_loss_sum = 0.0; num_batches = 0
558
- for i, (batch_X, batch_y) in enumerate(dataloader): # Added enumerate for batch index
559
  optimizer.zero_grad(); outputs = pytorch_model(batch_X)
560
  loss = criterion(outputs, batch_y); loss.backward(); optimizer.step()
561
- epoch_loss_sum += loss.item(); num_batches += 1
562
- avg_epoch_loss = epoch_loss_sum / num_batches if num_batches > 0 else 0
563
  epoch_losses.append(avg_epoch_loss)
564
- logs += f"Epoch {epoch+1}/{epochs}, Avg Loss: {avg_epoch_loss:.4f}\n"
 
565
 
566
  logs += f"PyTorch training completed in {time.time() - start_time:.2f}s.\n"
567
  pytorch_model.eval()
568
  with torch.no_grad():
569
- if is_classification_task and dataloader.dataset and len(dataloader.dataset)>0 :
570
- try:
571
- all_preds, all_targets = [], []
572
- for batch_X, batch_y in dataloader: # Evaluate on whole dataset (or a test split ideally)
573
- outputs = pytorch_model(batch_X)
574
- if nn_output_dim_actual == 1: predicted = (outputs > 0.5).float()
575
- else: _, predicted = torch.max(outputs.data, 1)
576
- all_preds.extend(predicted.cpu().numpy())
577
- all_targets.extend(batch_y.cpu().numpy())
578
-
579
- if all_targets and all_preds: # Check if lists are not empty
580
- # Ensure all_targets is 1D for accuracy_score if predicted is also 1D (binary case)
581
- all_targets_np = np.array(all_targets).squeeze()
582
- all_preds_np = np.array(all_preds).squeeze()
583
- acc = accuracy_score(all_targets_np, all_preds_np)
584
- metrics_out = f"Final Training Loss: {avg_epoch_loss:.4f}\nAccuracy on training data: {acc*100:.2f}%"
585
- else:
586
- metrics_out = f"Final Training Loss: {avg_epoch_loss:.4f}\n (Could not compute accuracy)"
587
-
588
- except Exception as eval_e: metrics_out = f"Final Training Loss: {avg_epoch_loss:.4f}\n Eval Error: {eval_e}"
589
- else: metrics_out = f"Final Training Loss (MSE): {avg_epoch_loss:.4f}"
590
  logs += "\n--- PyTorch Metrics ---\n" + metrics_out + "\n"
591
 
592
  if epoch_losses:
593
- import matplotlib # Use Agg backend for non-interactive environments like Spaces
594
- matplotlib.use('Agg')
595
- import matplotlib.pyplot as plt
596
  fig, ax = plt.subplots(); ax.plot(range(1, epochs + 1), epoch_losses, marker='o')
597
  ax.set_xlabel("Epoch"); ax.set_ylabel("Average Loss"); ax.set_title("Training Loss Curve")
598
  plot_out = fig; logs += "Loss curve generated.\n"
@@ -602,23 +545,63 @@ def train_model_pytorch(data_input_obj, target_column, task_type, model_type_pt,
602
  model_path_out = get_temp_filepath(model_filename_base, "pt")
603
  save_obj = {'model_state_dict': pytorch_model.state_dict(), 'output_dim_nn': nn_output_dim_actual, 'task_type': task_type}
604
  if model_type_pt == "Simple Neural Network (MLP)" and preprocessor_pipeline:
605
- save_obj.update({
606
- 'preprocessor': preprocessor_pipeline, 'input_dim_processed': processed_input_dim_actual,
607
- 'hidden_layers_str': mlp_hidden_layers_str, 'activation_fn': mlp_activation,
608
- })
609
- logs += f"PyTorch MLP (model + preprocessor) saved to {model_path_out}\n"
610
  elif model_type_pt == "Simple Convolutional Network (CNN)":
611
- c,h,w = processed_input_dim_actual
612
- save_obj.update({'input_channels':c, 'img_h':h, 'img_w':w}) # Save CNN architecture details
613
- logs += f"PyTorch CNN (model state_dict + arch_details) saved to {model_path_out}\n"
614
- else: logs += f"PyTorch {model_type_pt} (model state_dict) saved to {model_path_out}\n"
615
  torch.save(save_obj, model_path_out)
 
616
  else: # Fallback
617
- logs += f"Unsupported format '{model_output_format}'. Saving as .pt\n"
618
- model_path_out = get_temp_filepath(model_filename_base, "pt")
619
- torch.save(pytorch_model.state_dict(), model_path_out)
620
  return logs, metrics_out, model_path_out, model_params_out, plot_out
621
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
622
  # --- Gradio UI Definition ---
623
  TASK_CHOICES = ["Tabular Classification", "Tabular Regression", "Basic Image Classification"]
624
  MODEL_FAMILIES = ["Scikit-learn (Classical ML)", "PyTorch (Neural Networks)"]
@@ -631,20 +614,19 @@ MODEL_OUTPUT_FORMATS_PYTORCH = [".pt (PyTorch)"]
631
  MLP_ACTIVATIONS = ["relu", "tanh", "sigmoid"]
632
  CLONE_GUIDE_TEXT = """
633
  ## How to Clone & Upgrade This Space for More Power:
634
- (Instructions as provided in previous response - omitted here for brevity but should be included)
 
 
635
  """
636
 
637
  # Determine initial choices for model_specific_dd based on default task_type and model_family
638
  _initial_task_default = TASK_CHOICES[0]
639
  _initial_family_default = MODEL_FAMILIES[0]
640
- initial_model_choices_for_specific_dd = []
641
  if _initial_family_default == "Scikit-learn (Classical ML)":
642
- if _initial_task_default == "Tabular Classification": initial_model_choices_for_specific_dd = SKLEARN_MODELS_CLASSIFICATION
643
- elif _initial_task_default == "Tabular Regression": initial_model_choices_for_specific_dd = SKLEARN_MODELS_REGRESSION
644
  elif _initial_family_default == "PyTorch (Neural Networks)":
645
- if _initial_task_default.startswith("Tabular"): initial_model_choices_for_specific_dd = [PYTORCH_MODELS[0]]
646
- elif _initial_task_default == "Basic Image Classification": initial_model_choices_for_specific_dd = [PYTORCH_MODELS[1]]
647
- initial_model_value_for_specific_dd = initial_model_choices_for_specific_dd[0] if initial_model_choices_for_specific_dd else None
648
 
649
  def update_model_options(task_choice, model_family_choice):
650
  choices, value = [], None
@@ -657,25 +639,18 @@ def update_model_options(task_choice, model_family_choice):
657
  value = choices[0] if choices else None
658
  return gr.update(choices=choices, value=value, visible=bool(choices))
659
 
660
- def update_param_range_visibility(model_family_choice):
661
- return gr.update(visible=(model_family_choice == "PyTorch (Neural Networks)"))
662
-
663
- def update_pytorch_specific_options_visibility(model_choice_pytorch_family, specific_pytorch_model):
664
- # Only proceed if family is PyTorch
665
- if model_choice_pytorch_family != "PyTorch (Neural Networks)":
666
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) # Hide all: param_range, mlp_group, cnn_group
667
-
668
- param_range_visible = True # Always true if PyTorch family
669
- is_mlp = (specific_pytorch_model == "Simple Neural Network (MLP)")
670
- is_cnn = (specific_pytorch_model == "Simple Convolutional Network (CNN)")
671
- return gr.update(visible=param_range_visible), gr.update(visible=is_mlp), gr.update(visible=is_cnn)
672
 
673
  def update_model_output_formats(model_family_choice):
674
  if model_family_choice == "Scikit-learn (Classical ML)": return gr.update(choices=MODEL_OUTPUT_FORMATS_SKLEARN, value=MODEL_OUTPUT_FORMATS_SKLEARN[0])
675
- elif model_family_choice == "PyTorch (Neural Networks)": return gr.update(choices=MODEL_OUTPUT_FORMATS_PYTORCH, value=MODEL_OUTPUT_FORMATS_PYTORCH[0])
676
  return gr.update(choices=[], value=None)
677
 
678
- css = """.gradio-container { font-family: 'IBM Plex Sans', sans-serif; } footer {display:none !important}""" # Hide footer too
679
 
680
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="orange"), css=css) as demo:
681
  gr.Markdown("# 🧠 TrainAI ⚙️")
@@ -690,27 +665,28 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="orange"),
690
  model_family_dd = gr.Dropdown(MODEL_FAMILIES, label="Select Model Family", value=_initial_family_default)
691
  model_specific_dd = gr.Dropdown(label="Select Specific Model", choices=initial_model_choices_for_specific_dd, value=initial_model_value_for_specific_dd, interactive=True)
692
 
693
- pytorch_param_range_dd = gr.Dropdown(list(PARAM_RANGES.keys()), label="Target Parameter Range (for NNs)",
694
- info="Guides NN architecture suggestions. Training >250k params on CPU is slow.",
695
- value=list(PARAM_RANGES.keys())[1], visible=(_initial_family_default == "PyTorch (Neural Networks)"))
696
- with gr.Group(visible=(_initial_family_default == "PyTorch (Neural Networks)" and initial_model_value_for_specific_dd == PYTORCH_MODELS[0])) as pt_mlp_specific_group:
697
- gr.Markdown("#### MLP Configuration")
698
- pt_mlp_hidden_layers_txt = gr.Textbox(label="Hidden Layer Sizes (comma-separated, e.g., 128,64)", value="64,32")
699
- pt_mlp_activation_dd = gr.Dropdown(MLP_ACTIVATIONS, label="Activation Function", value="relu")
700
- with gr.Row():
701
- pt_mlp_suggest_btn = gr.Button("Suggest MLP Layers")
702
- pt_mlp_estimate_params_btn = gr.Button("Estimate Current MLP Params")
703
- pt_mlp_param_count_txt = gr.Textbox(label="Estimated MLP Parameters", interactive=False)
704
- with gr.Group(visible=(_initial_family_default == "PyTorch (Neural Networks)" and initial_model_value_for_specific_dd == PYTORCH_MODELS[1])) as pt_cnn_specific_group:
705
- gr.Markdown("#### CNN Configuration (Simplified)")
706
- gr.Markdown("SimpleCNN uses a fixed structure. Params depend on image size/classes from data.")
707
- pt_cnn_estimate_params_btn = gr.Button("Estimate CNN Params (needs Data Info)")
708
- pt_cnn_param_count_txt = gr.Textbox(label="Estimated CNN Parameters", interactive=False)
709
-
710
- with gr.TabItem("2. Configure Dataset"): # This tab should show generate_dataset_group by default
 
711
  dataset_source_rb = gr.Radio(["Generate new dataset", "Upload my own dataset (CSV, JSON, Parquet)"],
712
  label="Dataset Source", value="Generate new dataset")
713
- with gr.Group(visible=True) as generate_dataset_group: # Default visible=True
714
  gr.Markdown("#### Generate Synthetic Dataset")
715
  with gr.Row():
716
  ds_gen_samples_num = gr.Number(label="# Samples", value=1000, minimum=10, step=100)
@@ -722,7 +698,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="orange"),
722
  with gr.Group(visible=False) as upload_dataset_group:
723
  gr.Markdown("#### Upload Dataset")
724
  ds_upload_file = gr.File(label="Upload dataset", file_types=[".csv", ".json", ".parquet"])
725
- target_column_name_txt = gr.Textbox(label="Target Column Name (Case-Sensitive!)", placeholder="e.g., 'target' or 'label'")
726
  dataset_preview_df = gr.DataFrame(label="Dataset Preview (First 5 Rows)", interactive=False, height=200)
727
  generated_dataset_download_file = gr.File(label="Download Generated Dataset", interactive=False)
728
 
@@ -731,7 +707,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="orange"),
731
  with gr.Row():
732
  train_epochs_num = gr.Number(label="Epochs (NNs)", value=10, minimum=1, step=1)
733
  train_batch_size_num = gr.Number(label="Batch Size (NNs)", value=32, minimum=1, step=1)
734
- train_learning_rate_num = gr.Number(label="Learning Rate (NNs)", value=0.001, minimum=1e-6, step=1e-4, precision=6)
735
  model_output_format_dd = gr.Dropdown(label="Select Model Output Format", choices=MODEL_OUTPUT_FORMATS_SKLEARN, value=MODEL_OUTPUT_FORMATS_SKLEARN[0])
736
  train_model_btn = gr.Button("🚀 Train Model", variant="primary")
737
  gr.Markdown("---")
@@ -743,46 +719,40 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="orange"),
743
  download_trained_model_file = gr.File(label="Download Trained Model", interactive=False)
744
 
745
  with gr.TabItem("ℹ️ Guide & Info"):
746
- # ... Guide content ... (omitted for brevity)
747
  gr.Markdown(CLONE_GUIDE_TEXT)
748
 
749
  # --- Event Handlers ---
750
  task_type_dd.change(fn=update_model_options, inputs=[task_type_dd, model_family_dd], outputs=model_specific_dd)
751
  model_family_dd.change(fn=update_model_options, inputs=[task_type_dd, model_family_dd], outputs=model_specific_dd)
752
 
753
- model_family_dd.change(fn=update_pytorch_specific_options_visibility, inputs=[model_family_dd, model_specific_dd], outputs=[pytorch_param_range_dd, pt_mlp_specific_group, pt_cnn_specific_group])
754
- model_specific_dd.change(fn=update_pytorch_specific_options_visibility, inputs=[model_family_dd, model_specific_dd], outputs=[pytorch_param_range_dd, pt_mlp_specific_group, pt_cnn_specific_group])
755
 
756
  def get_data_dims_for_nn_suggestion(preview_df, target_col, task, logs_in):
757
  logs = logs_in
758
  input_dim_est, output_dim_est = 10, (2 if task.endswith("Classification") else 1) # Defaults
759
  img_h_est, img_w_est = 28, 28 # Defaults for CNN
 
760
 
761
  if preview_df is not None and isinstance(preview_df, pd.DataFrame) and not preview_df.empty:
762
- temp_X_cols = [col for col in preview_df.columns if col != target_col] # Features
763
- if not temp_X_cols and task == "Basic Image Classification": # Image data often has no named feature cols in preview
764
- if preview_df.shape[1] == 1 and target_col in preview_df.columns: # Only target col
765
- pass # Cannot estimate image from only target
766
- elif target_col in preview_df.columns:
767
- num_pixels = preview_df.shape[1] -1
768
- else: # no target col (e.g. raw image data)
769
- num_pixels = preview_df.shape[1]
770
-
771
  if num_pixels > 0:
772
  dim_sqrt = int(math.sqrt(num_pixels))
773
  if dim_sqrt * dim_sqrt == num_pixels: img_h_est, img_w_est = dim_sqrt, dim_sqrt
774
- else: logs += f"Non-square image ({num_pixels} pixels) from preview. Using default {img_h_est}x{img_w_est} for suggestion.\n"
775
- input_dim_est = img_h_est * img_w_est # For CNN, this is not input_dim to MLP but for SimpleCNN internal calcs
776
- elif temp_X_cols: # Tabular
777
- num_cols = len([col for col in temp_X_cols if pd.api.types.is_numeric_dtype(preview_df[col])])
778
- cat_cols = [col for col in temp_X_cols if pd.api.types.is_object_dtype(preview_df[col])]
779
- one_hot_est = sum(min(10, preview_df[col].nunique(dropna=False)) for col in cat_cols)
780
  input_dim_est = max(1, num_cols + one_hot_est)
781
 
782
  if target_col and target_col in preview_df.columns:
783
  if task.endswith("Classification"):
784
  output_dim_est = max(1, preview_df[target_col].nunique(dropna=False))
785
- if output_dim_est == 2: output_dim_est = 1
786
  else: logs += "Dataset preview not available for NN dimension estimation. Using defaults.\n"
787
  return input_dim_est, output_dim_est, img_h_est, img_w_est, logs
788
 
@@ -811,9 +781,8 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="orange"),
811
  def cnn_estimate_proxy_wrapper(current_logs, preview_df, target_col, task_type):
812
  _, output_dim, img_h, img_w, logs = get_data_dims_for_nn_suggestion(preview_df, target_col, task_type, current_logs)
813
  logs += f"Using estimated img_h: {img_h}, img_w: {img_w}, output_dim: {output_dim} for CNN param estimation.\n"
814
- # For CNN, task type for constructor is 'classification' typically
815
- cnn_task_type_for_constructor = "classification" if task_type == "Basic Image Classification" else "regression" # Placeholder
816
- param_count_str, logs = estimate_cnn_params(str(img_h), str(img_w), str(output_dim), cnn_task_type_for_constructor, logs)
817
  return logs, param_count_str
818
 
819
  pt_cnn_estimate_params_btn.click(fn=cnn_estimate_proxy_wrapper,
@@ -826,33 +795,32 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="orange"),
826
  model_family_dd.change(fn=update_model_output_formats, inputs=model_family_dd, outputs=model_output_format_dd)
827
 
828
  generate_dataset_btn.click(
829
- fn=generate_dataset_backend, # Pass the function directly
830
  inputs=[task_type_dd, ds_gen_samples_num, ds_gen_features_num, ds_gen_classes_informative_num,
831
  ds_gen_format_dd, ds_gen_ai_suggest_cb, pytorch_param_range_dd, model_specific_dd, current_logs_state],
832
  outputs=[dataset_preview_df, generated_data_state, training_log_txt, generated_dataset_download_file])
833
 
834
  def process_uploaded_file(file_obj, logs_in):
835
- logs, df_preview, status_msg, stored_data_path = logs_in, None, "Upload failed or no file.", None
836
- if file_obj is None: logs += "Please upload a file first.\n"; return df_preview, logs, status_msg, stored_data_path
837
  logs += f"Uploaded file: {file_obj.name}\n"; stored_data_path = file_obj.name
838
  try:
839
  if file_obj.name.endswith(".csv"): df_preview = pd.read_csv(file_obj.name, nrows=5)
840
  elif file_obj.name.endswith(".json"): df_preview = pd.read_json(file_obj.name, lines=True, nrows=5)
841
- elif file_obj.name.endswith(".parquet"): temp_df = pd.read_parquet(file_obj.name); df_preview = temp_df.head() # Parquet preview needs full read
842
- status_msg = "Preview generated for uploaded file." if df_preview is not None else "Could not generate preview, but file is noted."
843
- logs += status_msg + "\n"
844
- except Exception as e: logs += f"Error previewing {file_obj.name}: {e}\n"; status_msg = f"Error previewing: {e}"
845
- return df_preview, logs, status_msg, stored_data_path
846
  ds_upload_file.upload(fn=process_uploaded_file, inputs=[ds_upload_file, current_logs_state],
847
- outputs=[dataset_preview_df, training_log_txt, training_log_txt, generated_data_state])
848
 
849
  train_model_btn.click(
850
- fn=train_model_wrapper, # Wrapper function defined before UI
851
  inputs=[generated_data_state, target_column_name_txt, task_type_dd, model_family_dd, model_specific_dd,
852
- model_specific_dd, pt_mlp_hidden_layers_txt, pt_mlp_activation_dd,
853
  train_epochs_num, train_batch_size_num, train_learning_rate_num,
854
  model_output_format_dd, training_log_txt],
855
  outputs=[training_log_txt, evaluation_metrics_txt, download_trained_model_file,
856
- model_param_count_output_txt, loss_plot_img, download_trained_model_file])
857
 
858
  demo.queue().launch(debug=True, show_error=True)
 
37
  import json
38
  import math
39
  import collections.abc # For Gradio issue with new Python versions
40
+ import collections # Added for OrderedDict if not already covered
41
+ import matplotlib # Use Agg backend for non-interactive environments
42
+ matplotlib.use('Agg')
43
+ import matplotlib.pyplot as plt
44
 
45
 
46
  # --- Global Variables / Constants ---
 
97
 
98
  layers.append(nn.Linear(current_dim, output_dim))
99
 
100
+ if task_type == "classification":
101
+ if output_dim == 1: # For BCELoss (binary classification)
102
+ layers.append(nn.Sigmoid())
103
+ # For multi-class, nn.CrossEntropyLoss expects raw logits, so no final activation here.
104
  self.network = nn.Sequential(*layers)
105
  def forward(self, x): return self.network(x)
106
 
107
+ class SimpleCNN(nn.Module):
108
  def __init__(self, input_channels, img_size_wh, num_classes=10, task_type="classification",
109
  c_out1=16, k1=3, s1=1, p1=1, pool1_k=2, pool1_s=2,
110
  c_out2=32, k2=3, s2=1, p2=1, pool2_k=2, pool2_s=2,
 
112
  super(SimpleCNN, self).__init__()
113
  self.input_channels = input_channels
114
  self.img_h, self.img_w = img_size_wh
115
+ self.num_classes = num_classes
116
 
117
  self.conv1 = nn.Conv2d(self.input_channels, c_out1, kernel_size=k1, stride=s1, padding=p1)
118
  self.relu1 = nn.ReLU()
 
120
 
121
  h_out_conv1 = (self.img_h - k1 + 2 * p1) // s1 + 1
122
  w_out_conv1 = (self.img_w - k1 + 2 * p1) // s1 + 1
123
+ h_pool1 = (h_out_conv1 - pool1_k) // pool1_s + 1
124
+ w_pool1 = (w_out_conv1 - pool1_k) // pool1_s + 1
 
125
 
126
  self.conv2 = nn.Conv2d(c_out1, c_out2, kernel_size=k2, stride=s2, padding=p2)
127
  self.relu2 = nn.ReLU()
 
129
 
130
  h_out_conv2 = (h_pool1 - k2 + 2 * p2) // s2 + 1
131
  w_out_conv2 = (w_pool1 - k2 + 2 * p2) // s2 + 1
132
+ h_pool2 = (h_out_conv2 - pool2_k) // pool2_s + 1
133
+ w_pool2 = (w_out_conv2 - pool2_k) // pool2_s + 1
134
 
135
  self.flattened_size = c_out2 * h_pool2 * w_pool2
136
  if self.flattened_size <= 0:
137
+ raise ValueError(f"Calculated flattened size is {self.flattened_size}. Check CNN params and image size.")
138
 
139
  self.fc1 = nn.Linear(self.flattened_size, fc_hidden)
140
  self.relu3 = nn.ReLU()
141
+ self.fc2 = nn.Linear(fc_hidden, self.num_classes)
142
 
143
+ # The final activation is now a separate attribute for clarity.
144
  if task_type == "classification":
145
+ if self.num_classes == 1: # Binary classification with BCELoss
146
  self.final_activation = nn.Sigmoid()
147
+ else: # Multi-class classification with CrossEntropyLoss
148
+ self.final_activation = nn.Identity() # The loss function combines Softmax and NLLLoss.
 
 
149
  else: # Regression
150
  self.final_activation = nn.Identity()
151
 
 
158
  x = self.final_activation(x)
159
  return x
160
 
161
+ # --- Parameter Target Helpers ---
162
  PARAM_RANGES = collections.OrderedDict([
163
  ("Tiny (<10k)", (0, 10000)),
164
  ("Small (10k-50k)", (10000, 50000)),
 
197
  if not suggested_layers_str: suggested_layers_str = "64"; logs += "Defaulting to '64'.\n"
198
  return suggested_layers_str, logs
199
 
200
+ def estimate_current_mlp_params(input_dim_str, hidden_layers_str, output_dim_str, task_type, current_logs=""):
201
  logs = current_logs
202
  try:
203
  input_dim = int(input_dim_str); output_dim = int(output_dim_str)
204
  if input_dim <= 0 or output_dim <= 0: return "Input/Output dims must be > 0", logs
205
 
 
206
  mlp_task_type = "classification" if task_type.endswith("Classification") else "regression"
207
  temp_mlp = SimpleMLP(input_dim, hidden_layers_str, output_dim, task_type=mlp_task_type)
208
  params = count_pytorch_parameters(temp_mlp); del temp_mlp
209
  return f"{params:,}", logs
210
  except Exception as e: logs += f"Error estimating MLP params: {e}\n"; return "Error", logs
211
 
212
+ def estimate_cnn_params(img_h_str, img_w_str, num_classes_str, task_type, current_logs=""):
213
  logs = current_logs
214
  try:
215
  img_h, img_w, num_classes_parsed = int(img_h_str), int(img_w_str), int(num_classes_str)
216
  if not (img_h > 0 and img_w > 0 and num_classes_parsed > 0): return "Image dims/classes must be > 0", logs
217
 
218
+ cnn_task_type = "classification" if task_type.endswith("Classification") else "regression"
 
219
  temp_cnn = SimpleCNN(input_channels=1, img_size_wh=(img_h, img_w), num_classes=num_classes_parsed, task_type=cnn_task_type)
220
  params = count_pytorch_parameters(temp_cnn); del temp_cnn
221
  return f"{params:,}", logs
222
  except Exception as e: logs += f"Error estimating CNN params: {traceback.format_exc()}\n"; return "Error", logs
223
 
224
 
225
+ # --- Dataset and Preprocessing ---
226
  def generate_dataset_backend(task_type, n_samples_str, n_features_str,
227
  n_classes_or_informative_str, dataset_format,
228
  ai_suggest_ds_shape, target_param_range_str, model_type_selection,
 
254
  try:
255
  if task_type == "Tabular Classification":
256
  n_cls = max(2, n_classes_or_informative)
257
+ n_inf = max(1, min(n_features, n_classes_or_informative if n_classes_or_informative >= n_cls else n_features // 2))
258
  if n_inf > n_features: n_inf = n_features
259
  X_data, y_data = make_classification(n_samples=n_samples, n_features=n_features, n_informative=n_inf,
260
  n_redundant=max(0,n_features - n_inf)//2, n_classes=n_cls, flip_y=0.05, random_state=42)
 
281
  elif dataset_format == ".parquet": df.to_parquet(file_path, index=False)
282
  else: logs += f"Unsupported format {dataset_format}. Defaulting to CSV.\n"; file_path=get_temp_filepath("generated_dataset","csv"); df.to_csv(file_path, index=False)
283
  logs += f"Dataset saved to {file_path}\n"
 
 
 
284
  return df.head(), df, logs, file_path
285
  else:
286
  logs += "Dataset generated as numpy arrays. Not saving to file from this function directly.\n"
287
  return pd.DataFrame(X_data[:5] if X_data is not None else None), (X_data, y_data), logs, None
288
 
 
289
  except Exception as e: error_msg=f"Error generating dataset: {traceback.format_exc()}"; logs+=error_msg+"\n"; return None, error_msg, logs, None
290
 
291
  def preprocess_tabular_data(df_or_X, y_if_X_is_numpy, target_column_name, task_type, current_logs=""):
 
310
  ], remainder='passthrough')
311
 
312
  X_processed_np = preprocessor.fit_transform(X_df)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  processed_input_dim = X_processed_np.shape[1]
314
  logs += f"Tabular data preprocessed. X shape: {X_processed_np.shape}, Processed input dim: {processed_input_dim}\n"
315
 
 
318
  y_processed_np = le.fit_transform(y_series)
319
  num_classes = len(le.classes_)
320
  logs += f"Target encoded. Classes: {num_classes} ({le.classes_})\n"
321
+ output_dim_nn = 1 if num_classes == 2 else num_classes
322
  else: # Regression
323
  y_processed_np = y_series.astype(float).values
324
  num_classes = 1
325
  output_dim_nn = 1
326
 
327
+ return X_processed_np, y_processed_np, preprocessor, logs, processed_input_dim, output_dim_nn
328
 
329
 
330
+ # --- Training Functions ---
331
  def train_model_sklearn(data_input_obj, target_column, task_type, model_name, model_output_format, current_logs=""):
332
  logs = current_logs + f"\n--- Training Scikit-learn Model: {model_name} ---\n"
333
  model_path_out, metrics_out, model_params_out = None, "Training failed.", "N/A"
 
343
  elif isinstance(data_input_obj, pd.DataFrame): df = data_input_obj
344
  else: logs += "Invalid data for training.\n"; return logs, "Error: Invalid data.", None, "N/A"
345
 
346
+ if not target_column or target_column not in df.columns:
347
  logs += f"Target column '{target_column}' not provided or not found.\n"; return logs, f"Error: Target '{target_column}' not found/provided.", None, "N/A"
348
 
349
  try:
350
+ X_processed_np, y_processed_np, preprocessor, logs, _, _ = preprocess_tabular_data(df, None, target_column, task_type, logs)
351
  except ValueError as e: logs += f"Preprocessing error: {e}\n"; return logs, f"Error: {e}", None, "N/A"
352
 
353
  X_train, X_test, y_train, y_test = train_test_split(X_processed_np, y_processed_np, test_size=0.2, random_state=42)
 
390
  model_path_out = get_temp_filepath(model_filename_base, "onnx")
391
  raw_X_for_types_df = df.drop(target_column, axis=1).infer_objects()
392
  onnx_initial_types = []
393
+ for col_name in raw_X_for_types_df.columns:
394
  col_dtype = raw_X_for_types_df[col_name].dtype
 
 
395
  if pd.api.types.is_numeric_dtype(col_dtype):
 
 
 
 
396
  onnx_initial_types.append((col_name, FloatTensorType([None, 1])))
397
  elif pd.api.types.is_string_dtype(col_dtype) or col_dtype == 'object':
398
  onnx_initial_types.append((col_name, StringTensorType([None, 1])))
399
  else:
400
+ logs += f"Warning: Unsupported dtype {col_dtype} for {col_name} in ONNX. Defaulting to Float.\n"
401
+ onnx_initial_types.append((col_name, FloatTensorType([None, 1])))
402
 
403
  if not onnx_initial_types: raise ValueError("ONNX initial types failed: No valid columns found.")
404
  try:
 
407
  with open(model_path_out, "wb") as f: f.write(onnx_model.SerializeToString())
408
  logs += f"Model saved to {model_path_out} as ONNX.\n"
409
  sess = rt.InferenceSession(model_path_out, providers=rt.get_available_providers())
410
+ logs += f"ONNX model loaded successfully with ONNX Runtime.\n"
411
  except Exception as onnx_e: logs += f"ONNX Error: {traceback.format_exc()}\n"; model_path_out=None; metrics_out+="\nONNX EXPORT FAILED."
412
+ else:
413
  logs += f"Unsupported format '{model_output_format}'. Saving as .pkl\n"
414
  model_path_out = get_temp_filepath(model_filename_base, "pkl")
415
  joblib.dump(full_pipeline_for_saving, model_path_out)
 
432
  else: logs += f"Unsupported file: {data_input_obj}\n"; return logs, "Error", None, "N/A", None
433
  except Exception as e: logs += f"Error reading {data_input_obj}: {e}\n"; return logs, f"Error: {e}", None, "N/A", None
434
  elif isinstance(data_input_obj, pd.DataFrame): df_for_pytorch = data_input_obj
435
+ elif isinstance(data_input_obj, tuple): X_numpy_for_pytorch, y_numpy_for_pytorch = data_input_obj
 
436
  else: logs += "Invalid data for PyTorch training.\n"; return logs, "Error", None, "N/A", None
437
 
438
  try:
 
444
  X_processed_np, y_processed_np = None, None
445
 
446
  if model_type_pt == "Simple Neural Network (MLP)":
447
+ if not task_type.startswith("Tabular"): logs += "MLP requires Tabular task.\n"; return logs, "MLP Task Error", None, "N/A", None
448
+ if not target_column and df_for_pytorch is not None: logs += "Target column needed for MLP with DataFrame.\n"; return logs, "MLP Target Error", None, "N/A", None
 
 
 
449
  try:
450
  data_arg1 = df_for_pytorch if df_for_pytorch is not None else X_numpy_for_pytorch
451
  data_arg2 = y_numpy_for_pytorch if df_for_pytorch is None else None
452
+ current_target_col = target_column if df_for_pytorch is not None else "target"
453
+ X_processed_np, y_processed_np, preprocessor_pipeline, logs, processed_input_dim_actual, nn_output_dim_actual = \
 
 
454
  preprocess_tabular_data(data_arg1, data_arg2, current_target_col, task_type, logs)
455
  except ValueError as e: logs+=f"MLP Preprocessing error: {e}\n"; return logs,f"Error: {e}",None,"N/A",None
456
  elif model_type_pt == "Simple Convolutional Network (CNN)":
457
+ X_raw, y_raw = (df_for_pytorch.drop(target_column, axis=1).values, df_for_pytorch[target_column].values) if df_for_pytorch is not None else (X_numpy_for_pytorch, y_numpy_for_pytorch)
458
+ if X_raw is None: logs += "No valid data for CNN.\n"; return logs, "CNN Data Error", None, "N/A", None
 
 
 
 
 
 
 
 
 
 
459
  le = LabelEncoder(); y_processed_np = le.fit_transform(y_raw)
460
+ num_classes = len(le.classes_)
461
+ nn_output_dim_actual = 1 if num_classes == 2 else num_classes
462
 
463
  pixels_per_sample = X_raw.shape[1]; img_h, img_w, input_channels = 28,28,1
464
  img_dim_approx = int(math.sqrt(pixels_per_sample))
 
467
 
468
  X_processed_np = X_raw.reshape(-1, input_channels, img_h, img_w).astype(np.float32) / 255.0
469
  processed_input_dim_actual = (input_channels, img_h, img_w)
470
+ logs += f"CNN Data: X reshaped to {X_processed_np.shape}, y: {y_processed_np.shape}\n"
471
  else: logs += f"Unknown PyTorch model: {model_type_pt}\n"; return logs, "Unknown PyTorch model", None, "N/A", None
472
 
473
+ X_train, X_test, y_train, y_test = train_test_split(X_processed_np, y_processed_np, test_size=0.2, random_state=42)
474
+ logs += f"PyTorch Train/Test split. Train: {X_train.shape}, Test: {X_test.shape}\n"
475
+
476
+ y_train_dtype = torch.float32 if (nn_output_dim_actual == 1 and not task_type.endswith("Classification")) else (torch.float32 if nn_output_dim_actual == 1 else torch.long)
477
+ X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
478
+ y_train_tensor = torch.tensor(y_train, dtype=y_train_dtype)
479
+ X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
480
+ y_test_tensor = torch.tensor(y_test, dtype=y_train_dtype)
481
 
482
+ if nn_output_dim_actual == 1: # For BCELoss and MSELoss, target needs to be [N, 1]
483
+ y_train_tensor = y_train_tensor.unsqueeze(1)
484
+ y_test_tensor = y_test_tensor.unsqueeze(1)
485
+
486
+ train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
487
+ train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
488
 
489
  pytorch_model = None
490
  try:
491
+ is_classification = task_type.endswith("Classification")
492
  if model_type_pt == "Simple Neural Network (MLP)":
493
  pytorch_model = SimpleMLP(input_dim=processed_input_dim_actual, hidden_layers_str=mlp_hidden_layers_str,
494
  output_dim=nn_output_dim_actual, activation_fn_str=mlp_activation,
495
+ task_type="classification" if is_classification else "regression")
496
  elif model_type_pt == "Simple Convolutional Network (CNN)":
497
  channels, h, w = processed_input_dim_actual
498
  pytorch_model = SimpleCNN(input_channels=channels, img_size_wh=(h,w), num_classes=nn_output_dim_actual,
499
+ task_type="classification" if is_classification else "regression")
500
  except Exception as model_e: logs += f"Error creating PyTorch model: {traceback.format_exc()}\n"; return logs, f"Model Creation Error: {model_e}", None, "N/A", None
501
 
502
  if pytorch_model is None: logs += "Failed to instantiate PyTorch model.\n"; return logs, "Model instantiate fail", None, "N/A", None
503
  model_params_val = count_pytorch_parameters(pytorch_model); model_params_out = f"{model_params_val:,}"
504
  logs += f"PyTorch Model: {model_params_out} params.\n"
 
505
 
506
+ criterion = nn.BCELoss() if (is_classification and nn_output_dim_actual == 1) else (nn.CrossEntropyLoss() if is_classification else nn.MSELoss())
 
 
 
507
  optimizer = optim.Adam(pytorch_model.parameters(), lr=lr)
508
 
509
  logs += f"Starting PyTorch training for {epochs} epochs...\n"; start_time = time.time()
510
  epoch_losses = []
511
  pytorch_model.train()
512
  for epoch in range(epochs):
513
+ epoch_loss_sum = 0.0
514
+ for batch_X, batch_y in train_dataloader:
515
  optimizer.zero_grad(); outputs = pytorch_model(batch_X)
516
  loss = criterion(outputs, batch_y); loss.backward(); optimizer.step()
517
+ epoch_loss_sum += loss.item()
518
+ avg_epoch_loss = epoch_loss_sum / len(train_dataloader) if len(train_dataloader) > 0 else 0
519
  epoch_losses.append(avg_epoch_loss)
520
+ if (epoch + 1) % max(1, epochs // 10) == 0 or epoch == epochs - 1: # Log ~10 times
521
+ logs += f"Epoch {epoch+1}/{epochs}, Avg Training Loss: {avg_epoch_loss:.4f}\n"
522
 
523
  logs += f"PyTorch training completed in {time.time() - start_time:.2f}s.\n"
524
  pytorch_model.eval()
525
  with torch.no_grad():
526
+ test_outputs = pytorch_model(X_test_tensor)
527
+ if is_classification:
528
+ predicted = (test_outputs > 0.5).float() if nn_output_dim_actual == 1 else torch.max(test_outputs.data, 1)[1]
529
+ acc = accuracy_score(y_test_tensor.cpu().numpy(), predicted.cpu().numpy())
530
+ report = classification_report(y_test_tensor.cpu().numpy(), predicted.cpu().numpy(), zero_division=0)
531
+ metrics_out = f"Final Avg Training Loss: {avg_epoch_loss:.4f}\n\n--- Test Set Evaluation ---\nAccuracy: {acc:.4f}\n\nClassification Report:\n{report}"
532
+ else: # Regression
533
+ mse = mean_squared_error(y_test_tensor.cpu().numpy(), test_outputs.cpu().numpy())
534
+ r2 = r2_score(y_test_tensor.cpu().numpy(), test_outputs.cpu().numpy())
535
+ metrics_out = f"Final Avg Training Loss: {avg_epoch_loss:.4f}\n\n--- Test Set Evaluation ---\nMean Squared Error: {mse:.4f}\nR2 Score: {r2:.4f}"
 
 
 
 
 
 
 
 
 
 
 
536
  logs += "\n--- PyTorch Metrics ---\n" + metrics_out + "\n"
537
 
538
  if epoch_losses:
 
 
 
539
  fig, ax = plt.subplots(); ax.plot(range(1, epochs + 1), epoch_losses, marker='o')
540
  ax.set_xlabel("Epoch"); ax.set_ylabel("Average Loss"); ax.set_title("Training Loss Curve")
541
  plot_out = fig; logs += "Loss curve generated.\n"
 
545
  model_path_out = get_temp_filepath(model_filename_base, "pt")
546
  save_obj = {'model_state_dict': pytorch_model.state_dict(), 'output_dim_nn': nn_output_dim_actual, 'task_type': task_type}
547
  if model_type_pt == "Simple Neural Network (MLP)" and preprocessor_pipeline:
548
+ save_obj.update({'preprocessor': preprocessor_pipeline, 'input_dim_processed': processed_input_dim_actual, 'hidden_layers_str': mlp_hidden_layers_str, 'activation_fn': mlp_activation})
 
 
 
 
549
  elif model_type_pt == "Simple Convolutional Network (CNN)":
550
+ c,h,w = processed_input_dim_actual; save_obj.update({'input_channels':c, 'img_h':h, 'img_w':w})
 
 
 
551
  torch.save(save_obj, model_path_out)
552
+ logs += f"PyTorch model saved to {model_path_out}\n"
553
  else: # Fallback
554
+ logs += f"Unsupported format '{model_output_format}'.\n"
 
 
555
  return logs, metrics_out, model_path_out, model_params_out, plot_out
556
 
557
+
558
+ # --- Main Training Wrapper Function ---
559
+ def train_model_wrapper(data_input_obj, target_column, task_type, model_family,
560
+ model_specific_choice,
561
+ mlp_hidden_layers, mlp_activation,
562
+ epochs, batch_size, learning_rate,
563
+ model_output_format, current_logs):
564
+
565
+ logs = current_logs + "\n--- Kicking off Training ---\n"
566
+ if data_input_obj is None:
567
+ logs += "ERROR: No dataset has been generated or uploaded. Please go to Tab 2.\n"
568
+ return logs, "Error: No dataset available.", None, "N/A", None
569
+
570
+ try:
571
+ if model_family == "Scikit-learn (Classical ML)":
572
+ logs, metrics, model_path, param_count = train_model_sklearn(
573
+ data_input_obj=data_input_obj,
574
+ target_column=target_column,
575
+ task_type=task_type,
576
+ model_name=model_specific_choice,
577
+ model_output_format=model_output_format,
578
+ current_logs=logs
579
+ )
580
+ return logs, metrics, model_path, param_count, None
581
+ elif model_family == "PyTorch (Neural Networks)":
582
+ logs, metrics, model_path, param_count, loss_plot = train_model_pytorch(
583
+ data_input_obj=data_input_obj,
584
+ target_column=target_column,
585
+ task_type=task_type,
586
+ model_type_pt=model_specific_choice,
587
+ mlp_hidden_layers_str=mlp_hidden_layers,
588
+ mlp_activation=mlp_activation,
589
+ epochs_str=str(int(epochs)),
590
+ batch_size_str=str(int(batch_size)),
591
+ lr_str=str(learning_rate),
592
+ model_output_format=model_output_format,
593
+ current_logs=logs
594
+ )
595
+ return logs, metrics, model_path, param_count, loss_plot
596
+ else:
597
+ logs += f"Unknown model family: {model_family}\n"
598
+ return logs, "Error: Unknown model family.", None, "N/A", None
599
+ except Exception as e:
600
+ error_msg = f"An unexpected error occurred in the training wrapper: {traceback.format_exc()}"
601
+ logs += error_msg + "\n"
602
+ return logs, error_msg, None, "N/A", None
603
+
604
+
605
  # --- Gradio UI Definition ---
606
  TASK_CHOICES = ["Tabular Classification", "Tabular Regression", "Basic Image Classification"]
607
  MODEL_FAMILIES = ["Scikit-learn (Classical ML)", "PyTorch (Neural Networks)"]
 
614
  MLP_ACTIVATIONS = ["relu", "tanh", "sigmoid"]
615
  CLONE_GUIDE_TEXT = """
616
  ## How to Clone & Upgrade This Space for More Power:
617
+ 1. **Clone this Space:** Click the '...' menu at the top-right and choose 'Duplicate this Space'.
618
+ 2. **Choose Hardware:** On the duplication screen, select a more powerful hardware option, like a "CPU upgrade" or a "T4 Small" GPU.
619
+ 3. **Enjoy Faster Training:** Your private, upgraded version of TrainAI will now train models significantly faster!
620
  """
621
 
622
  # Determine initial choices for model_specific_dd based on default task_type and model_family
623
  _initial_task_default = TASK_CHOICES[0]
624
  _initial_family_default = MODEL_FAMILIES[0]
 
625
  if _initial_family_default == "Scikit-learn (Classical ML)":
626
+ initial_model_choices_for_specific_dd = SKLEARN_MODELS_CLASSIFICATION
 
627
  elif _initial_family_default == "PyTorch (Neural Networks)":
628
+ initial_model_choices_for_specific_dd = [PYTORCH_MODELS[0]]
629
+ initial_model_value_for_specific_dd = initial_model_choices_for_specific_dd[0]
 
630
 
631
  def update_model_options(task_choice, model_family_choice):
632
  choices, value = [], None
 
639
  value = choices[0] if choices else None
640
  return gr.update(choices=choices, value=value, visible=bool(choices))
641
 
642
+ def update_pytorch_specific_options_visibility(model_family_choice, specific_pytorch_model):
643
+ is_pytorch = model_family_choice == "PyTorch (Neural Networks)"
644
+ is_mlp = is_pytorch and (specific_pytorch_model == "Simple Neural Network (MLP)")
645
+ is_cnn = is_pytorch and (specific_pytorch_model == "Simple Convolutional Network (CNN)")
646
+ return gr.update(visible=is_pytorch), gr.update(visible=is_mlp), gr.update(visible=is_cnn)
 
 
 
 
 
 
 
647
 
648
  def update_model_output_formats(model_family_choice):
649
  if model_family_choice == "Scikit-learn (Classical ML)": return gr.update(choices=MODEL_OUTPUT_FORMATS_SKLEARN, value=MODEL_OUTPUT_FORMATS_SKLEARN[0])
650
+ if model_family_choice == "PyTorch (Neural Networks)": return gr.update(choices=MODEL_OUTPUT_FORMATS_PYTORCH, value=MODEL_OUTPUT_FORMATS_PYTORCH[0])
651
  return gr.update(choices=[], value=None)
652
 
653
+ css = """.gradio-container { font-family: 'IBM Plex Sans', sans-serif; } footer {display:none !important}"""
654
 
655
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="orange"), css=css) as demo:
656
  gr.Markdown("# 🧠 TrainAI ⚙️")
 
665
  model_family_dd = gr.Dropdown(MODEL_FAMILIES, label="Select Model Family", value=_initial_family_default)
666
  model_specific_dd = gr.Dropdown(label="Select Specific Model", choices=initial_model_choices_for_specific_dd, value=initial_model_value_for_specific_dd, interactive=True)
667
 
668
+ with gr.Group(visible=(_initial_family_default == "PyTorch (Neural Networks)")) as pt_options_group:
669
+ pytorch_param_range_dd = gr.Dropdown(list(PARAM_RANGES.keys()), label="Target Parameter Range (for NNs)",
670
+ info="Guides NN architecture suggestions. Training >250k params on CPU is slow.",
671
+ value=list(PARAM_RANGES.keys())[1])
672
+ with gr.Group(visible=(initial_model_value_for_specific_dd == PYTORCH_MODELS[0])) as pt_mlp_specific_group:
673
+ gr.Markdown("#### MLP Configuration")
674
+ pt_mlp_hidden_layers_txt = gr.Textbox(label="Hidden Layer Sizes (comma-separated, e.g., 128,64)", value="64,32")
675
+ pt_mlp_activation_dd = gr.Dropdown(MLP_ACTIVATIONS, label="Activation Function", value="relu")
676
+ with gr.Row():
677
+ pt_mlp_suggest_btn = gr.Button("Suggest MLP Layers")
678
+ pt_mlp_estimate_params_btn = gr.Button("Estimate Current MLP Params")
679
+ pt_mlp_param_count_txt = gr.Textbox(label="Estimated MLP Parameters", interactive=False)
680
+ with gr.Group(visible=(initial_model_value_for_specific_dd == PYTORCH_MODELS[1])) as pt_cnn_specific_group:
681
+ gr.Markdown("#### CNN Configuration (Simplified)")
682
+ gr.Markdown("SimpleCNN uses a fixed structure. Params depend on image size/classes from data.")
683
+ pt_cnn_estimate_params_btn = gr.Button("Estimate CNN Params (needs Data Info)")
684
+ pt_cnn_param_count_txt = gr.Textbox(label="Estimated CNN Parameters", interactive=False)
685
+
686
+ with gr.TabItem("2. Configure Dataset"):
687
  dataset_source_rb = gr.Radio(["Generate new dataset", "Upload my own dataset (CSV, JSON, Parquet)"],
688
  label="Dataset Source", value="Generate new dataset")
689
+ with gr.Group(visible=True) as generate_dataset_group:
690
  gr.Markdown("#### Generate Synthetic Dataset")
691
  with gr.Row():
692
  ds_gen_samples_num = gr.Number(label="# Samples", value=1000, minimum=10, step=100)
 
698
  with gr.Group(visible=False) as upload_dataset_group:
699
  gr.Markdown("#### Upload Dataset")
700
  ds_upload_file = gr.File(label="Upload dataset", file_types=[".csv", ".json", ".parquet"])
701
+ target_column_name_txt = gr.Textbox(label="Target Column Name (Case-Sensitive!)", placeholder="e.g., 'target' or 'label'", value="target")
702
  dataset_preview_df = gr.DataFrame(label="Dataset Preview (First 5 Rows)", interactive=False, height=200)
703
  generated_dataset_download_file = gr.File(label="Download Generated Dataset", interactive=False)
704
 
 
707
  with gr.Row():
708
  train_epochs_num = gr.Number(label="Epochs (NNs)", value=10, minimum=1, step=1)
709
  train_batch_size_num = gr.Number(label="Batch Size (NNs)", value=32, minimum=1, step=1)
710
+ train_learning_rate_num = gr.Number(label="Learning Rate (NNs)", value=0.001, minimum=1e-6, format="%.6f")
711
  model_output_format_dd = gr.Dropdown(label="Select Model Output Format", choices=MODEL_OUTPUT_FORMATS_SKLEARN, value=MODEL_OUTPUT_FORMATS_SKLEARN[0])
712
  train_model_btn = gr.Button("🚀 Train Model", variant="primary")
713
  gr.Markdown("---")
 
719
  download_trained_model_file = gr.File(label="Download Trained Model", interactive=False)
720
 
721
  with gr.TabItem("ℹ️ Guide & Info"):
 
722
  gr.Markdown(CLONE_GUIDE_TEXT)
723
 
724
  # --- Event Handlers ---
725
  task_type_dd.change(fn=update_model_options, inputs=[task_type_dd, model_family_dd], outputs=model_specific_dd)
726
  model_family_dd.change(fn=update_model_options, inputs=[task_type_dd, model_family_dd], outputs=model_specific_dd)
727
 
728
+ model_family_dd.change(fn=update_pytorch_specific_options_visibility, inputs=[model_family_dd, model_specific_dd], outputs=[pt_options_group, pt_mlp_specific_group, pt_cnn_specific_group])
729
+ model_specific_dd.change(fn=update_pytorch_specific_options_visibility, inputs=[model_family_dd, model_specific_dd], outputs=[pt_options_group, pt_mlp_specific_group, pt_cnn_specific_group])
730
 
731
  def get_data_dims_for_nn_suggestion(preview_df, target_col, task, logs_in):
732
  logs = logs_in
733
  input_dim_est, output_dim_est = 10, (2 if task.endswith("Classification") else 1) # Defaults
734
  img_h_est, img_w_est = 28, 28 # Defaults for CNN
735
+ num_pixels = 0
736
 
737
  if preview_df is not None and isinstance(preview_df, pd.DataFrame) and not preview_df.empty:
738
+ cols = list(preview_df.columns)
739
+ if target_col in cols: cols.remove(target_col)
740
+
741
+ if task == "Basic Image Classification":
742
+ num_pixels = len(cols)
 
 
 
 
743
  if num_pixels > 0:
744
  dim_sqrt = int(math.sqrt(num_pixels))
745
  if dim_sqrt * dim_sqrt == num_pixels: img_h_est, img_w_est = dim_sqrt, dim_sqrt
746
+ else: # Tabular
747
+ num_cols = len([c for c in cols if pd.api.types.is_numeric_dtype(preview_df[c])])
748
+ cat_cols = [c for c in cols if pd.api.types.is_object_dtype(preview_df[c])]
749
+ one_hot_est = sum(min(10, preview_df[c].nunique(dropna=False)) for c in cat_cols)
 
 
750
  input_dim_est = max(1, num_cols + one_hot_est)
751
 
752
  if target_col and target_col in preview_df.columns:
753
  if task.endswith("Classification"):
754
  output_dim_est = max(1, preview_df[target_col].nunique(dropna=False))
755
+ if output_dim_est == 2: output_dim_est = 1
756
  else: logs += "Dataset preview not available for NN dimension estimation. Using defaults.\n"
757
  return input_dim_est, output_dim_est, img_h_est, img_w_est, logs
758
 
 
781
  def cnn_estimate_proxy_wrapper(current_logs, preview_df, target_col, task_type):
782
  _, output_dim, img_h, img_w, logs = get_data_dims_for_nn_suggestion(preview_df, target_col, task_type, current_logs)
783
  logs += f"Using estimated img_h: {img_h}, img_w: {img_w}, output_dim: {output_dim} for CNN param estimation.\n"
784
+ cnn_task_type = "classification" if task_type == "Basic Image Classification" else "regression"
785
+ param_count_str, logs = estimate_cnn_params(str(img_h), str(img_w), str(output_dim), cnn_task_type, logs)
 
786
  return logs, param_count_str
787
 
788
  pt_cnn_estimate_params_btn.click(fn=cnn_estimate_proxy_wrapper,
 
795
  model_family_dd.change(fn=update_model_output_formats, inputs=model_family_dd, outputs=model_output_format_dd)
796
 
797
  generate_dataset_btn.click(
798
+ fn=generate_dataset_backend,
799
  inputs=[task_type_dd, ds_gen_samples_num, ds_gen_features_num, ds_gen_classes_informative_num,
800
  ds_gen_format_dd, ds_gen_ai_suggest_cb, pytorch_param_range_dd, model_specific_dd, current_logs_state],
801
  outputs=[dataset_preview_df, generated_data_state, training_log_txt, generated_dataset_download_file])
802
 
803
  def process_uploaded_file(file_obj, logs_in):
804
+ logs, df_preview, stored_data_path = logs_in, None, None
805
+ if file_obj is None: logs += "Please upload a file first.\n"; return df_preview, logs, stored_data_path
806
  logs += f"Uploaded file: {file_obj.name}\n"; stored_data_path = file_obj.name
807
  try:
808
  if file_obj.name.endswith(".csv"): df_preview = pd.read_csv(file_obj.name, nrows=5)
809
  elif file_obj.name.endswith(".json"): df_preview = pd.read_json(file_obj.name, lines=True, nrows=5)
810
+ elif file_obj.name.endswith(".parquet"): temp_df = pd.read_parquet(file_obj.name); df_preview = temp_df.head()
811
+ logs += "Preview generated for uploaded file.\n"
812
+ except Exception as e: logs += f"Error previewing {file_obj.name}: {e}\n"
813
+ return df_preview, logs, stored_data_path
 
814
  ds_upload_file.upload(fn=process_uploaded_file, inputs=[ds_upload_file, current_logs_state],
815
+ outputs=[dataset_preview_df, training_log_txt, generated_data_state])
816
 
817
  train_model_btn.click(
818
+ fn=train_model_wrapper,
819
  inputs=[generated_data_state, target_column_name_txt, task_type_dd, model_family_dd, model_specific_dd,
820
+ pt_mlp_hidden_layers_txt, pt_mlp_activation_dd,
821
  train_epochs_num, train_batch_size_num, train_learning_rate_num,
822
  model_output_format_dd, training_log_txt],
823
  outputs=[training_log_txt, evaluation_metrics_txt, download_trained_model_file,
824
+ model_param_count_output_txt, loss_plot_img])
825
 
826
  demo.queue().launch(debug=True, show_error=True)