| import argparse |
| import os |
| from pathlib import Path |
| from dotenv import load_dotenv |
| from huggingface_hub import hf_hub_download |
|
|
| load_dotenv() |
|
|
| MODELS_ROOT = Path("models") |
| VALID_MODES = ("marker", "qa_m", "qa_b", "fasttext") |
| REPO_PREFIX = "lamossta/distillbert" |
| FASTTEXT_REPO = "lamossta/fasttext_baseline" |
|
|
|
|
| def _repo_id(mode: str) -> str: |
| if mode == "fasttext": |
| return FASTTEXT_REPO |
| return f"{REPO_PREFIX}_{mode}" |
|
|
|
|
| def _model_filename(mode: str) -> str: |
| return "model.bin" if mode == "fasttext" else "model.onnx" |
|
|
|
|
| def download_model(mode: str, revision: str = "main") -> Path: |
| token = os.environ.get("HF_TOKEN") |
| repo_id = _repo_id(mode) |
| model_dir = MODELS_ROOT / mode |
| model_dir.mkdir(parents=True, exist_ok=True) |
| filename = _model_filename(mode) |
|
|
| hf_hub_download( |
| repo_id=repo_id, |
| filename=filename, |
| revision=revision, |
| token=token, |
| local_dir=str(model_dir), |
| ) |
| print(f"Downloaded {repo_id}/{filename} -> {model_dir / filename}") |
|
|
| return model_dir |
|
|
|
|
| def download_all(revision: str = "main") -> dict[str, Path]: |
| downloaded = {} |
| for mode in VALID_MODES: |
| try: |
| downloaded[mode] = download_model(mode, revision) |
| except Exception as e: |
| print(f"Skipping '{mode}': {e}") |
| return downloaded |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Download ONNX models from Hugging Face") |
| parser.add_argument("--mode", default=None, choices=VALID_MODES, help="Single mode to download (default: all)") |
| parser.add_argument("--revision", default="main") |
| args = parser.parse_args() |
|
|
| if args.mode: |
| download_model(args.mode, args.revision) |
| else: |
| downloaded = download_all(args.revision) |
| print(f"Downloaded models: {downloaded}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|