|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.fft |
| import math |
| from x_transformers import Decoder |
| from transformers import AutoTokenizer |
| import os |
|
|
| |
| try: |
| if os.path.exists("tokenizer_config.json"): |
| tokenizer = AutoTokenizer.from_pretrained(".") |
| else: |
| tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-de-en") |
| except Exception as e: |
| print(f"Warning: Tokenizer load failed: {e}") |
|
|
| |
| |
| |
|
|
| class ComplexDropout(nn.Module): |
| def __init__(self, p=0.5): |
| super().__init__() |
| self.p = p |
|
|
| def forward(self, z): |
| if not self.training or self.p == 0.0: |
| return z |
| mask = torch.ones_like(z.real) |
| mask = F.dropout(mask, self.p, self.training, inplace=False) |
| return z * mask |
|
|
| class PhasePreservingLayerNorm(nn.Module): |
| def __init__(self, d_model, eps=1e-5): |
| super().__init__() |
| self.layernorm = nn.LayerNorm(d_model, eps=eps) |
| self.eps = eps |
|
|
| def forward(self, x): |
| mag = torch.abs(x) |
| mag_norm = self.layernorm(mag) |
| return mag_norm.to(x.dtype) * (x / (mag + self.eps)) |
|
|
| class HarmonicEmbedding(nn.Module): |
| def __init__(self, num_embeddings, embedding_dim, max_period=10000.0): |
| super().__init__() |
| self.embedding_dim = embedding_dim |
| self.complex_embedding = nn.Embedding(num_embeddings, embedding_dim * 2) |
| freqs = torch.exp(torch.arange(0, embedding_dim, dtype=torch.float32) * -(math.log(max_period) / embedding_dim)) |
| self.register_buffer('freqs', freqs) |
|
|
| def forward(self, input_ids): |
| raw_embeds = self.complex_embedding(input_ids) |
| real = raw_embeds[..., :self.embedding_dim] |
| imag = raw_embeds[..., self.embedding_dim:] |
| content_z = torch.complex(real, imag) |
| seq_len = input_ids.shape[1] |
| positions = torch.arange(seq_len, device=input_ids.device).float() |
| angles = torch.outer(positions, self.freqs) |
| pos_rotation = torch.polar(torch.ones_like(angles), angles).unsqueeze(0) |
| return content_z * pos_rotation |
|
|
| class ModReLU(nn.Module): |
| def __init__(self, features): |
| super().__init__() |
| self.b = nn.Parameter(torch.zeros(features)) |
| def forward(self, z): |
| mag = torch.abs(z) |
| new_mag = F.relu(mag + self.b) |
| phase = z / (mag + 1e-6) |
| return new_mag * phase |
|
|
| |
| class PRISMLayer(nn.Module): |
| def __init__(self, d_model, max_len=5000, dropout=0.1): |
| super().__init__() |
| self.d_model = d_model |
| self.filter_len = max_len |
|
|
| |
| self.gate_proj = nn.Linear(d_model * 2, d_model * 2) |
|
|
| |
| self.global_filter = nn.Parameter(torch.randn(d_model, max_len, dtype=torch.cfloat) * 0.02) |
|
|
| |
| self.mix_real = nn.Linear(d_model, d_model) |
| self.mix_imag = nn.Linear(d_model, d_model) |
|
|
| |
| self.out_real = nn.Linear(d_model, d_model) |
| self.out_imag = nn.Linear(d_model, d_model) |
|
|
| self.activation = ModReLU(d_model) |
| self.norm = PhasePreservingLayerNorm(d_model) |
| self.dropout = ComplexDropout(dropout) |
|
|
| def complex_linear(self, x, l_real, l_imag): |
| r, i = x.real, x.imag |
| new_r = l_real(r) - l_imag(i) |
| new_i = l_real(i) + l_imag(r) |
| return torch.complex(new_r, new_i) |
|
|
| def forward(self, x, src_mask=None): |
| if x is None: return None |
| residual = x |
| x_norm = self.norm(x) |
|
|
| if src_mask is not None: |
| x_norm = x_norm.masked_fill(src_mask.unsqueeze(-1), 0.0) |
|
|
| |
| x_cat = torch.cat([x_norm.real, x_norm.imag], dim=-1) |
| gates = torch.sigmoid(self.gate_proj(x_cat)) |
| gate_r, gate_i = gates.chunk(2, dim=-1) |
|
|
| |
| B, L, D = x_norm.shape |
| x_freq = torch.fft.fft(x_norm, n=self.filter_len, dim=1) |
| x_filtered = x_freq * self.global_filter.transpose(-1, -2) |
| x_time = torch.fft.ifft(x_filtered, n=self.filter_len, dim=1) |
| x_time = x_time[:, :L, :] |
|
|
| |
| gated_r = x_time.real * gate_r |
| gated_i = x_time.imag * gate_i |
| x_gated = torch.complex(gated_r, gated_i) |
|
|
| |
| x_mixed = self.complex_linear(x_gated, self.mix_real, self.mix_imag) |
| x_act = self.activation(x_mixed) |
| out = self.complex_linear(x_act, self.out_real, self.out_imag) |
| return self.dropout(out) + residual |
|
|
| |
| class PRISMEncoder(nn.Module): |
| def __init__(self, num_layers, d_model, max_len, dropout=0.1): |
| super().__init__() |
| self.layers = nn.ModuleList([PRISMLayer(d_model, max_len, dropout) for _ in range(num_layers)]) |
| self.final_norm = PhasePreservingLayerNorm(d_model) |
|
|
| def forward(self, x, src_mask=None): |
| for layer in self.layers: |
| x = layer(x, src_mask) |
| return self.final_norm(x) |
|
|
| |
| class ComplexToRealBridge(nn.Module): |
| def __init__(self, d_model): |
| super().__init__() |
| self.proj = nn.Linear(d_model * 2, d_model) |
| self.norm = nn.LayerNorm(d_model) |
|
|
| def forward(self, x_complex): |
| if x_complex is None: raise ValueError("Bridge None") |
| cat = torch.cat([x_complex.real, x_complex.imag], dim=-1) |
| return self.norm(self.proj(cat)) |
|
|
| class PRISMHybrid_RoPE(nn.Module): |
| def __init__(self, num_encoder_layers, num_refining_layers, num_decoder_layers, |
| num_heads, d_model, dff, vocab_size, max_length, dropout): |
| super().__init__() |
| self.d_model = d_model |
| self.harmonic_embedding = HarmonicEmbedding(vocab_size, d_model) |
| self.tgt_embedding = nn.Embedding(vocab_size, d_model) |
| self.dropout = nn.Dropout(dropout) |
|
|
| if num_encoder_layers > 0: |
| self.prism_encoder = PRISMEncoder(num_encoder_layers, d_model, max_length, dropout) |
| else: |
| self.prism_encoder = None |
|
|
| self.bridge = ComplexToRealBridge(d_model) |
|
|
| if num_refining_layers > 0: |
| refining_layer = nn.TransformerEncoderLayer( |
| d_model, num_heads, dff, dropout, |
| batch_first=True, norm_first=True |
| ) |
| self.reasoning_encoder = nn.TransformerEncoder(refining_layer, num_layers=num_refining_layers) |
| else: |
| self.reasoning_encoder = None |
|
|
| self.decoder = Decoder( |
| dim = d_model, depth = num_decoder_layers, heads = num_heads, attn_dim_head = d_model // num_heads, |
| ff_mult = dff / d_model, rotary_pos_emb = True, cross_attend = True, attn_flash = True, |
| attn_dropout = dropout, ff_dropout = dropout, use_rmsnorm = True |
| ) |
| self.final_linear = nn.Linear(d_model, vocab_size) |
| self.final_linear.weight = self.tgt_embedding.weight |
|
|
| def create_masks(self, src, tgt): |
| src_padding_mask = (src == tokenizer.pad_token_id) |
| tgt_padding_mask = (tgt == tokenizer.pad_token_id) |
| tgt_mask = nn.Transformer.generate_square_subsequent_mask(sz=tgt.size(1), device=src.device, dtype=torch.bool) |
| return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask |
|
|
| def forward(self, src, tgt, src_mask, tgt_pad, mem_pad, tgt_mask): |
| src_harmonic = self.harmonic_embedding(src) |
| if src_mask is not None: |
| src_harmonic = src_harmonic.masked_fill(src_mask.unsqueeze(-1), 0.0) |
|
|
| if self.prism_encoder is not None: |
| if self.training: |
| src_harmonic.requires_grad_(True) |
| encoded_complex = torch.utils.checkpoint.checkpoint( |
| self.prism_encoder.forward, |
| src_harmonic, src_mask, use_reentrant=False |
| ) |
| else: |
| encoded_complex = self.prism_encoder(src_harmonic, src_mask) |
| else: |
| encoded_complex = src_harmonic |
|
|
| coarse_memory = self.bridge(encoded_complex) |
| if self.reasoning_encoder is not None: |
| refined_memory = self.reasoning_encoder(coarse_memory, src_key_padding_mask=mem_pad) |
| else: |
| refined_memory = coarse_memory |
|
|
| tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model) |
| tgt_emb = self.dropout(tgt_emb) |
| context_mask = ~mem_pad if mem_pad is not None else None |
| decoder_mask = ~tgt_pad if tgt_pad is not None else None |
|
|
| if self.training: |
| tgt_emb.requires_grad_(True) |
| output = torch.utils.checkpoint.checkpoint( |
| self.decoder, tgt_emb, context=refined_memory, mask=decoder_mask, context_mask=context_mask, use_reentrant=False |
| ) |
| else: |
| output = self.decoder(tgt_emb, context=refined_memory, mask=decoder_mask, context_mask=context_mask) |
|
|
| return self.final_linear(output) |
|
|
| |
| @torch.no_grad() |
| def generate(self, src, max_length, num_beams=5): |
| self.eval() |
| src_mask = (src == tokenizer.pad_token_id) |
| context_mask = ~src_mask |
| src_harmonic = self.harmonic_embedding(src) |
| if src_mask is not None: |
| src_harmonic = src_harmonic.masked_fill(src_mask.unsqueeze(-1), 0.0) |
|
|
| if self.prism_encoder is not None: |
| encoded_complex = self.prism_encoder(src_harmonic, src_mask) |
| else: |
| encoded_complex = src_harmonic |
|
|
| coarse_memory = self.bridge(encoded_complex) |
|
|
| if self.reasoning_encoder is not None: |
| memory = self.reasoning_encoder(coarse_memory, src_key_padding_mask=src_mask) |
| else: |
| memory = coarse_memory |
|
|
| batch_size = src.shape[0] |
| memory = memory.repeat_interleave(num_beams, dim=0) |
| context_mask = context_mask.repeat_interleave(num_beams, dim=0) |
|
|
| beams = torch.full((batch_size * num_beams, 1), tokenizer.pad_token_id, dtype=torch.long, device=src.device) |
| beam_scores = torch.zeros(batch_size * num_beams, device=src.device) |
| finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device) |
|
|
| for _ in range(max_length - 1): |
| if finished_beams.all(): break |
| tgt_emb = self.tgt_embedding(beams) * math.sqrt(self.d_model) |
| tgt_emb = self.dropout(tgt_emb) |
|
|
| |
| decoder_output = self.decoder(tgt_emb, context=memory, context_mask=context_mask) |
| logits = self.final_linear(decoder_output[:, -1, :]) |
| log_probs = F.log_softmax(logits, dim=-1) |
|
|
| |
| log_probs[:, tokenizer.pad_token_id] = -torch.inf |
| if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0 |
|
|
| |
| if _ == 0: |
| |
| |
| total = (beam_scores.unsqueeze(1) + log_probs).view(batch_size, num_beams, -1) |
| |
| total[:, 1:, :] = -torch.inf |
| |
| total = total.view(batch_size, -1) |
| else: |
| |
| total = (beam_scores.unsqueeze(1) + log_probs).view(batch_size, -1) |
|
|
| top_scores, top_indices = torch.topk(total, k=num_beams, dim=1) |
|
|
| beam_indices = top_indices // log_probs.shape[-1] |
| token_indices = top_indices % log_probs.shape[-1] |
|
|
| |
| effective = (torch.arange(batch_size, device=src.device).unsqueeze(1) * num_beams + beam_indices).view(-1) |
| beams = torch.cat([beams[effective], token_indices.view(-1, 1)], dim=1) |
| beam_scores = top_scores.view(-1) |
| finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id) |
|
|
| final_beams = beams.view(batch_size, num_beams, -1) |
| best_beams = final_beams[:, 0, :] |
| self.train() |
| return best_beams |
|
|