zeronet / modeling_rlm.py
yashshinde0080's picture
Upload folder using huggingface_hub
4c37d51 verified
"""
Recursive Language Model - ZeroNet
"""
import torch
import torch.nn as nn
class RecursiveBlock(nn.Module):
"""A recursive processing block that can be applied multiple times"""
def __init__(self, config):
super().__init__()
self.attention = nn.MultiheadAttention(
embed_dim=config.hidden_size,
num_heads=config.num_heads,
dropout=config.dropout,
batch_first=True
)
self.feed_forward = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size * 4),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.hidden_size * 4, config.hidden_size),
nn.Dropout(config.dropout)
)
self.layer_norm1 = nn.LayerNorm(config.hidden_size)
self.layer_norm2 = nn.LayerNorm(config.hidden_size)
def forward(self, x, mask=None):
attn_output, _ = self.attention(x, x, x, attn_mask=mask)
x = self.layer_norm1(x + attn_output)
ff_output = self.feed_forward(x)
x = self.layer_norm2(x + ff_output)
return x
class RecursiveLanguageModel(nn.Module):
"""Recursive Language Model that recursively processes sequences"""
def __init__(self, config):
super().__init__()
self.config = config
self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embedding = nn.Embedding(config.max_seq_length, config.hidden_size)
self.dropout = nn.Dropout(config.dropout)
self.recursive_block = RecursiveBlock(config)
self.layers = nn.ModuleList([
RecursiveBlock(config) for _ in range(config.num_layers)
])
self.layer_norm = nn.LayerNorm(config.hidden_size)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.lm_head.weight = self.token_embedding.weight
def forward(self, input_ids, attention_mask=None, labels=None):
batch_size, seq_length = input_ids.shape
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
token_embeds = self.token_embedding(input_ids)
position_embeds = self.position_embedding(position_ids)
hidden_states = self.dropout(token_embeds + position_embeds)
causal_mask = torch.triu(
torch.ones(seq_length, seq_length, device=input_ids.device) * float('-inf'),
diagonal=1
)
for _ in range(self.config.recursive_depth):
hidden_states = self.recursive_block(hidden_states, mask=causal_mask)
for layer in self.layers:
hidden_states = layer(hidden_states, mask=causal_mask)
hidden_states = self.layer_norm(hidden_states)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return {"loss": loss, "logits": logits}