| """Waveformer model for HuggingFace transformers.""" |
| import torch, torch.nn as nn, math |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers.generation import GenerationMixin |
| try: |
| from .configuration_waveformer import WaveformerConfig |
| except ImportError: |
| from configuration_waveformer import WaveformerConfig |
|
|
|
|
| class _OscillatorAttention(nn.Module): |
| def __init__(self, d_model, d_out, n_osc): |
| super().__init__() |
| self.n_osc = n_osc |
| self.perturb = nn.Linear(d_model, n_osc, bias=False) |
| self.readout = nn.Linear(n_osc, d_out, bias=False) |
| omega = (torch.arange(n_osc).float() * 1.618033988749895).fmod(1.0) * 2 * math.pi |
| self.register_buffer('omega', omega) |
| idx = torch.arange(n_osc) |
| dist = (idx.unsqueeze(1) - idx.unsqueeze(0)).abs().float() |
| K = torch.zeros(n_osc, n_osc) |
| K[dist > 0] = torch.exp(-dist[dist > 0] / 100.0) |
| self.register_buffer('coupling', K) |
|
|
| def forward(self, x): |
| B, S, D = x.shape |
| theta = self.perturb(x.mean(1)) |
| for _ in range(3): |
| sd = torch.sin(theta.unsqueeze(-1) - theta.unsqueeze(-2)) |
| theta = theta + 0.1 * (self.omega + (self.coupling * sd).sum(-1)) |
| return self.readout(torch.cos(theta)) |
|
|
|
|
| class WaveformerPreTrainedModel(PreTrainedModel): |
| config_class = WaveformerConfig |
| base_model_prefix = "waveformer" |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=0.02) |
| if module.bias is not None: |
| module.bias.data.zero_() |
|
|
|
|
| class WaveformerLayer(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| D = config.d_model |
| self.osc = _OscillatorAttention(D, D, D * 2) |
| self.norm1 = nn.RMSNorm(D, eps=1e-5) |
| self.norm2 = nn.RMSNorm(D, eps=1e-5) |
| self.ffn = nn.Sequential( |
| nn.Linear(D, D * 8 // 3, bias=False), |
| nn.SiLU(), |
| nn.Linear(D * 8 // 3, D, bias=False), |
| ) |
|
|
| def forward(self, x): |
| return x + self.osc(self.norm1(x)).unsqueeze(1).expand(-1, x.shape[1], -1) + self.ffn(self.norm2(x)) |
|
|
|
|
| class WaveformerModel(WaveformerPreTrainedModel, GenerationMixin): |
| def __init__(self, config): |
| super().__init__(config) |
| self.embed = nn.Embedding(config.vocab_size, config.d_model) |
|
|
| pos = torch.arange(config.max_seq_len).float() |
| omega = (pos * 1.618033988749895).fmod(1.0) * 2 * math.pi |
| sub = torch.arange(config.d_model).float() * 0.01 |
| theta = omega.unsqueeze(1) + sub.unsqueeze(0) |
| self.register_buffer('kam_sin', torch.sin(theta)) |
| self.register_buffer('kam_cos', torch.cos(theta)) |
|
|
| self.layers = nn.ModuleList([ |
| WaveformerLayer(config) for _ in range(config.n_layers) |
| ]) |
| self.norm_f = nn.RMSNorm(config.d_model, eps=1e-5) |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
| self.lm_head.weight = self.embed.weight |
| self.post_init() |
|
|
| def forward(self, input_ids, attention_mask=None, **kwargs): |
| B, S = input_ids.shape |
| sp = min(S, self.kam_sin.shape[0]) |
| x = self.embed(input_ids) |
| x = x * self.kam_cos[:sp].unsqueeze(0) + x.roll(1, -1) * self.kam_sin[:sp].unsqueeze(0) |
| for layer in self.layers: |
| x = layer(x) |
| return CausalLMOutputWithPast(logits=self.lm_head(self.norm_f(x))) |
|
|
| def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| return {"input_ids": input_ids} |
|
|
| def get_input_embeddings(self): |
| return self.embed |
|
|
| def set_input_embeddings(self, value): |
| self.embed = value |
|
|