Spaces:
Running
Running
| # app.py | |
| # ===== 1) Install deps (Colab) ===== | |
| # !pip -q install kagglehub[pandas-datasets] scikit-learn matplotlib gradio pillow | |
| """ | |
| Feature/Depth/Sample Explorer | |
| Dataset: Customer Shopping Trends (Kaggle) | |
| URL: https://www.kaggle.com/datasets/iamsouravbanerjee/customer-shopping-trends-dataset | |
| Purpose: Educational tool to visualize how model complexity (tree depth), | |
| training sample size, and data dimensionality affect generalization | |
| (under/overfitting) via F1 on a held-out test set. | |
| """ | |
| # ===== 2) App (launch inline) ===== | |
| # import io, re | |
| from typing import List, Sequence #, Tuple | |
| import numpy as np | |
| import pandas as pd | |
| import os | |
| # import matplotlib.pyplot as plt | |
| # from matplotlib.ticker import MaxNLocator | |
| import plotly.graph_objects as go | |
| # from plotly.subplots import make_subplots | |
| from sklearn.compose import ColumnTransformer | |
| from sklearn.preprocessing import OneHotEncoder | |
| from sklearn.pipeline import Pipeline | |
| from sklearn.model_selection import train_test_split, GridSearchCV | |
| from sklearn.tree import DecisionTreeClassifier | |
| from sklearn.metrics import f1_score | |
| from sklearn.model_selection import StratifiedKFold | |
| # import kagglehub | |
| # from kagglehub import KaggleDatasetAdapter | |
| import gradio as gr | |
| # Apply global styling | |
| custom_css = """ | |
| /* === Base font and readability === */ | |
| .gradio-container label, | |
| .gradio-container h1, | |
| .gradio-container h2, | |
| .gradio-container h3, | |
| .gradio-container p, | |
| .gradio-container button, | |
| .gradio-container span, | |
| .gradio-container div { | |
| font-weight: 600 !important; | |
| line-height: 1.2 !important; | |
| word-break: normal !important; | |
| overflow-wrap: normal !important; | |
| white-space: normal !important; | |
| } | |
| """ | |
| # ---- App metadata ---- | |
| APP_NAME = "Feature/Depth/Sample Explorer" | |
| # DATASET_NAME = "Customer Shopping Trends (Kaggle)" | |
| # DATASET_URL = "https://www.kaggle.com/datasets/iamsouravbanerjee/customer-shopping-trends-dataset" | |
| DATASET_NAME = "UCI Irvine - Predict Students' Dropout and Academic Success" | |
| DATASET_URL = "https://archive.ics.uci.edu/dataset/697/predict+students+dropout+and+academic+success" | |
| TOOL_DESC = ( | |
| "Explore how decision tree depth (model complexity), training sample size, " | |
| "and feature count affect performance (e.g., F1)." | |
| ) | |
| ABOUT_MD = f""" | |
| ### {APP_NAME} | |
| {TOOL_DESC} | |
| **Dataset:** [{DATASET_NAME}]({DATASET_URL}) | |
| **Target:** Student is a Dropout vs. Non-dropout. | |
| This tool is for education only. | |
| """ | |
| # **Target:** `Discount Applied` (binary) | |
| # ---------------- Config ---------------- | |
| TARGET_COL = "Target" | |
| TEST_SIZE = 700 | |
| TRAIN_FOLD_STEP=300 | |
| MIN_TRAIN_SIZE = 800 | |
| N_SPLITS_K_FOLD = 4 | |
| RANDOM_SEEDS = [42, 43, 44, 45, 46] | |
| DEFAULT_DEPTH_GRID = list(range(1, 51, 5)) | |
| # NUMERIC_CANDIDATES = ["Age", "Purchase Amount (USD)", "Review Rating", "Previous Purchases"] | |
| NUMERIC_CANDIDATES = ["Application order", "Previous qualification (grade)", "Admission grade", "Age", "Curricular units 1st sem (credited)" | |
| "Curricular units 1st sem (enrolled)", "Curricular units 1st sem (evaluations)", "Curricular units 1st sem (approved)", "Curricular units 1st sem (grade)", | |
| "Curricular units 1st sem (without evaluations)", "Curricular units 2nd sem (credited)", "Curricular units 2nd sem (enrolled)", "Curricular units 2nd sem (evaluations)", "Curricular units 2nd sem (approved)" | |
| "Curricular units 2nd sem (grade)", "Curricular units 2nd sem (without evaluations)", "Unemployment rate", "Inflation rate", "GDP"] | |
| # Growing training set | |
| FIXED_SEED = 7 | |
| # ---------------- Data loading ---------------- | |
| def load_data() -> pd.DataFrame: | |
| from ucimlrepo import fetch_ucirepo | |
| # fetch dataset | |
| predict_students_dropout_and_academic_success = fetch_ucirepo(id=697) | |
| df = predict_students_dropout_and_academic_success.data.features.copy() | |
| target = predict_students_dropout_and_academic_success.data.targets | |
| df[TARGET_COL] = target[TARGET_COL].str.strip().map({"Dropout": 1, "Enrolled":0, "Graduate": 0}).astype(int) | |
| return df | |
| # file_path = "shopping_trends_updated.csv" | |
| # df = kagglehub.dataset_load( | |
| # KaggleDatasetAdapter.PANDAS, | |
| # "iamsouravbanerjee/customer-shopping-trends-dataset", | |
| # file_path, | |
| # ).copy() | |
| # df[TARGET_COL] = df[TARGET_COL].astype(str).str.strip().str.lower().map({"yes": 1, "no": 0}).astype(int) | |
| # return df | |
| DF = load_data() | |
| # ALL_FEATURES = DF.columns.drop([TARGET_COL, "Promo Code Used", "Customer ID"]).to_list() | |
| ALL_FEATURES = DF.columns.drop([TARGET_COL]).to_list() | |
| DEFAULT_SELECTED = ALL_FEATURES[:4] | |
| def make_kfold_buckets(df: pd.DataFrame, target_col: str, k: int, seed: int = FIXED_SEED): | |
| """Return (train_folds, test_fold) where test_fold is fixed (e.g., fold 0).""" | |
| y = df[target_col].to_numpy() | |
| skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=seed) | |
| # Collect per-fold indices | |
| folds = [] | |
| for _, test_idx in skf.split(np.zeros_like(y), y): | |
| folds.append(test_idx) | |
| return folds | |
| # ---------------- Helpers ---------------- | |
| def stratified_fixed_sample(df: pd.DataFrame, train_total: int, seed: int = FIXED_SEED) -> pd.DataFrame: | |
| """Return a stratified fixed-size sample (same per n_total if seed fixed).""" | |
| if train_total < len(df): | |
| sample, test_sample = train_test_split( | |
| df, train_size=train_total, stratify=df[TARGET_COL], random_state=seed # ← fixed seed | |
| ) | |
| else: | |
| sample = df | |
| test_sample = None | |
| return sample.reset_index(), test_sample.reset_index() | |
| TRAIN_DF, TEST_DF = stratified_fixed_sample(DF, train_total=len(DF) - TEST_SIZE) # fixed once | |
| num_folds = int(np.floor(len(TRAIN_DF)/TRAIN_FOLD_STEP)) | |
| TRAIN_FOLDS = make_kfold_buckets(TRAIN_DF, TARGET_COL, num_folds) | |
| TRAIN_POOL = TRAIN_DF.copy() | |
| BUCKET_SIZES = [len(b) for b in TRAIN_FOLDS] | |
| CUM_BUCKET_SIZES = np.cumsum(BUCKET_SIZES) | |
| TOTAL_TRAIN = len(TRAIN_POOL) | |
| # print("TOTAL_TRAIN", TOTAL_TRAIN) | |
| # print("CUM_BUCKET_SIZES", CUM_BUCKET_SIZES) | |
| # print("BUCKET_SIZES", BUCKET_SIZES) | |
| # print("TRAIN_FOLDS", TRAIN_FOLDS) | |
| # print("TRAIN_POOL", TRAIN_POOL) | |
| def get_train_indices_for_n(n_total: int) -> np.ndarray: | |
| # print("n_total", n_total) | |
| """Return nested indices for a requested training size using merged folds; | |
| subsample from the last bucket if needed to match n_total.""" | |
| n = min(n_total, TOTAL_TRAIN) | |
| # print("n", n) | |
| # Find how many full buckets we need | |
| full = int(np.searchsorted(CUM_BUCKET_SIZES, n, side='right')) | |
| # print("full", full) | |
| if full == 0: | |
| # take a prefix of the first bucket | |
| idx = TRAIN_FOLDS[0][:n] | |
| # print("idx", idx) | |
| else: | |
| idx = np.concatenate(TRAIN_FOLDS[:(full)]) | |
| # print("idx", idx) | |
| extra = n - len(idx) | |
| # print("extra", extra) | |
| # print(" TRAIN_FOLDS[full]", TRAIN_FOLDS[full]) | |
| # print(" len(TRAIN_FOLDS[full]", len(TRAIN_FOLDS[full])) | |
| if extra > 0: | |
| idx = np.concatenate([idx, TRAIN_FOLDS[full][:extra]]) | |
| # print("idx", idx) | |
| # print("-----") | |
| return idx | |
| def get_train_df_for_n(n_total: int) -> pd.DataFrame: | |
| idx = get_train_indices_for_n(n_total) | |
| return TRAIN_POOL.loc[idx] | |
| def split_features(feats: Sequence[str]): | |
| numeric = [c for c in feats if c in NUMERIC_CANDIDATES] | |
| categorical = [c for c in feats if c not in numeric] | |
| return numeric, categorical | |
| def build_preprocessor(feats: Sequence[str]) -> ColumnTransformer: | |
| numeric, categorical = split_features(feats) | |
| return ColumnTransformer([ | |
| ("cat", OneHotEncoder(handle_unknown="ignore"), categorical), | |
| ("num", "passthrough", numeric), | |
| ]) | |
| def one_run(feats, max_depth, n_total, seed, auto_depth, depth_grid=DEFAULT_DEPTH_GRID): | |
| assert len(feats) > 0, "Select at least one feature." | |
| train_df = get_train_df_for_n(n_total) | |
| X_train, y_train = train_df[feats], train_df[TARGET_COL] | |
| X_test, y_test = TEST_DF[feats], TEST_DF[TARGET_COL] | |
| prep = build_preprocessor(feats) | |
| base_clf = DecisionTreeClassifier(random_state=seed, class_weight="balanced") | |
| if auto_depth: | |
| pipe = Pipeline([("prep", prep), ("clf", base_clf)]) | |
| cv = StratifiedKFold(n_splits=N_SPLITS_K_FOLD, shuffle=True, random_state=seed) | |
| grid = GridSearchCV(pipe, {"clf__max_depth": list(depth_grid)}, scoring="f1", cv=cv, refit=True, verbose=0) | |
| grid.fit(X_train, y_train) | |
| model = grid.best_estimator_ | |
| chosen_depth = int(model.named_steps["clf"].get_depth()) # actual depth | |
| else: | |
| clf = DecisionTreeClassifier(random_state=seed, class_weight="balanced", max_depth=max_depth) | |
| model = Pipeline([("prep", prep), ("clf", clf)]).fit(X_train, y_train) | |
| chosen_depth = int(model.named_steps["clf"].get_depth()) | |
| yhat_tr = model.predict(X_train) | |
| yhat_te = model.predict(X_test) | |
| return f1_score(y_train, yhat_tr), f1_score(y_test, yhat_te), chosen_depth | |
| def percentile_band(arr: np.ndarray): | |
| means = arr.mean(axis=0) | |
| p10, p90 = np.percentile(arr, [10, 90], axis=0) | |
| return means, p10, p90 | |
| def line_and_band(fig, x, mean, lo, hi, name, color, dash="solid"): | |
| fill = color.replace("1)", "0.15)") | |
| fig.add_trace(go.Scatter(x=x, y=mean, mode="lines+markers", | |
| name=name, line=dict(color=color, dash=dash))) | |
| fig.add_trace(go.Scatter(x=x, y=lo, mode="lines", line=dict(width=0), | |
| showlegend=False, hoverinfo="skip")) | |
| fig.add_trace(go.Scatter(x=x, y=hi, mode="lines", line=dict(width=0), | |
| fill="tonexty", fillcolor=fill, | |
| name=f"{name} 10–90%", hoverinfo="skip")) | |
| # ---------------- Plots ---------------- | |
| def plot_f1_vs_features(selected_feats: List[str], max_depth: int, n_total: int, auto_depth: bool): | |
| if not selected_feats: | |
| raise gr.Error("Please select at least one feature.") | |
| ks = list(range(1, len(selected_feats) + 1)) | |
| tr_runs, te_runs, depth_runs = [], [], [] | |
| for k in ks: | |
| tr_scores, te_scores, depths = [], [], [] | |
| feats_k = selected_feats[:k] | |
| for s in RANDOM_SEEDS: | |
| tr, te, d = one_run(feats_k, max_depth, n_total, s, auto_depth) | |
| tr_scores.append(tr); te_scores.append(te); depths.append(d) | |
| tr_runs.append(tr_scores); te_runs.append(te_scores); depth_runs.append(depths) | |
| tr_arr, te_arr = np.array(tr_runs).T, np.array(te_runs).T | |
| tr_m, tr_lo, tr_hi = percentile_band(tr_arr) | |
| te_m, te_lo, te_hi = percentile_band(te_arr) | |
| x_labels = [selected_feats[i-1] for i in ks] | |
| # --- Figure 1: F1 vs #features --- | |
| fig_f1 = go.Figure() | |
| line_and_band(fig_f1, ks, tr_m, tr_lo, tr_hi, "Train F1", "rgba(31,119,180,1)") | |
| line_and_band(fig_f1, ks, te_m, te_lo, te_hi, "Test F1", "rgba(255,127,14,1)") | |
| mode = "auto-depth (grid search)" if auto_depth else f"max_depth={max_depth}" | |
| fig_f1.update_layout( | |
| title=f"F1 vs Features ({mode}; n={n_total})", | |
| template="plotly_white", | |
| height=600, | |
| margin=dict(l=40, r=10, t=60, b=60), | |
| legend=dict(orientation="h", y=-0.2), | |
| uirevision="keep-zoom" | |
| ) | |
| fig_f1.update_xaxes(tickmode="array", tickvals=ks, ticktext=x_labels, tickangle=-30) | |
| fig_f1.update_yaxes(title_text="F1 Score", range=[0, 1]) | |
| # --- Figure 2: Depth vs #features (only when auto_depth) --- | |
| if auto_depth: | |
| depth_arr = np.array(depth_runs).T | |
| d_m, d_lo, d_hi = percentile_band(depth_arr) | |
| fig_depth = go.Figure() | |
| line_and_band(fig_depth, ks, d_m, d_lo, d_hi, "Depth", "rgba(44,160,44,1)", dash="dot") | |
| fig_depth.update_layout( | |
| title=f"Depth vs Features (n={n_total})", | |
| template="plotly_white", | |
| height=600, | |
| margin=dict(l=40, r=10, t=60, b=60), | |
| legend=dict(orientation="h", y=-0.2), | |
| uirevision="keep-zoom" | |
| ) | |
| fig_depth.update_xaxes(tickmode="array", | |
| tickvals=ks, ticktext=x_labels, tickangle=-30) | |
| y_min = max(0, np.nanmin(d_lo) if np.isfinite(np.nanmin(d_lo)) else 0) | |
| y_max = np.nanmax(d_hi) if np.isfinite(np.nanmax(d_hi)) else None | |
| fig_depth.update_yaxes(title_text="Depth", dtick=1, range=[y_min, y_max]) | |
| else: | |
| # Return a valid (empty) figure so Gradio Plot doesn't choke | |
| fig_depth = go.Figure() | |
| return fig_f1, gr.update(value=fig_depth, visible=auto_depth) | |
| def plot_f1_vs_depth(selected_feats: List[str], n_total: int): | |
| if not selected_feats: | |
| raise gr.Error("Please select at least one feature.") | |
| depths = list(range(1, 51, 5)) | |
| tr_runs, te_runs = [], [] | |
| for d in depths: | |
| tr_scores, te_scores = [], [] | |
| for s in RANDOM_SEEDS: | |
| tr, te, _ = one_run(selected_feats, d, n_total, s, auto_depth=False) | |
| tr_scores.append(tr); te_scores.append(te) | |
| tr_runs.append(tr_scores); te_runs.append(te_scores) | |
| tr_arr, te_arr = np.array(tr_runs).T, np.array(te_runs).T | |
| tr_m, tr_lo, tr_hi = percentile_band(tr_arr) | |
| te_m, te_lo, te_hi = percentile_band(te_arr) | |
| fig_f1 = go.Figure() | |
| line_and_band(fig_f1, depths, tr_m, tr_lo, tr_hi, "Train F1", "rgba(31,119,180,1)") | |
| line_and_band(fig_f1, depths, te_m, te_lo, te_hi, "Test F1", "rgba(255,127,14,1)") | |
| fig_f1.update_layout( | |
| title=f"F1 vs Tree Depth (n={n_total}; #features={len(selected_feats)})", | |
| template="plotly_white", | |
| height=600, | |
| margin=dict(l=40, r=10, t=60, b=60), | |
| legend=dict(orientation="h", y=-0.2), | |
| uirevision="keep-zoom" | |
| ) | |
| fig_f1.update_yaxes(title_text="F1 Score", range=[0, 1]) | |
| fig_f1.update_xaxes(title_text="max_depth", dtick=5) | |
| # IMPORTANT: return a single figure (not a tuple) | |
| return fig_f1 | |
| def plot_f1_vs_samplesize(selected_feats: List[str], max_depth: int, auto_depth: bool): | |
| if not selected_feats: | |
| raise gr.Error("Please select at least one feature.") | |
| sample_sizes = list(range(MIN_TRAIN_SIZE, len(DF) - TEST_SIZE + 1, TRAIN_FOLD_STEP)) # 600, 3401, 200 | |
| # print(sample_sizes, MIN_TRAIN_SIZE, len(DF) - TEST_SIZE + 1, TRAIN_FOLD_STEP) | |
| tr_runs, te_runs, depth_runs = [], [], [] | |
| for n_total in sample_sizes: | |
| tr_scores, te_scores, depths = [], [], [] | |
| for s in RANDOM_SEEDS: | |
| tr, te, d = one_run(selected_feats, max_depth, n_total, s, auto_depth) | |
| tr_scores.append(tr); te_scores.append(te); depths.append(d) | |
| tr_runs.append(tr_scores); te_runs.append(te_scores); depth_runs.append(depths) | |
| tr_arr, te_arr, d_arr = np.array(tr_runs).T, np.array(te_runs).T, np.array(depth_runs).T | |
| tr_m, tr_lo, tr_hi = percentile_band(tr_arr) | |
| te_m, te_lo, te_hi = percentile_band(te_arr) | |
| # ---- Figure 1: F1 vs Sample Size ---- | |
| fig_f1 = go.Figure() | |
| line_and_band(fig_f1, sample_sizes, tr_m, tr_lo, tr_hi, "Train F1", "rgba(31,119,180,1)") | |
| line_and_band(fig_f1, sample_sizes, te_m, te_lo, te_hi, "Test F1", "rgba(255,127,14,1)") | |
| mode = "auto-depth (grid search)" if auto_depth else f"max_depth={max_depth}" | |
| fig_f1.update_layout( | |
| title=f"F1 vs Sample Size ({mode}; #features={len(selected_feats)})", | |
| template="plotly_white", | |
| height=600, | |
| margin=dict(l=40, r=10, t=60, b=60), | |
| legend=dict(orientation="h", y=-0.2), | |
| uirevision="keep-zoom" | |
| ) | |
| fig_f1.update_xaxes(title_text="Number of samples (n)") | |
| fig_f1.update_yaxes(title_text="F1 Score", range=[0, 1]) | |
| # ---- Figure 2: Depth vs Sample Size ---- | |
| if auto_depth: | |
| d_m, d_lo, d_hi = percentile_band(d_arr) | |
| fig_depth = go.Figure() | |
| fig_depth.add_trace(go.Scatter(x=sample_sizes, y=d_m, mode="lines+markers", | |
| name="Depth (mean)", line=dict(dash="dot"))) | |
| fig_depth.add_trace(go.Scatter(x=sample_sizes, y=d_lo, mode="lines", line=dict(width=0), | |
| showlegend=False, hoverinfo="skip")) | |
| fig_depth.add_trace(go.Scatter(x=sample_sizes, y=d_hi, mode="lines", line=dict(width=0), | |
| fill="tonexty", name="Depth 10–90%", hoverinfo="skip")) | |
| fig_depth.update_layout( | |
| title=f"Depth vs Sample Size", | |
| template="plotly_white", | |
| height=600, | |
| margin=dict(l=40, r=10, t=60, b=60), | |
| legend=dict(orientation="h", y=-0.2), | |
| uirevision="keep-zoom" | |
| ) | |
| fig_depth.update_xaxes(title_text="Number of samples (n)") | |
| y_min = max(0, np.nanmin(d_lo) if np.isfinite(np.nanmin(d_lo)) else 0) | |
| y_max = np.nanmax(d_hi) if np.isfinite(np.nanmax(d_hi)) else None | |
| fig_depth.update_yaxes(title_text="Depth", dtick=1, range=[y_min, y_max]) | |
| else: | |
| fig_depth = go.Figure() | |
| return fig_f1, gr.update(value=fig_depth, visible=auto_depth) | |
| # ---------------- Gradio UI ---------------- | |
| with gr.Blocks(title="Feature/Depth/Sample Explorer", css=custom_css) as demo: | |
| with gr.Accordion("About this tool", open=False): | |
| gr.Markdown(ABOUT_MD) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| feat_choices = gr.CheckboxGroup( | |
| label="Select features (order is preserved):", | |
| choices=ALL_FEATURES, | |
| value=DEFAULT_SELECTED, | |
| ) | |
| gr.Markdown( | |
| f"**Dataset size:** {len(DF):,} rows • **Test size/run:** {TEST_SIZE} • **Seeds:** {len(RANDOM_SEEDS)}" | |
| ) | |
| with gr.Column(scale=2): | |
| # -------- Tab: F1 vs Features -------- | |
| with gr.Tab("F1 vs Features"): | |
| with gr.Row(): | |
| auto_depth_feat = gr.Checkbox(value=False, label="Auto-depth (grid 1..50 step 5)") | |
| depth_feat = gr.Slider(1, 50, value=5, step=1, label="max_depth (used when auto-depth is OFF)") | |
| n_total_feat = gr.Slider(minimum=MIN_TRAIN_SIZE, maximum=len(DF)-TEST_SIZE, value=min(MIN_TRAIN_SIZE, len(DF)), | |
| step=TRAIN_FOLD_STEP, label="Sample size (n)") | |
| btn_feat = gr.Button("Run") | |
| # Two plots: main F1 + depth | |
| plt_feat_main = gr.Plot(label="F1 vs Features", visible=True) | |
| plt_feat_depth = gr.Plot(label="Depth vs #Features", visible=False) | |
| # -------- Tab: F1 vs Depth -------- | |
| with gr.Tab("F1 vs Depth"): | |
| n_total_depth = gr.Slider(minimum=MIN_TRAIN_SIZE, maximum=len(DF)-TEST_SIZE, value=min(MIN_TRAIN_SIZE, len(DF)), | |
| step=TRAIN_FOLD_STEP, label="Sample size (n)") | |
| btn_depth = gr.Button("Run") | |
| plt_depth = gr.Plot(label="F1 vs Depth") | |
| # -------- Tab: F1 vs Sample Size -------- | |
| with gr.Tab("F1 vs Sample Size"): | |
| with gr.Row(): | |
| auto_depth_samp = gr.Checkbox(value=False, label="Auto-depth (grid 1..50 step 5)") | |
| depth_samp = gr.Slider(1, 50, value=5, step=1, label="max_depth (used when auto-depth is OFF)") | |
| btn_size = gr.Button("Run") | |
| # Two plots: main F1 + depth | |
| plt_size_main = gr.Plot(label="F1 vs Sample Size") | |
| plt_size_depth = gr.Plot(label="Depth vs Sample Size") | |
| def toggle_depth_and_plot(checked: bool): | |
| return gr.update(visible=not checked) | |
| auto_depth_feat.change( | |
| fn=toggle_depth_and_plot, | |
| inputs=auto_depth_feat, | |
| outputs=[depth_feat], | |
| ) | |
| auto_depth_samp.change( | |
| fn=toggle_depth_and_plot, | |
| inputs=auto_depth_samp, | |
| outputs=[depth_samp], | |
| ) | |
| # Wiring | |
| btn_feat.click( | |
| fn=plot_f1_vs_features, | |
| inputs=[feat_choices, depth_feat, n_total_feat, auto_depth_feat], | |
| outputs=[plt_feat_main, plt_feat_depth], | |
| ) | |
| btn_depth.click( | |
| fn=plot_f1_vs_depth, | |
| inputs=[feat_choices, n_total_depth], | |
| outputs=plt_depth, # single figure | |
| ) | |
| btn_size.click( | |
| fn=plot_f1_vs_samplesize, | |
| inputs=[feat_choices, depth_samp, auto_depth_samp], | |
| outputs=[plt_size_main, plt_size_depth], | |
| ) | |
| # AUTO-RUN on load with default values (return exactly 5 figures) | |
| demo.load( | |
| fn=lambda feats, d_feat, n_feat, auto_feat, n_depth, d_samp, auto_samp: ( | |
| *plot_f1_vs_features(feats, d_feat, n_feat, auto_feat), # -> 2 figs | |
| plot_f1_vs_depth(feats, n_depth), # -> 1 fig | |
| *plot_f1_vs_samplesize(feats, d_samp, auto_samp), # -> 2 figs | |
| ), | |
| inputs=[feat_choices, depth_feat, n_total_feat, auto_depth_feat, n_total_depth, depth_samp, auto_depth_samp], | |
| outputs=[plt_feat_main, plt_feat_depth, plt_depth, plt_size_main, plt_size_depth], | |
| ) | |
| workers = int(os.getenv("WORKERS", "4")) | |
| # set a global default concurrency for all events | |
| demo.queue( | |
| default_concurrency_limit=workers, | |
| max_size=100, | |
| status_update_rate="auto" # or a number of seconds | |
| ) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=int(os.getenv("PORT", "7860")), | |
| show_error=True, | |
| ) | |