mfarnas commited on
Commit
caa7bf6
·
1 Parent(s): 6df7f66

fix tgt col based on model selected

Browse files
src/inference_utils.py CHANGED
@@ -16,10 +16,11 @@ def compute_metrics(y_true, y_pred_proba, threshold=0.5):
16
 
17
  def add_predictions(df, probs):
18
  df['Predicted Probability'] = probs
19
- df['GVHD Prediction'] = ['POSITIVE' if p > 0.5 else 'NEGATIVE' for p in probs]
 
 
 
20
 
21
- df_with_gt = df[['Predicted Probability', 'GVHD Prediction']].join(st.session_state.targets_df)
22
-
23
  # Define cell-level styling
24
  def highlight_prediction(val):
25
  if val == "POSITIVE":
@@ -31,7 +32,7 @@ def add_predictions(df, probs):
31
  # Apply color and alignment
32
  df_styled = (
33
  df_with_gt.style
34
- .applymap(highlight_prediction, subset=["GVHD Prediction"])
35
  .set_properties(**{'text-align': 'center'}) # Apply center alignment to all cells
36
  )
37
 
 
16
 
17
  def add_predictions(df, probs):
18
  df['Predicted Probability'] = probs
19
+ pred_col = f"{st.session_state.target_col} Prediction"
20
+ df[pred_col] = ['POSITIVE' if p > 0.5 else 'NEGATIVE' for p in probs]
21
+
22
+ df_with_gt = df[['Predicted Probability', pred_col]].join(st.session_state.targets_df)
23
 
 
 
24
  # Define cell-level styling
25
  def highlight_prediction(val):
26
  if val == "POSITIVE":
 
32
  # Apply color and alignment
33
  df_styled = (
34
  df_with_gt.style
35
+ .applymap(highlight_prediction, subset=[pred_col])
36
  .set_properties(**{'text-align': 'center'}) # Apply center alignment to all cells
37
  )
38
 
src/pages/1_Individual_Predictions.py CHANGED
@@ -190,14 +190,17 @@ if submitted:
190
  if "ensemble" in st.session_state.selected_model:
191
  # ensemble prediction
192
  models = load_model_ensemble(st.session_state.selected_model)
 
193
  models = models["model"]
194
  pred = ensemble_predict(models, X, cat_features)
195
  else:
196
  # single model prediction
197
  model = load_model(st.session_state.selected_model)
 
198
  model = model["model"]
199
  pred = model.predict_proba(X)[0][1]
200
 
 
201
  result_df = pd.DataFrame()
202
  result_df = add_predictions(result_df, [pred])
203
 
 
190
  if "ensemble" in st.session_state.selected_model:
191
  # ensemble prediction
192
  models = load_model_ensemble(st.session_state.selected_model)
193
+ st.session_state.target_col = models.get("target_col", "UNKNOWN")
194
  models = models["model"]
195
  pred = ensemble_predict(models, X, cat_features)
196
  else:
197
  # single model prediction
198
  model = load_model(st.session_state.selected_model)
199
+ st.session_state.target_col = model.get("target_col", "UNKNOWN")
200
  model = model["model"]
201
  pred = model.predict_proba(X)[0][1]
202
 
203
+ st.warning(f"The model selected will only predict the target \"{st.session_state.target_col}\". Please choose a different model if you want to predict a different target.")
204
  result_df = pd.DataFrame()
205
  result_df = add_predictions(result_df, [pred])
206
 
src/pages/2_Bulk_Predictions.py CHANGED
@@ -9,6 +9,29 @@ from sidebar import sidebar
9
  # Initialize sidebar
10
  sidebar()
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  st.title("📊 Bulk Patient Predictions")
13
 
14
  uploaded_file = st.file_uploader("Upload CSV", type=["csv"])
@@ -27,41 +50,25 @@ if uploaded_file:
27
 
28
  # TODO: Define the target column (customize this based on your use case)
29
  # target_col = st.session_state.target_col # "GVHD" # or "Acute GVHD(<100 days)", etc.
30
- st.session_state.target_col = st.selectbox(
31
- "Select target column to predict:",
32
- options=[
33
- "GVHD",
34
- "Acute GVHD(<100 days)",
35
- "Chronic GVHD>100 days",
36
- ],
37
- index=0
38
- )
 
 
39
 
40
  if st.button("Predict"):
41
  if "bulk_input_df" not in st.session_state:
42
  st.warning("Please preprocess data first.")
43
  else:
44
- if "ensemble" in st.session_state.selected_model:
45
- # ensemble model
46
- ensemble = True
47
- try:
48
- ensemble_data = load_model_ensemble(st.session_state.selected_model)
49
- st.session_state.trained_models = ensemble_data["model"]
50
- models = ensemble_data["model"]
51
- st.session_state.best_iterations = ensemble_data.get("best_iterations", [])
52
- st.session_state.fold_scores = ensemble_data.get("fold_scores", [])
53
-
54
- except Exception as e:
55
- st.error(f"Error loading ensemble: {str(e)}")
56
- else:
57
- # single model
58
- ensemble = False
59
- model_dict = load_model(st.session_state.selected_model)
60
- model = model_dict["model"]
61
-
62
  df = st.session_state.bulk_input_df
63
 
64
- target_col = st.session_state.target_col
65
 
66
  # Optional filtering depending on target choice
67
  if target_col in ["Acute GVHD(<100 days)", "Chronic GVHD>100 days"]:
 
9
  # Initialize sidebar
10
  sidebar()
11
 
12
+ if "selected_model" in st.session_state:
13
+ if "ensemble" in st.session_state.selected_model:
14
+ # ensemble model
15
+ ensemble = True
16
+ try:
17
+ ensemble_data = load_model_ensemble(st.session_state.selected_model)
18
+ st.session_state.trained_models = ensemble_data["model"]
19
+ models = ensemble_data["model"]
20
+ st.session_state.best_iterations = ensemble_data.get("best_iterations", [])
21
+ st.session_state.fold_scores = ensemble_data.get("fold_scores", [])
22
+ target_col = ensemble_data.get("target_col", "UNKNOWN")
23
+
24
+ except Exception as e:
25
+ st.error(f"Error loading ensemble: {str(e)}")
26
+ else:
27
+ # single model
28
+ ensemble = False
29
+ model_dict = load_model(st.session_state.selected_model)
30
+ model = model_dict["model"]
31
+ target_col = model_dict.get("target_col", "UNKNOWN")
32
+
33
+ st.warning(f"The model selected will only predict the target \"{target_col}\". Please choose a different model if you want to predict a different target.")
34
+
35
  st.title("📊 Bulk Patient Predictions")
36
 
37
  uploaded_file = st.file_uploader("Upload CSV", type=["csv"])
 
50
 
51
  # TODO: Define the target column (customize this based on your use case)
52
  # target_col = st.session_state.target_col # "GVHD" # or "Acute GVHD(<100 days)", etc.
53
+
54
+ # # allow dynamic selection for target column
55
+ # st.session_state.target_col = st.selectbox(
56
+ # "Select target column to predict:",
57
+ # options=[
58
+ # "GVHD",
59
+ # "Acute GVHD(<100 days)",
60
+ # "Chronic GVHD>100 days",
61
+ # ],
62
+ # index=0
63
+ # )
64
 
65
  if st.button("Predict"):
66
  if "bulk_input_df" not in st.session_state:
67
  st.warning("Please preprocess data first.")
68
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  df = st.session_state.bulk_input_df
70
 
71
+ # target_col = st.session_state.target_col
72
 
73
  # Optional filtering depending on target choice
74
  if target_col in ["Acute GVHD(<100 days)", "Chronic GVHD>100 days"]: