Spaces:
Sleeping
Sleeping
| import torch | |
| import os | |
| from model import PotterGPT, Config, CharacterLevelTokenizer | |
| from tokenizers import Tokenizer | |
| from dataclasses import dataclass | |
| model_path = 'potterGPT/potterGPT.pth' | |
| with open('data/harry_potter_data', 'r', encoding='utf-8') as f: | |
| data = f.read() | |
| tokenizer = CharacterLevelTokenizer(data) | |
| lm = PotterGPT(Config) | |
| state_dict = torch.load(model_path, map_location='cpu') | |
| lm.load_state_dict(state_dict) | |
| generated_texts = [] | |
| for length in [1000]: | |
| generated = lm.generate( | |
| torch.zeros((1,1),dtype=torch.long,device='cpu') + 61, # initial context 0, 61 is \n | |
| total=length | |
| ) | |
| generated = tokenizer.decode(generated[0].cpu().numpy()) | |
| text=f'generated ({length} tokens)\n{"="*50}\n{generated}\n{"="*50}\n\n' | |
| generated_texts.append(text) | |
| print(generated_texts[0]) | |
| os.makedirs('output', exist_ok=True) | |
| with open('output/generated.txt', 'w+') as f: | |
| for text in generated_texts: | |
| f.write(text) |