File size: 3,529 Bytes
e317cf2
 
 
 
 
 
 
453cc78
e317cf2
 
 
 
 
 
 
 
 
 
 
 
 
453cc78
e317cf2
 
 
453cc78
 
 
e317cf2
 
 
453cc78
e317cf2
 
 
453cc78
 
 
 
 
 
 
e317cf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddb0552
e317cf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddb0552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e317cf2
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import json
import pickle
import shutil
from pathlib import Path

import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import EntryNotFoundError


def _normalize_state_dict(raw_obj):
    if isinstance(raw_obj, dict) and "state_dict" in raw_obj and isinstance(raw_obj["state_dict"], dict):
        raw_obj = raw_obj["state_dict"]

    if not isinstance(raw_obj, dict):
        raise ValueError("Checkpoint is not a valid state_dict dictionary")

    return {k.replace("module.", "", 1): v for k, v in raw_obj.items()}


def _download_roberta_tokenizer_files(local_dir: Path):
    required_files = [
        "tokenizer_config.json",
        "vocab.json",
        "merges.txt",
    ]
    optional_files = [
        "special_tokens_map.json",
        "tokenizer.json",
    ]

    for name in required_files:
        downloaded = hf_hub_download(repo_id="roberta-base", filename=name)
        shutil.copy2(downloaded, local_dir / name)

    for name in optional_files:
        try:
            downloaded = hf_hub_download(repo_id="roberta-base", filename=name)
            shutil.copy2(downloaded, local_dir / name)
        except EntryNotFoundError:
            print(f"Optional tokenizer file not found and skipped: {name}")


def main():
    root = Path(".")
    model_ckpt = root / "multitask_model.pth"
    label_encoder_path = root / "label_encoder.pkl"

    if not model_ckpt.exists():
        raise FileNotFoundError("multitask_model.pth not found")
    if not label_encoder_path.exists():
        raise FileNotFoundError("label_encoder.pkl not found")

    with open(label_encoder_path, "rb") as file:
        label_encoder = pickle.load(file)

    num_ai_classes = len(label_encoder.classes_)

    config = {
        "architectures": ["SuaveMultitaskModel"],
        "model_type": "suave_multitask",
        "base_model_name": "roberta-base",
        "num_ai_classes": num_ai_classes,
        "classifier_dropout": 0.1,
        "tokenizer_class": "RobertaTokenizerFast",
        "id2label": {"0": "human", "1": "ai"},
        "label2id": {"human": 0, "ai": 1},
        "auto_map": {
            "AutoConfig": "configuration_suave_multitask.SuaveMultitaskConfig",
            "AutoModel": "modeling_suave_multitask.SuaveMultitaskModel",
        },
    }

    with open(root / "config.json", "w", encoding="utf-8") as file:
        json.dump(config, file, indent=2)

    state_dict = torch.load(model_ckpt, map_location="cpu")
    state_dict = _normalize_state_dict(state_dict)
    torch.save(state_dict, root / "pytorch_model.bin")

    _download_roberta_tokenizer_files(root)

    tokenizer_config = {
        "tokenizer_class": "RobertaTokenizerFast",
        "model_max_length": 512,
        "padding_side": "right",
        "truncation_side": "right",
    }
    with open(root / "tokenizer_config.json", "w", encoding="utf-8") as file:
        json.dump(tokenizer_config, file, indent=2)

    special_tokens = {
        "bos_token": "<s>",
        "eos_token": "</s>",
        "unk_token": "<unk>",
        "sep_token": "</s>",
        "pad_token": "<pad>",
        "cls_token": "<s>",
        "mask_token": "<mask>",
    }
    with open(root / "special_tokens_map.json", "w", encoding="utf-8") as file:
        json.dump(special_tokens, file, indent=2)

    print("HF artifacts generated: config.json, pytorch_model.bin, tokenizer files")


if __name__ == "__main__":
    main()