wjnwjn59 commited on
Commit
dbb8268
·
1 Parent(s): 8a68df4

first init

Browse files
README.md CHANGED
@@ -1,12 +1,72 @@
1
  ---
2
- title: AIO2025M03 DEMO DECISION TREE
3
- emoji: 🔥
4
- colorFrom: blue
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.40.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: AIO2025M03 DEMO Decision Tree
3
+ emoji: 🌳
4
+ colorFrom: green
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.38.2
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ # 🌳 Decision Tree Interactive Demo
13
+
14
+ An interactive web application demonstrating Decision Tree algorithms with real-time visualization and educational features.
15
+
16
+ ## ✨ Features
17
+
18
+ - **📊 Multiple Datasets**: 4 built-in datasets (Iris, Wine, Breast Cancer, Diabetes)
19
+ - **🎮 Interactive Interface**: Real-time parameter adjustment and prediction
20
+ - **🌳 Tree Visualization**: Interactive decision tree structure with zoom capabilities
21
+ - **📊 Feature Importance**: Visual representation of feature importance scores
22
+ - **🎛️ Flexible Parameters**: Adjustable max depth, split criteria, and leaf constraints
23
+ - **📱 Responsive Design**: Works on desktop and mobile devices
24
+
25
+ ## 🚀 Quick Start
26
+
27
+ ### Local Installation
28
+ ```bash
29
+ git clone <repository-url>
30
+ cd AIO2025M03_DEMO_DECISION_TREE
31
+ pip install -r requirements.txt
32
+ python app.py
33
+ ```
34
+
35
+ ### Usage
36
+ 1. **Select Dataset**: Choose from pre-loaded datasets or upload your own CSV/Excel file
37
+ 2. **Configure Target**: Select target column and problem type (classification/regression)
38
+ 3. **Set Parameters**: Adjust max depth, split criteria, and leaf constraints
39
+ 4. **Input New Point**: Enter feature values for prediction
40
+ 5. **Run Prediction**: Get results with interactive tree visualization
41
+
42
+ ## 🧠 Technical Highlights
43
+
44
+ - **Tree Structure**: Interactive visualization of decision tree nodes and splits
45
+ - **Feature Importance**: Automatic calculation and visualization of feature importance scores
46
+ - **Auto-Detection**: Automatically determines classification vs regression problems
47
+ - **Error Handling**: Robust validation and user-friendly error messages
48
+
49
+ ## 📋 Requirements
50
+
51
+ - Python 3.8+
52
+ - Gradio 5.38+
53
+ - Scikit-learn
54
+ - Pandas
55
+ - NumPy
56
+ - Plotly
57
+
58
+ ## 🎓 Educational Value
59
+
60
+ Perfect for:
61
+ - Understanding Decision Tree algorithm mechanics
62
+ - Learning about tree-based splitting criteria
63
+ - Exploring feature importance and tree pruning
64
+ - Comparing classification vs regression approaches
65
+
66
+ ## 📄 License
67
+
68
+ Educational use for AIO2025 course materials.
69
+
70
+ ---
71
+
72
+ **Live Demo**: [Decision Tree Demo](https://huggingface.co/spaces/VLAI-AIVN/AIO2025M03_DEMO_DECISION_TREE)
__pycache__/vlai_template.cpython-312.pyc ADDED
Binary file (5.12 kB). View file
 
app.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from src import decision_tree_core
4
+ import vlai_template
5
+
6
+ # Global state
7
+ current_dataframe = None
8
+
9
+ # Dataset configurations
10
+ SAMPLE_DATA_CONFIG = {
11
+ "Iris": {"target_column": "target", "problem_type": "classification"},
12
+ "Wine": {"target_column": "target", "problem_type": "classification"},
13
+ "Breast Cancer": {"target_column": "target", "problem_type": "classification"},
14
+ "Diabetes": {"target_column": "target", "problem_type": "regression"},
15
+ }
16
+
17
+ force_light_theme_js = """
18
+ () => {
19
+ const params = new URLSearchParams(window.location.search);
20
+ if (!params.has('__theme')) {
21
+ params.set('__theme', 'light');
22
+ window.location.search = params.toString();
23
+ }
24
+ }
25
+ """
26
+
27
+ def validate_config(df, target_col, problem_type):
28
+ """Validate target column and problem type compatibility"""
29
+ if not target_col or target_col not in df.columns:
30
+ return False, "❌ Please select a valid target column from the dropdown."
31
+
32
+ if not problem_type:
33
+ return False, "❌ Please select either Classification or Regression as problem type."
34
+
35
+ target_series = df[target_col]
36
+ unique_vals = target_series.nunique()
37
+
38
+ if problem_type == "classification":
39
+ if unique_vals > 50:
40
+ return False, f"⚠️ Too many classes ({unique_vals}). Consider using Regression instead."
41
+ if target_series.isnull().any():
42
+ return False, "⚠️ Target column contains missing values. Please clean your data."
43
+ elif problem_type == "regression":
44
+ if target_series.dtype == 'object':
45
+ return False, "⚠️ Text values detected in target. Use Classification for categories."
46
+ if unique_vals < 5:
47
+ return False, f"⚠️ Too few unique values ({unique_vals}). Consider using Classification."
48
+
49
+ return True, f"\n✅ Configuration is valid! Ready for {unique_vals} {'classes' if problem_type == 'classification' else 'values'}."
50
+
51
+ def get_status_message(is_sample, dataset_choice, target_col, problem_type, is_valid, validation_msg):
52
+ """Generate status message"""
53
+ if is_sample:
54
+ return f"✅ **Sample Dataset**: {dataset_choice} | **Target**: {target_col} | **Type**: {problem_type.title()}"
55
+ elif target_col and problem_type:
56
+ status_icon = "✅" if is_valid else "⚠️"
57
+ return f"{status_icon} **Custom Data** | **Target**: {target_col} | **Type**: {problem_type.title()} | {validation_msg}"
58
+ else:
59
+ return "📁 **Custom data uploaded!** 👆 Please select target column and problem type above to continue."
60
+
61
+ def load_and_configure_data(file_obj=None, dataset_choice="Iris"):
62
+ """Load data and configure target/problem type"""
63
+ global current_dataframe
64
+
65
+ try:
66
+ df = decision_tree_core.load_data(file_obj, dataset_choice)
67
+ current_dataframe = df
68
+
69
+ target_options = df.columns.tolist()
70
+ is_sample = file_obj is None
71
+
72
+ if is_sample:
73
+ config = SAMPLE_DATA_CONFIG.get(dataset_choice, {})
74
+ target_col = config.get("target_column")
75
+ problem_type = config.get("problem_type")
76
+ else:
77
+ target_col = None
78
+ problem_type = None
79
+
80
+ # Validate and generate status
81
+ if target_col and problem_type:
82
+ is_valid, validation_msg = validate_config(df, target_col, problem_type)
83
+ status_msg = get_status_message(is_sample, dataset_choice, target_col, problem_type, is_valid, validation_msg)
84
+ else:
85
+ status_msg = get_status_message(is_sample, dataset_choice, target_col, problem_type, False, "")
86
+
87
+ # Generate input components
88
+ input_updates = [gr.update(visible=False)] * 16
89
+ inputs_visible = gr.update(visible=False)
90
+ input_status = "⚙️ Configure target and problem type above to enable feature inputs."
91
+
92
+ if target_col and problem_type and (not is_sample or is_valid):
93
+ try:
94
+ components_info = decision_tree_core.create_input_components(df, target_col)
95
+ for i in range(min(16, len(components_info))):
96
+ comp_info = components_info[i]
97
+ if comp_info['type'] == 'number':
98
+ update_params = {
99
+ 'visible': True, 'label': comp_info['name'], 'value': comp_info['value']
100
+ }
101
+ if comp_info['minimum'] is not None:
102
+ update_params['minimum'] = comp_info['minimum']
103
+ if comp_info['maximum'] is not None:
104
+ update_params['maximum'] = comp_info['maximum']
105
+ input_updates[i] = gr.update(**update_params)
106
+ else:
107
+ input_updates[i] = gr.update(
108
+ visible=True, label=comp_info['name'],
109
+ choices=comp_info['choices'], value=comp_info['value']
110
+ )
111
+ inputs_visible = gr.update(visible=True)
112
+ input_status = f"📝 **Ready!** Enter values for {len(components_info)} features below, then click Run Prediction! | {validation_msg}"
113
+ except Exception as e:
114
+ input_status = f"❌ Error generating inputs: {str(e)}"
115
+
116
+ return [df.head(5).round(2), gr.Dropdown(choices=target_options, value=target_col),
117
+ gr.Dropdown(value=problem_type), status_msg] + input_updates + [inputs_visible, input_status]
118
+
119
+ except Exception as e:
120
+ current_dataframe = None
121
+ empty_updates = [pd.DataFrame(), gr.Dropdown(choices=[], value=None),
122
+ gr.Dropdown(value=None), f"❌ **Error loading data**: {str(e)} | Please try a different file or dataset."]
123
+ return empty_updates + [gr.update(visible=False)] * 16 + [gr.update(visible=False), "No data loaded."]
124
+
125
+ def update_criterion_choices(problem_type):
126
+ """Update criterion choices based on problem type"""
127
+ if problem_type == "classification":
128
+ return gr.Dropdown(choices=["gini", "entropy", "log_loss"], value="gini")
129
+ else:
130
+ return gr.Dropdown(choices=["squared_error", "absolute_error", "friedman_mse", "poisson"], value="squared_error")
131
+
132
+ def update_configuration(df_preview, target_col, problem_type):
133
+ """Update configuration when target or problem type changes"""
134
+ global current_dataframe
135
+ df = current_dataframe
136
+
137
+ if df is None or df.empty:
138
+ return [gr.update(visible=False)] * 16 + [gr.update(visible=False), "No data available."]
139
+
140
+ if not target_col or not problem_type:
141
+ return [gr.update(visible=False)] * 16 + [gr.update(visible=False), "Select target column and problem type."]
142
+
143
+ try:
144
+ is_valid, validation_msg = validate_config(df, target_col, problem_type)
145
+
146
+ if not is_valid:
147
+ return [gr.update(visible=False)] * 16 + [gr.update(visible=False), f"⚠️ {validation_msg}"]
148
+
149
+ # Generate input components
150
+ components_info = decision_tree_core.create_input_components(df, target_col)
151
+ input_updates = [gr.update(visible=False)] * 16
152
+
153
+ for i in range(min(16, len(components_info))):
154
+ comp_info = components_info[i]
155
+ if comp_info['type'] == 'number':
156
+ # Không giới hạn min/max để cho phép user nhập giá trị ngoài phạm vi training data
157
+ update_params = {
158
+ 'visible': True, 'label': comp_info['name'], 'value': comp_info['value']
159
+ }
160
+ if comp_info['minimum'] is not None:
161
+ update_params['minimum'] = comp_info['minimum']
162
+ if comp_info['maximum'] is not None:
163
+ update_params['maximum'] = comp_info['maximum']
164
+ input_updates[i] = gr.update(**update_params)
165
+ else:
166
+ input_updates[i] = gr.update(
167
+ visible=True, label=comp_info['name'],
168
+ choices=comp_info['choices'], value=comp_info['value']
169
+ )
170
+
171
+ input_status = f"📝 Enter values for {len(components_info)} features | {validation_msg}"
172
+ return input_updates + [gr.update(visible=True), input_status]
173
+
174
+ except Exception as e:
175
+ return [gr.update(visible=False)] * 16 + [gr.update(visible=False), f"❌ Error: {str(e)}"]
176
+
177
+ def execute_prediction(df_preview, target_col, problem_type, max_depth, min_samples_split, min_samples_leaf, criterion, *input_values):
178
+ """Execute Decision Tree prediction"""
179
+ global current_dataframe
180
+ df = current_dataframe
181
+
182
+ # Validation checks
183
+ if df is None or df.empty:
184
+ return None, "❌ **No data loaded!** 📊 Please select a sample dataset or upload a file first.", None, "Load data to get started."
185
+
186
+ if not target_col or not problem_type:
187
+ return None, "❌ **Configuration incomplete!** 🎯 Please select target column and problem type above.", None, "Complete configuration to proceed."
188
+
189
+ is_valid, validation_msg = validate_config(df, target_col, problem_type)
190
+ if not is_valid:
191
+ return None, f"❌ **Configuration issue**: {validation_msg}", None, "Fix the configuration and try again."
192
+
193
+ try:
194
+ components_info = decision_tree_core.create_input_components(df, target_col)
195
+ new_point_dict = {}
196
+
197
+ for i, comp_info in enumerate(components_info):
198
+ if i < len(input_values) and input_values[i] is not None:
199
+ new_point_dict[comp_info['name']] = input_values[i]
200
+ else:
201
+ new_point_dict[comp_info['name']] = comp_info['value']
202
+
203
+ tree_fig, importance_fig, prediction, prediction_details, summary, error = decision_tree_core.run_decision_tree_and_visualize(
204
+ df, target_col, new_point_dict, max_depth, min_samples_split, min_samples_leaf, criterion, problem_type
205
+ )
206
+
207
+ if error:
208
+ return None, f"❌ **Prediction failed**: {error} | Please check your input values and try again.", None, "Adjust inputs and retry."
209
+
210
+ if problem_type == "classification":
211
+ result_header = f"## 🎯 **Classification Result**: {prediction}\n*Based on decision tree with {criterion} criterion*"
212
+ else:
213
+ result_header = f"## 🎯 **Regression Result**: {prediction:.3f}\n*Based on decision tree with {criterion} criterion*"
214
+
215
+ return tree_fig, importance_fig, result_header, prediction_details, summary
216
+
217
+ except Exception as e:
218
+ return None, None, f"❌ **Execution error**: {str(e)} | Please verify your input values are correct.", None, "Check inputs and try again."
219
+
220
+ # Main Application
221
+ with gr.Blocks(theme='gstaff/sketch', css=vlai_template.custom_css, fill_width=True, js=force_light_theme_js) as demo:
222
+ vlai_template.create_header()
223
+
224
+ # Main guidance text
225
+ gr.Markdown("### 🌳 **How to Use**: Select data → Configure target → Set tree parameters → Enter new point → Run prediction!")
226
+
227
+ with gr.Row(equal_height=False, variant="panel"):
228
+ with gr.Column(scale=45):
229
+ with gr.Accordion("📊 Data & Configuration", open=True):
230
+ with gr.Row():
231
+ with gr.Column(scale=1):
232
+ gr.Markdown("Start with sample datasets or upload your own CSV/Excel files.")
233
+ file_upload = gr.File(
234
+ label="📁 Upload Your Data",
235
+ file_types=[".csv", ".xlsx", ".xls"],
236
+ )
237
+ with gr.Column(scale=3):
238
+ sample_dataset = gr.Dropdown(
239
+ choices=list(SAMPLE_DATA_CONFIG.keys()),
240
+ value="Iris",
241
+ label="🗂️ Sample Datasets",
242
+ )
243
+ problem_type_selector = gr.Dropdown(
244
+ choices=["classification", "regression"],
245
+ label="🎲 Problem Type",
246
+ interactive=True,
247
+
248
+ )
249
+ target_column = gr.Dropdown(
250
+ choices=[],
251
+ label="🎯 Target Column",
252
+ interactive=True,
253
+ )
254
+
255
+ status_message = gr.Markdown("🔄 Loading sample data...")
256
+ data_preview = gr.DataFrame(
257
+ label="📋 Data Preview (First 5 Rows)",
258
+ row_count=5,
259
+ interactive=False,
260
+ max_height=250
261
+ )
262
+
263
+ with gr.Accordion("⚙️ Parameters & Input", open=True):
264
+ gr.Markdown("**🌳 Decision Tree Parameters**")
265
+ with gr.Row():
266
+ max_depth = gr.Number(
267
+ label="Max Depth",
268
+ value=5,
269
+ minimum=0,
270
+ maximum=20,
271
+ precision=0,
272
+ info="Set to 0 for unlimited depth"
273
+ )
274
+ min_samples_split = gr.Number(
275
+ label="Min Samples Split",
276
+ value=2,
277
+ minimum=2,
278
+ maximum=100,
279
+ precision=0,
280
+ )
281
+ min_samples_leaf = gr.Number(
282
+ label="Min Samples Leaf",
283
+ value=1,
284
+ minimum=1,
285
+ maximum=50,
286
+ precision=0,
287
+ )
288
+ with gr.Row():
289
+ criterion = gr.Dropdown(
290
+ choices=["gini", "entropy", "log_loss"],
291
+ value="gini",
292
+ label="🎯 Criterion",
293
+ )
294
+
295
+ inputs_group = gr.Group(visible=False)
296
+ with inputs_group:
297
+ input_status = gr.Markdown("Configure inputs above.")
298
+ gr.Markdown("**📝 New Data Point** - Enter feature values for prediction:")
299
+
300
+ input_components = []
301
+ for row in range(4):
302
+ with gr.Row():
303
+ for col in range(4):
304
+ idx = row * 4 + col
305
+ if idx < 16:
306
+ input_components.append(
307
+ gr.Number(label=f"Feature {idx+1}", visible=False)
308
+ )
309
+
310
+ run_prediction_btn = gr.Button(
311
+ "🚀 Run Prediction",
312
+ variant="primary",
313
+ size="lg",
314
+ )
315
+
316
+ with gr.Column(scale=55):
317
+ gr.Markdown("### 🌳 **Decision Tree Results & Visualization**")
318
+
319
+ with gr.Tabs():
320
+ with gr.TabItem("Decision Tree"):
321
+ tree_visualization = gr.Plot(
322
+ label="Interactive Decision Tree",
323
+ visible=True,
324
+ )
325
+
326
+ with gr.TabItem("Feature Importance"):
327
+ feature_importance_plot = gr.Plot(
328
+ label="Feature Importance",
329
+ visible=True,
330
+ )
331
+
332
+ prediction_result = gr.Markdown(
333
+ "## 🎯 Prediction Result\n**Run prediction to see the result.**",
334
+ label="📈 Final Prediction"
335
+ )
336
+
337
+ prediction_details = gr.Markdown(
338
+ "**📝 Prediction Details**\n\nDetailed prediction information will appear here.",
339
+ label="🔍 Prediction Details"
340
+ )
341
+
342
+ algorithm_summary = gr.Markdown(
343
+ "**📋 Algorithm Summary**\n\nAlgorithm details will appear here after prediction.",
344
+ label="🔍 Technical Details"
345
+ )
346
+
347
+ # Bottom guidance
348
+ gr.Markdown("""💡 **Tips**:
349
+ - **Interactive tree visualization** allows you to zoom and explore the decision tree structure.
350
+ - **Feature importance** shows which features are most critical for making decisions.
351
+ - Try different **max depth** and **criterion** values to see how the tree structure changes!
352
+ - **Min samples split/leaf** help control tree complexity and prevent overfitting.
353
+ """)
354
+
355
+ vlai_template.create_footer()
356
+
357
+ # Event Bindings
358
+ demo.load(
359
+ fn=lambda: load_and_configure_data(None, "Iris"),
360
+ outputs=[data_preview, target_column, problem_type_selector, status_message] + input_components + [inputs_group, input_status]
361
+ )
362
+
363
+ file_upload.upload(
364
+ fn=lambda file: load_and_configure_data(file, "Iris"),
365
+ inputs=[file_upload],
366
+ outputs=[data_preview, target_column, problem_type_selector, status_message] + input_components + [inputs_group, input_status]
367
+ )
368
+
369
+ sample_dataset.change(
370
+ fn=lambda choice: load_and_configure_data(None, choice),
371
+ inputs=[sample_dataset],
372
+ outputs=[data_preview, target_column, problem_type_selector, status_message] + input_components + [inputs_group, input_status]
373
+ )
374
+
375
+ target_column.change(
376
+ fn=update_configuration,
377
+ inputs=[data_preview, target_column, problem_type_selector],
378
+ outputs=input_components + [inputs_group, input_status]
379
+ )
380
+
381
+ problem_type_selector.change(
382
+ fn=update_configuration,
383
+ inputs=[data_preview, target_column, problem_type_selector],
384
+ outputs=input_components + [inputs_group, input_status]
385
+ )
386
+
387
+ problem_type_selector.change(
388
+ fn=update_criterion_choices,
389
+ inputs=[problem_type_selector],
390
+ outputs=[criterion]
391
+ )
392
+
393
+ run_prediction_btn.click(
394
+ fn=execute_prediction,
395
+ inputs=[data_preview, target_column, problem_type_selector, max_depth, min_samples_split, min_samples_leaf, criterion] + input_components,
396
+ outputs=[tree_visualization, feature_importance_plot, prediction_result, prediction_details, algorithm_summary]
397
+ )
398
+
399
+ if __name__ == "__main__":
400
+ demo.launch(allowed_paths=["static/aivn_logo.png", "static/vlai_logo.png", "static"])
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio==5.38.0
2
+ pandas>=1.5.0
3
+ scikit-learn>=1.3.0
4
+ numpy>=1.24.0
5
+ supertree==0.5.5
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (205 Bytes). View file
 
src/__pycache__/decision_tree_core.cpython-312.pyc ADDED
Binary file (16.2 kB). View file
 
src/decision_tree_core.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
4
+ from sklearn.preprocessing import StandardScaler, LabelEncoder
5
+ from sklearn.datasets import (
6
+ load_iris, load_wine, load_diabetes, load_breast_cancer
7
+ )
8
+ import plotly.express as px
9
+ import plotly.graph_objects as go
10
+
11
+ def load_data(file_obj=None, dataset_choice="Iris"):
12
+ """Load data from file or sample dataset"""
13
+ if file_obj is not None:
14
+ if file_obj.name.endswith('.csv'):
15
+ return pd.read_csv(file_obj.name)
16
+ elif file_obj.name.endswith(('.xlsx', '.xls')):
17
+ return pd.read_excel(file_obj.name)
18
+ else:
19
+ raise ValueError("Unsupported format. Upload CSV or Excel files.")
20
+
21
+ # Sample datasets
22
+ datasets = {
23
+ "Iris": lambda: _sklearn_to_df(load_iris()),
24
+ "Wine": lambda: _sklearn_to_df(load_wine()),
25
+ "Breast Cancer": lambda: _sklearn_to_df(load_breast_cancer()),
26
+ "Diabetes": lambda: _sklearn_to_df(load_diabetes()),
27
+ }
28
+
29
+ if dataset_choice not in datasets:
30
+ raise ValueError(f"Unknown dataset: {dataset_choice}")
31
+
32
+ return datasets[dataset_choice]()
33
+
34
+ def _sklearn_to_df(data):
35
+ """Convert sklearn dataset to DataFrame"""
36
+ df = pd.DataFrame(data.data, columns=data.feature_names)
37
+ df['target'] = data.target
38
+ return df
39
+
40
+ def analyze_dataframe(df):
41
+ """Analyze DataFrame and return target options"""
42
+ return df.columns.tolist(), df.columns[-1]
43
+
44
+ def determine_problem_type(df, target_col):
45
+ """Auto-detect classification or regression"""
46
+ if target_col not in df.columns:
47
+ return "classification"
48
+
49
+ target = df[target_col]
50
+ unique_vals = target.nunique()
51
+
52
+ if target.dtype == 'object' or unique_vals <= min(20, len(target) * 0.1):
53
+ return "classification"
54
+ return "regression"
55
+
56
+ def create_input_components(df, target_col):
57
+ """Generate UI component specifications for features"""
58
+ feature_cols = [col for col in df.columns if col != target_col]
59
+ components = []
60
+
61
+ for col in feature_cols:
62
+ data = df[col]
63
+ if data.dtype == 'object':
64
+ unique_vals = sorted(data.unique())
65
+ components.append({
66
+ 'name': col, 'type': 'dropdown',
67
+ 'choices': unique_vals, 'value': unique_vals[0]
68
+ })
69
+ else:
70
+ components.append({
71
+ 'name': col, 'type': 'number',
72
+ 'value': round(float(data.mean()), 2),
73
+ 'minimum': None,
74
+ 'maximum': None
75
+ })
76
+
77
+ return components
78
+
79
+ def preprocess_data(df, target_col, new_point_dict):
80
+ """Preprocess data for decision tree training"""
81
+ feature_cols = [col for col in df.columns if col != target_col]
82
+ X = df[feature_cols].copy()
83
+ y = df[target_col].copy()
84
+
85
+ # Encode categorical variables
86
+ encoders = {}
87
+ for col in feature_cols:
88
+ if X[col].dtype == 'object':
89
+ le = LabelEncoder()
90
+ X[col] = le.fit_transform(X[col].astype(str))
91
+ encoders[col] = le
92
+
93
+ # Process new point
94
+ new_point = []
95
+ for col in feature_cols:
96
+ if col in encoders:
97
+ try:
98
+ val = encoders[col].transform([str(new_point_dict[col])])[0]
99
+ except ValueError:
100
+ available_categories = list(encoders[col].classes_)
101
+ raise ValueError(f"Unknown category '{new_point_dict[col]}' for column '{col}'. Available options: {available_categories}")
102
+ new_point.append(val)
103
+ else:
104
+ new_point.append(float(new_point_dict[col]))
105
+
106
+ new_point = np.array(new_point).reshape(1, -1)
107
+
108
+ return X.values, y, new_point, feature_cols, encoders
109
+
110
+ def run_decision_tree_and_visualize(df, target_col, new_point_dict, max_depth, min_samples_split, min_samples_leaf, criterion, problem_type=None):
111
+ """Execute Decision Tree algorithm and generate visualization"""
112
+ X, y, new_point, feature_cols, encoders = preprocess_data(df, target_col, new_point_dict)
113
+
114
+ if problem_type is None:
115
+ problem_type = determine_problem_type(df, target_col)
116
+
117
+ # Validate parameters
118
+ if max_depth is not None and max_depth < 0:
119
+ return None, None, None, None, "Max depth must be at least 0 (unlimited) or 1+ for specific depth."
120
+
121
+ if min_samples_split < 2:
122
+ return None, None, None, None, "Min samples split must be at least 2."
123
+
124
+ if min_samples_leaf < 1:
125
+ return None, None, None, None, "Min samples leaf must be at least 1."
126
+
127
+ # Train decision tree
128
+ ModelClass = DecisionTreeClassifier if problem_type == "classification" else DecisionTreeRegressor
129
+ model = ModelClass(
130
+ max_depth=None if max_depth == 0 else max_depth,
131
+ min_samples_split=min_samples_split,
132
+ min_samples_leaf=min_samples_leaf,
133
+ criterion=criterion,
134
+ random_state=42
135
+ )
136
+
137
+ model.fit(X, y)
138
+ prediction = model.predict(new_point)[0]
139
+
140
+ # Get prediction path
141
+ path = model.decision_path(new_point)
142
+ node_indices = path.indices
143
+
144
+ # Create tree visualization
145
+ tree_fig = create_tree_visualization(model, feature_cols, target_col, problem_type, new_point_dict, prediction)
146
+
147
+ # Create feature importance plot
148
+ importance_fig = create_feature_importance_plot(model, feature_cols)
149
+
150
+ # Create prediction details
151
+ prediction_details = create_prediction_details(model, new_point[0], feature_cols, target_col, prediction, problem_type)
152
+
153
+ # Generate algorithm summary
154
+ summary = create_algorithm_summary(model, problem_type, max_depth, min_samples_split, min_samples_leaf, criterion, feature_cols)
155
+
156
+ return tree_fig, importance_fig, prediction, prediction_details, summary, None
157
+
158
+ def create_tree_visualization(model, feature_cols, target_col, problem_type, new_point_dict, prediction):
159
+ """Create interactive decision tree visualization using plotly"""
160
+ # Create a hierarchical tree visualization
161
+ fig = go.Figure()
162
+
163
+ # Get tree structure
164
+ tree_data = get_tree_structure(model, feature_cols, target_col, problem_type)
165
+
166
+ # Create tree layout
167
+ positions = calculate_tree_positions(tree_data)
168
+
169
+ # Add nodes
170
+ for node_id, pos in positions.items():
171
+ node_info = tree_data[node_id]
172
+
173
+ if node_info['is_leaf']:
174
+ color = 'lightgreen'
175
+ text = f"Leaf: {node_info['prediction']}"
176
+ else:
177
+ color = 'lightblue'
178
+ text = f"{node_info['feature']} ≤ {node_info['threshold']:.3f}"
179
+
180
+ fig.add_trace(go.Scatter(
181
+ x=[pos['x']], y=[pos['y']],
182
+ mode='markers+text',
183
+ marker=dict(size=15, color=color),
184
+ text=[text],
185
+ textposition='middle center',
186
+ textfont=dict(size=10),
187
+ showlegend=False,
188
+ hovertemplate=f"<b>{text}</b><br>Samples: {node_info['samples']}<extra></extra>"
189
+ ))
190
+
191
+ # Add edges
192
+ for node_id, pos in positions.items():
193
+ node_info = tree_data[node_id]
194
+ if not node_info['is_leaf']:
195
+ # Left child
196
+ if node_info['left_child'] in positions:
197
+ left_pos = positions[node_info['left_child']]
198
+ fig.add_trace(go.Scatter(
199
+ x=[pos['x'], left_pos['x']], y=[pos['y'], left_pos['y']],
200
+ mode='lines',
201
+ line=dict(color='gray', width=1),
202
+ showlegend=False,
203
+ hoverinfo='skip'
204
+ ))
205
+
206
+ # Right child
207
+ if node_info['right_child'] in positions:
208
+ right_pos = positions[node_info['right_child']]
209
+ fig.add_trace(go.Scatter(
210
+ x=[pos['x'], right_pos['x']], y=[pos['y'], right_pos['y']],
211
+ mode='lines',
212
+ line=dict(color='gray', width=1),
213
+ showlegend=False,
214
+ hoverinfo='skip'
215
+ ))
216
+
217
+ fig.update_layout(
218
+ title="Decision Tree Structure",
219
+ xaxis_title="",
220
+ yaxis_title="",
221
+ showlegend=False,
222
+ height=600,
223
+ width=800,
224
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
225
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
226
+ )
227
+
228
+ return fig
229
+
230
+ def get_tree_structure(model, feature_cols, target_col, problem_type):
231
+ """Extract tree structure from sklearn model"""
232
+ tree_data = {}
233
+
234
+ def process_node(node_id):
235
+ if model.tree_.children_left[node_id] == -1: # Leaf node
236
+ if problem_type == "classification":
237
+ class_counts = model.tree_.value[node_id][0]
238
+ predicted_class = np.argmax(class_counts)
239
+ else:
240
+ predicted_value = model.tree_.value[node_id][0][0]
241
+ predicted_class = predicted_value
242
+
243
+ tree_data[node_id] = {
244
+ 'is_leaf': True,
245
+ 'samples': int(model.tree_.n_node_samples[node_id]),
246
+ 'prediction': predicted_class
247
+ }
248
+ else: # Internal node
249
+ feature_idx = model.tree_.feature[node_id]
250
+ threshold = model.tree_.threshold[node_id]
251
+ feature_name = feature_cols[feature_idx] if feature_idx < len(feature_cols) else f'Feature_{feature_idx}'
252
+
253
+ tree_data[node_id] = {
254
+ 'is_leaf': False,
255
+ 'feature': feature_name,
256
+ 'threshold': threshold,
257
+ 'samples': int(model.tree_.n_node_samples[node_id]),
258
+ 'left_child': model.tree_.children_left[node_id],
259
+ 'right_child': model.tree_.children_right[node_id]
260
+ }
261
+
262
+ # Process children
263
+ process_node(model.tree_.children_left[node_id])
264
+ process_node(model.tree_.children_right[node_id])
265
+
266
+ process_node(0)
267
+ return tree_data
268
+
269
+ def calculate_tree_positions(tree_data):
270
+ """Calculate positions for tree nodes"""
271
+ positions = {}
272
+
273
+ def calculate_positions_recursive(node_id, x, y, level_width):
274
+ if node_id not in tree_data:
275
+ return
276
+
277
+ positions[node_id] = {'x': x, 'y': y}
278
+
279
+ if not tree_data[node_id]['is_leaf']:
280
+ # Calculate positions for children
281
+ left_child = tree_data[node_id]['left_child']
282
+ right_child = tree_data[node_id]['right_child']
283
+
284
+ child_width = level_width / 2
285
+ calculate_positions_recursive(left_child, x - child_width/2, y - 1, child_width)
286
+ calculate_positions_recursive(right_child, x + child_width/2, y - 1, child_width)
287
+
288
+ # Start from root
289
+ calculate_positions_recursive(0, 0, 0, 4)
290
+ return positions
291
+
292
+
293
+
294
+ def create_feature_importance_plot(model, feature_cols):
295
+ """Create feature importance visualization"""
296
+ importances = model.feature_importances_
297
+ indices = np.argsort(importances)[::-1]
298
+
299
+ fig = go.Figure()
300
+
301
+ fig.add_trace(go.Bar(
302
+ x=[feature_cols[i] for i in indices],
303
+ y=importances[indices],
304
+ marker_color='lightblue',
305
+ text=[f'{importances[i]:.3f}' for i in indices],
306
+ textposition='auto',
307
+ ))
308
+
309
+ fig.update_layout(
310
+ title="Feature Importance",
311
+ xaxis_title="Features",
312
+ yaxis_title="Importance Score",
313
+ showlegend=False,
314
+ height=400
315
+ )
316
+
317
+ return fig
318
+
319
+ def create_prediction_details(model, new_point, feature_cols, target_col, prediction, problem_type):
320
+ """Create detailed prediction information"""
321
+ details = []
322
+
323
+ # Add input features
324
+ details.append("## 📝 **Input Features**")
325
+ for i, (col, val) in enumerate(zip(feature_cols, new_point)):
326
+ details.append(f"- **{col}**: {val}")
327
+
328
+ details.append(f"\n## 🎯 **Prediction**")
329
+ if problem_type == "classification":
330
+ details.append(f"- **Predicted Class**: {prediction}")
331
+ # Get prediction probabilities if available
332
+ if hasattr(model, 'predict_proba'):
333
+ proba = model.predict_proba(new_point.reshape(1, -1))[0]
334
+ details.append(f"- **Confidence**: {max(proba):.3f}")
335
+ else:
336
+ details.append(f"- **Predicted Value**: {prediction:.3f}")
337
+
338
+ # Add tree statistics
339
+ details.append(f"\n## 🌳 **Tree Statistics**")
340
+ details.append(f"- **Total Nodes**: {model.tree_.node_count}")
341
+ details.append(f"- **Leaf Nodes**: {model.get_n_leaves()}")
342
+ details.append(f"- **Max Depth**: {model.get_depth()}")
343
+
344
+ return "\n".join(details)
345
+
346
+ def create_algorithm_summary(model, problem_type, max_depth, min_samples_split, min_samples_leaf, criterion, feature_cols):
347
+ """Generate algorithm summary"""
348
+ max_depth_str = "Unlimited" if max_depth == 0 else str(max_depth)
349
+
350
+ summary = f"""## Algorithm Summary
351
+ **Criterion:** {criterion} | **Max Depth:** {max_depth_str} | **Min Samples Split:** {min_samples_split} | **Min Samples Leaf:** {min_samples_leaf}
352
+ **Features:** {len(feature_cols)} | **Total Nodes:** {model.tree_.node_count} | **Leaf Nodes:** {model.get_n_leaves()}
353
+ **Tree Depth:** {model.get_depth()} | **Problem Type:** {problem_type.title()}
354
+
355
+ **Top 3 Most Important Features:**
356
+ """
357
+
358
+ importances = model.feature_importances_
359
+ indices = np.argsort(importances)[::-1]
360
+
361
+ for i in range(min(3, len(feature_cols))):
362
+ summary += f"- {feature_cols[indices[i]]}: {importances[indices[i]]:.3f}\n"
363
+
364
+ return summary
static/aivn_logo.png ADDED
static/vlai_logo.png ADDED
vlai_template.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, base64
2
+ import gradio as gr
3
+
4
+
5
+ PROJECT_NAME = "Decision Tree Demo"
6
+ AIO_YEAR = "2025"
7
+ AIO_MODULE = "03"
8
+ # END
9
+
10
+
11
+ def image_to_base64(image_path: str):
12
+ # Construct the absolute path to the image
13
+ current_dir = os.path.dirname(os.path.abspath(__file__))
14
+ full_image_path = os.path.join(current_dir, image_path)
15
+ with open(full_image_path, "rb") as f:
16
+ return base64.b64encode(f.read()).decode("utf-8")
17
+
18
+ def create_header():
19
+ with gr.Row():
20
+ with gr.Column(scale=2):
21
+ logo_base64 = image_to_base64("static/aivn_logo.png")
22
+ gr.HTML(
23
+ f"""<img src="data:image/png;base64,{logo_base64}"
24
+ alt="Logo"
25
+ style="height:120px;width:auto;margin:0 auto;margin-bottom:16px; display:block;">"""
26
+ )
27
+ with gr.Column(scale=2):
28
+ gr.HTML(f"""
29
+ <div style="display:flex;justify-content:flex-start;align-items:center;gap:30px;">
30
+ <div>
31
+ <h1 style="margin-bottom:0; color: #2E7D32; font-size: 2.5em; font-weight: bold;"> {PROJECT_NAME} </h1>
32
+ <h3 style="color: #888; font-style: italic"> AIO{AIO_YEAR}: Module {AIO_MODULE}. </h3>
33
+ </div>
34
+ </div>
35
+ """)
36
+
37
+ def create_footer():
38
+ logo_base64_vlai = image_to_base64("static/vlai_logo.png")
39
+ footer_html = """
40
+ <style>
41
+ .sticky-footer{position:fixed;bottom:0px;left:0;width:100%;background:#E8F5E8;
42
+ padding:10px;box-shadow:0 -2px 10px rgba(0,0,0,0.1);z-index:1000;}
43
+ .content-wrap{padding-bottom:60px;}
44
+ </style>""" + f"""
45
+ <div class="sticky-footer">
46
+ <div style="text-align:center;font-size:18px; color: #888">
47
+ Created by
48
+ <a href="https://vlai.work" target="_blank" style="color:#465C88;text-decoration:none;font-weight:bold; display:inline-flex; align-items:center;"> VLAI
49
+ <img src="data:image/png;base64,{logo_base64_vlai}" alt="Logo" style="height:20px; width:auto;">
50
+ </a> from <a href="https://aivietnam.edu.vn/" target="_blank" style="color:#355724;text-decoration:none;font-weight:bold">AI VIET NAM</a>
51
+ </div>
52
+ </div>
53
+ """
54
+ return gr.HTML(footer_html)
55
+
56
+ custom_css = """
57
+
58
+ .gradio-container {
59
+ min-height: 100vh !important;
60
+ width: 100vw !important;
61
+ margin: 0 !important;
62
+ padding: 0px !important;
63
+ background: linear-gradient(135deg, #E8F5E8 0%, #D4E6D4 50%, #A8D8A8 100%);
64
+ background-size: 600% 600%;
65
+ animation: gradientBG 7s ease infinite;
66
+ }
67
+
68
+ @keyframes gradientBG {
69
+ 0% {background-position: 0% 50%;}
70
+ 50% {background-position: 100% 50%;}
71
+ 100% {background-position: 0% 50%;}
72
+ }
73
+
74
+ /* Minimize spacing and padding */
75
+ .content-wrap {
76
+ padding: 2px !important;
77
+ margin: 0 !important;
78
+ }
79
+
80
+ /* Reduce component spacing */
81
+ .gr-row {
82
+ gap: 5px !important;
83
+ margin: 2px 0 !important;
84
+ }
85
+
86
+ .gr-column {
87
+ gap: 4px !important;
88
+ padding: 4px !important;
89
+ }
90
+
91
+ /* Accordion optimization */
92
+ .gr-accordion {
93
+ margin: 4px 0 !important;
94
+ }
95
+
96
+ .gr-accordion .gr-accordion-content {
97
+ padding: 2px !important;
98
+ }
99
+
100
+ /* Form elements spacing */
101
+ .gr-form {
102
+ gap: 2px !important;
103
+ }
104
+
105
+ /* Button styling */
106
+ .gr-button {
107
+ margin: 2px 0 !important;
108
+ }
109
+
110
+ /* DataFrame optimization */
111
+ .gr-dataframe {
112
+ margin: 4px 0 !important;
113
+ }
114
+
115
+ /* Remove horizontal scroll from data preview */
116
+ .gr-dataframe .wrap {
117
+ overflow-x: auto !important;
118
+ max-width: 100% !important;
119
+ }
120
+
121
+ /* Plot optimization */
122
+ .gr-plot {
123
+ margin: 4px 0 !important;
124
+ }
125
+
126
+ /* Reduce markdown margins */
127
+ .gr-markdown {
128
+ margin: 2px 0 !important;
129
+ }
130
+
131
+ /* Footer positioning */
132
+ .sticky-footer {
133
+ position: fixed;
134
+ bottom: 0px;
135
+ left: 0;
136
+ width: 100%;
137
+ background: #E8F5E8;
138
+ padding: 6px !important;
139
+ box-shadow: 0 -2px 10px rgba(0,0,0,0.1);
140
+ z-index: 1000;
141
+ }
142
+ """