# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import importlib import inspect import io import json import keyword import logging import os import re import tarfile import tempfile import warnings import zipfile from pathlib import Path from typing import Any, Dict, Optional, Set, Union import torch import physicsnemo from physicsnemo.models.meta import ModelMetaData from physicsnemo.registry import ModelRegistry from physicsnemo.utils.filesystem import _download_cached, _get_fs # Used for saving checkpoints of nested modules _BASE_CKPT_PREFIX = "__physicsnemo.Module__" def _load_state_dict_with_logging( module: torch.nn.Module, state_dict: Dict[str, Any], *args, **kwargs ): """Load state dictionary and log missing and unexpected keys Parameters ---------- module : torch.nn.Module Module to load state dictionary into state_dict : Dict[str, Any] State dictionary to load *args, **kwargs Additional arguments to pass to load_state_dict """ missing_keys, unexpected_keys = module.load_state_dict(state_dict, *args, **kwargs) if missing_keys: logging.warning( f"Missing keys when loading {module.__class__.__name__}: {missing_keys}" ) if unexpected_keys: logging.warning( f"Unexpected keys when loading {module.__class__.__name__}: {unexpected_keys}" ) return missing_keys, unexpected_keys class Module(torch.nn.Module): """The base class for all network models in PhysicsNeMo. This should be used as a direct replacement for torch.nn.module and provides additional functionality for saving and loading models, as well as handling file system abstractions. There is one important requirement for all models in PhysicsNeMo. They must have json serializable arguments in their ``__init__`` function. This is required for saving and loading models and allow models to be instantiated from a checkpoint. The only one exception to this rule is when the argument passed to the ``__init__`` function is itself a ``physicsnemo.Module`` instance. In this case, it is possible to construct, save and load nested Modules, with multiple levels of nesting and/or multiple ``physicsnemo.Module`` instances at each level. To be able to pass a ``torch.nn.Module`` instance as an argument to the ``__init__`` function, it is necessary to first use the ``Module.from_torch`` method to convert the ``torch.nn.Module`` subclass to a ``physicsnemo.Module`` subclass To pass nested ``torch.nn.Module`` instances as arguments to the ``__init__`` function, it is necessary to convert **all** nested ``torch.nn.Module`` instances to ``physicsnemo.Module`` instances using the ``Module.from_torch`` method. See the examples below for more details. Parameters ---------- meta : ModelMetaData, optional Meta data class for storing info regarding model, by default None Examples -------- To construct nested ``physicsnemo.Module`` instances with multiple levels of nesting and/or multiple ``physicsnemo.Module`` instances at each level: .. code-block:: python class InnerModel(physicsnemo.Module): def __init__(self, hidden_size): super().__init__(meta=ModelMetaData()) self.hidden_size = hidden_size class OuterModel(physicsnemo.Module): def __init__(self, inner_model): super().__init__(meta=ModelMetaData()) self.inner_model = inner_model # Create and save nested model model = OuterModel(inner_model=InnerModel(128)) model.save("checkpoint.mdlus") loaded = physicsnemo.Module.from_checkpoint("checkpoint.mdlus") Applying this to a ``torch.nn.Module`` instance is also possible, as long as all nested ``torch.nn.Module`` instances are converted to ``physicsnemo.Module`` instances using the ``Module.from_torch`` method: .. code-block:: python class TorchInnerModel(torch.nn.Module): def __init__(self, size): super().__init__() self.size = size class TorchMyModel(torch.nn.Module): def __init__(self, inner_model): super().__init__() self.inner_model = inner_model # Convert both torch.nn.Module to physicsnemo.Module PNMInnerModel = physicsnemo.Module.from_torch( TorchInnerModel, meta=ModelMetaData() ) PNMMyModel = physicsnemo.Module.from_torch( TorchMyModel, meta=ModelMetaData() ) # Create nested model with converted torch modules model = PNMMyModel(inner_model=PNMInnerModel(size=128)) """ _file_extension = ".mdlus" # Set file extension for saving and loading __model_checkpoint_version__ = ( "0.1.0" # Used for file versioning and is not the same as physicsnemo version ) __supported_model_checkpoint_version__ = {} # Dict of supported model checkpoints and corresponding warnings messages # __init__ arguments that can be overridden. By default all arguments are # protected. Subclasses can override this to allow for overriding of specific # __init__'s arguments with the ``from_checkpoint`` method. _overridable_args: Set[str] = set() def __new__(cls, *args, **kwargs): out = super().__new__(cls) # Get signature of __init__ function sig = inspect.signature(cls.__init__) # Bind args and kwargs to signature bound_args = sig.bind_partial( *([None] + list(args)), **kwargs ) # Add None to account for self bound_args.apply_defaults() # Get args and kwargs (excluding self and unroll kwargs) instantiate_args = {} for param, (k, v) in zip(sig.parameters.values(), bound_args.arguments.items()): # Skip self if k == "self": continue # Add args and kwargs to instantiate_args if param.kind == param.VAR_KEYWORD: instantiate_args.update(v) else: instantiate_args[k] = v # Store args needed for instantiation out._args = { "__name__": cls.__name__, "__module__": cls.__module__, "__args__": instantiate_args, } return out def __init__(self, meta: Union[ModelMetaData, None] = None): super().__init__() self.meta = meta self.register_buffer("device_buffer", torch.empty(0)) self._setup_logger() def _setup_logger(self): self.logger = logging.getLogger("core.module") handler = logging.StreamHandler() formatter = logging.Formatter( "[%(asctime)s - %(levelname)s] %(message)s", datefmt="%H:%M:%S" ) handler.setFormatter(formatter) self.logger.addHandler(handler) self.logger.setLevel(logging.WARNING) @staticmethod def _safe_members(tar, local_path): for member in tar.getmembers(): if ( ".." in member.name or os.path.isabs(member.name) or os.path.realpath(os.path.join(local_path, member.name)).startswith( os.path.realpath(local_path) ) ): yield member else: print(f"Skipping potentially malicious file: {member.name}") @classmethod def _backward_compat_arg_mapper( cls, version: str, args: Dict[str, Any] ) -> Dict[str, Any]: """Map arguments from older versions to current version format. This base implementation does nothing. Child classes should override this method to handle version-specific argument mappings. Parameters ---------- version : str Version of the checkpoint being loaded args : Dict[str, Any] Arguments dictionary from the checkpoint Returns ------- Dict[str, Any] Updated arguments dictionary compatible with current version """ return args @classmethod def _override_args( cls, args: Dict[str, Any], override_args: Dict[str, Any] ) -> None: """Safely override ``__init__`` arguments stored in a checkpoint. This updates ``args`` *in-place* with the values provided in ``override_args``. Only keys defined in ``cls._overridable_args`` are allowed to be modified. Attempting to override any other key will raise a ``ValueError``. Parameters ---------- args : Dict[str, Any] Keyword arguments that will be forwarded to the model constructor (e.g. ``args["__args__"]`` from a checkpoint). override_args : Dict[str, Any] Dictionary containing the desired argument overrides. """ for key, value in override_args.items(): if key not in cls._overridable_args: raise ValueError( f"Argument '{key}' cannot be overridden for {cls.__name__}." ) # In this case we are not overriding, but we are adding a new arg if key not in args: warnings.warn(f"New argument '{key}' added for {cls.__name__}.") args[key] = value @classmethod def _get_class_from_args(cls, arg_dict: Dict[str, Any]) -> type: """Get the class from a dictionary of arguments. Parameters ---------- arg_dict : Dict[str, Any] Dictionary of arguments containing '__name__' and '__module__' keys. Returns ------- type The class to instantiate. Raises ------ AttributeError If the class cannot be found. """ _cls_name = arg_dict["__name__"] registry = ModelRegistry() if cls.__name__ == arg_dict["__name__"]: # If cls is the class return cls elif _cls_name in registry.list_models(): # Built in registry return registry.factory(_cls_name) else: try: # Check if module is using modulus import and change it to physicsnemo instead if arg_dict["__module__"].split(".")[0] == "modulus": warnings.warn( "Using modulus import in model checkpoint. This is deprecated and will be removed in future versions. Please use physicsnemo instead." ) arg_module = ( "physicsnemo" + arg_dict["__module__"][len("modulus") :] ) else: arg_module = arg_dict["__module__"] # Otherwise, try to import the class _mod = importlib.import_module(arg_module) _cls = getattr(_mod, arg_dict["__name__"]) except AttributeError: # Cross fingers and hope for the best (maybe the class name changed) _cls = cls # This works with the importlib.metadata.EntryPoint if isinstance(_cls, importlib.metadata.EntryPoint): _cls = _cls.load() return _cls @classmethod def instantiate(cls, arg_dict: Dict[str, Any]) -> "Module": """Instantiate a model from a dictionary of arguments Parameters ---------- arg_dict : Dict[str, Any] Dictionary of arguments to instantiate model with. This should be have three keys: '__name__', '__module__', and '__args__'. The first two are used to import the class and the last is used to instantiate the class. The '__args__' key should be a dictionary of arguments to pass to the class's __init__ function. Returns ------- Module Examples -------- >>> from physicsnemo.models import Module >>> from physicsnemo.registry import ModelRegistry >>> registry = ModelRegistry() >>> model_entry = registry.factory('FullyConnected') >>> fcn = model_entry(**{'in_features': 10}) >>> fcn FullyConnected( (layers): ModuleList( (0): FCLayer( (activation_fn): SiLU() (linear): Linear(in_features=10, out_features=512, bias=True) ) (1-5): 5 x FCLayer( (activation_fn): SiLU() (linear): Linear(in_features=512, out_features=512, bias=True) ) ) (final_layer): FCLayer( (activation_fn): Identity() (linear): Linear(in_features=512, out_features=512, bias=True) ) ) """ _cls = cls._get_class_from_args(arg_dict) return _cls(**arg_dict["__args__"]) def debug(self): """Turn on debug logging""" self.logger.handlers.clear() handler = logging.StreamHandler() formatter = logging.Formatter( f"[%(asctime)s - %(levelname)s - {self.meta.name}] %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) handler.setFormatter(formatter) self.logger.addHandler(handler) self.logger.setLevel(logging.DEBUG) # TODO: set up debug log # fh = logging.FileHandler(f'physicsnemo-core-{self.meta.name}.log') def save( self, file_name: Union[str, None] = None, verbose: bool = False, legacy_format: bool = False, ) -> None: """ Utility method for saving a ``Module`` instance to a '.mdlus' checkpoint file. Parameters ---------- file_name : Union[str,None], optional, default=None File name to save the model checkpoint to. When ``None`` is provided it will default to the model's name set in the meta data (the model's metadata must have a 'name' attribute in this case). verbose : bool, optional, default=False Whether to save the model in verbose mode which will include git hash, etc. legacy_format : bool, optional, default=False Whether to save the model in legacy tar format. If True, saves as tar archive. If False (default), saves as zip archive. Raises ------ ValueError If file_name does not end with .mdlus extension Examples -------- >>> from physicsnemo.models.mlp import FullyConnected >>> model = FullyConnected(in_features=32, out_features=64) >>> # Save a checkpoint with the default file name 'FullyConnected.mdlus'. >>> # In this case, the model.meta.name coincides with the model class name, but that is not always the case. >>> model.save() >>> # Save a checkpoint to a specified file name 'my_model.mdlus' >>> model.save("my_model.mdlus") """ # Define some helper functions def _save_process(module, args, metadata, mod_prefix="") -> None: """Recursively serialize nested physicsnemo.Module instances for checkpoint saving. Performs a depth-first search through the module's ``__init__`` arguments. When an argument is a ``physicsnemo.Module`` instance, it is replaced with a placeholder string (prefixed with ``_BASE_CKPT_PREFIX``) and the nested module's information (``__name__``, ``__module__``, ``__args__``) is stored at the root level of the ``args`` dictionary. The nested module metadata (e.g., ``__model_checkpoint_version__``) is also added at the root level of ``metadata`` dictionary, with keys prefixed with ``_BASE_CKPT_PREFIX``. This allows for reconstruction of arbitrarily nested module hierarchies during checkpoint loading. Parameters ---------- module : physicsnemo.Module The module being processed args : Dict[str, Any] Dictionary to populate with serialized module arguments. Modified in-place. Keys prefixed with ``_BASE_CKPT_PREFIX`` store nested module metadata. metadata : Dict[str, Any] Dictionary to populate with module metadata (e.g., version info). Modified in-place. mod_prefix : str, optional Current module's prefix in the nested hierarchy, by default "". Root module uses empty string; nested modules use format ``_BASE_CKPT_PREFIX.arg_name``. Raises ------ TypeError If an argument is a ``torch.nn.Module`` instance that has not been converted to a ``physicsnemo.Module`` using ``Module.from_torch``. """ # Pointer to args["__args__"] for submodules if mod_prefix == "": args_ptr = args["__args__"].copy() else: args_ptr = args[mod_prefix]["__args__"].copy() for arg_name, arg_value in args_ptr.items(): if isinstance(arg_value, Module): next_mod_prefix = ( f"{mod_prefix if mod_prefix else _BASE_CKPT_PREFIX}.{arg_name}" ) args[next_mod_prefix] = arg_value._args.copy() args_ptr[arg_name] = next_mod_prefix metadata[f"{next_mod_prefix}.mdlus_file_version"] = ( arg_value.__model_checkpoint_version__ ) _save_process(arg_value, args, metadata, next_mod_prefix) elif isinstance(arg_value, torch.nn.Module): raise TypeError( f"Submodule {arg_name} of module {module.__class__.__name__} is" f" a PyTorch module, which is not supported by 'Module.save'. Please " f"first convert it to a PhysicsNeMo module using 'Module.from_torch'." ) if mod_prefix == "": args["__args__"] = args_ptr else: args[mod_prefix]["__args__"] = args_ptr return if file_name is not None and not file_name.endswith(self._file_extension): raise ValueError( f"File name must end with {self._file_extension} extension" ) # Strip out torch dynamo wrapper if isinstance(self, torch._dynamo.eval_frame.OptimizedModule): self._orig_mod.save(file_name, verbose) return # Save the physicsnemo version and git hash (if available) metadata_info = { "physicsnemo_version": physicsnemo.__version__, "mdlus_file_version": self.__model_checkpoint_version__, } if verbose: import git try: repo = git.Repo(search_parent_directories=True) metadata_info["git_hash"] = repo.head.object.hexsha except git.InvalidGitRepositoryError: metadata_info["git_hash"] = None # Copy self._args to avoid side effects _args = self._args.copy() # Recursively populate _args and metadata_info with submodules # information _save_process(self, _args, metadata_info) # If file_name is not provided, use the model's name from the metadata if file_name is None: meta_name = getattr(self.meta, "name", None) if meta_name is None: raise ValueError( "Model metadata does not have a 'name' attribute, please set it " "explicitly or pass a 'file_name' argument to save a checkpoint." ) file_name = f"{meta_name}.mdlus" # Write checkpoint file fs = _get_fs(file_name) if not legacy_format: # Save in zip format (default) try: with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmp: tmp_path = tmp.name with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_STORED) as archive: # Save model state dict state_dict_buffer = io.BytesIO() torch.save(self.state_dict(), state_dict_buffer) archive.writestr("model.pt", state_dict_buffer.getvalue()) # Save args args_str = json.dumps(_args) archive.writestr("args.json", args_str) # Save metadata metadata_str = json.dumps(metadata_info) archive.writestr("metadata.json", metadata_str) # Upload to final destination fs.put(tmp_path, file_name) finally: # Clean up temporary file if os.path.exists(tmp_path): os.remove(tmp_path) else: # Save in legacy tar format with tempfile.TemporaryDirectory() as temp_dir: local_path = Path(temp_dir) # Save model state dict torch.save(self.state_dict(), local_path / "model.pt") # Save args with open(local_path / "args.json", "w") as f: json.dump(_args, f) # Save metadata with open(local_path / "metadata.json", "w") as f: json.dump(metadata_info, f) # Create tar archive with tarfile.open(local_path / "model.tar", "w") as tar: for file in local_path.iterdir(): tar.add(str(file), arcname=file.name) # Upload to final destination fs.put(local_path / "model.tar", file_name) @staticmethod def _detect_checkpoint_format(file_path: str) -> str: """Detect whether checkpoint is zip or tar format Parameters ---------- file_path : str Path to checkpoint file Returns ------- str Either 'zip' or 'tar' Raises ------ IOError If file format cannot be determined """ try: # NOTE: the check for tarfile MUST come first, as older checkpoints # will be both zip and tar archives, but newer checkpoints will # only be zip. if tarfile.is_tarfile(file_path): return "tar" elif zipfile.is_zipfile(file_path): return "zip" else: raise IOError( f"Checkpoint file {file_path} is neither a valid zip " f"nor tar archive" ) except Exception as e: raise IOError( f"Could not determine checkpoint format for {file_path}: {e}" ) from e @staticmethod def _check_checkpoint(local_path: Path | str) -> None: local_path = Path(local_path) expected_files = ["args.json", "metadata.json", "model.pt"] for file in expected_files: if not (local_path / file).exists(): raise IOError(f"File '{file}' not found in checkpoint") def load( self, file_name: str, map_location: Union[None, str, torch.device] = None, strict: bool = True, ) -> None: """ Utility method for loading the model weights from a '.mdlus' checkpoint file. Unlike :meth:`~physicsnemo.models.module.Module.from_checkpoint`, this method *does not* instantiate the model, but rather loads the ``state_dict`` for an already instantiated model. Parameters ---------- file_name : str Checkpoint file name. Must be a valid '.mdlus' checkpoint file. map_location : Union[None, str, torch.device], optional, default=None Map location for loading the model weights, ``None`` will use the model's device. strict: bool, optional, default=True Whether to strictly enforce that the keys in ``state_dict`` match. Raises ------ IOError If ``file_name`` provided does not exist or is not a valid checkpoint Examples -------- Basic example loading the model weights (state_dict) from a checkpoint: .. code-block:: python from physicsnemo.models.mlp import FullyConnected # Create a model with the same architecture as the saved one model = FullyConnected(in_features=32, out_features=64) # Load the weights from checkpoint model.load("FullyConnected.mdlus") Loading with specific device mapping: .. code-block:: python import torch from physicsnemo.models.mlp import FullyConnected model = FullyConnected(in_features=32, out_features=64) # Load checkpoint to CPU even if it was saved on GPU model.load("FullyConnected.mdlus", map_location="cpu") # Or load to a specific GPU model.load("FullyConnected.mdlus", map_location=torch.device("cuda:0")) """ # Download and cache the checkpoint file if needed cached_file_name = _download_cached(file_name) # Detect checkpoint format checkpoint_format = Module._detect_checkpoint_format(cached_file_name) device = map_location if map_location is not None else self.device if checkpoint_format == "zip": # Load directly from zip file (no extraction needed) with zipfile.ZipFile(cached_file_name, "r") as archive: # Check if all expected files are present expected_files = ["args.json", "metadata.json", "model.pt"] archive_files = archive.namelist() for expected_file in expected_files: if expected_file not in archive_files: raise IOError(f"File '{expected_file}' not found in checkpoint") # Read into memory model_bytes = archive.read("model.pt") # Load state dict after closing archive model_dict = torch.load(io.BytesIO(model_bytes), map_location=device) # Load state_dict into the model _load_state_dict_with_logging(self, model_dict, strict=strict) else: # tar format (backward compatibility) # Use a temporary directory to extract the tar file with tempfile.TemporaryDirectory() as temp_dir: local_path = Path(temp_dir) # Open tar file and extract contents to temporary directory with tarfile.open(cached_file_name, "r") as tar: # Safely extract while supporting Python < 3.12 extract_kwargs = dict( path=local_path, members=list(Module._safe_members(tar, local_path)), ) if "filter" in tar.extractall.__code__.co_varnames: extract_kwargs["filter"] = "data" tar.extractall(**extract_kwargs) # noqa: S202 # Check if the checkpoint is valid Module._check_checkpoint(local_path) # Load the model weights model_dict = torch.load( local_path.joinpath("model.pt"), map_location=device ) # Load state dict into the model _load_state_dict_with_logging(self, model_dict, strict=strict) @classmethod def from_checkpoint( cls, file_name: str, override_args: Optional[Dict[str, Any]] = None, strict: bool = True, ) -> physicsnemo.Module: """ Utility class method for instantiating and loading a ``Module`` instance from a '.mdlus' checkpoint file. Parameters ---------- file_name : str Checkpoint file name. Must be a valid '.mdlus' checkpoint file. override_args : Optional[Dict[str, Any]], optional, default=None Dictionary of arguments to override the ``__init__`` method's arguments saved in the checkpoint. The override of arguments occurs *before* the model is instantiated, which allows for *ad-hoc* modifications to the model's initialization. Argument overrides are however applied *before* the state-dict is loaded, which means that for parameters or buffers saved in the state-dict, the values contained in the state-dict will take precedence over the override. This might also result in unexpected behavior if the model is instantiated with different arguments than the ones saved in the checkpoint, and some mismatching keys are saved in the state-dict. *Note*: Only arguments defined in ``cls._overridable_args`` can be overridden. ``Module``'s subclasses by default disable this functionality, unless they explicity define an ``_overridable_args`` class attribute. Attempting to override any other argument will raise a ``ValueError``. This API should be used with caution and only if you fully understand the implications of the override. strict : bool, optional Whether to strictly enforce that the keys in state_dict match, by default True Returns ------- Module Raises ------ IOError If file_name provided does not exist or is not a valid checkpoint Examples -------- Simple argument override: .. code-block:: python class MyModel(Module): _overridable_args = set(["a", "b"]) def __init__(self, a, b=2.0): super().__init__() # ... model implementation ... model = MyModel(1.0, b=2.0) model.save("checkpoint.mdlus") model_loaded = MyModel.from_checkpoint("checkpoint.mdlus", override_args={"a": 5.0}) For nested module, override is possible with dot-separated syntax: .. code-block:: python class SubModule(Module): _overridable_args = set(["a"]) def __init__(self, a): super().__init__() # ... submodule implementation ... class MyModel(Module): def __init__(self, submodule): super().__init__() self.submodule = submodule # ... model implementation ... submodule = SubModule(1.0) model = MyModel(submodule) model.save("checkpoint.mdlus") model = MyModel.from_checkpoint("checkpoint.mdlus", override_args={"submodule.a": 2.0}) """ # Validate the format of override_args keys override_args = override_args or {} for k in override_args.keys(): if not isinstance(k, str): raise ValueError( f"All keys in override_args must be strings, got {type(k)} for key {k}" ) if not all( p and p.isidentifier() and not keyword.iskeyword(p) for p in k.split(".") ): raise ValueError( f"Key {k} in override_args does not match the expected format " f"arg_name1.arg_name2..." ) # Define some helper functions def _from_checkpoint_process( cls_in, args, metadata, override_args, strict, mod_prefix="", ): """Recursively deserialize and instantiate nested physicsnemo.Module instances. Performs a depth-first reconstruction of the module hierarchy from a checkpoint. When an argument value is a placeholder string (prefixed with ``_BASE_CKPT_PREFIX``), it is replaced with a recursively instantiated ``physicsnemo.Module`` instance. This is the reciprocal operation of ``_save_process``, reconstructing the original nested module structure from the serialized checkpoint data. Parameters ---------- cls_in : type The class of the module to instantiate at the current recursion level args : Dict[str, Any] Dictionary containing serialized module arguments from the checkpoint. Keys prefixed with ``_BASE_CKPT_PREFIX`` contain nested module metadata. Modified in-place as nested modules are processed and removed. metadata : Dict[str, Any] Dictionary containing module metadata (e.g., version info) from the checkpoint. Modified in-place as nested modules are processed and removed. override_args : Dict[str, Any] Dictionary of arguments to override in the module's ``__init__`` method. Supports dot-separated syntax for nested module arguments. strict : bool Whether to strictly enforce that state_dict keys match when loading weights mod_prefix : str, optional Current module's prefix in the nested hierarchy, by default "". Root module uses empty string; nested modules use format ``_BASE_CKPT_PREFIX.arg_name``. Returns ------- physicsnemo.Module The instantiated module with all nested submodules recursively constructed Raises ------ IOError If the checkpoint version is incompatible with the current model version ValueError If argument names or prefixes don't match the expected format """ # Pointer to args (for submodules) if mod_prefix == "": args_ptr = { k: v for k, v in args.items() if not k.startswith(_BASE_CKPT_PREFIX) } override_args_ptr = { k: v for k, v in override_args.items() if k.isidentifier() and not keyword.iskeyword(k) } else: args_ptr = args[mod_prefix] prefix = mod_prefix[len(_BASE_CKPT_PREFIX) + 1 :] override_args_ptr = {} for k, v in override_args.items(): if k.startswith(f"{prefix}."): suffix = k[len(prefix) + 1 :] # +1 for the dot if suffix.isidentifier() and not keyword.iskeyword(suffix): override_args_ptr[suffix] = v # Get the checkpoint version version = metadata.get( f"{mod_prefix}{'.' if mod_prefix else ''}mdlus_file_version", cls_in.__model_checkpoint_version__, ) # Get the class from args _cls = Module._get_class_from_args(args_ptr) # Check if the checkpoint version is compatible with the current version # If not, apply backward compatibility mapping if method exists if version != _cls.__model_checkpoint_version__: if version in _cls.__supported_model_checkpoint_version__: warnings.warn(_cls.__supported_model_checkpoint_version__[version]) args_ptr["__args__"] = _cls._backward_compat_arg_mapper( version, args_ptr["__args__"] ) else: raise IOError( f"Model checkpoint version {version} is not compatible with " f"current version {_cls.__model_checkpoint_version__} of class " f"{_cls.__name__}" ) # Process all args and recursively instantiate those that are # submodules for arg_name, arg_value in args_ptr["__args__"].items(): if not isinstance(arg_value, str): continue is_module = re.match(rf"{_BASE_CKPT_PREFIX}(.*)", arg_value) if is_module: suffix = is_module.group(1) args_split = re.match(r"^(.*\.)*([^\.]+)$", suffix) if args_split: _arg_name = args_split.group(2) # Make sure that arg_value has the expected format if _arg_name != arg_name: raise ValueError( f"Argument name '{_arg_name}' does not match the " f"expected '{arg_name}' for module {_cls.__name__}" ) # Instantiate the submodule next_mod_prefix = arg_value args_ptr["__args__"][arg_name] = _from_checkpoint_process( Module._get_class_from_args(args[next_mod_prefix]), args, metadata, override_args, strict, mod_prefix=next_mod_prefix, ) # Cleanup args and metadata by removing the items # related to the submodule args.pop(next_mod_prefix, None) metadata.pop(f"{next_mod_prefix}.mdlus_file_version", None) else: # Make sure that arg_value has the expected format raise ValueError( f"Argument value '{arg_value}' for argument '{arg_name}' " f"of module {_cls.__name__} does not match the expected format " f"{_BASE_CKPT_PREFIX}.arg_name1.arg_name2..." ) # Override args_ptr["__args__"] with override_args if override_args is not None: _cls._override_args(args_ptr["__args__"], override_args_ptr) # Instantiate the module model = Module.instantiate(args_ptr) return model # Download and cache the checkpoint file if needed cached_file_name = _download_cached(file_name) # Detect checkpoint format checkpoint_format = Module._detect_checkpoint_format(cached_file_name) if checkpoint_format == "zip": # Load directly from zip file (no extraction needed) with zipfile.ZipFile(cached_file_name, "r") as archive: # Check if all expected files are present expected_files = ["args.json", "metadata.json", "model.pt"] archive_files = archive.namelist() for expected_file in expected_files: if expected_file not in archive_files: raise IOError(f"File '{expected_file}' not found in checkpoint") # Load model arguments and instantiate the model with archive.open("args.json") as f: args = json.loads(f.read().decode("utf-8")) # Load metadata to get version with archive.open("metadata.json") as f: metadata = json.loads(f.read().decode("utf-8")) model = _from_checkpoint_process( cls, args, metadata, override_args, strict, ) # Read into memory model_bytes = archive.read("model.pt") # Load state dict after closing archive model_dict = torch.load(io.BytesIO(model_bytes), map_location=model.device) # Load state_dict into the model _load_state_dict_with_logging(model, model_dict, strict=strict) else: # tar format (backward compatibility) # Use a temporary directory to extract the tar file with tempfile.TemporaryDirectory() as temp_dir: local_path = Path(temp_dir) # Open tar file and extract contents to temporary directory with tarfile.open(cached_file_name, "r") as tar: # Safely extract while supporting Python < 3.12 extract_kwargs = dict( path=local_path, members=list(Module._safe_members(tar, local_path)), ) if "filter" in tar.extractall.__code__.co_varnames: extract_kwargs["filter"] = "data" tar.extractall(**extract_kwargs) # noqa: S202 # Check if the checkpoint is valid Module._check_checkpoint(local_path) # Load model arguments and instantiate the model with open(local_path.joinpath("args.json"), "r") as f: args = json.load(f) # Load metadata to get version with open(local_path.joinpath("metadata.json"), "r") as f: metadata = json.load(f) model = _from_checkpoint_process( cls, args, metadata, override_args, strict, ) # Load the model weights model_dict = torch.load( local_path.joinpath("model.pt"), map_location=model.device ) # Load state_dict into the model _load_state_dict_with_logging(model, model_dict, strict=strict) return model @staticmethod def from_torch( torch_model_class: type[torch.nn.Module], meta: ModelMetaData | None = None ) -> type[Module]: """Construct a PhysicsNeMo module from a PyTorch module Parameters ---------- torch_model_class : torch.nn.Module PyTorch module class meta : ModelMetaData, optional Meta data for the model, by default None Returns ------- Module """ # Define an internal class as before class PhysicsNeMoModel(Module): def __init__(self, *args, **kwargs): super().__init__(meta=meta) self.inner_model = torch_model_class(*args, **kwargs) def forward(self, x): return self.inner_model(x) # Get the argument names and default values of the PyTorch model's init # method init_argspec = inspect.getfullargspec(torch_model_class.__init__) model_argnames = init_argspec.args[1:] # Exclude 'self' model_defaults = init_argspec.defaults or [] defaults_dict = dict( zip(model_argnames[-len(model_defaults) :], model_defaults) ) # Define the signature of new init params = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] params += [ inspect.Parameter( argname, inspect.Parameter.POSITIONAL_OR_KEYWORD, default=defaults_dict.get(argname, inspect.Parameter.empty), ) for argname in model_argnames ] init_signature = inspect.Signature(params) # Replace PhysicsNeMoModel.__init__ signature with new init signature PhysicsNeMoModel.__init__.__signature__ = init_signature # Generate a unique name for the created class new_class_name = f"{torch_model_class.__name__}PhysicsNeMoModel" PhysicsNeMoModel.__name__ = new_class_name # Add this class to the dict of models classes registry = ModelRegistry() registry.register(PhysicsNeMoModel, new_class_name) return PhysicsNeMoModel @property def device(self) -> torch.device: """Get device model is on Returns ------- torch.device PyTorch device """ return self.device_buffer.device def num_parameters(self) -> int: """Gets the number of learnable parameters""" count = 0 for name, param in self.named_parameters(): count += param.numel() return count