| import os | |
| import pickle | |
| import numpy as np | |
| input_file_path = os.path.join('data', 'chats', 'data.txt') | |
| with open(input_file_path, 'r') as f: | |
| data = f.read() | |
| chars = sorted(list(set(data))) | |
| vocab_size = len(chars) | |
| stoi = {ch: i for i, ch in enumerate(chars)} | |
| itos = {i: ch for i, ch in enumerate(chars)} | |
| def encode(s): | |
| return [stoi[c] for c in s] | |
| def decode(l): | |
| return ''.join([itos[i] for i in l]) | |
| n = len(data) | |
| train_data = data[:int(n*0.9)] | |
| val_data = data[int(n*0.9):] | |
| train_ids = encode(train_data) | |
| val_ids = encode(val_data) | |
| train_ids = np.array(train_ids, dtype=np.uint16) | |
| val_ids = np.array(val_ids, dtype=np.uint16) | |
| train_ids.tofile(os.path.join('data', 'chats', 'train.bin')) | |
| val_ids.tofile(os.path.join('data', 'chats', 'val.bin')) | |
| meta = { | |
| 'vocab_size': vocab_size, | |
| 'itos': itos, | |
| 'stoi': stoi, | |
| } | |
| with open(os.path.join('data', 'chats', 'meta.pkl'), 'wb') as f: | |
| pickle.dump(meta, f) | |