|
|
|
|
|
import os |
|
|
import time |
|
|
import traceback |
|
|
import tempfile |
|
|
import json |
|
|
import math |
|
|
import collections |
|
|
import collections.abc |
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
from sklearn.model_selection import train_test_split |
|
|
from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder |
|
|
from sklearn.impute import SimpleImputer |
|
|
from sklearn.compose import ColumnTransformer |
|
|
from sklearn.pipeline import Pipeline |
|
|
from sklearn.linear_model import LogisticRegression, LinearRegression |
|
|
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor |
|
|
from sklearn.svm import SVC, SVR |
|
|
from sklearn.metrics import accuracy_score, classification_report, mean_squared_error, r2_score |
|
|
from sklearn.datasets import make_classification, make_regression |
|
|
import joblib |
|
|
|
|
|
|
|
|
import skl2onnx |
|
|
from skl2onnx import convert_sklearn |
|
|
from skl2onnx.common.data_types import FloatTensorType, StringTensorType |
|
|
|
|
|
|
|
|
import matplotlib |
|
|
matplotlib.use('Agg') |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
import onnxruntime as rt |
|
|
ONNX_RUNTIME_AVAILABLE = True |
|
|
except ImportError: |
|
|
ONNX_RUNTIME_AVAILABLE = False |
|
|
print("Warning: onnxruntime could not be imported. ONNX model validation will be skipped.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TEMP_DIR = "temp_outputs" |
|
|
os.makedirs(TEMP_DIR, exist_ok=True) |
|
|
MAX_GENERATED_ROWS = 50000 |
|
|
MAX_GENERATED_COLS = 100 |
|
|
|
|
|
|
|
|
def get_temp_filepath(filename_base, extension): |
|
|
"""Generates a unique temporary filepath.""" |
|
|
clean_extension = extension.lstrip('.') |
|
|
return os.path.join(TEMP_DIR, f"{filename_base}_{time.strftime('%Y%m%d-%H%M%S')}.{clean_extension}") |
|
|
|
|
|
|
|
|
def generate_dataset_backend(task_type, n_samples, n_features, n_classes_or_informative, dataset_format): |
|
|
"""Generates synthetic data based on user specifications.""" |
|
|
logs = "\n--- Generating Dataset ---\n" |
|
|
n_samples = max(10, min(int(n_samples), MAX_GENERATED_ROWS)) |
|
|
n_features = max(1, min(int(n_features), MAX_GENERATED_COLS)) |
|
|
n_classes_or_informative = int(n_classes_or_informative) |
|
|
df = None |
|
|
|
|
|
try: |
|
|
if task_type == "Tabular Classification": |
|
|
X, y = make_classification(n_samples=n_samples, n_features=n_features, n_informative=max(1, n_features // 2), |
|
|
n_redundant=0, n_classes=max(2, n_classes_or_informative), random_state=42) |
|
|
df = pd.DataFrame(X, columns=[f'feature_{i}' for i in range(n_features)]) |
|
|
df['target'] = y |
|
|
elif task_type == "Tabular Regression": |
|
|
X, y = make_regression(n_samples=n_samples, n_features=n_features, |
|
|
n_informative=max(1, min(n_features, n_classes_or_informative)), noise=10, random_state=42) |
|
|
df = pd.DataFrame(X, columns=[f'feature_{i}' for i in range(n_features)]) |
|
|
df['target'] = y |
|
|
|
|
|
if df is None: |
|
|
raise NotImplementedError(f"Dataset generation for '{task_type}' is not implemented.") |
|
|
|
|
|
logs += f"Generated data with shape: {df.shape}\n" |
|
|
file_path = get_temp_filepath("generated_dataset", dataset_format) |
|
|
|
|
|
if dataset_format == ".csv": df.to_csv(file_path, index=False) |
|
|
elif dataset_format == ".json": df.to_json(file_path, orient='records', lines=True) |
|
|
elif dataset_format == ".parquet": df.to_parquet(file_path, index=False) |
|
|
|
|
|
logs += f"Dataset saved to temporary file: {os.path.basename(file_path)}\n" |
|
|
return df.head(), df, logs, file_path |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error generating dataset: {traceback.format_exc()}" |
|
|
logs += error_msg + "\n" |
|
|
return None, None, logs, None |
|
|
|
|
|
|
|
|
def train_model_sklearn(data_input, target_column, task_type, model_name, model_output_format, logs=""): |
|
|
"""Handles the entire Scikit-learn training and evaluation pipeline.""" |
|
|
logs += f"\n--- Training Scikit-learn Model: {model_name} ---\n" |
|
|
|
|
|
try: |
|
|
|
|
|
df = data_input |
|
|
if isinstance(data_input, str): |
|
|
if data_input.endswith('.csv'): df = pd.read_csv(data_input) |
|
|
elif data_input.endswith('.json'): df = pd.read_json(data_input, lines=True) |
|
|
elif data_input.endswith('.parquet'): df = pd.read_parquet(data_input) |
|
|
else: raise ValueError("Unsupported file type for upload.") |
|
|
|
|
|
if target_column not in df.columns: |
|
|
raise ValueError(f"Target column '{target_column}' not found.") |
|
|
|
|
|
|
|
|
X = df.drop(columns=[target_column]) |
|
|
y = df[target_column] |
|
|
numeric_features = X.select_dtypes(include=np.number).columns |
|
|
categorical_features = X.select_dtypes(include='object').columns |
|
|
|
|
|
preprocessor = ColumnTransformer(transformers=[ |
|
|
('num', Pipeline([('imputer', SimpleImputer(strategy='mean')), ('scaler', StandardScaler())]), numeric_features), |
|
|
('cat', Pipeline([('imputer', SimpleImputer(strategy='most_frequent')), ('onehot', OneHotEncoder(handle_unknown='ignore'))]), categorical_features) |
|
|
]) |
|
|
|
|
|
|
|
|
if task_type == "Tabular Classification": |
|
|
y = LabelEncoder().fit_transform(y) |
|
|
models = { |
|
|
"Logistic Regression": LogisticRegression(max_iter=1000, random_state=42), |
|
|
"Random Forest Classifier": RandomForestClassifier(random_state=42), |
|
|
"Support Vector Machine (SVM) Classifier": SVC(random_state=42, probability=True) |
|
|
} |
|
|
else: |
|
|
models = { |
|
|
"Linear Regression": LinearRegression(), |
|
|
"Random Forest Regressor": RandomForestRegressor(random_state=42), |
|
|
"Support Vector Machine (SVR) Regressor": SVR() |
|
|
} |
|
|
model = models[model_name] |
|
|
|
|
|
|
|
|
pipeline = Pipeline([('preprocessor', preprocessor), ('model', model)]) |
|
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) |
|
|
logs += f"Data split into training ({X_train.shape}) and testing ({X_test.shape}) sets.\n" |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
pipeline.fit(X_train, y_train) |
|
|
logs += f"Training completed in {time.time() - start_time:.2f}s.\n" |
|
|
|
|
|
|
|
|
y_pred = pipeline.predict(X_test) |
|
|
if task_type == "Tabular Classification": |
|
|
acc = accuracy_score(y_test, y_pred) |
|
|
report = classification_report(y_test, y_pred, zero_division=0) |
|
|
metrics = f"Accuracy: {acc:.4f}\n\nClassification Report:\n{report}" |
|
|
else: |
|
|
mse = mean_squared_error(y_test, y_pred) |
|
|
r2 = r2_score(y_test, y_pred) |
|
|
metrics = f"Mean Squared Error: {mse:.4f}\nR² Score: {r2:.4f}" |
|
|
logs += "\n--- Evaluation Metrics ---\n" + metrics + "\n" |
|
|
|
|
|
|
|
|
model_filename_base = f"sklearn_{model_name.replace(' ', '_').lower()}" |
|
|
model_path = None |
|
|
if model_output_format == ".pkl (Scikit-learn)": |
|
|
model_path = get_temp_filepath(model_filename_base, "pkl") |
|
|
joblib.dump(pipeline, model_path) |
|
|
logs += f"Model pipeline saved to {os.path.basename(model_path)} as PKL.\n" |
|
|
elif model_output_format == ".onnx (ONNX)": |
|
|
model_path = get_temp_filepath(model_filename_base, "onnx") |
|
|
initial_types = [] |
|
|
for col_name in X.columns: |
|
|
if pd.api.types.is_numeric_dtype(X[col_name].dtype): |
|
|
initial_types.append((col_name, FloatTensorType([None, 1]))) |
|
|
else: |
|
|
initial_types.append((col_name, StringTensorType([None, 1]))) |
|
|
|
|
|
options = {'zipmap': False} if task_type == "Tabular Classification" else {} |
|
|
onnx_model = convert_sklearn(pipeline, initial_types=initial_types, target_opset=12, options=options) |
|
|
with open(model_path, "wb") as f: f.write(onnx_model.SerializeToString()) |
|
|
logs += f"Model pipeline saved to {os.path.basename(model_path)} as ONNX.\n" |
|
|
|
|
|
if ONNX_RUNTIME_AVAILABLE: |
|
|
sess = rt.InferenceSession(model_path) |
|
|
logs += "ONNX model successfully loaded and validated with onnxruntime.\n" |
|
|
else: |
|
|
logs += "ONNX model validation skipped because onnxruntime is not available in this environment.\n" |
|
|
|
|
|
return logs, metrics, model_path |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Scikit-learn training failed: {traceback.format_exc()}" |
|
|
logs += error_msg + "\n" |
|
|
return logs, error_msg, None |
|
|
|
|
|
|
|
|
def train_model_wrapper(data_input, target_column, task_type, model_family, model_specific, |
|
|
model_output_format, logs): |
|
|
"""A wrapper to call the correct training function based on user choices.""" |
|
|
if data_input is None: |
|
|
logs += "ERROR: No dataset has been generated or uploaded. Please go to Tab 2.\n" |
|
|
return logs, "Error: No dataset available.", None, None |
|
|
|
|
|
if model_family == "Scikit-learn (Classical ML)": |
|
|
logs, metrics, model_path = train_model_sklearn(data_input, target_column, task_type, model_specific, model_output_format, logs) |
|
|
return logs, metrics, model_path, None |
|
|
|
|
|
|
|
|
else: |
|
|
logs += f"The selected model family '{model_family}' is not supported yet.\n" |
|
|
return logs, "Error: Model family not supported.", None, None |
|
|
|
|
|
|
|
|
def update_model_options(task_choice, model_family_choice): |
|
|
"""Dynamically updates the available models based on task and family.""" |
|
|
choices = [] |
|
|
if model_family_choice == "Scikit-learn (Classical ML)": |
|
|
if task_choice == "Tabular Classification": |
|
|
choices = ["Logistic Regression", "Random Forest Classifier", "Support Vector Machine (SVM) Classifier"] |
|
|
elif task_choice == "Tabular Regression": |
|
|
choices = ["Linear Regression", "Random Forest Regressor", "Support Vector Machine (SVR) Regressor"] |
|
|
|
|
|
value = choices[0] if choices else None |
|
|
return gr.update(choices=choices, value=value, visible=bool(choices)) |
|
|
|
|
|
def update_model_output_formats(model_family_choice): |
|
|
"""Updates the output format options based on the model family.""" |
|
|
formats = [] |
|
|
if model_family_choice == "Scikit-learn (Classical ML)": |
|
|
formats = [".pkl (Scikit-learn)", ".onnx (ONNX)"] |
|
|
|
|
|
value = formats[0] if formats else None |
|
|
return gr.update(choices=formats, value=value) |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="orange")) as demo: |
|
|
gr.Markdown("# 🧠 TrainAI ⚙️") |
|
|
gr.Markdown("A simple interface to create, train, and download machine learning models.") |
|
|
|
|
|
|
|
|
generated_data_state = gr.State(None) |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("1. Define Task & Model"): |
|
|
with gr.Row(): |
|
|
task_type_dd = gr.Dropdown(["Tabular Classification", "Tabular Regression"], label="Select Task Type", value="Tabular Classification") |
|
|
model_family_dd = gr.Dropdown(["Scikit-learn (Classical ML)"], label="Select Model Family", value="Scikit-learn (Classical ML)") |
|
|
|
|
|
model_specific_dd = gr.Dropdown(label="Select Specific Model", choices=["Logistic Regression", "Random Forest Classifier", "Support Vector Machine (SVM) Classifier"], value="Logistic Regression", interactive=True) |
|
|
|
|
|
with gr.TabItem("2. Configure Dataset"): |
|
|
with gr.Row(): |
|
|
ds_gen_samples_num = gr.Number(label="# Samples", value=1000, minimum=10, step=100) |
|
|
ds_gen_features_num = gr.Number(label="# Features", value=10, minimum=1, step=1) |
|
|
ds_gen_classes_num = gr.Number(label="Classes (Classif) / Informative (Regr)", value=2, minimum=1, step=1) |
|
|
ds_gen_format_dd = gr.Dropdown([".csv", ".json", ".parquet"], label="Generated Dataset Format", value=".csv") |
|
|
generate_dataset_btn = gr.Button("Generate & Preview Dataset", variant="secondary") |
|
|
|
|
|
target_column_name_txt = gr.Textbox(label="Target Column Name", value="target", interactive=True) |
|
|
|
|
|
|
|
|
dataset_preview_df = gr.DataFrame(label="Dataset Preview (First 5 Rows)", interactive=False, row_count=5) |
|
|
|
|
|
|
|
|
generated_dataset_download_file = gr.File(label="Download Generated Dataset", interactive=False) |
|
|
|
|
|
with gr.TabItem("3. Train Model & Get Results"): |
|
|
model_output_format_dd = gr.Dropdown(label="Select Model Output Format", choices=[".pkl (Scikit-learn)", ".onnx (ONNX)"], value=".pkl (Scikit-learn)") |
|
|
train_model_btn = gr.Button("🚀 Train Model", variant="primary") |
|
|
gr.Markdown("---") |
|
|
gr.Markdown("### Training Progress & Results") |
|
|
training_log_txt = gr.Textbox(label="Training Log & Status", lines=15, interactive=False, max_lines=50) |
|
|
evaluation_metrics_txt = gr.Textbox(label="Evaluation Metrics", lines=7, interactive=False) |
|
|
download_trained_model_file = gr.File(label="Download Trained Model", interactive=False) |
|
|
loss_plot_img = gr.Plot(label="Training Loss Curve (PyTorch only)", visible=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
task_type_dd.change(fn=update_model_options, inputs=[task_type_dd, model_family_dd], outputs=model_specific_dd) |
|
|
model_family_dd.change(fn=update_model_options, inputs=[task_type_dd, model_family_dd], outputs=model_specific_dd) |
|
|
|
|
|
|
|
|
model_family_dd.change(fn=update_model_output_formats, inputs=model_family_dd, outputs=model_output_format_dd) |
|
|
|
|
|
|
|
|
generate_dataset_btn.click( |
|
|
fn=generate_dataset_backend, |
|
|
inputs=[task_type_dd, ds_gen_samples_num, ds_gen_features_num, ds_gen_classes_num, ds_gen_format_dd], |
|
|
outputs=[dataset_preview_df, generated_data_state, training_log_txt, generated_dataset_download_file] |
|
|
) |
|
|
|
|
|
|
|
|
train_model_btn.click( |
|
|
fn=train_model_wrapper, |
|
|
inputs=[generated_data_state, target_column_name_txt, task_type_dd, model_family_dd, model_specific_dd, model_output_format_dd, training_log_txt], |
|
|
outputs=[training_log_txt, evaluation_metrics_txt, download_trained_model_file, loss_plot_img] |
|
|
) |
|
|
|
|
|
|
|
|
demo.queue().launch(debug=True, show_error=True) |