Spaces:
Running
Running
| from __future__ import annotations | |
| from typing import Literal | |
| from ase import Atoms | |
| from chgnet.model.dynamics import CHGNetCalculator | |
| from chgnet.model.model import CHGNet as CHGNetModel | |
| from mlip_arena.models.utils import get_freer_device | |
| class CHGNet(CHGNetCalculator): | |
| def __init__( | |
| self, | |
| checkpoint: CHGNetModel | None = None, # TODO: specifiy version | |
| device: str | None = None, | |
| stress_weight: float | None = 1 / 160.21766208, | |
| on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn", | |
| **kwargs, | |
| ) -> None: | |
| use_device = str(device or get_freer_device()) | |
| super().__init__( | |
| model=checkpoint, | |
| use_device=use_device, | |
| stress_weight=stress_weight, | |
| on_isolated_atoms=on_isolated_atoms, | |
| **kwargs, | |
| ) | |
| def calculate( | |
| self, | |
| atoms: Atoms | None = None, | |
| properties: list | None = None, | |
| system_changes: list | None = None, | |
| ) -> None: | |
| super().calculate(atoms, properties, system_changes) | |
| # for ase.io.write compatibility | |
| self.results.pop("crystal_fea", None) | |