| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import PretrainedConfig, PreTrainedModel |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| | import math |
| | class RecursiveLanguageModelConfig(PretrainedConfig): |
| | model_type="recursive_language_model" |
| | def __init__( |
| | self, |
| | vocab_size=50260, |
| | embedding_dim=768, |
| | num_layers=16, |
| | num_attention_heads=12, |
| | max_recursion_steps=5, |
| | max_position_embeddings=512, |
| | hidden_dropout_prob=0.1, |
| | attention_dropout_prob=0.1, |
| | intermediate_size=3072, |
| | layer_norm_eps=1e-5, |
| | pad_token_id=50257, |
| | bos_token_id=50256, |
| | eos_token_id=50256, |
| | simple_recursion_steps=1, |
| | medium_recursion_steps=3, |
| | complex_recursion_steps=5, |
| | initializer_range=0.02, |
| | **kwargs |
| | ): |
| | super().__init__( |
| | pad_token_id=pad_token_id, |
| | bos_token_id=bos_token_id, |
| | eos_token_id=eos_token_id, |
| | **kwargs |
| | ) |
| | self.vocab_size=vocab_size |
| | self.embedding_dim=embedding_dim |
| | self.num_layers=num_layers |
| | self.num_attention_heads=num_attention_heads |
| | self.max_recursion_steps=max_recursion_steps |
| | self.max_position_embeddings=max_position_embeddings |
| | self.hidden_dropout_prob=hidden_dropout_prob |
| | self.attention_dropout_prob=attention_dropout_prob |
| | self.intermediate_size=intermediate_size |
| | self.layer_norm_eps=layer_norm_eps |
| | self.simple_recursion_steps=simple_recursion_steps |
| | self.medium_recursion_steps=medium_recursion_steps |
| | self.complex_recursion_steps=complex_recursion_steps |
| | self.initializer_range=initializer_range |
| | class RotaryPositionalEmbedding(nn.Module): |
| | def __init__(self,dim,max_seq_len=2048,base=10000): |
| | super().__init__() |
| | inv_freq=1.0/(base**(torch.arange(0,dim,2).float()/dim)) |
| | self.register_buffer('inv_freq',inv_freq) |
| | def forward(self,seq_len,device): |
| | t=torch.arange(seq_len,device=device).float() |
| | freqs=torch.outer(t,self.inv_freq) |
| | emb=torch.cat([freqs,freqs], dim=-1) |
| | return emb.cos(),emb.sin() |
| | def rotate_half(x): |
| | x1, x2=x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:] |
| | return torch.cat([-x2, x1],dim=-1) |
| | def apply_rotary_pos_emb(q, k, cos, sin): |
| | cos=cos.unsqueeze(0).unsqueeze(0) |
| | sin=sin.unsqueeze(0).unsqueeze(0) |
| | return(q*cos)+(rotate_half(q)*sin),(k*cos)+(rotate_half(k)*sin) |
| | class MultiHeadAttention(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.num_heads=config.num_attention_heads |
| | self.head_dim=config.embedding_dim // config.num_attention_heads |
| | self.embed_dim=config.embedding_dim |
| | assert self.embed_dim % self.num_heads==0 |
| | self.q_proj=nn.Linear(self.embed_dim, self.embed_dim, bias=False) |
| | self.k_proj=nn.Linear(self.embed_dim, self.embed_dim, bias=False) |
| | self.v_proj=nn.Linear(self.embed_dim, self.embed_dim, bias=False) |
| | self.out_proj=nn.Linear(self.embed_dim, self.embed_dim, bias=False) |
| | self.attn_drop=nn.Dropout(config.attention_dropout_prob) |
| | self.rope=RotaryPositionalEmbedding(self.head_dim, config.max_position_embeddings) |
| | def forward(self,x,causal_mask=None): |
| | B,T,C=x.shape |
| | q=self.q_proj(x).view(B,T,self.num_heads, self.head_dim).transpose(1, 2) |
| | k=self.k_proj(x).view(B,T,self.num_heads, self.head_dim).transpose(1, 2) |
| | v=self.v_proj(x).view(B,T,self.num_heads, self.head_dim).transpose(1, 2) |
| | cos,sin=self.rope(T,x.device) |
| | q,k=apply_rotary_pos_emb(q, k, cos, sin) |
| | scale=math.sqrt(self.head_dim) |
| | scores=torch.matmul(q, k.transpose(-2, -1))/scale |
| | if causal_mask is not None: |
| | scores=scores+causal_mask |
| | scores=scores.clamp(min=-1e4, max=1e4) |
| | attn=F.softmax(scores, dim=-1) |
| | attn=torch.nan_to_num(attn, nan=0.0, posinf=0.0, neginf=0.0) |
| | attn=self.attn_drop(attn) |
| | out=torch.matmul(attn, v) |
| | out=out.transpose(1, 2).contiguous().view(B, T, C) |
| | return self.out_proj(out) |
| | class FeedForward(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.fc1=nn.Linear(config.embedding_dim, config.intermediate_size, bias=False) |
| | self.fc2=nn.Linear(config.intermediate_size, config.embedding_dim, bias=False) |
| | self.drop=nn.Dropout(config.hidden_dropout_prob) |
| | def forward(self, x): |
| | return self.drop(self.fc2(self.drop(F.gelu(self.fc1(x))))) |
| | class TransformerBlock(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.attn=MultiHeadAttention(config) |
| | self.ff=FeedForward(config) |
| | self.ln1=nn.LayerNorm(config.embedding_dim,eps=config.layer_norm_eps) |
| | self.ln2=nn.LayerNorm(config.embedding_dim,eps=config.layer_norm_eps) |
| | def forward(self, x, mask=None): |
| | x=x+self.attn(self.ln1(x), mask) |
| | x=x+self.ff(self.ln2(x)) |
| | return x |
| | class SequenceLevelRouter(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.pooler=nn.Linear(config.embedding_dim, config.embedding_dim) |
| | self.act=nn.Tanh() |
| | self.head=nn.Sequential( |
| | nn.Linear(config.embedding_dim, config.embedding_dim // 2), |
| | nn.GELU(), |
| | nn.Dropout(0.1), |
| | nn.Linear(config.embedding_dim // 2, 3) |
| | ) |
| | self.register_buffer('steps_map', torch.tensor([ |
| | config.simple_recursion_steps, |
| | config.medium_recursion_steps, |
| | config.complex_recursion_steps, |
| | ], dtype=torch.long)) |
| | def forward(self, x, valid_mask=None): |
| | if valid_mask is not None: |
| | m=valid_mask.unsqueeze(-1).float() |
| | pooled=(x * m).sum(1)/m.sum(1).clamp(min=1e-9) |
| | else: |
| | pooled=x.mean(1) |
| | pooled=self.act(self.pooler(pooled)) |
| | logits=self.head(pooled) |
| | cls=logits.argmax(dim=-1) |
| | return logits, cls, self.steps_map[cls] |
| | class RecursionLayer(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.block=TransformerBlock(config) |
| | def forward(self,x,mask=None): |
| | return self.block(x, mask) |
| | class RecursiveLanguageModel(PreTrainedModel): |
| | config_class = RecursiveLanguageModelConfig |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.config=config |
| | self.embed=nn.Embedding(config.vocab_size, config.embedding_dim, |
| | padding_idx=config.pad_token_id) |
| | self.layers=nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)]) |
| | self.router=SequenceLevelRouter(config) |
| | self.rec_layer=RecursionLayer(config) |
| | self.ln_f=nn.LayerNorm(config.embedding_dim, eps=config.layer_norm_eps) |
| | self.lm_head=nn.Linear(config.embedding_dim, config.vocab_size, bias=False) |
| | self.post_init() |
| | def get_input_embeddings(self): return self.embed |
| | def set_input_embeddings(self, v): self.embed = v |
| | def get_output_embeddings(self): return self.lm_head |
| | def set_output_embeddings(self, v): self.lm_head = v |
| | def _init_weights(self, module): |
| | if isinstance(module, nn.Linear): |
| | nn.init.normal_(module.weight,mean=0.0,std=self.config.initializer_range) |
| | if module.bias is not None: |
| | nn.init.zeros_(module.bias) |
| | elif isinstance(module, nn.Embedding): |
| | nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) |
| | if module.padding_idx is not None: |
| | module.weight.data[module.padding_idx].zero_() |
| | elif isinstance(module, nn.LayerNorm): |
| | nn.init.zeros_(module.bias) |
| | nn.init.ones_(module.weight) |
| | def _make_causal_mask(self, input_ids): |
| | B,T=input_ids.shape |
| | device=input_ids.device |
| | mask=torch.zeros(T, T, device=device) |
| | mask=mask.masked_fill( |
| | torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1), |
| | -1e4 |
| | ) |
| | mask=mask.unsqueeze(0).unsqueeze(0) |
| | pad_mask=(input_ids==self.config.pad_token_id) |
| | valid_mask=~pad_mask |
| | if pad_mask.any(): |
| | pad_key_mask=pad_mask.unsqueeze(1).unsqueeze(2).float()*-1e4 |
| | mask=mask+pad_key_mask |
| | return mask, valid_mask |
| | def forward(self,input_ids,labels=None,attention_mask=None,**kwargs): |
| | B,T=input_ids.shape |
| | x=self.embed(input_ids) |
| | causal_mask,valid_mask=self._make_causal_mask(input_ids) |
| | for layer in self.layers: |
| | x=layer(x,causal_mask) |
| | router_logits, cls, steps=self.router(x,valid_mask) |
| | max_steps=int(steps.max().item()) |
| | for s in range(max_steps): |
| | gate=(steps > s).float().view(B, 1, 1) |
| | x=gate*self.rec_layer(x, causal_mask)+(1-gate)*x |
| | x=self.ln_f(x) |
| | logits=self.lm_head(x) |
| | loss=None |
| | if labels is not None: |
| | shift_logits=logits[:, :-1, :].contiguous() |
| | shift_labels=labels[:, 1:].contiguous() |
| | lm_loss=F.cross_entropy( |
| | shift_logits.view(-1, self.config.vocab_size), |
| | shift_labels.view(-1), |
| | ignore_index=-100 |
| | ) |
| | with torch.no_grad(): |
| | per_tok= F.cross_entropy( |
| | shift_logits.view(-1, self.config.vocab_size), |
| | shift_labels.view(-1), |
| | ignore_index=-100, reduction='none' |
| | ).view(B, -1) |
| | valid_tok=(shift_labels!=-100).sum(1).clamp(min=1).float() |
| | ppl=torch.exp((per_tok.sum(1)/valid_tok).clamp(max=20)) |
| | pseudo=torch.zeros(B,dtype=torch.long,device=input_ids.device) |
| | pseudo[(ppl>=20)&(ppl<50)]=1 |
| | pseudo[ppl>= 50] = 2 |
| | router_loss=F.cross_entropy(router_logits, pseudo) |
| | loss=lm_loss+0.1*router_loss |
| | return CausalLMOutputWithPast(loss=loss,logits=logits) |
| | @torch.no_grad() |
| | def generate(self,input_ids,max_new_tokens=100,temperature=0.8, |
| | top_p=0.9, do_sample=True, **kwargs): |
| | self.eval() |
| | gen=input_ids |
| | for _ in range(max_new_tokens): |
| | ctx=gen[:,-self.config.max_position_embeddings:] |
| | logits=self.forward(ctx).logits[:, -1, :] |
| | if temperature!=1.0: |
| | logits=logits/temperature |
| | if do_sample: |
| | probs=F.softmax(logits, dim=-1) |
| | sorted_probs,sorted_idx=torch.sort(probs,descending=True) |
| | cum_probs=torch.cumsum(sorted_probs, dim=-1) |
| | remove=cum_probs-sorted_probs>top_p |
| | sorted_probs[remove]=0.0 |
| | sorted_probs=sorted_probs/sorted_probs.sum(dim=-1,keepdim=True) |
| | next_tok=torch.gather(sorted_idx,-1,torch.multinomial(sorted_probs,1)) |
| | else: |
| | next_tok=logits.argmax(dim=-1,keepdim=True) |
| | gen=torch.cat([gen,next_tok],dim=-1) |
| | if (next_tok==self.config.eos_token_id).all(): |
| | break |
| | return gen |