Spaces:
Running
Running
ynuozhang
commited on
Commit
·
83f5778
1
Parent(s):
a164d37
update models
Browse files- inference.py +39 -13
inference.py
CHANGED
|
@@ -113,7 +113,8 @@ MODEL_ALIAS = {
|
|
| 113 |
"XGB": "xgb",
|
| 114 |
"XGB_REG": "xgb_reg",
|
| 115 |
"POOLED": "pooled",
|
| 116 |
-
"UNPOOLED": "unpooled"
|
|
|
|
| 117 |
}
|
| 118 |
def canon_model(label: Optional[str]) -> Optional[str]:
|
| 119 |
if label is None:
|
|
@@ -719,15 +720,25 @@ class PeptiVersePredictor:
|
|
| 719 |
self.models[(prop_key, mode)] = obj
|
| 720 |
else:
|
| 721 |
# rebuild NN architecture
|
| 722 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 723 |
|
| 724 |
self.meta[(prop_key, mode)] = {
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
|
|
|
|
|
|
| 731 |
|
| 732 |
def _get_features_for_model(self, prop_key: str, mode: str, input_str: str):
|
| 733 |
"""
|
|
@@ -783,6 +794,14 @@ class PeptiVersePredictor:
|
|
| 783 |
X, M = self._get_features_for_model(prop_key, mode, input_str)
|
| 784 |
with torch.no_grad():
|
| 785 |
y = model(X, M).squeeze().float().cpu().item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 786 |
if task_type == "classifier":
|
| 787 |
prob = float(1.0 / (1.0 + np.exp(-y))) # sigmoid(logit)
|
| 788 |
out = {"property": prop_key, "mode": mode, "score": prob}
|
|
@@ -793,15 +812,22 @@ class PeptiVersePredictor:
|
|
| 793 |
else:
|
| 794 |
return {"property": prop_key, "mode": mode, "score": float(y)}
|
| 795 |
|
| 796 |
-
# xgb path
|
| 797 |
if kind == "xgb":
|
| 798 |
-
feats = self._get_features_for_model(prop_key, mode, input_str)
|
| 799 |
dmat = xgb.DMatrix(feats)
|
| 800 |
pred = float(model.predict(dmat)[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 801 |
out = {"property": prop_key, "mode": mode, "score": pred}
|
| 802 |
-
|
| 803 |
-
out["label"] = int(pred >= float(thr))
|
| 804 |
-
out["threshold"] = float(thr)
|
| 805 |
return out
|
| 806 |
|
| 807 |
# joblib path (svm/enet/svr)
|
|
|
|
| 113 |
"XGB": "xgb",
|
| 114 |
"XGB_REG": "xgb_reg",
|
| 115 |
"POOLED": "pooled",
|
| 116 |
+
"UNPOOLED": "unpooled",
|
| 117 |
+
"TRANSFORMER_WT_LOG": "transformer_wt_log",
|
| 118 |
}
|
| 119 |
def canon_model(label: Optional[str]) -> Optional[str]:
|
| 120 |
if label is None:
|
|
|
|
| 720 |
self.models[(prop_key, mode)] = obj
|
| 721 |
else:
|
| 722 |
# rebuild NN architecture
|
| 723 |
+
arch = m
|
| 724 |
+
if arch.startswith("transformer"):
|
| 725 |
+
arch = "transformer"
|
| 726 |
+
elif arch.startswith("mlp"):
|
| 727 |
+
arch = "mlp"
|
| 728 |
+
elif arch.startswith("cnn"):
|
| 729 |
+
arch = "cnn"
|
| 730 |
+
|
| 731 |
+
self.models[(prop_key, mode)] = build_torch_model_from_ckpt(arch, obj, self.device)
|
| 732 |
|
| 733 |
self.meta[(prop_key, mode)] = {
|
| 734 |
+
"task_type": row.task_type,
|
| 735 |
+
"threshold": thr,
|
| 736 |
+
"artifact": str(art),
|
| 737 |
+
"model_name": m,
|
| 738 |
+
"arch_name": arch,
|
| 739 |
+
"kind": kind,
|
| 740 |
+
}
|
| 741 |
+
|
| 742 |
|
| 743 |
def _get_features_for_model(self, prop_key: str, mode: str, input_str: str):
|
| 744 |
"""
|
|
|
|
| 794 |
X, M = self._get_features_for_model(prop_key, mode, input_str)
|
| 795 |
with torch.no_grad():
|
| 796 |
y = model(X, M).squeeze().float().cpu().item()
|
| 797 |
+
# invert log1p(hours) ONLY for WT half-life log models
|
| 798 |
+
model_name = meta.get("model_name", "")
|
| 799 |
+
if (
|
| 800 |
+
prop_key == "halflife"
|
| 801 |
+
and mode == "wt"
|
| 802 |
+
and model_name in {"xgb_wt_log", "transformer_wt_log"}
|
| 803 |
+
):
|
| 804 |
+
y = float(np.expm1(y))
|
| 805 |
if task_type == "classifier":
|
| 806 |
prob = float(1.0 / (1.0 + np.exp(-y))) # sigmoid(logit)
|
| 807 |
out = {"property": prop_key, "mode": mode, "score": prob}
|
|
|
|
| 812 |
else:
|
| 813 |
return {"property": prop_key, "mode": mode, "score": float(y)}
|
| 814 |
|
|
|
|
| 815 |
if kind == "xgb":
|
| 816 |
+
feats = self._get_features_for_model(prop_key, mode, input_str)
|
| 817 |
dmat = xgb.DMatrix(feats)
|
| 818 |
pred = float(model.predict(dmat)[0])
|
| 819 |
+
|
| 820 |
+
# invert log1p(hours) ONLY for WT half-life log models
|
| 821 |
+
model_name = meta.get("model_name", "")
|
| 822 |
+
if (
|
| 823 |
+
prop_key == "halflife"
|
| 824 |
+
and mode == "wt"
|
| 825 |
+
and model_name in {"xgb_wt_log", "transformer_wt_log"}
|
| 826 |
+
):
|
| 827 |
+
pred = float(np.expm1(pred))
|
| 828 |
+
|
| 829 |
out = {"property": prop_key, "mode": mode, "score": pred}
|
| 830 |
+
|
|
|
|
|
|
|
| 831 |
return out
|
| 832 |
|
| 833 |
# joblib path (svm/enet/svr)
|