Upload folder using huggingface_hub
Browse files- prepare_hf_artifacts_light.py +76 -0
- upload.py +1 -1
prepare_hf_artifacts_light.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import pickle
|
| 3 |
+
import shutil
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from huggingface_hub import hf_hub_download
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _normalize_state_dict(raw_obj):
|
| 11 |
+
if isinstance(raw_obj, dict) and "state_dict" in raw_obj and isinstance(raw_obj["state_dict"], dict):
|
| 12 |
+
raw_obj = raw_obj["state_dict"]
|
| 13 |
+
|
| 14 |
+
if not isinstance(raw_obj, dict):
|
| 15 |
+
raise ValueError("Checkpoint is not a valid state_dict dictionary")
|
| 16 |
+
|
| 17 |
+
return {k.replace("module.", "", 1): v for k, v in raw_obj.items()}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _download_roberta_tokenizer_files(local_dir: Path):
|
| 21 |
+
files = [
|
| 22 |
+
"tokenizer_config.json",
|
| 23 |
+
"special_tokens_map.json",
|
| 24 |
+
"vocab.json",
|
| 25 |
+
"merges.txt",
|
| 26 |
+
"tokenizer.json",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
for name in files:
|
| 30 |
+
downloaded = hf_hub_download(repo_id="roberta-base", filename=name)
|
| 31 |
+
shutil.copy2(downloaded, local_dir / name)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def main():
|
| 35 |
+
root = Path(".")
|
| 36 |
+
model_ckpt = root / "multitask_model.pth"
|
| 37 |
+
label_encoder_path = root / "label_encoder.pkl"
|
| 38 |
+
|
| 39 |
+
if not model_ckpt.exists():
|
| 40 |
+
raise FileNotFoundError("multitask_model.pth not found")
|
| 41 |
+
if not label_encoder_path.exists():
|
| 42 |
+
raise FileNotFoundError("label_encoder.pkl not found")
|
| 43 |
+
|
| 44 |
+
with open(label_encoder_path, "rb") as file:
|
| 45 |
+
label_encoder = pickle.load(file)
|
| 46 |
+
|
| 47 |
+
num_ai_classes = len(label_encoder.classes_)
|
| 48 |
+
|
| 49 |
+
config = {
|
| 50 |
+
"architectures": ["SuaveMultitaskModel"],
|
| 51 |
+
"model_type": "suave_multitask",
|
| 52 |
+
"base_model_name": "roberta-base",
|
| 53 |
+
"num_ai_classes": num_ai_classes,
|
| 54 |
+
"classifier_dropout": 0.1,
|
| 55 |
+
"id2label": {"0": "human", "1": "ai"},
|
| 56 |
+
"label2id": {"human": 0, "ai": 1},
|
| 57 |
+
"auto_map": {
|
| 58 |
+
"AutoConfig": "configuration_suave_multitask.SuaveMultitaskConfig",
|
| 59 |
+
"AutoModel": "modeling_suave_multitask.SuaveMultitaskModel",
|
| 60 |
+
},
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
with open(root / "config.json", "w", encoding="utf-8") as file:
|
| 64 |
+
json.dump(config, file, indent=2)
|
| 65 |
+
|
| 66 |
+
state_dict = torch.load(model_ckpt, map_location="cpu")
|
| 67 |
+
state_dict = _normalize_state_dict(state_dict)
|
| 68 |
+
torch.save(state_dict, root / "pytorch_model.bin")
|
| 69 |
+
|
| 70 |
+
_download_roberta_tokenizer_files(root)
|
| 71 |
+
|
| 72 |
+
print("HF artifacts generated: config.json, pytorch_model.bin, tokenizer files")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
main()
|
upload.py
CHANGED
|
@@ -24,7 +24,7 @@ api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
|
|
| 24 |
# 2. Generate HF-compatible artifacts from existing checkpoint (optional)
|
| 25 |
skip_prepare = os.environ.get("SKIP_HF_PREPARE", "0") == "1"
|
| 26 |
if not skip_prepare:
|
| 27 |
-
from
|
| 28 |
|
| 29 |
prepare_hf_artifacts()
|
| 30 |
else:
|
|
|
|
| 24 |
# 2. Generate HF-compatible artifacts from existing checkpoint (optional)
|
| 25 |
skip_prepare = os.environ.get("SKIP_HF_PREPARE", "0") == "1"
|
| 26 |
if not skip_prepare:
|
| 27 |
+
from prepare_hf_artifacts_light import main as prepare_hf_artifacts
|
| 28 |
|
| 29 |
prepare_hf_artifacts()
|
| 30 |
else:
|