Upload modeling_fast_esmfold.py with huggingface_hub
Browse files- modeling_fast_esmfold.py +195 -52
modeling_fast_esmfold.py
CHANGED
|
@@ -665,6 +665,59 @@ class FastEsmBackbone(nn.Module):
|
|
| 665 |
_ESM_STANDARD_AA = list("ACDEFGHIKLMNPQRSTVWY")
|
| 666 |
|
| 667 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 668 |
@dataclass
|
| 669 |
class TTTConfig:
|
| 670 |
lr: float = 4e-4
|
|
@@ -683,7 +736,7 @@ class TTTConfig:
|
|
| 683 |
freeze_embeddings: bool = True
|
| 684 |
lora_rank: int = 8
|
| 685 |
lora_alpha: float = 32.0
|
| 686 |
-
|
| 687 |
|
| 688 |
def verify(self) -> None:
|
| 689 |
assert self.lr > 0.0, "TTT learning rate must be positive."
|
|
@@ -761,16 +814,19 @@ class FastEsmForProteinFolding(EsmForProteinFolding):
|
|
| 761 |
super().__init__(config)
|
| 762 |
|
| 763 |
# Replace standard ESM2 backbone with FastESM2 (multi-backend attention)
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
self.esm.
|
|
|
|
|
|
|
| 768 |
|
| 769 |
# MLM head for TTT (pretrained EsmLMHead: Dense -> GELU -> LN -> Linear)
|
| 770 |
self.mlm_head = EsmLMHead(config)
|
| 771 |
|
| 772 |
# TTT state (lazy initialization)
|
| 773 |
-
|
|
|
|
| 774 |
self._ttt_cfg.verify()
|
| 775 |
self._ttt_initialized = False
|
| 776 |
self._ttt_initial_state = None
|
|
@@ -800,6 +856,9 @@ class FastEsmForProteinFolding(EsmForProteinFolding):
|
|
| 800 |
self.mlm_head.eval()
|
| 801 |
for p in self.mlm_head.parameters():
|
| 802 |
p.requires_grad = False
|
|
|
|
|
|
|
|
|
|
| 803 |
self._inject_lora()
|
| 804 |
else:
|
| 805 |
# Legacy path: jointly-trained random linear projection head
|
|
@@ -816,25 +875,32 @@ class FastEsmForProteinFolding(EsmForProteinFolding):
|
|
| 816 |
return self._ttt_cfg.lora_rank > 0
|
| 817 |
|
| 818 |
def _inject_lora(self) -> None:
|
| 819 |
-
|
| 820 |
-
|
| 821 |
-
|
|
|
|
| 822 |
r=self._ttt_cfg.lora_rank,
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
|
|
|
|
| 827 |
)
|
| 828 |
-
inject_adapter_in_model(lora_config, self.esm, adapter_name="ttt")
|
| 829 |
|
| 830 |
# ---- TTT State Management ----
|
| 831 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 832 |
def _ttt_get_state(self) -> Dict[str, Any]:
|
| 833 |
if self._uses_lora:
|
| 834 |
-
lora_state =
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
|
|
|
|
|
|
|
| 838 |
return {"_lora_state": lora_state}
|
| 839 |
return {
|
| 840 |
"esm": copy.deepcopy(self.esm),
|
|
@@ -843,9 +909,11 @@ class FastEsmForProteinFolding(EsmForProteinFolding):
|
|
| 843 |
|
| 844 |
def _ttt_set_state(self, state: Dict[str, Any]) -> None:
|
| 845 |
if "_lora_state" in state:
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
|
|
|
|
|
|
| 849 |
return
|
| 850 |
if "esm" in state:
|
| 851 |
self.esm = copy.deepcopy(state["esm"])
|
|
@@ -993,11 +1061,9 @@ class FastEsmForProteinFolding(EsmForProteinFolding):
|
|
| 993 |
|
| 994 |
for parameter in self.parameters():
|
| 995 |
parameter.requires_grad = False
|
| 996 |
-
for
|
| 997 |
-
|
| 998 |
-
|
| 999 |
-
lora_params = [p for n, p in self.esm.named_parameters() if "lora_" in n]
|
| 1000 |
-
optimizer = self._ttt_get_optimizer(iter(lora_params))
|
| 1001 |
optimizer.zero_grad(set_to_none=True)
|
| 1002 |
|
| 1003 |
self.eval()
|
|
@@ -1097,49 +1163,126 @@ class FastEsmForProteinFolding(EsmForProteinFolding):
|
|
| 1097 |
|
| 1098 |
# ---- High-Level API ----
|
| 1099 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1100 |
def fold_protein(
|
| 1101 |
self,
|
| 1102 |
sequence: str,
|
| 1103 |
-
ttt: bool = False,
|
| 1104 |
return_pdb_string: bool = True,
|
| 1105 |
) -> Dict[str, Any]:
|
| 1106 |
-
"""Fold a protein sequence
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1107 |
|
| 1108 |
Args:
|
| 1109 |
sequence: Protein sequence (single-letter amino acid codes)
|
| 1110 |
-
ttt: If True, run test-time training before folding (improves accuracy)
|
| 1111 |
return_pdb_string: If True, include PDB string in output
|
| 1112 |
|
| 1113 |
Returns:
|
| 1114 |
Dict with keys:
|
| 1115 |
-
- plddt: float, mean
|
| 1116 |
-
- ptm: float, predicted TM-score
|
| 1117 |
-
- pdb_string: str (if return_pdb_string=True), PDB
|
| 1118 |
-
-
|
|
|
|
| 1119 |
"""
|
| 1120 |
-
|
| 1121 |
|
| 1122 |
-
|
| 1123 |
-
|
| 1124 |
-
|
|
|
|
|
|
|
| 1125 |
|
| 1126 |
-
|
| 1127 |
-
|
| 1128 |
|
| 1129 |
-
|
| 1130 |
-
|
| 1131 |
-
|
| 1132 |
-
else:
|
| 1133 |
-
mean_plddt = float(plddt.mean().item())
|
| 1134 |
|
| 1135 |
-
|
| 1136 |
-
|
|
|
|
| 1137 |
|
| 1138 |
-
|
| 1139 |
-
|
| 1140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1141 |
|
| 1142 |
-
|
| 1143 |
-
|
| 1144 |
|
| 1145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 665 |
_ESM_STANDARD_AA = list("ACDEFGHIKLMNPQRSTVWY")
|
| 666 |
|
| 667 |
|
| 668 |
+
class LoraInjectedLinear(nn.Module):
|
| 669 |
+
"""LoRA-augmented linear layer matching lora_diffusion's behavior.
|
| 670 |
+
|
| 671 |
+
Replaces an existing nn.Linear with base(x) + lora_up(lora_down(x)) * scale.
|
| 672 |
+
Initialization follows cloneofsimo/lora: down=Normal(0, 1/r), up=zeros.
|
| 673 |
+
"""
|
| 674 |
+
|
| 675 |
+
def __init__(self, original_linear: nn.Linear, r: int = 4, scale: float = 1.0):
|
| 676 |
+
super().__init__()
|
| 677 |
+
self.linear = original_linear
|
| 678 |
+
in_features = original_linear.in_features
|
| 679 |
+
out_features = original_linear.out_features
|
| 680 |
+
assert r <= min(in_features, out_features), f"LoRA rank {r} exceeds dimensions ({in_features}, {out_features})"
|
| 681 |
+
self.lora_down = nn.Linear(in_features, r, bias=False)
|
| 682 |
+
self.lora_up = nn.Linear(r, out_features, bias=False)
|
| 683 |
+
self.scale = scale
|
| 684 |
+
nn.init.normal_(self.lora_down.weight, std=1.0 / r)
|
| 685 |
+
nn.init.zeros_(self.lora_up.weight)
|
| 686 |
+
|
| 687 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 688 |
+
return self.linear(x) + self.lora_up(self.lora_down(x)) * self.scale
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
def inject_trainable_lora(
|
| 692 |
+
model: nn.Module,
|
| 693 |
+
target_class_name: str,
|
| 694 |
+
r: int,
|
| 695 |
+
scale: float,
|
| 696 |
+
) -> List[nn.Parameter]:
|
| 697 |
+
"""Replace nn.Linear layers inside modules matching target_class_name with LoRA.
|
| 698 |
+
|
| 699 |
+
Matches lora_diffusion's inject_trainable_lora behavior: finds all modules whose
|
| 700 |
+
class name matches target_class_name, then replaces their nn.Linear children with
|
| 701 |
+
LoraInjectedLinear. Returns the list of trainable LoRA parameters.
|
| 702 |
+
"""
|
| 703 |
+
lora_params: List[nn.Parameter] = []
|
| 704 |
+
for _parent_name, parent_module in model.named_modules():
|
| 705 |
+
if parent_module.__class__.__name__ != target_class_name:
|
| 706 |
+
continue
|
| 707 |
+
for child_name, child_module in list(parent_module.named_children()):
|
| 708 |
+
if not isinstance(child_module, nn.Linear):
|
| 709 |
+
continue
|
| 710 |
+
lora_linear = LoraInjectedLinear(child_module, r=r, scale=scale)
|
| 711 |
+
lora_linear = lora_linear.to(
|
| 712 |
+
device=child_module.weight.device,
|
| 713 |
+
dtype=child_module.weight.dtype,
|
| 714 |
+
)
|
| 715 |
+
setattr(parent_module, child_name, lora_linear)
|
| 716 |
+
lora_params.extend(lora_linear.lora_down.parameters())
|
| 717 |
+
lora_params.extend(lora_linear.lora_up.parameters())
|
| 718 |
+
return lora_params
|
| 719 |
+
|
| 720 |
+
|
| 721 |
@dataclass
|
| 722 |
class TTTConfig:
|
| 723 |
lr: float = 4e-4
|
|
|
|
| 736 |
freeze_embeddings: bool = True
|
| 737 |
lora_rank: int = 8
|
| 738 |
lora_alpha: float = 32.0
|
| 739 |
+
lora_target_class: str = "EsmSelfAttention"
|
| 740 |
|
| 741 |
def verify(self) -> None:
|
| 742 |
assert self.lr > 0.0, "TTT learning rate must be positive."
|
|
|
|
| 814 |
super().__init__(config)
|
| 815 |
|
| 816 |
# Replace standard ESM2 backbone with FastESM2 (multi-backend attention)
|
| 817 |
+
# unless use_standard_backbone is set (for TTT debugging/compatibility)
|
| 818 |
+
if not config.ttt_config.get("use_standard_backbone", False):
|
| 819 |
+
self.esm = FastEsmBackbone(config)
|
| 820 |
+
self.esm.requires_grad_(False)
|
| 821 |
+
if config.esmfold_config.fp16_esm:
|
| 822 |
+
self.esm.half()
|
| 823 |
|
| 824 |
# MLM head for TTT (pretrained EsmLMHead: Dense -> GELU -> LN -> Linear)
|
| 825 |
self.mlm_head = EsmLMHead(config)
|
| 826 |
|
| 827 |
# TTT state (lazy initialization)
|
| 828 |
+
ttt_kwargs = {k: v for k, v in config.ttt_config.items() if k != "use_standard_backbone"}
|
| 829 |
+
self._ttt_cfg = TTTConfig(**ttt_kwargs)
|
| 830 |
self._ttt_cfg.verify()
|
| 831 |
self._ttt_initialized = False
|
| 832 |
self._ttt_initial_state = None
|
|
|
|
| 856 |
self.mlm_head.eval()
|
| 857 |
for p in self.mlm_head.parameters():
|
| 858 |
p.requires_grad = False
|
| 859 |
+
# Seed global state before LoRA init for reproducible weight initialization
|
| 860 |
+
if self._ttt_cfg.seed is not None:
|
| 861 |
+
torch.manual_seed(self._ttt_cfg.seed)
|
| 862 |
self._inject_lora()
|
| 863 |
else:
|
| 864 |
# Legacy path: jointly-trained random linear projection head
|
|
|
|
| 875 |
return self._ttt_cfg.lora_rank > 0
|
| 876 |
|
| 877 |
def _inject_lora(self) -> None:
|
| 878 |
+
"""Inject LoRA adapters into ESM2 attention layers (matching lora_diffusion behavior)."""
|
| 879 |
+
self._lora_params = inject_trainable_lora(
|
| 880 |
+
self.esm,
|
| 881 |
+
target_class_name=self._ttt_cfg.lora_target_class,
|
| 882 |
r=self._ttt_cfg.lora_rank,
|
| 883 |
+
scale=self._ttt_cfg.lora_alpha,
|
| 884 |
+
)
|
| 885 |
+
assert len(self._lora_params) > 0, (
|
| 886 |
+
f"No LoRA params injected. Check target_class_name='{self._ttt_cfg.lora_target_class}' "
|
| 887 |
+
f"matches attention modules in the backbone."
|
| 888 |
)
|
|
|
|
| 889 |
|
| 890 |
# ---- TTT State Management ----
|
| 891 |
|
| 892 |
+
def _get_lora_modules(self) -> List[LoraInjectedLinear]:
|
| 893 |
+
"""Find all LoraInjectedLinear modules in the backbone."""
|
| 894 |
+
return [m for m in self.esm.modules() if isinstance(m, LoraInjectedLinear)]
|
| 895 |
+
|
| 896 |
def _ttt_get_state(self) -> Dict[str, Any]:
|
| 897 |
if self._uses_lora:
|
| 898 |
+
lora_state = []
|
| 899 |
+
for m in self._get_lora_modules():
|
| 900 |
+
lora_state.append({
|
| 901 |
+
"down": m.lora_down.weight.data.clone(),
|
| 902 |
+
"up": m.lora_up.weight.data.clone(),
|
| 903 |
+
})
|
| 904 |
return {"_lora_state": lora_state}
|
| 905 |
return {
|
| 906 |
"esm": copy.deepcopy(self.esm),
|
|
|
|
| 909 |
|
| 910 |
def _ttt_set_state(self, state: Dict[str, Any]) -> None:
|
| 911 |
if "_lora_state" in state:
|
| 912 |
+
modules = self._get_lora_modules()
|
| 913 |
+
assert len(modules) == len(state["_lora_state"])
|
| 914 |
+
for m, saved in zip(modules, state["_lora_state"]):
|
| 915 |
+
m.lora_down.weight.data.copy_(saved["down"])
|
| 916 |
+
m.lora_up.weight.data.copy_(saved["up"])
|
| 917 |
return
|
| 918 |
if "esm" in state:
|
| 919 |
self.esm = copy.deepcopy(state["esm"])
|
|
|
|
| 1061 |
|
| 1062 |
for parameter in self.parameters():
|
| 1063 |
parameter.requires_grad = False
|
| 1064 |
+
for p in self._lora_params:
|
| 1065 |
+
p.requires_grad = True
|
| 1066 |
+
optimizer = self._ttt_get_optimizer(self._lora_params)
|
|
|
|
|
|
|
| 1067 |
optimizer.zero_grad(set_to_none=True)
|
| 1068 |
|
| 1069 |
self.eval()
|
|
|
|
| 1163 |
|
| 1164 |
# ---- High-Level API ----
|
| 1165 |
|
| 1166 |
+
def _fold_single(self, sequence: str, return_pdb_string: bool = True) -> Dict[str, Any]:
|
| 1167 |
+
"""Fold a sequence once and return pLDDT, ptm, and optionally PDB string."""
|
| 1168 |
+
with torch.no_grad():
|
| 1169 |
+
output = self.infer(sequence)
|
| 1170 |
+
plddt = output["plddt"]
|
| 1171 |
+
if plddt.dim() >= 2:
|
| 1172 |
+
mean_plddt = float(plddt.mean(dim=-1).mean().item())
|
| 1173 |
+
else:
|
| 1174 |
+
mean_plddt = float(plddt.mean().item())
|
| 1175 |
+
result = {
|
| 1176 |
+
"plddt": mean_plddt,
|
| 1177 |
+
"ptm": float(output["ptm"].item()) if "ptm" in output else None,
|
| 1178 |
+
}
|
| 1179 |
+
if return_pdb_string:
|
| 1180 |
+
pdb_strings = self.output_to_pdb(output)
|
| 1181 |
+
result["pdb_string"] = pdb_strings[0] if isinstance(pdb_strings, list) else pdb_strings
|
| 1182 |
+
return result
|
| 1183 |
+
|
| 1184 |
def fold_protein(
|
| 1185 |
self,
|
| 1186 |
sequence: str,
|
|
|
|
| 1187 |
return_pdb_string: bool = True,
|
| 1188 |
) -> Dict[str, Any]:
|
| 1189 |
+
"""Fold a protein sequence with test-time training.
|
| 1190 |
+
|
| 1191 |
+
Runs TTT (masked language model adaptation via LoRA) for the configured
|
| 1192 |
+
number of steps, folding after each optimizer step to track pLDDT. Returns
|
| 1193 |
+
the structure with the highest pLDDT across all steps (including baseline).
|
| 1194 |
|
| 1195 |
Args:
|
| 1196 |
sequence: Protein sequence (single-letter amino acid codes)
|
|
|
|
| 1197 |
return_pdb_string: If True, include PDB string in output
|
| 1198 |
|
| 1199 |
Returns:
|
| 1200 |
Dict with keys:
|
| 1201 |
+
- plddt: float, best mean pLDDT across all TTT steps
|
| 1202 |
+
- ptm: float, predicted TM-score from best step
|
| 1203 |
+
- pdb_string: str (if return_pdb_string=True), PDB from best step
|
| 1204 |
+
- step_plddts: list[float], pLDDT at each step [baseline, s1, ..., s10]
|
| 1205 |
+
- best_step: int, which step produced best structure (0=baseline)
|
| 1206 |
"""
|
| 1207 |
+
self._ensure_ttt_ready()
|
| 1208 |
|
| 1209 |
+
# Cast to fp32 for TTT stability
|
| 1210 |
+
esm_dtype = next(self.esm.parameters()).dtype
|
| 1211 |
+
if esm_dtype != torch.float32:
|
| 1212 |
+
self.esm.float()
|
| 1213 |
+
self.mlm_head.float()
|
| 1214 |
|
| 1215 |
+
device = next(self.parameters()).device
|
| 1216 |
+
non_blocking = device.type == "cuda"
|
| 1217 |
|
| 1218 |
+
# Step 0: baseline fold (no TTT adaptation)
|
| 1219 |
+
best = self._fold_single(sequence, return_pdb_string=return_pdb_string)
|
| 1220 |
+
step_plddts = [best["plddt"]]
|
|
|
|
|
|
|
| 1221 |
|
| 1222 |
+
if self._ttt_cfg.steps > 0:
|
| 1223 |
+
# Tokenize for masked LM training
|
| 1224 |
+
x = self._ttt_tokenize(sequence)
|
| 1225 |
|
| 1226 |
+
# Freeze all, unfreeze LoRA
|
| 1227 |
+
for p in self.parameters():
|
| 1228 |
+
p.requires_grad = False
|
| 1229 |
+
if self._uses_lora:
|
| 1230 |
+
for p in self._lora_params:
|
| 1231 |
+
p.requires_grad = True
|
| 1232 |
+
optimizer = self._ttt_get_optimizer(self._lora_params)
|
| 1233 |
+
else:
|
| 1234 |
+
for p in self.esm.parameters():
|
| 1235 |
+
p.requires_grad = True
|
| 1236 |
+
if self._ttt_cfg.freeze_embeddings:
|
| 1237 |
+
for p in self.esm.embeddings.parameters():
|
| 1238 |
+
p.requires_grad = False
|
| 1239 |
+
for p in self._ttt_lm_proj.parameters():
|
| 1240 |
+
p.requires_grad = True
|
| 1241 |
+
trainable = [p for p in self.parameters() if p.requires_grad]
|
| 1242 |
+
optimizer = self._ttt_get_optimizer(trainable)
|
| 1243 |
+
optimizer.zero_grad(set_to_none=True)
|
| 1244 |
+
|
| 1245 |
+
self.eval()
|
| 1246 |
+
for step in range(self._ttt_cfg.steps * self._ttt_cfg.ags):
|
| 1247 |
+
batch_masked, targets, mask, _start = self._ttt_sample_batch(x)
|
| 1248 |
+
batch_masked = batch_masked.to(device, non_blocking=non_blocking)
|
| 1249 |
+
targets = targets.to(device, non_blocking=non_blocking)
|
| 1250 |
+
mask = mask.to(device, non_blocking=non_blocking)
|
| 1251 |
+
|
| 1252 |
+
self.train()
|
| 1253 |
+
logits = self._ttt_predict_logits(batch_masked)
|
| 1254 |
+
loss = self._ttt_cross_entropy_loss(logits, targets, mask)
|
| 1255 |
+
loss.backward()
|
| 1256 |
+
|
| 1257 |
+
if (step + 1) % self._ttt_cfg.ags == 0:
|
| 1258 |
+
optimizer.step()
|
| 1259 |
+
optimizer.zero_grad(set_to_none=True)
|
| 1260 |
+
|
| 1261 |
+
# Fold after this optimizer step
|
| 1262 |
+
self.eval()
|
| 1263 |
+
current = self._fold_single(sequence, return_pdb_string=return_pdb_string)
|
| 1264 |
+
step_plddts.append(current["plddt"])
|
| 1265 |
+
if current["plddt"] > best["plddt"]:
|
| 1266 |
+
best = current
|
| 1267 |
+
|
| 1268 |
+
self.eval()
|
| 1269 |
+
|
| 1270 |
+
# Restore requires_grad
|
| 1271 |
+
for p in self.parameters():
|
| 1272 |
+
p.requires_grad = False
|
| 1273 |
|
| 1274 |
+
# Reset LoRA weights for next sequence
|
| 1275 |
+
self.ttt_reset()
|
| 1276 |
|
| 1277 |
+
# Restore dtype
|
| 1278 |
+
if esm_dtype != torch.float32:
|
| 1279 |
+
self.esm.to(esm_dtype)
|
| 1280 |
+
self.mlm_head.to(esm_dtype)
|
| 1281 |
+
|
| 1282 |
+
return {
|
| 1283 |
+
"plddt": best["plddt"],
|
| 1284 |
+
"ptm": best["ptm"],
|
| 1285 |
+
"pdb_string": best.get("pdb_string"),
|
| 1286 |
+
"step_plddts": step_plddts,
|
| 1287 |
+
"best_step": step_plddts.index(max(step_plddts)),
|
| 1288 |
+
}
|