NariLabs commited on
Commit
92fbf87
·
verified ·
1 Parent(s): f385dd1

Delete assets.py

Browse files
Files changed (1) hide show
  1. assets.py +0 -65
assets.py DELETED
@@ -1,65 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- import os
5
- from dataclasses import dataclass
6
- from pathlib import Path
7
- from typing import Optional
8
-
9
- from huggingface_hub import hf_hub_download
10
-
11
- ASSET_MANIFEST = os.environ.get("DIA2_ASSET_MANIFEST", "dia2_assets.json")
12
-
13
-
14
- @dataclass(frozen=True)
15
- class AssetBundle:
16
- config_path: str
17
- weights_path: str
18
- tokenizer_id: Optional[str]
19
- mimi_id: Optional[str]
20
- repo_id: Optional[str]
21
-
22
-
23
- def resolve_assets(
24
- *,
25
- repo: Optional[str],
26
- config_path: Optional[str | Path],
27
- weights_path: Optional[str | Path],
28
- manifest_name: Optional[str] = None,
29
- ) -> AssetBundle:
30
- repo_id = repo
31
- manifest_name = manifest_name or ASSET_MANIFEST
32
- if repo_id and (config_path or weights_path):
33
- raise ValueError("Provide either repo or config+weights, not both")
34
- if config_path is None or weights_path is None:
35
- if repo_id is None:
36
- raise ValueError("Must specify repo or config+weights")
37
- manifest = load_manifest(repo_id, manifest_name)
38
- config_name = manifest.get("config", "config.json")
39
- weights_name = manifest.get("weights", "model.safetensors")
40
- config_local = hf_hub_download(repo_id, config_name)
41
- weights_local = hf_hub_download(repo_id, weights_name)
42
- return AssetBundle(
43
- config_path=config_local,
44
- weights_path=weights_local,
45
- tokenizer_id=manifest.get("tokenizer") or repo_id,
46
- mimi_id=manifest.get("mimi"),
47
- repo_id=repo_id,
48
- )
49
- return AssetBundle(str(config_path), str(weights_path), None, None, repo_id)
50
-
51
-
52
- def load_manifest(repo_id: str, manifest_name: str) -> dict:
53
- if not manifest_name:
54
- return {}
55
- try:
56
- path = hf_hub_download(repo_id, manifest_name)
57
- except Exception:
58
- return {}
59
- try:
60
- return json.loads(Path(path).read_text())
61
- except json.JSONDecodeError:
62
- return {}
63
-
64
-
65
- __all__ = ["AssetBundle", "ASSET_MANIFEST", "resolve_assets", "load_manifest"]