lhallee commited on
Commit
cacf282
·
verified ·
1 Parent(s): 1050cef

Upload modeling_fast_esmfold.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fast_esmfold.py +14 -2
modeling_fast_esmfold.py CHANGED
@@ -1079,9 +1079,21 @@ class FastEsmForProteinFolding(EsmForProteinFolding):
1079
  Dict with "losses" key containing per-step MLM loss values
1080
  """
1081
  self._ensure_ttt_ready()
 
 
 
 
 
 
1082
  if self._uses_lora:
1083
- return self._lora_ttt(seq)
1084
- return self._legacy_ttt(seq)
 
 
 
 
 
 
1085
 
1086
  # ---- High-Level API ----
1087
 
 
1079
  Dict with "losses" key containing per-step MLM loss values
1080
  """
1081
  self._ensure_ttt_ready()
1082
+ # TTT requires fp32 for stable gradient computation. ESMFold typically
1083
+ # runs the backbone in fp16, but small LoRA updates vanish in half precision.
1084
+ esm_dtype = next(self.esm.parameters()).dtype
1085
+ if esm_dtype != torch.float32:
1086
+ self.esm.float()
1087
+ self.mlm_head.float()
1088
  if self._uses_lora:
1089
+ result = self._lora_ttt(seq)
1090
+ else:
1091
+ result = self._legacy_ttt(seq)
1092
+ # Restore original dtype (backbone back to fp16 for inference)
1093
+ if esm_dtype != torch.float32:
1094
+ self.esm.to(esm_dtype)
1095
+ self.mlm_head.to(esm_dtype)
1096
+ return result
1097
 
1098
  # ---- High-Level API ----
1099