File size: 19,989 Bytes
dbb8268
 
679ad4f
 
dbb8268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c63a8e3
a6324d7
dbb8268
c63a8e3
a6324d7
dbb8268
 
a6324d7
 
 
c63a8e3
dbb8268
a6324d7
dbb8268
a6324d7
c63a8e3
 
dbb8268
a6324d7
 
 
 
dbb8268
 
 
c63a8e3
dbb8268
 
 
 
c63a8e3
dbb8268
a6324d7
dbb8268
a6324d7
dbb8268
 
 
 
a6324d7
dbb8268
 
a6324d7
dbb8268
a6324d7
 
 
dbb8268
a6324d7
 
 
c63a8e3
a6324d7
 
 
dbb8268
 
 
a6324d7
 
 
dbb8268
c63a8e3
a6324d7
dbb8268
 
 
b2f22dd
a6324d7
 
 
 
 
 
 
 
 
8298402
b9bf1f1
8298402
 
a6324d7
b2f22dd
dbb8268
a6324d7
dbb8268
 
a6324d7
 
 
dbb8268
 
a6324d7
 
dbb8268
 
c63a8e3
a6324d7
dbb8268
 
a6324d7
dbb8268
78955bd
c63a8e3
78955bd
a6324d7
dbb8268
c63a8e3
dbb8268
78955bd
a6324d7
dbb8268
8298402
b2f22dd
a6324d7
 
 
 
 
 
 
 
 
8298402
b9bf1f1
8298402
 
a6324d7
b2f22dd
dbb8268
78955bd
 
a6324d7
dbb8268
78955bd
dbb8268
a6324d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c63a8e3
a6324d7
dbb8268
 
a6324d7
 
 
 
dbb8268
a6324d7
c63a8e3
a6324d7
 
c63a8e3
dbb8268
a6324d7
 
 
 
 
 
 
 
 
 
dbb8268
 
 
a6324d7
 
 
 
dbb8268
a6324d7
 
 
 
dbb8268
 
a6324d7
dbb8268
a6324d7
 
679ad4f
 
dbb8268
679ad4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbb8268
a6324d7
 
 
 
679ad4f
 
 
 
 
 
 
dbb8268
a6324d7
dbb8268
a6324d7
 
 
 
 
dbb8268
 
a6324d7
dbb8268
 
 
 
 
 
a6324d7
dbb8268
a6324d7
 
78955bd
a6324d7
 
dbb8268
a6324d7
 
dbb8268
 
 
28c5210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbb8268
28c5210
 
 
 
a6324d7
dbb8268
 
 
 
 
a6324d7
dbb8268
a6324d7
b2f22dd
a6324d7
8298402
 
 
a6324d7
 
 
dbb8268
 
 
 
a6324d7
dbb8268
a6324d7
 
 
 
 
 
 
 
 
 
 
 
 
dbb8268
a6324d7
 
 
dbb8268
a6324d7
dbb8268
a6324d7
 
 
dbb8268
 
a6324d7
dbb8268
a6324d7
 
 
dbb8268
 
a6324d7
dbb8268
a6324d7
 
dbb8268
a6324d7
 
dbb8268
c63a8e3
a6324d7
 
dbb8268
a6324d7
dbb8268
 
c63a8e3
a6324d7
dbb8268
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
import gradio as gr
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from src import decision_tree_core
import vlai_template

# Global state
current_dataframe = None

# Dataset configurations
SAMPLE_DATA_CONFIG = {
    "Iris": {"target_column": "target", "problem_type": "classification"},
    "Wine": {"target_column": "target", "problem_type": "classification"},
    "Breast Cancer": {"target_column": "target", "problem_type": "classification"},
    "Diabetes": {"target_column": "target", "problem_type": "regression"},
}

force_light_theme_js = """
() => {
  const params = new URLSearchParams(window.location.search);
  if (!params.has('__theme')) {
    params.set('__theme', 'light');
    window.location.search = params.toString();
  }
}
"""

def validate_config(df, target_col):
    """Validate target column and determine problem type."""
    if not target_col or target_col not in df.columns:
        return False, "❌ Please select a valid target column from the dropdown.", None

    target_series = df[target_col]
    unique_vals = target_series.nunique()

    # Auto-detect
    if target_series.dtype == "object" or unique_vals <= min(20, len(target_series) * 0.1):
        problem_type = "classification"
        if unique_vals > 50:
            return False, f"⚠️ Too many classes ({unique_vals}). Consider another target.", None
        if target_series.isnull().any():
            return False, "⚠️ Target column has missing values. Please clean your data.", None
    else:
        problem_type = "regression"
        if unique_vals < 5:
            return False, f"⚠️ Too few unique values ({unique_vals}). Consider another target.", None

    return True, f"\nβœ… Configuration is valid! Ready for {unique_vals} {'classes' if problem_type=='classification' else 'values'}.", problem_type


def get_status_message(is_sample, dataset_choice, target_col, problem_type, is_valid, validation_msg):
    if is_sample:
        return f"βœ… **Selected Dataset**: {dataset_choice} | **Target**: {target_col} | **Type**: {problem_type.title()}"
    elif target_col and problem_type:
        status_icon = "βœ…" if is_valid else "⚠️"
        return f"{status_icon} **Custom Data** | **Target**: {target_col} | **Type**: {problem_type.title()} | {validation_msg}"
    else:
        return "πŸ“ **Custom data uploaded!** πŸ‘† Please select target column above to continue."


def load_and_configure_data(file_obj=None, dataset_choice="Iris"):
    """Load data and prepare target/problem type + feature inputs."""
    global current_dataframe
    try:
        df = decision_tree_core.load_data(file_obj, dataset_choice)
        current_dataframe = df

        target_options = df.columns.tolist()
        is_sample = file_obj is None

        if is_sample:
            cfg = SAMPLE_DATA_CONFIG.get(dataset_choice, {})
            target_col = cfg.get("target_column")
            problem_type = cfg.get("problem_type")
        else:
            target_col, problem_type = None, None

        # Validate & status
        if target_col:
            is_valid, validation_msg, detected = validate_config(df, target_col)
            if detected:
                problem_type = detected
            status_msg = get_status_message(is_sample, dataset_choice, target_col, problem_type, is_valid, validation_msg)
        else:
            status_msg = get_status_message(is_sample, dataset_choice, target_col, problem_type, False, "")

        # Build feature input widgets
        input_updates = [gr.update(visible=False)] * 40  # 20 features * (number + dropdown)
        inputs_visible = gr.update(visible=False)
        input_status = "βš™οΈ Configure target column above to enable feature inputs."

        if target_col and problem_type and (not is_sample or is_valid):
            try:
                components_info = decision_tree_core.create_input_components(df, target_col)
                for i in range(min(20, len(components_info))):
                    comp = components_info[i]
                    number_idx, dropdown_idx = i * 2, i * 2 + 1
                    if comp["type"] == "number":
                        upd = {"visible": True, "label": comp["name"], "value": comp["value"]}
                        if comp["minimum"] is not None:
                            upd["minimum"] = comp["minimum"]
                        if comp["maximum"] is not None:
                            upd["maximum"] = comp["maximum"]
                        input_updates[number_idx] = gr.update(**upd)
                        input_updates[dropdown_idx] = gr.update(visible=False)
                    else:
                        input_updates[number_idx] = gr.update(visible=False)
                        input_updates[dropdown_idx] = gr.update(
                            visible=True, label=comp["name"], choices=comp["choices"], value=comp["value"]
                        )
                inputs_visible = gr.update(visible=True)
                input_status = f"πŸ“ **Ready!** Enter values for {len(components_info)} features below, then click Run prediction. | {validation_msg}"
            except Exception as e:
                input_status = f"❌ Error generating inputs: {str(e)}"

        return [df.head(5).round(2), gr.Dropdown(choices=target_options, value=target_col), status_msg] + input_updates + [inputs_visible, input_status]

    except Exception as e:
        current_dataframe = None
        empty = [pd.DataFrame(), gr.Dropdown(choices=[], value=None), f"❌ **Error loading data**: {str(e)} | Please try a different file or dataset."]
        return empty + [gr.update(visible=False)] * 40 + [gr.update(visible=False), "No data loaded."]


def update_configuration(df_preview, target_col):
    """Rebuild feature widgets when target changes."""
    global current_dataframe
    df = current_dataframe

    if df is None or df.empty:
        return [gr.update(visible=False)] * 40 + [gr.update(visible=False), "No data available.", "No data available."]
    if not target_col:
        return [gr.update(visible=False)] * 40 + [gr.update(visible=False), "Select target column.", "Select target column."]

    try:
        is_valid, validation_msg, problem_type = validate_config(df, target_col)
        if not is_valid:
            return [gr.update(visible=False)] * 40 + [gr.update(visible=False), f"⚠️ {validation_msg}", f"⚠️ {validation_msg}"]

        components_info = decision_tree_core.create_input_components(df, target_col)
        input_updates = [gr.update(visible=False)] * 40
        for i in range(min(20, len(components_info))):
            comp = components_info[i]
            number_idx, dropdown_idx = i * 2, i * 2 + 1
            if comp["type"] == "number":
                upd = {"visible": True, "label": comp["name"], "value": comp["value"]}
                if comp["minimum"] is not None:
                    upd["minimum"] = comp["minimum"]
                if comp["maximum"] is not None:
                    upd["maximum"] = comp["maximum"]
                input_updates[number_idx] = gr.update(**upd)
                input_updates[dropdown_idx] = gr.update(visible=False)
            else:
                input_updates[number_idx] = gr.update(visible=False)
                input_updates[dropdown_idx] = gr.update(
                    visible=True, label=comp["name"], choices=comp["choices"], value=comp["value"]
                )
        input_status = f"πŸ“ Enter values for {len(components_info)} features | {validation_msg}"
        status_msg = f"βœ… **Selected Dataset**: Custom Data | **Target**: {target_col} | **Type**: {problem_type.title()}"
        return input_updates + [gr.update(visible=True), input_status, status_msg]

    except Exception as e:
        return [gr.update(visible=False)] * 40 + [gr.update(visible=False), f"❌ Error: {str(e)}", f"❌ Error: {str(e)}"]


# ---- criterion helpers ----
CLASS_CRITS = {"gini", "entropy", "log_loss"}
REGR_CRITS = {"squared_error", "absolute_error", "friedman_mse", "poisson"}

def update_criterion_choices(problem_type):
    if problem_type == "classification":
        return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini")
    else:
        return gr.Dropdown(choices=sorted(REGR_CRITS), value="squared_error")


def update_criterion_on_target_change(df_preview, target_col):
    """Recompute problem type from current df + target and return the right dropdown config."""
    if not target_col:
        return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini")
    global current_dataframe
    df = current_dataframe
    if df is None or df.empty:
        return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini")
    try:
        is_valid, _, problem_type = validate_config(df, target_col)
        if problem_type == "classification":
            return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini")
        else:
            return gr.Dropdown(choices=sorted(REGR_CRITS), value="squared_error")
    except Exception:
        return gr.Dropdown(choices=sorted(CLASS_CRITS), value="gini")


def execute_prediction(df_preview, target_col, max_depth, min_samples_split, min_samples_leaf, criterion, *input_values):
    """Run the tree and produce all outputs. Always return 5 values."""
    global current_dataframe
    df = current_dataframe

    EMPTY_PLOT = None  # for gr.Plot
    EMPTY_MD = " "     # for gr.Markdown

    if df is None or df.empty:
        return (EMPTY_PLOT, EMPTY_PLOT, "❌ **No data loaded!** πŸ“Š Please select a sample dataset or upload a file first.", EMPTY_MD, "Load data to get started.")
    if not target_col:
        return (EMPTY_PLOT, EMPTY_PLOT, "❌ **Configuration incomplete!** 🎯 Please select target column above.", EMPTY_MD, "Complete configuration to proceed.")

    is_valid, validation_msg, problem_type = validate_config(df, target_col)
    if not is_valid:
        return (EMPTY_PLOT, EMPTY_PLOT, f"❌ **Configuration issue**: {validation_msg}", EMPTY_MD, "Fix the configuration and try again.")

    # normalize criterion defensively
    if problem_type == "classification":
        if criterion not in CLASS_CRITS:
            criterion = "gini"
    else:
        if criterion not in REGR_CRITS:
            criterion = "squared_error"

    try:
        components_info = decision_tree_core.create_input_components(df, target_col)
        new_point_dict = {}
        for i, comp in enumerate(components_info):
            number_idx, dropdown_idx = i * 2, i * 2 + 1
            if comp["type"] == "number":
                v = input_values[number_idx] if number_idx < len(input_values) and input_values[number_idx] is not None else comp["value"]
            else:
                v = input_values[dropdown_idx] if dropdown_idx < len(input_values) and input_values[dropdown_idx] is not None else comp["value"]
            new_point_dict[comp["name"]] = v

        tree_fig, importance_fig, prediction, pred_details, summary, error = decision_tree_core.run_decision_tree_and_visualize(
            df, target_col, new_point_dict, max_depth, min_samples_split, min_samples_leaf, criterion, problem_type
        )

        if error:
            return (tree_fig or EMPTY_PLOT, importance_fig or EMPTY_PLOT, f"❌ **Prediction failed**: {error}", pred_details or EMPTY_MD, summary or "Adjust inputs and retry.")

        # --------- NEW: decode class id -> original label (if any) ----------
        display_pred = prediction
        if problem_type == "classification":
            try:
                y_series = df[target_col]
                # If target is object/categorical (string), use LabelEncoder for reverse map 
                if pd.api.types.is_object_dtype(y_series) or pd.api.types.is_categorical_dtype(y_series):
                    le = LabelEncoder()
                    le.fit(y_series.dropna().astype(str))
                    # prediction 0..K-1 -> inverse_transform
                    if isinstance(prediction, (int, np.integer, float, np.floating)) or str(prediction).isdigit():
                        code = int(prediction)
                        if 0 <= code < len(le.classes_):
                            display_pred = le.inverse_transform([code])[0]
            except Exception:
                # fallback safe
                display_pred = prediction
        # --------------------------------------------------------------------

        if problem_type == "classification":
            header = f"## 🎯 **Classification Result**: {display_pred}\n*Based on decision tree with `{criterion}`*"
        else:
            header = f"## 🎯 **Regression Result**: {float(prediction):.3f}\n*Based on decision tree with `{criterion}`*"

        return (tree_fig, importance_fig, header, pred_details, summary)

        # if problem_type == "classification":
        #     header = f"## 🎯 **Classification Result**: {prediction}\n*Based on decision tree with `{criterion}`*"
        # else:
        #     header = f"## 🎯 **Regression Result**: {float(prediction):.3f}\n*Based on decision tree with `{criterion}`*"

        # return (tree_fig, importance_fig, header, pred_details, summary)

    except Exception as e:
        return (EMPTY_PLOT, EMPTY_PLOT, f"❌ **Execution error**: {str(e)}", EMPTY_MD, "Check inputs and try again.")


# ==========================
# Gradio UI
# ==========================
with gr.Blocks(theme="gstaff/sketch", css=vlai_template.custom_css, fill_width=True, js=force_light_theme_js) as demo:
    vlai_template.create_header()
    gr.Markdown("### 🌳 **How to Use**: Select data β†’ Configure target β†’ Set tree parameters β†’ Enter new point β†’ Run prediction!")

    with gr.Row(equal_height=False, variant="panel"):
        with gr.Column(scale=45):
            with gr.Accordion("πŸ“Š Data & Configuration", open=True):
                with gr.Row():
                    with gr.Column(scale=1):
                        gr.Markdown("Start with sample datasets or upload your own CSV/Excel files.")
                        file_upload = gr.File(label="πŸ“ Upload Your Data", file_types=[".csv", ".xlsx", ".xls"])
                    with gr.Column(scale=3):
                        sample_dataset = gr.Dropdown(choices=list(SAMPLE_DATA_CONFIG.keys()), value="Iris", label="πŸ—‚οΈ Sample Datasets")

                with gr.Row():
                    target_column = gr.Dropdown(choices=[], label="🎯 Target Column", interactive=True)

                status_message = gr.Markdown("πŸ”„ Loading sample data...")
                data_preview = gr.DataFrame(label="πŸ“‹ Data Preview (First 5 Rows)", row_count=5, interactive=False, max_height=250)

            with gr.Accordion("βš™οΈ Parameters & Input", open=True):
                gr.Markdown("**🌳 Decision Tree Parameters**")
                with gr.Row():
                    max_depth = gr.Number(
                        label="Max Depth",
                        value=5, minimum=0, maximum=20, precision=0,
                        info="Set to 0 for unlimited depth"
                    )
                    min_samples_split = gr.Number(
                        label="Min Samples Split",
                        value=2, minimum=2, maximum=100, precision=0,
                        info="Minimum samples required to split an internal node"
                    )
                    min_samples_leaf = gr.Number(
                        label="Min Samples Leaf",
                        value=1, minimum=1, maximum=50, precision=0,
                        info="Minimum samples required to be at a leaf node"
                    )
                with gr.Row():
                    criterion = gr.Dropdown(
                        choices=sorted(CLASS_CRITS), value="gini", label="🎯 Criterion",
                        info="Objective to measure split quality (auto-switched for regression)"
                    )

                inputs_group = gr.Group(visible=False)
                with inputs_group:
                    input_status = gr.Markdown("Configure inputs above.")
                    gr.Markdown("**πŸ“ New Data Point** - Enter feature values for prediction:")
                    input_components = []
                    for row in range(5):
                        with gr.Row():
                            for col in range(4):
                                idx = row * 4 + col
                                if idx < 20:
                                    number_comp = gr.Number(label=f"Feature {idx+1}", visible=False)
                                    dropdown_comp = gr.Dropdown(label=f"Feature {idx+1}", visible=False)
                                    input_components.extend([number_comp, dropdown_comp])

                run_prediction_btn = gr.Button("πŸš€ Run Prediction", variant="primary", size="lg")

        with gr.Column(scale=55):
            gr.Markdown("### 🌳 **Decision Tree Results & Visualization**")
            with gr.Tabs():
                with gr.TabItem("Decision Tree"):
                    tree_visualization = gr.Plot(label="Decision Tree (Interactive)", visible=True)
                with gr.TabItem("Feature Importance"):
                    feature_importance_plot = gr.Plot(label="Feature Importance", visible=True)

            prediction_result = gr.Markdown("## 🎯 Prediction Result\n**Run prediction to see the result.**", label="πŸ“ˆ Final Prediction")
            prediction_details = gr.Markdown("**πŸ“ Prediction Details**\n\nDetailed prediction information will appear here.", label="πŸ” Prediction Details")
            algorithm_summary = gr.Markdown("**πŸ“‹ Algorithm Summary**\n\nAlgorithm details will appear here after prediction.", label="πŸ” Technical Details")

    gr.Markdown("""πŸ’‘ **Tips**:
- **Tree visualization** shows the complete decision tree structure with decision paths.
- **Feature importance** shows which features matter most.
- Try different **max depth** and **criterion** to see structure changes!
- **Min samples split/leaf** control complexity and reduce overfitting.
""")

    vlai_template.create_footer()

    # ---- Event bindings ----
    load_evt = demo.load(
        fn=lambda: load_and_configure_data(None, "Iris"),
        outputs=[data_preview, target_column, status_message] + input_components + [inputs_group, input_status],
    )
    load_evt.then(fn=update_criterion_on_target_change, inputs=[data_preview, target_column], outputs=[criterion])

    upload_evt = file_upload.upload(
        fn=lambda file: load_and_configure_data(file, "Iris"),
        inputs=[file_upload],
        outputs=[data_preview, target_column, status_message] + input_components + [inputs_group, input_status],
    )
    upload_evt.then(fn=update_criterion_on_target_change, inputs=[data_preview, target_column], outputs=[criterion])

    sample_evt = sample_dataset.change(
        fn=lambda choice: load_and_configure_data(None, choice),
        inputs=[sample_dataset],
        outputs=[data_preview, target_column, status_message] + input_components + [inputs_group, input_status],
    )
    sample_evt.then(fn=update_criterion_on_target_change, inputs=[data_preview, target_column], outputs=[criterion])

    target_column.change(
        fn=update_configuration, inputs=[data_preview, target_column],
        outputs=input_components + [inputs_group, input_status, status_message],
    )
    target_column.change(
        fn=update_criterion_on_target_change, inputs=[data_preview, target_column],
        outputs=[criterion],
    )

    run_prediction_btn.click(
        fn=execute_prediction,
        inputs=[data_preview, target_column, max_depth, min_samples_split, min_samples_leaf, criterion] + input_components,
        outputs=[tree_visualization, feature_importance_plot, prediction_result, prediction_details, algorithm_summary],
    )

if __name__ == "__main__":
    demo.launch(allowed_paths=["static/aivn_logo.png", "static/vlai_logo.png", "static"])