Create dia_use.py
Browse files- dia_use.py +20 -0
dia_use.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import hf_hub_download
|
| 2 |
+
|
| 3 |
+
def from_pretrained(
|
| 4 |
+
cls, model_name: str = "nari-labs/Dia-1.6B", device: torch.device = torch.device("cuda")
|
| 5 |
+
) -> "Dia":
|
| 6 |
+
"""Loads the Dia model from a Hugging Face Hub repository.
|
| 7 |
+
Downloads the configuration and checkpoint files from the specified
|
| 8 |
+
repository ID and then loads the model.
|
| 9 |
+
Args:
|
| 10 |
+
model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B").
|
| 11 |
+
device: The device to load the model onto.
|
| 12 |
+
Returns:
|
| 13 |
+
An instance of the Dia model loaded with weights and set to eval mode.
|
| 14 |
+
Raises:
|
| 15 |
+
FileNotFoundError: If config or checkpoint download/loading fails.
|
| 16 |
+
RuntimeError: If there is an error loading the checkpoint.
|
| 17 |
+
"""
|
| 18 |
+
config_path = hf_hub_download(repo_id=model_name, filename="config.json")
|
| 19 |
+
checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth")
|
| 20 |
+
return cls.from_local(config_path, checkpoint_path, device)
|