Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| class AberLanguageModel(nn.Module): | |
| def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, embed_dim) | |
| self.gru = nn.GRU( | |
| input_size=embed_dim, | |
| hidden_size=hidden_dim, | |
| num_layers=num_layers, | |
| dropout=dropout if num_layers > 1 else 0.0, | |
| batch_first=True, | |
| ) | |
| self.dropout = nn.Dropout(dropout) | |
| self.head = nn.Linear(hidden_dim, vocab_size) | |
| def forward(self, idx, hidden=None, targets=None): | |
| emb = self.embedding(idx) | |
| out, hidden = self.gru(emb, hidden) | |
| out = self.dropout(out) | |
| logits = self.head(out) | |
| loss = None | |
| if targets is not None: | |
| loss = nn.functional.cross_entropy( | |
| logits.reshape(-1, logits.size(-1)), | |
| targets.reshape(-1), | |
| ) | |
| return logits, hidden, loss | |
| def generate(self, idx, max_new_tokens, eos_id, temperature=1.0, top_k=8): | |
| hidden = None | |
| for _ in range(max_new_tokens): | |
| logits, hidden, _ = self(idx[:, -1:], hidden) | |
| next_logits = logits[:, -1, :] / max(temperature, 1e-4) | |
| if top_k is not None and top_k > 0: | |
| values, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1))) | |
| next_logits[next_logits < values[:, [-1]]] = float("-inf") | |
| probs = torch.softmax(next_logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| idx = torch.cat([idx, next_token], dim=1) | |
| if int(next_token.item()) == eos_id: | |
| break | |
| return idx | |