import argparse import os from pathlib import Path from dotenv import load_dotenv from huggingface_hub import hf_hub_download load_dotenv() MODELS_ROOT = Path("models") VALID_MODES = ("marker", "qa_m", "qa_b", "fasttext") REPO_PREFIX = "lamossta/distillbert" FASTTEXT_REPO = "lamossta/fasttext_baseline" def _repo_id(mode: str) -> str: if mode == "fasttext": return FASTTEXT_REPO return f"{REPO_PREFIX}_{mode}" def _model_filename(mode: str) -> str: return "model.bin" if mode == "fasttext" else "model.onnx" def download_model(mode: str, revision: str = "main") -> Path: token = os.environ.get("HF_TOKEN") repo_id = _repo_id(mode) model_dir = MODELS_ROOT / mode model_dir.mkdir(parents=True, exist_ok=True) filename = _model_filename(mode) hf_hub_download( repo_id=repo_id, filename=filename, revision=revision, token=token, local_dir=str(model_dir), ) print(f"Downloaded {repo_id}/{filename} -> {model_dir / filename}") return model_dir def download_all(revision: str = "main") -> dict[str, Path]: downloaded = {} for mode in VALID_MODES: try: downloaded[mode] = download_model(mode, revision) except Exception as e: print(f"Skipping '{mode}': {e}") return downloaded def main(): parser = argparse.ArgumentParser(description="Download ONNX models from Hugging Face") parser.add_argument("--mode", default=None, choices=VALID_MODES, help="Single mode to download (default: all)") parser.add_argument("--revision", default="main") args = parser.parse_args() if args.mode: download_model(args.mode, args.revision) else: downloaded = download_all(args.revision) print(f"Downloaded models: {downloaded}") if __name__ == "__main__": main()