# This file is modified from TRELLIS: # https://github.com/microsoft/TRELLIS # Original license: MIT # Copyright (c) the TRELLIS authors # Modifications Copyright (c) 2026 Ze-Xin Yin, Robot labs of Horizon Robotics, and D-Robotics. import importlib __attributes = { 'SparseStructureEncoder': 'sparse_structure_vae', 'SparseStructureDecoder': 'sparse_structure_vae', 'SparseStructureFlowModel': 'sparse_structure_flow', 'SLatEncoder': 'structured_latent_vae', 'SLatGaussianDecoder': 'structured_latent_vae', 'SLatRadianceFieldDecoder': 'structured_latent_vae', 'SLatMeshDecoder': 'structured_latent_vae', 'ElasticSLatEncoder': 'structured_latent_vae', 'ElasticSLatGaussianDecoder': 'structured_latent_vae', 'ElasticSLatRadianceFieldDecoder': 'structured_latent_vae', 'ElasticSLatMeshDecoder': 'structured_latent_vae', 'SLatFlowModel': 'structured_latent_flow', 'ElasticSLatFlowModel': 'structured_latent_flow', 'SceneSLatFlowModel': 'scene_structured_latent_flow', 'ElasticSceneSLatFlowModel': 'scene_structured_latent_flow', 'SceneSparseStructureFlowModule': 'scene_sparse_structure_flow', } __submodules = [] __all__ = list(__attributes.keys()) + __submodules def __getattr__(name): if name not in globals(): if name in __attributes: module_name = __attributes[name] module = importlib.import_module(f".{module_name}", __name__) globals()[name] = getattr(module, name) elif name in __submodules: module = importlib.import_module(f".{name}", __name__) globals()[name] = module else: raise AttributeError(f"module {__name__} has no attribute {name}") return globals()[name] def from_pretrained(path: str, **kwargs): """ Load a model from a pretrained checkpoint. Args: path: The path to the checkpoint. Can be either local path or a Hugging Face model name. NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively. **kwargs: Additional arguments for the model constructor. """ import os import json import torch from safetensors.torch import load_file from ..utils.dist_utils import read_file_dist is_local = os.path.exists(f"{path}.json") and (os.path.exists(f"{path}.safetensors") or os.path.exists(f"{path}.pt")) if is_local: config_file = f"{path}.json" model_file = f"{path}.safetensors" if os.path.exists(f"{path}.safetensors") else f"{path}.pt" else: from huggingface_hub import hf_hub_download path_parts = path.split('/') repo_id = f'{path_parts[0]}/{path_parts[1]}' model_name = '/'.join(path_parts[2:]) config_file = hf_hub_download(repo_id, f"{model_name}.json") model_file = hf_hub_download(repo_id, f"{model_name}.safetensors") with open(config_file, 'r') as f: config = json.load(f) model = __getattr__(config['name'])(**config['args'], **kwargs) if model_file.endswith(".safetensors"): model.load_state_dict(load_file(model_file)) else: model_ckpt = torch.load(read_file_dist(model_file), map_location='cpu', weights_only=True) model.load_state_dict(model_ckpt) if model.dtype == torch.float16: model.convert_to_fp16() return model # For Pylance if __name__ == '__main__': from .sparse_structure_vae import ( SparseStructureEncoder, SparseStructureDecoder, ) from .sparse_structure_flow import SparseStructureFlowModel from .structured_latent_vae import ( SLatEncoder, SLatGaussianDecoder, SLatRadianceFieldDecoder, SLatMeshDecoder, ElasticSLatEncoder, ElasticSLatGaussianDecoder, ElasticSLatRadianceFieldDecoder, ElasticSLatMeshDecoder, ) from .structured_latent_flow import ( SLatFlowModel, ElasticSLatFlowModel, ) from .scene_sparse_structure_flow import ( SceneSparseStructureFlowModule ) from .scene_structured_latent_flow import ( SceneSLatFlowModel, ElasticSceneSLatFlowModel )