| """ |
| HYDRA: Hybrid Dynamic Recurrent Architecture |
| Novel non-transformer combining Mamba SSM, Griffin RG-LRU, RWKV mixing. |
| """ |
| import math, torch, torch.nn as nn, torch.nn.functional as F |
| from dataclasses import dataclass |
| from typing import Optional |
|
|
| @dataclass |
| class HydraConfig: |
| vocab_size: int = 50257 |
| d_model: int = 512 |
| n_layers: int = 8 |
| d_state: int = 16 |
| d_conv: int = 4 |
| d_inner: int = None |
| mlp_expand: int = 3 |
| n_scales: int = 3 |
| dropout: float = 0.1 |
| max_seq_len: int = 1024 |
| pad_token_id: int = 50256 |
| tie_weights: bool = True |
| def __post_init__(self): |
| if self.d_inner is None: self.d_inner = 2 * self.d_model |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, d_model, eps=1e-8): |
| super().__init__(); self.eps=eps; self.weight=nn.Parameter(torch.ones(d_model)) |
| def forward(self, x): |
| return x / torch.sqrt(torch.mean(x**2,dim=-1,keepdim=True)+self.eps) * self.weight |
|
|
| class SelectiveGatedRecurrence(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.d_inner=config.d_inner; self.d_state=config.d_state; self.n_scales=config.n_scales; self.gate_c=8.0 |
| self.in_proj=nn.Linear(config.d_model, 2*config.d_inner, bias=False) |
| self.conv1d=nn.Conv1d(config.d_inner,config.d_inner,config.d_conv,padding=config.d_conv-1,groups=config.d_inner) |
| self.B_proj=nn.ModuleList([nn.Linear(config.d_inner,config.d_state,bias=False) for _ in range(config.n_scales)]) |
| self.C_proj=nn.ModuleList([nn.Linear(config.d_inner,config.d_state,bias=False) for _ in range(config.n_scales)]) |
| self.recurrence_gate=nn.ModuleList([nn.Linear(config.d_inner,config.d_state,bias=True) for _ in range(config.n_scales)]) |
| self.input_gate=nn.ModuleList([nn.Linear(config.d_inner,config.d_state,bias=True) for _ in range(config.n_scales)]) |
| self.Lambda=nn.ParameterList() |
| for s in range(config.n_scales): |
| low_a=0.9**(1.0/(s+1)); high_a=0.999**(1.0/(s+1)) |
| init_val=torch.empty(config.d_state).uniform_(math.log(low_a/(1-low_a+1e-8)),math.log(high_a/(1-high_a+1e-8))) |
| self.Lambda.append(nn.Parameter(init_val)) |
| self.scale_fusion=nn.Linear(config.n_scales*config.d_state,config.d_inner,bias=False) |
| self.out_proj=nn.Linear(config.d_inner,config.d_model,bias=False) |
| self.dropout=nn.Dropout(config.dropout) |
|
|
| def forward(self, x): |
| batch,seq_len,_=x.shape |
| xz=self.in_proj(x); x_branch,z_branch=xz.chunk(2,dim=-1) |
| x_conv=self.conv1d(x_branch.transpose(1,2))[:,:,:seq_len].transpose(1,2) |
| x_conv=F.silu(x_conv); z_gate=F.gelu(z_branch) |
| scale_outputs=[] |
| for s in range(self.n_scales): |
| B_t=self.B_proj[s](x_conv); C_t=self.C_proj[s](x_conv) |
| r_t=torch.sigmoid(self.recurrence_gate[s](x_conv)) |
| i_t=torch.sigmoid(self.input_gate[s](x_conv)) |
| a_base=torch.sigmoid(self.Lambda[s]) |
| log_a_base=torch.log(a_base.clamp(min=1e-8)) |
| log_a_t=self.gate_c*r_t*log_a_base.unsqueeze(0).unsqueeze(0) |
| a_t=torch.exp(log_a_t) |
| sqrt_term=torch.sqrt((1-a_t**2).clamp(min=1e-8)) |
| input_contrib=sqrt_term*i_t*B_t |
| h_prev=torch.zeros(batch,self.d_state,device=x.device,dtype=x.dtype) |
| h_states=[] |
| for t in range(seq_len): |
| h_prev=a_t[:,t]*h_prev+input_contrib[:,t] |
| h_states.append(h_prev.unsqueeze(1)) |
| h=torch.cat(h_states,dim=1) |
| scale_outputs.append(h) |
| multi_scale=torch.cat(scale_outputs,dim=-1) |
| fused=self.scale_fusion(multi_scale) |
| output=self.out_proj(fused*z_gate) |
| return self.dropout(output) |
|
|
| class GatedChannelMixing(nn.Module): |
| def __init__(self, config): |
| super().__init__(); d_ff=config.mlp_expand*config.d_model |
| self.gate_proj=nn.Linear(config.d_model,d_ff,bias=False) |
| self.up_proj=nn.Linear(config.d_model,d_ff,bias=False) |
| self.down_proj=nn.Linear(d_ff,config.d_model,bias=False) |
| self.dropout=nn.Dropout(config.dropout) |
| def forward(self, x): |
| return self.dropout(self.down_proj(F.gelu(self.gate_proj(x))*self.up_proj(x))) |
|
|
| class HydraBlock(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.norm1=RMSNorm(config.d_model); self.temporal_mix=SelectiveGatedRecurrence(config) |
| self.norm2=RMSNorm(config.d_model); self.channel_mix=GatedChannelMixing(config) |
| def forward(self, x): |
| return x+self.channel_mix(self.norm2(x+self.temporal_mix(self.norm1(x)))) |
|
|
| class HydraModel(nn.Module): |
| def __init__(self, config): |
| super().__init__(); self.config=config |
| self.embed=nn.Embedding(config.vocab_size,config.d_model) |
| self.embed_dropout=nn.Dropout(config.dropout) |
| self.blocks=nn.ModuleList([HydraBlock(config) for _ in range(config.n_layers)]) |
| self.final_norm=RMSNorm(config.d_model) |
| self.lm_head=nn.Linear(config.d_model,config.vocab_size,bias=False) |
| if config.tie_weights: self.lm_head.weight=self.embed.weight |
| self._init_weights() |
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): nn.init.normal_(m.weight,0,1.0/math.sqrt(m.weight.shape[1])); hasattr(m,'bias') and m.bias is not None and nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight,0,0.02) |
| elif isinstance(m, nn.Conv1d): nn.init.normal_(m.weight,0,0.02); hasattr(m,'bias') and m.bias is not None and nn.init.zeros_(m.bias) |
| def forward(self, input_ids, labels=None): |
| x=self.embed_dropout(self.embed(input_ids)) |
| for b in self.blocks: x=b(x) |
| x=self.final_norm(x); logits=self.lm_head(x) |
| result={"logits":logits} |
| if labels is not None: |
| loss=F.cross_entropy(logits[:,:-1,:].contiguous().view(-1,self.config.vocab_size),labels[:,1:].contiguous().view(-1),ignore_index=self.config.pad_token_id) |
| result["loss"]=loss |
| return result |
| def generate(self, input_ids, max_new_tokens=100, temperature=0.8, top_k=50): |
| self.eval() |
| with torch.no_grad(): |
| for _ in range(max_new_tokens): |
| ctx=input_ids[:,-self.config.max_seq_len:] |
| logits=self.forward(ctx)["logits"][:,-1,:]/temperature |
| if top_k>0: |
| v,_=torch.topk(logits,top_k); logits[logits<v[:,-1:]]=float('-inf') |
| probs=F.softmax(logits,dim=-1) |
| next_token=torch.multinomial(probs,1) |
| input_ids=torch.cat([input_ids,next_token],dim=-1) |
| return input_ids |
| @property |
| def num_parameters(self): return sum(p.numel() for p in self.parameters()) |
|
|