Spaces:
Runtime error
Runtime error
| # app.py | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| from datasets import load_dataset | |
| from sklearn.ensemble import GradientBoostingClassifier | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import accuracy_score, confusion_matrix | |
| # In some remote environments, Matplotlib needs to be set to 'Agg' backend | |
| matplotlib.use('Agg') | |
| ################################################################################ | |
| # SUGGESTED_DATASETS: Must actually exist on huggingface.co/datasets. | |
| # | |
| # "scikit-learn/iris" -> a tabular Iris dataset with a "train" split of 150 rows. | |
| # "uci/wine" -> a tabular Wine dataset with a "train" split of 178 rows. | |
| ################################################################################ | |
| SUGGESTED_DATASETS = [ | |
| "scikit-learn/iris", | |
| "uci/wine", | |
| "SKIP/ENTER_CUSTOM" # a placeholder meaning "use custom_dataset_id" | |
| ] | |
| def update_columns(dataset_id, custom_dataset_id): | |
| """ | |
| Loads the chosen dataset (train split) and returns its column names, | |
| to populate the Label Column & Feature Columns selectors. | |
| """ | |
| # If user picked a suggested dataset (not SKIP), use that | |
| if dataset_id != "SKIP/ENTER_CUSTOM": | |
| final_id = dataset_id | |
| else: | |
| # Use the user-supplied dataset ID | |
| final_id = custom_dataset_id.strip() | |
| try: | |
| # Load just the "train" split; many HF datasets have train/test/validation | |
| ds = load_dataset(final_id, split="train") | |
| df = pd.DataFrame(ds) | |
| cols = df.columns.tolist() | |
| message = f"**Loaded dataset**: {final_id}\n\n**Columns found**: {cols}" | |
| # Return list of columns for both label & features | |
| return ( | |
| gr.update(choices=cols, value=None), # label_col dropdown | |
| gr.update(choices=cols, value=[]), # feature_cols checkbox group | |
| message | |
| ) | |
| except Exception as e: | |
| # If load fails or dataset doesn't exist | |
| err_msg = f"**Error loading** `{final_id}`: {e}" | |
| return ( | |
| gr.update(choices=[], value=None), | |
| gr.update(choices=[], value=[]), | |
| err_msg | |
| ) | |
| def train_model(dataset_id, custom_dataset_id, label_column, feature_columns, | |
| learning_rate, n_estimators, max_depth, test_size): | |
| """ | |
| 1. Determine the final dataset ID (from dropdown or custom text). | |
| 2. Load the dataset -> create dataframe -> X, y. | |
| 3. Train GradientBoostingClassifier. | |
| 4. Return metrics (accuracy) and a Matplotlib figure with: | |
| - Feature importance bar chart | |
| - Confusion matrix heatmap | |
| """ | |
| if dataset_id != "SKIP/ENTER_CUSTOM": | |
| final_id = dataset_id | |
| else: | |
| final_id = custom_dataset_id.strip() | |
| # Load dataset | |
| ds = load_dataset(final_id, split="train") | |
| df = pd.DataFrame(ds) | |
| # Basic validation | |
| if label_column not in df.columns: | |
| raise ValueError(f"Label column '{label_column}' not found in dataset columns.") | |
| for fc in feature_columns: | |
| if fc not in df.columns: | |
| raise ValueError(f"Feature column '{fc}' not found in dataset columns.") | |
| # Build X, y arrays | |
| X = df[feature_columns].values | |
| y = df[label_column].values | |
| # Split | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, test_size=test_size, random_state=42 | |
| ) | |
| # Train model | |
| clf = GradientBoostingClassifier( | |
| learning_rate=learning_rate, | |
| n_estimators=int(n_estimators), | |
| max_depth=int(max_depth), | |
| random_state=42 | |
| ) | |
| clf.fit(X_train, y_train) | |
| # Predictions & metrics | |
| y_pred = clf.predict(X_test) | |
| accuracy = accuracy_score(y_test, y_pred) | |
| cm = confusion_matrix(y_test, y_pred) | |
| # Build a single figure with 2 subplots: | |
| # 1) Feature importances | |
| # 2) Confusion matrix heatmap | |
| fig, axs = plt.subplots(1, 2, figsize=(10, 4)) | |
| # Subplot 1: Feature Importances | |
| importances = clf.feature_importances_ | |
| axs[0].barh(range(len(feature_columns)), importances, color='skyblue') | |
| axs[0].set_yticks(range(len(feature_columns))) | |
| axs[0].set_yticklabels(feature_columns) | |
| axs[0].set_xlabel("Importance") | |
| axs[0].set_title("Feature Importances") | |
| # Subplot 2: Confusion Matrix Heatmap | |
| im = axs[1].imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) | |
| axs[1].set_title("Confusion Matrix") | |
| plt.colorbar(im, ax=axs[1]) | |
| axs[1].set_xlabel("Predicted") | |
| axs[1].set_ylabel("True") | |
| # Optionally annotate each cell with the count | |
| thresh = cm.max() / 2.0 | |
| for i in range(cm.shape[0]): | |
| for j in range(cm.shape[1]): | |
| color = "white" if cm[i, j] > thresh else "black" | |
| axs[1].text(j, i, str(cm[i, j]), ha="center", va="center", color=color) | |
| plt.tight_layout() | |
| # Build textual summary | |
| text_summary = ( | |
| f"**Dataset used**: `{final_id}`\n\n" | |
| f"**Label column**: `{label_column}`\n\n" | |
| f"**Feature columns**: `{feature_columns}`\n\n" | |
| f"**Accuracy**: {accuracy:.3f}\n\n" | |
| ) | |
| return text_summary, fig | |
| # Build the Gradio Blocks UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Train a GradientBoostingClassifier on any HF Dataset\n") | |
| gr.Markdown( | |
| "1. Choose a suggested dataset from the dropdown **or** enter a custom dataset ID in the format `user/dataset`.\n" | |
| "2. Click **Load Columns** to inspect the columns.\n" | |
| "3. Pick a **Label column** and **Feature columns**.\n" | |
| "4. Adjust hyperparameters and click **Train & Evaluate**.\n" | |
| "5. Observe accuracy, feature importances, and a confusion matrix heatmap.\n\n" | |
| "*(Note: the dataset must have a `train` split!)*" | |
| ) | |
| # Row 1: Dataset selection | |
| with gr.Row(): | |
| dataset_dropdown = gr.Dropdown( | |
| label="Choose suggested dataset", | |
| choices=SUGGESTED_DATASETS, | |
| value=SUGGESTED_DATASETS[0] # default | |
| ) | |
| custom_dataset_id = gr.Textbox( | |
| label="Or enter a custom dataset ID", | |
| placeholder="e.g. username/my_custom_dataset" | |
| ) | |
| load_cols_btn = gr.Button("Load Columns") | |
| load_cols_info = gr.Markdown() | |
| # Row 2: label & feature columns | |
| with gr.Row(): | |
| label_col = gr.Dropdown(choices=[], label="Label column (choose 1)") | |
| feature_cols = gr.CheckboxGroup(choices=[], label="Feature columns (choose 1 or more)") | |
| # Hyperparameters | |
| learning_rate_slider = gr.Slider(0.01, 1.0, value=0.1, step=0.01, label="learning_rate") | |
| n_estimators_slider = gr.Slider(50, 300, value=100, step=50, label="n_estimators") | |
| max_depth_slider = gr.Slider(1, 10, value=3, step=1, label="max_depth") | |
| test_size_slider = gr.Slider(0.1, 0.9, value=0.3, step=0.1, label="test_size fraction (0.1-0.9)") | |
| train_button = gr.Button("Train & Evaluate") | |
| output_text = gr.Markdown() | |
| output_plot = gr.Plot() | |
| # Link the "Load Columns" button -> update_columns function | |
| load_cols_btn.click( | |
| fn=update_columns, | |
| inputs=[dataset_dropdown, custom_dataset_id], | |
| outputs=[label_col, feature_cols, load_cols_info], | |
| ) | |
| # Link "Train & Evaluate" -> train_model function | |
| train_button.click( | |
| fn=train_model, | |
| inputs=[ | |
| dataset_dropdown, | |
| custom_dataset_id, | |
| label_col, | |
| feature_cols, | |
| learning_rate_slider, | |
| n_estimators_slider, | |
| max_depth_slider, | |
| test_size_slider | |
| ], | |
| outputs=[output_text, output_plot], | |
| ) | |
| demo.launch() | |