File size: 2,019 Bytes
1315cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from __future__ import annotations

import json
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

from huggingface_hub import hf_hub_download

ASSET_MANIFEST = os.environ.get("DIA2_ASSET_MANIFEST", "dia2_assets.json")


@dataclass(frozen=True)
class AssetBundle:
    config_path: str
    weights_path: str
    tokenizer_id: Optional[str]
    mimi_id: Optional[str]
    repo_id: Optional[str]


def resolve_assets(
    *,
    repo: Optional[str],
    config_path: Optional[str | Path],
    weights_path: Optional[str | Path],
    manifest_name: Optional[str] = None,
) -> AssetBundle:
    repo_id = repo
    manifest_name = manifest_name or ASSET_MANIFEST
    if repo_id and (config_path or weights_path):
        raise ValueError("Provide either repo or config+weights, not both")
    if config_path is None or weights_path is None:
        if repo_id is None:
            raise ValueError("Must specify repo or config+weights")
        manifest = load_manifest(repo_id, manifest_name)
        config_name = manifest.get("config", "config.json")
        weights_name = manifest.get("weights", "model.safetensors")
        config_local = hf_hub_download(repo_id, config_name)
        weights_local = hf_hub_download(repo_id, weights_name)
        return AssetBundle(
            config_path=config_local,
            weights_path=weights_local,
            tokenizer_id=manifest.get("tokenizer") or repo_id,
            mimi_id=manifest.get("mimi"),
            repo_id=repo_id,
        )
    return AssetBundle(str(config_path), str(weights_path), None, None, repo_id)


def load_manifest(repo_id: str, manifest_name: str) -> dict:
    if not manifest_name:
        return {}
    try:
        path = hf_hub_download(repo_id, manifest_name)
    except Exception:
        return {}
    try:
        return json.loads(Path(path).read_text())
    except json.JSONDecodeError:
        return {}


__all__ = ["AssetBundle", "ASSET_MANIFEST", "resolve_assets", "load_manifest"]