Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import tqdm | |
| from torch.nn import functional as F | |
| from core.layers import LlamaBlock, RMSNorm | |
| class LlamaLanguageModel(nn.Module): | |
| def __init__(self, vocab_size, n_embd, block_size, n_head, n_layer, dropout, device, name = "llama"): | |
| super().__init__() | |
| self.name = name | |
| self.block_size = block_size | |
| self.device = device | |
| self.token_embedding_table = nn.Embedding(vocab_size, n_embd) | |
| self.blocks = nn.Sequential(*[LlamaBlock(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]) | |
| self.ln_f = RMSNorm(n_embd) | |
| self.lm_head = nn.Linear(n_embd, vocab_size) | |
| self.apply(self._init_weights) | |
| self.history = {} | |
| self.vocab_size = vocab_size | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| def forward(self, idx): | |
| B, T = idx.shape | |
| kv_cache = None | |
| token_embeddings = self.token_embedding_table(idx) | |
| for block in self.blocks: | |
| token_embeddings = block(token_embeddings, kv_cache) | |
| token_embeddings = self.ln_f(token_embeddings) | |
| logits = self.lm_head(token_embeddings) | |
| return logits, token_embeddings | |
| def generate(self, idx, max_new_tokens, max_seq_length=200, temperature=1.0): | |
| for _ in range(max_new_tokens): | |
| if idx.size(1) > max_seq_length: | |
| idx = idx[:, -max_seq_length:] | |
| idx_cond = idx[:, -self.block_size:] | |
| logits, _ = self(idx_cond) | |
| logits = logits[:, -1, :] / temperature | |
| probs = F.softmax(logits, dim=-1) | |
| idx_next = torch.multinomial(probs, num_samples=1) | |
| idx = torch.cat((idx, idx_next), dim=1) | |
| yield idx | |