Waveformer / modeling_waveformer.py
wannaq's picture
Upload 3 files
07f5733 verified
Raw
History Blame Contribute Delete
3.78 kB
"""Waveformer model for HuggingFace transformers."""
import torch, torch.nn as nn, math
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation import GenerationMixin
try:
from .configuration_waveformer import WaveformerConfig
except ImportError:
from configuration_waveformer import WaveformerConfig
class _OscillatorAttention(nn.Module):
def __init__(self, d_model, d_out, n_osc):
super().__init__()
self.n_osc = n_osc
self.perturb = nn.Linear(d_model, n_osc, bias=False)
self.readout = nn.Linear(n_osc, d_out, bias=False)
omega = (torch.arange(n_osc).float() * 1.618033988749895).fmod(1.0) * 2 * math.pi
self.register_buffer('omega', omega)
idx = torch.arange(n_osc)
dist = (idx.unsqueeze(1) - idx.unsqueeze(0)).abs().float()
K = torch.zeros(n_osc, n_osc)
K[dist > 0] = torch.exp(-dist[dist > 0] / 100.0)
self.register_buffer('coupling', K)
def forward(self, x):
B, S, D = x.shape
theta = self.perturb(x.mean(1))
for _ in range(3):
sd = torch.sin(theta.unsqueeze(-1) - theta.unsqueeze(-2))
theta = theta + 0.1 * (self.omega + (self.coupling * sd).sum(-1))
return self.readout(torch.cos(theta))
class WaveformerPreTrainedModel(PreTrainedModel):
config_class = WaveformerConfig
base_model_prefix = "waveformer"
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
class WaveformerLayer(nn.Module):
def __init__(self, config):
super().__init__()
D = config.d_model
self.osc = _OscillatorAttention(D, D, D * 2)
self.norm1 = nn.RMSNorm(D, eps=1e-5)
self.norm2 = nn.RMSNorm(D, eps=1e-5)
self.ffn = nn.Sequential(
nn.Linear(D, D * 8 // 3, bias=False),
nn.SiLU(),
nn.Linear(D * 8 // 3, D, bias=False),
)
def forward(self, x):
return x + self.osc(self.norm1(x)).unsqueeze(1).expand(-1, x.shape[1], -1) + self.ffn(self.norm2(x))
class WaveformerModel(WaveformerPreTrainedModel, GenerationMixin):
def __init__(self, config):
super().__init__(config)
self.embed = nn.Embedding(config.vocab_size, config.d_model)
pos = torch.arange(config.max_seq_len).float()
omega = (pos * 1.618033988749895).fmod(1.0) * 2 * math.pi
sub = torch.arange(config.d_model).float() * 0.01
theta = omega.unsqueeze(1) + sub.unsqueeze(0)
self.register_buffer('kam_sin', torch.sin(theta))
self.register_buffer('kam_cos', torch.cos(theta))
self.layers = nn.ModuleList([
WaveformerLayer(config) for _ in range(config.n_layers)
])
self.norm_f = nn.RMSNorm(config.d_model, eps=1e-5)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.lm_head.weight = self.embed.weight
self.post_init()
def forward(self, input_ids, attention_mask=None, **kwargs):
B, S = input_ids.shape
sp = min(S, self.kam_sin.shape[0])
x = self.embed(input_ids)
x = x * self.kam_cos[:sp].unsqueeze(0) + x.roll(1, -1) * self.kam_sin[:sp].unsqueeze(0)
for layer in self.layers:
x = layer(x)
return CausalLMOutputWithPast(logits=self.lm_head(self.norm_f(x)))
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
def get_input_embeddings(self):
return self.embed
def set_input_embeddings(self, value):
self.embed = value