|
|
from __future__ import annotations |
|
|
|
|
|
import json |
|
|
import os |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import Optional |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
ASSET_MANIFEST = os.environ.get("DIA2_ASSET_MANIFEST", "dia2_assets.json") |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class AssetBundle: |
|
|
config_path: str |
|
|
weights_path: str |
|
|
tokenizer_id: Optional[str] |
|
|
mimi_id: Optional[str] |
|
|
repo_id: Optional[str] |
|
|
|
|
|
|
|
|
def resolve_assets( |
|
|
*, |
|
|
repo: Optional[str], |
|
|
config_path: Optional[str | Path], |
|
|
weights_path: Optional[str | Path], |
|
|
manifest_name: Optional[str] = None, |
|
|
) -> AssetBundle: |
|
|
repo_id = repo |
|
|
manifest_name = manifest_name or ASSET_MANIFEST |
|
|
if repo_id and (config_path or weights_path): |
|
|
raise ValueError("Provide either repo or config+weights, not both") |
|
|
if config_path is None or weights_path is None: |
|
|
if repo_id is None: |
|
|
raise ValueError("Must specify repo or config+weights") |
|
|
manifest = load_manifest(repo_id, manifest_name) |
|
|
config_name = manifest.get("config", "config.json") |
|
|
weights_name = manifest.get("weights", "model.safetensors") |
|
|
config_local = hf_hub_download(repo_id, config_name) |
|
|
weights_local = hf_hub_download(repo_id, weights_name) |
|
|
return AssetBundle( |
|
|
config_path=config_local, |
|
|
weights_path=weights_local, |
|
|
tokenizer_id=manifest.get("tokenizer") or repo_id, |
|
|
mimi_id=manifest.get("mimi"), |
|
|
repo_id=repo_id, |
|
|
) |
|
|
return AssetBundle(str(config_path), str(weights_path), None, None, repo_id) |
|
|
|
|
|
|
|
|
def load_manifest(repo_id: str, manifest_name: str) -> dict: |
|
|
if not manifest_name: |
|
|
return {} |
|
|
try: |
|
|
path = hf_hub_download(repo_id, manifest_name) |
|
|
except Exception: |
|
|
return {} |
|
|
try: |
|
|
return json.loads(Path(path).read_text()) |
|
|
except json.JSONDecodeError: |
|
|
return {} |
|
|
|
|
|
|
|
|
__all__ = ["AssetBundle", "ASSET_MANIFEST", "resolve_assets", "load_manifest"] |
|
|
|