"""Waveformer model for HuggingFace transformers.""" import torch, torch.nn as nn, math from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.generation import GenerationMixin try: from .configuration_waveformer import WaveformerConfig except ImportError: from configuration_waveformer import WaveformerConfig class _OscillatorAttention(nn.Module): def __init__(self, d_model, d_out, n_osc): super().__init__() self.n_osc = n_osc self.perturb = nn.Linear(d_model, n_osc, bias=False) self.readout = nn.Linear(n_osc, d_out, bias=False) omega = (torch.arange(n_osc).float() * 1.618033988749895).fmod(1.0) * 2 * math.pi self.register_buffer('omega', omega) idx = torch.arange(n_osc) dist = (idx.unsqueeze(1) - idx.unsqueeze(0)).abs().float() K = torch.zeros(n_osc, n_osc) K[dist > 0] = torch.exp(-dist[dist > 0] / 100.0) self.register_buffer('coupling', K) def forward(self, x): B, S, D = x.shape theta = self.perturb(x.mean(1)) for _ in range(3): sd = torch.sin(theta.unsqueeze(-1) - theta.unsqueeze(-2)) theta = theta + 0.1 * (self.omega + (self.coupling * sd).sum(-1)) return self.readout(torch.cos(theta)) class WaveformerPreTrainedModel(PreTrainedModel): config_class = WaveformerConfig base_model_prefix = "waveformer" def _init_weights(self, module): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) if module.bias is not None: module.bias.data.zero_() class WaveformerLayer(nn.Module): def __init__(self, config): super().__init__() D = config.d_model self.osc = _OscillatorAttention(D, D, D * 2) self.norm1 = nn.RMSNorm(D, eps=1e-5) self.norm2 = nn.RMSNorm(D, eps=1e-5) self.ffn = nn.Sequential( nn.Linear(D, D * 8 // 3, bias=False), nn.SiLU(), nn.Linear(D * 8 // 3, D, bias=False), ) def forward(self, x): return x + self.osc(self.norm1(x)).unsqueeze(1).expand(-1, x.shape[1], -1) + self.ffn(self.norm2(x)) class WaveformerModel(WaveformerPreTrainedModel, GenerationMixin): def __init__(self, config): super().__init__(config) self.embed = nn.Embedding(config.vocab_size, config.d_model) pos = torch.arange(config.max_seq_len).float() omega = (pos * 1.618033988749895).fmod(1.0) * 2 * math.pi sub = torch.arange(config.d_model).float() * 0.01 theta = omega.unsqueeze(1) + sub.unsqueeze(0) self.register_buffer('kam_sin', torch.sin(theta)) self.register_buffer('kam_cos', torch.cos(theta)) self.layers = nn.ModuleList([ WaveformerLayer(config) for _ in range(config.n_layers) ]) self.norm_f = nn.RMSNorm(config.d_model, eps=1e-5) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.lm_head.weight = self.embed.weight self.post_init() def forward(self, input_ids, attention_mask=None, **kwargs): B, S = input_ids.shape sp = min(S, self.kam_sin.shape[0]) x = self.embed(input_ids) x = x * self.kam_cos[:sp].unsqueeze(0) + x.roll(1, -1) * self.kam_sin[:sp].unsqueeze(0) for layer in self.layers: x = layer(x) return CausalLMOutputWithPast(logits=self.lm_head(self.norm_f(x))) def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} def get_input_embeddings(self): return self.embed def set_input_embeddings(self, value): self.embed = value