Spaces:
Build error
Build error
| import numpy as np | |
| import pandas as pd | |
| import skops.io as sio | |
| import shap | |
| import plotly.graph_objects as go | |
| import os | |
| import sys | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning, module="sklearn") | |
| print("===== Application Startup =====") | |
| print(f"Working directory: {os.getcwd()}") | |
| print(f"Files present: {os.listdir('.')}") | |
| # --------------------------------------------------------------------------- | |
| # Compatibility patch | |
| # --------------------------------------------------------------------------- | |
| import sklearn.compose._column_transformer as _ct | |
| if not hasattr(_ct, "_RemainderColsList"): | |
| class _RemainderColsList(list): | |
| def __init__(self, lst=None, future_dtype=None): | |
| super().__init__(lst or []) | |
| self.future_dtype = future_dtype | |
| _ct._RemainderColsList = _RemainderColsList | |
| import sklearn.compose | |
| sklearn.compose._RemainderColsList = _RemainderColsList | |
| print("Patched _RemainderColsList into sklearn.compose") | |
| # --------------------------------------------------------------------------- | |
| # Column / feature definitions | |
| # --------------------------------------------------------------------------- | |
| NUM_COLUMNS = ["AGE", "NACS2YR"] | |
| CATEG_COLUMNS = [ | |
| "AGEGPFF", "SEX", "KPS", "DONORF", "GRAFTYPE", "CONDGRPF", | |
| "CONDGRP_FINAL", "ATGF", "GVHD_FINAL", "HLA_FINAL", | |
| "RCMVPR", "EXCHTFPR", "VOC2YPR", "VOCFRQPR", "SCATXRSN", | |
| ] | |
| FEATURE_NAMES = NUM_COLUMNS + CATEG_COLUMNS | |
| OUTCOMES = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI", "DWOGF"] | |
| CLASSIFICATION_OUTCOMES = OUTCOMES | |
| REPORTING_OUTCOMES = [ | |
| "OS", "EFS", "GF", "DEAD", | |
| "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI", | |
| ] | |
| OUTCOME_DESCRIPTIONS = { | |
| "OS": "Overall Survival", | |
| "EFS": "Event-Free Survival", | |
| "DEAD": "Total Mortality", | |
| "GF": "Graft Failure", | |
| "AGVHD": "Acute Graft-versus-Host Disease", | |
| "CGVHD": "Chronic Graft-versus-Host Disease", | |
| "VOCPSHI": "Vaso-Occlusive Crisis Post-HCT", | |
| "STROKEHI": "Stroke Post-HCT", | |
| } | |
| SHAP_OUTCOMES = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI", "OS", "EFS"] | |
| MODEL_DIR = "." | |
| CONSENSUS_THRESHOLD = 0.5 | |
| DEFAULT_N_BOOT_CI = 500 | |
| # --------------------------------------------------------------------------- | |
| # Model loading | |
| # --------------------------------------------------------------------------- | |
| def _load_skops_model(fname): | |
| if not os.path.exists(fname): | |
| raise RuntimeError(f"Model file not found: {fname}") | |
| try: | |
| untrusted = sio.get_untrusted_types(file=fname) | |
| model = sio.load(fname, trusted=untrusted) | |
| print(f" Loaded: {fname}") | |
| return model | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load '{fname}': {type(e).__name__}: {e}") from e | |
| print("Loading preprocessor...") | |
| preprocessor = _load_skops_model(os.path.join(MODEL_DIR, "preprocessor.skops")) | |
| print("Loading ensemble models...") | |
| classification_model_data = {} | |
| for _o in CLASSIFICATION_OUTCOMES: | |
| _path = os.path.join(MODEL_DIR, f"ensemble_model_{_o}.skops") | |
| if os.path.exists(_path): | |
| classification_model_data[_o] = _load_skops_model(_path) | |
| else: | |
| print(f" Warning: Model for {_o} not found at {_path}. Skipping.") | |
| classification_models = {o: d["models"] for o, d in classification_model_data.items()} | |
| betas = {o: d["beta"] for o, d in classification_model_data.items()} | |
| priors = {o: d["prior"] for o, d in classification_model_data.items()} | |
| consensus_thresholds = { | |
| o: d.get("consensus_threshold", CONSENSUS_THRESHOLD) | |
| for o, d in classification_model_data.items() | |
| } | |
| calibrators = {} | |
| for _o, _d in classification_model_data.items(): | |
| _cal = None | |
| _cal_type = _d.get("calibrator_type", None) | |
| if "calibrator" in _d and _d["calibrator"] is not None: | |
| if _cal_type is None or _cal_type == "isotonic": | |
| _cal = _d["calibrator"] | |
| else: | |
| print(f" Warning: outcome '{_o}' has calibrator_type='{_cal_type}'. Skipping.") | |
| elif "isotonic_calibrator" in _d and _d["isotonic_calibrator"] is not None: | |
| _cal = _d["isotonic_calibrator"] | |
| calibrators[_o] = _cal | |
| isotonic_calibrators = calibrators | |
| oof_probs_calibrated = { | |
| o: d.get("oof_probs_calibrated") for o, d in classification_model_data.items() | |
| } | |
| ohe = preprocessor.named_transformers_["cat"] | |
| ohe_feature_names = ohe.get_feature_names_out(CATEG_COLUMNS) | |
| processed_feature_names = np.concatenate([NUM_COLUMNS, ohe_feature_names]) | |
| print(f"Models loaded: {list(classification_models.keys())}") | |
| # --------------------------------------------------------------------------- | |
| # SHAP background data | |
| # --------------------------------------------------------------------------- | |
| print("Building SHAP background...") | |
| np.random.seed(23) | |
| _n_background = 500 | |
| _background_data = { | |
| "AGE": np.random.uniform(5, 50, _n_background), | |
| "NACS2YR": np.random.randint(0, 5, _n_background), | |
| "AGEGPFF": np.random.choice(["<=10", "11-17", "18-29", "30-49", ">=50"], _n_background), | |
| "SEX": np.random.choice(["Male", "Female"], _n_background), | |
| "KPS": np.random.choice(["<90", "β₯ 90"], _n_background), | |
| "DONORF": np.random.choice([ | |
| "HLA identical sibling", "HLA mismatch relative", | |
| "Matched unrelated donor", | |
| "Mismatched unrelated donor or cord blood", | |
| ], _n_background), | |
| "GRAFTYPE": np.random.choice(["Bone marrow", "Peripheral blood", "Cord blood"], _n_background), | |
| "CONDGRPF": np.random.choice(["MAC", "RIC", "NMA"], _n_background), | |
| "CONDGRP_FINAL": np.random.choice(["TBI/Cy", "Bu/Cy", "Flu/Bu", "Flu/Mel"], _n_background), | |
| "ATGF": np.random.choice(["ATG", "Alemtuzumab", "None"], _n_background), | |
| "GVHD_FINAL": np.random.choice(["CNI + MMF", "CNI + MTX", "Post-CY + siro +- MMF"], _n_background), | |
| "HLA_FINAL": np.random.choice(["8/8", "7/8", "β€ 6/8"], _n_background), | |
| "RCMVPR": np.random.choice(["Negative", "Positive"], _n_background), | |
| "EXCHTFPR": np.random.choice(["No", "Yes"], _n_background), | |
| "VOC2YPR": np.random.choice(["No", "Yes"], _n_background), | |
| "VOCFRQPR": np.random.choice(["< 3/yr", "β₯ 3/yr"], _n_background), | |
| "SCATXRSN": np.random.choice([ | |
| "CNS event", "Acute chest Syndrome", | |
| "Recurrent vaso-occlusive pain", "Recurrent priapism", | |
| "Excessive transfusion requirements/iron overload", | |
| "Cardio-pulmonary", "Chronic transfusion", "Asymptomatic", | |
| "Renal insufficiency", "Splenic sequestration", | |
| "Avascular necrosis", "Hodgkin lymphoma", | |
| ], _n_background), | |
| } | |
| _background_df = pd.DataFrame(_background_data)[FEATURE_NAMES] | |
| _X_background = preprocessor.transform(_background_df) | |
| shap_background = shap.maskers.Independent(_X_background) | |
| print("SHAP background ready.") | |
| # --------------------------------------------------------------------------- | |
| # Calibration helpers | |
| # --------------------------------------------------------------------------- | |
| def calibrate_probabilities_undersampling(p_s, beta): | |
| p_s = np.asarray(p_s, dtype=float) | |
| numerator = beta * p_s | |
| denominator = np.maximum((beta - 1.0) * p_s + 1.0, 1e-10) | |
| return np.clip(numerator / denominator, 0.0, 1.0) | |
| def predict_consensus_signed_voting(ensemble_models, X_test, threshold=0.5): | |
| individual_probas = np.array( | |
| [m.predict_proba(X_test)[:, 1] for m in ensemble_models] | |
| ) | |
| binary_preds = (individual_probas >= threshold).astype(int) | |
| signed_votes = np.where(binary_preds == 1, 1, -1) | |
| avg_signed_vote = np.mean(signed_votes, axis=0) | |
| consensus_pred = (avg_signed_vote > 0).astype(int) | |
| avg_proba = np.mean(individual_probas, axis=0) | |
| return consensus_pred, avg_proba, avg_signed_vote, individual_probas.flatten() | |
| def predict_consensus_majority(ensemble_models, X_test, threshold=0.5): | |
| individual_probas = np.array( | |
| [m.predict_proba(X_test)[:, 1] for m in ensemble_models] | |
| ) | |
| avg_proba = np.mean(individual_probas, axis=0) | |
| return avg_proba, individual_probas.flatten() | |
| # --------------------------------------------------------------------------- | |
| # Bootstrap CI | |
| # --------------------------------------------------------------------------- | |
| def bootstrap_ci_from_oof( | |
| point_estimate: float, | |
| oof_probs: np.ndarray, | |
| n_boot: int = DEFAULT_N_BOOT_CI, | |
| confidence: float = 0.95, | |
| random_state: int = 42, | |
| ) -> tuple: | |
| if oof_probs is None or len(oof_probs) == 0: | |
| return float(point_estimate), float(point_estimate) | |
| oof_probs = np.asarray(oof_probs, dtype=float) | |
| rng = np.random.RandomState(random_state) | |
| grand_mean = np.mean(oof_probs) | |
| n = len(oof_probs) | |
| boot_means = np.array([ | |
| np.mean(rng.choice(oof_probs, size=n, replace=True)) | |
| for _ in range(n_boot) | |
| ]) | |
| shift = point_estimate - grand_mean | |
| boot_means = boot_means + shift | |
| alpha = 1.0 - confidence | |
| lo = float(np.clip(np.percentile(boot_means, 100 * alpha / 2), 0.0, 1.0)) | |
| hi = float(np.clip(np.percentile(boot_means, 100 * (1 - alpha / 2)), 0.0, 1.0)) | |
| return lo, hi | |
| # --------------------------------------------------------------------------- | |
| # Calibration dispatch | |
| # --------------------------------------------------------------------------- | |
| def _calibrate_point(outcome: str, raw_prob: float, use_calibration: bool) -> float: | |
| beta = betas[outcome] | |
| p_beta = float(calibrate_probabilities_undersampling([raw_prob], beta)[0]) | |
| if not use_calibration: | |
| return p_beta | |
| cal = calibrators.get(outcome) | |
| if cal is None: | |
| return p_beta | |
| return float(cal.transform([p_beta])[0]) | |
| # --------------------------------------------------------------------------- | |
| # Main prediction functions | |
| # --------------------------------------------------------------------------- | |
| def predict_all_outcomes( | |
| user_inputs, | |
| use_calibration: bool = True, | |
| use_signed_voting: bool = True, | |
| n_boot_ci: int = DEFAULT_N_BOOT_CI, | |
| ): | |
| if isinstance(user_inputs, dict): | |
| input_df = pd.DataFrame([user_inputs]) | |
| else: | |
| input_df = pd.DataFrame([user_inputs], columns=FEATURE_NAMES) | |
| input_df = input_df[FEATURE_NAMES] | |
| X = preprocessor.transform(input_df) | |
| probs, intervals = {}, {} | |
| for o in CLASSIFICATION_OUTCOMES: | |
| if o not in classification_models: | |
| continue | |
| threshold = consensus_thresholds.get(o, CONSENSUS_THRESHOLD) | |
| if use_signed_voting: | |
| _, uncalib_arr, _, _ = predict_consensus_signed_voting( | |
| classification_models[o], X, threshold | |
| ) | |
| else: | |
| uncalib_arr, _ = predict_consensus_majority( | |
| classification_models[o], X, threshold | |
| ) | |
| raw_prob = float(uncalib_arr[0]) | |
| event_prob = _calibrate_point(o, raw_prob, use_calibration) | |
| lo, hi = bootstrap_ci_from_oof( | |
| point_estimate=event_prob, | |
| oof_probs=oof_probs_calibrated.get(o), | |
| n_boot=n_boot_ci, | |
| ) | |
| probs[o] = event_prob | |
| intervals[o] = (lo, hi) | |
| # OS = 1 - P(DEAD) | |
| if "DEAD" in probs: | |
| p_dead = probs["DEAD"] | |
| probs["OS"] = float(1.0 - p_dead) | |
| dead_lo, dead_hi = intervals["DEAD"] | |
| intervals["OS"] = ( | |
| float(np.clip(1.0 - dead_hi, 0, 1)), | |
| float(np.clip(1.0 - dead_lo, 0, 1)), | |
| ) | |
| # EFS = 1 - P(DWOGF) - P(GF) | |
| if "DWOGF" in probs and "GF" in probs: | |
| p_dwogf = probs["DWOGF"] | |
| p_gf = probs["GF"] | |
| probs["EFS"] = float(np.clip(1.0 - p_dwogf - p_gf, 0.0, 1.0)) | |
| oof_dwogf = oof_probs_calibrated.get("DWOGF") | |
| oof_gf = oof_probs_calibrated.get("GF") | |
| if oof_dwogf is not None and oof_gf is not None: | |
| oof_dwogf = np.asarray(oof_dwogf, dtype=float) | |
| oof_gf = np.asarray(oof_gf, dtype=float) | |
| n_min = min(len(oof_dwogf), len(oof_gf)) | |
| oof_dwogf = oof_dwogf[:n_min] | |
| oof_gf = oof_gf[:n_min] | |
| rng = np.random.RandomState(42) | |
| grand_dwogf = np.mean(oof_dwogf) | |
| grand_gf = np.mean(oof_gf) | |
| shift_dwogf = p_dwogf - grand_dwogf | |
| shift_gf = p_gf - grand_gf | |
| efs_boot = np.array([ | |
| np.clip( | |
| 1.0 | |
| - (np.mean(rng.choice(oof_dwogf, size=n_min, replace=True)) + shift_dwogf) | |
| - (np.mean(rng.choice(oof_gf, size=n_min, replace=True)) + shift_gf), | |
| 0.0, 1.0, | |
| ) | |
| for _ in range(n_boot_ci) | |
| ]) | |
| intervals["EFS"] = ( | |
| float(np.percentile(efs_boot, 2.5)), | |
| float(np.percentile(efs_boot, 97.5)), | |
| ) | |
| else: | |
| intervals["EFS"] = (probs["EFS"], probs["EFS"]) | |
| return probs, intervals | |
| def predict_with_comparison(user_inputs, n_boot_ci: int = DEFAULT_N_BOOT_CI): | |
| cal_probs, cal_intervals = predict_all_outcomes(user_inputs, True, True, n_boot_ci) | |
| uncal_probs, uncal_intervals = predict_all_outcomes(user_inputs, False, True, n_boot_ci) | |
| return (cal_probs, cal_intervals), (uncal_probs, uncal_intervals) | |
| # --------------------------------------------------------------------------- | |
| # SHAP helpers | |
| # --------------------------------------------------------------------------- | |
| def _get_shap_values_for_model_outcome(user_inputs, model_outcome, invert, X_proc): | |
| all_model_shap_vals = [] | |
| for rf_model in classification_models[model_outcome]: | |
| explainer = shap.TreeExplainer(rf_model, model_output="probability", data=shap_background) | |
| shap_vals = explainer.shap_values(X_proc) | |
| if isinstance(shap_vals, list): | |
| shap_vals = shap_vals[1] | |
| elif shap_vals.ndim == 3 and shap_vals.shape[2] == 2: | |
| shap_vals = shap_vals[:, :, 1] | |
| sv = shap_vals[0] | |
| if invert: | |
| sv = -sv | |
| all_model_shap_vals.append(sv) | |
| return np.array(all_model_shap_vals) | |
| def compute_shap_values_with_direction(user_inputs, outcome, max_display=10): | |
| if isinstance(user_inputs, dict): | |
| input_df = pd.DataFrame([user_inputs]) | |
| else: | |
| input_df = pd.DataFrame([user_inputs], columns=FEATURE_NAMES) | |
| X_proc = preprocessor.transform(input_df) | |
| processed_to_orig = {f: f for f in NUM_COLUMNS} | |
| for pf in ohe_feature_names: | |
| processed_to_orig[pf] = pf.split("_", 1)[0] | |
| if outcome == "OS": | |
| raw_shap = _get_shap_values_for_model_outcome(user_inputs, "DEAD", invert=True, X_proc=X_proc) | |
| elif outcome == "EFS": | |
| shap_dwogf = _get_shap_values_for_model_outcome(user_inputs, "DWOGF", invert=True, X_proc=X_proc) | |
| shap_gf = _get_shap_values_for_model_outcome(user_inputs, "GF", invert=True, X_proc=X_proc) | |
| raw_shap = np.concatenate([shap_dwogf, shap_gf], axis=0) | |
| else: | |
| raw_shap = _get_shap_values_for_model_outcome(user_inputs, outcome, invert=False, X_proc=X_proc) | |
| unique_orig_features = list(dict.fromkeys(processed_to_orig.values())) | |
| n_models = len(raw_shap) | |
| model_shap_by_orig = np.zeros((n_models, len(unique_orig_features))) | |
| for model_idx in range(n_models): | |
| agg_by_orig = {} | |
| for i, pf in enumerate(processed_feature_names): | |
| orig = processed_to_orig[pf] | |
| agg_by_orig.setdefault(orig, 0.0) | |
| agg_by_orig[orig] += raw_shap[model_idx, i] | |
| for feat_idx, feat_name in enumerate(unique_orig_features): | |
| model_shap_by_orig[model_idx, feat_idx] = agg_by_orig.get(feat_name, 0.0) | |
| mean_shap_vals = np.mean(model_shap_by_orig, axis=0) | |
| rng = np.random.RandomState(42) | |
| bootstrap_shap_means = np.array([ | |
| np.mean(model_shap_by_orig[rng.choice(n_models, size=n_models, replace=True)], axis=0) | |
| for _ in range(DEFAULT_N_BOOT_CI) | |
| ]) | |
| shap_ci_low = np.percentile(bootstrap_shap_means, 2.5, axis=0) | |
| shap_ci_high = np.percentile(bootstrap_shap_means, 97.5, axis=0) | |
| order = np.argsort(-np.abs(mean_shap_vals)) | |
| top_feat_names = [] | |
| for i in order[:max_display]: | |
| feat_name = unique_orig_features[i] | |
| if feat_name in user_inputs: | |
| val = user_inputs[feat_name] | |
| if isinstance(val, float) and val != int(val): | |
| display_name = f"{feat_name} = {val:.2f}" | |
| elif isinstance(val, (int, float)): | |
| display_name = f"{feat_name} = {int(val)}" | |
| else: | |
| val_str = str(val) | |
| if len(val_str) > 20: | |
| val_str = val_str[:17] + "..." | |
| display_name = f"{feat_name} = {val_str}" | |
| else: | |
| display_name = feat_name | |
| top_feat_names.append(display_name) | |
| top_feat_names = top_feat_names[::-1] | |
| top_shap_vals = mean_shap_vals[order][:max_display][::-1] | |
| top_ci_low = shap_ci_low[order][:max_display][::-1] | |
| top_ci_high = shap_ci_high[order][:max_display][::-1] | |
| return top_feat_names, top_shap_vals, top_ci_low, top_ci_high | |
| def create_shap_plot(user_inputs, outcome, max_display=10): | |
| feat_names, shap_vals, ci_low, ci_high = compute_shap_values_with_direction( | |
| user_inputs, outcome, max_display | |
| ) | |
| colors = ["blue" if v >= 0 else "red" for v in shap_vals] | |
| error_minus = shap_vals - ci_low | |
| error_plus = ci_high - shap_vals | |
| fig = go.Figure() | |
| fig.add_trace(go.Bar( | |
| y=feat_names, | |
| x=shap_vals, | |
| orientation="h", | |
| marker=dict(color=colors), | |
| showlegend=False, | |
| error_x=dict( | |
| type="data", | |
| symmetric=False, | |
| array=error_plus, | |
| arrayminus=error_minus, | |
| color="gray", | |
| thickness=1.5, | |
| width=4, | |
| ), | |
| )) | |
| fig.add_vline(x=0, line_width=1, line_color="black") | |
| fig.update_layout( | |
| title=dict( | |
| text=OUTCOME_DESCRIPTIONS.get(outcome, outcome), | |
| x=0.5, xanchor="center", | |
| font=dict(size=14, color="black"), | |
| ), | |
| xaxis_title="SHAP value", | |
| yaxis_title="", | |
| height=400, | |
| margin=dict(l=120, r=60, t=50, b=50), | |
| plot_bgcolor="white", | |
| paper_bgcolor="white", | |
| xaxis=dict(showgrid=True, gridcolor="lightgray", zeroline=True, | |
| zerolinecolor="black", zerolinewidth=1), | |
| yaxis=dict(showgrid=False), | |
| ) | |
| return fig | |
| def create_all_shap_plots(user_inputs, max_display=10): | |
| return {o: create_shap_plot(user_inputs, o, max_display) for o in SHAP_OUTCOMES} | |
| # --------------------------------------------------------------------------- | |
| # Icon array | |
| # --------------------------------------------------------------------------- | |
| # Root cause of previous gaps / distortion: | |
| # Plotly shape coords are in DATA units. If px-per-data-unit differs on | |
| # x vs y axes the circle head becomes an ellipse and spacing looks uneven. | |
| # | |
| # Fix: | |
| # β’ Use EQUAL axis spans on x and y (both = cols + 2*pad = 10.3) | |
| # β’ Set width and height so that usable pixels are EQUAL on both axes: | |
| # usable_w = W - margin_l - margin_r = W - 20 | |
| # usable_h = H - margin_t - margin_b = H - 100 | |
| # usable_w == usable_h β H = W + 80 | |
| # β’ This guarantees 1 data-unit = same number of pixels on both axes, | |
| # so circles are round and spacing is perfectly uniform. | |
| # --------------------------------------------------------------------------- | |
| def _stick_figure(cx, cy, color, s): | |
| """ | |
| Returns Plotly shape dicts for a stick figure centred at (cx, cy). | |
| s = scale (data units). With a cell size of 1.0, s β 0.46 gives | |
| a figure that fills ~75 % of the cell vertically. | |
| Anatomy (all offsets relative to cy): | |
| head centre : cy + s*0.55 radius s*0.18 | |
| neck top : cy + s*0.35 | |
| hip : cy - s*0.15 | |
| arm branch : cy + s*0.18 | |
| foot : cy - s*0.55 | |
| """ | |
| shapes = [] | |
| lw = dict(color=color, width=1.8) # fixed pixel width β looks consistent | |
| # head | |
| hr = s * 0.18 | |
| hy = cy + s * 0.55 | |
| shapes.append(dict( | |
| type="circle", xref="x", yref="y", | |
| x0=cx - hr, y0=hy - hr, | |
| x1=cx + hr, y1=hy + hr, | |
| fillcolor=color, | |
| line=dict(color=color, width=0), | |
| )) | |
| neck_y = cy + s * 0.35 | |
| hip_y = cy - s * 0.15 | |
| arm_y = cy + s * 0.18 | |
| foot_y = cy - s * 0.55 | |
| # spine | |
| shapes.append(dict(type="line", xref="x", yref="y", | |
| x0=cx, y0=neck_y, x1=cx, y1=hip_y, line=lw)) | |
| # arms | |
| adx = s * 0.32 | |
| ady = s * 0.15 | |
| shapes.append(dict(type="line", xref="x", yref="y", | |
| x0=cx, y0=arm_y, x1=cx - adx, y1=arm_y - ady, line=lw)) | |
| shapes.append(dict(type="line", xref="x", yref="y", | |
| x0=cx, y0=arm_y, x1=cx + adx, y1=arm_y - ady, line=lw)) | |
| # legs | |
| ldx = s * 0.26 | |
| shapes.append(dict(type="line", xref="x", yref="y", | |
| x0=cx, y0=hip_y, x1=cx - ldx, y1=foot_y, line=lw)) | |
| shapes.append(dict(type="line", xref="x", yref="y", | |
| x0=cx, y0=hip_y, x1=cx + ldx, y1=foot_y, line=lw)) | |
| return shapes | |
| def icon_array(probability, outcome): | |
| outcome_labels = { | |
| "DEAD": ("Death", "Overall Survival"), | |
| "GF": ("Graft Failure", "No Graft Failure"), | |
| "AGVHD": ("AGVHD", "No AGVHD"), | |
| "CGVHD": ("CGVHD", "No CGVHD"), | |
| "VOCPSHI": ("VOC Post-HCT", "No VOC Post-HCT"), | |
| "STROKEHI": ("Stroke Post-HCT", "No Stroke Post-HCT"), | |
| } | |
| event_label, no_event_label = outcome_labels.get(outcome, ("Event", "No Event")) | |
| n_event = round(probability * 100) | |
| n_no_event = 100 - n_event | |
| cols, rows = 10, 10 | |
| # ββ Layout constants ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Icons sit on an integer grid 0..9 Γ 0..9. | |
| # Padding of 0.65 on each side β axis span = 9 + 2*0.65 = 10.30 | |
| # Margins: left=10, right=10, top=95, bottom=10 | |
| # usable_w = W - 20 ; usable_h = H - 105 | |
| # To ensure px_per_unit identical on both axes: usable_w == usable_h | |
| # β H = W + 85 | |
| # We also enforce equal axis spans (both 10.30). | |
| PAD = 0.65 | |
| W = 400 | |
| H = W + 85 # = 485 β usable = 380 px on both axes | |
| S = 0.46 # figure scale (β 75 % vertical fill per cell) | |
| x_lo, x_hi = -PAD, (cols - 1) + PAD # -0.65 β¦ 9.65 span=10.30 | |
| y_lo, y_hi = -PAD, (rows - 1) + PAD # -0.65 β¦ 9.65 span=10.30 | |
| all_shapes = [] | |
| icon_idx = 0 | |
| for row in range(rows): # row 0 β top of grid | |
| for col in range(cols): # col 0 β left | |
| color = "#e05555" if icon_idx < n_event else "#3bbfad" | |
| cx = col | |
| cy = (rows - 1) - row # invert: row 0 β cy=9 (top) | |
| all_shapes.extend(_stick_figure(cx, cy, color, S)) | |
| icon_idx += 1 | |
| fig = go.Figure() | |
| fig.update_layout( | |
| title=dict( | |
| text=( | |
| f"<b>{OUTCOME_DESCRIPTIONS.get(outcome, outcome)}</b><br>" | |
| f"<span style='font-size:12px;color:#e05555'>" | |
| f"β {event_label}: {n_event}%</span>" | |
| f" " | |
| f"<span style='font-size:12px;color:#3bbfad'>" | |
| f"β {no_event_label}: {n_no_event}%</span>" | |
| ), | |
| x=0.5, xanchor="center", | |
| font=dict(size=14, color="black"), | |
| ), | |
| shapes=all_shapes, | |
| xaxis=dict( | |
| range=[x_lo, x_hi], | |
| showgrid=False, zeroline=False, showticklabels=False, | |
| fixedrange=True, | |
| ), | |
| yaxis=dict( | |
| range=[y_lo, y_hi], | |
| showgrid=False, zeroline=False, showticklabels=False, | |
| fixedrange=True, | |
| # scaleanchor / scaleratio intentionally OMITTED β | |
| # equal spans + equal usable pixels already guarantee | |
| # identical px/unit on both axes without distortion. | |
| ), | |
| width=W, | |
| height=H, | |
| margin=dict(l=10, r=10, t=95, b=10), | |
| plot_bgcolor="white", | |
| paper_bgcolor="white", | |
| ) | |
| return fig | |
| print("===== inference.py loaded successfully =====") |