Upload modeling_fast_esmfold.py with huggingface_hub
Browse files- modeling_fast_esmfold.py +14 -2
modeling_fast_esmfold.py
CHANGED
|
@@ -1079,9 +1079,21 @@ class FastEsmForProteinFolding(EsmForProteinFolding):
|
|
| 1079 |
Dict with "losses" key containing per-step MLM loss values
|
| 1080 |
"""
|
| 1081 |
self._ensure_ttt_ready()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1082 |
if self._uses_lora:
|
| 1083 |
-
|
| 1084 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1085 |
|
| 1086 |
# ---- High-Level API ----
|
| 1087 |
|
|
|
|
| 1079 |
Dict with "losses" key containing per-step MLM loss values
|
| 1080 |
"""
|
| 1081 |
self._ensure_ttt_ready()
|
| 1082 |
+
# TTT requires fp32 for stable gradient computation. ESMFold typically
|
| 1083 |
+
# runs the backbone in fp16, but small LoRA updates vanish in half precision.
|
| 1084 |
+
esm_dtype = next(self.esm.parameters()).dtype
|
| 1085 |
+
if esm_dtype != torch.float32:
|
| 1086 |
+
self.esm.float()
|
| 1087 |
+
self.mlm_head.float()
|
| 1088 |
if self._uses_lora:
|
| 1089 |
+
result = self._lora_ttt(seq)
|
| 1090 |
+
else:
|
| 1091 |
+
result = self._legacy_ttt(seq)
|
| 1092 |
+
# Restore original dtype (backbone back to fp16 for inference)
|
| 1093 |
+
if esm_dtype != torch.float32:
|
| 1094 |
+
self.esm.to(esm_dtype)
|
| 1095 |
+
self.mlm_head.to(esm_dtype)
|
| 1096 |
+
return result
|
| 1097 |
|
| 1098 |
# ---- High-Level API ----
|
| 1099 |
|