from pathlib import Path import pickle import torch from transformers import AutoTokenizer from configuration_suave_multitask import SuaveMultitaskConfig from modeling_suave_multitask import SuaveMultitaskModel def main(): model_ckpt = Path("multitask_model.pth") label_encoder_path = Path("label_encoder.pkl") if not model_ckpt.exists(): raise FileNotFoundError("multitask_model.pth not found") if not label_encoder_path.exists(): raise FileNotFoundError("label_encoder.pkl not found") with open(label_encoder_path, "rb") as file: label_encoder = pickle.load(file) num_ai_classes = len(label_encoder.classes_) config = SuaveMultitaskConfig( base_model_name="roberta-base", num_ai_classes=num_ai_classes, id2label={0: "human", 1: "ai"}, label2id={"human": 0, "ai": 1}, ) config.auto_map = { "AutoConfig": "configuration_suave_multitask.SuaveMultitaskConfig", "AutoModel": "modeling_suave_multitask.SuaveMultitaskModel", } model = SuaveMultitaskModel(config) state_dict = torch.load(model_ckpt, map_location="cpu") model.load_state_dict(state_dict, strict=True) model.eval() model.save_pretrained(".", safe_serialization=True) tokenizer = AutoTokenizer.from_pretrained(config.base_model_name) tokenizer.save_pretrained(".") print("HF artifacts generated: config.json, model.safetensors, tokenizer files") if __name__ == "__main__": main()