| """ |
| Recursive Language Model - ZeroNet |
| """ |
| import torch |
| import torch.nn as nn |
|
|
| class RecursiveBlock(nn.Module): |
| """A recursive processing block that can be applied multiple times""" |
| def __init__(self, config): |
| super().__init__() |
| self.attention = nn.MultiheadAttention( |
| embed_dim=config.hidden_size, |
| num_heads=config.num_heads, |
| dropout=config.dropout, |
| batch_first=True |
| ) |
| self.feed_forward = nn.Sequential( |
| nn.Linear(config.hidden_size, config.hidden_size * 4), |
| nn.GELU(), |
| nn.Dropout(config.dropout), |
| nn.Linear(config.hidden_size * 4, config.hidden_size), |
| nn.Dropout(config.dropout) |
| ) |
| self.layer_norm1 = nn.LayerNorm(config.hidden_size) |
| self.layer_norm2 = nn.LayerNorm(config.hidden_size) |
| |
| def forward(self, x, mask=None): |
| attn_output, _ = self.attention(x, x, x, attn_mask=mask) |
| x = self.layer_norm1(x + attn_output) |
| ff_output = self.feed_forward(x) |
| x = self.layer_norm2(x + ff_output) |
| return x |
|
|
| class RecursiveLanguageModel(nn.Module): |
| """Recursive Language Model that recursively processes sequences""" |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| |
| self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.position_embedding = nn.Embedding(config.max_seq_length, config.hidden_size) |
| self.dropout = nn.Dropout(config.dropout) |
| |
| self.recursive_block = RecursiveBlock(config) |
| self.layers = nn.ModuleList([ |
| RecursiveBlock(config) for _ in range(config.num_layers) |
| ]) |
| |
| self.layer_norm = nn.LayerNorm(config.hidden_size) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| self.lm_head.weight = self.token_embedding.weight |
| |
| def forward(self, input_ids, attention_mask=None, labels=None): |
| batch_size, seq_length = input_ids.shape |
| |
| position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) |
| position_ids = position_ids.unsqueeze(0).expand_as(input_ids) |
| |
| token_embeds = self.token_embedding(input_ids) |
| position_embeds = self.position_embedding(position_ids) |
| hidden_states = self.dropout(token_embeds + position_embeds) |
| |
| causal_mask = torch.triu( |
| torch.ones(seq_length, seq_length, device=input_ids.device) * float('-inf'), |
| diagonal=1 |
| ) |
| |
| for _ in range(self.config.recursive_depth): |
| hidden_states = self.recursive_block(hidden_states, mask=causal_mask) |
| |
| for layer in self.layers: |
| hidden_states = layer(hidden_states, mask=causal_mask) |
| |
| hidden_states = self.layer_norm(hidden_states) |
| logits = self.lm_head(hidden_states) |
| |
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
| |
| return {"loss": loss, "logits": logits} |
|
|