|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
from sklearn.preprocessing import LabelEncoder |
|
|
from src import decision_tree_core |
|
|
import vlai_template |
|
|
|
|
|
|
|
|
current_dataframe = None |
|
|
|
|
|
|
|
|
SAMPLE_DATA_CONFIG = { |
|
|
"Iris": {"target_column": "target", "problem_type": "classification"}, |
|
|
"Wine": {"target_column": "target", "problem_type": "classification"}, |
|
|
"Breast Cancer": {"target_column": "target", "problem_type": "classification"}, |
|
|
"Diabetes": {"target_column": "target", "problem_type": "regression"}, |
|
|
} |
|
|
|
|
|
force_light_theme_js = """ |
|
|
() => { |
|
|
const params = new URLSearchParams(window.location.search); |
|
|
if (!params.has('__theme')) { |
|
|
params.set('__theme', 'light'); |
|
|
window.location.search = params.toString(); |
|
|
} |
|
|
} |
|
|
""" |
|
|
|
|
|
def validate_config(df, target_col): |
|
|
"""Validate target column and determine problem type.""" |
|
|
if not target_col or target_col not in df.columns: |
|
|
return False, "β Please select a valid target column from the dropdown.", None |
|
|
|
|
|
target_series = df[target_col] |
|
|
unique_vals = target_series.nunique() |
|
|
|
|
|
|
|
|
if target_series.dtype == "object" or unique_vals <= min(20, len(target_series) * 0.1): |
|
|
problem_type = "classification" |
|
|
if unique_vals > 50: |
|
|
return False, f"β οΈ Too many classes ({unique_vals}). Consider another target.", None |
|
|
if target_series.isnull().any(): |
|
|
return False, "β οΈ Target column has missing values. Please clean your data.", None |
|
|
else: |
|
|
problem_type = "regression" |
|
|
if unique_vals < 5: |
|
|
return False, f"β οΈ Too few unique values ({unique_vals}). Consider another target.", None |
|
|
|
|
|
return True, f"\nβ
Configuration is valid! Ready for {unique_vals} {'classes' if problem_type=='classification' else 'values'}.", problem_type |
|
|
|
|
|
|
|
|
def get_status_message(is_sample, dataset_choice, target_col, problem_type, is_valid, validation_msg): |
|
|
if is_sample: |
|
|
return f"β
**Selected Dataset**: {dataset_choice} | **Target**: {target_col} | **Type**: {problem_type.title()}" |
|
|
elif target_col and problem_type: |
|
|
status_icon = "β
" if is_valid else "β οΈ" |
|
|
return f"{status_icon} **Custom Data** | **Target**: {target_col} | **Type**: {problem_type.title()} | {validation_msg}" |
|
|
else: |
|
|
return "π **Custom data uploaded!** π Please select target column above to continue." |
|
|
|
|
|
|
|
|
def load_and_configure_data(file_obj=None, dataset_choice="Iris"): |
|
|
"""Load data and prepare target/problem type + feature inputs.""" |
|
|
global current_dataframe |
|
|
try: |
|
|
df = decision_tree_core.load_data(file_obj, dataset_choice) |
|
|
current_dataframe = df |
|
|
|
|
|
target_options = df.columns.tolist() |
|
|
is_sample = file_obj is None |
|
|
|
|
|
if is_sample: |
|
|
cfg = SAMPLE_DATA_CONFIG.get(dataset_choice, {}) |
|
|
target_col = cfg.get("target_column") |
|
|
problem_type = cfg.get("problem_type") |
|
|
else: |
|
|
target_col, problem_type = None, None |
|
|
|
|
|
|
|
|
if target_col: |
|
|
is_valid, validation_msg, detected = validate_config(df, target_col) |
|
|
if detected: |
|
|
problem_type = detected |
|
|
status_msg = get_status_message(is_sample, dataset_choice, target_col, problem_type, is_valid, validation_msg) |
|
|
else: |
|
|
status_msg = get_status_message(is_sample, dataset_choice, target_col, problem_type, False, "") |
|
|
|
|
|
|
|
|
input_updates = [gr.update(visible=False)] * 40 |
|
|
inputs_visible = gr.update(visible=False) |
|
|
input_status = "βοΈ Configure target column above to enable feature inputs." |
|
|
|
|
|
if target_col and problem_type and (not is_sample or is_valid): |
|
|
try: |
|
|
components_info = decision_tree_core.create_input_components(df, target_col) |
|
|
for i in range(min(20, len(components_info))): |
|
|
comp = components_info[i] |
|
|
number_idx, dropdown_idx = i * 2, i * 2 + 1 |
|
|
if comp["type"] == "number": |
|
|
upd = {"visible": True, "label": comp["name"], "value": comp["value"]} |
|
|
if comp["minimum"] is not None: |
|
|
upd["minimum"] = comp["minimum"] |
|
|
if comp["maximum"] is not None: |
|
|
upd["maximum"] = comp["maximum"] |
|
|
input_updates[number_idx] = gr.update(**upd) |
|
|
input_updates[dropdown_idx] = gr.update(visible=False) |
|
|
else: |
|
|
input_updates[number_idx] = gr.update(visible=False) |
|
|
input_updates[dropdown_idx] = gr.update( |
|
|
visible=True, label=comp["name"], choices=comp["choices"], value=comp["value"] |
|
|
) |
|
|
inputs_visible = gr.update(visible=True) |
|
|
input_status = f"π **Ready!** Enter values for {len(components_info)} features below, then click Run prediction. | {validation_msg}" |
|
|
except Exception as e: |
|
|
input_status = f"β Error generating inputs: {str(e)}" |
|
|
|
|
|
return [df.head(5).round(2), gr.Dropdown(choices=target_options, value=target_col), status_msg] + input_updates + [inputs_visible, input_status] |
|
|
|
|
|
except Exception as e: |
|
|
current_dataframe = None |
|
|
empty = [pd.DataFrame(), gr.Dropdown(choices=[], value=None), f"β **Error loading data**: {str(e)} | Please try a different file or dataset."] |
|
|
return empty + [gr.update(visible=False)] * 40 + [gr.update(visible=False), "No data loaded."] |
|
|
|
|
|
|
|
|
def update_configuration(df_preview, target_col): |
|
|
"""Rebuild feature widgets when target changes.""" |
|
|
global current_dataframe |
|
|
df = current_dataframe |
|
|
|
|
|
if df is None or df.empty: |
|
|
return [gr.update(visible=False)] * 40 + [gr.update(visible=False), "No data available.", "No data available."] |
|
|
if not target_col: |
|
|
return [gr.update(visible=False)] * 40 + [gr.update(visible=False), "Select target column.", "Select target column."] |
|
|
|
|
|
try: |
|
|
is_valid, validation_msg, problem_type = validate_config(df, target_col) |
|
|
if not is_valid: |
|
|
return [gr.update(visible=False)] * 40 + [gr.update(visible=False), f"β οΈ {validation_msg}", f"β οΈ {validation_msg}"] |
|
|
|
|
|
components_info = decision_tree_core.create_input_components(df, target_col) |
|
|
input_updates = [gr.update(visible=False)] * 40 |
|
|
for i in range(min(20, len(components_info))): |
|
|
comp = components_info[i] |
|
|
number_idx, dropdown_idx = i * 2, i * 2 + 1 |
|
|
if comp["type"] == "number": |
|
|
upd = {"visible": True, "label": comp["name"], "value": comp["value"]} |
|
|
if comp["minimum"] is not None: |
|
|
upd["minimum"] = comp["minimum"] |
|
|
if comp["maximum"] is not None: |
|
|
upd["maximum"] = comp["maximum"] |
|
|
input_updates[number_idx] = gr.update(**upd) |
|
|
input_updates[dropdown_idx] = gr.update(visible=False) |
|
|
else: |
|
|
input_updates[number_idx] = gr.update(visible=False) |
|
|
input_updates[dropdown_idx] = gr.update( |
|
|
visible=True, label=comp["name"], choices=comp["choices"], value=comp["value"] |
|
|
) |
|
|
input_status = f"π Enter values for {len(components_info)} features | {validation_msg}" |
|
|
status_msg = f"β
**Selected Dataset**: Custom Data | **Target**: {target_col} | **Type**: {problem_type.title()}" |
|
|
return input_updates + [gr.update(visible=True), input_status, status_msg] |
|
|
|
|
|
except Exception as e: |
|
|
return [gr.update(visible=False)] * 40 + [gr.update(visible=False), f"β Error: {str(e)}", f"β Error: {str(e)}"] |
|
|
|
|
|
|
|
|
|
|
|
CLASS_CRITS = {"gini", "entropy", "log_loss"} |
|
|
REGR_CRITS = {"squared_error", "absolute_error", "friedman_mse", "poisson"} |
|
|
|
|
|
def update_criterion_choices(problem_type): |
|
|
if problem_type == "classification": |
|
|
return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini") |
|
|
else: |
|
|
return gr.Dropdown(choices=sorted(REGR_CRITS), value="squared_error") |
|
|
|
|
|
|
|
|
def update_criterion_on_target_change(df_preview, target_col): |
|
|
"""Recompute problem type from current df + target and return the right dropdown config.""" |
|
|
if not target_col: |
|
|
return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini") |
|
|
global current_dataframe |
|
|
df = current_dataframe |
|
|
if df is None or df.empty: |
|
|
return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini") |
|
|
try: |
|
|
is_valid, _, problem_type = validate_config(df, target_col) |
|
|
if problem_type == "classification": |
|
|
return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini") |
|
|
else: |
|
|
return gr.Dropdown(choices=sorted(REGR_CRITS), value="squared_error") |
|
|
except Exception: |
|
|
return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini") |
|
|
|
|
|
|
|
|
def execute_prediction(df_preview, target_col, max_depth, min_samples_split, min_samples_leaf, criterion, *input_values): |
|
|
"""Run the tree and produce all outputs. Always return 5 values.""" |
|
|
global current_dataframe |
|
|
df = current_dataframe |
|
|
|
|
|
EMPTY_PLOT = None |
|
|
EMPTY_MD = " " |
|
|
|
|
|
if df is None or df.empty: |
|
|
return (EMPTY_PLOT, EMPTY_PLOT, "β **No data loaded!** π Please select a sample dataset or upload a file first.", EMPTY_MD, "Load data to get started.") |
|
|
if not target_col: |
|
|
return (EMPTY_PLOT, EMPTY_PLOT, "β **Configuration incomplete!** π― Please select target column above.", EMPTY_MD, "Complete configuration to proceed.") |
|
|
|
|
|
is_valid, validation_msg, problem_type = validate_config(df, target_col) |
|
|
if not is_valid: |
|
|
return (EMPTY_PLOT, EMPTY_PLOT, f"β **Configuration issue**: {validation_msg}", EMPTY_MD, "Fix the configuration and try again.") |
|
|
|
|
|
|
|
|
if problem_type == "classification": |
|
|
if criterion not in CLASS_CRITS: |
|
|
criterion = "gini" |
|
|
else: |
|
|
if criterion not in REGR_CRITS: |
|
|
criterion = "squared_error" |
|
|
|
|
|
try: |
|
|
components_info = decision_tree_core.create_input_components(df, target_col) |
|
|
new_point_dict = {} |
|
|
for i, comp in enumerate(components_info): |
|
|
number_idx, dropdown_idx = i * 2, i * 2 + 1 |
|
|
if comp["type"] == "number": |
|
|
v = input_values[number_idx] if number_idx < len(input_values) and input_values[number_idx] is not None else comp["value"] |
|
|
else: |
|
|
v = input_values[dropdown_idx] if dropdown_idx < len(input_values) and input_values[dropdown_idx] is not None else comp["value"] |
|
|
new_point_dict[comp["name"]] = v |
|
|
|
|
|
tree_fig, importance_fig, prediction, pred_details, summary, error = decision_tree_core.run_decision_tree_and_visualize( |
|
|
df, target_col, new_point_dict, max_depth, min_samples_split, min_samples_leaf, criterion, problem_type |
|
|
) |
|
|
|
|
|
if error: |
|
|
return (tree_fig or EMPTY_PLOT, importance_fig or EMPTY_PLOT, f"β **Prediction failed**: {error}", pred_details or EMPTY_MD, summary or "Adjust inputs and retry.") |
|
|
|
|
|
|
|
|
display_pred = prediction |
|
|
if problem_type == "classification": |
|
|
try: |
|
|
y_series = df[target_col] |
|
|
|
|
|
if pd.api.types.is_object_dtype(y_series) or pd.api.types.is_categorical_dtype(y_series): |
|
|
le = LabelEncoder() |
|
|
le.fit(y_series.dropna().astype(str)) |
|
|
|
|
|
if isinstance(prediction, (int, np.integer, float, np.floating)) or str(prediction).isdigit(): |
|
|
code = int(prediction) |
|
|
if 0 <= code < len(le.classes_): |
|
|
display_pred = le.inverse_transform([code])[0] |
|
|
except Exception: |
|
|
|
|
|
display_pred = prediction |
|
|
|
|
|
|
|
|
if problem_type == "classification": |
|
|
header = f"## π― **Classification Result**: {display_pred}\n*Based on decision tree with `{criterion}`*" |
|
|
else: |
|
|
header = f"## π― **Regression Result**: {float(prediction):.3f}\n*Based on decision tree with `{criterion}`*" |
|
|
|
|
|
return (tree_fig, importance_fig, header, pred_details, summary) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
return (EMPTY_PLOT, EMPTY_PLOT, f"β **Execution error**: {str(e)}", EMPTY_MD, "Check inputs and try again.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme="gstaff/sketch", css=vlai_template.custom_css, fill_width=True, js=force_light_theme_js) as demo: |
|
|
vlai_template.create_header() |
|
|
gr.Markdown("### π³ **How to Use**: Select data β Configure target β Set tree parameters β Enter new point β Run prediction!") |
|
|
|
|
|
with gr.Row(equal_height=False, variant="panel"): |
|
|
with gr.Column(scale=45): |
|
|
with gr.Accordion("π Data & Configuration", open=True): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("Start with sample datasets or upload your own CSV/Excel files.") |
|
|
file_upload = gr.File(label="π Upload Your Data", file_types=[".csv", ".xlsx", ".xls"]) |
|
|
with gr.Column(scale=3): |
|
|
sample_dataset = gr.Dropdown(choices=list(SAMPLE_DATA_CONFIG.keys()), value="Iris", label="ποΈ Sample Datasets") |
|
|
|
|
|
with gr.Row(): |
|
|
target_column = gr.Dropdown(choices=[], label="π― Target Column", interactive=True) |
|
|
|
|
|
status_message = gr.Markdown("π Loading sample data...") |
|
|
data_preview = gr.DataFrame(label="π Data Preview (First 5 Rows)", row_count=5, interactive=False, max_height=250) |
|
|
|
|
|
with gr.Accordion("βοΈ Parameters & Input", open=True): |
|
|
gr.Markdown("**π³ Decision Tree Parameters**") |
|
|
with gr.Row(): |
|
|
max_depth = gr.Number( |
|
|
label="Max Depth", |
|
|
value=5, minimum=0, maximum=20, precision=0, |
|
|
info="Set to 0 for unlimited depth" |
|
|
) |
|
|
min_samples_split = gr.Number( |
|
|
label="Min Samples Split", |
|
|
value=2, minimum=2, maximum=100, precision=0, |
|
|
info="Minimum samples required to split an internal node" |
|
|
) |
|
|
min_samples_leaf = gr.Number( |
|
|
label="Min Samples Leaf", |
|
|
value=1, minimum=1, maximum=50, precision=0, |
|
|
info="Minimum samples required to be at a leaf node" |
|
|
) |
|
|
with gr.Row(): |
|
|
criterion = gr.Dropdown( |
|
|
choices=sorted(CLASS_CRITS), value="gini", label="π― Criterion", |
|
|
info="Objective to measure split quality (auto-switched for regression)" |
|
|
) |
|
|
|
|
|
inputs_group = gr.Group(visible=False) |
|
|
with inputs_group: |
|
|
input_status = gr.Markdown("Configure inputs above.") |
|
|
gr.Markdown("**π New Data Point** - Enter feature values for prediction:") |
|
|
input_components = [] |
|
|
for row in range(5): |
|
|
with gr.Row(): |
|
|
for col in range(4): |
|
|
idx = row * 4 + col |
|
|
if idx < 20: |
|
|
number_comp = gr.Number(label=f"Feature {idx+1}", visible=False) |
|
|
dropdown_comp = gr.Dropdown(label=f"Feature {idx+1}", visible=False) |
|
|
input_components.extend([number_comp, dropdown_comp]) |
|
|
|
|
|
run_prediction_btn = gr.Button("π Run Prediction", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=55): |
|
|
gr.Markdown("### π³ **Decision Tree Results & Visualization**") |
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("Decision Tree"): |
|
|
tree_visualization = gr.Plot(label="Decision Tree (Interactive)", visible=True) |
|
|
with gr.TabItem("Feature Importance"): |
|
|
feature_importance_plot = gr.Plot(label="Feature Importance", visible=True) |
|
|
|
|
|
prediction_result = gr.Markdown("## π― Prediction Result\n**Run prediction to see the result.**", label="π Final Prediction") |
|
|
prediction_details = gr.Markdown("**π Prediction Details**\n\nDetailed prediction information will appear here.", label="π Prediction Details") |
|
|
algorithm_summary = gr.Markdown("**π Algorithm Summary**\n\nAlgorithm details will appear here after prediction.", label="π Technical Details") |
|
|
|
|
|
gr.Markdown("""π‘ **Tips**: |
|
|
- **Tree visualization** shows the complete decision tree structure with decision paths. |
|
|
- **Feature importance** shows which features matter most. |
|
|
- Try different **max depth** and **criterion** to see structure changes! |
|
|
- **Min samples split/leaf** control complexity and reduce overfitting. |
|
|
""") |
|
|
|
|
|
vlai_template.create_footer() |
|
|
|
|
|
|
|
|
load_evt = demo.load( |
|
|
fn=lambda: load_and_configure_data(None, "Iris"), |
|
|
outputs=[data_preview, target_column, status_message] + input_components + [inputs_group, input_status], |
|
|
) |
|
|
load_evt.then(fn=update_criterion_on_target_change, inputs=[data_preview, target_column], outputs=[criterion]) |
|
|
|
|
|
upload_evt = file_upload.upload( |
|
|
fn=lambda file: load_and_configure_data(file, "Iris"), |
|
|
inputs=[file_upload], |
|
|
outputs=[data_preview, target_column, status_message] + input_components + [inputs_group, input_status], |
|
|
) |
|
|
upload_evt.then(fn=update_criterion_on_target_change, inputs=[data_preview, target_column], outputs=[criterion]) |
|
|
|
|
|
sample_evt = sample_dataset.change( |
|
|
fn=lambda choice: load_and_configure_data(None, choice), |
|
|
inputs=[sample_dataset], |
|
|
outputs=[data_preview, target_column, status_message] + input_components + [inputs_group, input_status], |
|
|
) |
|
|
sample_evt.then(fn=update_criterion_on_target_change, inputs=[data_preview, target_column], outputs=[criterion]) |
|
|
|
|
|
target_column.change( |
|
|
fn=update_configuration, inputs=[data_preview, target_column], |
|
|
outputs=input_components + [inputs_group, input_status, status_message], |
|
|
) |
|
|
target_column.change( |
|
|
fn=update_criterion_on_target_change, inputs=[data_preview, target_column], |
|
|
outputs=[criterion], |
|
|
) |
|
|
|
|
|
run_prediction_btn.click( |
|
|
fn=execute_prediction, |
|
|
inputs=[data_preview, target_column, max_depth, min_samples_split, min_samples_leaf, criterion] + input_components, |
|
|
outputs=[tree_visualization, feature_importance_plot, prediction_result, prediction_details, algorithm_summary], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(allowed_paths=["static/aivn_logo.png", "static/vlai_logo.png", "static"]) |
|
|
|