| import json | |
| import pickle | |
| import shutil | |
| from pathlib import Path | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from huggingface_hub.errors import EntryNotFoundError | |
| def _normalize_state_dict(raw_obj): | |
| if isinstance(raw_obj, dict) and "state_dict" in raw_obj and isinstance(raw_obj["state_dict"], dict): | |
| raw_obj = raw_obj["state_dict"] | |
| if not isinstance(raw_obj, dict): | |
| raise ValueError("Checkpoint is not a valid state_dict dictionary") | |
| return {k.replace("module.", "", 1): v for k, v in raw_obj.items()} | |
| def _download_roberta_tokenizer_files(local_dir: Path): | |
| required_files = [ | |
| "tokenizer_config.json", | |
| "vocab.json", | |
| "merges.txt", | |
| ] | |
| optional_files = [ | |
| "special_tokens_map.json", | |
| "tokenizer.json", | |
| ] | |
| for name in required_files: | |
| downloaded = hf_hub_download(repo_id="roberta-base", filename=name) | |
| shutil.copy2(downloaded, local_dir / name) | |
| for name in optional_files: | |
| try: | |
| downloaded = hf_hub_download(repo_id="roberta-base", filename=name) | |
| shutil.copy2(downloaded, local_dir / name) | |
| except EntryNotFoundError: | |
| print(f"Optional tokenizer file not found and skipped: {name}") | |
| def main(): | |
| root = Path(".") | |
| model_ckpt = root / "multitask_model.pth" | |
| label_encoder_path = root / "label_encoder.pkl" | |
| if not model_ckpt.exists(): | |
| raise FileNotFoundError("multitask_model.pth not found") | |
| if not label_encoder_path.exists(): | |
| raise FileNotFoundError("label_encoder.pkl not found") | |
| with open(label_encoder_path, "rb") as file: | |
| label_encoder = pickle.load(file) | |
| num_ai_classes = len(label_encoder.classes_) | |
| config = { | |
| "architectures": ["SuaveMultitaskModel"], | |
| "model_type": "suave_multitask", | |
| "base_model_name": "roberta-base", | |
| "num_ai_classes": num_ai_classes, | |
| "classifier_dropout": 0.1, | |
| "tokenizer_class": "RobertaTokenizerFast", | |
| "id2label": {"0": "human", "1": "ai"}, | |
| "label2id": {"human": 0, "ai": 1}, | |
| "auto_map": { | |
| "AutoConfig": "configuration_suave_multitask.SuaveMultitaskConfig", | |
| "AutoModel": "modeling_suave_multitask.SuaveMultitaskModel", | |
| }, | |
| } | |
| with open(root / "config.json", "w", encoding="utf-8") as file: | |
| json.dump(config, file, indent=2) | |
| state_dict = torch.load(model_ckpt, map_location="cpu") | |
| state_dict = _normalize_state_dict(state_dict) | |
| torch.save(state_dict, root / "pytorch_model.bin") | |
| _download_roberta_tokenizer_files(root) | |
| tokenizer_config = { | |
| "tokenizer_class": "RobertaTokenizerFast", | |
| "model_max_length": 512, | |
| "padding_side": "right", | |
| "truncation_side": "right", | |
| } | |
| with open(root / "tokenizer_config.json", "w", encoding="utf-8") as file: | |
| json.dump(tokenizer_config, file, indent=2) | |
| special_tokens = { | |
| "bos_token": "<s>", | |
| "eos_token": "</s>", | |
| "unk_token": "<unk>", | |
| "sep_token": "</s>", | |
| "pad_token": "<pad>", | |
| "cls_token": "<s>", | |
| "mask_token": "<mask>", | |
| } | |
| with open(root / "special_tokens_map.json", "w", encoding="utf-8") as file: | |
| json.dump(special_tokens, file, indent=2) | |
| print("HF artifacts generated: config.json, pytorch_model.bin, tokenizer files") | |
| if __name__ == "__main__": | |
| main() | |