HarryPotterGPT / loading_helper.py
CamiloVega's picture
Upload improved HarryPotterGPT with better compatibility
6ec2e4a verified
raw
history blame contribute delete
626 Bytes
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
import os
def load_harrypotter_gpt(model_path, device="auto"):
'''Helper function to load HarryPotterGPT model and tokenizer'''
# Load the tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(model_path)
# Ensure the tokenizer has a pad token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load the model
model = GPT2LMHeadModel.from_pretrained(model_path)
return model, tokenizer
# Example usage:
# model, tokenizer = load_harrypotter_gpt("CamiloVega/HarryPotterGPT-v2")