gvhd-intel-pro / src /scenario_utils.py
Synav's picture
Update src/scenario_utils.py
ab6248d verified
import pandas as pd
import itertools
from typing import Dict, List, Optional
def build_scenarios(
baseline_row: pd.Series,
variable_options: Dict[str, List],
max_scenarios: int = 2000
) -> pd.DataFrame:
"""
Create scenario dataframe by varying selected variables around a baseline patient row.
baseline_row: one patient's input row (Series)
variable_options: {variable_name: [candidate_values]}
"""
if baseline_row is None or baseline_row.empty:
raise ValueError("baseline_row is required (single patient input row).")
variables = list(variable_options.keys())
choices = [variable_options[v] for v in variables]
combos = list(itertools.product(*choices))
if len(combos) > max_scenarios:
combos = combos[:max_scenarios]
rows = []
for combo in combos:
r = baseline_row.copy()
for v, val in zip(variables, combo):
r[v] = val
rows.append(r)
return pd.DataFrame(rows)
def rank_scenarios(
df_pred: pd.DataFrame,
gvhd_col: str = "pred_aGVHD",
surv_col: Optional[str] = "surv_1y",
objective: str = "min_gvhd_max_survival"
) -> pd.DataFrame:
"""
Rank scenarios:
- min_gvhd: sort by gvhd_col ascending
- max_survival: sort by surv_col descending
- min_gvhd_max_survival: gvhd low first, survival high next
"""
df = df_pred.copy()
if objective == "min_gvhd":
df = df.sort_values([gvhd_col], ascending=[True])
elif objective == "max_survival" and surv_col and surv_col in df.columns:
df = df.sort_values([surv_col], ascending=[False])
else:
sort_cols = [gvhd_col]
ascending = [True]
if surv_col and surv_col in df.columns:
sort_cols.append(surv_col)
ascending.append(False)
df = df.sort_values(sort_cols, ascending=ascending)
df.reset_index(drop=True, inplace=True)
return df