File size: 2,862 Bytes
8a3099e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/env python3
"""Push a locally trained CodeBERT model to Hugging Face Hub."""

from __future__ import annotations

import argparse
import json
import os
import sys
from pathlib import Path

PROJECT_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(PROJECT_ROOT))

from huggingface_hub import HfApi, create_repo

from src.codebert_labels import load_codebert_labels

DEFAULT_MODEL_DIR = PROJECT_ROOT / "models" / "codebert-cross-encoder"
MODEL_CARD = PROJECT_ROOT / "hub" / "CODEBERT_MODEL_CARD.md"


def push(
    model_dir: Path = DEFAULT_MODEL_DIR,
    repo_id: str = "",
    private: bool = False,
    token: str | None = None,
) -> str:
    if not model_dir.exists():
        raise FileNotFoundError(f"Model not found at {model_dir}. Train first.")

    if not repo_id:
        raise ValueError("--repo-id required, e.g. nishu08/sql-codebert-classifier")

    token = token or os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
    if not token:
        raise ValueError("Set HF_TOKEN environment variable")

    # Ensure label config exists for inference
    label_config = model_dir / "label_config.json"
    if not label_config.exists():
        with open(label_config, "w") as f:
            json.dump(
                {
                    "labels": load_codebert_labels(),
                    "model_name": "microsoft/codebert-base",
                    "architecture": "codebert-cross-encoder",
                    "threshold": 0.5,
                    "max_length": 512,
                },
                f,
                indent=2,
            )

    api = HfApi(token=token)
    create_repo(repo_id, repo_type="model", private=private, exist_ok=True, token=token)

    print(f"Uploading {model_dir}{repo_id} ...")
    api.upload_folder(
        folder_path=str(model_dir),
        repo_id=repo_id,
        repo_type="model",
        token=token,
        commit_message="Upload CodeBERT SQL error classifier",
    )

    if MODEL_CARD.exists():
        api.upload_file(
            path_or_fileobj=str(MODEL_CARD),
            path_in_repo="README.md",
            repo_id=repo_id,
            repo_type="model",
            token=token,
        )

    url = f"https://huggingface.co/{repo_id}"
    print(f"Done: {url}")
    return url


def main() -> None:
    parser = argparse.ArgumentParser(description="Push CodeBERT model to HF Hub")
    parser.add_argument("--model-dir", type=Path, default=DEFAULT_MODEL_DIR)
    parser.add_argument("--repo-id", type=str, required=True)
    parser.add_argument("--private", action="store_true")
    parser.add_argument("--token", type=str, default=None)
    args = parser.parse_args()
    push(
        model_dir=args.model_dir,
        repo_id=args.repo_id,
        private=args.private,
        token=args.token,
    )


if __name__ == "__main__":
    main()