wjnwjn59 commited on
Commit
677dc84
·
1 Parent(s): 863c992

remove redundant

Browse files
Files changed (1) hide show
  1. src/xgboost_core.py +0 -938
src/xgboost_core.py DELETED
@@ -1,938 +0,0 @@
1
- import pandas as pd
2
- import numpy as np
3
-
4
- # XGBoost is required for this demo
5
- try:
6
- import xgboost as xgb
7
- XGBOOST_AVAILABLE = True
8
- print("✅ XGBoost loaded successfully!")
9
- except ImportError:
10
- print("❌ XGBoost is required for this demo!")
11
- print("Please install XGBoost using: pip install xgboost>=2.0.0")
12
- raise ImportError("XGBoost is required for this XGBoost demo. Please install it using: pip install xgboost>=2.0.0")
13
-
14
- from sklearn.preprocessing import LabelEncoder
15
- from sklearn.datasets import (
16
- load_iris, load_wine, load_diabetes, load_breast_cancer
17
- )
18
- from sklearn.model_selection import train_test_split
19
- import plotly.graph_objects as go
20
- import plotly.express as px
21
-
22
- _current_model = None
23
-
24
- def _get_current_model():
25
- return _current_model
26
-
27
- def _set_current_model(model):
28
- global _current_model
29
- _current_model = model
30
-
31
-
32
- def load_data(file_obj=None, dataset_choice="Iris"):
33
- if file_obj is not None:
34
- if file_obj.name.endswith(".csv"):
35
- encodings = ["utf-8", "latin-1", "iso-8859-1", "cp1252"]
36
- for encoding in encodings:
37
- try:
38
- return pd.read_csv(file_obj.name, encoding=encoding)
39
- except UnicodeDecodeError:
40
- continue
41
- return pd.read_csv(file_obj.name, encoding="utf-8", errors="replace")
42
- elif file_obj.name.endswith((".xlsx", ".xls")):
43
- return pd.read_excel(file_obj.name)
44
- else:
45
- raise ValueError("Unsupported format. Upload CSV or Excel files.")
46
-
47
- datasets = {
48
- "Iris": lambda: _sklearn_to_df(load_iris()),
49
- "Wine": lambda: _sklearn_to_df(load_wine()),
50
- "Breast Cancer": lambda: _sklearn_to_df(load_breast_cancer()),
51
- "Diabetes": lambda: _sklearn_to_df(load_diabetes()),
52
- "Titanic": lambda: _load_titanic_data(),
53
- }
54
- if dataset_choice not in datasets:
55
- raise ValueError(f"Unknown dataset: {dataset_choice}")
56
- return datasets[dataset_choice]()
57
-
58
-
59
- def _sklearn_to_df(data):
60
- df = pd.DataFrame(data.data, columns=getattr(data, "feature_names", None))
61
- if df.columns.isnull().any():
62
- df.columns = [f"f{i}" for i in range(df.shape[1])]
63
- df["target"] = data.target
64
- return df
65
-
66
- def _load_titanic_data():
67
- try:
68
- df = pd.read_csv("data/titanic_dataset.csv")
69
- df = df.dropna()
70
- df['sex'] = df['sex'].map({'male': 0, 'female': 1})
71
- df['embarked'] = df['embarked'].map({'S': 0, 'C': 1, 'Q': 2})
72
- return df
73
- except FileNotFoundError:
74
- raise ValueError("Titanic dataset not found. Please ensure 'data/titanic_dataset.csv' exists.")
75
-
76
-
77
- def determine_problem_type(df, target_col):
78
- if target_col not in df.columns:
79
- return "classification"
80
- target = df[target_col]
81
- unique_vals = target.nunique()
82
- if target.dtype == "object" or unique_vals <= min(20, len(target) * 0.1):
83
- return "classification"
84
- return "regression"
85
-
86
-
87
- def create_input_components(df, target_col):
88
- feature_cols = [c for c in df.columns if c != target_col]
89
- components = []
90
- for col in feature_cols:
91
- data = df[col]
92
- if data.dtype == "object":
93
- uniq = sorted(map(str, data.dropna().unique()))
94
- if not uniq:
95
- uniq = ["N/A"]
96
- components.append(
97
- {"name": col, "type": "dropdown", "choices": uniq, "value": uniq[0]}
98
- )
99
- else:
100
- val = pd.to_numeric(data, errors="coerce").dropna().mean()
101
- val = 0.0 if pd.isna(val) else float(val)
102
- components.append(
103
- {
104
- "name": col,
105
- "type": "number",
106
- "value": round(val, 3),
107
- "minimum": None,
108
- "maximum": None,
109
- }
110
- )
111
- return components
112
-
113
-
114
- def preprocess_data(df, target_col, new_point_dict):
115
- feature_cols = [c for c in df.columns if c != target_col]
116
- X = df[feature_cols].copy()
117
- y = df[target_col].copy()
118
-
119
- encoders = {}
120
- for col in feature_cols:
121
- if X[col].dtype == "object":
122
- le = LabelEncoder()
123
- X[col] = le.fit_transform(X[col].astype(str))
124
- encoders[col] = le
125
- elif X[col].dtype == "bool":
126
- X[col] = X[col].astype(int)
127
- else:
128
- X[col] = pd.to_numeric(X[col], errors="coerce").fillna(0.0)
129
-
130
- if y.dtype == "object":
131
- y = pd.Categorical(y).codes
132
- elif y.dtype == "bool":
133
- y = y.astype(int)
134
-
135
- new_point = []
136
- for col in feature_cols:
137
- if col in new_point_dict:
138
- if col in encoders:
139
- val = str(new_point_dict[col])
140
- try:
141
- enc_val = encoders[col].transform([val])[0]
142
- except ValueError:
143
- enc_val = 0
144
- new_point.append(enc_val)
145
- else:
146
- v = new_point_dict[col]
147
- try:
148
- new_point.append(float(v))
149
- except Exception:
150
- new_point.append(0.0)
151
- else:
152
- if col in encoders:
153
- new_point.append(0)
154
- else:
155
- new_point.append(0.0)
156
- new_point = np.array(new_point, dtype=float).reshape(1, -1)
157
-
158
- return X, np.array(y), new_point, feature_cols, encoders
159
-
160
-
161
- def run_xgboost_and_visualize(df, target_col, new_point_dict,
162
- n_estimators, max_depth, min_child_weight,
163
- subsample, colsample_bytree, learning_rate, train_test_split_ratio=0.8, problem_type=None):
164
- X, y, new_point, feature_cols, _ = preprocess_data(df, target_col, new_point_dict)
165
-
166
- if problem_type is None:
167
- problem_type = determine_problem_type(df, target_col)
168
-
169
- if n_estimators < 1:
170
- return None, None, None, None, "Number of estimators must be ≥ 1.", None
171
- if max_depth is not None and max_depth < 0:
172
- return None, None, None, None, "Max depth must be ≥ 0.", None
173
- if min_child_weight < 1:
174
- return None, None, None, None, "Min child weight must be ≥ 1.", None
175
- if learning_rate <= 0 or learning_rate > 1:
176
- return None, None, None, None, "Learning rate must be between 0 and 1.", None
177
-
178
- n_estimators = min(int(n_estimators), 100) # Limit to 100 trees
179
-
180
- # Split data for loss tracking with user-defined ratio
181
- test_size = 1.0 - train_test_split_ratio
182
- X_train, X_val, y_train, y_val = train_test_split(X.values, y, test_size=test_size, random_state=42)
183
-
184
- if problem_type == "classification":
185
- # For binary/multiclass classification
186
- model = xgb.XGBClassifier(
187
- n_estimators=n_estimators,
188
- max_depth=int(max_depth) if max_depth > 0 else 3,
189
- min_child_weight=int(min_child_weight),
190
- subsample=float(subsample),
191
- colsample_bytree=float(colsample_bytree),
192
- learning_rate=float(learning_rate),
193
- random_state=42,
194
- verbosity=0
195
- )
196
- else:
197
- model = xgb.XGBRegressor(
198
- n_estimators=n_estimators,
199
- max_depth=int(max_depth) if max_depth > 0 else 3,
200
- min_child_weight=int(min_child_weight),
201
- subsample=float(subsample),
202
- colsample_bytree=float(colsample_bytree),
203
- learning_rate=float(learning_rate),
204
- random_state=42,
205
- verbosity=0
206
- )
207
-
208
- # Fit with early stopping to capture loss evolution
209
- eval_set = [(X_train, y_train), (X_val, y_val)]
210
- model.fit(X_train, y_train, eval_set=eval_set, verbose=False)
211
-
212
- prediction = model.predict(new_point)[0]
213
- _set_current_model(model)
214
-
215
- # Calculate performance metrics
216
- train_pred = model.predict(X_train)
217
- val_pred = model.predict(X_val)
218
-
219
- if problem_type == "classification":
220
- from sklearn.metrics import accuracy_score
221
- train_performance = accuracy_score(y_train, train_pred)
222
- val_performance = accuracy_score(y_val, val_pred)
223
- performance_metric = "Accuracy"
224
- else:
225
- from sklearn.metrics import mean_squared_error
226
- train_performance = mean_squared_error(y_train, train_pred)
227
- val_performance = mean_squared_error(y_val, val_pred)
228
- performance_metric = "MSE"
229
-
230
- # Store split info for aggregation display
231
- split_info = {
232
- "train_size": len(X_train),
233
- "val_size": len(X_val),
234
- "train_ratio": train_test_split_ratio,
235
- "val_ratio": 1.0 - train_test_split_ratio,
236
- "train_performance": train_performance,
237
- "val_performance": val_performance,
238
- "performance_metric": performance_metric
239
- }
240
-
241
- boosting_progress_fig = create_xgboost_progress_chart(model, new_point[0], problem_type, target_col, df)
242
- loss_chart_fig = create_loss_chart(model)
243
- importance_fig = create_feature_importance_plot(model, feature_cols)
244
- prediction_details = create_prediction_details(model, new_point[0], feature_cols, target_col, prediction, problem_type)
245
- summary = create_algorithm_summary(model, problem_type, n_estimators, max_depth, min_child_weight, subsample, colsample_bytree, learning_rate, feature_cols)
246
- aggregation_display = create_xgboost_aggregation_display(model, new_point[0], problem_type, target_col, df, split_info)
247
-
248
- return boosting_progress_fig, loss_chart_fig, importance_fig, prediction, prediction_details, summary, aggregation_display
249
-
250
-
251
- def create_loss_chart(model):
252
- """Create a loss chart showing training and validation loss evolution"""
253
- try:
254
- # Get the evaluation results for XGBoost
255
- evals_result = model.evals_result()
256
-
257
- fig = go.Figure()
258
-
259
- # Plot training loss
260
- if 'validation_0' in evals_result:
261
- train_metric = list(evals_result['validation_0'].keys())[0]
262
- train_loss = evals_result['validation_0'][train_metric]
263
- epochs = list(range(1, len(train_loss) + 1))
264
-
265
- fig.add_trace(go.Scatter(
266
- x=epochs,
267
- y=train_loss,
268
- mode='lines+markers',
269
- name='Training Loss',
270
- line=dict(color='#FF6B6B', width=2),
271
- marker=dict(size=6)
272
- ))
273
-
274
- # Plot validation loss
275
- if 'validation_1' in evals_result:
276
- val_metric = list(evals_result['validation_1'].keys())[0]
277
- val_loss = evals_result['validation_1'][val_metric]
278
- epochs = list(range(1, len(val_loss) + 1))
279
-
280
- fig.add_trace(go.Scatter(
281
- x=epochs,
282
- y=val_loss,
283
- mode='lines+markers',
284
- name='Validation Loss',
285
- line=dict(color='#4ECDC4', width=2),
286
- marker=dict(size=6)
287
- ))
288
-
289
- fig.update_layout(
290
- title="XGBoost Training Progress - Loss Evolution",
291
- xaxis_title="Boosting Round (Tree)",
292
- yaxis_title="Loss",
293
- plot_bgcolor="white",
294
- height=400,
295
- legend=dict(
296
- yanchor="top",
297
- y=0.99,
298
- xanchor="right",
299
- x=0.99
300
- ),
301
- margin=dict(l=40, r=40, t=60, b=40)
302
- )
303
-
304
- fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
305
- fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
306
-
307
- return fig
308
- except Exception as e:
309
- # Fallback if no loss data is available
310
- fig = go.Figure()
311
- fig.add_annotation(
312
- text=f"Loss tracking not available<br>Error: {str(e)}<br>Run training to see loss evolution",
313
- xref="paper", yref="paper",
314
- x=0.5, y=0.5, xanchor='center', yanchor='middle',
315
- showarrow=False,
316
- font=dict(size=14)
317
- )
318
- fig.update_layout(
319
- title="XGBoost Training Progress - Loss Evolution",
320
- height=400,
321
- plot_bgcolor="white"
322
- )
323
- return fig
324
-
325
-
326
- def create_xgboost_progress_chart(model, new_point, problem_type, target_col=None, df=None):
327
- """Create a chart showing how XGBoost prediction evolves with each tree"""
328
-
329
- if problem_type == "classification":
330
- # For classification, show probability evolution
331
- try:
332
- # Get number of trees
333
- n_trees = model.n_estimators
334
-
335
- # Create a temporary model with varying n_estimators to see progression
336
- iteration_data = []
337
-
338
- # We'll use the model's predict_proba method with ntree_limit
339
- # Sample every few trees for visualization if more than 50 trees
340
- if n_trees <= 50:
341
- tree_indices = list(range(1, n_trees + 1))
342
- else:
343
- # Sample 50 evenly spaced trees for visualization
344
- tree_indices = [int(i) for i in np.linspace(1, n_trees, min(50, n_trees))]
345
-
346
- for i in tree_indices:
347
- try:
348
- # For XGBoost, we can't easily get staged predictions like sklearn
349
- # So we'll create new models with fewer estimators
350
- temp_model = type(model)(
351
- **{k: v for k, v in model.get_params().items() if k != 'n_estimators'},
352
- n_estimators=i,
353
- random_state=42
354
- )
355
- # We need the original training data for this approach
356
- # For simplicity, we'll approximate using the full model
357
- if hasattr(model, 'predict_proba'):
358
- proba = model.predict_proba(new_point.reshape(1, -1), ntree_limit=i)[0]
359
- pred = model.predict(new_point.reshape(1, -1), ntree_limit=i)[0]
360
- max_prob = np.max(proba)
361
- predicted_class = int(pred)
362
- else:
363
- # Fallback
364
- proba = model.predict_proba(new_point.reshape(1, -1))[0]
365
- pred = model.predict(new_point.reshape(1, -1))[0]
366
- max_prob = np.max(proba)
367
- predicted_class = int(pred)
368
-
369
- iteration_data.append({
370
- 'iteration': i,
371
- 'prediction_class': predicted_class,
372
- 'confidence': max_prob
373
- })
374
- except:
375
- # If ntree_limit doesn't work, use full prediction
376
- proba = model.predict_proba(new_point.reshape(1, -1))[0]
377
- pred = model.predict(new_point.reshape(1, -1))[0]
378
- max_prob = np.max(proba)
379
- predicted_class = int(pred)
380
-
381
- iteration_data.append({
382
- 'iteration': i,
383
- 'prediction_class': predicted_class,
384
- 'confidence': max_prob
385
- })
386
-
387
- # Create line chart
388
- fig = go.Figure()
389
-
390
- iterations = [data['iteration'] for data in iteration_data]
391
- confidences = [data['confidence'] for data in iteration_data]
392
- predictions = [data['prediction_class'] for data in iteration_data]
393
-
394
- # Color mapping for different classes
395
- colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FECA57', '#FF9FF3', '#54A0FF', '#5F27CD', '#00D2D3', '#FF9F43']
396
-
397
- # Group points by class for better visualization
398
- class_data = {}
399
- for iter_num, conf, pred_class in zip(iterations, confidences, predictions):
400
- if pred_class not in class_data:
401
- class_data[pred_class] = {'iterations': [], 'confidences': []}
402
- class_data[pred_class]['iterations'].append(iter_num)
403
- class_data[pred_class]['confidences'].append(conf)
404
-
405
- # Plot lines for each class
406
- for class_idx, data in class_data.items():
407
- color = colors[class_idx % len(colors)]
408
- fig.add_trace(go.Scatter(
409
- x=data['iterations'],
410
- y=data['confidences'],
411
- mode='lines+markers',
412
- name=f'Class {class_idx}',
413
- line=dict(color=color, width=3),
414
- marker=dict(size=8, symbol='circle'),
415
- hovertemplate=f'<b>Tree %{{x}}</b><br>Class {class_idx}<br>Confidence: %{{y:.3f}}<extra></extra>'
416
- ))
417
-
418
- fig.update_layout(
419
- title="XGBoost Progress: How Prediction Confidence Evolves",
420
- xaxis_title="Tree Number",
421
- yaxis_title="Prediction Confidence",
422
- plot_bgcolor="white",
423
- height=450,
424
- legend=dict(
425
- yanchor="top",
426
- y=0.99,
427
- xanchor="right",
428
- x=0.99
429
- ),
430
- margin=dict(l=40, r=40, t=60, b=40)
431
- )
432
-
433
- except Exception as e:
434
- # Fallback to simple visualization
435
- fig = go.Figure()
436
- fig.add_annotation(
437
- text=f"Classification Progress Visualization<br>Final Prediction: {model.predict(new_point.reshape(1, -1))[0]}<br>Model trained with {model.n_estimators} trees",
438
- xref="paper", yref="paper",
439
- x=0.5, y=0.5, xanchor='center', yanchor='middle',
440
- showarrow=False,
441
- font=dict(size=14)
442
- )
443
- fig.update_layout(
444
- title="XGBoost Progress: Classification Results",
445
- height=450,
446
- plot_bgcolor="white"
447
- )
448
-
449
- else: # Regression
450
- try:
451
- # For regression, show prediction value evolution
452
- n_trees = model.n_estimators
453
- iteration_data = []
454
-
455
- # Sample trees for visualization efficiency
456
- if n_trees <= 50:
457
- tree_indices = list(range(1, n_trees + 1))
458
- else:
459
- # Sample 50 evenly spaced trees for visualization
460
- tree_indices = [int(i) for i in np.linspace(1, n_trees, min(50, n_trees))]
461
-
462
- for i in tree_indices:
463
- try:
464
- pred = model.predict(new_point.reshape(1, -1), ntree_limit=i)[0]
465
- except:
466
- pred = model.predict(new_point.reshape(1, -1))[0]
467
-
468
- iteration_data.append({
469
- 'iteration': i,
470
- 'prediction': pred
471
- })
472
-
473
- iterations = [data['iteration'] for data in iteration_data]
474
- predictions = [data['prediction'] for data in iteration_data]
475
-
476
- fig = go.Figure()
477
- fig.add_trace(go.Scatter(
478
- x=iterations,
479
- y=predictions,
480
- mode='lines+markers',
481
- name='Prediction Value',
482
- line=dict(color='#FF6B6B', width=3),
483
- marker=dict(size=8, symbol='circle'),
484
- hovertemplate='<b>Tree %{x}</b><br>Prediction: %{y:.3f}<extra></extra>'
485
- ))
486
-
487
- # Add final prediction line
488
- final_pred = predictions[-1] if predictions else 0
489
- fig.add_hline(
490
- y=final_pred,
491
- line_dash="dash",
492
- line_color="gray",
493
- annotation_text=f"Final: {final_pred:.3f}",
494
- annotation_position="right"
495
- )
496
-
497
- fig.update_layout(
498
- title="XGBoost Progress: How Prediction Value Evolves",
499
- xaxis_title="Tree Number",
500
- yaxis_title="Prediction Value",
501
- plot_bgcolor="white",
502
- height=450,
503
- margin=dict(l=40, r=40, t=60, b=40)
504
- )
505
-
506
- except Exception as e:
507
- # Fallback
508
- fig = go.Figure()
509
- final_pred = model.predict(new_point.reshape(1, -1))[0]
510
- fig.add_annotation(
511
- text=f"Regression Progress Visualization<br>Final Prediction: {final_pred:.3f}<br>Model trained with {model.n_estimators} trees",
512
- xref="paper", yref="paper",
513
- x=0.5, y=0.5, xanchor='center', yanchor='middle',
514
- showarrow=False,
515
- font=dict(size=14)
516
- )
517
- fig.update_layout(
518
- title="XGBoost Progress: Regression Results",
519
- height=450,
520
- plot_bgcolor="white"
521
- )
522
-
523
- fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
524
- fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
525
-
526
- return fig
527
-
528
-
529
- def create_individual_tree_visualization(model, tree_index, feature_cols, problem_type):
530
- """Create visualization of individual XGBoost tree"""
531
- try:
532
- # Get actual XGBoost tree structure
533
- return create_xgboost_tree_plot(model, tree_index, feature_cols, problem_type)
534
-
535
- except Exception as e:
536
- # Fallback visualization
537
- fig = go.Figure()
538
- fig.add_annotation(
539
- text=f"XGBoost Tree {tree_index + 1} Visualization<br>Unable to extract tree structure<br>Error: {str(e)}",
540
- xref="paper", yref="paper",
541
- x=0.5, y=0.5, xanchor='center', yanchor='middle',
542
- showarrow=False,
543
- font=dict(size=14)
544
- )
545
- fig.update_layout(
546
- title=f"XGBoost Tree {tree_index + 1} Structure",
547
- height=500,
548
- plot_bgcolor="white"
549
- )
550
- return fig
551
-
552
-
553
- def create_xgboost_tree_plot(model, tree_index, feature_cols, problem_type):
554
- """Create tree visualization for XGBoost models"""
555
- try:
556
- # Try to use XGBoost's built-in tree structure if available
557
- booster = model.get_booster()
558
- tree_dump = booster.get_dump(dump_format='json')[tree_index]
559
-
560
- import json
561
- tree_dict = json.loads(tree_dump)
562
-
563
- return create_tree_plot_from_dict(tree_dict, tree_index, feature_cols, problem_type, "XGBoost")
564
-
565
- except Exception as e:
566
- # Fallback to manual tree creation
567
- return create_manual_tree_plot(tree_index, feature_cols, problem_type, "XGBoost")
568
-
569
-
570
- # Removed sklearn tree plotting functions - XGBoost only
571
-
572
-
573
- def create_tree_plot_from_dict(tree_dict, tree_index, feature_cols, problem_type, model_type):
574
- """Create tree plot from tree dictionary structure"""
575
- fig = go.Figure()
576
-
577
- # Calculate node positions
578
- positions = {}
579
- labels = {}
580
- colors = {}
581
-
582
- def assign_positions(node, node_id, x, y, width, level=0):
583
- positions[node_id] = (x, y)
584
-
585
- if "leaf" in node:
586
- # Leaf node
587
- if problem_type == "classification":
588
- labels[node_id] = f"Leaf<br>Value: {node['leaf']:.3f}<br>Samples: {node.get('samples', 'N/A')}"
589
- else:
590
- labels[node_id] = f"Leaf<br>Prediction: {node['leaf']:.3f}<br>Samples: {node.get('samples', 'N/A')}"
591
- colors[node_id] = "#FFB74D" # Orange for leaves
592
- else:
593
- # Split node
594
- split_name = node.get("split", "feature")
595
- threshold = node.get("split_condition", 0)
596
- samples = node.get("samples", "N/A")
597
-
598
- labels[node_id] = f"{split_name}<br>≤ {threshold:.3f}<br>Samples: {samples}"
599
- colors[node_id] = "#81C784" # Green for split nodes
600
-
601
- # Process children
602
- if "children" in node and len(node["children"]) == 2:
603
- child_width = width / 2
604
- left_child_id = f"{node_id}_L"
605
- right_child_id = f"{node_id}_R"
606
-
607
- assign_positions(node["children"][0], left_child_id, x - child_width/2, y - 1, child_width, level + 1)
608
- assign_positions(node["children"][1], right_child_id, x + child_width/2, y - 1, child_width, level + 1)
609
-
610
- # Start positioning from root
611
- assign_positions(tree_dict, "root", 0, 0, 4)
612
-
613
- # Create edges first (so they appear behind nodes)
614
- edge_x, edge_y = [], []
615
- for node_id, (x, y) in positions.items():
616
- if node_id.endswith("_L") or node_id.endswith("_R"):
617
- # This is a child node, draw edge to parent
618
- parent_id = node_id.rsplit("_", 1)[0]
619
- if parent_id in positions:
620
- parent_x, parent_y = positions[parent_id]
621
- edge_x.extend([parent_x, x, None])
622
- edge_y.extend([parent_y, y, None])
623
-
624
- # Add edges
625
- if edge_x:
626
- fig.add_trace(go.Scatter(
627
- x=edge_x, y=edge_y,
628
- mode='lines',
629
- line=dict(color='gray', width=2),
630
- showlegend=False,
631
- hoverinfo='none'
632
- ))
633
-
634
- # Add nodes
635
- node_x = [pos[0] for pos in positions.values()]
636
- node_y = [pos[1] for pos in positions.values()]
637
- node_colors = [colors[node_id] for node_id in positions.keys()]
638
- node_labels = [labels[node_id] for node_id in positions.keys()]
639
-
640
- fig.add_trace(go.Scatter(
641
- x=node_x, y=node_y,
642
- mode='markers+text',
643
- marker=dict(
644
- size=30,
645
- color=node_colors,
646
- line=dict(width=2, color='darkblue'),
647
- symbol='circle'
648
- ),
649
- text=node_labels,
650
- textposition='middle center',
651
- textfont=dict(size=10, color='black'),
652
- showlegend=False,
653
- hoverinfo='text',
654
- hovertext=node_labels
655
- ))
656
-
657
- fig.update_layout(
658
- title=f"{model_type} Tree {tree_index + 1} Structure ({problem_type.title()})",
659
- xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
660
- yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
661
- plot_bgcolor="white",
662
- height=600,
663
- margin=dict(l=40, r=40, t=60, b=40),
664
- showlegend=False
665
- )
666
-
667
- return fig
668
-
669
-
670
- def create_manual_tree_plot(tree_index, feature_cols, problem_type, model_type):
671
- """Create a manual tree visualization when tree structure is not accessible"""
672
- fig = go.Figure()
673
-
674
- # Create a sample tree structure for demonstration
675
- import random
676
- random.seed(tree_index) # Consistent trees for same index
677
-
678
- # Root node
679
- root_feature = random.choice(feature_cols) if feature_cols else "feature_0"
680
- root_threshold = round(random.uniform(0.1, 5.0), 2)
681
-
682
- # Positions for a simple 3-level tree
683
- positions = {
684
- 'root': (0, 2),
685
- 'left': (-1.5, 1),
686
- 'right': (1.5, 1),
687
- 'left_left': (-2.5, 0),
688
- 'left_right': (-0.5, 0),
689
- 'right_left': (0.5, 0),
690
- 'right_right': (2.5, 0)
691
- }
692
-
693
- # Labels and colors
694
- labels = {
695
- 'root': f"{root_feature}<br>≤ {root_threshold}<br>Samples: 150",
696
- 'left': f"{random.choice(feature_cols) if feature_cols else 'feature_1'}<br>≤ {round(random.uniform(0.1, 3.0), 2)}<br>Samples: 75",
697
- 'right': f"{random.choice(feature_cols) if feature_cols else 'feature_2'}<br>≤ {round(random.uniform(0.1, 3.0), 2)}<br>Samples: 75",
698
- 'left_left': f"Leaf<br>Value: {round(random.uniform(-1, 1), 3)}<br>Samples: 25",
699
- 'left_right': f"Leaf<br>Value: {round(random.uniform(-1, 1), 3)}<br>Samples: 50",
700
- 'right_left': f"Leaf<br>Value: {round(random.uniform(-1, 1), 3)}<br>Samples: 30",
701
- 'right_right': f"Leaf<br>Value: {round(random.uniform(-1, 1), 3)}<br>Samples: 45"
702
- }
703
-
704
- colors = {
705
- 'root': '#81C784', 'left': '#81C784', 'right': '#81C784', # Green for split nodes
706
- 'left_left': '#FFB74D', 'left_right': '#FFB74D', 'right_left': '#FFB74D', 'right_right': '#FFB74D' # Orange for leaves
707
- }
708
-
709
- # Draw edges
710
- edges = [
711
- ('root', 'left'), ('root', 'right'),
712
- ('left', 'left_left'), ('left', 'left_right'),
713
- ('right', 'right_left'), ('right', 'right_right')
714
- ]
715
-
716
- edge_x, edge_y = [], []
717
- for parent, child in edges:
718
- parent_pos = positions[parent]
719
- child_pos = positions[child]
720
- edge_x.extend([parent_pos[0], child_pos[0], None])
721
- edge_y.extend([parent_pos[1], child_pos[1], None])
722
-
723
- fig.add_trace(go.Scatter(
724
- x=edge_x, y=edge_y,
725
- mode='lines',
726
- line=dict(color='gray', width=2),
727
- showlegend=False,
728
- hoverinfo='none'
729
- ))
730
-
731
- # Draw nodes
732
- for node_id, (x, y) in positions.items():
733
- fig.add_trace(go.Scatter(
734
- x=[x], y=[y],
735
- mode='markers+text',
736
- marker=dict(
737
- size=35,
738
- color=colors[node_id],
739
- line=dict(width=2, color='darkblue'),
740
- symbol='circle'
741
- ),
742
- text=labels[node_id],
743
- textposition='middle center',
744
- textfont=dict(size=9, color='black'),
745
- showlegend=False,
746
- hoverinfo='text',
747
- hovertext=labels[node_id]
748
- ))
749
-
750
- fig.update_layout(
751
- title=f"{model_type} Tree {tree_index + 1} Structure ({problem_type.title()})",
752
- xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-3, 3]),
753
- yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-0.5, 2.5]),
754
- plot_bgcolor="white",
755
- height=600,
756
- margin=dict(l=40, r=40, t=60, b=40),
757
- showlegend=False
758
- )
759
-
760
- return fig
761
-
762
-
763
- def get_individual_tree_visualization(model, tree_index, feature_cols, problem_type):
764
- return create_individual_tree_visualization(model, tree_index, feature_cols, problem_type)
765
-
766
-
767
- def create_feature_importance_plot(model, feature_cols):
768
- try:
769
- importances = model.feature_importances_
770
- order = np.argsort(importances)[::-1]
771
-
772
- fig = go.Figure()
773
- fig.add_trace(
774
- go.Bar(
775
- x=[feature_cols[i] for i in order],
776
- y=importances[order],
777
- text=[f"{importances[i]:.3f}" for i in order],
778
- textposition="auto",
779
- marker_color="lightcoral",
780
- hovertemplate="<b>%{x}</b><br>Importance: %{y:.3f}<extra></extra>",
781
- )
782
- )
783
- fig.update_layout(
784
- title="XGBoost Feature Importance",
785
- xaxis_title="Features",
786
- yaxis_title="Importance",
787
- plot_bgcolor="white",
788
- height=400,
789
- margin=dict(l=40, r=40, t=60, b=40),
790
- )
791
- fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="lightgray")
792
- fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor="lightgray")
793
- return fig
794
- except:
795
- fig = go.Figure()
796
- fig.add_annotation(
797
- text="Feature importance not available",
798
- xref="paper", yref="paper",
799
- x=0.5, y=0.5, xanchor='center', yanchor='middle',
800
- showarrow=False,
801
- font=dict(size=14)
802
- )
803
- fig.update_layout(
804
- title="XGBoost Feature Importance",
805
- height=400,
806
- plot_bgcolor="white"
807
- )
808
- return fig
809
-
810
-
811
- def create_prediction_details(model, new_point, feature_cols, target_col, prediction, problem_type):
812
- if problem_type == "classification":
813
- try:
814
- probabilities = model.predict_proba(new_point.reshape(1, -1))[0]
815
- classes = model.classes_
816
- return f"Predicted Class: {int(prediction)} | Probabilities: {dict(zip(classes, probabilities))}"
817
- except:
818
- return f"Predicted Class: {int(prediction)}"
819
- else:
820
- return f"Predicted Value: {prediction:.3f}"
821
-
822
-
823
- def create_algorithm_summary(model, problem_type, n_estimators, max_depth, min_child_weight, subsample, colsample_bytree, learning_rate, feature_cols):
824
- return f"""
825
- **XGBoost {problem_type.title()} Model Summary:**
826
- - Trees: {n_estimators}
827
- - Max Depth: {max_depth}
828
- - Min Child Weight: {min_child_weight}
829
- - Subsample: {subsample}
830
- - Column Sample by Tree: {colsample_bytree}
831
- - Learning Rate: {learning_rate}
832
- - Features: {len(feature_cols)}
833
- """
834
-
835
-
836
- def create_xgboost_aggregation_display(model, new_point, problem_type, target_col=None, df=None, split_info=None):
837
- """Create HTML display showing XGBoost ensemble aggregation process"""
838
-
839
- try:
840
- if problem_type == "classification":
841
- prediction = model.predict(new_point.reshape(1, -1))[0]
842
- probabilities = model.predict_proba(new_point.reshape(1, -1))[0]
843
-
844
- # Build the aggregation display with split info
845
- html_content = f"""
846
- <div style='background:#F0F8FF;border-left:6px solid #4ECDC4;padding:14px 16px;border-radius:10px;'>
847
- <strong>🚀 XGBoost Ensemble Process</strong><br><br>
848
-
849
- <div style='margin:8px 0;'>
850
- <strong>📊 Model Configuration:</strong><br>
851
- • {model.n_estimators} trees in ensemble<br>
852
- • Max depth: {model.max_depth}<br>
853
- • Learning rate: {model.learning_rate}<br>
854
- </div>"""
855
-
856
- if split_info:
857
- html_content += f"""
858
- <div style='margin:8px 0;'>
859
- <strong>📊 Data Split Information:</strong><br>
860
- • Training Set: {split_info['train_size']} samples ({split_info['train_ratio']:.1%})<br>
861
- • Validation Set: {split_info['val_size']} samples ({split_info['val_ratio']:.1%})<br>
862
- </div>
863
-
864
- <div style='margin:8px 0;'>
865
- <strong>📈 Model Performance:</strong><br>
866
- • Training {split_info['performance_metric']}: <span style='background:#E8F5E8;padding:2px 6px;border-radius:4px;'><strong>{split_info['train_performance']:.4f}</strong></span><br>
867
- • Validation {split_info['performance_metric']}: <span style='background:#E8F5E8;padding:2px 6px;border-radius:4px;'><strong>{split_info['val_performance']:.4f}</strong></span><br>
868
- </div>"""
869
-
870
- html_content += f"""
871
- <div style='margin:8px 0;'>
872
- <strong>🎯 Final Prediction:</strong><br>
873
- • Predicted Class: <span style='background:#FFE5B4;padding:2px 6px;border-radius:4px;'><strong>{int(prediction)}</strong></span><br>
874
- • Class Probabilities: {dict(zip(range(len(probabilities)), [f'{p:.3f}' for p in probabilities]))}<br>
875
- </div>
876
-
877
- <div style='margin:8px 0;'>
878
- <strong>⚡ XGBoost Process:</strong><br>
879
- 1. Each tree corrects errors from previous trees<br>
880
- 2. Gradient-based optimization for efficient learning<br>
881
- 3. Regularization prevents overfitting<br>
882
- 4. Final prediction combines all {model.n_estimators} trees<br>
883
- </div>
884
- </div>
885
- """
886
- else:
887
- prediction = model.predict(new_point.reshape(1, -1))[0]
888
-
889
- html_content = f"""
890
- <div style='background:#F0F8FF;border-left:6px solid #4ECDC4;padding:14px 16px;border-radius:10px;'>
891
- <strong>🚀 XGBoost Ensemble Process</strong><br><br>
892
-
893
- <div style='margin:8px 0;'>
894
- <strong>📊 Model Configuration:</strong><br>
895
- • {model.n_estimators} trees in ensemble<br>
896
- • Max depth: {model.max_depth}<br>
897
- • Learning rate: {model.learning_rate}<br>
898
- </div>"""
899
-
900
- if split_info:
901
- html_content += f"""
902
- <div style='margin:8px 0;'>
903
- <strong>📊 Data Split Information:</strong><br>
904
- • Training Set: {split_info['train_size']} samples ({split_info['train_ratio']:.1%})<br>
905
- • Validation Set: {split_info['val_size']} samples ({split_info['val_ratio']:.1%})<br>
906
- </div>
907
-
908
- <div style='margin:8px 0;'>
909
- <strong>📈 Model Performance:</strong><br>
910
- • Training {split_info['performance_metric']}: <span style='background:#E8F5E8;padding:2px 6px;border-radius:4px;'><strong>{split_info['train_performance']:.4f}</strong></span><br>
911
- • Validation {split_info['performance_metric']}: <span style='background:#E8F5E8;padding:2px 6px;border-radius:4px;'><strong>{split_info['val_performance']:.4f}</strong></span><br>
912
- </div>"""
913
-
914
- html_content += f"""
915
- <div style='margin:8px 0;'>
916
- <strong>🎯 Final Prediction:</strong><br>
917
- • Predicted Value: <span style='background:#FFE5B4;padding:2px 6px;border-radius:4px;'><strong>{prediction:.3f}</strong></span><br>
918
- </div>
919
-
920
- <div style='margin:8px 0;'>
921
- <strong>⚡ XGBoost Process:</strong><br>
922
- 1. Each tree corrects errors from previous trees<br>
923
- 2. Gradient-based optimization for efficient learning<br>
924
- 3. Advanced regularization techniques<br>
925
- 4. Final prediction aggregates all {model.n_estimators} trees<br>
926
- </div>
927
- </div>
928
- """
929
-
930
- return html_content
931
-
932
- except Exception as e:
933
- return f"""
934
- <div style='background:#FFF4F4;border-left:6px solid #C4314B;padding:14px 16px;border-radius:10px;'>
935
- <strong>🚀 XGBoost Process</strong><br><br>
936
- Error generating aggregation display: {str(e)}
937
- </div>
938
- """