lhallee commited on
Commit
bc7152f
·
verified ·
1 Parent(s): cacf282

Upload modeling_fast_esmfold.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fast_esmfold.py +195 -52
modeling_fast_esmfold.py CHANGED
@@ -665,6 +665,59 @@ class FastEsmBackbone(nn.Module):
665
  _ESM_STANDARD_AA = list("ACDEFGHIKLMNPQRSTVWY")
666
 
667
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
668
  @dataclass
669
  class TTTConfig:
670
  lr: float = 4e-4
@@ -683,7 +736,7 @@ class TTTConfig:
683
  freeze_embeddings: bool = True
684
  lora_rank: int = 8
685
  lora_alpha: float = 32.0
686
- lora_target_modules: Tuple[str, ...] = ("query", "key", "value")
687
 
688
  def verify(self) -> None:
689
  assert self.lr > 0.0, "TTT learning rate must be positive."
@@ -761,16 +814,19 @@ class FastEsmForProteinFolding(EsmForProteinFolding):
761
  super().__init__(config)
762
 
763
  # Replace standard ESM2 backbone with FastESM2 (multi-backend attention)
764
- self.esm = FastEsmBackbone(config)
765
- self.esm.requires_grad_(False)
766
- if config.esmfold_config.fp16_esm:
767
- self.esm.half()
 
 
768
 
769
  # MLM head for TTT (pretrained EsmLMHead: Dense -> GELU -> LN -> Linear)
770
  self.mlm_head = EsmLMHead(config)
771
 
772
  # TTT state (lazy initialization)
773
- self._ttt_cfg = TTTConfig(**config.ttt_config)
 
774
  self._ttt_cfg.verify()
775
  self._ttt_initialized = False
776
  self._ttt_initial_state = None
@@ -800,6 +856,9 @@ class FastEsmForProteinFolding(EsmForProteinFolding):
800
  self.mlm_head.eval()
801
  for p in self.mlm_head.parameters():
802
  p.requires_grad = False
 
 
 
803
  self._inject_lora()
804
  else:
805
  # Legacy path: jointly-trained random linear projection head
@@ -816,25 +875,32 @@ class FastEsmForProteinFolding(EsmForProteinFolding):
816
  return self._ttt_cfg.lora_rank > 0
817
 
818
  def _inject_lora(self) -> None:
819
- from peft import LoraConfig, inject_adapter_in_model
820
-
821
- lora_config = LoraConfig(
 
822
  r=self._ttt_cfg.lora_rank,
823
- lora_alpha=self._ttt_cfg.lora_alpha,
824
- target_modules=list(self._ttt_cfg.lora_target_modules),
825
- lora_dropout=0.0,
826
- bias="none",
 
827
  )
828
- inject_adapter_in_model(lora_config, self.esm, adapter_name="ttt")
829
 
830
  # ---- TTT State Management ----
831
 
 
 
 
 
832
  def _ttt_get_state(self) -> Dict[str, Any]:
833
  if self._uses_lora:
834
- lora_state = {
835
- k: v.clone() for k, v in self.esm.state_dict().items()
836
- if "lora_" in k
837
- }
 
 
838
  return {"_lora_state": lora_state}
839
  return {
840
  "esm": copy.deepcopy(self.esm),
@@ -843,9 +909,11 @@ class FastEsmForProteinFolding(EsmForProteinFolding):
843
 
844
  def _ttt_set_state(self, state: Dict[str, Any]) -> None:
845
  if "_lora_state" in state:
846
- current_state = self.esm.state_dict()
847
- current_state.update(state["_lora_state"])
848
- self.esm.load_state_dict(current_state)
 
 
849
  return
850
  if "esm" in state:
851
  self.esm = copy.deepcopy(state["esm"])
@@ -993,11 +1061,9 @@ class FastEsmForProteinFolding(EsmForProteinFolding):
993
 
994
  for parameter in self.parameters():
995
  parameter.requires_grad = False
996
- for name, parameter in self.esm.named_parameters():
997
- if "lora_" in name:
998
- parameter.requires_grad = True
999
- lora_params = [p for n, p in self.esm.named_parameters() if "lora_" in n]
1000
- optimizer = self._ttt_get_optimizer(iter(lora_params))
1001
  optimizer.zero_grad(set_to_none=True)
1002
 
1003
  self.eval()
@@ -1097,49 +1163,126 @@ class FastEsmForProteinFolding(EsmForProteinFolding):
1097
 
1098
  # ---- High-Level API ----
1099
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1100
  def fold_protein(
1101
  self,
1102
  sequence: str,
1103
- ttt: bool = False,
1104
  return_pdb_string: bool = True,
1105
  ) -> Dict[str, Any]:
1106
- """Fold a protein sequence, optionally with test-time training.
 
 
 
 
1107
 
1108
  Args:
1109
  sequence: Protein sequence (single-letter amino acid codes)
1110
- ttt: If True, run test-time training before folding (improves accuracy)
1111
  return_pdb_string: If True, include PDB string in output
1112
 
1113
  Returns:
1114
  Dict with keys:
1115
- - plddt: float, mean per-residue pLDDT confidence score
1116
- - ptm: float, predicted TM-score
1117
- - pdb_string: str (if return_pdb_string=True), PDB format structure
1118
- - ttt_losses: list[float] (if ttt=True), per-step MLM losses
 
1119
  """
1120
- result: Dict[str, Any] = {}
1121
 
1122
- if ttt:
1123
- ttt_result = self.ttt(sequence)
1124
- result["ttt_losses"] = ttt_result["losses"]
 
 
1125
 
1126
- with torch.no_grad():
1127
- output = self.infer(sequence)
1128
 
1129
- plddt = output["plddt"]
1130
- if plddt.dim() >= 2:
1131
- mean_plddt = float(plddt.mean(dim=-1).mean().item())
1132
- else:
1133
- mean_plddt = float(plddt.mean().item())
1134
 
1135
- result["plddt"] = mean_plddt
1136
- result["ptm"] = float(output["ptm"].item()) if "ptm" in output else None
 
1137
 
1138
- if return_pdb_string:
1139
- pdb_strings = self.output_to_pdb(output)
1140
- result["pdb_string"] = pdb_strings[0] if isinstance(pdb_strings, list) else pdb_strings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1141
 
1142
- if ttt:
1143
- self.ttt_reset()
1144
 
1145
- return result
 
 
 
 
 
 
 
 
 
 
 
 
665
  _ESM_STANDARD_AA = list("ACDEFGHIKLMNPQRSTVWY")
666
 
667
 
668
+ class LoraInjectedLinear(nn.Module):
669
+ """LoRA-augmented linear layer matching lora_diffusion's behavior.
670
+
671
+ Replaces an existing nn.Linear with base(x) + lora_up(lora_down(x)) * scale.
672
+ Initialization follows cloneofsimo/lora: down=Normal(0, 1/r), up=zeros.
673
+ """
674
+
675
+ def __init__(self, original_linear: nn.Linear, r: int = 4, scale: float = 1.0):
676
+ super().__init__()
677
+ self.linear = original_linear
678
+ in_features = original_linear.in_features
679
+ out_features = original_linear.out_features
680
+ assert r <= min(in_features, out_features), f"LoRA rank {r} exceeds dimensions ({in_features}, {out_features})"
681
+ self.lora_down = nn.Linear(in_features, r, bias=False)
682
+ self.lora_up = nn.Linear(r, out_features, bias=False)
683
+ self.scale = scale
684
+ nn.init.normal_(self.lora_down.weight, std=1.0 / r)
685
+ nn.init.zeros_(self.lora_up.weight)
686
+
687
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
688
+ return self.linear(x) + self.lora_up(self.lora_down(x)) * self.scale
689
+
690
+
691
+ def inject_trainable_lora(
692
+ model: nn.Module,
693
+ target_class_name: str,
694
+ r: int,
695
+ scale: float,
696
+ ) -> List[nn.Parameter]:
697
+ """Replace nn.Linear layers inside modules matching target_class_name with LoRA.
698
+
699
+ Matches lora_diffusion's inject_trainable_lora behavior: finds all modules whose
700
+ class name matches target_class_name, then replaces their nn.Linear children with
701
+ LoraInjectedLinear. Returns the list of trainable LoRA parameters.
702
+ """
703
+ lora_params: List[nn.Parameter] = []
704
+ for _parent_name, parent_module in model.named_modules():
705
+ if parent_module.__class__.__name__ != target_class_name:
706
+ continue
707
+ for child_name, child_module in list(parent_module.named_children()):
708
+ if not isinstance(child_module, nn.Linear):
709
+ continue
710
+ lora_linear = LoraInjectedLinear(child_module, r=r, scale=scale)
711
+ lora_linear = lora_linear.to(
712
+ device=child_module.weight.device,
713
+ dtype=child_module.weight.dtype,
714
+ )
715
+ setattr(parent_module, child_name, lora_linear)
716
+ lora_params.extend(lora_linear.lora_down.parameters())
717
+ lora_params.extend(lora_linear.lora_up.parameters())
718
+ return lora_params
719
+
720
+
721
  @dataclass
722
  class TTTConfig:
723
  lr: float = 4e-4
 
736
  freeze_embeddings: bool = True
737
  lora_rank: int = 8
738
  lora_alpha: float = 32.0
739
+ lora_target_class: str = "EsmSelfAttention"
740
 
741
  def verify(self) -> None:
742
  assert self.lr > 0.0, "TTT learning rate must be positive."
 
814
  super().__init__(config)
815
 
816
  # Replace standard ESM2 backbone with FastESM2 (multi-backend attention)
817
+ # unless use_standard_backbone is set (for TTT debugging/compatibility)
818
+ if not config.ttt_config.get("use_standard_backbone", False):
819
+ self.esm = FastEsmBackbone(config)
820
+ self.esm.requires_grad_(False)
821
+ if config.esmfold_config.fp16_esm:
822
+ self.esm.half()
823
 
824
  # MLM head for TTT (pretrained EsmLMHead: Dense -> GELU -> LN -> Linear)
825
  self.mlm_head = EsmLMHead(config)
826
 
827
  # TTT state (lazy initialization)
828
+ ttt_kwargs = {k: v for k, v in config.ttt_config.items() if k != "use_standard_backbone"}
829
+ self._ttt_cfg = TTTConfig(**ttt_kwargs)
830
  self._ttt_cfg.verify()
831
  self._ttt_initialized = False
832
  self._ttt_initial_state = None
 
856
  self.mlm_head.eval()
857
  for p in self.mlm_head.parameters():
858
  p.requires_grad = False
859
+ # Seed global state before LoRA init for reproducible weight initialization
860
+ if self._ttt_cfg.seed is not None:
861
+ torch.manual_seed(self._ttt_cfg.seed)
862
  self._inject_lora()
863
  else:
864
  # Legacy path: jointly-trained random linear projection head
 
875
  return self._ttt_cfg.lora_rank > 0
876
 
877
  def _inject_lora(self) -> None:
878
+ """Inject LoRA adapters into ESM2 attention layers (matching lora_diffusion behavior)."""
879
+ self._lora_params = inject_trainable_lora(
880
+ self.esm,
881
+ target_class_name=self._ttt_cfg.lora_target_class,
882
  r=self._ttt_cfg.lora_rank,
883
+ scale=self._ttt_cfg.lora_alpha,
884
+ )
885
+ assert len(self._lora_params) > 0, (
886
+ f"No LoRA params injected. Check target_class_name='{self._ttt_cfg.lora_target_class}' "
887
+ f"matches attention modules in the backbone."
888
  )
 
889
 
890
  # ---- TTT State Management ----
891
 
892
+ def _get_lora_modules(self) -> List[LoraInjectedLinear]:
893
+ """Find all LoraInjectedLinear modules in the backbone."""
894
+ return [m for m in self.esm.modules() if isinstance(m, LoraInjectedLinear)]
895
+
896
  def _ttt_get_state(self) -> Dict[str, Any]:
897
  if self._uses_lora:
898
+ lora_state = []
899
+ for m in self._get_lora_modules():
900
+ lora_state.append({
901
+ "down": m.lora_down.weight.data.clone(),
902
+ "up": m.lora_up.weight.data.clone(),
903
+ })
904
  return {"_lora_state": lora_state}
905
  return {
906
  "esm": copy.deepcopy(self.esm),
 
909
 
910
  def _ttt_set_state(self, state: Dict[str, Any]) -> None:
911
  if "_lora_state" in state:
912
+ modules = self._get_lora_modules()
913
+ assert len(modules) == len(state["_lora_state"])
914
+ for m, saved in zip(modules, state["_lora_state"]):
915
+ m.lora_down.weight.data.copy_(saved["down"])
916
+ m.lora_up.weight.data.copy_(saved["up"])
917
  return
918
  if "esm" in state:
919
  self.esm = copy.deepcopy(state["esm"])
 
1061
 
1062
  for parameter in self.parameters():
1063
  parameter.requires_grad = False
1064
+ for p in self._lora_params:
1065
+ p.requires_grad = True
1066
+ optimizer = self._ttt_get_optimizer(self._lora_params)
 
 
1067
  optimizer.zero_grad(set_to_none=True)
1068
 
1069
  self.eval()
 
1163
 
1164
  # ---- High-Level API ----
1165
 
1166
+ def _fold_single(self, sequence: str, return_pdb_string: bool = True) -> Dict[str, Any]:
1167
+ """Fold a sequence once and return pLDDT, ptm, and optionally PDB string."""
1168
+ with torch.no_grad():
1169
+ output = self.infer(sequence)
1170
+ plddt = output["plddt"]
1171
+ if plddt.dim() >= 2:
1172
+ mean_plddt = float(plddt.mean(dim=-1).mean().item())
1173
+ else:
1174
+ mean_plddt = float(plddt.mean().item())
1175
+ result = {
1176
+ "plddt": mean_plddt,
1177
+ "ptm": float(output["ptm"].item()) if "ptm" in output else None,
1178
+ }
1179
+ if return_pdb_string:
1180
+ pdb_strings = self.output_to_pdb(output)
1181
+ result["pdb_string"] = pdb_strings[0] if isinstance(pdb_strings, list) else pdb_strings
1182
+ return result
1183
+
1184
  def fold_protein(
1185
  self,
1186
  sequence: str,
 
1187
  return_pdb_string: bool = True,
1188
  ) -> Dict[str, Any]:
1189
+ """Fold a protein sequence with test-time training.
1190
+
1191
+ Runs TTT (masked language model adaptation via LoRA) for the configured
1192
+ number of steps, folding after each optimizer step to track pLDDT. Returns
1193
+ the structure with the highest pLDDT across all steps (including baseline).
1194
 
1195
  Args:
1196
  sequence: Protein sequence (single-letter amino acid codes)
 
1197
  return_pdb_string: If True, include PDB string in output
1198
 
1199
  Returns:
1200
  Dict with keys:
1201
+ - plddt: float, best mean pLDDT across all TTT steps
1202
+ - ptm: float, predicted TM-score from best step
1203
+ - pdb_string: str (if return_pdb_string=True), PDB from best step
1204
+ - step_plddts: list[float], pLDDT at each step [baseline, s1, ..., s10]
1205
+ - best_step: int, which step produced best structure (0=baseline)
1206
  """
1207
+ self._ensure_ttt_ready()
1208
 
1209
+ # Cast to fp32 for TTT stability
1210
+ esm_dtype = next(self.esm.parameters()).dtype
1211
+ if esm_dtype != torch.float32:
1212
+ self.esm.float()
1213
+ self.mlm_head.float()
1214
 
1215
+ device = next(self.parameters()).device
1216
+ non_blocking = device.type == "cuda"
1217
 
1218
+ # Step 0: baseline fold (no TTT adaptation)
1219
+ best = self._fold_single(sequence, return_pdb_string=return_pdb_string)
1220
+ step_plddts = [best["plddt"]]
 
 
1221
 
1222
+ if self._ttt_cfg.steps > 0:
1223
+ # Tokenize for masked LM training
1224
+ x = self._ttt_tokenize(sequence)
1225
 
1226
+ # Freeze all, unfreeze LoRA
1227
+ for p in self.parameters():
1228
+ p.requires_grad = False
1229
+ if self._uses_lora:
1230
+ for p in self._lora_params:
1231
+ p.requires_grad = True
1232
+ optimizer = self._ttt_get_optimizer(self._lora_params)
1233
+ else:
1234
+ for p in self.esm.parameters():
1235
+ p.requires_grad = True
1236
+ if self._ttt_cfg.freeze_embeddings:
1237
+ for p in self.esm.embeddings.parameters():
1238
+ p.requires_grad = False
1239
+ for p in self._ttt_lm_proj.parameters():
1240
+ p.requires_grad = True
1241
+ trainable = [p for p in self.parameters() if p.requires_grad]
1242
+ optimizer = self._ttt_get_optimizer(trainable)
1243
+ optimizer.zero_grad(set_to_none=True)
1244
+
1245
+ self.eval()
1246
+ for step in range(self._ttt_cfg.steps * self._ttt_cfg.ags):
1247
+ batch_masked, targets, mask, _start = self._ttt_sample_batch(x)
1248
+ batch_masked = batch_masked.to(device, non_blocking=non_blocking)
1249
+ targets = targets.to(device, non_blocking=non_blocking)
1250
+ mask = mask.to(device, non_blocking=non_blocking)
1251
+
1252
+ self.train()
1253
+ logits = self._ttt_predict_logits(batch_masked)
1254
+ loss = self._ttt_cross_entropy_loss(logits, targets, mask)
1255
+ loss.backward()
1256
+
1257
+ if (step + 1) % self._ttt_cfg.ags == 0:
1258
+ optimizer.step()
1259
+ optimizer.zero_grad(set_to_none=True)
1260
+
1261
+ # Fold after this optimizer step
1262
+ self.eval()
1263
+ current = self._fold_single(sequence, return_pdb_string=return_pdb_string)
1264
+ step_plddts.append(current["plddt"])
1265
+ if current["plddt"] > best["plddt"]:
1266
+ best = current
1267
+
1268
+ self.eval()
1269
+
1270
+ # Restore requires_grad
1271
+ for p in self.parameters():
1272
+ p.requires_grad = False
1273
 
1274
+ # Reset LoRA weights for next sequence
1275
+ self.ttt_reset()
1276
 
1277
+ # Restore dtype
1278
+ if esm_dtype != torch.float32:
1279
+ self.esm.to(esm_dtype)
1280
+ self.mlm_head.to(esm_dtype)
1281
+
1282
+ return {
1283
+ "plddt": best["plddt"],
1284
+ "ptm": best["ptm"],
1285
+ "pdb_string": best.get("pdb_string"),
1286
+ "step_plddts": step_plddts,
1287
+ "best_step": step_plddts.index(max(step_plddts)),
1288
+ }