DaJulster commited on
Commit
e317cf2
·
verified ·
1 Parent(s): 6316722

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. prepare_hf_artifacts_light.py +76 -0
  2. upload.py +1 -1
prepare_hf_artifacts_light.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pickle
3
+ import shutil
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ from huggingface_hub import hf_hub_download
8
+
9
+
10
+ def _normalize_state_dict(raw_obj):
11
+ if isinstance(raw_obj, dict) and "state_dict" in raw_obj and isinstance(raw_obj["state_dict"], dict):
12
+ raw_obj = raw_obj["state_dict"]
13
+
14
+ if not isinstance(raw_obj, dict):
15
+ raise ValueError("Checkpoint is not a valid state_dict dictionary")
16
+
17
+ return {k.replace("module.", "", 1): v for k, v in raw_obj.items()}
18
+
19
+
20
+ def _download_roberta_tokenizer_files(local_dir: Path):
21
+ files = [
22
+ "tokenizer_config.json",
23
+ "special_tokens_map.json",
24
+ "vocab.json",
25
+ "merges.txt",
26
+ "tokenizer.json",
27
+ ]
28
+
29
+ for name in files:
30
+ downloaded = hf_hub_download(repo_id="roberta-base", filename=name)
31
+ shutil.copy2(downloaded, local_dir / name)
32
+
33
+
34
+ def main():
35
+ root = Path(".")
36
+ model_ckpt = root / "multitask_model.pth"
37
+ label_encoder_path = root / "label_encoder.pkl"
38
+
39
+ if not model_ckpt.exists():
40
+ raise FileNotFoundError("multitask_model.pth not found")
41
+ if not label_encoder_path.exists():
42
+ raise FileNotFoundError("label_encoder.pkl not found")
43
+
44
+ with open(label_encoder_path, "rb") as file:
45
+ label_encoder = pickle.load(file)
46
+
47
+ num_ai_classes = len(label_encoder.classes_)
48
+
49
+ config = {
50
+ "architectures": ["SuaveMultitaskModel"],
51
+ "model_type": "suave_multitask",
52
+ "base_model_name": "roberta-base",
53
+ "num_ai_classes": num_ai_classes,
54
+ "classifier_dropout": 0.1,
55
+ "id2label": {"0": "human", "1": "ai"},
56
+ "label2id": {"human": 0, "ai": 1},
57
+ "auto_map": {
58
+ "AutoConfig": "configuration_suave_multitask.SuaveMultitaskConfig",
59
+ "AutoModel": "modeling_suave_multitask.SuaveMultitaskModel",
60
+ },
61
+ }
62
+
63
+ with open(root / "config.json", "w", encoding="utf-8") as file:
64
+ json.dump(config, file, indent=2)
65
+
66
+ state_dict = torch.load(model_ckpt, map_location="cpu")
67
+ state_dict = _normalize_state_dict(state_dict)
68
+ torch.save(state_dict, root / "pytorch_model.bin")
69
+
70
+ _download_roberta_tokenizer_files(root)
71
+
72
+ print("HF artifacts generated: config.json, pytorch_model.bin, tokenizer files")
73
+
74
+
75
+ if __name__ == "__main__":
76
+ main()
upload.py CHANGED
@@ -24,7 +24,7 @@ api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
24
  # 2. Generate HF-compatible artifacts from existing checkpoint (optional)
25
  skip_prepare = os.environ.get("SKIP_HF_PREPARE", "0") == "1"
26
  if not skip_prepare:
27
- from prepare_hf_artifacts import main as prepare_hf_artifacts
28
 
29
  prepare_hf_artifacts()
30
  else:
 
24
  # 2. Generate HF-compatible artifacts from existing checkpoint (optional)
25
  skip_prepare = os.environ.get("SKIP_HF_PREPARE", "0") == "1"
26
  if not skip_prepare:
27
+ from prepare_hf_artifacts_light import main as prepare_hf_artifacts
28
 
29
  prepare_hf_artifacts()
30
  else: