wjnwjn59's picture
update parameter
28c5210
import gradio as gr
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from src import decision_tree_core
import vlai_template
# Global state
current_dataframe = None
# Dataset configurations
SAMPLE_DATA_CONFIG = {
"Iris": {"target_column": "target", "problem_type": "classification"},
"Wine": {"target_column": "target", "problem_type": "classification"},
"Breast Cancer": {"target_column": "target", "problem_type": "classification"},
"Diabetes": {"target_column": "target", "problem_type": "regression"},
}
force_light_theme_js = """
() => {
const params = new URLSearchParams(window.location.search);
if (!params.has('__theme')) {
params.set('__theme', 'light');
window.location.search = params.toString();
}
}
"""
def validate_config(df, target_col):
"""Validate target column and determine problem type."""
if not target_col or target_col not in df.columns:
return False, "❌ Please select a valid target column from the dropdown.", None
target_series = df[target_col]
unique_vals = target_series.nunique()
# Auto-detect
if target_series.dtype == "object" or unique_vals <= min(20, len(target_series) * 0.1):
problem_type = "classification"
if unique_vals > 50:
return False, f"⚠️ Too many classes ({unique_vals}). Consider another target.", None
if target_series.isnull().any():
return False, "⚠️ Target column has missing values. Please clean your data.", None
else:
problem_type = "regression"
if unique_vals < 5:
return False, f"⚠️ Too few unique values ({unique_vals}). Consider another target.", None
return True, f"\nβœ… Configuration is valid! Ready for {unique_vals} {'classes' if problem_type=='classification' else 'values'}.", problem_type
def get_status_message(is_sample, dataset_choice, target_col, problem_type, is_valid, validation_msg):
if is_sample:
return f"βœ… **Selected Dataset**: {dataset_choice} | **Target**: {target_col} | **Type**: {problem_type.title()}"
elif target_col and problem_type:
status_icon = "βœ…" if is_valid else "⚠️"
return f"{status_icon} **Custom Data** | **Target**: {target_col} | **Type**: {problem_type.title()} | {validation_msg}"
else:
return "πŸ“ **Custom data uploaded!** πŸ‘† Please select target column above to continue."
def load_and_configure_data(file_obj=None, dataset_choice="Iris"):
"""Load data and prepare target/problem type + feature inputs."""
global current_dataframe
try:
df = decision_tree_core.load_data(file_obj, dataset_choice)
current_dataframe = df
target_options = df.columns.tolist()
is_sample = file_obj is None
if is_sample:
cfg = SAMPLE_DATA_CONFIG.get(dataset_choice, {})
target_col = cfg.get("target_column")
problem_type = cfg.get("problem_type")
else:
target_col, problem_type = None, None
# Validate & status
if target_col:
is_valid, validation_msg, detected = validate_config(df, target_col)
if detected:
problem_type = detected
status_msg = get_status_message(is_sample, dataset_choice, target_col, problem_type, is_valid, validation_msg)
else:
status_msg = get_status_message(is_sample, dataset_choice, target_col, problem_type, False, "")
# Build feature input widgets
input_updates = [gr.update(visible=False)] * 40 # 20 features * (number + dropdown)
inputs_visible = gr.update(visible=False)
input_status = "βš™οΈ Configure target column above to enable feature inputs."
if target_col and problem_type and (not is_sample or is_valid):
try:
components_info = decision_tree_core.create_input_components(df, target_col)
for i in range(min(20, len(components_info))):
comp = components_info[i]
number_idx, dropdown_idx = i * 2, i * 2 + 1
if comp["type"] == "number":
upd = {"visible": True, "label": comp["name"], "value": comp["value"]}
if comp["minimum"] is not None:
upd["minimum"] = comp["minimum"]
if comp["maximum"] is not None:
upd["maximum"] = comp["maximum"]
input_updates[number_idx] = gr.update(**upd)
input_updates[dropdown_idx] = gr.update(visible=False)
else:
input_updates[number_idx] = gr.update(visible=False)
input_updates[dropdown_idx] = gr.update(
visible=True, label=comp["name"], choices=comp["choices"], value=comp["value"]
)
inputs_visible = gr.update(visible=True)
input_status = f"πŸ“ **Ready!** Enter values for {len(components_info)} features below, then click Run prediction. | {validation_msg}"
except Exception as e:
input_status = f"❌ Error generating inputs: {str(e)}"
return [df.head(5).round(2), gr.Dropdown(choices=target_options, value=target_col), status_msg] + input_updates + [inputs_visible, input_status]
except Exception as e:
current_dataframe = None
empty = [pd.DataFrame(), gr.Dropdown(choices=[], value=None), f"❌ **Error loading data**: {str(e)} | Please try a different file or dataset."]
return empty + [gr.update(visible=False)] * 40 + [gr.update(visible=False), "No data loaded."]
def update_configuration(df_preview, target_col):
"""Rebuild feature widgets when target changes."""
global current_dataframe
df = current_dataframe
if df is None or df.empty:
return [gr.update(visible=False)] * 40 + [gr.update(visible=False), "No data available.", "No data available."]
if not target_col:
return [gr.update(visible=False)] * 40 + [gr.update(visible=False), "Select target column.", "Select target column."]
try:
is_valid, validation_msg, problem_type = validate_config(df, target_col)
if not is_valid:
return [gr.update(visible=False)] * 40 + [gr.update(visible=False), f"⚠️ {validation_msg}", f"⚠️ {validation_msg}"]
components_info = decision_tree_core.create_input_components(df, target_col)
input_updates = [gr.update(visible=False)] * 40
for i in range(min(20, len(components_info))):
comp = components_info[i]
number_idx, dropdown_idx = i * 2, i * 2 + 1
if comp["type"] == "number":
upd = {"visible": True, "label": comp["name"], "value": comp["value"]}
if comp["minimum"] is not None:
upd["minimum"] = comp["minimum"]
if comp["maximum"] is not None:
upd["maximum"] = comp["maximum"]
input_updates[number_idx] = gr.update(**upd)
input_updates[dropdown_idx] = gr.update(visible=False)
else:
input_updates[number_idx] = gr.update(visible=False)
input_updates[dropdown_idx] = gr.update(
visible=True, label=comp["name"], choices=comp["choices"], value=comp["value"]
)
input_status = f"πŸ“ Enter values for {len(components_info)} features | {validation_msg}"
status_msg = f"βœ… **Selected Dataset**: Custom Data | **Target**: {target_col} | **Type**: {problem_type.title()}"
return input_updates + [gr.update(visible=True), input_status, status_msg]
except Exception as e:
return [gr.update(visible=False)] * 40 + [gr.update(visible=False), f"❌ Error: {str(e)}", f"❌ Error: {str(e)}"]
# ---- criterion helpers ----
CLASS_CRITS = {"gini", "entropy", "log_loss"}
REGR_CRITS = {"squared_error", "absolute_error", "friedman_mse", "poisson"}
def update_criterion_choices(problem_type):
if problem_type == "classification":
return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini")
else:
return gr.Dropdown(choices=sorted(REGR_CRITS), value="squared_error")
def update_criterion_on_target_change(df_preview, target_col):
"""Recompute problem type from current df + target and return the right dropdown config."""
if not target_col:
return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini")
global current_dataframe
df = current_dataframe
if df is None or df.empty:
return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini")
try:
is_valid, _, problem_type = validate_config(df, target_col)
if problem_type == "classification":
return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini")
else:
return gr.Dropdown(choices=sorted(REGR_CRITS), value="squared_error")
except Exception:
return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini")
def execute_prediction(df_preview, target_col, max_depth, min_samples_split, min_samples_leaf, criterion, *input_values):
"""Run the tree and produce all outputs. Always return 5 values."""
global current_dataframe
df = current_dataframe
EMPTY_PLOT = None # for gr.Plot
EMPTY_MD = " " # for gr.Markdown
if df is None or df.empty:
return (EMPTY_PLOT, EMPTY_PLOT, "❌ **No data loaded!** πŸ“Š Please select a sample dataset or upload a file first.", EMPTY_MD, "Load data to get started.")
if not target_col:
return (EMPTY_PLOT, EMPTY_PLOT, "❌ **Configuration incomplete!** 🎯 Please select target column above.", EMPTY_MD, "Complete configuration to proceed.")
is_valid, validation_msg, problem_type = validate_config(df, target_col)
if not is_valid:
return (EMPTY_PLOT, EMPTY_PLOT, f"❌ **Configuration issue**: {validation_msg}", EMPTY_MD, "Fix the configuration and try again.")
# normalize criterion defensively
if problem_type == "classification":
if criterion not in CLASS_CRITS:
criterion = "gini"
else:
if criterion not in REGR_CRITS:
criterion = "squared_error"
try:
components_info = decision_tree_core.create_input_components(df, target_col)
new_point_dict = {}
for i, comp in enumerate(components_info):
number_idx, dropdown_idx = i * 2, i * 2 + 1
if comp["type"] == "number":
v = input_values[number_idx] if number_idx < len(input_values) and input_values[number_idx] is not None else comp["value"]
else:
v = input_values[dropdown_idx] if dropdown_idx < len(input_values) and input_values[dropdown_idx] is not None else comp["value"]
new_point_dict[comp["name"]] = v
tree_fig, importance_fig, prediction, pred_details, summary, error = decision_tree_core.run_decision_tree_and_visualize(
df, target_col, new_point_dict, max_depth, min_samples_split, min_samples_leaf, criterion, problem_type
)
if error:
return (tree_fig or EMPTY_PLOT, importance_fig or EMPTY_PLOT, f"❌ **Prediction failed**: {error}", pred_details or EMPTY_MD, summary or "Adjust inputs and retry.")
# --------- NEW: decode class id -> original label (if any) ----------
display_pred = prediction
if problem_type == "classification":
try:
y_series = df[target_col]
# If target is object/categorical (string), use LabelEncoder for reverse map
if pd.api.types.is_object_dtype(y_series) or pd.api.types.is_categorical_dtype(y_series):
le = LabelEncoder()
le.fit(y_series.dropna().astype(str))
# prediction 0..K-1 -> inverse_transform
if isinstance(prediction, (int, np.integer, float, np.floating)) or str(prediction).isdigit():
code = int(prediction)
if 0 <= code < len(le.classes_):
display_pred = le.inverse_transform([code])[0]
except Exception:
# fallback safe
display_pred = prediction
# --------------------------------------------------------------------
if problem_type == "classification":
header = f"## 🎯 **Classification Result**: {display_pred}\n*Based on decision tree with `{criterion}`*"
else:
header = f"## 🎯 **Regression Result**: {float(prediction):.3f}\n*Based on decision tree with `{criterion}`*"
return (tree_fig, importance_fig, header, pred_details, summary)
# if problem_type == "classification":
# header = f"## 🎯 **Classification Result**: {prediction}\n*Based on decision tree with `{criterion}`*"
# else:
# header = f"## 🎯 **Regression Result**: {float(prediction):.3f}\n*Based on decision tree with `{criterion}`*"
# return (tree_fig, importance_fig, header, pred_details, summary)
except Exception as e:
return (EMPTY_PLOT, EMPTY_PLOT, f"❌ **Execution error**: {str(e)}", EMPTY_MD, "Check inputs and try again.")
# ==========================
# Gradio UI
# ==========================
with gr.Blocks(theme="gstaff/sketch", css=vlai_template.custom_css, fill_width=True, js=force_light_theme_js) as demo:
vlai_template.create_header()
gr.Markdown("### 🌳 **How to Use**: Select data β†’ Configure target β†’ Set tree parameters β†’ Enter new point β†’ Run prediction!")
with gr.Row(equal_height=False, variant="panel"):
with gr.Column(scale=45):
with gr.Accordion("πŸ“Š Data & Configuration", open=True):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("Start with sample datasets or upload your own CSV/Excel files.")
file_upload = gr.File(label="πŸ“ Upload Your Data", file_types=[".csv", ".xlsx", ".xls"])
with gr.Column(scale=3):
sample_dataset = gr.Dropdown(choices=list(SAMPLE_DATA_CONFIG.keys()), value="Iris", label="πŸ—‚οΈ Sample Datasets")
with gr.Row():
target_column = gr.Dropdown(choices=[], label="🎯 Target Column", interactive=True)
status_message = gr.Markdown("πŸ”„ Loading sample data...")
data_preview = gr.DataFrame(label="πŸ“‹ Data Preview (First 5 Rows)", row_count=5, interactive=False, max_height=250)
with gr.Accordion("βš™οΈ Parameters & Input", open=True):
gr.Markdown("**🌳 Decision Tree Parameters**")
with gr.Row():
max_depth = gr.Number(
label="Max Depth",
value=5, minimum=0, maximum=20, precision=0,
info="Set to 0 for unlimited depth"
)
min_samples_split = gr.Number(
label="Min Samples Split",
value=2, minimum=2, maximum=100, precision=0,
info="Minimum samples required to split an internal node"
)
min_samples_leaf = gr.Number(
label="Min Samples Leaf",
value=1, minimum=1, maximum=50, precision=0,
info="Minimum samples required to be at a leaf node"
)
with gr.Row():
criterion = gr.Dropdown(
choices=sorted(CLASS_CRITS), value="gini", label="🎯 Criterion",
info="Objective to measure split quality (auto-switched for regression)"
)
inputs_group = gr.Group(visible=False)
with inputs_group:
input_status = gr.Markdown("Configure inputs above.")
gr.Markdown("**πŸ“ New Data Point** - Enter feature values for prediction:")
input_components = []
for row in range(5):
with gr.Row():
for col in range(4):
idx = row * 4 + col
if idx < 20:
number_comp = gr.Number(label=f"Feature {idx+1}", visible=False)
dropdown_comp = gr.Dropdown(label=f"Feature {idx+1}", visible=False)
input_components.extend([number_comp, dropdown_comp])
run_prediction_btn = gr.Button("πŸš€ Run Prediction", variant="primary", size="lg")
with gr.Column(scale=55):
gr.Markdown("### 🌳 **Decision Tree Results & Visualization**")
with gr.Tabs():
with gr.TabItem("Decision Tree"):
tree_visualization = gr.Plot(label="Decision Tree (Interactive)", visible=True)
with gr.TabItem("Feature Importance"):
feature_importance_plot = gr.Plot(label="Feature Importance", visible=True)
prediction_result = gr.Markdown("## 🎯 Prediction Result\n**Run prediction to see the result.**", label="πŸ“ˆ Final Prediction")
prediction_details = gr.Markdown("**πŸ“ Prediction Details**\n\nDetailed prediction information will appear here.", label="πŸ” Prediction Details")
algorithm_summary = gr.Markdown("**πŸ“‹ Algorithm Summary**\n\nAlgorithm details will appear here after prediction.", label="πŸ” Technical Details")
gr.Markdown("""πŸ’‘ **Tips**:
- **Tree visualization** shows the complete decision tree structure with decision paths.
- **Feature importance** shows which features matter most.
- Try different **max depth** and **criterion** to see structure changes!
- **Min samples split/leaf** control complexity and reduce overfitting.
""")
vlai_template.create_footer()
# ---- Event bindings ----
load_evt = demo.load(
fn=lambda: load_and_configure_data(None, "Iris"),
outputs=[data_preview, target_column, status_message] + input_components + [inputs_group, input_status],
)
load_evt.then(fn=update_criterion_on_target_change, inputs=[data_preview, target_column], outputs=[criterion])
upload_evt = file_upload.upload(
fn=lambda file: load_and_configure_data(file, "Iris"),
inputs=[file_upload],
outputs=[data_preview, target_column, status_message] + input_components + [inputs_group, input_status],
)
upload_evt.then(fn=update_criterion_on_target_change, inputs=[data_preview, target_column], outputs=[criterion])
sample_evt = sample_dataset.change(
fn=lambda choice: load_and_configure_data(None, choice),
inputs=[sample_dataset],
outputs=[data_preview, target_column, status_message] + input_components + [inputs_group, input_status],
)
sample_evt.then(fn=update_criterion_on_target_change, inputs=[data_preview, target_column], outputs=[criterion])
target_column.change(
fn=update_configuration, inputs=[data_preview, target_column],
outputs=input_components + [inputs_group, input_status, status_message],
)
target_column.change(
fn=update_criterion_on_target_change, inputs=[data_preview, target_column],
outputs=[criterion],
)
run_prediction_btn.click(
fn=execute_prediction,
inputs=[data_preview, target_column, max_depth, min_samples_split, min_samples_leaf, criterion] + input_components,
outputs=[tree_visualization, feature_importance_plot, prediction_result, prediction_details, algorithm_summary],
)
if __name__ == "__main__":
demo.launch(allowed_paths=["static/aivn_logo.png", "static/vlai_logo.png", "static"])