""" HYDRA: Hybrid Dynamic Recurrent Architecture Novel non-transformer combining Mamba SSM, Griffin RG-LRU, RWKV mixing. """ import math, torch, torch.nn as nn, torch.nn.functional as F from dataclasses import dataclass from typing import Optional @dataclass class HydraConfig: vocab_size: int = 50257 d_model: int = 512 n_layers: int = 8 d_state: int = 16 d_conv: int = 4 d_inner: int = None mlp_expand: int = 3 n_scales: int = 3 dropout: float = 0.1 max_seq_len: int = 1024 pad_token_id: int = 50256 tie_weights: bool = True def __post_init__(self): if self.d_inner is None: self.d_inner = 2 * self.d_model class RMSNorm(nn.Module): def __init__(self, d_model, eps=1e-8): super().__init__(); self.eps=eps; self.weight=nn.Parameter(torch.ones(d_model)) def forward(self, x): return x / torch.sqrt(torch.mean(x**2,dim=-1,keepdim=True)+self.eps) * self.weight class SelectiveGatedRecurrence(nn.Module): def __init__(self, config): super().__init__() self.d_inner=config.d_inner; self.d_state=config.d_state; self.n_scales=config.n_scales; self.gate_c=8.0 self.in_proj=nn.Linear(config.d_model, 2*config.d_inner, bias=False) self.conv1d=nn.Conv1d(config.d_inner,config.d_inner,config.d_conv,padding=config.d_conv-1,groups=config.d_inner) self.B_proj=nn.ModuleList([nn.Linear(config.d_inner,config.d_state,bias=False) for _ in range(config.n_scales)]) self.C_proj=nn.ModuleList([nn.Linear(config.d_inner,config.d_state,bias=False) for _ in range(config.n_scales)]) self.recurrence_gate=nn.ModuleList([nn.Linear(config.d_inner,config.d_state,bias=True) for _ in range(config.n_scales)]) self.input_gate=nn.ModuleList([nn.Linear(config.d_inner,config.d_state,bias=True) for _ in range(config.n_scales)]) self.Lambda=nn.ParameterList() for s in range(config.n_scales): low_a=0.9**(1.0/(s+1)); high_a=0.999**(1.0/(s+1)) init_val=torch.empty(config.d_state).uniform_(math.log(low_a/(1-low_a+1e-8)),math.log(high_a/(1-high_a+1e-8))) self.Lambda.append(nn.Parameter(init_val)) self.scale_fusion=nn.Linear(config.n_scales*config.d_state,config.d_inner,bias=False) self.out_proj=nn.Linear(config.d_inner,config.d_model,bias=False) self.dropout=nn.Dropout(config.dropout) def forward(self, x): batch,seq_len,_=x.shape xz=self.in_proj(x); x_branch,z_branch=xz.chunk(2,dim=-1) x_conv=self.conv1d(x_branch.transpose(1,2))[:,:,:seq_len].transpose(1,2) x_conv=F.silu(x_conv); z_gate=F.gelu(z_branch) scale_outputs=[] for s in range(self.n_scales): B_t=self.B_proj[s](x_conv); C_t=self.C_proj[s](x_conv) r_t=torch.sigmoid(self.recurrence_gate[s](x_conv)) i_t=torch.sigmoid(self.input_gate[s](x_conv)) a_base=torch.sigmoid(self.Lambda[s]) log_a_base=torch.log(a_base.clamp(min=1e-8)) log_a_t=self.gate_c*r_t*log_a_base.unsqueeze(0).unsqueeze(0) a_t=torch.exp(log_a_t) sqrt_term=torch.sqrt((1-a_t**2).clamp(min=1e-8)) input_contrib=sqrt_term*i_t*B_t h_prev=torch.zeros(batch,self.d_state,device=x.device,dtype=x.dtype) h_states=[] for t in range(seq_len): h_prev=a_t[:,t]*h_prev+input_contrib[:,t] h_states.append(h_prev.unsqueeze(1)) h=torch.cat(h_states,dim=1) scale_outputs.append(h) multi_scale=torch.cat(scale_outputs,dim=-1) fused=self.scale_fusion(multi_scale) output=self.out_proj(fused*z_gate) return self.dropout(output) class GatedChannelMixing(nn.Module): def __init__(self, config): super().__init__(); d_ff=config.mlp_expand*config.d_model self.gate_proj=nn.Linear(config.d_model,d_ff,bias=False) self.up_proj=nn.Linear(config.d_model,d_ff,bias=False) self.down_proj=nn.Linear(d_ff,config.d_model,bias=False) self.dropout=nn.Dropout(config.dropout) def forward(self, x): return self.dropout(self.down_proj(F.gelu(self.gate_proj(x))*self.up_proj(x))) class HydraBlock(nn.Module): def __init__(self, config): super().__init__() self.norm1=RMSNorm(config.d_model); self.temporal_mix=SelectiveGatedRecurrence(config) self.norm2=RMSNorm(config.d_model); self.channel_mix=GatedChannelMixing(config) def forward(self, x): return x+self.channel_mix(self.norm2(x+self.temporal_mix(self.norm1(x)))) class HydraModel(nn.Module): def __init__(self, config): super().__init__(); self.config=config self.embed=nn.Embedding(config.vocab_size,config.d_model) self.embed_dropout=nn.Dropout(config.dropout) self.blocks=nn.ModuleList([HydraBlock(config) for _ in range(config.n_layers)]) self.final_norm=RMSNorm(config.d_model) self.lm_head=nn.Linear(config.d_model,config.vocab_size,bias=False) if config.tie_weights: self.lm_head.weight=self.embed.weight self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight,0,1.0/math.sqrt(m.weight.shape[1])); hasattr(m,'bias') and m.bias is not None and nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight,0,0.02) elif isinstance(m, nn.Conv1d): nn.init.normal_(m.weight,0,0.02); hasattr(m,'bias') and m.bias is not None and nn.init.zeros_(m.bias) def forward(self, input_ids, labels=None): x=self.embed_dropout(self.embed(input_ids)) for b in self.blocks: x=b(x) x=self.final_norm(x); logits=self.lm_head(x) result={"logits":logits} if labels is not None: loss=F.cross_entropy(logits[:,:-1,:].contiguous().view(-1,self.config.vocab_size),labels[:,1:].contiguous().view(-1),ignore_index=self.config.pad_token_id) result["loss"]=loss return result def generate(self, input_ids, max_new_tokens=100, temperature=0.8, top_k=50): self.eval() with torch.no_grad(): for _ in range(max_new_tokens): ctx=input_ids[:,-self.config.max_seq_len:] logits=self.forward(ctx)["logits"][:,-1,:]/temperature if top_k>0: v,_=torch.topk(logits,top_k); logits[logits