yunuk0 commited on
Commit
f2fdddb
·
verified ·
1 Parent(s): a3b0f47

Update utils/inference_ESM2.py

Browse files
Files changed (1) hide show
  1. utils/inference_ESM2.py +12 -21
utils/inference_ESM2.py CHANGED
@@ -56,7 +56,7 @@ def esm2_embed_batch(seqs, tokenizer, model, device, batch_size=8):
56
  @torch.no_grad()
57
  def get_final_score(
58
  epitope,
59
- metadata_df, # 반드시 ['assay','method','state','disease'] string
60
  tokenizer,
61
  esm_model,
62
  device,
@@ -65,37 +65,28 @@ def get_final_score(
65
  encoder_s2,
66
  threshold_s1=0.0
67
  ):
 
 
 
 
 
68
  # 1. embedding
69
- emb = esm2_embed_batch(
70
- [epitope],
71
- tokenizer,
72
- esm_model,
73
- device
74
- )
75
  df_emb = pd.DataFrame(emb)
76
 
77
- # ------------------
78
- # Stage 1 (CatBoost)
79
- # ------------------
80
- X_s1 = pd.concat(
81
- [df_emb, metadata_df[['assay','method','state','disease']]],
82
- axis=1
83
- )
84
-
85
  p1 = model_s1.predict_proba(X_s1)[0, 1]
86
 
87
  if p1 < threshold_s1:
88
  return p1, None, None
89
 
90
- # ------------------
91
- # Stage 2 (XGBoost)
92
- # ------------------
93
- X_cat_s2 = encoder_s2.transform(
94
- metadata_df[['assay','method','state','disease']]
95
- )
96
  X_s2 = np.hstack([df_emb.values, X_cat_s2])
97
 
98
  p2 = model_s2.predict_proba(X_s2)[0, 1]
99
  final = 0.4 * p1 + 0.6 * p2
100
 
101
  return p1, p2, final
 
 
56
  @torch.no_grad()
57
  def get_final_score(
58
  epitope,
59
+ metadata_df,
60
  tokenizer,
61
  esm_model,
62
  device,
 
65
  encoder_s2,
66
  threshold_s1=0.0
67
  ):
68
+ FEATURE_ORDER = ['assay', 'method', 'state', 'disease']
69
+
70
+ # 🔒 컬럼 순서 강제 (Streamlit rerun 방어)
71
+ metadata_df = metadata_df[FEATURE_ORDER]
72
+
73
  # 1. embedding
74
+ emb = esm2_embed_batch([epitope], tokenizer, esm_model, device)
 
 
 
 
 
75
  df_emb = pd.DataFrame(emb)
76
 
77
+ # Stage 1
78
+ X_s1 = pd.concat([df_emb, metadata_df], axis=1)
 
 
 
 
 
 
79
  p1 = model_s1.predict_proba(X_s1)[0, 1]
80
 
81
  if p1 < threshold_s1:
82
  return p1, None, None
83
 
84
+ # Stage 2
85
+ X_cat_s2 = encoder_s2.transform(metadata_df)
 
 
 
 
86
  X_s2 = np.hstack([df_emb.values, X_cat_s2])
87
 
88
  p2 = model_s2.predict_proba(X_s2)[0, 1]
89
  final = 0.4 * p1 + 0.6 * p2
90
 
91
  return p1, p2, final
92
+