import json import pickle import shutil from pathlib import Path import torch from huggingface_hub import hf_hub_download from huggingface_hub.errors import EntryNotFoundError def _normalize_state_dict(raw_obj): if isinstance(raw_obj, dict) and "state_dict" in raw_obj and isinstance(raw_obj["state_dict"], dict): raw_obj = raw_obj["state_dict"] if not isinstance(raw_obj, dict): raise ValueError("Checkpoint is not a valid state_dict dictionary") return {k.replace("module.", "", 1): v for k, v in raw_obj.items()} def _download_roberta_tokenizer_files(local_dir: Path): required_files = [ "tokenizer_config.json", "vocab.json", "merges.txt", ] optional_files = [ "special_tokens_map.json", "tokenizer.json", ] for name in required_files: downloaded = hf_hub_download(repo_id="roberta-base", filename=name) shutil.copy2(downloaded, local_dir / name) for name in optional_files: try: downloaded = hf_hub_download(repo_id="roberta-base", filename=name) shutil.copy2(downloaded, local_dir / name) except EntryNotFoundError: print(f"Optional tokenizer file not found and skipped: {name}") def main(): root = Path(".") model_ckpt = root / "multitask_model.pth" label_encoder_path = root / "label_encoder.pkl" if not model_ckpt.exists(): raise FileNotFoundError("multitask_model.pth not found") if not label_encoder_path.exists(): raise FileNotFoundError("label_encoder.pkl not found") with open(label_encoder_path, "rb") as file: label_encoder = pickle.load(file) num_ai_classes = len(label_encoder.classes_) config = { "architectures": ["SuaveMultitaskModel"], "model_type": "suave_multitask", "base_model_name": "roberta-base", "num_ai_classes": num_ai_classes, "classifier_dropout": 0.1, "tokenizer_class": "RobertaTokenizerFast", "id2label": {"0": "human", "1": "ai"}, "label2id": {"human": 0, "ai": 1}, "auto_map": { "AutoConfig": "configuration_suave_multitask.SuaveMultitaskConfig", "AutoModel": "modeling_suave_multitask.SuaveMultitaskModel", }, } with open(root / "config.json", "w", encoding="utf-8") as file: json.dump(config, file, indent=2) state_dict = torch.load(model_ckpt, map_location="cpu") state_dict = _normalize_state_dict(state_dict) torch.save(state_dict, root / "pytorch_model.bin") _download_roberta_tokenizer_files(root) tokenizer_config = { "tokenizer_class": "RobertaTokenizerFast", "model_max_length": 512, "padding_side": "right", "truncation_side": "right", } with open(root / "tokenizer_config.json", "w", encoding="utf-8") as file: json.dump(tokenizer_config, file, indent=2) special_tokens = { "bos_token": "", "eos_token": "", "unk_token": "", "sep_token": "", "pad_token": "", "cls_token": "", "mask_token": "", } with open(root / "special_tokens_map.json", "w", encoding="utf-8") as file: json.dump(special_tokens, file, indent=2) print("HF artifacts generated: config.json, pytorch_model.bin, tokenizer files") if __name__ == "__main__": main()