Spaces:
Sleeping
Sleeping
Scale ESMFold pLDDT to percent
Browse files
folding_api_service/backends.py
CHANGED
|
@@ -150,12 +150,16 @@ def _esmfold_output_to_pdb(output: Any) -> str:
|
|
| 150 |
final_atom_positions = final_atom_positions.detach().cpu().numpy()
|
| 151 |
final_atom_mask = cpu_data["atom37_atom_exists"]
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
protein = OpenFoldProtein(
|
| 154 |
aatype=cpu_data["aatype"][0],
|
| 155 |
atom_positions=final_atom_positions[0],
|
| 156 |
atom_mask=final_atom_mask[0],
|
| 157 |
residue_index=cpu_data["residue_index"][0] + 1,
|
| 158 |
-
b_factors=
|
| 159 |
chain_index=cpu_data.get("chain_index", [None])[0],
|
| 160 |
)
|
| 161 |
return to_pdb(protein)
|
|
@@ -167,8 +171,12 @@ def _mean_plddt(output: Any) -> float | None:
|
|
| 167 |
if plddt is None:
|
| 168 |
return None
|
| 169 |
if hasattr(plddt, "detach"):
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
|
| 174 |
def make_backend() -> FoldingBackend:
|
|
@@ -178,4 +186,3 @@ def make_backend() -> FoldingBackend:
|
|
| 178 |
if backend == "esmfold":
|
| 179 |
return EsmFoldBackend(os.getenv("ESMFOLD_MODEL_ID", "facebook/esmfold_v1"))
|
| 180 |
raise ValueError(f"unsupported FOLD_BACKEND: {backend}")
|
| 181 |
-
|
|
|
|
| 150 |
final_atom_positions = final_atom_positions.detach().cpu().numpy()
|
| 151 |
final_atom_mask = cpu_data["atom37_atom_exists"]
|
| 152 |
|
| 153 |
+
b_factors = cpu_data["plddt"][0]
|
| 154 |
+
if float(b_factors.max()) <= 1.5:
|
| 155 |
+
b_factors = b_factors * 100.0
|
| 156 |
+
|
| 157 |
protein = OpenFoldProtein(
|
| 158 |
aatype=cpu_data["aatype"][0],
|
| 159 |
atom_positions=final_atom_positions[0],
|
| 160 |
atom_mask=final_atom_mask[0],
|
| 161 |
residue_index=cpu_data["residue_index"][0] + 1,
|
| 162 |
+
b_factors=b_factors,
|
| 163 |
chain_index=cpu_data.get("chain_index", [None])[0],
|
| 164 |
)
|
| 165 |
return to_pdb(protein)
|
|
|
|
| 171 |
if plddt is None:
|
| 172 |
return None
|
| 173 |
if hasattr(plddt, "detach"):
|
| 174 |
+
value = float(plddt.detach().float().mean().cpu().item())
|
| 175 |
+
else:
|
| 176 |
+
value = float(plddt.mean())
|
| 177 |
+
if value <= 1.5:
|
| 178 |
+
value *= 100.0
|
| 179 |
+
return round(value, 4)
|
| 180 |
|
| 181 |
|
| 182 |
def make_backend() -> FoldingBackend:
|
|
|
|
| 186 |
if backend == "esmfold":
|
| 187 |
return EsmFoldBackend(os.getenv("ESMFOLD_MODEL_ID", "facebook/esmfold_v1"))
|
| 188 |
raise ValueError(f"unsupported FOLD_BACKEND: {backend}")
|
|
|