szili2011 commited on
Commit
fbb3355
·
verified ·
1 Parent(s): 591f159

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -715
app.py CHANGED
@@ -1,826 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
2
  import pandas as pd
3
  import numpy as np
4
- import sklearn
 
5
  from sklearn.model_selection import train_test_split
6
  from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder
7
  from sklearn.impute import SimpleImputer
8
  from sklearn.compose import ColumnTransformer
9
  from sklearn.pipeline import Pipeline
10
- # Scikit-learn Models
11
  from sklearn.linear_model import LogisticRegression, LinearRegression
12
  from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
13
  from sklearn.svm import SVC, SVR
14
- # Metrics
15
  from sklearn.metrics import accuracy_score, classification_report, mean_squared_error, r2_score
16
- # Dataset generators
17
  from sklearn.datasets import make_classification, make_regression
18
-
19
  import joblib
20
- import os
21
- import time
22
  import torch
23
  import torch.nn as nn
24
  import torch.optim as optim
25
  from torch.utils.data import TensorDataset, DataLoader
26
- import torchvision # For transforms, even if data is basic
27
- import torchvision.transforms as T
28
 
29
- # ONNX specific imports
30
  import skl2onnx
31
  from skl2onnx import convert_sklearn
32
- from skl2onnx.common.data_types import FloatTensorType, Int64TensorType, StringTensorType # Ensure this import if used for ONNX initial types
33
- import onnxruntime as rt
34
 
35
- import traceback
36
- import tempfile
37
- import json
38
- import math
39
- import collections.abc # For Gradio issue with new Python versions
40
- import collections # Added for OrderedDict if not already covered
41
- import matplotlib # Use Agg backend for non-interactive environments
42
- matplotlib.use('Agg')
43
  import matplotlib.pyplot as plt
44
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- # --- Global Variables / Constants ---
47
  TEMP_DIR = "temp_outputs"
48
  os.makedirs(TEMP_DIR, exist_ok=True)
49
- MAX_DATASET_ROWS_WARN = 30000
50
  MAX_GENERATED_ROWS = 50000
51
  MAX_GENERATED_COLS = 100
 
 
 
 
 
 
52
 
53
- # --- Helper Functions (count_parameters, get_temp_filepath) ---
54
- def count_sklearn_parameters(model):
55
- if hasattr(model, 'coef_'):
56
- return model.coef_.size + (model.intercept_.size if hasattr(model, 'intercept_') else 0)
57
- if hasattr(model, 'support_vectors_'):
58
- return model.support_vectors_.size
59
- if isinstance(model, (RandomForestClassifier, RandomForestRegressor)):
60
- try:
61
- return sum(tree.tree_.node_count for tree in model.estimators_)
62
- except: return "N/A (Complex Ensemble)"
63
- return "N/A"
64
-
65
- def count_pytorch_parameters(model):
66
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
67
-
68
  def get_temp_filepath(filename_base, extension):
 
69
  clean_extension = extension.lstrip('.')
70
  return os.path.join(TEMP_DIR, f"{filename_base}_{time.strftime('%Y%m%d-%H%M%S')}.{clean_extension}")
71
 
72
- # --- PyTorch Model Definitions (SimpleMLP, SimpleCNN) ---
73
  class SimpleMLP(nn.Module):
 
74
  def __init__(self, input_dim, hidden_layers_str, output_dim, activation_fn_str="relu", task_type="classification"):
75
- super(SimpleMLP, self).__init__()
76
  layers = []
77
- if not isinstance(input_dim, int) or input_dim <= 0:
78
- raise ValueError(f"Input dimension must be a positive integer, got {input_dim}")
79
-
80
- hidden_units_list = []
81
- if hidden_layers_str and isinstance(hidden_layers_str, str) and hidden_layers_str.strip():
82
- try:
83
- hidden_units_list = [int(x.strip()) for x in hidden_layers_str.split(',') if x.strip()]
84
- if any(h_units <= 0 for h_units in hidden_units_list):
85
- raise ValueError("Hidden layer units must be positive integers.")
86
- except ValueError as e:
87
- raise ValueError(f"Invalid hidden layer string '{hidden_layers_str}'. Error: {e}")
88
 
89
  current_dim = input_dim
90
- for h_units in hidden_units_list:
91
  layers.append(nn.Linear(current_dim, h_units))
92
  if activation_fn_str.lower() == "relu": layers.append(nn.ReLU())
93
  elif activation_fn_str.lower() == "tanh": layers.append(nn.Tanh())
94
  elif activation_fn_str.lower() == "sigmoid": layers.append(nn.Sigmoid())
95
- else: layers.append(nn.ReLU())
96
  current_dim = h_units
97
 
98
  layers.append(nn.Linear(current_dim, output_dim))
99
 
100
- if task_type == "classification":
101
- if output_dim == 1: # For BCELoss (binary classification)
102
- layers.append(nn.Sigmoid())
103
- # For multi-class, nn.CrossEntropyLoss expects raw logits, so no final activation here.
104
- self.network = nn.Sequential(*layers)
105
- def forward(self, x): return self.network(x)
106
-
107
- class SimpleCNN(nn.Module):
108
- def __init__(self, input_channels, img_size_wh, num_classes=10, task_type="classification",
109
- c_out1=16, k1=3, s1=1, p1=1, pool1_k=2, pool1_s=2,
110
- c_out2=32, k2=3, s2=1, p2=1, pool2_k=2, pool2_s=2,
111
- fc_hidden=128):
112
- super(SimpleCNN, self).__init__()
113
- self.input_channels = input_channels
114
- self.img_h, self.img_w = img_size_wh
115
- self.num_classes = num_classes
116
-
117
- self.conv1 = nn.Conv2d(self.input_channels, c_out1, kernel_size=k1, stride=s1, padding=p1)
118
- self.relu1 = nn.ReLU()
119
- self.pool1 = nn.MaxPool2d(kernel_size=pool1_k, stride=pool1_s)
120
-
121
- h_out_conv1 = (self.img_h - k1 + 2 * p1) // s1 + 1
122
- w_out_conv1 = (self.img_w - k1 + 2 * p1) // s1 + 1
123
- h_pool1 = (h_out_conv1 - pool1_k) // pool1_s + 1
124
- w_pool1 = (w_out_conv1 - pool1_k) // pool1_s + 1
125
-
126
- self.conv2 = nn.Conv2d(c_out1, c_out2, kernel_size=k2, stride=s2, padding=p2)
127
- self.relu2 = nn.ReLU()
128
- self.pool2 = nn.MaxPool2d(kernel_size=pool2_k, stride=pool2_s)
129
-
130
- h_out_conv2 = (h_pool1 - k2 + 2 * p2) // s2 + 1
131
- w_out_conv2 = (w_pool1 - k2 + 2 * p2) // s2 + 1
132
- h_pool2 = (h_out_conv2 - pool2_k) // pool2_s + 1
133
- w_pool2 = (w_out_conv2 - pool2_k) // pool2_s + 1
134
-
135
- self.flattened_size = c_out2 * h_pool2 * w_pool2
136
- if self.flattened_size <= 0:
137
- raise ValueError(f"Calculated flattened size is {self.flattened_size}. Check CNN params and image size.")
138
-
139
- self.fc1 = nn.Linear(self.flattened_size, fc_hidden)
140
- self.relu3 = nn.ReLU()
141
- self.fc2 = nn.Linear(fc_hidden, self.num_classes)
142
 
143
- # The final activation is now a separate attribute for clarity.
144
- if task_type == "classification":
145
- if self.num_classes == 1: # Binary classification with BCELoss
146
- self.final_activation = nn.Sigmoid()
147
- else: # Multi-class classification with CrossEntropyLoss
148
- self.final_activation = nn.Identity() # The loss function combines Softmax and NLLLoss.
149
- else: # Regression
150
- self.final_activation = nn.Identity()
151
 
152
  def forward(self, x):
153
- x = self.pool1(self.relu1(self.conv1(x)))
154
- x = self.pool2(self.relu2(self.conv2(x)))
155
- x = x.view(-1, self.flattened_size)
156
- x = self.relu3(self.fc1(x))
157
- x = self.fc2(x)
158
- x = self.final_activation(x)
159
- return x
160
-
161
- # --- Parameter Target Helpers ---
162
- PARAM_RANGES = collections.OrderedDict([
163
- ("Tiny (<10k)", (0, 10000)),
164
- ("Small (10k-50k)", (10000, 50000)),
165
- ("Medium (50k-250k)", (50000, 250000)),
166
- ("Large (250k-1M)", (250000, 1000000)),
167
- ])
168
-
169
- def suggest_mlp_layers_for_range(input_dim, output_dim, target_range_str, current_logs=""):
170
- logs = current_logs
171
- if not target_range_str or target_range_str not in PARAM_RANGES:
172
- logs += "Invalid parameter range selected for MLP suggestion.\n"; return "", logs
173
- min_p, max_p = PARAM_RANGES[target_range_str]
174
- target_p_avg = (min_p + max_p) // 2
175
- suggested_layers_str = ""
176
- if input_dim <=0 or output_dim <=0:
177
- logs += "Input/Output dims must be positive for MLP suggestion.\n"; return "", logs
178
-
179
- h1_candidate = max(1, int(target_p_avg / (input_dim + output_dim + 1e-6)))
180
- params_1_layer = (input_dim * h1_candidate + h1_candidate) + (h1_candidate * output_dim + output_dim)
181
- if min_p <= params_1_layer <= max_p and h1_candidate > 0:
182
- suggested_layers_str = str(h1_candidate)
183
- logs += f"Suggested 1 hidden layer: {h1_candidate} units (Est. Params: {params_1_layer})\n"
184
- else:
185
- h_base = max(1, int(math.sqrt(target_p_avg / 2.0)))
186
- h1 = min(2048, max(1, int(h_base * (input_dim / (input_dim + output_dim + 1e-6)) * 2 + h_base / 2)))
187
- h2 = min(2048, max(1, int(h_base * (output_dim / (input_dim + output_dim + 1e-6)) * 2 + h_base / 2)))
188
- params_2_layers = (input_dim * h1 + h1) + (h1 * h2 + h2) + (h2 * output_dim + output_dim)
189
- if min_p <= params_2_layers <= max_p and h1 > 0 and h2 > 0:
190
- suggested_layers_str = f"{h1},{h2}"
191
- logs += f"Suggested 2 hidden layers: {h1},{h2} units (Est. Params: {params_2_layers})\n"
192
- else:
193
- if target_p_avg < 50000: suggested_layers_str = str(max(1, int(target_p_avg / (input_dim + output_dim + 100)))) or "32"
194
- elif target_p_avg < 250000: h = max(1,int(math.sqrt(target_p_avg/1.5))); suggested_layers_str=f"{h},{h//2}" if h>0 and h//2 >0 else "128,64"
195
- else: h = max(1,int(math.sqrt(target_p_avg/2.0))); suggested_layers_str=f"{h},{h},{h//2}" if h>0 and h//2 >0 else "256,256,128"
196
- logs += f"Fallback suggestion: {suggested_layers_str} (Verify params).\n"
197
- if not suggested_layers_str: suggested_layers_str = "64"; logs += "Defaulting to '64'.\n"
198
- return suggested_layers_str, logs
199
-
200
- def estimate_current_mlp_params(input_dim_str, hidden_layers_str, output_dim_str, task_type, current_logs=""):
201
- logs = current_logs
202
- try:
203
- input_dim = int(input_dim_str); output_dim = int(output_dim_str)
204
- if input_dim <= 0 or output_dim <= 0: return "Input/Output dims must be > 0", logs
205
-
206
- mlp_task_type = "classification" if task_type.endswith("Classification") else "regression"
207
- temp_mlp = SimpleMLP(input_dim, hidden_layers_str, output_dim, task_type=mlp_task_type)
208
- params = count_pytorch_parameters(temp_mlp); del temp_mlp
209
- return f"{params:,}", logs
210
- except Exception as e: logs += f"Error estimating MLP params: {e}\n"; return "Error", logs
211
-
212
- def estimate_cnn_params(img_h_str, img_w_str, num_classes_str, task_type, current_logs=""):
213
- logs = current_logs
214
- try:
215
- img_h, img_w, num_classes_parsed = int(img_h_str), int(img_w_str), int(num_classes_str)
216
- if not (img_h > 0 and img_w > 0 and num_classes_parsed > 0): return "Image dims/classes must be > 0", logs
217
-
218
- cnn_task_type = "classification" if task_type.endswith("Classification") else "regression"
219
- temp_cnn = SimpleCNN(input_channels=1, img_size_wh=(img_h, img_w), num_classes=num_classes_parsed, task_type=cnn_task_type)
220
- params = count_pytorch_parameters(temp_cnn); del temp_cnn
221
- return f"{params:,}", logs
222
- except Exception as e: logs += f"Error estimating CNN params: {traceback.format_exc()}\n"; return "Error", logs
223
-
224
-
225
- # --- Dataset and Preprocessing ---
226
- def generate_dataset_backend(task_type, n_samples_str, n_features_str,
227
- n_classes_or_informative_str, dataset_format,
228
- ai_suggest_ds_shape, target_param_range_str, model_type_selection,
229
- current_logs=""):
230
- logs = current_logs + "\n--- Generating Dataset ---\n"
231
- try:
232
- n_samples = int(n_samples_str); n_features = int(n_features_str); n_classes_or_informative = int(n_classes_or_informative_str)
233
- except ValueError: logs += "Invalid numbers for dataset generation.\n"; return None, "Error", logs, None
234
-
235
- if ai_suggest_ds_shape:
236
- n_samples_sugg, n_features_sugg, n_classes_or_informative_sugg = 5000, 10, 2
237
- if task_type == "Tabular Regression": n_classes_or_informative_sugg = min(n_features_sugg // 2, 5) if n_features_sugg > 0 else 1
238
- elif task_type == "Basic Image Classification": n_samples_sugg, n_features_sugg = 500, 0
239
-
240
- is_nn = "Network" in model_type_selection
241
- if is_nn and target_param_range_str in PARAM_RANGES:
242
- min_p, max_p = PARAM_RANGES[target_param_range_str]; avg_p = (min_p + max_p) / 2
243
- if avg_p > 200000: n_samples_sugg = min(MAX_GENERATED_ROWS, n_samples_sugg * 3); n_features_sugg = min(MAX_GENERATED_COLS, n_features_sugg * 2) if task_type.startswith("Tabular") else n_features_sugg
244
- elif avg_p < 50000: n_samples_sugg = max(200, n_samples_sugg // 2); n_features_sugg = max(3, n_features_sugg // 2) if task_type.startswith("Tabular") else n_features_sugg
245
-
246
- n_samples, n_features, n_classes_or_informative = n_samples_sugg, n_features_sugg, n_classes_or_informative_sugg
247
- logs += f"AI Suggested Dataset: Samples={n_samples}, Feats={n_features}, Classes/Informative={n_classes_or_informative}\n"
248
 
249
- n_samples = max(10, min(n_samples, MAX_GENERATED_ROWS))
250
- if task_type.startswith("Tabular"): n_features = max(1, min(n_features, MAX_GENERATED_COLS))
251
- if n_samples > MAX_DATASET_ROWS_WARN: logs += f"Warning: Generating {n_samples} rows. May be slow.\n"
252
-
253
- df = None; X_data=None; y_data=None
254
  try:
255
  if task_type == "Tabular Classification":
256
- n_cls = max(2, n_classes_or_informative)
257
- n_inf = max(1, min(n_features, n_classes_or_informative if n_classes_or_informative >= n_cls else n_features // 2))
258
- if n_inf > n_features: n_inf = n_features
259
- X_data, y_data = make_classification(n_samples=n_samples, n_features=n_features, n_informative=n_inf,
260
- n_redundant=max(0,n_features - n_inf)//2, n_classes=n_cls, flip_y=0.05, random_state=42)
261
- df = pd.DataFrame(X_data, columns=[f'feature_{i}' for i in range(n_features)]); df['target'] = y_data
262
  elif task_type == "Tabular Regression":
263
- n_inf = max(1, min(n_features, n_classes_or_informative))
264
- if n_inf > n_features: n_inf = n_features
265
- X_data, y_data = make_regression(n_samples=n_samples, n_features=n_features, n_informative=n_inf, noise=10, random_state=42)
266
- df = pd.DataFrame(X_data, columns=[f'feature_{i}' for i in range(n_features)]); df['target'] = y_data
267
- elif task_type == "Basic Image Classification":
268
- img_h, img_w = 28, 28
269
- num_pixels = img_h * img_w
270
- X_data = np.random.randint(0, 256, size=(n_samples, num_pixels), dtype=np.uint8)
271
- y_data = np.random.randint(0, max(2, n_classes_or_informative), n_samples)
272
- df = pd.DataFrame(X_data, columns=[f'pixel_{i}' for i in range(num_pixels)]); df['target'] = y_data
273
- logs += f"Generated {img_h}x{img_w} Image placeholder data.\n"
274
- else: logs += f"Dataset generation for '{task_type}' not fully implemented.\n"; return None, "Task not implemented", logs, None
275
 
276
- logs += f"Generated data: {df.shape if df is not None else (X_data.shape, y_data.shape)}\n"
 
 
 
277
  file_path = get_temp_filepath("generated_dataset", dataset_format)
278
- if df is not None:
279
- if dataset_format == ".csv": df.to_csv(file_path, index=False)
280
- elif dataset_format == ".json": df.to_json(file_path, orient='records', lines=True)
281
- elif dataset_format == ".parquet": df.to_parquet(file_path, index=False)
282
- else: logs += f"Unsupported format {dataset_format}. Defaulting to CSV.\n"; file_path=get_temp_filepath("generated_dataset","csv"); df.to_csv(file_path, index=False)
283
- logs += f"Dataset saved to {file_path}\n"
284
- return df.head(), df, logs, file_path
285
- else:
286
- logs += "Dataset generated as numpy arrays. Not saving to file from this function directly.\n"
287
- return pd.DataFrame(X_data[:5] if X_data is not None else None), (X_data, y_data), logs, None
288
-
289
- except Exception as e: error_msg=f"Error generating dataset: {traceback.format_exc()}"; logs+=error_msg+"\n"; return None, error_msg, logs, None
290
-
291
- def preprocess_tabular_data(df_or_X, y_if_X_is_numpy, target_column_name, task_type, current_logs=""):
292
- logs = current_logs
293
- if isinstance(df_or_X, pd.DataFrame):
294
- df = df_or_X
295
- if target_column_name not in df.columns: raise ValueError(f"Target column '{target_column_name}' not found.")
296
- X_df = df.drop(target_column_name, axis=1)
297
- y_series = df[target_column_name]
298
- elif isinstance(df_or_X, np.ndarray) and y_if_X_is_numpy is not None:
299
- X_df = pd.DataFrame(df_or_X, columns=[f'feature_{i}' for i in range(df_or_X.shape[1])])
300
- y_series = pd.Series(y_if_X_is_numpy)
301
- else: raise ValueError("Invalid input for preprocess_tabular_data.")
302
-
303
- numerical_features = X_df.select_dtypes(include=np.number).columns.tolist()
304
- categorical_features = X_df.select_dtypes(include='object').columns.tolist()
305
- logs += f"Numerical: {numerical_features}, Categorical: {categorical_features}\n"
306
-
307
- preprocessor = ColumnTransformer(transformers=[
308
- ('num', Pipeline([('imputer', SimpleImputer(strategy='mean')), ('scaler', StandardScaler())]), numerical_features),
309
- ('cat', Pipeline([('imputer', SimpleImputer(strategy='most_frequent')), ('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=False))]), categorical_features)
310
- ], remainder='passthrough')
311
 
312
- X_processed_np = preprocessor.fit_transform(X_df)
313
- processed_input_dim = X_processed_np.shape[1]
314
- logs += f"Tabular data preprocessed. X shape: {X_processed_np.shape}, Processed input dim: {processed_input_dim}\n"
 
315
 
316
- if task_type.endswith("Classification"):
317
- le = LabelEncoder()
318
- y_processed_np = le.fit_transform(y_series)
319
- num_classes = len(le.classes_)
320
- logs += f"Target encoded. Classes: {num_classes} ({le.classes_})\n"
321
- output_dim_nn = 1 if num_classes == 2 else num_classes
322
- else: # Regression
323
- y_processed_np = y_series.astype(float).values
324
- num_classes = 1
325
- output_dim_nn = 1
326
 
327
- return X_processed_np, y_processed_np, preprocessor, logs, processed_input_dim, output_dim_nn
328
-
329
-
330
- # --- Training Functions ---
331
- def train_model_sklearn(data_input_obj, target_column, task_type, model_name, model_output_format, current_logs=""):
332
- logs = current_logs + f"\n--- Training Scikit-learn Model: {model_name} ---\n"
333
- model_path_out, metrics_out, model_params_out = None, "Training failed.", "N/A"
334
-
335
- df = None
336
- if isinstance(data_input_obj, str):
337
- try:
338
- if data_input_obj.endswith('.csv'): df = pd.read_csv(data_input_obj)
339
- elif data_input_obj.endswith('.json'): df = pd.read_json(data_input_obj, lines=True)
340
- elif data_input_obj.endswith('.parquet'): df = pd.read_parquet(data_input_obj)
341
- else: logs += f"Unsupported file: {data_input_obj}\n"; return logs, "Error: Unsupported file.", None, "N/A"
342
- except Exception as e: logs += f"Error reading {data_input_obj}: {e}\n"; return logs, f"Error reading: {e}", None, "N/A"
343
- elif isinstance(data_input_obj, pd.DataFrame): df = data_input_obj
344
- else: logs += "Invalid data for training.\n"; return logs, "Error: Invalid data.", None, "N/A"
345
-
346
- if not target_column or target_column not in df.columns:
347
- logs += f"Target column '{target_column}' not provided or not found.\n"; return logs, f"Error: Target '{target_column}' not found/provided.", None, "N/A"
348
-
349
  try:
350
- X_processed_np, y_processed_np, preprocessor, logs, _, _ = preprocess_tabular_data(df, None, target_column, task_type, logs)
351
- except ValueError as e: logs += f"Preprocessing error: {e}\n"; return logs, f"Error: {e}", None, "N/A"
352
-
353
- X_train, X_test, y_train, y_test = train_test_split(X_processed_np, y_processed_np, test_size=0.2, random_state=42)
354
- logs += f"Train/Test split. Train: {X_train.shape}, Test: {X_test.shape}\n"
355
-
356
- model = None
357
- if task_type == "Tabular Classification":
358
- if model_name == "Logistic Regression": model = LogisticRegression(max_iter=1000, random_state=42)
359
- elif model_name == "Random Forest Classifier": model = RandomForestClassifier(random_state=42)
360
- elif model_name == "Support Vector Machine (SVM) Classifier": model = SVC(random_state=42, probability=True)
361
- elif task_type == "Tabular Regression":
362
- if model_name == "Linear Regression": model = LinearRegression()
363
- elif model_name == "Random Forest Regressor": model = RandomForestRegressor(random_state=42)
364
- elif model_name == "Support Vector Machine (SVR) Regressor": model = SVR()
365
- if model is None: logs += f"Model {model_name} or task {task_type} not supported.\n"; return logs, "Model/Task Error", None, "N/A"
366
-
367
- try:
368
- logs += f"Starting training for {model_name}...\n"; start_time = time.time()
369
- model.fit(X_train, y_train)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  logs += f"Training completed in {time.time() - start_time:.2f}s.\n"
371
- model_params_out = str(count_sklearn_parameters(model))
372
- logs += f"Est. Model Params: {model_params_out}\n"
373
- y_pred = model.predict(X_test)
374
-
375
  if task_type == "Tabular Classification":
376
- acc = accuracy_score(y_test, y_pred); report = classification_report(y_test, y_pred, zero_division=0)
377
- metrics_out = f"Accuracy: {acc:.4f}\n\nClassification Report:\n{report}"
378
- elif task_type == "Tabular Regression":
379
- mse = mean_squared_error(y_test, y_pred); r2 = r2_score(y_test, y_pred)
380
- metrics_out = f"Mean Squared Error: {mse:.4f}\nR2 Score: {r2:.4f}"
381
- logs += "\n--- Evaluation Metrics ---\n" + metrics_out + "\n"
382
- full_pipeline_for_saving = Pipeline([('preprocessor', preprocessor), ('model', model)])
 
 
 
383
  model_filename_base = f"sklearn_{model_name.replace(' ', '_').lower()}"
384
-
385
  if model_output_format == ".pkl (Scikit-learn)":
386
- model_path_out = get_temp_filepath(model_filename_base, "pkl")
387
- joblib.dump(full_pipeline_for_saving, model_path_out)
388
- logs += f"Model (with preprocessor) saved to {model_path_out} as PKL.\n"
389
  elif model_output_format == ".onnx (ONNX)":
390
- model_path_out = get_temp_filepath(model_filename_base, "onnx")
391
- raw_X_for_types_df = df.drop(target_column, axis=1).infer_objects()
392
- onnx_initial_types = []
393
- for col_name in raw_X_for_types_df.columns:
394
- col_dtype = raw_X_for_types_df[col_name].dtype
395
- if pd.api.types.is_numeric_dtype(col_dtype):
396
- onnx_initial_types.append((col_name, FloatTensorType([None, 1])))
397
- elif pd.api.types.is_string_dtype(col_dtype) or col_dtype == 'object':
398
- onnx_initial_types.append((col_name, StringTensorType([None, 1])))
399
  else:
400
- logs += f"Warning: Unsupported dtype {col_dtype} for {col_name} in ONNX. Defaulting to Float.\n"
401
- onnx_initial_types.append((col_name, FloatTensorType([None, 1])))
402
 
403
- if not onnx_initial_types: raise ValueError("ONNX initial types failed: No valid columns found.")
404
- try:
405
- options = {id(full_pipeline_for_saving): {'zipmap': False}} if task_type.endswith("Classification") else {}
406
- onnx_model = convert_sklearn(full_pipeline_for_saving, initial_types=onnx_initial_types, target_opset=12, options=options)
407
- with open(model_path_out, "wb") as f: f.write(onnx_model.SerializeToString())
408
- logs += f"Model saved to {model_path_out} as ONNX.\n"
409
- sess = rt.InferenceSession(model_path_out, providers=rt.get_available_providers())
410
- logs += f"ONNX model loaded successfully with ONNX Runtime.\n"
411
- except Exception as onnx_e: logs += f"ONNX Error: {traceback.format_exc()}\n"; model_path_out=None; metrics_out+="\nONNX EXPORT FAILED."
412
- else:
413
- logs += f"Unsupported format '{model_output_format}'. Saving as .pkl\n"
414
- model_path_out = get_temp_filepath(model_filename_base, "pkl")
415
- joblib.dump(full_pipeline_for_saving, model_path_out)
416
- except Exception as e: error_msg=f"Sklearn training error: {traceback.format_exc()}"; logs+=error_msg+"\n"; metrics_out=error_msg
417
- return logs, metrics_out, model_path_out, model_params_out
418
-
419
- def train_model_pytorch(data_input_obj, target_column, task_type, model_type_pt,
420
- mlp_hidden_layers_str, mlp_activation,
421
- epochs_str, batch_size_str, lr_str,
422
- model_output_format, current_logs=""):
423
- logs = current_logs + f"\n--- Training PyTorch Model: {model_type_pt} ---\n"
424
- model_path_out, metrics_out, model_params_out, plot_out = None, "Training failed.", "N/A", None
425
-
426
- df_for_pytorch, X_numpy_for_pytorch, y_numpy_for_pytorch = None, None, None
427
- if isinstance(data_input_obj, str):
428
- try:
429
- if data_input_obj.endswith('.csv'): df_for_pytorch = pd.read_csv(data_input_obj)
430
- elif data_input_obj.endswith('.json'): df_for_pytorch = pd.read_json(data_input_obj, lines=True)
431
- elif data_input_obj.endswith('.parquet'): df_for_pytorch = pd.read_parquet(data_input_obj)
432
- else: logs += f"Unsupported file: {data_input_obj}\n"; return logs, "Error", None, "N/A", None
433
- except Exception as e: logs += f"Error reading {data_input_obj}: {e}\n"; return logs, f"Error: {e}", None, "N/A", None
434
- elif isinstance(data_input_obj, pd.DataFrame): df_for_pytorch = data_input_obj
435
- elif isinstance(data_input_obj, tuple): X_numpy_for_pytorch, y_numpy_for_pytorch = data_input_obj
436
- else: logs += "Invalid data for PyTorch training.\n"; return logs, "Error", None, "N/A", None
437
-
438
- try:
439
- epochs = int(epochs_str); batch_size = int(batch_size_str); lr = float(lr_str)
440
- if not (epochs > 0 and batch_size > 0 and lr > 0): raise ValueError("Params must be >0.")
441
- except ValueError as e: logs += f"Invalid training params: {e}\n"; return logs, f"Error: {e}", None, "N/A", None
442
-
443
- processed_input_dim_actual, nn_output_dim_actual, preprocessor_pipeline = -1, -1, None
444
- X_processed_np, y_processed_np = None, None
445
-
446
- if model_type_pt == "Simple Neural Network (MLP)":
447
- if not task_type.startswith("Tabular"): logs += "MLP requires Tabular task.\n"; return logs, "MLP Task Error", None, "N/A", None
448
- if not target_column and df_for_pytorch is not None: logs += "Target column needed for MLP with DataFrame.\n"; return logs, "MLP Target Error", None, "N/A", None
449
- try:
450
- data_arg1 = df_for_pytorch if df_for_pytorch is not None else X_numpy_for_pytorch
451
- data_arg2 = y_numpy_for_pytorch if df_for_pytorch is None else None
452
- current_target_col = target_column if df_for_pytorch is not None else "target"
453
- X_processed_np, y_processed_np, preprocessor_pipeline, logs, processed_input_dim_actual, nn_output_dim_actual = \
454
- preprocess_tabular_data(data_arg1, data_arg2, current_target_col, task_type, logs)
455
- except ValueError as e: logs+=f"MLP Preprocessing error: {e}\n"; return logs,f"Error: {e}",None,"N/A",None
456
- elif model_type_pt == "Simple Convolutional Network (CNN)":
457
- X_raw, y_raw = (df_for_pytorch.drop(target_column, axis=1).values, df_for_pytorch[target_column].values) if df_for_pytorch is not None else (X_numpy_for_pytorch, y_numpy_for_pytorch)
458
- if X_raw is None: logs += "No valid data for CNN.\n"; return logs, "CNN Data Error", None, "N/A", None
459
- le = LabelEncoder(); y_processed_np = le.fit_transform(y_raw)
460
- num_classes = len(le.classes_)
461
- nn_output_dim_actual = 1 if num_classes == 2 else num_classes
462
-
463
- pixels_per_sample = X_raw.shape[1]; img_h, img_w, input_channels = 28,28,1
464
- img_dim_approx = int(math.sqrt(pixels_per_sample))
465
- if img_dim_approx * img_dim_approx == pixels_per_sample: img_h, img_w = img_dim_approx, img_dim_approx
466
- else: logs += f"Warning: Cannot infer square image from {pixels_per_sample} pixels. Defaulting to 28x28.\n"
467
-
468
- X_processed_np = X_raw.reshape(-1, input_channels, img_h, img_w).astype(np.float32) / 255.0
469
- processed_input_dim_actual = (input_channels, img_h, img_w)
470
- logs += f"CNN Data: X reshaped to {X_processed_np.shape}, y: {y_processed_np.shape}\n"
471
- else: logs += f"Unknown PyTorch model: {model_type_pt}\n"; return logs, "Unknown PyTorch model", None, "N/A", None
472
-
473
- X_train, X_test, y_train, y_test = train_test_split(X_processed_np, y_processed_np, test_size=0.2, random_state=42)
474
- logs += f"PyTorch Train/Test split. Train: {X_train.shape}, Test: {X_test.shape}\n"
475
-
476
- y_train_dtype = torch.float32 if (nn_output_dim_actual == 1 and not task_type.endswith("Classification")) else (torch.float32 if nn_output_dim_actual == 1 else torch.long)
477
- X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
478
- y_train_tensor = torch.tensor(y_train, dtype=y_train_dtype)
479
- X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
480
- y_test_tensor = torch.tensor(y_test, dtype=y_train_dtype)
481
-
482
- if nn_output_dim_actual == 1: # For BCELoss and MSELoss, target needs to be [N, 1]
483
- y_train_tensor = y_train_tensor.unsqueeze(1)
484
- y_test_tensor = y_test_tensor.unsqueeze(1)
485
-
486
- train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
487
- train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
488
-
489
- pytorch_model = None
490
- try:
491
- is_classification = task_type.endswith("Classification")
492
- if model_type_pt == "Simple Neural Network (MLP)":
493
- pytorch_model = SimpleMLP(input_dim=processed_input_dim_actual, hidden_layers_str=mlp_hidden_layers_str,
494
- output_dim=nn_output_dim_actual, activation_fn_str=mlp_activation,
495
- task_type="classification" if is_classification else "regression")
496
- elif model_type_pt == "Simple Convolutional Network (CNN)":
497
- channels, h, w = processed_input_dim_actual
498
- pytorch_model = SimpleCNN(input_channels=channels, img_size_wh=(h,w), num_classes=nn_output_dim_actual,
499
- task_type="classification" if is_classification else "regression")
500
- except Exception as model_e: logs += f"Error creating PyTorch model: {traceback.format_exc()}\n"; return logs, f"Model Creation Error: {model_e}", None, "N/A", None
501
-
502
- if pytorch_model is None: logs += "Failed to instantiate PyTorch model.\n"; return logs, "Model instantiate fail", None, "N/A", None
503
- model_params_val = count_pytorch_parameters(pytorch_model); model_params_out = f"{model_params_val:,}"
504
- logs += f"PyTorch Model: {model_params_out} params.\n"
505
-
506
- criterion = nn.BCELoss() if (is_classification and nn_output_dim_actual == 1) else (nn.CrossEntropyLoss() if is_classification else nn.MSELoss())
507
- optimizer = optim.Adam(pytorch_model.parameters(), lr=lr)
508
-
509
- logs += f"Starting PyTorch training for {epochs} epochs...\n"; start_time = time.time()
510
- epoch_losses = []
511
- pytorch_model.train()
512
- for epoch in range(epochs):
513
- epoch_loss_sum = 0.0
514
- for batch_X, batch_y in train_dataloader:
515
- optimizer.zero_grad(); outputs = pytorch_model(batch_X)
516
- loss = criterion(outputs, batch_y); loss.backward(); optimizer.step()
517
- epoch_loss_sum += loss.item()
518
- avg_epoch_loss = epoch_loss_sum / len(train_dataloader) if len(train_dataloader) > 0 else 0
519
- epoch_losses.append(avg_epoch_loss)
520
- if (epoch + 1) % max(1, epochs // 10) == 0 or epoch == epochs - 1: # Log ~10 times
521
- logs += f"Epoch {epoch+1}/{epochs}, Avg Training Loss: {avg_epoch_loss:.4f}\n"
522
-
523
- logs += f"PyTorch training completed in {time.time() - start_time:.2f}s.\n"
524
- pytorch_model.eval()
525
- with torch.no_grad():
526
- test_outputs = pytorch_model(X_test_tensor)
527
- if is_classification:
528
- predicted = (test_outputs > 0.5).float() if nn_output_dim_actual == 1 else torch.max(test_outputs.data, 1)[1]
529
- acc = accuracy_score(y_test_tensor.cpu().numpy(), predicted.cpu().numpy())
530
- report = classification_report(y_test_tensor.cpu().numpy(), predicted.cpu().numpy(), zero_division=0)
531
- metrics_out = f"Final Avg Training Loss: {avg_epoch_loss:.4f}\n\n--- Test Set Evaluation ---\nAccuracy: {acc:.4f}\n\nClassification Report:\n{report}"
532
- else: # Regression
533
- mse = mean_squared_error(y_test_tensor.cpu().numpy(), test_outputs.cpu().numpy())
534
- r2 = r2_score(y_test_tensor.cpu().numpy(), test_outputs.cpu().numpy())
535
- metrics_out = f"Final Avg Training Loss: {avg_epoch_loss:.4f}\n\n--- Test Set Evaluation ---\nMean Squared Error: {mse:.4f}\nR2 Score: {r2:.4f}"
536
- logs += "\n--- PyTorch Metrics ---\n" + metrics_out + "\n"
537
-
538
- if epoch_losses:
539
- fig, ax = plt.subplots(); ax.plot(range(1, epochs + 1), epoch_losses, marker='o')
540
- ax.set_xlabel("Epoch"); ax.set_ylabel("Average Loss"); ax.set_title("Training Loss Curve")
541
- plot_out = fig; logs += "Loss curve generated.\n"
542
-
543
- model_filename_base = f"pytorch_{model_type_pt.replace(' ', '_').lower()}"
544
- if model_output_format == ".pt (PyTorch)":
545
- model_path_out = get_temp_filepath(model_filename_base, "pt")
546
- save_obj = {'model_state_dict': pytorch_model.state_dict(), 'output_dim_nn': nn_output_dim_actual, 'task_type': task_type}
547
- if model_type_pt == "Simple Neural Network (MLP)" and preprocessor_pipeline:
548
- save_obj.update({'preprocessor': preprocessor_pipeline, 'input_dim_processed': processed_input_dim_actual, 'hidden_layers_str': mlp_hidden_layers_str, 'activation_fn': mlp_activation})
549
- elif model_type_pt == "Simple Convolutional Network (CNN)":
550
- c,h,w = processed_input_dim_actual; save_obj.update({'input_channels':c, 'img_h':h, 'img_w':w})
551
- torch.save(save_obj, model_path_out)
552
- logs += f"PyTorch model saved to {model_path_out}\n"
553
- else: # Fallback
554
- logs += f"Unsupported format '{model_output_format}'.\n"
555
- return logs, metrics_out, model_path_out, model_params_out, plot_out
556
-
557
-
558
- # --- Main Training Wrapper Function ---
559
- def train_model_wrapper(data_input_obj, target_column, task_type, model_family,
560
- model_specific_choice,
561
- mlp_hidden_layers, mlp_activation,
562
- epochs, batch_size, learning_rate,
563
- model_output_format, current_logs):
564
-
565
- logs = current_logs + "\n--- Kicking off Training ---\n"
566
- if data_input_obj is None:
567
- logs += "ERROR: No dataset has been generated or uploaded. Please go to Tab 2.\n"
568
- return logs, "Error: No dataset available.", None, "N/A", None
569
 
570
- try:
571
- if model_family == "Scikit-learn (Classical ML)":
572
- logs, metrics, model_path, param_count = train_model_sklearn(
573
- data_input_obj=data_input_obj,
574
- target_column=target_column,
575
- task_type=task_type,
576
- model_name=model_specific_choice,
577
- model_output_format=model_output_format,
578
- current_logs=logs
579
- )
580
- return logs, metrics, model_path, param_count, None
581
- elif model_family == "PyTorch (Neural Networks)":
582
- logs, metrics, model_path, param_count, loss_plot = train_model_pytorch(
583
- data_input_obj=data_input_obj,
584
- target_column=target_column,
585
- task_type=task_type,
586
- model_type_pt=model_specific_choice,
587
- mlp_hidden_layers_str=mlp_hidden_layers,
588
- mlp_activation=mlp_activation,
589
- epochs_str=str(int(epochs)),
590
- batch_size_str=str(int(batch_size)),
591
- lr_str=str(learning_rate),
592
- model_output_format=model_output_format,
593
- current_logs=logs
594
- )
595
- return logs, metrics, model_path, param_count, loss_plot
596
- else:
597
- logs += f"Unknown model family: {model_family}\n"
598
- return logs, "Error: Unknown model family.", None, "N/A", None
599
  except Exception as e:
600
- error_msg = f"An unexpected error occurred in the training wrapper: {traceback.format_exc()}"
601
  logs += error_msg + "\n"
602
- return logs, error_msg, None, "N/A", None
603
-
604
 
605
- # --- Gradio UI Definition ---
606
- TASK_CHOICES = ["Tabular Classification", "Tabular Regression", "Basic Image Classification"]
607
- MODEL_FAMILIES = ["Scikit-learn (Classical ML)", "PyTorch (Neural Networks)"]
608
- SKLEARN_MODELS_CLASSIFICATION = ["Logistic Regression", "Random Forest Classifier", "Support Vector Machine (SVM) Classifier"]
609
- SKLEARN_MODELS_REGRESSION = ["Linear Regression", "Random Forest Regressor", "Support Vector Machine (SVR) Regressor"]
610
- PYTORCH_MODELS = ["Simple Neural Network (MLP)", "Simple Convolutional Network (CNN)"]
611
- DATASET_FORMATS = [".csv", ".json", ".parquet"]
612
- MODEL_OUTPUT_FORMATS_SKLEARN = [".pkl (Scikit-learn)", ".onnx (ONNX)"]
613
- MODEL_OUTPUT_FORMATS_PYTORCH = [".pt (PyTorch)"]
614
- MLP_ACTIVATIONS = ["relu", "tanh", "sigmoid"]
615
- CLONE_GUIDE_TEXT = """
616
- ## How to Clone & Upgrade This Space for More Power:
617
- 1. **Clone this Space:** Click the '...' menu at the top-right and choose 'Duplicate this Space'.
618
- 2. **Choose Hardware:** On the duplication screen, select a more powerful hardware option, like a "CPU upgrade" or a "T4 Small" GPU.
619
- 3. **Enjoy Faster Training:** Your private, upgraded version of TrainAI will now train models significantly faster!
620
- """
621
 
622
- # Determine initial choices for model_specific_dd based on default task_type and model_family
623
- _initial_task_default = TASK_CHOICES[0]
624
- _initial_family_default = MODEL_FAMILIES[0]
625
- if _initial_family_default == "Scikit-learn (Classical ML)":
626
- initial_model_choices_for_specific_dd = SKLEARN_MODELS_CLASSIFICATION
627
- elif _initial_family_default == "PyTorch (Neural Networks)":
628
- initial_model_choices_for_specific_dd = [PYTORCH_MODELS[0]]
629
- initial_model_value_for_specific_dd = initial_model_choices_for_specific_dd[0]
 
 
 
 
630
 
 
631
  def update_model_options(task_choice, model_family_choice):
632
- choices, value = [], None
 
633
  if model_family_choice == "Scikit-learn (Classical ML)":
634
- if task_choice == "Tabular Classification": choices = SKLEARN_MODELS_CLASSIFICATION
635
- elif task_choice == "Tabular Regression": choices = SKLEARN_MODELS_REGRESSION
636
- elif model_family_choice == "PyTorch (Neural Networks)":
637
- if task_choice.startswith("Tabular"): choices = [PYTORCH_MODELS[0]] # MLP
638
- elif task_choice == "Basic Image Classification": choices = [PYTORCH_MODELS[1]] # CNN
 
639
  value = choices[0] if choices else None
640
  return gr.update(choices=choices, value=value, visible=bool(choices))
641
 
642
- def update_pytorch_specific_options_visibility(model_family_choice, specific_pytorch_model):
643
- is_pytorch = model_family_choice == "PyTorch (Neural Networks)"
644
- is_mlp = is_pytorch and (specific_pytorch_model == "Simple Neural Network (MLP)")
645
- is_cnn = is_pytorch and (specific_pytorch_model == "Simple Convolutional Network (CNN)")
646
- return gr.update(visible=is_pytorch), gr.update(visible=is_mlp), gr.update(visible=is_cnn)
647
-
648
  def update_model_output_formats(model_family_choice):
649
- if model_family_choice == "Scikit-learn (Classical ML)": return gr.update(choices=MODEL_OUTPUT_FORMATS_SKLEARN, value=MODEL_OUTPUT_FORMATS_SKLEARN[0])
650
- if model_family_choice == "PyTorch (Neural Networks)": return gr.update(choices=MODEL_OUTPUT_FORMATS_PYTORCH, value=MODEL_OUTPUT_FORMATS_PYTORCH[0])
651
- return gr.update(choices=[], value=None)
652
-
653
- css = """.gradio-container { font-family: 'IBM Plex Sans', sans-serif; } footer {display:none !important}"""
 
 
 
654
 
655
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="orange"), css=css) as demo:
 
656
  gr.Markdown("# 🧠 TrainAI ⚙️")
657
- gr.Markdown("Create, train, and download AI models. Optimized for CPU - expect longer training for complex models.")
658
- generated_data_state = gr.State(None)
659
- current_logs_state = gr.State("")
660
-
 
661
  with gr.Tabs():
662
  with gr.TabItem("1. Define Task & Model"):
663
  with gr.Row():
664
- task_type_dd = gr.Dropdown(TASK_CHOICES, label="Select Task Type", value=_initial_task_default)
665
- model_family_dd = gr.Dropdown(MODEL_FAMILIES, label="Select Model Family", value=_initial_family_default)
666
- model_specific_dd = gr.Dropdown(label="Select Specific Model", choices=initial_model_choices_for_specific_dd, value=initial_model_value_for_specific_dd, interactive=True)
667
 
668
- with gr.Group(visible=(_initial_family_default == "PyTorch (Neural Networks)")) as pt_options_group:
669
- pytorch_param_range_dd = gr.Dropdown(list(PARAM_RANGES.keys()), label="Target Parameter Range (for NNs)",
670
- info="Guides NN architecture suggestions. Training >250k params on CPU is slow.",
671
- value=list(PARAM_RANGES.keys())[1])
672
- with gr.Group(visible=(initial_model_value_for_specific_dd == PYTORCH_MODELS[0])) as pt_mlp_specific_group:
673
- gr.Markdown("#### MLP Configuration")
674
- pt_mlp_hidden_layers_txt = gr.Textbox(label="Hidden Layer Sizes (comma-separated, e.g., 128,64)", value="64,32")
675
- pt_mlp_activation_dd = gr.Dropdown(MLP_ACTIVATIONS, label="Activation Function", value="relu")
676
- with gr.Row():
677
- pt_mlp_suggest_btn = gr.Button("Suggest MLP Layers")
678
- pt_mlp_estimate_params_btn = gr.Button("Estimate Current MLP Params")
679
- pt_mlp_param_count_txt = gr.Textbox(label="Estimated MLP Parameters", interactive=False)
680
- with gr.Group(visible=(initial_model_value_for_specific_dd == PYTORCH_MODELS[1])) as pt_cnn_specific_group:
681
- gr.Markdown("#### CNN Configuration (Simplified)")
682
- gr.Markdown("SimpleCNN uses a fixed structure. Params depend on image size/classes from data.")
683
- pt_cnn_estimate_params_btn = gr.Button("Estimate CNN Params (needs Data Info)")
684
- pt_cnn_param_count_txt = gr.Textbox(label="Estimated CNN Parameters", interactive=False)
685
 
686
  with gr.TabItem("2. Configure Dataset"):
687
- dataset_source_rb = gr.Radio(["Generate new dataset", "Upload my own dataset (CSV, JSON, Parquet)"],
688
- label="Dataset Source", value="Generate new dataset")
689
- with gr.Group(visible=True) as generate_dataset_group:
690
- gr.Markdown("#### Generate Synthetic Dataset")
691
- with gr.Row():
692
- ds_gen_samples_num = gr.Number(label="# Samples", value=1000, minimum=10, step=100)
693
- ds_gen_features_num = gr.Number(label="# Features (Tabular)", value=10, minimum=1, step=1)
694
- ds_gen_classes_informative_num = gr.Number(label="Classes (Classif) / Informative Feats (Regr)", value=2, minimum=1, step=1)
695
- ds_gen_ai_suggest_cb = gr.Checkbox(label="Let AI suggest dataset shape?", value=False)
696
- ds_gen_format_dd = gr.Dropdown(DATASET_FORMATS, label="Generated Dataset Format", value=".csv")
697
- generate_dataset_btn = gr.Button("Generate & Preview Dataset", variant="secondary")
698
- with gr.Group(visible=False) as upload_dataset_group:
699
- gr.Markdown("#### Upload Dataset")
700
- ds_upload_file = gr.File(label="Upload dataset", file_types=[".csv", ".json", ".parquet"])
701
- target_column_name_txt = gr.Textbox(label="Target Column Name (Case-Sensitive!)", placeholder="e.g., 'target' or 'label'", value="target")
702
  dataset_preview_df = gr.DataFrame(label="Dataset Preview (First 5 Rows)", interactive=False, height=200)
703
  generated_dataset_download_file = gr.File(label="Download Generated Dataset", interactive=False)
704
 
705
  with gr.TabItem("3. Train Model & Get Results"):
706
- gr.Markdown("Ensure Model and Dataset are configured. For NNs, set Epochs/Batch/LR.")
707
- with gr.Row():
708
- train_epochs_num = gr.Number(label="Epochs (NNs)", value=10, minimum=1, step=1)
709
- train_batch_size_num = gr.Number(label="Batch Size (NNs)", value=32, minimum=1, step=1)
710
- train_learning_rate_num = gr.Number(label="Learning Rate (NNs)", value=0.001, minimum=1e-6, format="%.6f")
711
- model_output_format_dd = gr.Dropdown(label="Select Model Output Format", choices=MODEL_OUTPUT_FORMATS_SKLEARN, value=MODEL_OUTPUT_FORMATS_SKLEARN[0])
712
  train_model_btn = gr.Button("🚀 Train Model", variant="primary")
713
  gr.Markdown("---")
714
  gr.Markdown("### Training Progress & Results")
715
  training_log_txt = gr.Textbox(label="Training Log & Status", lines=15, interactive=False, max_lines=50)
716
- model_param_count_output_txt = gr.Textbox(label="Actual Trained Model Parameters", interactive=False)
717
  evaluation_metrics_txt = gr.Textbox(label="Evaluation Metrics", lines=7, interactive=False)
718
- loss_plot_img = gr.Plot(label="Training Loss Curve (PyTorch NNs)")
719
  download_trained_model_file = gr.File(label="Download Trained Model", interactive=False)
 
720
 
721
- with gr.TabItem("ℹ️ Guide & Info"):
722
- gr.Markdown(CLONE_GUIDE_TEXT)
723
-
724
  # --- Event Handlers ---
 
 
725
  task_type_dd.change(fn=update_model_options, inputs=[task_type_dd, model_family_dd], outputs=model_specific_dd)
726
  model_family_dd.change(fn=update_model_options, inputs=[task_type_dd, model_family_dd], outputs=model_specific_dd)
727
 
728
- model_family_dd.change(fn=update_pytorch_specific_options_visibility, inputs=[model_family_dd, model_specific_dd], outputs=[pt_options_group, pt_mlp_specific_group, pt_cnn_specific_group])
729
- model_specific_dd.change(fn=update_pytorch_specific_options_visibility, inputs=[model_family_dd, model_specific_dd], outputs=[pt_options_group, pt_mlp_specific_group, pt_cnn_specific_group])
730
-
731
- def get_data_dims_for_nn_suggestion(preview_df, target_col, task, logs_in):
732
- logs = logs_in
733
- input_dim_est, output_dim_est = 10, (2 if task.endswith("Classification") else 1) # Defaults
734
- img_h_est, img_w_est = 28, 28 # Defaults for CNN
735
- num_pixels = 0
736
-
737
- if preview_df is not None and isinstance(preview_df, pd.DataFrame) and not preview_df.empty:
738
- cols = list(preview_df.columns)
739
- if target_col in cols: cols.remove(target_col)
740
-
741
- if task == "Basic Image Classification":
742
- num_pixels = len(cols)
743
- if num_pixels > 0:
744
- dim_sqrt = int(math.sqrt(num_pixels))
745
- if dim_sqrt * dim_sqrt == num_pixels: img_h_est, img_w_est = dim_sqrt, dim_sqrt
746
- else: # Tabular
747
- num_cols = len([c for c in cols if pd.api.types.is_numeric_dtype(preview_df[c])])
748
- cat_cols = [c for c in cols if pd.api.types.is_object_dtype(preview_df[c])]
749
- one_hot_est = sum(min(10, preview_df[c].nunique(dropna=False)) for c in cat_cols)
750
- input_dim_est = max(1, num_cols + one_hot_est)
751
-
752
- if target_col and target_col in preview_df.columns:
753
- if task.endswith("Classification"):
754
- output_dim_est = max(1, preview_df[target_col].nunique(dropna=False))
755
- if output_dim_est == 2: output_dim_est = 1
756
- else: logs += "Dataset preview not available for NN dimension estimation. Using defaults.\n"
757
- return input_dim_est, output_dim_est, img_h_est, img_w_est, logs
758
-
759
- def mlp_suggest_proxy_wrapper(target_range_str, current_logs, preview_df, target_col, task_type):
760
- input_dim, output_dim, _, _, logs = get_data_dims_for_nn_suggestion(preview_df, target_col, task_type, current_logs)
761
- logs += f"Using estimated input_dim: {input_dim}, output_dim: {output_dim} for MLP suggestion.\n"
762
- suggested_str, logs = suggest_mlp_layers_for_range(input_dim, output_dim, target_range_str, logs)
763
- param_count_str = "Error"
764
- if suggested_str: param_count_str, logs = estimate_current_mlp_params(str(input_dim), suggested_str, str(output_dim), task_type, logs)
765
- return suggested_str, logs, param_count_str
766
-
767
- pt_mlp_suggest_btn.click(fn=mlp_suggest_proxy_wrapper,
768
- inputs=[pytorch_param_range_dd, current_logs_state, dataset_preview_df, target_column_name_txt, task_type_dd],
769
- outputs=[pt_mlp_hidden_layers_txt, training_log_txt, pt_mlp_param_count_txt])
770
-
771
- def mlp_estimate_proxy_wrapper(hidden_layers, current_logs, preview_df, target_col, task_type):
772
- input_dim, output_dim, _, _, logs = get_data_dims_for_nn_suggestion(preview_df, target_col, task_type, current_logs)
773
- logs += f"Using estimated input_dim: {input_dim}, output_dim: {output_dim} for MLP param estimation.\n"
774
- param_count_str, logs = estimate_current_mlp_params(str(input_dim), hidden_layers, str(output_dim), task_type, logs)
775
- return logs, param_count_str
776
-
777
- pt_mlp_estimate_params_btn.click(fn=mlp_estimate_proxy_wrapper,
778
- inputs=[pt_mlp_hidden_layers_txt, current_logs_state, dataset_preview_df, target_column_name_txt, task_type_dd],
779
- outputs=[training_log_txt, pt_mlp_param_count_txt])
780
-
781
- def cnn_estimate_proxy_wrapper(current_logs, preview_df, target_col, task_type):
782
- _, output_dim, img_h, img_w, logs = get_data_dims_for_nn_suggestion(preview_df, target_col, task_type, current_logs)
783
- logs += f"Using estimated img_h: {img_h}, img_w: {img_w}, output_dim: {output_dim} for CNN param estimation.\n"
784
- cnn_task_type = "classification" if task_type == "Basic Image Classification" else "regression"
785
- param_count_str, logs = estimate_cnn_params(str(img_h), str(img_w), str(output_dim), cnn_task_type, logs)
786
- return logs, param_count_str
787
-
788
- pt_cnn_estimate_params_btn.click(fn=cnn_estimate_proxy_wrapper,
789
- inputs=[current_logs_state, dataset_preview_df, target_column_name_txt, task_type_dd],
790
- outputs=[training_log_txt, pt_cnn_param_count_txt])
791
-
792
- def toggle_dataset_source_groups(source_choice):
793
- return gr.update(visible=(source_choice == "Generate new dataset")), gr.update(visible=(source_choice == "Upload my own dataset (CSV, JSON, Parquet)"))
794
- dataset_source_rb.change(fn=toggle_dataset_source_groups, inputs=dataset_source_rb, outputs=[generate_dataset_group, upload_dataset_group])
795
  model_family_dd.change(fn=update_model_output_formats, inputs=model_family_dd, outputs=model_output_format_dd)
796
 
 
797
  generate_dataset_btn.click(
798
  fn=generate_dataset_backend,
799
- inputs=[task_type_dd, ds_gen_samples_num, ds_gen_features_num, ds_gen_classes_informative_num,
800
- ds_gen_format_dd, ds_gen_ai_suggest_cb, pytorch_param_range_dd, model_specific_dd, current_logs_state],
801
- outputs=[dataset_preview_df, generated_data_state, training_log_txt, generated_dataset_download_file])
802
-
803
- def process_uploaded_file(file_obj, logs_in):
804
- logs, df_preview, stored_data_path = logs_in, None, None
805
- if file_obj is None: logs += "Please upload a file first.\n"; return df_preview, logs, stored_data_path
806
- logs += f"Uploaded file: {file_obj.name}\n"; stored_data_path = file_obj.name
807
- try:
808
- if file_obj.name.endswith(".csv"): df_preview = pd.read_csv(file_obj.name, nrows=5)
809
- elif file_obj.name.endswith(".json"): df_preview = pd.read_json(file_obj.name, lines=True, nrows=5)
810
- elif file_obj.name.endswith(".parquet"): temp_df = pd.read_parquet(file_obj.name); df_preview = temp_df.head()
811
- logs += "Preview generated for uploaded file.\n"
812
- except Exception as e: logs += f"Error previewing {file_obj.name}: {e}\n"
813
- return df_preview, logs, stored_data_path
814
- ds_upload_file.upload(fn=process_uploaded_file, inputs=[ds_upload_file, current_logs_state],
815
- outputs=[dataset_preview_df, training_log_txt, generated_data_state])
816
 
 
817
  train_model_btn.click(
818
  fn=train_model_wrapper,
819
- inputs=[generated_data_state, target_column_name_txt, task_type_dd, model_family_dd, model_specific_dd,
820
- pt_mlp_hidden_layers_txt, pt_mlp_activation_dd,
821
- train_epochs_num, train_batch_size_num, train_learning_rate_num,
822
- model_output_format_dd, training_log_txt],
823
- outputs=[training_log_txt, evaluation_metrics_txt, download_trained_model_file,
824
- model_param_count_output_txt, loss_plot_img])
825
 
 
826
  demo.queue().launch(debug=True, show_error=True)
 
1
+ # --- Standard Library Imports ---
2
+ import os
3
+ import time
4
+ import traceback
5
+ import tempfile
6
+ import json
7
+ import math
8
+ import collections
9
+ import collections.abc # For Gradio compatibility with newer Python versions
10
+
11
+ # --- UI Framework ---
12
  import gradio as gr
13
+
14
+ # --- Data Handling & Numerical Ops ---
15
  import pandas as pd
16
  import numpy as np
17
+
18
+ # --- Core Machine Learning (Scikit-learn) ---
19
  from sklearn.model_selection import train_test_split
20
  from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder
21
  from sklearn.impute import SimpleImputer
22
  from sklearn.compose import ColumnTransformer
23
  from sklearn.pipeline import Pipeline
 
24
  from sklearn.linear_model import LogisticRegression, LinearRegression
25
  from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
26
  from sklearn.svm import SVC, SVR
 
27
  from sklearn.metrics import accuracy_score, classification_report, mean_squared_error, r2_score
 
28
  from sklearn.datasets import make_classification, make_regression
 
29
  import joblib
30
+
31
+ # --- Core Machine Learning (PyTorch) ---
32
  import torch
33
  import torch.nn as nn
34
  import torch.optim as optim
35
  from torch.utils.data import TensorDataset, DataLoader
 
 
36
 
37
+ # --- ONNX Support for Model Interoperability ---
38
  import skl2onnx
39
  from skl2onnx import convert_sklearn
40
+ from skl2onnx.common.data_types import FloatTensorType, StringTensorType
 
41
 
42
+ # --- Visualization ---
43
+ import matplotlib
44
+ matplotlib.use('Agg') # Use non-interactive backend for server environments
 
 
 
 
 
45
  import matplotlib.pyplot as plt
46
 
47
+ # --- Graceful ONNX Runtime Handling ---
48
+ # This addresses the system-level ImportError on platforms like Hugging Face Spaces.
49
+ try:
50
+ import onnxruntime as rt
51
+ ONNX_RUNTIME_AVAILABLE = True
52
+ except ImportError:
53
+ ONNX_RUNTIME_AVAILABLE = False
54
+ print("Warning: onnxruntime could not be imported. ONNX model validation will be skipped.")
55
+ # --- End of Imports ---
56
+
57
 
58
+ # --- Global Variables & Constants ---
59
  TEMP_DIR = "temp_outputs"
60
  os.makedirs(TEMP_DIR, exist_ok=True)
 
61
  MAX_GENERATED_ROWS = 50000
62
  MAX_GENERATED_COLS = 100
63
+ PARAM_RANGES = collections.OrderedDict([
64
+ ("Tiny (<10k)", (0, 10000)),
65
+ ("Small (10k-50k)", (10000, 50000)),
66
+ ("Medium (50k-250k)", (50000, 250000)),
67
+ ("Large (250k-1M)", (250000, 1000000)),
68
+ ])
69
 
70
+ # --- Helper Functions ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def get_temp_filepath(filename_base, extension):
72
+ """Generates a unique temporary filepath."""
73
  clean_extension = extension.lstrip('.')
74
  return os.path.join(TEMP_DIR, f"{filename_base}_{time.strftime('%Y%m%d-%H%M%S')}.{clean_extension}")
75
 
76
+ # --- PyTorch Model Definitions ---
77
  class SimpleMLP(nn.Module):
78
+ """A simple Multi-Layer Perceptron."""
79
  def __init__(self, input_dim, hidden_layers_str, output_dim, activation_fn_str="relu", task_type="classification"):
80
+ super().__init__()
81
  layers = []
82
+ hidden_units = [int(x.strip()) for x in hidden_layers_str.split(',') if x.strip()]
 
 
 
 
 
 
 
 
 
 
83
 
84
  current_dim = input_dim
85
+ for h_units in hidden_units:
86
  layers.append(nn.Linear(current_dim, h_units))
87
  if activation_fn_str.lower() == "relu": layers.append(nn.ReLU())
88
  elif activation_fn_str.lower() == "tanh": layers.append(nn.Tanh())
89
  elif activation_fn_str.lower() == "sigmoid": layers.append(nn.Sigmoid())
 
90
  current_dim = h_units
91
 
92
  layers.append(nn.Linear(current_dim, output_dim))
93
 
94
+ if task_type == "classification" and output_dim == 1:
95
+ layers.append(nn.Sigmoid()) # For BCELoss
96
+ # For multi-class, CrossEntropyLoss expects raw logits, so no final activation.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ self.network = nn.Sequential(*layers)
 
 
 
 
 
 
 
99
 
100
  def forward(self, x):
101
+ return self.network(x)
102
+
103
+ # --- Dataset and Preprocessing Logic ---
104
+ def generate_dataset_backend(task_type, n_samples, n_features, n_classes_or_informative, dataset_format):
105
+ """Generates synthetic data based on user specifications."""
106
+ logs = "\n--- Generating Dataset ---\n"
107
+ n_samples = max(10, min(int(n_samples), MAX_GENERATED_ROWS))
108
+ n_features = max(1, min(int(n_features), MAX_GENERATED_COLS))
109
+ n_classes_or_informative = int(n_classes_or_informative)
110
+ df = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
 
 
 
 
 
112
  try:
113
  if task_type == "Tabular Classification":
114
+ X, y = make_classification(n_samples=n_samples, n_features=n_features, n_informative=max(1, n_features // 2),
115
+ n_redundant=0, n_classes=max(2, n_classes_or_informative), random_state=42)
116
+ df = pd.DataFrame(X, columns=[f'feature_{i}' for i in range(n_features)])
117
+ df['target'] = y
 
 
118
  elif task_type == "Tabular Regression":
119
+ X, y = make_regression(n_samples=n_samples, n_features=n_features,
120
+ n_informative=max(1, min(n_features, n_classes_or_informative)), noise=10, random_state=42)
121
+ df = pd.DataFrame(X, columns=[f'feature_{i}' for i in range(n_features)])
122
+ df['target'] = y
 
 
 
 
 
 
 
 
123
 
124
+ if df is None:
125
+ raise NotImplementedError(f"Dataset generation for '{task_type}' is not implemented.")
126
+
127
+ logs += f"Generated data with shape: {df.shape}\n"
128
  file_path = get_temp_filepath("generated_dataset", dataset_format)
129
+
130
+ if dataset_format == ".csv": df.to_csv(file_path, index=False)
131
+ elif dataset_format == ".json": df.to_json(file_path, orient='records', lines=True)
132
+ elif dataset_format == ".parquet": df.to_parquet(file_path, index=False)
133
+
134
+ logs += f"Dataset saved to temporary file: {os.path.basename(file_path)}\n"
135
+ return df.head(), df, logs, file_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
+ except Exception as e:
138
+ error_msg = f"Error generating dataset: {traceback.format_exc()}"
139
+ logs += error_msg + "\n"
140
+ return None, None, logs, None
141
 
142
+ # --- Core Training Functions ---
143
+ def train_model_sklearn(data_input, target_column, task_type, model_name, model_output_format, logs=""):
144
+ """Handles the entire Scikit-learn training and evaluation pipeline."""
145
+ logs += f"\n--- Training Scikit-learn Model: {model_name} ---\n"
 
 
 
 
 
 
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  try:
148
+ if isinstance(data_input, str): # Is a filepath
149
+ if data_input.endswith('.csv'): df = pd.read_csv(data_input)
150
+ else: raise ValueError("Unsupported file type for upload.")
151
+ else: # Is a DataFrame from generation
152
+ df = data_input
153
+
154
+ if target_column not in df.columns:
155
+ raise ValueError(f"Target column '{target_column}' not found.")
156
+
157
+ # Preprocessing
158
+ X = df.drop(columns=[target_column])
159
+ y = df[target_column]
160
+ numeric_features = X.select_dtypes(include=np.number).columns
161
+ categorical_features = X.select_dtypes(include='object').columns
162
+
163
+ preprocessor = ColumnTransformer(transformers=[
164
+ ('num', Pipeline([('imputer', SimpleImputer(strategy='mean')), ('scaler', StandardScaler())]), numeric_features),
165
+ ('cat', Pipeline([('imputer', SimpleImputer(strategy='most_frequent')), ('onehot', OneHotEncoder(handle_unknown='ignore'))]), categorical_features)
166
+ ])
167
+
168
+ # Model Selection
169
+ if task_type == "Tabular Classification":
170
+ y = LabelEncoder().fit_transform(y)
171
+ models = {
172
+ "Logistic Regression": LogisticRegression(max_iter=1000, random_state=42),
173
+ "Random Forest Classifier": RandomForestClassifier(random_state=42),
174
+ "Support Vector Machine (SVM) Classifier": SVC(random_state=42, probability=True)
175
+ }
176
+ else: # Regression
177
+ models = {
178
+ "Linear Regression": LinearRegression(),
179
+ "Random Forest Regressor": RandomForestRegressor(random_state=42),
180
+ "Support Vector Machine (SVR) Regressor": SVR()
181
+ }
182
+ model = models[model_name]
183
+
184
+ # Create full pipeline
185
+ pipeline = Pipeline([('preprocessor', preprocessor), ('model', model)])
186
+
187
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
188
+ logs += f"Data split into training ({X_train.shape}) and testing ({X_test.shape}) sets.\n"
189
+
190
+ # Training
191
+ start_time = time.time()
192
+ pipeline.fit(X_train, y_train)
193
  logs += f"Training completed in {time.time() - start_time:.2f}s.\n"
194
+
195
+ # Evaluation
196
+ y_pred = pipeline.predict(X_test)
 
197
  if task_type == "Tabular Classification":
198
+ acc = accuracy_score(y_test, y_pred)
199
+ report = classification_report(y_test, y_pred, zero_division=0)
200
+ metrics = f"Accuracy: {acc:.4f}\n\nClassification Report:\n{report}"
201
+ else:
202
+ mse = mean_squared_error(y_test, y_pred)
203
+ r2 = r2_score(y_test, y_pred)
204
+ metrics = f"Mean Squared Error: {mse:.4f}\nR² Score: {r2:.4f}"
205
+ logs += "\n--- Evaluation Metrics ---\n" + metrics + "\n"
206
+
207
+ # Model Saving
208
  model_filename_base = f"sklearn_{model_name.replace(' ', '_').lower()}"
 
209
  if model_output_format == ".pkl (Scikit-learn)":
210
+ model_path = get_temp_filepath(model_filename_base, "pkl")
211
+ joblib.dump(pipeline, model_path)
212
+ logs += f"Model pipeline saved to {os.path.basename(model_path)} as PKL.\n"
213
  elif model_output_format == ".onnx (ONNX)":
214
+ model_path = get_temp_filepath(model_filename_base, "onnx")
215
+ initial_types = []
216
+ for col_name in X.columns:
217
+ if pd.api.types.is_numeric_dtype(X[col_name].dtype):
218
+ initial_types.append((col_name, FloatTensorType([None, 1])))
 
 
 
 
219
  else:
220
+ initial_types.append((col_name, StringTensorType([None, 1])))
 
221
 
222
+ onnx_model = convert_sklearn(pipeline, initial_types=initial_types, target_opset=12)
223
+ with open(model_path, "wb") as f: f.write(onnx_model.SerializeToString())
224
+ logs += f"Model pipeline saved to {os.path.basename(model_path)} as ONNX.\n"
225
+
226
+ if ONNX_RUNTIME_AVAILABLE:
227
+ sess = rt.InferenceSession(model_path)
228
+ logs += "ONNX model successfully loaded and validated with onnxruntime.\n"
229
+ else:
230
+ logs += "ONNX model validation skipped because onnxruntime is not available in this environment.\n"
231
+
232
+ return logs, metrics, model_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  except Exception as e:
235
+ error_msg = f"Scikit-learn training failed: {traceback.format_exc()}"
236
  logs += error_msg + "\n"
237
+ return logs, error_msg, None
 
238
 
239
+ # --- Main Training Dispatcher ---
240
+ def train_model_wrapper(data_input, target_column, task_type, model_family, model_specific,
241
+ model_output_format, logs):
242
+ """A wrapper to call the correct training function based on user choices."""
243
+ if data_input is None:
244
+ logs += "ERROR: No dataset has been generated or uploaded. Please go to Tab 2.\n"
245
+ return logs, "Error: No dataset available.", None, None
 
 
 
 
 
 
 
 
 
246
 
247
+ if model_family == "Scikit-learn (Classical ML)":
248
+ logs, metrics, model_path = train_model_sklearn(data_input, target_column, task_type, model_specific, model_output_format, logs)
249
+ return logs, metrics, model_path, None # No plot for sklearn
250
+
251
+ # Placeholder for PyTorch integration if added back
252
+ elif model_family == "PyTorch (Neural Networks)":
253
+ logs += "PyTorch training is not fully integrated in this version yet.\n"
254
+ return logs, "PyTorch not available.", None, None
255
+
256
+ else:
257
+ logs += f"Unknown model family: {model_family}\n"
258
+ return logs, "Error: Unknown model family.", None, None
259
 
260
+ # --- Gradio UI Definition ---
261
  def update_model_options(task_choice, model_family_choice):
262
+ """Dynamically updates the available models based on task and family."""
263
+ choices = []
264
  if model_family_choice == "Scikit-learn (Classical ML)":
265
+ if task_choice == "Tabular Classification":
266
+ choices = ["Logistic Regression", "Random Forest Classifier", "Support Vector Machine (SVM) Classifier"]
267
+ elif task_choice == "Tabular Regression":
268
+ choices = ["Linear Regression", "Random Forest Regressor", "Support Vector Machine (SVR) Regressor"]
269
+ # Add PyTorch options here if needed
270
+
271
  value = choices[0] if choices else None
272
  return gr.update(choices=choices, value=value, visible=bool(choices))
273
 
 
 
 
 
 
 
274
  def update_model_output_formats(model_family_choice):
275
+ """Updates the output format options based on the model family."""
276
+ formats = []
277
+ if model_family_choice == "Scikit-learn (Classical ML)":
278
+ formats = [".pkl (Scikit-learn)", ".onnx (ONNX)"]
279
+ # Add PyTorch formats here
280
+
281
+ value = formats[0] if formats else None
282
+ return gr.update(choices=formats, value=value)
283
 
284
+ # The Gradio App Layout
285
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="orange")) as demo:
286
  gr.Markdown("# 🧠 TrainAI ⚙️")
287
+ gr.Markdown("A simple interface to create, train, and download machine learning models.")
288
+
289
+ # State variables to hold data between interactions
290
+ generated_data_state = gr.State(None)
291
+
292
  with gr.Tabs():
293
  with gr.TabItem("1. Define Task & Model"):
294
  with gr.Row():
295
+ task_type_dd = gr.Dropdown(["Tabular Classification", "Tabular Regression"], label="Select Task Type", value="Tabular Classification")
296
+ model_family_dd = gr.Dropdown(["Scikit-learn (Classical ML)"], label="Select Model Family", value="Scikit-learn (Classical ML)")
 
297
 
298
+ model_specific_dd = gr.Dropdown(label="Select Specific Model", choices=["Logistic Regression", "Random Forest Classifier", "Support Vector Machine (SVM) Classifier"], value="Logistic Regression", interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
  with gr.TabItem("2. Configure Dataset"):
301
+ with gr.Row():
302
+ ds_gen_samples_num = gr.Number(label="# Samples", value=1000, minimum=10, step=100)
303
+ ds_gen_features_num = gr.Number(label="# Features", value=10, minimum=1, step=1)
304
+ ds_gen_classes_num = gr.Number(label="Classes (Classif) / Informative (Regr)", value=2, minimum=1, step=1)
305
+ ds_gen_format_dd = gr.Dropdown([".csv", ".json", ".parquet"], label="Generated Dataset Format", value=".csv")
306
+ generate_dataset_btn = gr.Button("Generate & Preview Dataset", variant="secondary")
307
+
308
+ target_column_name_txt = gr.Textbox(label="Target Column Name", value="target", interactive=True)
 
 
 
 
 
 
 
309
  dataset_preview_df = gr.DataFrame(label="Dataset Preview (First 5 Rows)", interactive=False, height=200)
310
  generated_dataset_download_file = gr.File(label="Download Generated Dataset", interactive=False)
311
 
312
  with gr.TabItem("3. Train Model & Get Results"):
313
+ model_output_format_dd = gr.Dropdown(label="Select Model Output Format", choices=[".pkl (Scikit-learn)", ".onnx (ONNX)"], value=".pkl (Scikit-learn)")
 
 
 
 
 
314
  train_model_btn = gr.Button("🚀 Train Model", variant="primary")
315
  gr.Markdown("---")
316
  gr.Markdown("### Training Progress & Results")
317
  training_log_txt = gr.Textbox(label="Training Log & Status", lines=15, interactive=False, max_lines=50)
 
318
  evaluation_metrics_txt = gr.Textbox(label="Evaluation Metrics", lines=7, interactive=False)
 
319
  download_trained_model_file = gr.File(label="Download Trained Model", interactive=False)
320
+ loss_plot_img = gr.Plot(label="Training Loss Curve (PyTorch only)", visible=False) # Hide for now
321
 
 
 
 
322
  # --- Event Handlers ---
323
+
324
+ # Update model choices when task or family changes
325
  task_type_dd.change(fn=update_model_options, inputs=[task_type_dd, model_family_dd], outputs=model_specific_dd)
326
  model_family_dd.change(fn=update_model_options, inputs=[task_type_dd, model_family_dd], outputs=model_specific_dd)
327
 
328
+ # Update output formats when family changes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  model_family_dd.change(fn=update_model_output_formats, inputs=model_family_dd, outputs=model_output_format_dd)
330
 
331
+ # Dataset generation button
332
  generate_dataset_btn.click(
333
  fn=generate_dataset_backend,
334
+ inputs=[task_type_dd, ds_gen_samples_num, ds_gen_features_num, ds_gen_classes_num, ds_gen_format_dd],
335
+ outputs=[dataset_preview_df, generated_data_state, training_log_txt, generated_dataset_download_file]
336
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
+ # Main training button
339
  train_model_btn.click(
340
  fn=train_model_wrapper,
341
+ inputs=[generated_data_state, target_column_name_txt, task_type_dd, model_family_dd, model_specific_dd, model_output_format_dd, training_log_txt],
342
+ outputs=[training_log_txt, evaluation_metrics_txt, download_trained_model_file, loss_plot_img]
343
+ )
 
 
 
344
 
345
+ # Launch the application
346
  demo.queue().launch(debug=True, show_error=True)