Update utils/inference_ESM2.py
Browse files- utils/inference_ESM2.py +12 -21
utils/inference_ESM2.py
CHANGED
|
@@ -56,7 +56,7 @@ def esm2_embed_batch(seqs, tokenizer, model, device, batch_size=8):
|
|
| 56 |
@torch.no_grad()
|
| 57 |
def get_final_score(
|
| 58 |
epitope,
|
| 59 |
-
metadata_df,
|
| 60 |
tokenizer,
|
| 61 |
esm_model,
|
| 62 |
device,
|
|
@@ -65,37 +65,28 @@ def get_final_score(
|
|
| 65 |
encoder_s2,
|
| 66 |
threshold_s1=0.0
|
| 67 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
# 1. embedding
|
| 69 |
-
emb = esm2_embed_batch(
|
| 70 |
-
[epitope],
|
| 71 |
-
tokenizer,
|
| 72 |
-
esm_model,
|
| 73 |
-
device
|
| 74 |
-
)
|
| 75 |
df_emb = pd.DataFrame(emb)
|
| 76 |
|
| 77 |
-
#
|
| 78 |
-
|
| 79 |
-
# ------------------
|
| 80 |
-
X_s1 = pd.concat(
|
| 81 |
-
[df_emb, metadata_df[['assay','method','state','disease']]],
|
| 82 |
-
axis=1
|
| 83 |
-
)
|
| 84 |
-
|
| 85 |
p1 = model_s1.predict_proba(X_s1)[0, 1]
|
| 86 |
|
| 87 |
if p1 < threshold_s1:
|
| 88 |
return p1, None, None
|
| 89 |
|
| 90 |
-
#
|
| 91 |
-
|
| 92 |
-
# ------------------
|
| 93 |
-
X_cat_s2 = encoder_s2.transform(
|
| 94 |
-
metadata_df[['assay','method','state','disease']]
|
| 95 |
-
)
|
| 96 |
X_s2 = np.hstack([df_emb.values, X_cat_s2])
|
| 97 |
|
| 98 |
p2 = model_s2.predict_proba(X_s2)[0, 1]
|
| 99 |
final = 0.4 * p1 + 0.6 * p2
|
| 100 |
|
| 101 |
return p1, p2, final
|
|
|
|
|
|
| 56 |
@torch.no_grad()
|
| 57 |
def get_final_score(
|
| 58 |
epitope,
|
| 59 |
+
metadata_df,
|
| 60 |
tokenizer,
|
| 61 |
esm_model,
|
| 62 |
device,
|
|
|
|
| 65 |
encoder_s2,
|
| 66 |
threshold_s1=0.0
|
| 67 |
):
|
| 68 |
+
FEATURE_ORDER = ['assay', 'method', 'state', 'disease']
|
| 69 |
+
|
| 70 |
+
# 🔒 컬럼 순서 강제 (Streamlit rerun 방어)
|
| 71 |
+
metadata_df = metadata_df[FEATURE_ORDER]
|
| 72 |
+
|
| 73 |
# 1. embedding
|
| 74 |
+
emb = esm2_embed_batch([epitope], tokenizer, esm_model, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
df_emb = pd.DataFrame(emb)
|
| 76 |
|
| 77 |
+
# Stage 1
|
| 78 |
+
X_s1 = pd.concat([df_emb, metadata_df], axis=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
p1 = model_s1.predict_proba(X_s1)[0, 1]
|
| 80 |
|
| 81 |
if p1 < threshold_s1:
|
| 82 |
return p1, None, None
|
| 83 |
|
| 84 |
+
# Stage 2
|
| 85 |
+
X_cat_s2 = encoder_s2.transform(metadata_df)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
X_s2 = np.hstack([df_emb.values, X_cat_s2])
|
| 87 |
|
| 88 |
p2 = model_s2.predict_proba(X_s2)[0, 1]
|
| 89 |
final = 0.4 * p1 + 0.6 * p2
|
| 90 |
|
| 91 |
return p1, p2, final
|
| 92 |
+
|