sv-task / src /models /hf_download.py
lamossta's picture
hf upload/download and onnx export
9f3aa4a
import argparse
import os
from pathlib import Path
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download
load_dotenv()
MODELS_ROOT = Path("models")
VALID_MODES = ("marker", "qa_m", "qa_b", "fasttext")
REPO_PREFIX = "lamossta/distillbert"
FASTTEXT_REPO = "lamossta/fasttext_baseline"
def _repo_id(mode: str) -> str:
if mode == "fasttext":
return FASTTEXT_REPO
return f"{REPO_PREFIX}_{mode}"
def _model_filename(mode: str) -> str:
return "model.bin" if mode == "fasttext" else "model.onnx"
def download_model(mode: str, revision: str = "main") -> Path:
token = os.environ.get("HF_TOKEN")
repo_id = _repo_id(mode)
model_dir = MODELS_ROOT / mode
model_dir.mkdir(parents=True, exist_ok=True)
filename = _model_filename(mode)
hf_hub_download(
repo_id=repo_id,
filename=filename,
revision=revision,
token=token,
local_dir=str(model_dir),
)
print(f"Downloaded {repo_id}/{filename} -> {model_dir / filename}")
return model_dir
def download_all(revision: str = "main") -> dict[str, Path]:
downloaded = {}
for mode in VALID_MODES:
try:
downloaded[mode] = download_model(mode, revision)
except Exception as e:
print(f"Skipping '{mode}': {e}")
return downloaded
def main():
parser = argparse.ArgumentParser(description="Download ONNX models from Hugging Face")
parser.add_argument("--mode", default=None, choices=VALID_MODES, help="Single mode to download (default: all)")
parser.add_argument("--revision", default="main")
args = parser.parse_args()
if args.mode:
download_model(args.mode, args.revision)
else:
downloaded = download_all(args.revision)
print(f"Downloaded models: {downloaded}")
if __name__ == "__main__":
main()