ynuozhang commited on
Commit
83f5778
·
1 Parent(s): a164d37

update models

Browse files
Files changed (1) hide show
  1. inference.py +39 -13
inference.py CHANGED
@@ -113,7 +113,8 @@ MODEL_ALIAS = {
113
  "XGB": "xgb",
114
  "XGB_REG": "xgb_reg",
115
  "POOLED": "pooled",
116
- "UNPOOLED": "unpooled"
 
117
  }
118
  def canon_model(label: Optional[str]) -> Optional[str]:
119
  if label is None:
@@ -719,15 +720,25 @@ class PeptiVersePredictor:
719
  self.models[(prop_key, mode)] = obj
720
  else:
721
  # rebuild NN architecture
722
- self.models[(prop_key, mode)] = build_torch_model_from_ckpt(m, obj, self.device)
 
 
 
 
 
 
 
 
723
 
724
  self.meta[(prop_key, mode)] = {
725
- "task_type": row.task_type,
726
- "threshold": thr,
727
- "artifact": str(art),
728
- "model_name": m,
729
- "kind": kind,
730
- }
 
 
731
 
732
  def _get_features_for_model(self, prop_key: str, mode: str, input_str: str):
733
  """
@@ -783,6 +794,14 @@ class PeptiVersePredictor:
783
  X, M = self._get_features_for_model(prop_key, mode, input_str)
784
  with torch.no_grad():
785
  y = model(X, M).squeeze().float().cpu().item()
 
 
 
 
 
 
 
 
786
  if task_type == "classifier":
787
  prob = float(1.0 / (1.0 + np.exp(-y))) # sigmoid(logit)
788
  out = {"property": prop_key, "mode": mode, "score": prob}
@@ -793,15 +812,22 @@ class PeptiVersePredictor:
793
  else:
794
  return {"property": prop_key, "mode": mode, "score": float(y)}
795
 
796
- # xgb path
797
  if kind == "xgb":
798
- feats = self._get_features_for_model(prop_key, mode, input_str) # (1,H)
799
  dmat = xgb.DMatrix(feats)
800
  pred = float(model.predict(dmat)[0])
 
 
 
 
 
 
 
 
 
 
801
  out = {"property": prop_key, "mode": mode, "score": pred}
802
- if task_type == "classifier" and thr is not None:
803
- out["label"] = int(pred >= float(thr))
804
- out["threshold"] = float(thr)
805
  return out
806
 
807
  # joblib path (svm/enet/svr)
 
113
  "XGB": "xgb",
114
  "XGB_REG": "xgb_reg",
115
  "POOLED": "pooled",
116
+ "UNPOOLED": "unpooled",
117
+ "TRANSFORMER_WT_LOG": "transformer_wt_log",
118
  }
119
  def canon_model(label: Optional[str]) -> Optional[str]:
120
  if label is None:
 
720
  self.models[(prop_key, mode)] = obj
721
  else:
722
  # rebuild NN architecture
723
+ arch = m
724
+ if arch.startswith("transformer"):
725
+ arch = "transformer"
726
+ elif arch.startswith("mlp"):
727
+ arch = "mlp"
728
+ elif arch.startswith("cnn"):
729
+ arch = "cnn"
730
+
731
+ self.models[(prop_key, mode)] = build_torch_model_from_ckpt(arch, obj, self.device)
732
 
733
  self.meta[(prop_key, mode)] = {
734
+ "task_type": row.task_type,
735
+ "threshold": thr,
736
+ "artifact": str(art),
737
+ "model_name": m,
738
+ "arch_name": arch,
739
+ "kind": kind,
740
+ }
741
+
742
 
743
  def _get_features_for_model(self, prop_key: str, mode: str, input_str: str):
744
  """
 
794
  X, M = self._get_features_for_model(prop_key, mode, input_str)
795
  with torch.no_grad():
796
  y = model(X, M).squeeze().float().cpu().item()
797
+ # invert log1p(hours) ONLY for WT half-life log models
798
+ model_name = meta.get("model_name", "")
799
+ if (
800
+ prop_key == "halflife"
801
+ and mode == "wt"
802
+ and model_name in {"xgb_wt_log", "transformer_wt_log"}
803
+ ):
804
+ y = float(np.expm1(y))
805
  if task_type == "classifier":
806
  prob = float(1.0 / (1.0 + np.exp(-y))) # sigmoid(logit)
807
  out = {"property": prop_key, "mode": mode, "score": prob}
 
812
  else:
813
  return {"property": prop_key, "mode": mode, "score": float(y)}
814
 
 
815
  if kind == "xgb":
816
+ feats = self._get_features_for_model(prop_key, mode, input_str)
817
  dmat = xgb.DMatrix(feats)
818
  pred = float(model.predict(dmat)[0])
819
+
820
+ # invert log1p(hours) ONLY for WT half-life log models
821
+ model_name = meta.get("model_name", "")
822
+ if (
823
+ prop_key == "halflife"
824
+ and mode == "wt"
825
+ and model_name in {"xgb_wt_log", "transformer_wt_log"}
826
+ ):
827
+ pred = float(np.expm1(pred))
828
+
829
  out = {"property": prop_key, "mode": mode, "score": pred}
830
+
 
 
831
  return out
832
 
833
  # joblib path (svm/enet/svr)