| import argparse |
| import os |
| from pathlib import Path |
| from dotenv import load_dotenv |
| from huggingface_hub import HfApi |
|
|
| load_dotenv() |
|
|
| MODELS_ROOT = Path("models") |
| VALID_MODES = ("marker", "qa_m", "qa_b", "fasttext") |
| REPO_PREFIX = "lamossta/distillbert" |
| FASTTEXT_REPO = "lamossta/fasttext_baseline" |
|
|
|
|
| def _repo_id(mode: str) -> str: |
| if mode == "fasttext": |
| return FASTTEXT_REPO |
| return f"{REPO_PREFIX}_{mode}" |
|
|
|
|
| def _model_filename(mode: str) -> str: |
| return "model.bin" if mode == "fasttext" else "model.onnx" |
|
|
|
|
| def upload_model(mode: str, revision: str = "main") -> None: |
| token = os.environ.get("HF_TOKEN") |
| api = HfApi(token=token) |
| model_dir = MODELS_ROOT / mode |
| repo_id = _repo_id(mode) |
| filename = _model_filename(mode) |
| model_path = model_dir / filename |
|
|
| if not model_path.exists(): |
| raise FileNotFoundError(f"Model file not found: {model_path}") |
|
|
| api.create_repo(repo_id, exist_ok=True, token=token) |
|
|
| api.upload_file( |
| repo_id=repo_id, |
| path_or_fileobj=str(model_path), |
| path_in_repo=filename, |
| revision=revision, |
| token=token, |
| ) |
| print(f"Uploaded '{mode}' model to {repo_id}/{filename}") |
|
|
|
|
| def upload_all(revision: str = "main") -> list[str]: |
| uploaded = [] |
| for mode in VALID_MODES: |
| model_path = MODELS_ROOT / mode / _model_filename(mode) |
| if not model_path.exists(): |
| print(f"Skipping '{mode}': {model_path} not found") |
| continue |
| upload_model(mode, revision) |
| uploaded.append(mode) |
| return uploaded |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Upload ONNX models to Hugging Face") |
| parser.add_argument("--mode", default=None, choices=VALID_MODES, help="Single mode to upload (default: all)") |
| parser.add_argument("--revision", default="main") |
| args = parser.parse_args() |
|
|
| if args.mode: |
| upload_model(args.mode, args.revision) |
| else: |
| uploaded = upload_all(args.revision) |
| print(f"Uploaded models: {uploaded}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|