File size: 3,529 Bytes
e317cf2 453cc78 e317cf2 453cc78 e317cf2 453cc78 e317cf2 453cc78 e317cf2 453cc78 e317cf2 ddb0552 e317cf2 ddb0552 e317cf2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | import json
import pickle
import shutil
from pathlib import Path
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import EntryNotFoundError
def _normalize_state_dict(raw_obj):
if isinstance(raw_obj, dict) and "state_dict" in raw_obj and isinstance(raw_obj["state_dict"], dict):
raw_obj = raw_obj["state_dict"]
if not isinstance(raw_obj, dict):
raise ValueError("Checkpoint is not a valid state_dict dictionary")
return {k.replace("module.", "", 1): v for k, v in raw_obj.items()}
def _download_roberta_tokenizer_files(local_dir: Path):
required_files = [
"tokenizer_config.json",
"vocab.json",
"merges.txt",
]
optional_files = [
"special_tokens_map.json",
"tokenizer.json",
]
for name in required_files:
downloaded = hf_hub_download(repo_id="roberta-base", filename=name)
shutil.copy2(downloaded, local_dir / name)
for name in optional_files:
try:
downloaded = hf_hub_download(repo_id="roberta-base", filename=name)
shutil.copy2(downloaded, local_dir / name)
except EntryNotFoundError:
print(f"Optional tokenizer file not found and skipped: {name}")
def main():
root = Path(".")
model_ckpt = root / "multitask_model.pth"
label_encoder_path = root / "label_encoder.pkl"
if not model_ckpt.exists():
raise FileNotFoundError("multitask_model.pth not found")
if not label_encoder_path.exists():
raise FileNotFoundError("label_encoder.pkl not found")
with open(label_encoder_path, "rb") as file:
label_encoder = pickle.load(file)
num_ai_classes = len(label_encoder.classes_)
config = {
"architectures": ["SuaveMultitaskModel"],
"model_type": "suave_multitask",
"base_model_name": "roberta-base",
"num_ai_classes": num_ai_classes,
"classifier_dropout": 0.1,
"tokenizer_class": "RobertaTokenizerFast",
"id2label": {"0": "human", "1": "ai"},
"label2id": {"human": 0, "ai": 1},
"auto_map": {
"AutoConfig": "configuration_suave_multitask.SuaveMultitaskConfig",
"AutoModel": "modeling_suave_multitask.SuaveMultitaskModel",
},
}
with open(root / "config.json", "w", encoding="utf-8") as file:
json.dump(config, file, indent=2)
state_dict = torch.load(model_ckpt, map_location="cpu")
state_dict = _normalize_state_dict(state_dict)
torch.save(state_dict, root / "pytorch_model.bin")
_download_roberta_tokenizer_files(root)
tokenizer_config = {
"tokenizer_class": "RobertaTokenizerFast",
"model_max_length": 512,
"padding_side": "right",
"truncation_side": "right",
}
with open(root / "tokenizer_config.json", "w", encoding="utf-8") as file:
json.dump(tokenizer_config, file, indent=2)
special_tokens = {
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
"sep_token": "</s>",
"pad_token": "<pad>",
"cls_token": "<s>",
"mask_token": "<mask>",
}
with open(root / "special_tokens_map.json", "w", encoding="utf-8") as file:
json.dump(special_tokens, file, indent=2)
print("HF artifacts generated: config.json, pytorch_model.bin, tokenizer files")
if __name__ == "__main__":
main()
|