File size: 1,924 Bytes
ab6248d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import pandas as pd
import itertools
from typing import Dict, List, Optional

def build_scenarios(
    baseline_row: pd.Series,
    variable_options: Dict[str, List],
    max_scenarios: int = 2000
) -> pd.DataFrame:
    """
    Create scenario dataframe by varying selected variables around a baseline patient row.
    baseline_row: one patient's input row (Series)
    variable_options: {variable_name: [candidate_values]}
    """
    if baseline_row is None or baseline_row.empty:
        raise ValueError("baseline_row is required (single patient input row).")

    variables = list(variable_options.keys())
    choices = [variable_options[v] for v in variables]

    combos = list(itertools.product(*choices))
    if len(combos) > max_scenarios:
        combos = combos[:max_scenarios]

    rows = []
    for combo in combos:
        r = baseline_row.copy()
        for v, val in zip(variables, combo):
            r[v] = val
        rows.append(r)

    return pd.DataFrame(rows)


def rank_scenarios(
    df_pred: pd.DataFrame,
    gvhd_col: str = "pred_aGVHD",
    surv_col: Optional[str] = "surv_1y",
    objective: str = "min_gvhd_max_survival"
) -> pd.DataFrame:
    """
    Rank scenarios:
    - min_gvhd: sort by gvhd_col ascending
    - max_survival: sort by surv_col descending
    - min_gvhd_max_survival: gvhd low first, survival high next
    """
    df = df_pred.copy()

    if objective == "min_gvhd":
        df = df.sort_values([gvhd_col], ascending=[True])
    elif objective == "max_survival" and surv_col and surv_col in df.columns:
        df = df.sort_values([surv_col], ascending=[False])
    else:
        sort_cols = [gvhd_col]
        ascending = [True]
        if surv_col and surv_col in df.columns:
            sort_cols.append(surv_col)
            ascending.append(False)
        df = df.sort_values(sort_cols, ascending=ascending)

    df.reset_index(drop=True, inplace=True)
    return df