import torch from torch import nn class AberLanguageModel(nn.Module): def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, dropout): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.gru = nn.GRU( input_size=embed_dim, hidden_size=hidden_dim, num_layers=num_layers, dropout=dropout if num_layers > 1 else 0.0, batch_first=True, ) self.dropout = nn.Dropout(dropout) self.head = nn.Linear(hidden_dim, vocab_size) def forward(self, idx, hidden=None, targets=None): emb = self.embedding(idx) out, hidden = self.gru(emb, hidden) out = self.dropout(out) logits = self.head(out) loss = None if targets is not None: loss = nn.functional.cross_entropy( logits.reshape(-1, logits.size(-1)), targets.reshape(-1), ) return logits, hidden, loss def generate(self, idx, max_new_tokens, eos_id, temperature=1.0, top_k=8): hidden = None for _ in range(max_new_tokens): logits, hidden, _ = self(idx[:, -1:], hidden) next_logits = logits[:, -1, :] / max(temperature, 1e-4) if top_k is not None and top_k > 0: values, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1))) next_logits[next_logits < values[:, [-1]]] = float("-inf") probs = torch.softmax(next_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) idx = torch.cat([idx, next_token], dim=1) if int(next_token.item()) == eos_id: break return idx