wjnwjn59 commited on
Commit
863c992
Β·
1 Parent(s): da0fb40

fix depth error

Browse files
Files changed (3) hide show
  1. README.md +2 -2
  2. app.py +3 -3
  3. src/adaboost_core.py +87 -33
README.md CHANGED
@@ -26,7 +26,7 @@ This interactive demo showcases AdaBoost (Adaptive Boosting) algorithms for both
26
  ### AdaBoost Parameters
27
  - **Number of Estimators**: Control sequential learning steps (limited to 1000 for performance)
28
  - **Learning Rate**: Step size shrinkage for adaptive learning (0.0001-2.0)
29
- - **Max Depth**: Individual weak learner depth (default: 0, decision stumps work best)
30
  - **Base Estimator**: Decision tree with limited depth (weak learner)
31
 
32
  ### Visualizations
@@ -100,7 +100,7 @@ This interactive demo showcases AdaBoost (Adaptive Boosting) algorithms for both
100
 
101
  - **Number of Estimators**: Limited to 1000 for optimal performance in this demo
102
  - **Learning Rate**: Default 1.0 works well; lower rates (0.0001-0.1) create more conservative models, higher rates (1.0-2.0) for faster learning
103
- - **Max Depth**: Decision stumps (depth 0) typically optimal for AdaBoost
104
  - **Weak Learners**: Simple estimators work best to avoid overfitting
105
 
106
  ## 🎯 Use Cases
 
26
  ### AdaBoost Parameters
27
  - **Number of Estimators**: Control sequential learning steps (limited to 1000 for performance)
28
  - **Learning Rate**: Step size shrinkage for adaptive learning (0.0001-2.0)
29
+ - **Max Depth**: Individual weak learner depth (default: 1, decision stumps work best)
30
  - **Base Estimator**: Decision tree with limited depth (weak learner)
31
 
32
  ### Visualizations
 
100
 
101
  - **Number of Estimators**: Limited to 1000 for optimal performance in this demo
102
  - **Learning Rate**: Default 1.0 works well; lower rates (0.0001-0.1) create more conservative models, higher rates (1.0-2.0) for faster learning
103
+ - **Max Depth**: Decision stumps (depth 1) typically optimal for AdaBoost
104
  - **Weak Learners**: Simple estimators work best to avoid overfitting
105
 
106
  ## 🎯 Use Cases
app.py CHANGED
@@ -426,8 +426,8 @@ with gr.Blocks(theme="gstaff/sketch", css=vlai_template.custom_css, fill_width=T
426
  with gr.Row():
427
  max_depth = gr.Number(
428
  label="Max Depth (Base Estimator)",
429
- value=0, minimum=0, maximum=10, precision=0,
430
- info="Maximum depth of individual decision trees (0 = decision stumps with 1 split, ideal for AdaBoost)"
431
  )
432
 
433
  gr.Markdown("**πŸ“Š Data Split Configuration**")
@@ -482,7 +482,7 @@ with gr.Blocks(theme="gstaff/sketch", css=vlai_template.custom_css, fill_width=T
482
  - **πŸ“Š Feature Importance**: Displays which features are most influential across all estimators.
483
  - **🎯 Parameter Tuning**: Try different **number of estimators** (up to 1000) and **learning rate** (0.0001-2.0).
484
  - **⚑ Learning Rate**: Default 1.0 works well; lower values create more conservative models with better generalization.
485
- - **🌲 Decision Stumps**: Max depth 0 creates decision stumps (one split), which are ideal weak learners for AdaBoost.
486
  - **🎯 Adaptive Reweighting**: AdaBoost focuses on misclassified examples by increasing their weights.
487
  - **πŸ” Estimator Analysis**: Use the estimator selector to understand how each decision stump contributes to predictions.
488
  """)
 
426
  with gr.Row():
427
  max_depth = gr.Number(
428
  label="Max Depth (Base Estimator)",
429
+ value=1, minimum=1, maximum=10, precision=0,
430
+ info="Maximum depth of individual decision trees (1 = decision stumps, 2+ = deeper trees)"
431
  )
432
 
433
  gr.Markdown("**πŸ“Š Data Split Configuration**")
 
482
  - **πŸ“Š Feature Importance**: Displays which features are most influential across all estimators.
483
  - **🎯 Parameter Tuning**: Try different **number of estimators** (up to 1000) and **learning rate** (0.0001-2.0).
484
  - **⚑ Learning Rate**: Default 1.0 works well; lower values create more conservative models with better generalization.
485
+ - **🌲 Decision Stumps**: Max depth 1 creates decision stumps (one split), which are ideal weak learners for AdaBoost.
486
  - **🎯 Adaptive Reweighting**: AdaBoost focuses on misclassified examples by increasing their weights.
487
  - **πŸ” Estimator Analysis**: Use the estimator selector to understand how each decision stump contributes to predictions.
488
  """)
src/adaboost_core.py CHANGED
@@ -160,8 +160,8 @@ def run_adaboost_and_visualize(df, target_col, new_point_dict,
160
 
161
  if n_estimators < 1:
162
  return None, None, None, None, "Number of estimators must be β‰₯ 1.", None
163
- if max_depth is not None and max_depth < 0:
164
- return None, None, None, None, "Max depth must be β‰₯ 0.", None
165
  if learning_rate <= 0 or learning_rate > 2:
166
  return None, None, None, None, "Learning rate must be between 0 and 2.", None
167
 
@@ -173,7 +173,9 @@ def run_adaboost_and_visualize(df, target_col, new_point_dict,
173
 
174
  if problem_type == "classification":
175
  # For binary/multiclass classification
176
- base_estimator = DecisionTreeClassifier(max_depth=1 if max_depth == 0 else int(max_depth))
 
 
177
  try:
178
  # Try the new parameter name first (scikit-learn >= 1.2)
179
  model = AdaBoostClassifier(
@@ -193,7 +195,9 @@ def run_adaboost_and_visualize(df, target_col, new_point_dict,
193
  random_state=42
194
  )
195
  else:
196
- base_estimator = DecisionTreeRegressor(max_depth=1 if max_depth == 0 else int(max_depth))
 
 
197
  try:
198
  # Try the new parameter name first (scikit-learn >= 1.2)
199
  model = AdaBoostRegressor(
@@ -389,34 +393,78 @@ def create_manual_tree_plot(tree_index, feature_cols, problem_type, model_type,
389
  import random
390
  random.seed(tree_index) # Consistent trees for same index
391
 
392
- # Root node (decision stump - only one split)
 
 
 
 
 
 
 
 
 
 
 
393
  root_feature = random.choice(feature_cols) if feature_cols else "feature_0"
394
  root_threshold = round(random.uniform(0.1, 5.0), 2)
395
 
396
- # Positions for a decision stump (depth 0 - only root and two leaves)
397
- positions = {
398
- 'root': (0, 1),
399
- 'left': (-1, 0),
400
- 'right': (1, 0)
401
- }
402
-
403
- # Labels and colors for decision stump
404
- labels = {
405
- 'root': f"{root_feature}<br>≀ {root_threshold}<br>Weight: {weight:.3f}<br>Decision Stump",
406
- 'left': f"Leaf (≀)<br>Value: {round(random.uniform(-1, 1), 3)}<br>Samples: {random.randint(20, 80)}",
407
- 'right': f"Leaf (>)<br>Value: {round(random.uniform(-1, 1), 3)}<br>Samples: {random.randint(20, 80)}"
408
- }
409
-
410
- colors = {
411
- 'root': '#81C784', # Green for split node
412
- 'left': '#FFB74D', # Orange for left leaf
413
- 'right': '#FFB74D' # Orange for right leaf
414
- }
415
-
416
- # Draw edges for decision stump
417
- edges = [
418
- ('root', 'left'), ('root', 'right')
419
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
421
  edge_x, edge_y = [], []
422
  for parent, child in edges:
@@ -452,12 +500,18 @@ def create_manual_tree_plot(tree_index, feature_cols, problem_type, model_type,
452
  hovertext=labels[node_id]
453
  ))
454
 
 
 
 
 
 
 
455
  fig.update_layout(
456
- title=f"{model_type} Estimator {tree_index + 1} Structure - Decision Stump ({problem_type.title()})",
457
- xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-1.5, 1.5]),
458
- yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-0.5, 1.5]),
459
  plot_bgcolor="white",
460
- height=400,
461
  margin=dict(l=40, r=40, t=60, b=40),
462
  showlegend=False
463
  )
 
160
 
161
  if n_estimators < 1:
162
  return None, None, None, None, "Number of estimators must be β‰₯ 1.", None
163
+ if max_depth is not None and max_depth < 1:
164
+ return None, None, None, None, "Max depth must be β‰₯ 1.", None
165
  if learning_rate <= 0 or learning_rate > 2:
166
  return None, None, None, None, "Learning rate must be between 0 and 2.", None
167
 
 
173
 
174
  if problem_type == "classification":
175
  # For binary/multiclass classification
176
+ # Direct mapping: UI depth = actual depth, with minimum depth of 1 for AdaBoost
177
+ actual_depth = max(1, int(max_depth)) if max_depth >= 1 else 1
178
+ base_estimator = DecisionTreeClassifier(max_depth=actual_depth)
179
  try:
180
  # Try the new parameter name first (scikit-learn >= 1.2)
181
  model = AdaBoostClassifier(
 
195
  random_state=42
196
  )
197
  else:
198
+ # Direct mapping: UI depth = actual depth, with minimum depth of 1 for AdaBoost
199
+ actual_depth = max(1, int(max_depth)) if max_depth >= 1 else 1
200
+ base_estimator = DecisionTreeRegressor(max_depth=actual_depth)
201
  try:
202
  # Try the new parameter name first (scikit-learn >= 1.2)
203
  model = AdaBoostRegressor(
 
393
  import random
394
  random.seed(tree_index) # Consistent trees for same index
395
 
396
+ # Get the current model to determine actual depth
397
+ current_model = _get_current_model()
398
+ if current_model and hasattr(current_model, 'estimators_') and len(current_model.estimators_) > tree_index:
399
+ try:
400
+ actual_estimator = current_model.estimators_[tree_index]
401
+ actual_depth = actual_estimator.max_depth
402
+ except:
403
+ actual_depth = 1 # fallback to stump
404
+ else:
405
+ actual_depth = 1 # fallback to stump
406
+
407
+ # Root node
408
  root_feature = random.choice(feature_cols) if feature_cols else "feature_0"
409
  root_threshold = round(random.uniform(0.1, 5.0), 2)
410
 
411
+ # Create tree structure based on actual depth
412
+ if actual_depth == 1:
413
+ # Decision stump (depth 1 - only root and two leaves)
414
+ positions = {
415
+ 'root': (0, 1),
416
+ 'left': (-1, 0),
417
+ 'right': (1, 0)
418
+ }
419
+
420
+ labels = {
421
+ 'root': f"{root_feature}<br>≀ {root_threshold}<br>Weight: {weight:.3f}<br>Decision Stump",
422
+ 'left': f"Leaf (≀)<br>Value: {round(random.uniform(-1, 1), 3)}<br>Samples: {random.randint(20, 80)}",
423
+ 'right': f"Leaf (>)<br>Value: {round(random.uniform(-1, 1), 3)}<br>Samples: {random.randint(20, 80)}"
424
+ }
425
+
426
+ colors = {
427
+ 'root': '#81C784', # Green for split node
428
+ 'left': '#FFB74D', # Orange for left leaf
429
+ 'right': '#FFB74D' # Orange for right leaf
430
+ }
431
+
432
+ edges = [('root', 'left'), ('root', 'right')]
433
+ title_suffix = "Decision Stump"
434
+
435
+ else:
436
+ # Deeper tree (depth 2+)
437
+ positions = {
438
+ 'root': (0, 2),
439
+ 'left': (-1.5, 1),
440
+ 'right': (1.5, 1),
441
+ 'left_left': (-2.5, 0),
442
+ 'left_right': (-0.5, 0),
443
+ 'right_left': (0.5, 0),
444
+ 'right_right': (2.5, 0)
445
+ }
446
+
447
+ labels = {
448
+ 'root': f"{root_feature}<br>≀ {root_threshold}<br>Weight: {weight:.3f}<br>Depth: {actual_depth}",
449
+ 'left': f"{random.choice(feature_cols) if feature_cols else 'feature_1'}<br>≀ {round(random.uniform(0.1, 3.0), 2)}<br>Samples: 75",
450
+ 'right': f"{random.choice(feature_cols) if feature_cols else 'feature_2'}<br>≀ {round(random.uniform(0.1, 3.0), 2)}<br>Samples: 75",
451
+ 'left_left': f"Leaf<br>Value: {round(random.uniform(-1, 1), 3)}<br>Samples: 25",
452
+ 'left_right': f"Leaf<br>Value: {round(random.uniform(-1, 1), 3)}<br>Samples: 50",
453
+ 'right_left': f"Leaf<br>Value: {round(random.uniform(-1, 1), 3)}<br>Samples: 30",
454
+ 'right_right': f"Leaf<br>Value: {round(random.uniform(-1, 1), 3)}<br>Samples: 45"
455
+ }
456
+
457
+ colors = {
458
+ 'root': '#81C784', 'left': '#81C784', 'right': '#81C784', # Green for split nodes
459
+ 'left_left': '#FFB74D', 'left_right': '#FFB74D', 'right_left': '#FFB74D', 'right_right': '#FFB74D' # Orange for leaves
460
+ }
461
+
462
+ edges = [
463
+ ('root', 'left'), ('root', 'right'),
464
+ ('left', 'left_left'), ('left', 'left_right'),
465
+ ('right', 'right_left'), ('right', 'right_right')
466
+ ]
467
+ title_suffix = f"Depth {actual_depth} Tree"
468
 
469
  edge_x, edge_y = [], []
470
  for parent, child in edges:
 
500
  hovertext=labels[node_id]
501
  ))
502
 
503
+ # Adjust layout based on tree depth
504
+ if actual_depth == 1:
505
+ x_range, y_range, height = [-1.5, 1.5], [-0.5, 1.5], 400
506
+ else:
507
+ x_range, y_range, height = [-3, 3], [-0.5, 2.5], 600
508
+
509
  fig.update_layout(
510
+ title=f"{model_type} Estimator {tree_index + 1} Structure - {title_suffix} ({problem_type.title()})",
511
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=x_range),
512
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=y_range),
513
  plot_bgcolor="white",
514
+ height=height,
515
  margin=dict(l=40, r=40, t=60, b=40),
516
  showlegend=False
517
  )