|
|
| import torch |
| import argparse |
| from torch.nn import functional as F |
| import time |
| from attention_head import AttentionHead,Head, MultiHeadAttention, TransFormerBlock |
| torch.manual_seed(1337) |
|
|
| def get_batch(batch_size, dataset, block_size): |
| sample = torch.randint(high=len(dataset)- (block_size +1), size = (batch_size, 1)) |
| xb = torch.zeros(batch_size,block_size, dtype=torch.long) |
| yb = torch.zeros(batch_size,block_size, dtype=torch.long) |
| for idx, sample_index in enumerate(sample): |
| xb[idx,:] = dataset[sample_index:sample_index+block_size] |
| yb[idx,:] = dataset[sample_index+1:sample_index+block_size+1] |
| return xb, yb |
|
|
| @torch.no_grad() |
| def eval(model, batch_size, block_size, dataset): |
| xb, yb = get_batch(batch_size, dataset, block_size) |
| logits, loss = model(xb, yb) |
| return loss.item() |
|
|
| def train(model, optimizer, batch_size, block_size, train_ds, val_ds, steps): |
| sumloss = 0 |
| for _ in range(1,steps+1): |
| xb, yb = get_batch(batch_size, train_ds, block_size) |
| logits, loss = model(xb, yb) |
| sumloss += loss.item() |
| optimizer.zero_grad(set_to_none=True) |
| loss.backward() |
| optimizer.step() |
| if _ % 1000 == 0: |
| val_loss = eval(model, 30, block_size, val_ds,) |
| print(f"step {_} || train loss: {sumloss/1000} , val loss: {val_loss}") |
|
|
| sumloss = 0 |
|
|
| class Transformer(torch.nn.Module): |
| def __init__(self,vocab_size,n_tf=3, block_size=8,token_embed_dim=16) -> None: |
| super().__init__() |
| self.block_size=block_size |
| self.token_embedding_table = torch.nn.Embedding(vocab_size, token_embed_dim) |
| self.positional_embedding = torch.nn.Embedding(block_size, token_embed_dim) |
| self.tf_blocks = torch.nn.Sequential( |
| *[TransFormerBlock(token_embed_dim, block_size, 16, 8) for _ in range(n_tf)] |
| ) |
| self.lm_head = torch.nn.Linear(128, vocab_size) |
| def forward(self, idx, targets=None): |
| B,T=idx.shape |
| token_embed = self.token_embedding_table(idx) |
| positional_embed = self.positional_embedding(torch.arange(T)) |
| x = token_embed+positional_embed |
| x= self.tf_blocks(x) |
| logits = self.lm_head(x) |
|
|
| if targets is None: |
| loss = None |
| else: |
| B, T, C = logits.shape |
| logits = logits.view(B*T, C) |
| targets = targets.view(B*T) |
| loss = F.cross_entropy(logits, targets) |
| return logits, loss |
| def generate(self, idx, max_new_tokens): |
| |
| for _ in range(max_new_tokens): |
| |
| logits, loss = self(idx[:, -self.block_size:]) |
| |
| logits = logits[:, -1, :] |
| |
| probs = F.softmax(logits, dim=-1) |
| |
| idx_next = torch.multinomial(probs, num_samples=1) |
| |
| idx = torch.cat((idx, idx_next), dim=1) |
| return idx |
| class BigramLanguageModel(torch.nn.Module): |
| def __init__(self, vocab_size,block_size=8,token_embed_dim=16): |
| super().__init__() |
| self.token_embedding_table = torch.nn.Embedding(vocab_size, token_embed_dim) |
| self.positional_embedding = torch.nn.Embedding(block_size, token_embed_dim) |
| self.attention_head = MultiHeadAttention(n_embed=token_embed_dim, |
| timesteps=block_size, |
| head_size=token_embed_dim//4, |
| n_heads=4) |
| self.lm_head = torch.nn.Linear(token_embed_dim, vocab_size) |
| self.block_size = block_size |
| def forward(self, idx, targets=None): |
| B, T = idx.shape |
| |
| token_embedding = self.token_embedding_table(idx) |
| positional_embedding = self.positional_embedding(torch.arange(T,dtype=torch.long)) |
| x = token_embedding + positional_embedding |
| x = self.attention_head(x) |
| logits = self.lm_head(x) |
| if targets is None: |
| loss = None |
| else: |
| B, T, C = logits.shape |
| logits = logits.view(B*T, C) |
| targets = targets.view(B*T) |
| loss = F.cross_entropy(logits, targets) |
| return logits, loss |
|
|
| def generate(self, idx, max_new_tokens): |
| |
| for _ in range(max_new_tokens): |
| |
| logits, loss = self(idx[:, -self.block_size:]) |
| |
| logits = logits[:, -1, :] |
| |
| probs = F.softmax(logits, dim=-1) |
| |
| idx_next = torch.multinomial(probs, num_samples=1) |
| |
| idx = torch.cat((idx, idx_next), dim=1) |
| return idx |
| def main(): |
| |
| |
| batch_size = 32 |
| block_size= 128 |
| n_embed = 128 |
| n_tf = 3 |
| n_heads=8 |
| head_size=16 |
| vocab_size=65 |
| |
| parser = argparse.ArgumentParser( |
| description='Train a bigram language model' |
| ) |
| parser.add_argument('-c', '--cont', action='store_true',) |
| parser.add_argument('-e', '--eval', action='store_true',) |
| parser.add_argument('-v', '--verbose',action='store_true') |
| text = open('input.txt').read() |
| characters = sorted(list(set(text))) |
| decoder = dict(enumerate(characters)) |
| encoder = {v: k for k, v in decoder.items()} |
| encode = lambda x: encoder[x] |
| decode = lambda x: decoder[x] |
| text_tensor = torch.tensor([encode(c) for c in text]) |
| train_tensor = text_tensor[:int(len(text_tensor) * 0.8)] |
| val_tensor = text_tensor[int(len(text_tensor) * 0.8):] |
| model = Transformer(vocab_size=vocab_size, n_tf=n_tf,block_size=block_size, token_embed_dim=n_embed) |
| if parser.parse_args().verbose: |
| print(model) |
| num_params: int = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print('parameters:', num_params) |
| |
| if parser.parse_args().cont: |
| state_dict = torch.load('transformer.pth') |
| model.load_state_dict(state_dict) |
| optimizer = torch.optim.Adam(model.parameters(), lr=3e-5) |
| s = time.time() |
| if not parser.parse_args().eval: |
| try: |
| train(model, optimizer, batch_size=batch_size, block_size=block_size, train_ds=train_tensor, val_ds=val_tensor,steps= 100000) |
| except KeyboardInterrupt: |
| torch.save(model.state_dict(), 'transformer.pth') |
| exit() |
| if parser.parse_args().verbose: |
| print('training time: ', time.time() - s) |
| torch.save(model.state_dict(), 'transformer.pth') |
| model.eval() |
| print(''.join([decode(c) for c in model.generate(torch.zeros(1,32, dtype=torch.long), 1000)[0].tolist()[32:]])) |
| |
| if __name__ == '__main__': |
| main() |
|
|
|
|