ASR_Streaming_Server / download_models.py
Piyazon
push again
b67ab05
import os
import sys
from pathlib import Path
from huggingface_hub import hf_hub_download
DEFAULT_MODEL_REPO = "piyazon/whisper_uyghur_pt"
DEFAULT_MODEL_FILENAMES = "uyghur_whisper_tiny.pt"
DEFAULT_SELECTED_MODEL = "uyghur_whisper_tiny.pt"
def _split_filenames(value: str) -> list[str]:
return [item.strip() for item in value.replace(",", " ").split() if item.strip()]
def _log(message: str) -> None:
print(message, file=sys.stderr, flush=True)
def main() -> int:
repo_id = os.getenv("HF_MODEL_REPO", DEFAULT_MODEL_REPO)
selected_filename = os.getenv("WHISPER_MODEL_FILENAME", DEFAULT_SELECTED_MODEL)
model_dir = Path(os.getenv("MODEL_DIR", "/home/user/models")).expanduser()
token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
filenames = _split_filenames(os.getenv("HF_MODEL_FILENAMES", DEFAULT_MODEL_FILENAMES))
if selected_filename not in filenames:
filenames.insert(0, selected_filename)
model_dir.mkdir(parents=True, exist_ok=True)
downloaded_paths: dict[str, str] = {}
for filename in filenames:
cached_path = model_dir / filename
if cached_path.exists() and cached_path.stat().st_size > 0:
_log(f"Using cached model: {cached_path}")
downloaded_paths[filename] = str(cached_path)
continue
_log(f"Downloading {repo_id}/{filename} to {model_dir}")
try:
downloaded_paths[filename] = hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=str(model_dir),
token=token,
)
except Exception as exc:
_log(f"Failed to download {repo_id}/{filename}: {exc}")
_log(
"If the model repository is private, add an HF_TOKEN Space secret "
"with read access to the model repo."
)
return 1
print(downloaded_paths[selected_filename], flush=True)
return 0
if __name__ == "__main__":
raise SystemExit(main())