| import numpy as np |
| import pandas as pd |
| import skops.io as sio |
| import shap |
| import plotly.graph_objects as go |
| import os |
| import sys |
| import warnings |
|
|
| warnings.filterwarnings("ignore", category=UserWarning, module="sklearn") |
|
|
|
|
| import sklearn.compose._column_transformer as _ct |
| if not hasattr(_ct, "_RemainderColsList"): |
| class _RemainderColsList(list): |
| """Minimal shim for sklearn._RemainderColsList (missing in this env).""" |
| def __init__(self, lst=None, future_dtype=None): |
| super().__init__(lst or []) |
| self.future_dtype = future_dtype |
| _ct._RemainderColsList = _RemainderColsList |
| import sklearn.compose |
| sklearn.compose._RemainderColsList = _RemainderColsList |
|
|
|
|
|
|
| NUM_COLUMNS = ["AGE", "NACS2YR"] |
| CATEG_COLUMNS = [ |
| "AGEGPFF", |
| "SEX", |
| "KPS", |
| "DONORF", |
| "GRAFTYPE", |
| "CONDGRPF", |
| "CONDGRP_FINAL", |
| "ATGF", |
| "GVHD_FINAL", |
| "HLA_FINAL", |
| "RCMVPR", |
| "EXCHTFPR", |
| "VOC2YPR", |
| "VOCFRQPR", |
| "SCATXRSN", |
| ] |
|
|
| FEATURE_NAMES = NUM_COLUMNS + CATEG_COLUMNS |
|
|
| OUTCOMES = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI", "DWOGF"] |
| CLASSIFICATION_OUTCOMES = OUTCOMES |
|
|
| REPORTING_OUTCOMES = [ |
| "OS", "EFS", "GF", "DEAD", |
| "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI", |
| ] |
|
|
| OUTCOME_DESCRIPTIONS = { |
| "OS": "Overall Survival", |
| "EFS": "Event-Free Survival", |
| "DEAD": "Death", |
| "GF": "Graft Failure", |
| "AGVHD": "Acute Graft-versus-Host Disease", |
| "CGVHD": "Chronic Graft-versus-Host Disease", |
| "VOCPSHI": "Vaso-Occlusive Crisis Post-HCT", |
| "STROKEHI": "Stroke Post-HCT", |
| } |
|
|
| SHAP_OUTCOMES = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI", "OS", "EFS"] |
|
|
| MODEL_DIR = "." |
| CONSENSUS_THRESHOLD = 0.5 |
| DEFAULT_N_BOOT_CI = 500 |
|
|
|
|
|
|
|
|
| def _load_skops_model(fname): |
| try: |
| untrusted = sio.get_untrusted_types(file=fname) |
| return sio.load(fname, trusted=untrusted) |
| except Exception as e: |
| print(f"Error loading '{fname}': {e}") |
| sys.exit(1) |
|
|
|
|
| preprocessor = _load_skops_model(os.path.join(MODEL_DIR, "preprocessor.skops")) |
|
|
| classification_model_data = {} |
| for _o in CLASSIFICATION_OUTCOMES: |
| _path = os.path.join(MODEL_DIR, f"ensemble_model_{_o}.skops") |
| if os.path.exists(_path): |
| classification_model_data[_o] = _load_skops_model(_path) |
| else: |
| print(f"Warning: Model for {_o} not found at {_path}. Skipping.") |
|
|
| classification_models = {o: d["models"] for o, d in classification_model_data.items()} |
| betas = {o: d["beta"] for o, d in classification_model_data.items()} |
| priors = {o: d["prior"] for o, d in classification_model_data.items()} |
| consensus_thresholds = { |
| o: d.get("consensus_threshold", CONSENSUS_THRESHOLD) |
| for o, d in classification_model_data.items() |
| } |
|
|
|
|
| calibrators = {} |
| for _o, _d in classification_model_data.items(): |
| _cal = None |
| _cal_type = _d.get("calibrator_type", None) |
|
|
| if "calibrator" in _d and _d["calibrator"] is not None: |
| if _cal_type is None or _cal_type == "isotonic": |
| _cal = _d["calibrator"] |
| else: |
| print( |
| f"Warning: outcome '{_o}' has calibrator_type='{_cal_type}'. " |
| "Skipping non-isotonic calibrator (isotonic-only policy)." |
| ) |
| elif "isotonic_calibrator" in _d and _d["isotonic_calibrator"] is not None: |
| _cal = _d["isotonic_calibrator"] |
|
|
| calibrators[_o] = _cal |
|
|
| |
| isotonic_calibrators = calibrators |
|
|
| oof_probs_calibrated = { |
| o: d.get("oof_probs_calibrated") for o, d in classification_model_data.items() |
| } |
|
|
| ohe = preprocessor.named_transformers_["cat"] |
| ohe_feature_names = ohe.get_feature_names_out(CATEG_COLUMNS) |
| processed_feature_names = np.concatenate([NUM_COLUMNS, ohe_feature_names]) |
|
|
|
|
|
|
|
|
| np.random.seed(23) |
| _n_background = 500 |
|
|
| _background_data = { |
| "AGE": np.random.uniform(5, 50, _n_background), |
| "NACS2YR": np.random.randint(0, 5, _n_background), |
| "AGEGPFF": np.random.choice(["<=10", "11-17", "18-29", "30-49", ">=50"], _n_background), |
| "SEX": np.random.choice(["Male", "Female"], _n_background), |
| "KPS": np.random.choice(["<90", "≥ 90"], _n_background), |
| "DONORF": np.random.choice([ |
| "HLA identical sibling", "HLA mismatch relative", |
| "Matched unrelated donor", |
| "Mismatched unrelated donor or cord blood", |
| ], _n_background), |
| "GRAFTYPE": np.random.choice(["Bone marrow", "Peripheral blood", "Cord blood"], _n_background), |
| "CONDGRPF": np.random.choice(["MAC", "RIC", "NMA"], _n_background), |
| "CONDGRP_FINAL": np.random.choice(["TBI/Cy", "Bu/Cy", "Flu/Bu", "Flu/Mel"], _n_background), |
| "ATGF": np.random.choice(["ATG", "Alemtuzumab", "None"], _n_background), |
| "GVHD_FINAL": np.random.choice(["CNI + MMF", "CNI + MTX", "Post-CY + siro +- MMF"], _n_background), |
| "HLA_FINAL": np.random.choice(["8/8", "7/8", "≤ 6/8"], _n_background), |
| "RCMVPR": np.random.choice(["Negative", "Positive"], _n_background), |
| "EXCHTFPR": np.random.choice(["No", "Yes"], _n_background), |
| "VOC2YPR": np.random.choice(["No", "Yes"], _n_background), |
| "VOCFRQPR": np.random.choice(["< 3/yr", "≥ 3/yr"], _n_background), |
| "SCATXRSN": np.random.choice([ |
| "CNS event", "Acute chest Syndrome", |
| "Recurrent vaso-occlusive pain", "Recurrent priapism", |
| "Excessive transfusion requirements/iron overload", |
| "Cardio-pulmonary", "Chronic transfusion", "Asymptomatic", |
| "Renal insufficiency", "Splenic sequestration", |
| "Avascular necrosis", "Hodgkin lymphoma", |
| ], _n_background), |
| } |
|
|
| _background_df = pd.DataFrame(_background_data)[FEATURE_NAMES] |
| _X_background = preprocessor.transform(_background_df) |
| shap_background = shap.maskers.Independent(_X_background) |
|
|
|
|
|
|
| def calibrate_probabilities_undersampling(p_s, beta): |
| p_s = np.asarray(p_s, dtype=float) |
| numerator = beta * p_s |
| denominator = np.maximum((beta - 1.0) * p_s + 1.0, 1e-10) |
| return np.clip(numerator / denominator, 0.0, 1.0) |
|
|
|
|
| def predict_consensus_signed_voting(ensemble_models, X_test, threshold=0.5): |
| individual_probas = np.array( |
| [m.predict_proba(X_test)[:, 1] for m in ensemble_models] |
| ) |
| binary_preds = (individual_probas >= threshold).astype(int) |
| signed_votes = np.where(binary_preds == 1, 1, -1) |
| avg_signed_vote = np.mean(signed_votes, axis=0) |
| consensus_pred = (avg_signed_vote > 0).astype(int) |
| avg_proba = np.mean(individual_probas, axis=0) |
| return consensus_pred, avg_proba, avg_signed_vote, individual_probas.flatten() |
|
|
|
|
| def predict_consensus_majority(ensemble_models, X_test, threshold=0.5): |
| individual_probas = np.array( |
| [m.predict_proba(X_test)[:, 1] for m in ensemble_models] |
| ) |
| avg_proba = np.mean(individual_probas, axis=0) |
| return avg_proba, individual_probas.flatten() |
|
|
|
|
|
|
|
|
| def bootstrap_ci_from_oof( |
| point_estimate: float, |
| oof_probs: np.ndarray, |
| n_boot: int = DEFAULT_N_BOOT_CI, |
| confidence: float = 0.95, |
| random_state: int = 42, |
| ) -> tuple: |
| if oof_probs is None or len(oof_probs) == 0: |
| return float(point_estimate), float(point_estimate) |
|
|
| oof_probs = np.asarray(oof_probs, dtype=float) |
| rng = np.random.RandomState(random_state) |
| grand_mean = np.mean(oof_probs) |
| n = len(oof_probs) |
|
|
| boot_means = np.array([ |
| np.mean(rng.choice(oof_probs, size=n, replace=True)) |
| for _ in range(n_boot) |
| ]) |
|
|
| shift = point_estimate - grand_mean |
| boot_means = boot_means + shift |
|
|
| alpha = 1.0 - confidence |
| lo = float(np.clip(np.percentile(boot_means, 100 * alpha / 2), 0.0, 1.0)) |
| hi = float(np.clip(np.percentile(boot_means, 100 * (1 - alpha / 2)), 0.0, 1.0)) |
| return lo, hi |
|
|
|
|
|
|
|
|
| def _calibrate_point(outcome: str, raw_prob: float, use_calibration: bool) -> float: |
| beta = betas[outcome] |
| p_beta = float(calibrate_probabilities_undersampling([raw_prob], beta)[0]) |
|
|
| if not use_calibration: |
| return p_beta |
|
|
| cal = calibrators.get(outcome) |
| if cal is None: |
| return p_beta |
|
|
| return float(cal.transform([p_beta])[0]) |
|
|
|
|
|
|
|
|
| def predict_all_outcomes( |
| user_inputs, |
| use_calibration: bool = True, |
| use_signed_voting: bool = True, |
| n_boot_ci: int = DEFAULT_N_BOOT_CI, |
| ): |
| if isinstance(user_inputs, dict): |
| input_df = pd.DataFrame([user_inputs]) |
| else: |
| input_df = pd.DataFrame([user_inputs], columns=FEATURE_NAMES) |
|
|
| input_df = input_df[FEATURE_NAMES] |
| X = preprocessor.transform(input_df) |
|
|
| probs, intervals = {}, {} |
|
|
| for o in CLASSIFICATION_OUTCOMES: |
| if o not in classification_models: |
| continue |
|
|
| threshold = consensus_thresholds.get(o, CONSENSUS_THRESHOLD) |
|
|
| if use_signed_voting: |
| _, uncalib_arr, _, _ = predict_consensus_signed_voting( |
| classification_models[o], X, threshold |
| ) |
| else: |
| uncalib_arr, _ = predict_consensus_majority( |
| classification_models[o], X, threshold |
| ) |
|
|
| raw_prob = float(uncalib_arr[0]) |
| event_prob = _calibrate_point(o, raw_prob, use_calibration) |
|
|
| lo, hi = bootstrap_ci_from_oof( |
| point_estimate=event_prob, |
| oof_probs=oof_probs_calibrated.get(o), |
| n_boot=n_boot_ci, |
| ) |
|
|
| probs[o] = event_prob |
| intervals[o] = (lo, hi) |
|
|
| if "DEAD" in probs: |
| p_dead = probs["DEAD"] |
| probs["OS"] = float(1.0 - p_dead) |
|
|
| dead_lo, dead_hi = intervals["DEAD"] |
| intervals["OS"] = ( |
| float(np.clip(1.0 - dead_hi, 0, 1)), |
| float(np.clip(1.0 - dead_lo, 0, 1)), |
| ) |
|
|
| if "DWOGF" in probs and "GF" in probs: |
| p_dwogf = probs["DWOGF"] |
| p_gf = probs["GF"] |
| probs["EFS"] = float(np.clip(1.0 - p_dwogf - p_gf, 0.0, 1.0)) |
|
|
| oof_dwogf = oof_probs_calibrated.get("DWOGF") |
| oof_gf = oof_probs_calibrated.get("GF") |
|
|
| if oof_dwogf is not None and oof_gf is not None: |
| oof_dwogf = np.asarray(oof_dwogf, dtype=float) |
| oof_gf = np.asarray(oof_gf, dtype=float) |
| n_min = min(len(oof_dwogf), len(oof_gf)) |
| oof_dwogf = oof_dwogf[:n_min] |
| oof_gf = oof_gf[:n_min] |
|
|
| rng = np.random.RandomState(42) |
| grand_dwogf = np.mean(oof_dwogf) |
| grand_gf = np.mean(oof_gf) |
| shift_dwogf = p_dwogf - grand_dwogf |
| shift_gf = p_gf - grand_gf |
|
|
| efs_boot = np.array([ |
| np.clip( |
| 1.0 |
| - (np.mean(rng.choice(oof_dwogf, size=n_min, replace=True)) + shift_dwogf) |
| - (np.mean(rng.choice(oof_gf, size=n_min, replace=True)) + shift_gf), |
| 0.0, 1.0, |
| ) |
| for _ in range(DEFAULT_N_BOOT_CI) |
| ]) |
| efs_lo = float(np.percentile(efs_boot, 2.5)) |
| efs_hi = float(np.percentile(efs_boot, 97.5)) |
| intervals["EFS"] = (efs_lo, efs_hi) |
| else: |
| intervals["EFS"] = (probs["EFS"], probs["EFS"]) |
|
|
| return probs, intervals |
|
|
|
|
| def predict_with_comparison(user_inputs, n_boot_ci: int = DEFAULT_N_BOOT_CI): |
| cal_probs, cal_intervals = predict_all_outcomes(user_inputs, True, True, n_boot_ci) |
| uncal_probs, uncal_intervals = predict_all_outcomes(user_inputs, False, True, n_boot_ci) |
| return (cal_probs, cal_intervals), (uncal_probs, uncal_intervals) |
|
|
|
|
|
|
|
|
| def _get_shap_values_for_model_outcome(user_inputs, model_outcome, invert, X_proc): |
| """Return per-model SHAP values (shape: n_models × n_processed_features).""" |
| all_model_shap_vals = [] |
| for rf_model in classification_models[model_outcome]: |
| explainer = shap.TreeExplainer(rf_model, model_output="probability", data=shap_background) |
| shap_vals = explainer.shap_values(X_proc) |
|
|
| if isinstance(shap_vals, list): |
| shap_vals = shap_vals[1] |
| elif shap_vals.ndim == 3 and shap_vals.shape[2] == 2: |
| shap_vals = shap_vals[:, :, 1] |
|
|
| sv = shap_vals[0] |
| if invert: |
| sv = -sv |
| all_model_shap_vals.append(sv) |
|
|
| return np.array(all_model_shap_vals) |
|
|
|
|
| def compute_shap_values_with_direction(user_inputs, outcome, max_display=10): |
| if isinstance(user_inputs, dict): |
| input_df = pd.DataFrame([user_inputs]) |
| else: |
| input_df = pd.DataFrame([user_inputs], columns=FEATURE_NAMES) |
|
|
| X_proc = preprocessor.transform(input_df) |
|
|
| processed_to_orig = {f: f for f in NUM_COLUMNS} |
| for pf in ohe_feature_names: |
| processed_to_orig[pf] = pf.split("_", 1)[0] |
|
|
| if outcome == "OS": |
| raw_shap = _get_shap_values_for_model_outcome(user_inputs, "DEAD", invert=True, X_proc=X_proc) |
| elif outcome == "EFS": |
| shap_dwogf = _get_shap_values_for_model_outcome(user_inputs, "DWOGF", invert=True, X_proc=X_proc) |
| shap_gf = _get_shap_values_for_model_outcome(user_inputs, "GF", invert=True, X_proc=X_proc) |
| raw_shap = np.concatenate([shap_dwogf, shap_gf], axis=0) |
| else: |
| raw_shap = _get_shap_values_for_model_outcome(user_inputs, outcome, invert=False, X_proc=X_proc) |
|
|
| unique_orig_features = list(dict.fromkeys(processed_to_orig.values())) |
| n_models = len(raw_shap) |
|
|
| model_shap_by_orig = np.zeros((n_models, len(unique_orig_features))) |
| for model_idx in range(n_models): |
| agg_by_orig = {} |
| for i, pf in enumerate(processed_feature_names): |
| orig = processed_to_orig[pf] |
| agg_by_orig.setdefault(orig, 0.0) |
| agg_by_orig[orig] += raw_shap[model_idx, i] |
| for feat_idx, feat_name in enumerate(unique_orig_features): |
| model_shap_by_orig[model_idx, feat_idx] = agg_by_orig.get(feat_name, 0.0) |
|
|
| mean_shap_vals = np.mean(model_shap_by_orig, axis=0) |
|
|
| rng = np.random.RandomState(42) |
| bootstrap_shap_means = np.array([ |
| np.mean(model_shap_by_orig[rng.choice(n_models, size=n_models, replace=True)], axis=0) |
| for _ in range(DEFAULT_N_BOOT_CI) |
| ]) |
| shap_ci_low = np.percentile(bootstrap_shap_means, 2.5, axis=0) |
| shap_ci_high = np.percentile(bootstrap_shap_means, 97.5, axis=0) |
|
|
| order = np.argsort(-np.abs(mean_shap_vals)) |
|
|
| top_feat_names = [] |
| for i in order[:max_display]: |
| feat_name = unique_orig_features[i] |
| if feat_name in user_inputs: |
| val = user_inputs[feat_name] |
| if isinstance(val, float) and val != int(val): |
| display_name = f"{feat_name} = {val:.2f}" |
| elif isinstance(val, (int, float)): |
| display_name = f"{feat_name} = {int(val)}" |
| else: |
| val_str = str(val) |
| if len(val_str) > 20: |
| val_str = val_str[:17] + "..." |
| display_name = f"{feat_name} = {val_str}" |
| else: |
| display_name = feat_name |
| top_feat_names.append(display_name) |
|
|
| top_feat_names = top_feat_names[::-1] |
| top_shap_vals = mean_shap_vals[order][:max_display][::-1] |
| top_ci_low = shap_ci_low[order][:max_display][::-1] |
| top_ci_high = shap_ci_high[order][:max_display][::-1] |
|
|
| return top_feat_names, top_shap_vals, top_ci_low, top_ci_high |
|
|
|
|
| def create_shap_plot(user_inputs, outcome, max_display=10): |
| feat_names, shap_vals, ci_low, ci_high = compute_shap_values_with_direction( |
| user_inputs, outcome, max_display |
| ) |
|
|
| colors = ["blue" if v >= 0 else "red" for v in shap_vals] |
| error_minus = shap_vals - ci_low |
| error_plus = ci_high - shap_vals |
|
|
| fig = go.Figure() |
| fig.add_trace(go.Bar( |
| y=feat_names, |
| x=shap_vals, |
| orientation="h", |
| marker=dict(color=colors), |
| showlegend=False, |
| error_x=dict( |
| type="data", |
| symmetric=False, |
| array=error_plus, |
| arrayminus=error_minus, |
| color="gray", |
| thickness=1.5, |
| width=4, |
| ), |
| )) |
| fig.add_vline(x=0, line_width=1, line_color="black") |
|
|
| fig.update_layout( |
| title=dict( |
| text=OUTCOME_DESCRIPTIONS.get(outcome, outcome), |
| x=0.5, xanchor="center", |
| font=dict(size=14, color="black"), |
| ), |
| xaxis_title="SHAP value", |
| yaxis_title="", |
| height=400, |
| margin=dict(l=120, r=60, t=50, b=50), |
| plot_bgcolor="white", |
| paper_bgcolor="white", |
| xaxis=dict(showgrid=True, gridcolor="lightgray", zeroline=True, |
| zerolinecolor="black", zerolinewidth=1), |
| yaxis=dict(showgrid=False), |
| ) |
| return fig |
|
|
|
|
| def create_all_shap_plots(user_inputs, max_display=10): |
| return {o: create_shap_plot(user_inputs, o, max_display) for o in SHAP_OUTCOMES} |
|
|
|
|
|
|
|
|
| EVENT_COLOR = "#e53935" |
| NO_EVENT_COLOR = "#43a047" |
|
|
| OUTCOME_TITLES = { |
| "DEAD": "TDeath", |
| "GF": "Graft Failure", |
| "AGVHD": "Acute GvHD", |
| "CGVHD": "Chronic GvHD", |
| "VOCPSHI": "Vaso-Occlusive Crisis", |
| "STROKEHI": "Stroke Post-HCT", |
| } |
|
|
|
|
| OUTCOME_LABELS = { |
| "DEAD": ("Death", "No Death"), |
| "GF": ("Graft Failure", "No Graft Failure"), |
| "AGVHD": ("Acute GVHD", "No Acute GVHD"), |
| "CGVHD": ("Chronic GVHD", "No Chronic GVHD"), |
| "VOCPSHI": ("VOC", "No VOC"), |
| "STROKEHI": ("Stroke", "No Stroke"), |
| } |
|
|
|
|
| def _stick_figure_svg(color: str, size: int = 16) -> str: |
| """Inline SVG stick figure. ViewBox 0 0 20 32 (portrait).""" |
| h = round(size * 1.6) |
| return ( |
| f'<svg xmlns="http://www.w3.org/2000/svg" width="{size}" height="{h}" ' |
| f'viewBox="0 0 20 32" style="display:block;flex-shrink:0;" ' |
| f'stroke="{color}" stroke-width="2.2" stroke-linecap="round" fill="none">' |
| f'<circle cx="10" cy="5" r="3.8" fill="{color}" stroke="none"/>' |
| f'<line x1="10" y1="9" x2="10" y2="20"/>' |
| f'<line x1="3" y1="13" x2="17" y2="13"/>' |
| f'<line x1="10" y1="20" x2="4" y2="30"/>' |
| f'<line x1="10" y1="20" x2="16" y2="30"/>' |
| f'</svg>' |
| ) |
|
|
|
|
| def create_icon_array_html(probability: float, outcome: str) -> str: |
| """ |
| Single outcome card: 10×10 stick-figure grid, red=event, green=no event. |
| All cards have identical fixed-height sections so they align in the grid. |
| Legend always shows 'Event (N/100)' and 'No Event (N/100)' — never varies. |
| """ |
| title = OUTCOME_TITLES.get(outcome, OUTCOME_DESCRIPTIONS.get(outcome, outcome)) |
| event_label, no_event_label = OUTCOME_LABELS.get(outcome, ("Event", "No Event")) |
| n_event = round(probability * 100) |
| n_no_event = 100 - n_event |
| pct_str = f"{probability * 100:.1f}%" |
|
|
| rows_parts = [] |
| for row in range(10): |
| cells = "" |
| for col in range(10): |
| idx = row * 10 + col |
| color = EVENT_COLOR if idx < n_event else NO_EVENT_COLOR |
| cells += _stick_figure_svg(color, size=16) |
| rows_parts.append( |
| f'<div style="display:flex;justify-content:center;gap:2px;margin-bottom:2px;">{cells}</div>' |
| ) |
| grid_html = "\n".join(rows_parts) |
|
|
| |
| fig_event = _stick_figure_svg(EVENT_COLOR, size=13) |
| fig_no_event = _stick_figure_svg(NO_EVENT_COLOR, size=13) |
|
|
| legend_html = ( |
| f'<div style="display:inline-grid;grid-template-columns:16px 130px 44px;' |
| f'align-items:center;gap:4px;row-gap:4px;">' |
| f'{fig_event}' |
| f'<span style="color:{EVENT_COLOR};font-weight:700;font-size:11px;white-space:nowrap;' |
| f'overflow:hidden;text-overflow:ellipsis;">{event_label}</span>' |
| f'<span style="color:#888;font-size:10px;white-space:nowrap;">({n_event}/100)</span>' |
| f'{fig_no_event}' |
| f'<span style="color:{NO_EVENT_COLOR};font-weight:700;font-size:11px;white-space:nowrap;' |
| f'overflow:hidden;text-overflow:ellipsis;">{no_event_label}</span>' |
| f'<span style="color:#888;font-size:10px;white-space:nowrap;">({n_no_event}/100)</span>' |
| f'</div>' |
| ) |
|
|
| return ( |
| f'<div style="background:#fff;border:1px solid #e0e0e0;border-radius:10px;' |
| f'padding:10px 8px;text-align:center;font-family:\'Segoe UI\',Arial,sans-serif;' |
| f'box-shadow:0 2px 6px rgba(0,0,0,0.07);box-sizing:border-box;' |
| f'display:flex;flex-direction:column;align-items:center;">' |
| |
| f'<div style="min-height:34px;display:flex;align-items:center;justify-content:center;' |
| f'font-size:12px;font-weight:700;color:#222;line-height:1.3;margin-bottom:2px;">' |
| f'{title}</div>' |
| |
| f'<div style="font-size:22px;font-weight:800;color:{EVENT_COLOR};' |
| f'line-height:1;margin-bottom:6px;">{pct_str}</div>' |
| |
| f'<div style="margin-bottom:6px;">{grid_html}</div>' |
| |
| f'<div style="margin-top:2px;">{legend_html}</div>' |
| f'</div>' |
| ) |
|
|
|
|
| def create_all_icon_arrays(calibrated_probs: dict) -> dict: |
| """ |
| Returns individual cards + a combined '__grid__' key with the 4×2 layout. |
| All 6 cards are rendered at equal flex widths and equal internal heights. |
| """ |
| pie_outcomes = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI"] |
| cards = {o: create_icon_array_html(calibrated_probs[o], o) for o in pie_outcomes} |
|
|
| rows_html = "" |
| for row_start in range(0, len(pie_outcomes), 4): |
| row_outcomes = pie_outcomes[row_start: row_start + 4] |
| cols = "".join( |
| f'<div style="flex:1 1 0%;min-width:0;">{cards[o]}</div>' |
| for o in row_outcomes |
| ) |
| rows_html += ( |
| f'<div style="display:flex;gap:10px;margin-bottom:10px;">{cols}</div>' |
| ) |
|
|
| footnote = ( |
| f'<div style="font-size:10.5px;color:#888;text-align:center;margin-top:4px;">' |
| f'Each figure = 1 patient out of 100 with similar characteristics. ' |
| f'<span style="color:{EVENT_COLOR};font-weight:600;">■ Red = Event</span>' |
| f' ' |
| f'<span style="color:{NO_EVENT_COLOR};font-weight:600;">■ Green = No Event</span>' |
| f'</div>' |
| ) |
|
|
| cards["__grid__"] = ( |
| f'<div style="font-family:\'Segoe UI\',Arial,sans-serif;padding:4px 0;">' |
| f'{rows_html}{footnote}</div>' |
| ) |
| return cards |
|
|
|
|
|
|
| def create_pie_chart(probability, outcome): |
| return create_icon_array_html(probability, outcome) |
|
|
|
|
| def create_all_pie_charts(calibrated_probs): |
| return create_all_icon_arrays(calibrated_probs) |