File size: 2,042 Bytes
2b84626
 
 
 
 
 
 
 
b67ab05
 
2b84626
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import sys
from pathlib import Path

from huggingface_hub import hf_hub_download


DEFAULT_MODEL_REPO = "piyazon/whisper_uyghur_pt"
DEFAULT_MODEL_FILENAMES = "uyghur_whisper_tiny.pt"
DEFAULT_SELECTED_MODEL = "uyghur_whisper_tiny.pt"


def _split_filenames(value: str) -> list[str]:
    return [item.strip() for item in value.replace(",", " ").split() if item.strip()]


def _log(message: str) -> None:
    print(message, file=sys.stderr, flush=True)


def main() -> int:
    repo_id = os.getenv("HF_MODEL_REPO", DEFAULT_MODEL_REPO)
    selected_filename = os.getenv("WHISPER_MODEL_FILENAME", DEFAULT_SELECTED_MODEL)
    model_dir = Path(os.getenv("MODEL_DIR", "/home/user/models")).expanduser()
    token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")

    filenames = _split_filenames(os.getenv("HF_MODEL_FILENAMES", DEFAULT_MODEL_FILENAMES))
    if selected_filename not in filenames:
        filenames.insert(0, selected_filename)

    model_dir.mkdir(parents=True, exist_ok=True)

    downloaded_paths: dict[str, str] = {}
    for filename in filenames:
        cached_path = model_dir / filename
        if cached_path.exists() and cached_path.stat().st_size > 0:
            _log(f"Using cached model: {cached_path}")
            downloaded_paths[filename] = str(cached_path)
            continue

        _log(f"Downloading {repo_id}/{filename} to {model_dir}")
        try:
            downloaded_paths[filename] = hf_hub_download(
                repo_id=repo_id,
                filename=filename,
                local_dir=str(model_dir),
                token=token,
            )
        except Exception as exc:
            _log(f"Failed to download {repo_id}/{filename}: {exc}")
            _log(
                "If the model repository is private, add an HF_TOKEN Space secret "
                "with read access to the model repo."
            )
            return 1

    print(downloaded_paths[selected_filename], flush=True)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())