szili2011 commited on
Commit
881d92c
·
verified ·
1 Parent(s): 1d651bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +345 -497
app.py CHANGED
@@ -29,7 +29,7 @@ import torchvision.transforms as T
29
  # ONNX specific imports
30
  import skl2onnx
31
  from skl2onnx import convert_sklearn
32
- from skl2onnx.common.data_types import FloatTensorType, Int64TensorType, StringTensorType
33
  import onnxruntime as rt
34
 
35
  import traceback
@@ -38,14 +38,15 @@ import json
38
  import math
39
  import collections.abc # For Gradio issue with new Python versions
40
 
 
41
  # --- Global Variables / Constants ---
42
  TEMP_DIR = "temp_outputs"
43
  os.makedirs(TEMP_DIR, exist_ok=True)
44
- MAX_DATASET_ROWS_WARN = 30000 # Reduced slightly due to increased complexity
45
- MAX_GENERATED_ROWS = 50000 # Max rows for generation
46
- MAX_GENERATED_COLS = 100 # Max cols for generation
47
 
48
- # --- Helper Functions ---
49
  def count_sklearn_parameters(model):
50
  if hasattr(model, 'coef_'):
51
  return model.coef_.size + (model.intercept_.size if hasattr(model, 'intercept_') else 0)
@@ -61,12 +62,10 @@ def count_pytorch_parameters(model):
61
  return sum(p.numel() for p in model.parameters() if p.requires_grad)
62
 
63
  def get_temp_filepath(filename_base, extension):
64
- # Ensure extension does not start with a dot if it's passed with one
65
  clean_extension = extension.lstrip('.')
66
  return os.path.join(TEMP_DIR, f"{filename_base}_{time.strftime('%Y%m%d-%H%M%S')}.{clean_extension}")
67
 
68
-
69
- # --- PyTorch Model Definitions ---
70
  class SimpleMLP(nn.Module):
71
  def __init__(self, input_dim, hidden_layers_str, output_dim, activation_fn_str="relu", task_type="classification"):
72
  super(SimpleMLP, self).__init__()
@@ -94,21 +93,21 @@ class SimpleMLP(nn.Module):
94
 
95
  layers.append(nn.Linear(current_dim, output_dim))
96
 
97
- if task_type == "classification":
98
  if output_dim == 1: layers.append(nn.Sigmoid()) # Binary
99
  elif output_dim > 1: layers.append(nn.Softmax(dim=-1)) # Multi-class
100
  self.network = nn.Sequential(*layers)
101
  def forward(self, x): return self.network(x)
102
 
103
- class SimpleCNN(nn.Module):
104
- def __init__(self, input_channels, img_size_wh, num_classes=10,
105
  c_out1=16, k1=3, s1=1, p1=1, pool1_k=2, pool1_s=2,
106
  c_out2=32, k2=3, s2=1, p2=1, pool2_k=2, pool2_s=2,
107
  fc_hidden=128):
108
  super(SimpleCNN, self).__init__()
109
  self.input_channels = input_channels
110
  self.img_h, self.img_w = img_size_wh
111
- self.num_classes = num_classes
112
 
113
  self.conv1 = nn.Conv2d(self.input_channels, c_out1, kernel_size=k1, stride=s1, padding=p1)
114
  self.relu1 = nn.ReLU()
@@ -116,32 +115,37 @@ class SimpleCNN(nn.Module):
116
 
117
  h_out_conv1 = (self.img_h - k1 + 2 * p1) // s1 + 1
118
  w_out_conv1 = (self.img_w - k1 + 2 * p1) // s1 + 1
119
- h_pool1 = (h_out_conv1 - pool1_k) // pool1_s + 1
120
- w_pool1 = (w_out_conv1 - pool1_k) // pool1_s + 1
121
-
 
122
  self.conv2 = nn.Conv2d(c_out1, c_out2, kernel_size=k2, stride=s2, padding=p2)
123
  self.relu2 = nn.ReLU()
124
  self.pool2 = nn.MaxPool2d(kernel_size=pool2_k, stride=pool2_s)
125
 
126
  h_out_conv2 = (h_pool1 - k2 + 2 * p2) // s2 + 1
127
  w_out_conv2 = (w_pool1 - k2 + 2 * p2) // s2 + 1
128
- h_pool2 = (h_out_conv2 - pool2_k) // pool2_s + 1
129
- w_pool2 = (w_out_conv2 - pool2_k) // pool2_s + 1
130
 
131
  self.flattened_size = c_out2 * h_pool2 * w_pool2
132
  if self.flattened_size <= 0:
133
- raise ValueError(f"Calculated flattened size is {self.flattened_size}. Check CNN params and image size. Conv1_out:({h_out_conv1},{w_out_conv1}), Pool1_out:({h_pool1},{w_pool1}), Conv2_out:({h_out_conv2},{w_out_conv2}), Pool2_out:({h_pool2},{w_pool2})")
134
 
135
  self.fc1 = nn.Linear(self.flattened_size, fc_hidden)
136
  self.relu3 = nn.ReLU()
137
- self.fc2 = nn.Linear(fc_hidden, num_classes)
138
 
139
- if num_classes > 1 or (num_classes == 1 and task_type=="classification"): # Adapt for binary vs regression
140
- self.final_activation = nn.Softmax(dim=1) if num_classes > 1 else nn.Sigmoid()
141
- else: # Regression output from fc2
 
 
 
 
 
142
  self.final_activation = nn.Identity()
143
 
144
-
145
  def forward(self, x):
146
  x = self.pool1(self.relu1(self.conv1(x)))
147
  x = self.pool2(self.relu2(self.conv2(x)))
@@ -151,8 +155,8 @@ class SimpleCNN(nn.Module):
151
  x = self.final_activation(x)
152
  return x
153
 
154
- # --- Parameter Target Helpers ---
155
- PARAM_RANGES = collections.OrderedDict([ # Ordered for consistent UI
156
  ("Tiny (<10k)", (0, 10000)),
157
  ("Small (10k-50k)", (10000, 50000)),
158
  ("Medium (50k-250k)", (50000, 250000)),
@@ -190,28 +194,34 @@ def suggest_mlp_layers_for_range(input_dim, output_dim, target_range_str, curren
190
  if not suggested_layers_str: suggested_layers_str = "64"; logs += "Defaulting to '64'.\n"
191
  return suggested_layers_str, logs
192
 
193
- def estimate_current_mlp_params(input_dim_str, hidden_layers_str, output_dim_str, current_logs=""):
194
  logs = current_logs
195
  try:
196
  input_dim = int(input_dim_str); output_dim = int(output_dim_str)
197
  if input_dim <= 0 or output_dim <= 0: return "Input/Output dims must be > 0", logs
198
- temp_mlp = SimpleMLP(input_dim, hidden_layers_str, output_dim)
 
 
 
199
  params = count_pytorch_parameters(temp_mlp); del temp_mlp
200
  return f"{params:,}", logs
201
  except Exception as e: logs += f"Error estimating MLP params: {e}\n"; return "Error", logs
202
 
203
- def estimate_cnn_params(img_h_str, img_w_str, num_classes_str, current_logs=""):
204
  logs = current_logs
205
  try:
206
- img_h, img_w, num_classes = int(img_h_str), int(img_w_str), int(num_classes_str)
207
- if not (img_h > 0 and img_w > 0 and num_classes > 0): return "Image dims/classes must be > 0", logs
208
- # Using default SimpleCNN params here. A real app would pass them.
209
- temp_cnn = SimpleCNN(input_channels=1, img_size_wh=(img_h, img_w), num_classes=num_classes)
 
 
210
  params = count_pytorch_parameters(temp_cnn); del temp_cnn
211
  return f"{params:,}", logs
212
- except Exception as e: logs += f"Error estimating CNN params: {e}\n"; return "Error", logs
 
213
 
214
- # --- Dataset and Preprocessing ---
215
  def generate_dataset_backend(task_type, n_samples_str, n_features_str,
216
  n_classes_or_informative_str, dataset_format,
217
  ai_suggest_ds_shape, target_param_range_str, model_type_selection,
@@ -223,14 +233,14 @@ def generate_dataset_backend(task_type, n_samples_str, n_features_str,
223
 
224
  if ai_suggest_ds_shape:
225
  n_samples_sugg, n_features_sugg, n_classes_or_informative_sugg = 5000, 10, 2
226
- if task_type == "Tabular Regression": n_classes_or_informative_sugg = min(n_features_sugg // 2, 5)
227
- elif task_type == "Basic Image Classification": n_samples_sugg, n_features_sugg = 500, 0 # features not tabular
228
 
229
  is_nn = "Network" in model_type_selection
230
  if is_nn and target_param_range_str in PARAM_RANGES:
231
  min_p, max_p = PARAM_RANGES[target_param_range_str]; avg_p = (min_p + max_p) / 2
232
- if avg_p > 200000: n_samples_sugg = min(MAX_GENERATED_ROWS, n_samples_sugg * 2); n_features_sugg = min(MAX_GENERATED_COLS, n_features_sugg * 2) if task_type.startswith("Tabular") else n_features_sugg
233
- elif avg_p < 50000: n_samples_sugg = max(100, n_samples_sugg // 2); n_features_sugg = max(3, n_features_sugg // 2) if task_type.startswith("Tabular") else n_features_sugg
234
 
235
  n_samples, n_features, n_classes_or_informative = n_samples_sugg, n_features_sugg, n_classes_or_informative_sugg
236
  logs += f"AI Suggested Dataset: Samples={n_samples}, Feats={n_features}, Classes/Informative={n_classes_or_informative}\n"
@@ -239,20 +249,21 @@ def generate_dataset_backend(task_type, n_samples_str, n_features_str,
239
  if task_type.startswith("Tabular"): n_features = max(1, min(n_features, MAX_GENERATED_COLS))
240
  if n_samples > MAX_DATASET_ROWS_WARN: logs += f"Warning: Generating {n_samples} rows. May be slow.\n"
241
 
242
- df = None; X_data=None; y_data=None # Init X_data, y_data
243
  try:
244
  if task_type == "Tabular Classification":
245
  n_cls = max(2, n_classes_or_informative)
246
- n_inf = max(1, min(n_features, n_classes_or_informative if n_classes_or_informative > n_cls else n_features // 2))
 
247
  X_data, y_data = make_classification(n_samples=n_samples, n_features=n_features, n_informative=n_inf,
248
  n_redundant=max(0,n_features - n_inf)//2, n_classes=n_cls, flip_y=0.05, random_state=42)
249
  df = pd.DataFrame(X_data, columns=[f'feature_{i}' for i in range(n_features)]); df['target'] = y_data
250
  elif task_type == "Tabular Regression":
251
  n_inf = max(1, min(n_features, n_classes_or_informative))
 
252
  X_data, y_data = make_regression(n_samples=n_samples, n_features=n_features, n_informative=n_inf, noise=10, random_state=42)
253
  df = pd.DataFrame(X_data, columns=[f'feature_{i}' for i in range(n_features)]); df['target'] = y_data
254
  elif task_type == "Basic Image Classification":
255
- # For SimpleCNN, let's generate 28x28 "images" (random noise)
256
  img_h, img_w = 28, 28
257
  num_pixels = img_h * img_w
258
  X_data = np.random.randint(0, 256, size=(n_samples, num_pixels), dtype=np.uint8)
@@ -263,17 +274,19 @@ def generate_dataset_backend(task_type, n_samples_str, n_features_str,
263
 
264
  logs += f"Generated data: {df.shape if df is not None else (X_data.shape, y_data.shape)}\n"
265
  file_path = get_temp_filepath("generated_dataset", dataset_format)
266
- if df is not None: # Save if DataFrame was created
267
  if dataset_format == ".csv": df.to_csv(file_path, index=False)
268
  elif dataset_format == ".json": df.to_json(file_path, orient='records', lines=True)
269
  elif dataset_format == ".parquet": df.to_parquet(file_path, index=False)
270
  else: logs += f"Unsupported format {dataset_format}. Defaulting to CSV.\n"; file_path=get_temp_filepath("generated_dataset","csv"); df.to_csv(file_path, index=False)
271
  logs += f"Dataset saved to {file_path}\n"
272
- return df.head(), df, logs, file_path # Return DataFrame for sklearn
273
- else: # Case where df might not be created (though current logic does)
274
- logs += "Dataset generated as numpy arrays. No file saved directly by this part of function.\n"
275
- # This branch needs more thought if we don't always make a df
276
- return pd.DataFrame(X_data[:5]), (X_data, y_data), logs, None # Return numpy arrays for PyTorch image case
 
 
277
 
278
 
279
  except Exception as e: error_msg=f"Error generating dataset: {traceback.format_exc()}"; logs+=error_msg+"\n"; return None, error_msg, logs, None
@@ -285,8 +298,8 @@ def preprocess_tabular_data(df_or_X, y_if_X_is_numpy, target_column_name, task_t
285
  if target_column_name not in df.columns: raise ValueError(f"Target column '{target_column_name}' not found.")
286
  X_df = df.drop(target_column_name, axis=1)
287
  y_series = df[target_column_name]
288
- elif isinstance(df_or_X, np.ndarray) and y_if_X_is_numpy is not None: # If X,y are numpy
289
- X_df = pd.DataFrame(df_or_X, columns=[f'feature_{i}' for i in range(df_or_X.shape[1])]) # Temp DF for pipeline
290
  y_series = pd.Series(y_if_X_is_numpy)
291
  else: raise ValueError("Invalid input for preprocess_tabular_data.")
292
 
@@ -296,20 +309,30 @@ def preprocess_tabular_data(df_or_X, y_if_X_is_numpy, target_column_name, task_t
296
 
297
  preprocessor = ColumnTransformer(transformers=[
298
  ('num', Pipeline([('imputer', SimpleImputer(strategy='mean')), ('scaler', StandardScaler())]), numerical_features),
299
- ('cat', Pipeline([('imputer', SimpleImputer(strategy='most_frequent')), ('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=False))]), categorical_features) # sparse_output=False for easier handling
300
- ], remainder='passthrough') # passthrough to keep unhandled columns if any
301
 
302
  X_processed_np = preprocessor.fit_transform(X_df)
303
 
304
- try: feature_names_out = preprocessor.get_feature_names_out()
305
- except AttributeError: # Older sklearn
306
- cat_encoder = preprocessor.named_transformers_['cat'].named_steps['onehot']
307
- if hasattr(cat_encoder, 'get_feature_names_out'):
308
- cat_feature_names = cat_encoder.get_feature_names_out(categorical_features)
309
- elif hasattr(cat_encoder, 'get_feature_names'): # even older
310
- cat_feature_names = cat_encoder.get_feature_names(categorical_features)
311
- else: cat_feature_names = [f"cat_feat_{i}" for i in range(X_processed_np.shape[1] - len(numerical_features))] # Fallback
312
- feature_names_out = numerical_features + list(cat_feature_names)
 
 
 
 
 
 
 
 
 
 
313
 
314
  processed_input_dim = X_processed_np.shape[1]
315
  logs += f"Tabular data preprocessed. X shape: {X_processed_np.shape}, Processed input dim: {processed_input_dim}\n"
@@ -319,25 +342,22 @@ def preprocess_tabular_data(df_or_X, y_if_X_is_numpy, target_column_name, task_t
319
  y_processed_np = le.fit_transform(y_series)
320
  num_classes = len(le.classes_)
321
  logs += f"Target encoded. Classes: {num_classes} ({le.classes_})\n"
322
- # For binary classification with PyTorch, often output 1 neuron with Sigmoid or BCEWithLogitsLoss
323
- # If num_classes is 2, some PyTorch setups expect output_dim=1.
324
- # Scikit-learn handles this internally.
325
- output_dim_nn = 1 if num_classes == 2 else num_classes
326
  else: # Regression
327
  y_processed_np = y_series.astype(float).values
328
- num_classes = 1 # Output dim for regression for NN
329
  output_dim_nn = 1
330
 
331
- return X_processed_np, y_processed_np, preprocessor, logs, processed_input_dim, output_dim_nn, feature_names_out
332
 
333
 
334
- # --- Training Functions ---
335
  def train_model_sklearn(data_input_obj, target_column, task_type, model_name, model_output_format, current_logs=""):
336
  logs = current_logs + f"\n--- Training Scikit-learn Model: {model_name} ---\n"
337
  model_path_out, metrics_out, model_params_out = None, "Training failed.", "N/A"
338
 
339
  df = None
340
- if isinstance(data_input_obj, str): # Filepath
341
  try:
342
  if data_input_obj.endswith('.csv'): df = pd.read_csv(data_input_obj)
343
  elif data_input_obj.endswith('.json'): df = pd.read_json(data_input_obj, lines=True)
@@ -347,11 +367,11 @@ def train_model_sklearn(data_input_obj, target_column, task_type, model_name, mo
347
  elif isinstance(data_input_obj, pd.DataFrame): df = data_input_obj
348
  else: logs += "Invalid data for training.\n"; return logs, "Error: Invalid data.", None, "N/A"
349
 
350
- if target_column not in df.columns:
351
- logs += f"Target '{target_column}' not found.\n"; return logs, f"Error: Target '{target_column}' not found.", None, "N/A"
352
 
353
  try:
354
- X_processed_np, y_processed_np, preprocessor, logs, _, _, feature_names = preprocess_tabular_data(df, None, target_column, task_type, logs)
355
  except ValueError as e: logs += f"Preprocessing error: {e}\n"; return logs, f"Error: {e}", None, "N/A"
356
 
357
  X_train, X_test, y_train, y_test = train_test_split(X_processed_np, y_processed_np, test_size=0.2, random_state=42)
@@ -361,7 +381,7 @@ def train_model_sklearn(data_input_obj, target_column, task_type, model_name, mo
361
  if task_type == "Tabular Classification":
362
  if model_name == "Logistic Regression": model = LogisticRegression(max_iter=1000, random_state=42)
363
  elif model_name == "Random Forest Classifier": model = RandomForestClassifier(random_state=42)
364
- elif model_name == "Support Vector Machine (SVM) Classifier": model = SVC(random_state=42, probability=True) # probability=True for ONNX if it needs predict_proba
365
  elif task_type == "Tabular Regression":
366
  if model_name == "Linear Regression": model = LinearRegression()
367
  elif model_name == "Random Forest Regressor": model = RandomForestRegressor(random_state=42)
@@ -377,16 +397,12 @@ def train_model_sklearn(data_input_obj, target_column, task_type, model_name, mo
377
  y_pred = model.predict(X_test)
378
 
379
  if task_type == "Tabular Classification":
380
- acc = accuracy_score(y_test, y_pred)
381
- report = classification_report(y_test, y_pred, zero_division=0)
382
  metrics_out = f"Accuracy: {acc:.4f}\n\nClassification Report:\n{report}"
383
  elif task_type == "Tabular Regression":
384
- mse = mean_squared_error(y_test, y_pred)
385
- r2 = r2_score(y_test, y_pred)
386
  metrics_out = f"Mean Squared Error: {mse:.4f}\nR2 Score: {r2:.4f}"
387
  logs += "\n--- Evaluation Metrics ---\n" + metrics_out + "\n"
388
-
389
- # Full pipeline for inference: preprocessor + model
390
  full_pipeline_for_saving = Pipeline([('preprocessor', preprocessor), ('model', model)])
391
  model_filename_base = f"sklearn_{model_name.replace(' ', '_').lower()}"
392
 
@@ -394,89 +410,60 @@ def train_model_sklearn(data_input_obj, target_column, task_type, model_name, mo
394
  model_path_out = get_temp_filepath(model_filename_base, "pkl")
395
  joblib.dump(full_pipeline_for_saving, model_path_out)
396
  logs += f"Model (with preprocessor) saved to {model_path_out} as PKL.\n"
397
-
398
  elif model_output_format == ".onnx (ONNX)":
399
  model_path_out = get_temp_filepath(model_filename_base, "onnx")
400
-
401
- # Define initial types for ONNX conversion based on preprocessed input
402
- # The preprocessor converts all to numerical. Shape is (batch_size, num_processed_features)
403
- # num_processed_features = X_train.shape[1]
404
- initial_type = [('float_input', FloatTensorType([None, X_train.shape[1]]))] # None for batch size
405
-
406
- # For models with string inputs *before* preprocessing, it's more complex.
407
- # Here, we assume the `full_pipeline_for_saving` takes the raw DataFrame structure as input.
408
- # So, we need to define initial_types based on the *original* DataFrame features.
409
-
410
- # Re-create initial types based on the *original* df structure, before preprocessing
411
- # This is complex because ColumnTransformer input spec is not trivial for skl2onnx for mixed types.
412
- # The EASIEST way for skl2onnx with ColumnTransformer is to convert the *fitted preprocessor separately*
413
- # OR, provide initial types that match the *input to the preprocessor*.
414
-
415
- # Let's try providing initial types for the raw input to the preprocessor
416
- raw_X_for_types = df.drop(target_column, axis=1).infer_objects() # Infer object dtypes to str for ONNX
417
  onnx_initial_types = []
418
- for col_name in raw_X_for_types.columns:
419
- col_dtype = raw_X_for_types[col_name].dtype
 
 
420
  if pd.api.types.is_numeric_dtype(col_dtype):
421
- # Forcing float32 for ONNX compatibility
 
 
 
422
  onnx_initial_types.append((col_name, FloatTensorType([None, 1])))
423
  elif pd.api.types.is_string_dtype(col_dtype) or col_dtype == 'object':
424
  onnx_initial_types.append((col_name, StringTensorType([None, 1])))
425
  else:
426
- logs += f"Warning: Unsupported dtype {col_dtype} for column {col_name} in ONNX conversion. Skipping.\n"
 
427
 
428
- if not onnx_initial_types:
429
- logs += "Error: Could not determine ONNX initial types for raw input. Aborting ONNX export.\n"
430
- raise ValueError("ONNX initial types failed.")
431
-
432
  try:
433
- options = {id(full_pipeline_for_saving): {'zipmap': False}} # Disable zipmap for classifier output
434
- onnx_model = convert_sklearn(full_pipeline_for_saving, initial_types=onnx_initial_types,
435
- target_opset=12, options=options) # Target opset can be important
436
- with open(model_path_out, "wb") as f:
437
- f.write(onnx_model.SerializeToString())
438
- logs += f"Model (with preprocessor) saved to {model_path_out} as ONNX.\n"
439
-
440
- # Optional: Verify ONNX model
441
  sess = rt.InferenceSession(model_path_out, providers=rt.get_available_providers())
442
- logs += f"ONNX model loaded successfully with ONNX Runtime. Input names: {[inp.name for inp in sess.get_inputs()]}\n"
443
- except Exception as onnx_e:
444
- logs += f"Error during ONNX conversion/saving: {traceback.format_exc()}\n"
445
- model_path_out = None # Clear path if saving failed
446
- metrics_out += "\nONNX EXPORT FAILED."
447
-
448
- else:
449
  logs += f"Unsupported format '{model_output_format}'. Saving as .pkl\n"
450
  model_path_out = get_temp_filepath(model_filename_base, "pkl")
451
  joblib.dump(full_pipeline_for_saving, model_path_out)
452
-
453
- except Exception as e:
454
- error_msg = f"Error during sklearn training/eval: {traceback.format_exc()}"; logs += error_msg + "\n"; metrics_out = error_msg
455
  return logs, metrics_out, model_path_out, model_params_out
456
 
457
-
458
  def train_model_pytorch(data_input_obj, target_column, task_type, model_type_pt,
459
  mlp_hidden_layers_str, mlp_activation,
460
- # CNN specific (using defaults in SimpleCNN for now)
461
- # cnn_img_h_str, cnn_img_w_str, # Now derived from data
462
  epochs_str, batch_size_str, lr_str,
463
  model_output_format, current_logs=""):
464
  logs = current_logs + f"\n--- Training PyTorch Model: {model_type_pt} ---\n"
465
  model_path_out, metrics_out, model_params_out, plot_out = None, "Training failed.", "N/A", None
466
 
467
- df_for_pytorch = None; X_numpy_for_pytorch=None; y_numpy_for_pytorch=None # For flexibility
468
- if isinstance(data_input_obj, str): # Filepath
469
  try:
470
- # For PyTorch, we might want to handle data differently, esp images
471
  if data_input_obj.endswith('.csv'): df_for_pytorch = pd.read_csv(data_input_obj)
472
  elif data_input_obj.endswith('.json'): df_for_pytorch = pd.read_json(data_input_obj, lines=True)
473
  elif data_input_obj.endswith('.parquet'): df_for_pytorch = pd.read_parquet(data_input_obj)
474
  else: logs += f"Unsupported file: {data_input_obj}\n"; return logs, "Error", None, "N/A", None
475
  except Exception as e: logs += f"Error reading {data_input_obj}: {e}\n"; return logs, f"Error: {e}", None, "N/A", None
476
  elif isinstance(data_input_obj, pd.DataFrame): df_for_pytorch = data_input_obj
477
- elif isinstance(data_input_obj, tuple) and len(data_input_obj) == 2 and \
478
- isinstance(data_input_obj[0], np.ndarray) and isinstance(data_input_obj[1], np.ndarray):
479
- X_numpy_for_pytorch, y_numpy_for_pytorch = data_input_obj # If data was (X,y) from generation
480
  else: logs += "Invalid data for PyTorch training.\n"; return logs, "Error", None, "N/A", None
481
 
482
  try:
@@ -484,63 +471,60 @@ def train_model_pytorch(data_input_obj, target_column, task_type, model_type_pt,
484
  if not (epochs > 0 and batch_size > 0 and lr > 0): raise ValueError("Params must be >0.")
485
  except ValueError as e: logs += f"Invalid training params: {e}\n"; return logs, f"Error: {e}", None, "N/A", None
486
 
487
- processed_input_dim_actual = -1; nn_output_dim_actual = -1; preprocessor_pipeline = None
488
- X_processed_np = None; y_processed_np = None
489
 
490
  if model_type_pt == "Simple Neural Network (MLP)":
491
  if not task_type.startswith("Tabular"):
492
  logs += "MLP requires Tabular task.\n"; return logs, "MLP Task Error", None, "N/A", None
 
 
 
493
  try:
494
- # Pass df_for_pytorch or (X_numpy_for_pytorch, y_numpy_for_pytorch)
495
  data_arg1 = df_for_pytorch if df_for_pytorch is not None else X_numpy_for_pytorch
496
  data_arg2 = y_numpy_for_pytorch if df_for_pytorch is None else None
 
 
 
497
  X_processed_np, y_processed_np, preprocessor_pipeline, logs, processed_input_dim_actual, nn_output_dim_actual, _ = \
498
- preprocess_tabular_data(data_arg1, data_arg2, target_column, task_type, logs)
499
  except ValueError as e: logs+=f"MLP Preprocessing error: {e}\n"; return logs,f"Error: {e}",None,"N/A",None
500
-
501
  elif model_type_pt == "Simple Convolutional Network (CNN)":
502
- if task_type != "Basic Image Classification":
503
- logs += "Warning: CNN selected, but task is not Basic Image Classification. Output may be unexpected.\n"
504
 
 
505
  if df_for_pytorch is not None:
506
- if target_column not in df_for_pytorch.columns:
507
- logs += f"Target '{target_column}' not found for CNN.\n"; return logs, "CNN Target Error", None, "N/A", None
508
  X_raw = df_for_pytorch.drop(target_column, axis=1).values
509
  y_raw = df_for_pytorch[target_column].values
510
  elif X_numpy_for_pytorch is not None and y_numpy_for_pytorch is not None:
511
- X_raw = X_numpy_for_pytorch
512
- y_raw = y_numpy_for_pytorch
513
- else:
514
- logs += "No valid data found for CNN.\n"; return logs, "CNN Data Error", None, "N/A", None
515
 
516
  le = LabelEncoder(); y_processed_np = le.fit_transform(y_raw)
517
  nn_output_dim_actual = len(le.classes_)
518
- if nn_output_dim_actual == 2: nn_output_dim_actual = 1 # Binary output for NN
519
 
520
- pixels_per_sample = X_raw.shape[1]
521
  img_dim_approx = int(math.sqrt(pixels_per_sample))
522
- img_h, img_w, input_channels = (28,28,1) # Default
523
- if img_dim_approx * img_dim_approx == pixels_per_sample:
524
- img_h, img_w = img_dim_approx, img_dim_approx
525
- else: logs += f"Warning: Cannot infer square image from {pixels_per_sample} pixels. Defaulting to 28x28 for CNN.\n"
526
 
527
- # Reshape and normalize (basic)
528
  X_processed_np = X_raw.reshape(-1, input_channels, img_h, img_w).astype(np.float32) / 255.0
529
- processed_input_dim_actual = (input_channels, img_h, img_w) # For CNN constructor
530
  logs += f"CNN Data: X reshaped to {X_processed_np.shape}, y: {y_processed_np.shape}, NN Output Dim: {nn_output_dim_actual}\n"
531
  else: logs += f"Unknown PyTorch model: {model_type_pt}\n"; return logs, "Unknown PyTorch model", None, "N/A", None
532
 
533
- X_tensor = torch.tensor(X_processed_np, dtype=torch.float32)
534
- # Adjust y_tensor dtype based on loss function expectations
535
  y_dtype = torch.float32 if (nn_output_dim_actual == 1 and task_type.endswith("Regression")) or \
536
  (nn_output_dim_actual == 1 and task_type.endswith("Classification")) \
537
- else torch.long # MSELoss/BCELoss with float, CrossEntropy with long
 
538
  y_tensor = torch.tensor(y_processed_np, dtype=y_dtype)
539
- if nn_output_dim_actual == 1 and task_type.endswith("Classification"): y_tensor = y_tensor.unsqueeze(1) # For BCE based loss
540
- if task_type.endswith("Regression"): y_tensor = y_tensor.unsqueeze(1) # MSELoss expects [N,1]
541
 
542
  dataset = TensorDataset(X_tensor, y_tensor)
543
- # Use num_workers=0 on free tier to avoid issues with multiprocessing
544
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
545
 
546
  pytorch_model = None
@@ -551,21 +535,19 @@ def train_model_pytorch(data_input_obj, target_column, task_type, model_type_pt,
551
  task_type="classification" if task_type.endswith("Classification") else "regression")
552
  elif model_type_pt == "Simple Convolutional Network (CNN)":
553
  channels, h, w = processed_input_dim_actual
554
- pytorch_model = SimpleCNN(input_channels=channels, img_size_wh=(h,w), num_classes=nn_output_dim_actual)
555
- except Exception as model_e:
556
- logs += f"Error creating PyTorch model: {traceback.format_exc()}\n"; return logs, f"Model Creation Error: {model_e}", None, "N/A", None
557
 
558
  if pytorch_model is None: logs += "Failed to instantiate PyTorch model.\n"; return logs, "Model instantiate fail", None, "N/A", None
559
- model_params_val = count_pytorch_parameters(pytorch_model)
560
- model_params_out = f"{model_params_val:,}"
561
  logs += f"PyTorch Model: {model_params_out} params.\n"
562
  if model_params_val > 500000: logs += "Warning: >500k params on CPU will be SLOW.\n"
563
 
564
- is_classification_task = task_type.endswith("Classification") or model_type_pt == "Simple Convolutional Network (CNN)" # Treat CNN as classification here
565
  if is_classification_task:
566
  criterion = nn.BCELoss() if nn_output_dim_actual == 1 else nn.CrossEntropyLoss()
567
- else: # Regression
568
- criterion = nn.MSELoss()
569
  optimizer = optim.Adam(pytorch_model.parameters(), lr=lr)
570
 
571
  logs += f"Starting PyTorch training for {epochs} epochs...\n"; start_time = time.time()
@@ -573,223 +555,185 @@ def train_model_pytorch(data_input_obj, target_column, task_type, model_type_pt,
573
  pytorch_model.train()
574
  for epoch in range(epochs):
575
  epoch_loss_sum = 0.0; num_batches = 0
576
- for batch_X, batch_y in dataloader:
577
- optimizer.zero_grad()
578
- outputs = pytorch_model(batch_X)
579
- loss = criterion(outputs, batch_y)
580
- loss.backward(); optimizer.step()
581
  epoch_loss_sum += loss.item(); num_batches += 1
582
  avg_epoch_loss = epoch_loss_sum / num_batches if num_batches > 0 else 0
583
  epoch_losses.append(avg_epoch_loss)
584
  logs += f"Epoch {epoch+1}/{epochs}, Avg Loss: {avg_epoch_loss:.4f}\n"
585
- # yield logs, metrics_out, model_path_out, model_params_out, None # For streaming, but makes UI complex
586
-
587
- training_time = time.time() - start_time
588
- logs += f"PyTorch training completed in {training_time:.2f} seconds.\n"
589
 
590
- # Basic evaluation (on last batch for simplicity, or could do full test set)
591
- # A proper eval loop on a test set would be better here.
592
  pytorch_model.eval()
593
  with torch.no_grad():
594
- # For simplicity, let's just report final training loss.
595
- # A full evaluation on a test split would be needed for proper metrics.
596
- if is_classification_task:
597
- # This is a very rough accuracy on the last training batch for demo
598
- if dataloader.dataset: # Check if dataset is not empty
599
- try:
600
- last_batch_X, last_batch_y = next(iter(dataloader)) # Get one batch
601
- outputs = pytorch_model(last_batch_X)
602
- if nn_output_dim_actual == 1: # Binary
603
- predicted = (outputs > 0.5).float()
604
- else: # Multi-class
605
- _, predicted = torch.max(outputs.data, 1)
606
- correct = (predicted == last_batch_y.view_as(predicted)).sum().item()
607
- total = last_batch_y.size(0)
608
- acc = correct / total if total > 0 else 0
609
- metrics_out = f"Final Training Loss: {avg_epoch_loss:.4f}\nApprox. Accuracy on a batch: {acc*100:.2f}% (Note: Proper eval needs a test set)"
610
- except StopIteration: # Dataloader was empty
611
- metrics_out = f"Final Training Loss: {avg_epoch_loss:.4f}\n (Dataloader empty, cannot get batch accuracy)"
612
-
613
- else:
614
- metrics_out = f"Final Training Loss: {avg_epoch_loss:.4f}\n (No data for batch accuracy)"
615
- else: # Regression
616
- metrics_out = f"Final Training Loss (MSE): {avg_epoch_loss:.4f}"
617
- logs += "\n--- PyTorch Metrics (Simplified) ---\n" + metrics_out + "\n"
618
 
619
- # Loss plot
620
  if epoch_losses:
 
 
621
  import matplotlib.pyplot as plt
622
- fig, ax = plt.subplots()
623
- ax.plot(range(1, epochs + 1), epoch_losses, marker='o')
624
- ax.set_xlabel("Epoch")
625
- ax.set_ylabel("Average Loss")
626
- ax.set_title("Training Loss Curve")
627
- plot_out = fig # Gradio can display matplotlib figures
628
- logs += "Loss curve generated.\n"
629
-
630
 
631
- # Save model (and preprocessor if MLP)
632
  model_filename_base = f"pytorch_{model_type_pt.replace(' ', '_').lower()}"
633
  if model_output_format == ".pt (PyTorch)":
634
  model_path_out = get_temp_filepath(model_filename_base, "pt")
 
635
  if model_type_pt == "Simple Neural Network (MLP)" and preprocessor_pipeline:
636
- torch.save({
637
- 'model_state_dict': pytorch_model.state_dict(),
638
- 'preprocessor': preprocessor_pipeline,
639
- 'input_dim': processed_input_dim_actual, # From preprocessing
640
- 'output_dim': nn_output_dim_actual, # From preprocessing
641
- 'hidden_layers_str': mlp_hidden_layers_str,
642
- 'activation_fn': mlp_activation,
643
- 'task_type': task_type
644
- }, model_path_out)
645
  logs += f"PyTorch MLP (model + preprocessor) saved to {model_path_out}\n"
646
- else: # CNN or MLP without preprocessor explicitly bundled (less common)
647
- torch.save(pytorch_model.state_dict(), model_path_out)
648
- logs += f"PyTorch {model_type_pt} (model state_dict) saved to {model_path_out}\n"
649
- # Add ONNX export for PyTorch later if needed (torch.onnx.export)
650
- else:
651
- logs += f"Unsupported format '{model_output_format}' for PyTorch. Saving as .pt\n"
 
 
652
  model_path_out = get_temp_filepath(model_filename_base, "pt")
653
- torch.save(pytorch_model.state_dict(), model_path_out) # Fallback to state_dict
654
-
655
  return logs, metrics_out, model_path_out, model_params_out, plot_out
656
 
657
-
658
  # --- Gradio UI Definition ---
659
- # Define choices
660
- TASK_CHOICES = ["Tabular Classification", "Tabular Regression", "Basic Image Classification"] # Simple Text removed for focus
661
  MODEL_FAMILIES = ["Scikit-learn (Classical ML)", "PyTorch (Neural Networks)"]
662
  SKLEARN_MODELS_CLASSIFICATION = ["Logistic Regression", "Random Forest Classifier", "Support Vector Machine (SVM) Classifier"]
663
  SKLEARN_MODELS_REGRESSION = ["Linear Regression", "Random Forest Regressor", "Support Vector Machine (SVR) Regressor"]
664
  PYTORCH_MODELS = ["Simple Neural Network (MLP)", "Simple Convolutional Network (CNN)"]
665
  DATASET_FORMATS = [".csv", ".json", ".parquet"]
666
  MODEL_OUTPUT_FORMATS_SKLEARN = [".pkl (Scikit-learn)", ".onnx (ONNX)"]
667
- MODEL_OUTPUT_FORMATS_PYTORCH = [".pt (PyTorch)"] # ".onnx (ONNX)" can be added later
668
  MLP_ACTIVATIONS = ["relu", "tanh", "sigmoid"]
669
-
670
  CLONE_GUIDE_TEXT = """
671
  ## How to Clone & Upgrade This Space for More Power:
672
  (Instructions as provided in previous response - omitted here for brevity but should be included)
673
  """
674
 
 
 
 
 
 
 
 
 
 
 
 
 
675
  def update_model_options(task_choice, model_family_choice):
 
676
  if model_family_choice == "Scikit-learn (Classical ML)":
677
- if task_choice == "Tabular Classification": return gr.update(choices=SKLEARN_MODELS_CLASSIFICATION, value=SKLEARN_MODELS_CLASSIFICATION[0], visible=True)
678
- elif task_choice == "Tabular Regression": return gr.update(choices=SKLEARN_MODELS_REGRESSION, value=SKLEARN_MODELS_REGRESSION[0], visible=True)
679
- else: return gr.update(choices=[], value=None, visible=False) # Sklearn not for image task here
680
  elif model_family_choice == "PyTorch (Neural Networks)":
681
- if task_choice.startswith("Tabular"): return gr.update(choices=[PYTORCH_MODELS[0]], value=PYTORCH_MODELS[0], visible=True) # Only MLP for tabular
682
- elif task_choice == "Basic Image Classification": return gr.update(choices=[PYTORCH_MODELS[1]], value=PYTORCH_MODELS[1], visible=True) # Only CNN for image
683
- else: return gr.update(choices=[], value=None, visible=False)
684
- return gr.update(choices=[], value=None, visible=False)
685
 
686
  def update_param_range_visibility(model_family_choice):
687
  return gr.update(visible=(model_family_choice == "PyTorch (Neural Networks)"))
688
 
689
- def update_pytorch_specific_options_visibility(model_choice_pytorch):
690
- is_mlp = model_choice_pytorch == "Simple Neural Network (MLP)"
691
- is_cnn = model_choice_pytorch == "Simple Convolutional Network (CNN)"
692
- return gr.update(visible=is_mlp), gr.update(visible=is_cnn) # MLP Group, CNN Group
 
 
 
 
 
693
 
694
  def update_model_output_formats(model_family_choice):
695
- if model_family_choice == "Scikit-learn (Classical ML)":
696
- return gr.update(choices=MODEL_OUTPUT_FORMATS_SKLEARN, value=MODEL_OUTPUT_FORMATS_SKLEARN[0])
697
- elif model_family_choice == "PyTorch (Neural Networks)":
698
- return gr.update(choices=MODEL_OUTPUT_FORMATS_PYTORCH, value=MODEL_OUTPUT_FORMATS_PYTORCH[0])
699
  return gr.update(choices=[], value=None)
700
 
701
-
702
- css = """
703
- .gradio-container { font-family: 'IBM Plex Sans', sans-serif; }
704
- .gr-button { color: white; border-color: black; background: black; }
705
- .gr-input { border-radius: 8px; }
706
- .gr-output { border-radius: 8px; }
707
- """
708
 
709
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="orange"), css=css) as demo:
710
- gr.Markdown("# 🧠 Universal AI Model Trainer (CPU Edition)")
711
  gr.Markdown("Create, train, and download AI models. Optimized for CPU - expect longer training for complex models.")
712
-
713
- # Global state to store generated data path or df
714
- # This helps pass data between dataset generation and training without re-upload
715
- # For DataFrames, it's better to pass them directly if possible, or save/load paths.
716
  generated_data_state = gr.State(None)
717
- current_logs_state = gr.State("") # To accumulate logs
718
 
719
  with gr.Tabs():
720
  with gr.TabItem("1. Define Task & Model"):
721
  with gr.Row():
722
- task_type_dd = gr.Dropdown(TASK_CHOICES, label="Select Task Type", value=TASK_CHOICES[0])
723
- model_family_dd = gr.Dropdown(MODEL_FAMILIES, label="Select Model Family", value=MODEL_FAMILIES[0])
 
724
 
725
- model_specific_dd = gr.Dropdown(label="Select Specific Model", interactive=True) # Populated by callback
726
-
727
- # PyTorch Parameter Range (only visible for PyTorch)
728
  pytorch_param_range_dd = gr.Dropdown(list(PARAM_RANGES.keys()), label="Target Parameter Range (for NNs)",
729
  info="Guides NN architecture suggestions. Training >250k params on CPU is slow.",
730
- value=list(PARAM_RANGES.keys())[1], visible=False)
731
-
732
- # PyTorch MLP Specifics (only visible for MLP)
733
- with gr.Group(visible=False) as pt_mlp_specific_group:
734
  gr.Markdown("#### MLP Configuration")
735
- # Input dim will be determined after data preprocessing for MLP. User doesn't set it here.
736
- # Output dim also determined by data (num_classes or 1 for regression)
737
  pt_mlp_hidden_layers_txt = gr.Textbox(label="Hidden Layer Sizes (comma-separated, e.g., 128,64)", value="64,32")
738
  pt_mlp_activation_dd = gr.Dropdown(MLP_ACTIVATIONS, label="Activation Function", value="relu")
739
- pt_mlp_suggest_btn = gr.Button("Suggest MLP Layers for Target Range")
 
 
740
  pt_mlp_param_count_txt = gr.Textbox(label="Estimated MLP Parameters", interactive=False)
741
- # For MLP param estimation, we'd need #input_features and #output_classes from data step
742
- # This means estimation might be better placed *after* dataset is defined.
743
- # For now, placeholder or user has to guess input/output dims.
744
- # Simplified: we'll show actual params *after* training or with a dedicated button post-data.
745
-
746
- # PyTorch CNN Specifics (Placeholder - visible for CNN)
747
- with gr.Group(visible=False) as pt_cnn_specific_group:
748
- gr.Markdown("#### CNN Configuration (Simplified for Demo)")
749
- gr.Markdown("SimpleCNN uses fixed architecture for now (2 conv layers, 1 FC). Parameters mainly come from image size/classes.")
750
- # For CNN param estimation, we need image H, W, num_classes from data step.
751
- # cnn_img_h_param_est = gr.Number(label="Est. Image Height (for param count)", value=28, visible=False) # Hidden, used by callback
752
- # cnn_img_w_param_est = gr.Number(label="Est. Image Width (for param count)", value=28, visible=False)
753
- # cnn_num_classes_param_est = gr.Number(label="Est. Num Classes (for param count)", value=10, visible=False)
754
  pt_cnn_param_count_txt = gr.Textbox(label="Estimated CNN Parameters", interactive=False)
755
- # Actual CNN param count shown after training or with dedicated button post-data.
756
 
757
-
758
- with gr.TabItem("2. Configure Dataset"):
759
  dataset_source_rb = gr.Radio(["Generate new dataset", "Upload my own dataset (CSV, JSON, Parquet)"],
760
  label="Dataset Source", value="Generate new dataset")
761
-
762
- with gr.Group(visible=True) as generate_dataset_group: # Visible by default
763
  gr.Markdown("#### Generate Synthetic Dataset")
764
  with gr.Row():
765
- ds_gen_samples_num = gr.Number(label="Number of Rows (Samples)", value=1000)
766
- ds_gen_features_num = gr.Number(label="Number of Features (Columns, if tabular)", value=10)
767
- ds_gen_classes_informative_num = gr.Number(label="Num Classes (for Classification) / Num Informative Features (for Regression)", value=2)
768
- ds_gen_ai_suggest_cb = gr.Checkbox(label="Let AI suggest optimal rows/columns based on model type & param range?", value=False)
769
- ds_gen_format_dd = gr.Dropdown(DATASET_FORMATS, label="Generated Dataset Download Format", value=".csv")
770
  generate_dataset_btn = gr.Button("Generate & Preview Dataset", variant="secondary")
771
-
772
  with gr.Group(visible=False) as upload_dataset_group:
773
  gr.Markdown("#### Upload Dataset")
774
- ds_upload_file = gr.File(label="Upload your dataset file", file_types=[".csv", ".json", ".parquet"])
775
-
776
- target_column_name_txt = gr.Textbox(label="Target Column Name (Case-Sensitive)", placeholder="e.g., 'target' or 'label'")
777
- dataset_preview_df = gr.DataFrame(label="Dataset Preview (First 5 Rows)", interactive=False)
778
  generated_dataset_download_file = gr.File(label="Download Generated Dataset", interactive=False)
779
 
780
  with gr.TabItem("3. Train Model & Get Results"):
781
- gr.Markdown("Ensure Model and Dataset are configured before training.")
782
  with gr.Row():
783
- # Training Hyperparameters (Common for PyTorch)
784
- # For Scikit-learn, HPs are mostly defaults or need more complex UI
785
- # These are mainly for PyTorch NNs
786
- train_epochs_num = gr.Number(label="Epochs (for NNs)", value=10)
787
- train_batch_size_num = gr.Number(label="Batch Size (for NNs)", value=32)
788
- train_learning_rate_num = gr.Number(label="Learning Rate (for NNs)", value=0.001)
789
-
790
- model_output_format_dd = gr.Dropdown(label="Select Model Output Format", choices=MODEL_OUTPUT_FORMATS_SKLEARN, value=MODEL_OUTPUT_FORMATS_SKLEARN[0]) # Default to sklearn
791
  train_model_btn = gr.Button("🚀 Train Model", variant="primary")
792
-
793
  gr.Markdown("---")
794
  gr.Markdown("### Training Progress & Results")
795
  training_log_txt = gr.Textbox(label="Training Log & Status", lines=15, interactive=False, max_lines=50)
@@ -799,212 +743,116 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="orange"),
799
  download_trained_model_file = gr.File(label="Download Trained Model", interactive=False)
800
 
801
  with gr.TabItem("ℹ️ Guide & Info"):
802
- gr.Markdown("### Using This Space")
803
- gr.Markdown("- **Free CPU Tier:** Training large or complex models will be slow. Memory is also limited (around 15GB RAM).")
804
- gr.Markdown("- **Workflow:** 1. Define Task/Model -> 2. Configure Dataset -> 3. Train.")
805
- gr.Markdown("- **Dataset Generation:** For 'Basic Image Classification', random pixel data is generated (not real images).")
806
- gr.Markdown("- **Parameters:** For Neural Networks, the 'Target Parameter Range' helps suggest architectures. 1M params is already large for CPU training.")
807
- gr.Markdown("- **ONNX Export (Scikit-learn):** Converts Scikit-learn pipelines (preprocessor + model) to ONNX. Input to the ONNX model should be raw data matching the original training DataFrame structure.")
808
  gr.Markdown(CLONE_GUIDE_TEXT)
809
 
810
  # --- Event Handlers ---
811
- # Update model choices based on task and family
812
  task_type_dd.change(fn=update_model_options, inputs=[task_type_dd, model_family_dd], outputs=model_specific_dd)
813
  model_family_dd.change(fn=update_model_options, inputs=[task_type_dd, model_family_dd], outputs=model_specific_dd)
814
 
815
- # Show/hide PyTorch parameter range dropdown
816
- model_family_dd.change(fn=update_param_range_visibility, inputs=model_family_dd, outputs=pytorch_param_range_dd)
817
-
818
- # Show/hide PyTorch MLP/CNN specific groups
819
- # This needs model_specific_dd as input, which is tricky if it's dynamically populated.
820
- # Let's assume model_specific_dd is the PyTorch model dropdown for this context.
821
- # This means model_specific_dd must *only* be active/relevant when model_family_dd is PyTorch.
822
- def combined_pytorch_ui_update(model_family_choice, pytorch_model_choice):
823
- param_range_visible = (model_family_choice == "PyTorch (Neural Networks)")
824
- if not param_range_visible: # If not PyTorch, hide all PyTorch specific groups
825
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
826
-
827
- is_mlp = (pytorch_model_choice == "Simple Neural Network (MLP)")
828
- is_cnn = (pytorch_model_choice == "Simple Convolutional Network (CNN)")
829
- return gr.update(visible=param_range_visible), gr.update(visible=is_mlp), gr.update(visible=is_cnn)
830
-
831
- model_family_dd.change(fn=combined_pytorch_ui_update,
832
- inputs=[model_family_dd, model_specific_dd],
833
- outputs=[pytorch_param_range_dd, pt_mlp_specific_group, pt_cnn_specific_group])
834
- model_specific_dd.change(fn=combined_pytorch_ui_update, # Also trigger when specific PyTorch model changes
835
- inputs=[model_family_dd, model_specific_dd],
836
- outputs=[pytorch_param_range_dd, pt_mlp_specific_group, pt_cnn_specific_group])
837
-
838
- # Suggest MLP Layers
839
- def mlp_suggest_proxy(target_range_str, current_logs, dataset_preview_df, target_col_name, task_type):
840
- logs = current_logs
841
- input_dim_est = 10 # default if no data
842
- output_dim_est = 2 if task_type.endswith("Classification") else 1 # default
843
-
844
- if dataset_preview_df is not None and isinstance(dataset_preview_df, pd.DataFrame) and not dataset_preview_df.empty and target_col_name:
845
- try:
846
- # Attempt to get processed input dim. This is a simplified estimation.
847
- # A full preprocessing run is too heavy here.
848
- temp_X = dataset_preview_df.drop(target_col_name, axis=1, errors='ignore')
849
- num_cols = len(temp_X.select_dtypes(include=np.number).columns)
850
- cat_cols = temp_X.select_dtypes(include='object').columns
851
- # Rough estimate of one-hot encoded features
852
- one_hot_est = sum(min(10, dataset_preview_df[col].nunique()) for col in cat_cols) # cap nunique
853
- input_dim_est = num_cols + one_hot_est
854
- input_dim_est = max(1, input_dim_est) # Ensure > 0
855
-
856
- if task_type.endswith("Classification"):
857
- output_dim_est = max(1, dataset_preview_df[target_col_name].nunique())
858
- if output_dim_est == 2: output_dim_est = 1 # For binary an output of 1 is common in NNs
859
- logs += f"Estimated input_dim: {input_dim_est}, output_dim: {output_dim_est} for MLP suggestion.\n"
860
- except Exception as e:
861
- logs += f"Could not estimate dims from preview for MLP suggestion: {e}. Using defaults.\n"
862
- else:
863
- logs += "Dataset preview not available for MLP dimension estimation. Using defaults.\n"
864
 
865
- suggested_str, logs = suggest_mlp_layers_for_range(input_dim_est, output_dim_est, target_range_str, logs)
866
-
867
- # Also estimate params for the suggestion
 
 
 
 
 
 
 
 
868
  param_count_str = "Error"
869
- if suggested_str:
870
- param_count_str, logs = estimate_current_mlp_params(str(input_dim_est), suggested_str, str(output_dim_est), logs)
871
-
872
  return suggested_str, logs, param_count_str
873
-
874
- pt_mlp_suggest_btn.click(
875
- fn=mlp_suggest_proxy,
876
  inputs=[pytorch_param_range_dd, current_logs_state, dataset_preview_df, target_column_name_txt, task_type_dd],
877
- outputs=[pt_mlp_hidden_layers_txt, training_log_txt, pt_mlp_param_count_txt] # Use training_log_txt for logs from suggestion
878
- )
 
 
 
 
 
879
 
880
- # Estimate MLP params when hidden layers text changes (might be too slow if hooked to .change)
881
- # A button is safer for this. For now, rely on suggestion button or post-training report.
882
- # We can add an "Estimate Current MLP Params" button if needed.
 
 
 
 
 
 
 
 
 
 
 
 
883
 
884
- # Show/hide dataset generation/upload groups
885
  def toggle_dataset_source_groups(source_choice):
886
- return gr.update(visible=(source_choice == "Generate new dataset")), \
887
- gr.update(visible=(source_choice == "Upload my own dataset (CSV, JSON, Parquet)"))
888
- dataset_source_rb.change(fn=toggle_dataset_source_groups, inputs=dataset_source_rb,
889
- outputs=[generate_dataset_group, upload_dataset_group])
890
-
891
- # Update model output formats based on family
892
  model_family_dd.change(fn=update_model_output_formats, inputs=model_family_dd, outputs=model_output_format_dd)
893
 
894
- # Dataset Generation Button
895
- def generate_dataset_wrapper(task_type, n_samples, n_features, n_classes_info, ds_format, ai_sugg, param_range, model_type, logs_in):
896
- preview, data_obj, logs_out, file_out = generate_dataset_backend(
897
- task_type, n_samples, n_features, n_classes_info, ds_format, ai_sugg, param_range, model_type, logs_in
898
- )
899
- # Store the actual data (DataFrame or (X,y) tuple) in state if generation was successful
900
- # If it's a filepath (from upload), store the path.
901
- # For generated data, store the df or (X,y) tuple to avoid disk I/O if not necessary before training.
902
- stored_data = data_obj if data_obj is not None else None
903
- return preview, stored_data, logs_out, file_out
904
-
905
  generate_dataset_btn.click(
906
- fn=generate_dataset_wrapper,
907
  inputs=[task_type_dd, ds_gen_samples_num, ds_gen_features_num, ds_gen_classes_informative_num,
908
  ds_gen_format_dd, ds_gen_ai_suggest_cb, pytorch_param_range_dd, model_specific_dd, current_logs_state],
909
- outputs=[dataset_preview_df, generated_data_state, training_log_txt, generated_dataset_download_file]
910
- )
911
 
912
- # Handle dataset upload
913
  def process_uploaded_file(file_obj, logs_in):
914
- logs = logs_in
915
- if file_obj is None:
916
- return None, logs, "Please upload a file first.", None
917
- logs += f"Uploaded file: {file_obj.name}\n"
918
-
919
- # For preview, try to read a few lines
920
- df_preview = None
921
  try:
922
- if file_obj.name.endswith(".csv"):
923
- df_preview = pd.read_csv(file_obj.name, nrows=5)
924
- elif file_obj.name.endswith(".json"): # Assuming JSONL
925
- df_preview = pd.read_json(file_obj.name, lines=True, nrows=5)
926
- elif file_obj.name.endswith(".parquet"):
927
- # Reading only 5 rows from parquet is not straightforward without loading more.
928
- # For simplicity, load full and take head, or skip preview.
929
- temp_df = pd.read_parquet(file_obj.name)
930
- df_preview = temp_df.head()
931
- logs += "Preview generated for uploaded file.\n"
932
- except Exception as e:
933
- logs += f"Could not generate preview for {file_obj.name}: {e}\n"
934
- return None, logs, f"Error previewing: {e}", file_obj.name # Return path even if preview fails
935
-
936
- return df_preview, logs, "File ready for training.", file_obj.name # Store path in generated_data_state
937
-
938
- ds_upload_file.upload(
939
- fn=process_uploaded_file,
940
- inputs=[ds_upload_file, current_logs_state],
941
- outputs=[dataset_preview_df, training_log_txt, training_log_txt, generated_data_state] # Use training_log for status, then store path
942
- )
943
-
944
-
945
- # Train Model Button
946
- def train_model_wrapper(data_state_val, # This will be DataFrame, (X,y) tuple, or filepath string
947
- target_col, task_type, model_family, model_name, # Common params
948
- # Sklearn specific (none for now beyond model_name)
949
- # PyTorch specific
950
- pt_model_type, pt_mlp_hidden, pt_mlp_activ, #pt_cnn_params (later)
951
- epochs, batch_size, lr,
952
- model_out_format,
953
- logs_in): # Accumulate logs
954
-
955
- current_logs = logs_in + "\n--- Initiating Training ---\n"
956
- current_logs += f"Data state type: {type(data_state_val)}\n"
957
-
958
- if data_state_val is None:
959
- current_logs += "Error: No dataset loaded or generated. Please go to Tab 2.\n"
960
- return current_logs, "No data available.", None, "N/A", None, None # logs, metrics, model_file, params, plot, download_btn_update
961
-
962
- if not target_col and (task_type.startswith("Tabular") or (isinstance(data_state_val, pd.DataFrame) and model_type_pt != "Simple Convolutional Network (CNN)")) : # Target col needed for tabular
963
- current_logs += "Error: Target column name is required for this task/data.\n"
964
- return current_logs, "Target column needed.", None, "N/A", None, None
965
-
966
- # Ensure logs are passed and returned correctly by train functions
967
- if model_family == "Scikit-learn (Classical ML)":
968
- logs, metrics, model_file, params = train_model_sklearn(
969
- data_state_val, target_col, task_type, model_name, model_out_format, current_logs
970
- )
971
- return logs, metrics, model_file, params, None, model_file # No plot for sklearn here
972
-
973
- elif model_family == "PyTorch (Neural Networks)":
974
- # model_name here is the PyTorch model type (MLP or CNN)
975
- logs, metrics, model_file, params, plot = train_model_pytorch(
976
- data_state_val, target_col, task_type, model_name,
977
- pt_mlp_hidden, pt_mlp_activ,
978
- epochs, batch_size, lr,
979
- model_out_format, current_logs
980
- )
981
- return logs, metrics, model_file, params, plot, model_file
982
- else:
983
- current_logs += f"Unknown model family: {model_family}\n"
984
- return current_logs, "Unknown model family.", None, "N/A", None, None
985
 
986
  train_model_btn.click(
987
- fn=train_model_wrapper,
988
- inputs=[
989
- generated_data_state, target_column_name_txt, task_type_dd, model_family_dd, model_specific_dd,
990
- # PyTorch specific inputs (will be None if not PyTorch family, but passed)
991
- model_specific_dd, # This is pt_model_type if family is PyTorch
992
- pt_mlp_hidden_layers_txt, pt_mlp_activation_dd,
993
- train_epochs_num, train_batch_size_num, train_learning_rate_num,
994
- model_output_format_dd,
995
- training_log_txt # Pass current log content to append
996
- ],
997
- outputs=[
998
- training_log_txt, evaluation_metrics_txt, download_trained_model_file,
999
- model_param_count_output_txt, loss_plot_img,
1000
- download_trained_model_file # This seems redundant, download_trained_model_file is already an output
1001
- ]
1002
- )
1003
-
1004
- # Clear logs button (optional)
1005
- # clear_logs_btn = gr.Button("Clear Logs")
1006
- # def clear_logs_func(): return "", "" # Clears current_logs_state and training_log_txt
1007
- # clear_logs_btn.click(clear_logs_func, [], [current_logs_state, training_log_txt])
1008
-
1009
-
1010
- demo.queue().launch(debug=True, show_error=True) # Enable queue for longer tasks, debug for local testing
 
29
  # ONNX specific imports
30
  import skl2onnx
31
  from skl2onnx import convert_sklearn
32
+ from skl2onnx.common.data_types import FloatTensorType, Int64TensorType, StringTensorType # Ensure this import if used for ONNX initial types
33
  import onnxruntime as rt
34
 
35
  import traceback
 
38
  import math
39
  import collections.abc # For Gradio issue with new Python versions
40
 
41
+
42
  # --- Global Variables / Constants ---
43
  TEMP_DIR = "temp_outputs"
44
  os.makedirs(TEMP_DIR, exist_ok=True)
45
+ MAX_DATASET_ROWS_WARN = 30000
46
+ MAX_GENERATED_ROWS = 50000
47
+ MAX_GENERATED_COLS = 100
48
 
49
+ # --- Helper Functions (count_parameters, get_temp_filepath) ---
50
  def count_sklearn_parameters(model):
51
  if hasattr(model, 'coef_'):
52
  return model.coef_.size + (model.intercept_.size if hasattr(model, 'intercept_') else 0)
 
62
  return sum(p.numel() for p in model.parameters() if p.requires_grad)
63
 
64
  def get_temp_filepath(filename_base, extension):
 
65
  clean_extension = extension.lstrip('.')
66
  return os.path.join(TEMP_DIR, f"{filename_base}_{time.strftime('%Y%m%d-%H%M%S')}.{clean_extension}")
67
 
68
+ # --- PyTorch Model Definitions (SimpleMLP, SimpleCNN) ---
 
69
  class SimpleMLP(nn.Module):
70
  def __init__(self, input_dim, hidden_layers_str, output_dim, activation_fn_str="relu", task_type="classification"):
71
  super(SimpleMLP, self).__init__()
 
93
 
94
  layers.append(nn.Linear(current_dim, output_dim))
95
 
96
+ if task_type == "classification": # Changed from task_type.endswith("Classification")
97
  if output_dim == 1: layers.append(nn.Sigmoid()) # Binary
98
  elif output_dim > 1: layers.append(nn.Softmax(dim=-1)) # Multi-class
99
  self.network = nn.Sequential(*layers)
100
  def forward(self, x): return self.network(x)
101
 
102
+ class SimpleCNN(nn.Module): # Added task_type to constructor for clarity
103
+ def __init__(self, input_channels, img_size_wh, num_classes=10, task_type="classification",
104
  c_out1=16, k1=3, s1=1, p1=1, pool1_k=2, pool1_s=2,
105
  c_out2=32, k2=3, s2=1, p2=1, pool2_k=2, pool2_s=2,
106
  fc_hidden=128):
107
  super(SimpleCNN, self).__init__()
108
  self.input_channels = input_channels
109
  self.img_h, self.img_w = img_size_wh
110
+ self.num_classes = num_classes # This is the direct output dimension from the last linear layer
111
 
112
  self.conv1 = nn.Conv2d(self.input_channels, c_out1, kernel_size=k1, stride=s1, padding=p1)
113
  self.relu1 = nn.ReLU()
 
115
 
116
  h_out_conv1 = (self.img_h - k1 + 2 * p1) // s1 + 1
117
  w_out_conv1 = (self.img_w - k1 + 2 * p1) // s1 + 1
118
+ h_pool1 = (h_out_conv1 - pool1_k) // pool1_s + 1 if pool1_k > 0 else h_out_conv1 # handle no pooling
119
+ w_pool1 = (w_out_conv1 - pool1_k) // pool1_s + 1 if pool1_k > 0 else w_out_conv1
120
+
121
+
122
  self.conv2 = nn.Conv2d(c_out1, c_out2, kernel_size=k2, stride=s2, padding=p2)
123
  self.relu2 = nn.ReLU()
124
  self.pool2 = nn.MaxPool2d(kernel_size=pool2_k, stride=pool2_s)
125
 
126
  h_out_conv2 = (h_pool1 - k2 + 2 * p2) // s2 + 1
127
  w_out_conv2 = (w_pool1 - k2 + 2 * p2) // s2 + 1
128
+ h_pool2 = (h_out_conv2 - pool2_k) // pool2_s + 1 if pool2_k > 0 else h_out_conv2
129
+ w_pool2 = (w_out_conv2 - pool2_k) // pool2_s + 1 if pool2_k > 0 else w_out_conv2
130
 
131
  self.flattened_size = c_out2 * h_pool2 * w_pool2
132
  if self.flattened_size <= 0:
133
+ raise ValueError(f"Calculated flattened size is {self.flattened_size}. Check CNN params and image size. Current (h_pool2, w_pool2): ({h_pool2},{w_pool2}) from img ({self.img_h},{self.img_w})")
134
 
135
  self.fc1 = nn.Linear(self.flattened_size, fc_hidden)
136
  self.relu3 = nn.ReLU()
137
+ self.fc2 = nn.Linear(fc_hidden, self.num_classes) # Output layer before final activation
138
 
139
+ if task_type == "classification":
140
+ if self.num_classes == 1: # Binary classification
141
+ self.final_activation = nn.Sigmoid()
142
+ elif self.num_classes > 1: # Multi-class classification
143
+ self.final_activation = nn.Softmax(dim=1)
144
+ else: # Should not happen for classification if num_classes is properly set
145
+ self.final_activation = nn.Identity()
146
+ else: # Regression
147
  self.final_activation = nn.Identity()
148
 
 
149
  def forward(self, x):
150
  x = self.pool1(self.relu1(self.conv1(x)))
151
  x = self.pool2(self.relu2(self.conv2(x)))
 
155
  x = self.final_activation(x)
156
  return x
157
 
158
+ # --- Parameter Target Helpers (PARAM_RANGES, suggest_mlp_layers_for_range, estimate_current_mlp_params, estimate_cnn_params) ---
159
+ PARAM_RANGES = collections.OrderedDict([
160
  ("Tiny (<10k)", (0, 10000)),
161
  ("Small (10k-50k)", (10000, 50000)),
162
  ("Medium (50k-250k)", (50000, 250000)),
 
194
  if not suggested_layers_str: suggested_layers_str = "64"; logs += "Defaulting to '64'.\n"
195
  return suggested_layers_str, logs
196
 
197
+ def estimate_current_mlp_params(input_dim_str, hidden_layers_str, output_dim_str, task_type, current_logs=""): # Added task_type
198
  logs = current_logs
199
  try:
200
  input_dim = int(input_dim_str); output_dim = int(output_dim_str)
201
  if input_dim <= 0 or output_dim <= 0: return "Input/Output dims must be > 0", logs
202
+
203
+ # Determine task_type for MLP constructor
204
+ mlp_task_type = "classification" if task_type.endswith("Classification") else "regression"
205
+ temp_mlp = SimpleMLP(input_dim, hidden_layers_str, output_dim, task_type=mlp_task_type)
206
  params = count_pytorch_parameters(temp_mlp); del temp_mlp
207
  return f"{params:,}", logs
208
  except Exception as e: logs += f"Error estimating MLP params: {e}\n"; return "Error", logs
209
 
210
+ def estimate_cnn_params(img_h_str, img_w_str, num_classes_str, task_type, current_logs=""): # Added task_type
211
  logs = current_logs
212
  try:
213
+ img_h, img_w, num_classes_parsed = int(img_h_str), int(img_w_str), int(num_classes_str)
214
+ if not (img_h > 0 and img_w > 0 and num_classes_parsed > 0): return "Image dims/classes must be > 0", logs
215
+
216
+ # Determine task_type for CNN constructor
217
+ cnn_task_type = "classification" if task_type.endswith("Classification") else "regression" # Assuming CNN for image is classification
218
+ temp_cnn = SimpleCNN(input_channels=1, img_size_wh=(img_h, img_w), num_classes=num_classes_parsed, task_type=cnn_task_type)
219
  params = count_pytorch_parameters(temp_cnn); del temp_cnn
220
  return f"{params:,}", logs
221
+ except Exception as e: logs += f"Error estimating CNN params: {traceback.format_exc()}\n"; return "Error", logs
222
+
223
 
224
+ # --- Dataset and Preprocessing (generate_dataset_backend, preprocess_tabular_data) ---
225
  def generate_dataset_backend(task_type, n_samples_str, n_features_str,
226
  n_classes_or_informative_str, dataset_format,
227
  ai_suggest_ds_shape, target_param_range_str, model_type_selection,
 
233
 
234
  if ai_suggest_ds_shape:
235
  n_samples_sugg, n_features_sugg, n_classes_or_informative_sugg = 5000, 10, 2
236
+ if task_type == "Tabular Regression": n_classes_or_informative_sugg = min(n_features_sugg // 2, 5) if n_features_sugg > 0 else 1
237
+ elif task_type == "Basic Image Classification": n_samples_sugg, n_features_sugg = 500, 0
238
 
239
  is_nn = "Network" in model_type_selection
240
  if is_nn and target_param_range_str in PARAM_RANGES:
241
  min_p, max_p = PARAM_RANGES[target_param_range_str]; avg_p = (min_p + max_p) / 2
242
+ if avg_p > 200000: n_samples_sugg = min(MAX_GENERATED_ROWS, n_samples_sugg * 3); n_features_sugg = min(MAX_GENERATED_COLS, n_features_sugg * 2) if task_type.startswith("Tabular") else n_features_sugg
243
+ elif avg_p < 50000: n_samples_sugg = max(200, n_samples_sugg // 2); n_features_sugg = max(3, n_features_sugg // 2) if task_type.startswith("Tabular") else n_features_sugg
244
 
245
  n_samples, n_features, n_classes_or_informative = n_samples_sugg, n_features_sugg, n_classes_or_informative_sugg
246
  logs += f"AI Suggested Dataset: Samples={n_samples}, Feats={n_features}, Classes/Informative={n_classes_or_informative}\n"
 
249
  if task_type.startswith("Tabular"): n_features = max(1, min(n_features, MAX_GENERATED_COLS))
250
  if n_samples > MAX_DATASET_ROWS_WARN: logs += f"Warning: Generating {n_samples} rows. May be slow.\n"
251
 
252
+ df = None; X_data=None; y_data=None
253
  try:
254
  if task_type == "Tabular Classification":
255
  n_cls = max(2, n_classes_or_informative)
256
+ n_inf = max(1, min(n_features, n_classes_or_informative if n_classes_or_informative >= n_cls else n_features // 2)) # make_classification expects n_informative <= n_features
257
+ if n_inf > n_features: n_inf = n_features
258
  X_data, y_data = make_classification(n_samples=n_samples, n_features=n_features, n_informative=n_inf,
259
  n_redundant=max(0,n_features - n_inf)//2, n_classes=n_cls, flip_y=0.05, random_state=42)
260
  df = pd.DataFrame(X_data, columns=[f'feature_{i}' for i in range(n_features)]); df['target'] = y_data
261
  elif task_type == "Tabular Regression":
262
  n_inf = max(1, min(n_features, n_classes_or_informative))
263
+ if n_inf > n_features: n_inf = n_features
264
  X_data, y_data = make_regression(n_samples=n_samples, n_features=n_features, n_informative=n_inf, noise=10, random_state=42)
265
  df = pd.DataFrame(X_data, columns=[f'feature_{i}' for i in range(n_features)]); df['target'] = y_data
266
  elif task_type == "Basic Image Classification":
 
267
  img_h, img_w = 28, 28
268
  num_pixels = img_h * img_w
269
  X_data = np.random.randint(0, 256, size=(n_samples, num_pixels), dtype=np.uint8)
 
274
 
275
  logs += f"Generated data: {df.shape if df is not None else (X_data.shape, y_data.shape)}\n"
276
  file_path = get_temp_filepath("generated_dataset", dataset_format)
277
+ if df is not None:
278
  if dataset_format == ".csv": df.to_csv(file_path, index=False)
279
  elif dataset_format == ".json": df.to_json(file_path, orient='records', lines=True)
280
  elif dataset_format == ".parquet": df.to_parquet(file_path, index=False)
281
  else: logs += f"Unsupported format {dataset_format}. Defaulting to CSV.\n"; file_path=get_temp_filepath("generated_dataset","csv"); df.to_csv(file_path, index=False)
282
  logs += f"Dataset saved to {file_path}\n"
283
+ # For consistency, data_obj returned should be what train functions expect
284
+ # Sklearn train func can take df or path. PyTorch train func can take df, path, or (X,y) tuple.
285
+ # Returning df here for generated data is fine.
286
+ return df.head(), df, logs, file_path
287
+ else:
288
+ logs += "Dataset generated as numpy arrays. Not saving to file from this function directly.\n"
289
+ return pd.DataFrame(X_data[:5] if X_data is not None else None), (X_data, y_data), logs, None
290
 
291
 
292
  except Exception as e: error_msg=f"Error generating dataset: {traceback.format_exc()}"; logs+=error_msg+"\n"; return None, error_msg, logs, None
 
298
  if target_column_name not in df.columns: raise ValueError(f"Target column '{target_column_name}' not found.")
299
  X_df = df.drop(target_column_name, axis=1)
300
  y_series = df[target_column_name]
301
+ elif isinstance(df_or_X, np.ndarray) and y_if_X_is_numpy is not None:
302
+ X_df = pd.DataFrame(df_or_X, columns=[f'feature_{i}' for i in range(df_or_X.shape[1])])
303
  y_series = pd.Series(y_if_X_is_numpy)
304
  else: raise ValueError("Invalid input for preprocess_tabular_data.")
305
 
 
309
 
310
  preprocessor = ColumnTransformer(transformers=[
311
  ('num', Pipeline([('imputer', SimpleImputer(strategy='mean')), ('scaler', StandardScaler())]), numerical_features),
312
+ ('cat', Pipeline([('imputer', SimpleImputer(strategy='most_frequent')), ('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=False))]), categorical_features)
313
+ ], remainder='passthrough')
314
 
315
  X_processed_np = preprocessor.fit_transform(X_df)
316
 
317
+ feature_names_out_list = []
318
+ try: feature_names_out_list = list(preprocessor.get_feature_names_out())
319
+ except AttributeError:
320
+ current_pos = 0
321
+ if numerical_features: feature_names_out_list.extend(numerical_features); current_pos += len(numerical_features)
322
+ if categorical_features:
323
+ cat_encoder = preprocessor.named_transformers_['cat'].named_steps['onehot']
324
+ if hasattr(cat_encoder, 'get_feature_names_out'):
325
+ cat_feature_names = cat_encoder.get_feature_names_out(categorical_features)
326
+ elif hasattr(cat_encoder, 'get_feature_names'):
327
+ cat_feature_names = cat_encoder.get_feature_names(categorical_features)
328
+ else: # Estimate number of one-hot features
329
+ num_onehot_cols = X_processed_np.shape[1] - len(numerical_features) # Assuming only num and cat
330
+ cat_feature_names = [f"cat_feat_{i}" for i in range(num_onehot_cols)]
331
+ feature_names_out_list.extend(cat_feature_names); current_pos += len(cat_feature_names)
332
+ # Handle remainder='passthrough' if necessary, X_processed_np.shape[1] would be total
333
+ if X_processed_np.shape[1] > current_pos:
334
+ feature_names_out_list.extend([f"other_feat_{i}" for i in range(X_processed_np.shape[1] - current_pos)])
335
+
336
 
337
  processed_input_dim = X_processed_np.shape[1]
338
  logs += f"Tabular data preprocessed. X shape: {X_processed_np.shape}, Processed input dim: {processed_input_dim}\n"
 
342
  y_processed_np = le.fit_transform(y_series)
343
  num_classes = len(le.classes_)
344
  logs += f"Target encoded. Classes: {num_classes} ({le.classes_})\n"
345
+ output_dim_nn = 1 if num_classes == 2 else num_classes # For NN output layer
 
 
 
346
  else: # Regression
347
  y_processed_np = y_series.astype(float).values
348
+ num_classes = 1
349
  output_dim_nn = 1
350
 
351
+ return X_processed_np, y_processed_np, preprocessor, logs, processed_input_dim, output_dim_nn, feature_names_out_list
352
 
353
 
354
+ # --- Training Functions (train_model_sklearn, train_model_pytorch) ---
355
  def train_model_sklearn(data_input_obj, target_column, task_type, model_name, model_output_format, current_logs=""):
356
  logs = current_logs + f"\n--- Training Scikit-learn Model: {model_name} ---\n"
357
  model_path_out, metrics_out, model_params_out = None, "Training failed.", "N/A"
358
 
359
  df = None
360
+ if isinstance(data_input_obj, str):
361
  try:
362
  if data_input_obj.endswith('.csv'): df = pd.read_csv(data_input_obj)
363
  elif data_input_obj.endswith('.json'): df = pd.read_json(data_input_obj, lines=True)
 
367
  elif isinstance(data_input_obj, pd.DataFrame): df = data_input_obj
368
  else: logs += "Invalid data for training.\n"; return logs, "Error: Invalid data.", None, "N/A"
369
 
370
+ if not target_column or target_column not in df.columns: # check if target_column is empty
371
+ logs += f"Target column '{target_column}' not provided or not found.\n"; return logs, f"Error: Target '{target_column}' not found/provided.", None, "N/A"
372
 
373
  try:
374
+ X_processed_np, y_processed_np, preprocessor, logs, _, _, feature_names_original = preprocess_tabular_data(df, None, target_column, task_type, logs)
375
  except ValueError as e: logs += f"Preprocessing error: {e}\n"; return logs, f"Error: {e}", None, "N/A"
376
 
377
  X_train, X_test, y_train, y_test = train_test_split(X_processed_np, y_processed_np, test_size=0.2, random_state=42)
 
381
  if task_type == "Tabular Classification":
382
  if model_name == "Logistic Regression": model = LogisticRegression(max_iter=1000, random_state=42)
383
  elif model_name == "Random Forest Classifier": model = RandomForestClassifier(random_state=42)
384
+ elif model_name == "Support Vector Machine (SVM) Classifier": model = SVC(random_state=42, probability=True)
385
  elif task_type == "Tabular Regression":
386
  if model_name == "Linear Regression": model = LinearRegression()
387
  elif model_name == "Random Forest Regressor": model = RandomForestRegressor(random_state=42)
 
397
  y_pred = model.predict(X_test)
398
 
399
  if task_type == "Tabular Classification":
400
+ acc = accuracy_score(y_test, y_pred); report = classification_report(y_test, y_pred, zero_division=0)
 
401
  metrics_out = f"Accuracy: {acc:.4f}\n\nClassification Report:\n{report}"
402
  elif task_type == "Tabular Regression":
403
+ mse = mean_squared_error(y_test, y_pred); r2 = r2_score(y_test, y_pred)
 
404
  metrics_out = f"Mean Squared Error: {mse:.4f}\nR2 Score: {r2:.4f}"
405
  logs += "\n--- Evaluation Metrics ---\n" + metrics_out + "\n"
 
 
406
  full_pipeline_for_saving = Pipeline([('preprocessor', preprocessor), ('model', model)])
407
  model_filename_base = f"sklearn_{model_name.replace(' ', '_').lower()}"
408
 
 
410
  model_path_out = get_temp_filepath(model_filename_base, "pkl")
411
  joblib.dump(full_pipeline_for_saving, model_path_out)
412
  logs += f"Model (with preprocessor) saved to {model_path_out} as PKL.\n"
 
413
  elif model_output_format == ".onnx (ONNX)":
414
  model_path_out = get_temp_filepath(model_filename_base, "onnx")
415
+ raw_X_for_types_df = df.drop(target_column, axis=1).infer_objects()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
  onnx_initial_types = []
417
+ for col_idx, col_name in enumerate(raw_X_for_types_df.columns):
418
+ col_dtype = raw_X_for_types_df[col_name].dtype
419
+ # Forcing float32 for numeric inputs to ONNX for broader compatibility
420
+ # ONNX is stricter about types than scikit-learn sometimes.
421
  if pd.api.types.is_numeric_dtype(col_dtype):
422
+ # Create a sample of the correct type for skl2onnx to infer shape and type
423
+ # Shape [None, 1] implies one feature at a time for this column.
424
+ # If a feature is multi-dimensional (e.g. embeddings), this needs adjustment.
425
+ # For typical tabular, each column is one feature.
426
  onnx_initial_types.append((col_name, FloatTensorType([None, 1])))
427
  elif pd.api.types.is_string_dtype(col_dtype) or col_dtype == 'object':
428
  onnx_initial_types.append((col_name, StringTensorType([None, 1])))
429
  else:
430
+ logs += f"Warning: Unsupported dtype {col_dtype} for column {col_name} in ONNX. Defaulting to FloatTensorType.\n"
431
+ onnx_initial_types.append((col_name, FloatTensorType([None, 1]))) # Fallback
432
 
433
+ if not onnx_initial_types: raise ValueError("ONNX initial types failed: No valid columns found.")
 
 
 
434
  try:
435
+ options = {id(full_pipeline_for_saving): {'zipmap': False}} if task_type.endswith("Classification") else {}
436
+ onnx_model = convert_sklearn(full_pipeline_for_saving, initial_types=onnx_initial_types, target_opset=12, options=options)
437
+ with open(model_path_out, "wb") as f: f.write(onnx_model.SerializeToString())
438
+ logs += f"Model saved to {model_path_out} as ONNX.\n"
 
 
 
 
439
  sess = rt.InferenceSession(model_path_out, providers=rt.get_available_providers())
440
+ logs += f"ONNX model loaded with ONNX Runtime. Inputs: {[inp.name for inp in sess.get_inputs()]}\n"
441
+ except Exception as onnx_e: logs += f"ONNX Error: {traceback.format_exc()}\n"; model_path_out=None; metrics_out+="\nONNX EXPORT FAILED."
442
+ else: # Fallback to PKL
 
 
 
 
443
  logs += f"Unsupported format '{model_output_format}'. Saving as .pkl\n"
444
  model_path_out = get_temp_filepath(model_filename_base, "pkl")
445
  joblib.dump(full_pipeline_for_saving, model_path_out)
446
+ except Exception as e: error_msg=f"Sklearn training error: {traceback.format_exc()}"; logs+=error_msg+"\n"; metrics_out=error_msg
 
 
447
  return logs, metrics_out, model_path_out, model_params_out
448
 
 
449
  def train_model_pytorch(data_input_obj, target_column, task_type, model_type_pt,
450
  mlp_hidden_layers_str, mlp_activation,
 
 
451
  epochs_str, batch_size_str, lr_str,
452
  model_output_format, current_logs=""):
453
  logs = current_logs + f"\n--- Training PyTorch Model: {model_type_pt} ---\n"
454
  model_path_out, metrics_out, model_params_out, plot_out = None, "Training failed.", "N/A", None
455
 
456
+ df_for_pytorch, X_numpy_for_pytorch, y_numpy_for_pytorch = None, None, None
457
+ if isinstance(data_input_obj, str):
458
  try:
 
459
  if data_input_obj.endswith('.csv'): df_for_pytorch = pd.read_csv(data_input_obj)
460
  elif data_input_obj.endswith('.json'): df_for_pytorch = pd.read_json(data_input_obj, lines=True)
461
  elif data_input_obj.endswith('.parquet'): df_for_pytorch = pd.read_parquet(data_input_obj)
462
  else: logs += f"Unsupported file: {data_input_obj}\n"; return logs, "Error", None, "N/A", None
463
  except Exception as e: logs += f"Error reading {data_input_obj}: {e}\n"; return logs, f"Error: {e}", None, "N/A", None
464
  elif isinstance(data_input_obj, pd.DataFrame): df_for_pytorch = data_input_obj
465
+ elif isinstance(data_input_obj, tuple) and len(data_input_obj) == 2 and isinstance(data_input_obj[0], np.ndarray) and isinstance(data_input_obj[1], np.ndarray):
466
+ X_numpy_for_pytorch, y_numpy_for_pytorch = data_input_obj
 
467
  else: logs += "Invalid data for PyTorch training.\n"; return logs, "Error", None, "N/A", None
468
 
469
  try:
 
471
  if not (epochs > 0 and batch_size > 0 and lr > 0): raise ValueError("Params must be >0.")
472
  except ValueError as e: logs += f"Invalid training params: {e}\n"; return logs, f"Error: {e}", None, "N/A", None
473
 
474
+ processed_input_dim_actual, nn_output_dim_actual, preprocessor_pipeline = -1, -1, None
475
+ X_processed_np, y_processed_np = None, None
476
 
477
  if model_type_pt == "Simple Neural Network (MLP)":
478
  if not task_type.startswith("Tabular"):
479
  logs += "MLP requires Tabular task.\n"; return logs, "MLP Task Error", None, "N/A", None
480
+ if not target_column and df_for_pytorch is not None: # Check if target column is provided for DataFrame
481
+ logs += "Target column needed for MLP with DataFrame input.\n"; return logs, "MLP Target Error", None, "N/A", None
482
+
483
  try:
 
484
  data_arg1 = df_for_pytorch if df_for_pytorch is not None else X_numpy_for_pytorch
485
  data_arg2 = y_numpy_for_pytorch if df_for_pytorch is None else None
486
+ # target_column is only relevant if data_arg1 is a DataFrame
487
+ current_target_col = target_column if df_for_pytorch is not None else "target" # Placeholder if from numpy
488
+
489
  X_processed_np, y_processed_np, preprocessor_pipeline, logs, processed_input_dim_actual, nn_output_dim_actual, _ = \
490
+ preprocess_tabular_data(data_arg1, data_arg2, current_target_col, task_type, logs)
491
  except ValueError as e: logs+=f"MLP Preprocessing error: {e}\n"; return logs,f"Error: {e}",None,"N/A",None
 
492
  elif model_type_pt == "Simple Convolutional Network (CNN)":
493
+ if task_type != "Basic Image Classification": logs += "Warning: CNN selected, but task is not Basic Image Classification.\n"
 
494
 
495
+ X_raw, y_raw = None, None
496
  if df_for_pytorch is not None:
497
+ if not target_column or target_column not in df_for_pytorch.columns:
498
+ logs += f"Target '{target_column}' not found/provided for CNN.\n"; return logs, "CNN Target Error", None, "N/A", None
499
  X_raw = df_for_pytorch.drop(target_column, axis=1).values
500
  y_raw = df_for_pytorch[target_column].values
501
  elif X_numpy_for_pytorch is not None and y_numpy_for_pytorch is not None:
502
+ X_raw = X_numpy_for_pytorch; y_raw = y_numpy_for_pytorch
503
+ else: logs += "No valid data for CNN.\n"; return logs, "CNN Data Error", None, "N/A", None
 
 
504
 
505
  le = LabelEncoder(); y_processed_np = le.fit_transform(y_raw)
506
  nn_output_dim_actual = len(le.classes_)
507
+ if nn_output_dim_actual == 2: nn_output_dim_actual = 1
508
 
509
+ pixels_per_sample = X_raw.shape[1]; img_h, img_w, input_channels = 28,28,1
510
  img_dim_approx = int(math.sqrt(pixels_per_sample))
511
+ if img_dim_approx * img_dim_approx == pixels_per_sample: img_h, img_w = img_dim_approx, img_dim_approx
512
+ else: logs += f"Warning: Cannot infer square image from {pixels_per_sample} pixels. Defaulting to 28x28.\n"
 
 
513
 
 
514
  X_processed_np = X_raw.reshape(-1, input_channels, img_h, img_w).astype(np.float32) / 255.0
515
+ processed_input_dim_actual = (input_channels, img_h, img_w)
516
  logs += f"CNN Data: X reshaped to {X_processed_np.shape}, y: {y_processed_np.shape}, NN Output Dim: {nn_output_dim_actual}\n"
517
  else: logs += f"Unknown PyTorch model: {model_type_pt}\n"; return logs, "Unknown PyTorch model", None, "N/A", None
518
 
 
 
519
  y_dtype = torch.float32 if (nn_output_dim_actual == 1 and task_type.endswith("Regression")) or \
520
  (nn_output_dim_actual == 1 and task_type.endswith("Classification")) \
521
+ else torch.long
522
+ X_tensor = torch.tensor(X_processed_np, dtype=torch.float32)
523
  y_tensor = torch.tensor(y_processed_np, dtype=y_dtype)
524
+ if nn_output_dim_actual == 1 and task_type.endswith("Classification"): y_tensor = y_tensor.unsqueeze(1)
525
+ if task_type.endswith("Regression"): y_tensor = y_tensor.unsqueeze(1)
526
 
527
  dataset = TensorDataset(X_tensor, y_tensor)
 
528
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)
529
 
530
  pytorch_model = None
 
535
  task_type="classification" if task_type.endswith("Classification") else "regression")
536
  elif model_type_pt == "Simple Convolutional Network (CNN)":
537
  channels, h, w = processed_input_dim_actual
538
+ pytorch_model = SimpleCNN(input_channels=channels, img_size_wh=(h,w), num_classes=nn_output_dim_actual,
539
+ task_type="classification" if task_type.endswith("Classification") else "regression") # Pass task_type
540
+ except Exception as model_e: logs += f"Error creating PyTorch model: {traceback.format_exc()}\n"; return logs, f"Model Creation Error: {model_e}", None, "N/A", None
541
 
542
  if pytorch_model is None: logs += "Failed to instantiate PyTorch model.\n"; return logs, "Model instantiate fail", None, "N/A", None
543
+ model_params_val = count_pytorch_parameters(pytorch_model); model_params_out = f"{model_params_val:,}"
 
544
  logs += f"PyTorch Model: {model_params_out} params.\n"
545
  if model_params_val > 500000: logs += "Warning: >500k params on CPU will be SLOW.\n"
546
 
547
+ is_classification_task = task_type.endswith("Classification") # Simplified condition
548
  if is_classification_task:
549
  criterion = nn.BCELoss() if nn_output_dim_actual == 1 else nn.CrossEntropyLoss()
550
+ else: criterion = nn.MSELoss()
 
551
  optimizer = optim.Adam(pytorch_model.parameters(), lr=lr)
552
 
553
  logs += f"Starting PyTorch training for {epochs} epochs...\n"; start_time = time.time()
 
555
  pytorch_model.train()
556
  for epoch in range(epochs):
557
  epoch_loss_sum = 0.0; num_batches = 0
558
+ for i, (batch_X, batch_y) in enumerate(dataloader): # Added enumerate for batch index
559
+ optimizer.zero_grad(); outputs = pytorch_model(batch_X)
560
+ loss = criterion(outputs, batch_y); loss.backward(); optimizer.step()
 
 
561
  epoch_loss_sum += loss.item(); num_batches += 1
562
  avg_epoch_loss = epoch_loss_sum / num_batches if num_batches > 0 else 0
563
  epoch_losses.append(avg_epoch_loss)
564
  logs += f"Epoch {epoch+1}/{epochs}, Avg Loss: {avg_epoch_loss:.4f}\n"
 
 
 
 
565
 
566
+ logs += f"PyTorch training completed in {time.time() - start_time:.2f}s.\n"
 
567
  pytorch_model.eval()
568
  with torch.no_grad():
569
+ if is_classification_task and dataloader.dataset and len(dataloader.dataset)>0 :
570
+ try:
571
+ all_preds, all_targets = [], []
572
+ for batch_X, batch_y in dataloader: # Evaluate on whole dataset (or a test split ideally)
573
+ outputs = pytorch_model(batch_X)
574
+ if nn_output_dim_actual == 1: predicted = (outputs > 0.5).float()
575
+ else: _, predicted = torch.max(outputs.data, 1)
576
+ all_preds.extend(predicted.cpu().numpy())
577
+ all_targets.extend(batch_y.cpu().numpy())
578
+
579
+ if all_targets and all_preds: # Check if lists are not empty
580
+ # Ensure all_targets is 1D for accuracy_score if predicted is also 1D (binary case)
581
+ all_targets_np = np.array(all_targets).squeeze()
582
+ all_preds_np = np.array(all_preds).squeeze()
583
+ acc = accuracy_score(all_targets_np, all_preds_np)
584
+ metrics_out = f"Final Training Loss: {avg_epoch_loss:.4f}\nAccuracy on training data: {acc*100:.2f}%"
585
+ else:
586
+ metrics_out = f"Final Training Loss: {avg_epoch_loss:.4f}\n (Could not compute accuracy)"
587
+
588
+ except Exception as eval_e: metrics_out = f"Final Training Loss: {avg_epoch_loss:.4f}\n Eval Error: {eval_e}"
589
+ else: metrics_out = f"Final Training Loss (MSE): {avg_epoch_loss:.4f}"
590
+ logs += "\n--- PyTorch Metrics ---\n" + metrics_out + "\n"
 
 
591
 
 
592
  if epoch_losses:
593
+ import matplotlib # Use Agg backend for non-interactive environments like Spaces
594
+ matplotlib.use('Agg')
595
  import matplotlib.pyplot as plt
596
+ fig, ax = plt.subplots(); ax.plot(range(1, epochs + 1), epoch_losses, marker='o')
597
+ ax.set_xlabel("Epoch"); ax.set_ylabel("Average Loss"); ax.set_title("Training Loss Curve")
598
+ plot_out = fig; logs += "Loss curve generated.\n"
 
 
 
 
 
599
 
 
600
  model_filename_base = f"pytorch_{model_type_pt.replace(' ', '_').lower()}"
601
  if model_output_format == ".pt (PyTorch)":
602
  model_path_out = get_temp_filepath(model_filename_base, "pt")
603
+ save_obj = {'model_state_dict': pytorch_model.state_dict(), 'output_dim_nn': nn_output_dim_actual, 'task_type': task_type}
604
  if model_type_pt == "Simple Neural Network (MLP)" and preprocessor_pipeline:
605
+ save_obj.update({
606
+ 'preprocessor': preprocessor_pipeline, 'input_dim_processed': processed_input_dim_actual,
607
+ 'hidden_layers_str': mlp_hidden_layers_str, 'activation_fn': mlp_activation,
608
+ })
 
 
 
 
 
609
  logs += f"PyTorch MLP (model + preprocessor) saved to {model_path_out}\n"
610
+ elif model_type_pt == "Simple Convolutional Network (CNN)":
611
+ c,h,w = processed_input_dim_actual
612
+ save_obj.update({'input_channels':c, 'img_h':h, 'img_w':w}) # Save CNN architecture details
613
+ logs += f"PyTorch CNN (model state_dict + arch_details) saved to {model_path_out}\n"
614
+ else: logs += f"PyTorch {model_type_pt} (model state_dict) saved to {model_path_out}\n"
615
+ torch.save(save_obj, model_path_out)
616
+ else: # Fallback
617
+ logs += f"Unsupported format '{model_output_format}'. Saving as .pt\n"
618
  model_path_out = get_temp_filepath(model_filename_base, "pt")
619
+ torch.save(pytorch_model.state_dict(), model_path_out)
 
620
  return logs, metrics_out, model_path_out, model_params_out, plot_out
621
 
 
622
  # --- Gradio UI Definition ---
623
+ TASK_CHOICES = ["Tabular Classification", "Tabular Regression", "Basic Image Classification"]
 
624
  MODEL_FAMILIES = ["Scikit-learn (Classical ML)", "PyTorch (Neural Networks)"]
625
  SKLEARN_MODELS_CLASSIFICATION = ["Logistic Regression", "Random Forest Classifier", "Support Vector Machine (SVM) Classifier"]
626
  SKLEARN_MODELS_REGRESSION = ["Linear Regression", "Random Forest Regressor", "Support Vector Machine (SVR) Regressor"]
627
  PYTORCH_MODELS = ["Simple Neural Network (MLP)", "Simple Convolutional Network (CNN)"]
628
  DATASET_FORMATS = [".csv", ".json", ".parquet"]
629
  MODEL_OUTPUT_FORMATS_SKLEARN = [".pkl (Scikit-learn)", ".onnx (ONNX)"]
630
+ MODEL_OUTPUT_FORMATS_PYTORCH = [".pt (PyTorch)"]
631
  MLP_ACTIVATIONS = ["relu", "tanh", "sigmoid"]
 
632
  CLONE_GUIDE_TEXT = """
633
  ## How to Clone & Upgrade This Space for More Power:
634
  (Instructions as provided in previous response - omitted here for brevity but should be included)
635
  """
636
 
637
+ # Determine initial choices for model_specific_dd based on default task_type and model_family
638
+ _initial_task_default = TASK_CHOICES[0]
639
+ _initial_family_default = MODEL_FAMILIES[0]
640
+ initial_model_choices_for_specific_dd = []
641
+ if _initial_family_default == "Scikit-learn (Classical ML)":
642
+ if _initial_task_default == "Tabular Classification": initial_model_choices_for_specific_dd = SKLEARN_MODELS_CLASSIFICATION
643
+ elif _initial_task_default == "Tabular Regression": initial_model_choices_for_specific_dd = SKLEARN_MODELS_REGRESSION
644
+ elif _initial_family_default == "PyTorch (Neural Networks)":
645
+ if _initial_task_default.startswith("Tabular"): initial_model_choices_for_specific_dd = [PYTORCH_MODELS[0]]
646
+ elif _initial_task_default == "Basic Image Classification": initial_model_choices_for_specific_dd = [PYTORCH_MODELS[1]]
647
+ initial_model_value_for_specific_dd = initial_model_choices_for_specific_dd[0] if initial_model_choices_for_specific_dd else None
648
+
649
  def update_model_options(task_choice, model_family_choice):
650
+ choices, value = [], None
651
  if model_family_choice == "Scikit-learn (Classical ML)":
652
+ if task_choice == "Tabular Classification": choices = SKLEARN_MODELS_CLASSIFICATION
653
+ elif task_choice == "Tabular Regression": choices = SKLEARN_MODELS_REGRESSION
 
654
  elif model_family_choice == "PyTorch (Neural Networks)":
655
+ if task_choice.startswith("Tabular"): choices = [PYTORCH_MODELS[0]] # MLP
656
+ elif task_choice == "Basic Image Classification": choices = [PYTORCH_MODELS[1]] # CNN
657
+ value = choices[0] if choices else None
658
+ return gr.update(choices=choices, value=value, visible=bool(choices))
659
 
660
  def update_param_range_visibility(model_family_choice):
661
  return gr.update(visible=(model_family_choice == "PyTorch (Neural Networks)"))
662
 
663
+ def update_pytorch_specific_options_visibility(model_choice_pytorch_family, specific_pytorch_model):
664
+ # Only proceed if family is PyTorch
665
+ if model_choice_pytorch_family != "PyTorch (Neural Networks)":
666
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) # Hide all: param_range, mlp_group, cnn_group
667
+
668
+ param_range_visible = True # Always true if PyTorch family
669
+ is_mlp = (specific_pytorch_model == "Simple Neural Network (MLP)")
670
+ is_cnn = (specific_pytorch_model == "Simple Convolutional Network (CNN)")
671
+ return gr.update(visible=param_range_visible), gr.update(visible=is_mlp), gr.update(visible=is_cnn)
672
 
673
  def update_model_output_formats(model_family_choice):
674
+ if model_family_choice == "Scikit-learn (Classical ML)": return gr.update(choices=MODEL_OUTPUT_FORMATS_SKLEARN, value=MODEL_OUTPUT_FORMATS_SKLEARN[0])
675
+ elif model_family_choice == "PyTorch (Neural Networks)": return gr.update(choices=MODEL_OUTPUT_FORMATS_PYTORCH, value=MODEL_OUTPUT_FORMATS_PYTORCH[0])
 
 
676
  return gr.update(choices=[], value=None)
677
 
678
+ css = """.gradio-container { font-family: 'IBM Plex Sans', sans-serif; } footer {display:none !important}""" # Hide footer too
 
 
 
 
 
 
679
 
680
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="orange"), css=css) as demo:
681
+ gr.Markdown("# 🧠 TrainAI ⚙️")
682
  gr.Markdown("Create, train, and download AI models. Optimized for CPU - expect longer training for complex models.")
 
 
 
 
683
  generated_data_state = gr.State(None)
684
+ current_logs_state = gr.State("")
685
 
686
  with gr.Tabs():
687
  with gr.TabItem("1. Define Task & Model"):
688
  with gr.Row():
689
+ task_type_dd = gr.Dropdown(TASK_CHOICES, label="Select Task Type", value=_initial_task_default)
690
+ model_family_dd = gr.Dropdown(MODEL_FAMILIES, label="Select Model Family", value=_initial_family_default)
691
+ model_specific_dd = gr.Dropdown(label="Select Specific Model", choices=initial_model_choices_for_specific_dd, value=initial_model_value_for_specific_dd, interactive=True)
692
 
 
 
 
693
  pytorch_param_range_dd = gr.Dropdown(list(PARAM_RANGES.keys()), label="Target Parameter Range (for NNs)",
694
  info="Guides NN architecture suggestions. Training >250k params on CPU is slow.",
695
+ value=list(PARAM_RANGES.keys())[1], visible=(_initial_family_default == "PyTorch (Neural Networks)"))
696
+ with gr.Group(visible=(_initial_family_default == "PyTorch (Neural Networks)" and initial_model_value_for_specific_dd == PYTORCH_MODELS[0])) as pt_mlp_specific_group:
 
 
697
  gr.Markdown("#### MLP Configuration")
 
 
698
  pt_mlp_hidden_layers_txt = gr.Textbox(label="Hidden Layer Sizes (comma-separated, e.g., 128,64)", value="64,32")
699
  pt_mlp_activation_dd = gr.Dropdown(MLP_ACTIVATIONS, label="Activation Function", value="relu")
700
+ with gr.Row():
701
+ pt_mlp_suggest_btn = gr.Button("Suggest MLP Layers")
702
+ pt_mlp_estimate_params_btn = gr.Button("Estimate Current MLP Params")
703
  pt_mlp_param_count_txt = gr.Textbox(label="Estimated MLP Parameters", interactive=False)
704
+ with gr.Group(visible=(_initial_family_default == "PyTorch (Neural Networks)" and initial_model_value_for_specific_dd == PYTORCH_MODELS[1])) as pt_cnn_specific_group:
705
+ gr.Markdown("#### CNN Configuration (Simplified)")
706
+ gr.Markdown("SimpleCNN uses a fixed structure. Params depend on image size/classes from data.")
707
+ pt_cnn_estimate_params_btn = gr.Button("Estimate CNN Params (needs Data Info)")
 
 
 
 
 
 
 
 
 
708
  pt_cnn_param_count_txt = gr.Textbox(label="Estimated CNN Parameters", interactive=False)
 
709
 
710
+ with gr.TabItem("2. Configure Dataset"): # This tab should show generate_dataset_group by default
 
711
  dataset_source_rb = gr.Radio(["Generate new dataset", "Upload my own dataset (CSV, JSON, Parquet)"],
712
  label="Dataset Source", value="Generate new dataset")
713
+ with gr.Group(visible=True) as generate_dataset_group: # Default visible=True
 
714
  gr.Markdown("#### Generate Synthetic Dataset")
715
  with gr.Row():
716
+ ds_gen_samples_num = gr.Number(label="# Samples", value=1000, minimum=10, step=100)
717
+ ds_gen_features_num = gr.Number(label="# Features (Tabular)", value=10, minimum=1, step=1)
718
+ ds_gen_classes_informative_num = gr.Number(label="Classes (Classif) / Informative Feats (Regr)", value=2, minimum=1, step=1)
719
+ ds_gen_ai_suggest_cb = gr.Checkbox(label="Let AI suggest dataset shape?", value=False)
720
+ ds_gen_format_dd = gr.Dropdown(DATASET_FORMATS, label="Generated Dataset Format", value=".csv")
721
  generate_dataset_btn = gr.Button("Generate & Preview Dataset", variant="secondary")
 
722
  with gr.Group(visible=False) as upload_dataset_group:
723
  gr.Markdown("#### Upload Dataset")
724
+ ds_upload_file = gr.File(label="Upload dataset", file_types=[".csv", ".json", ".parquet"])
725
+ target_column_name_txt = gr.Textbox(label="Target Column Name (Case-Sensitive!)", placeholder="e.g., 'target' or 'label'")
726
+ dataset_preview_df = gr.DataFrame(label="Dataset Preview (First 5 Rows)", interactive=False, height=200)
 
727
  generated_dataset_download_file = gr.File(label="Download Generated Dataset", interactive=False)
728
 
729
  with gr.TabItem("3. Train Model & Get Results"):
730
+ gr.Markdown("Ensure Model and Dataset are configured. For NNs, set Epochs/Batch/LR.")
731
  with gr.Row():
732
+ train_epochs_num = gr.Number(label="Epochs (NNs)", value=10, minimum=1, step=1)
733
+ train_batch_size_num = gr.Number(label="Batch Size (NNs)", value=32, minimum=1, step=1)
734
+ train_learning_rate_num = gr.Number(label="Learning Rate (NNs)", value=0.001, minimum=1e-6, step=1e-4, precision=6)
735
+ model_output_format_dd = gr.Dropdown(label="Select Model Output Format", choices=MODEL_OUTPUT_FORMATS_SKLEARN, value=MODEL_OUTPUT_FORMATS_SKLEARN[0])
 
 
 
 
736
  train_model_btn = gr.Button("🚀 Train Model", variant="primary")
 
737
  gr.Markdown("---")
738
  gr.Markdown("### Training Progress & Results")
739
  training_log_txt = gr.Textbox(label="Training Log & Status", lines=15, interactive=False, max_lines=50)
 
743
  download_trained_model_file = gr.File(label="Download Trained Model", interactive=False)
744
 
745
  with gr.TabItem("ℹ️ Guide & Info"):
746
+ # ... Guide content ... (omitted for brevity)
 
 
 
 
 
747
  gr.Markdown(CLONE_GUIDE_TEXT)
748
 
749
  # --- Event Handlers ---
 
750
  task_type_dd.change(fn=update_model_options, inputs=[task_type_dd, model_family_dd], outputs=model_specific_dd)
751
  model_family_dd.change(fn=update_model_options, inputs=[task_type_dd, model_family_dd], outputs=model_specific_dd)
752
 
753
+ model_family_dd.change(fn=update_pytorch_specific_options_visibility, inputs=[model_family_dd, model_specific_dd], outputs=[pytorch_param_range_dd, pt_mlp_specific_group, pt_cnn_specific_group])
754
+ model_specific_dd.change(fn=update_pytorch_specific_options_visibility, inputs=[model_family_dd, model_specific_dd], outputs=[pytorch_param_range_dd, pt_mlp_specific_group, pt_cnn_specific_group])
755
+
756
+ def get_data_dims_for_nn_suggestion(preview_df, target_col, task, logs_in):
757
+ logs = logs_in
758
+ input_dim_est, output_dim_est = 10, (2 if task.endswith("Classification") else 1) # Defaults
759
+ img_h_est, img_w_est = 28, 28 # Defaults for CNN
760
+
761
+ if preview_df is not None and isinstance(preview_df, pd.DataFrame) and not preview_df.empty:
762
+ temp_X_cols = [col for col in preview_df.columns if col != target_col] # Features
763
+ if not temp_X_cols and task == "Basic Image Classification": # Image data often has no named feature cols in preview
764
+ if preview_df.shape[1] == 1 and target_col in preview_df.columns: # Only target col
765
+ pass # Cannot estimate image from only target
766
+ elif target_col in preview_df.columns:
767
+ num_pixels = preview_df.shape[1] -1
768
+ else: # no target col (e.g. raw image data)
769
+ num_pixels = preview_df.shape[1]
770
+
771
+ if num_pixels > 0:
772
+ dim_sqrt = int(math.sqrt(num_pixels))
773
+ if dim_sqrt * dim_sqrt == num_pixels: img_h_est, img_w_est = dim_sqrt, dim_sqrt
774
+ else: logs += f"Non-square image ({num_pixels} pixels) from preview. Using default {img_h_est}x{img_w_est} for suggestion.\n"
775
+ input_dim_est = img_h_est * img_w_est # For CNN, this is not input_dim to MLP but for SimpleCNN internal calcs
776
+ elif temp_X_cols: # Tabular
777
+ num_cols = len([col for col in temp_X_cols if pd.api.types.is_numeric_dtype(preview_df[col])])
778
+ cat_cols = [col for col in temp_X_cols if pd.api.types.is_object_dtype(preview_df[col])]
779
+ one_hot_est = sum(min(10, preview_df[col].nunique(dropna=False)) for col in cat_cols)
780
+ input_dim_est = max(1, num_cols + one_hot_est)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
781
 
782
+ if target_col and target_col in preview_df.columns:
783
+ if task.endswith("Classification"):
784
+ output_dim_est = max(1, preview_df[target_col].nunique(dropna=False))
785
+ if output_dim_est == 2: output_dim_est = 1
786
+ else: logs += "Dataset preview not available for NN dimension estimation. Using defaults.\n"
787
+ return input_dim_est, output_dim_est, img_h_est, img_w_est, logs
788
+
789
+ def mlp_suggest_proxy_wrapper(target_range_str, current_logs, preview_df, target_col, task_type):
790
+ input_dim, output_dim, _, _, logs = get_data_dims_for_nn_suggestion(preview_df, target_col, task_type, current_logs)
791
+ logs += f"Using estimated input_dim: {input_dim}, output_dim: {output_dim} for MLP suggestion.\n"
792
+ suggested_str, logs = suggest_mlp_layers_for_range(input_dim, output_dim, target_range_str, logs)
793
  param_count_str = "Error"
794
+ if suggested_str: param_count_str, logs = estimate_current_mlp_params(str(input_dim), suggested_str, str(output_dim), task_type, logs)
 
 
795
  return suggested_str, logs, param_count_str
796
+
797
+ pt_mlp_suggest_btn.click(fn=mlp_suggest_proxy_wrapper,
 
798
  inputs=[pytorch_param_range_dd, current_logs_state, dataset_preview_df, target_column_name_txt, task_type_dd],
799
+ outputs=[pt_mlp_hidden_layers_txt, training_log_txt, pt_mlp_param_count_txt])
800
+
801
+ def mlp_estimate_proxy_wrapper(hidden_layers, current_logs, preview_df, target_col, task_type):
802
+ input_dim, output_dim, _, _, logs = get_data_dims_for_nn_suggestion(preview_df, target_col, task_type, current_logs)
803
+ logs += f"Using estimated input_dim: {input_dim}, output_dim: {output_dim} for MLP param estimation.\n"
804
+ param_count_str, logs = estimate_current_mlp_params(str(input_dim), hidden_layers, str(output_dim), task_type, logs)
805
+ return logs, param_count_str
806
 
807
+ pt_mlp_estimate_params_btn.click(fn=mlp_estimate_proxy_wrapper,
808
+ inputs=[pt_mlp_hidden_layers_txt, current_logs_state, dataset_preview_df, target_column_name_txt, task_type_dd],
809
+ outputs=[training_log_txt, pt_mlp_param_count_txt])
810
+
811
+ def cnn_estimate_proxy_wrapper(current_logs, preview_df, target_col, task_type):
812
+ _, output_dim, img_h, img_w, logs = get_data_dims_for_nn_suggestion(preview_df, target_col, task_type, current_logs)
813
+ logs += f"Using estimated img_h: {img_h}, img_w: {img_w}, output_dim: {output_dim} for CNN param estimation.\n"
814
+ # For CNN, task type for constructor is 'classification' typically
815
+ cnn_task_type_for_constructor = "classification" if task_type == "Basic Image Classification" else "regression" # Placeholder
816
+ param_count_str, logs = estimate_cnn_params(str(img_h), str(img_w), str(output_dim), cnn_task_type_for_constructor, logs)
817
+ return logs, param_count_str
818
+
819
+ pt_cnn_estimate_params_btn.click(fn=cnn_estimate_proxy_wrapper,
820
+ inputs=[current_logs_state, dataset_preview_df, target_column_name_txt, task_type_dd],
821
+ outputs=[training_log_txt, pt_cnn_param_count_txt])
822
 
 
823
  def toggle_dataset_source_groups(source_choice):
824
+ return gr.update(visible=(source_choice == "Generate new dataset")), gr.update(visible=(source_choice == "Upload my own dataset (CSV, JSON, Parquet)"))
825
+ dataset_source_rb.change(fn=toggle_dataset_source_groups, inputs=dataset_source_rb, outputs=[generate_dataset_group, upload_dataset_group])
 
 
 
 
826
  model_family_dd.change(fn=update_model_output_formats, inputs=model_family_dd, outputs=model_output_format_dd)
827
 
 
 
 
 
 
 
 
 
 
 
 
828
  generate_dataset_btn.click(
829
+ fn=generate_dataset_backend, # Pass the function directly
830
  inputs=[task_type_dd, ds_gen_samples_num, ds_gen_features_num, ds_gen_classes_informative_num,
831
  ds_gen_format_dd, ds_gen_ai_suggest_cb, pytorch_param_range_dd, model_specific_dd, current_logs_state],
832
+ outputs=[dataset_preview_df, generated_data_state, training_log_txt, generated_dataset_download_file])
 
833
 
 
834
  def process_uploaded_file(file_obj, logs_in):
835
+ logs, df_preview, status_msg, stored_data_path = logs_in, None, "Upload failed or no file.", None
836
+ if file_obj is None: logs += "Please upload a file first.\n"; return df_preview, logs, status_msg, stored_data_path
837
+ logs += f"Uploaded file: {file_obj.name}\n"; stored_data_path = file_obj.name
 
 
 
 
838
  try:
839
+ if file_obj.name.endswith(".csv"): df_preview = pd.read_csv(file_obj.name, nrows=5)
840
+ elif file_obj.name.endswith(".json"): df_preview = pd.read_json(file_obj.name, lines=True, nrows=5)
841
+ elif file_obj.name.endswith(".parquet"): temp_df = pd.read_parquet(file_obj.name); df_preview = temp_df.head() # Parquet preview needs full read
842
+ status_msg = "Preview generated for uploaded file." if df_preview is not None else "Could not generate preview, but file is noted."
843
+ logs += status_msg + "\n"
844
+ except Exception as e: logs += f"Error previewing {file_obj.name}: {e}\n"; status_msg = f"Error previewing: {e}"
845
+ return df_preview, logs, status_msg, stored_data_path
846
+ ds_upload_file.upload(fn=process_uploaded_file, inputs=[ds_upload_file, current_logs_state],
847
+ outputs=[dataset_preview_df, training_log_txt, training_log_txt, generated_data_state])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
848
 
849
  train_model_btn.click(
850
+ fn=train_model_wrapper, # Wrapper function defined before UI
851
+ inputs=[generated_data_state, target_column_name_txt, task_type_dd, model_family_dd, model_specific_dd,
852
+ model_specific_dd, pt_mlp_hidden_layers_txt, pt_mlp_activation_dd,
853
+ train_epochs_num, train_batch_size_num, train_learning_rate_num,
854
+ model_output_format_dd, training_log_txt],
855
+ outputs=[training_log_txt, evaluation_metrics_txt, download_trained_model_file,
856
+ model_param_count_output_txt, loss_plot_img, download_trained_model_file])
857
+
858
+ demo.queue().launch(debug=True, show_error=True)