sukritvemula's picture
HYDRA v1: 19M param non-transformer with selective gated recurrence
81271b1 verified
"""
HYDRA: Hybrid Dynamic Recurrent Architecture
Novel non-transformer combining Mamba SSM, Griffin RG-LRU, RWKV mixing.
"""
import math, torch, torch.nn as nn, torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional
@dataclass
class HydraConfig:
vocab_size: int = 50257
d_model: int = 512
n_layers: int = 8
d_state: int = 16
d_conv: int = 4
d_inner: int = None
mlp_expand: int = 3
n_scales: int = 3
dropout: float = 0.1
max_seq_len: int = 1024
pad_token_id: int = 50256
tie_weights: bool = True
def __post_init__(self):
if self.d_inner is None: self.d_inner = 2 * self.d_model
class RMSNorm(nn.Module):
def __init__(self, d_model, eps=1e-8):
super().__init__(); self.eps=eps; self.weight=nn.Parameter(torch.ones(d_model))
def forward(self, x):
return x / torch.sqrt(torch.mean(x**2,dim=-1,keepdim=True)+self.eps) * self.weight
class SelectiveGatedRecurrence(nn.Module):
def __init__(self, config):
super().__init__()
self.d_inner=config.d_inner; self.d_state=config.d_state; self.n_scales=config.n_scales; self.gate_c=8.0
self.in_proj=nn.Linear(config.d_model, 2*config.d_inner, bias=False)
self.conv1d=nn.Conv1d(config.d_inner,config.d_inner,config.d_conv,padding=config.d_conv-1,groups=config.d_inner)
self.B_proj=nn.ModuleList([nn.Linear(config.d_inner,config.d_state,bias=False) for _ in range(config.n_scales)])
self.C_proj=nn.ModuleList([nn.Linear(config.d_inner,config.d_state,bias=False) for _ in range(config.n_scales)])
self.recurrence_gate=nn.ModuleList([nn.Linear(config.d_inner,config.d_state,bias=True) for _ in range(config.n_scales)])
self.input_gate=nn.ModuleList([nn.Linear(config.d_inner,config.d_state,bias=True) for _ in range(config.n_scales)])
self.Lambda=nn.ParameterList()
for s in range(config.n_scales):
low_a=0.9**(1.0/(s+1)); high_a=0.999**(1.0/(s+1))
init_val=torch.empty(config.d_state).uniform_(math.log(low_a/(1-low_a+1e-8)),math.log(high_a/(1-high_a+1e-8)))
self.Lambda.append(nn.Parameter(init_val))
self.scale_fusion=nn.Linear(config.n_scales*config.d_state,config.d_inner,bias=False)
self.out_proj=nn.Linear(config.d_inner,config.d_model,bias=False)
self.dropout=nn.Dropout(config.dropout)
def forward(self, x):
batch,seq_len,_=x.shape
xz=self.in_proj(x); x_branch,z_branch=xz.chunk(2,dim=-1)
x_conv=self.conv1d(x_branch.transpose(1,2))[:,:,:seq_len].transpose(1,2)
x_conv=F.silu(x_conv); z_gate=F.gelu(z_branch)
scale_outputs=[]
for s in range(self.n_scales):
B_t=self.B_proj[s](x_conv); C_t=self.C_proj[s](x_conv)
r_t=torch.sigmoid(self.recurrence_gate[s](x_conv))
i_t=torch.sigmoid(self.input_gate[s](x_conv))
a_base=torch.sigmoid(self.Lambda[s])
log_a_base=torch.log(a_base.clamp(min=1e-8))
log_a_t=self.gate_c*r_t*log_a_base.unsqueeze(0).unsqueeze(0)
a_t=torch.exp(log_a_t)
sqrt_term=torch.sqrt((1-a_t**2).clamp(min=1e-8))
input_contrib=sqrt_term*i_t*B_t
h_prev=torch.zeros(batch,self.d_state,device=x.device,dtype=x.dtype)
h_states=[]
for t in range(seq_len):
h_prev=a_t[:,t]*h_prev+input_contrib[:,t]
h_states.append(h_prev.unsqueeze(1))
h=torch.cat(h_states,dim=1)
scale_outputs.append(h)
multi_scale=torch.cat(scale_outputs,dim=-1)
fused=self.scale_fusion(multi_scale)
output=self.out_proj(fused*z_gate)
return self.dropout(output)
class GatedChannelMixing(nn.Module):
def __init__(self, config):
super().__init__(); d_ff=config.mlp_expand*config.d_model
self.gate_proj=nn.Linear(config.d_model,d_ff,bias=False)
self.up_proj=nn.Linear(config.d_model,d_ff,bias=False)
self.down_proj=nn.Linear(d_ff,config.d_model,bias=False)
self.dropout=nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.down_proj(F.gelu(self.gate_proj(x))*self.up_proj(x)))
class HydraBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.norm1=RMSNorm(config.d_model); self.temporal_mix=SelectiveGatedRecurrence(config)
self.norm2=RMSNorm(config.d_model); self.channel_mix=GatedChannelMixing(config)
def forward(self, x):
return x+self.channel_mix(self.norm2(x+self.temporal_mix(self.norm1(x))))
class HydraModel(nn.Module):
def __init__(self, config):
super().__init__(); self.config=config
self.embed=nn.Embedding(config.vocab_size,config.d_model)
self.embed_dropout=nn.Dropout(config.dropout)
self.blocks=nn.ModuleList([HydraBlock(config) for _ in range(config.n_layers)])
self.final_norm=RMSNorm(config.d_model)
self.lm_head=nn.Linear(config.d_model,config.vocab_size,bias=False)
if config.tie_weights: self.lm_head.weight=self.embed.weight
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear): nn.init.normal_(m.weight,0,1.0/math.sqrt(m.weight.shape[1])); hasattr(m,'bias') and m.bias is not None and nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight,0,0.02)
elif isinstance(m, nn.Conv1d): nn.init.normal_(m.weight,0,0.02); hasattr(m,'bias') and m.bias is not None and nn.init.zeros_(m.bias)
def forward(self, input_ids, labels=None):
x=self.embed_dropout(self.embed(input_ids))
for b in self.blocks: x=b(x)
x=self.final_norm(x); logits=self.lm_head(x)
result={"logits":logits}
if labels is not None:
loss=F.cross_entropy(logits[:,:-1,:].contiguous().view(-1,self.config.vocab_size),labels[:,1:].contiguous().view(-1),ignore_index=self.config.pad_token_id)
result["loss"]=loss
return result
def generate(self, input_ids, max_new_tokens=100, temperature=0.8, top_k=50):
self.eval()
with torch.no_grad():
for _ in range(max_new_tokens):
ctx=input_ids[:,-self.config.max_seq_len:]
logits=self.forward(ctx)["logits"][:,-1,:]/temperature
if top_k>0:
v,_=torch.topk(logits,top_k); logits[logits<v[:,-1:]]=float('-inf')
probs=F.softmax(logits,dim=-1)
next_token=torch.multinomial(probs,1)
input_ids=torch.cat([input_ids,next_token],dim=-1)
return input_ids
@property
def num_parameters(self): return sum(p.numel() for p in self.parameters())