| from huggingface_hub import hf_hub_download | |
| def from_pretrained( | |
| cls, model_name: str = "nari-labs/Dia-1.6B", device: torch.device = torch.device("cuda") | |
| ) -> "Dia": | |
| """Loads the Dia model from a Hugging Face Hub repository. | |
| Downloads the configuration and checkpoint files from the specified | |
| repository ID and then loads the model. | |
| Args: | |
| model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B"). | |
| device: The device to load the model onto. | |
| Returns: | |
| An instance of the Dia model loaded with weights and set to eval mode. | |
| Raises: | |
| FileNotFoundError: If config or checkpoint download/loading fails. | |
| RuntimeError: If there is an error loading the checkpoint. | |
| """ | |
| config_path = hf_hub_download(repo_id=model_name, filename="config.json") | |
| checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth") | |
| return cls.from_local(config_path, checkpoint_path, device) |