File size: 1,550 Bytes
6316722 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | from pathlib import Path
import pickle
import torch
from transformers import AutoTokenizer
from configuration_suave_multitask import SuaveMultitaskConfig
from modeling_suave_multitask import SuaveMultitaskModel
def main():
model_ckpt = Path("multitask_model.pth")
label_encoder_path = Path("label_encoder.pkl")
if not model_ckpt.exists():
raise FileNotFoundError("multitask_model.pth not found")
if not label_encoder_path.exists():
raise FileNotFoundError("label_encoder.pkl not found")
with open(label_encoder_path, "rb") as file:
label_encoder = pickle.load(file)
num_ai_classes = len(label_encoder.classes_)
config = SuaveMultitaskConfig(
base_model_name="roberta-base",
num_ai_classes=num_ai_classes,
id2label={0: "human", 1: "ai"},
label2id={"human": 0, "ai": 1},
)
config.auto_map = {
"AutoConfig": "configuration_suave_multitask.SuaveMultitaskConfig",
"AutoModel": "modeling_suave_multitask.SuaveMultitaskModel",
}
model = SuaveMultitaskModel(config)
state_dict = torch.load(model_ckpt, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
model.eval()
model.save_pretrained(".", safe_serialization=True)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name)
tokenizer.save_pretrained(".")
print("HF artifacts generated: config.json, model.safetensors, tokenizer files")
if __name__ == "__main__":
main()
|