import torch import torch.nn as nn import torch.nn.functional as F import torch.fft import math from x_transformers import Decoder from transformers import AutoTokenizer import os # --- GLOBAL TOKENIZER SETUP --- try: if os.path.exists("tokenizer_config.json"): tokenizer = AutoTokenizer.from_pretrained(".") else: tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-de-en") except Exception as e: print(f"Warning: Tokenizer load failed: {e}") # ================================================================== # SHIMMER ARCHITECTURE CLASSES # ================================================================== class ComplexDropout(nn.Module): """ # Standard nn.Dropout doesn't work on ComplexFloat. # This module generates a mask based on the shape and applies it to both # Real and Imaginary parts identically to preserve Phase. """ def __init__(self, p=0.5): super().__init__() self.p = p def forward(self, z): if not self.training or self.p == 0.0: return z # Generate mask using F.dropout on a ones tensor of the same shape (Real part) # F.dropout handles the scaling (1 / 1-p) automatically mask = torch.ones_like(z.real) mask = F.dropout(mask, self.p, self.training, inplace=False) # Apply mask to the complex tensor return z * mask class PhasePreservingLayerNorm(nn.Module): def __init__(self, d_model, eps=1e-5): super().__init__() self.layernorm = nn.LayerNorm(d_model, eps=eps) self.eps = eps def forward(self, x): mag = torch.abs(x) mag_norm = self.layernorm(mag) # Avoid division by zero return mag_norm.to(x.dtype) * (x / (mag + self.eps)) class HarmonicEmbedding(nn.Module): def __init__(self, num_embeddings, embedding_dim, max_period=10000.0): super().__init__() self.embedding_dim = embedding_dim # 1. Learnable Real and Imaginary parts (Cartesian coordinates) # This allows learning both Amplitude AND Intrinsic Phase implicitly self.complex_embedding = nn.Embedding(num_embeddings, embedding_dim * 2) # Frequencies (Fixed) freqs = torch.exp(torch.arange(0, embedding_dim, dtype=torch.float32) * -(math.log(max_period) / embedding_dim)) self.register_buffer('freqs', freqs) def forward(self, input_ids): # A. Get Learnable Content (Mag + Intrinsic Phase) # Shape: [Batch, Seq, Dim * 2] raw_embeds = self.complex_embedding(input_ids) # Split into Real/Imag real = raw_embeds[..., :self.embedding_dim] imag = raw_embeds[..., self.embedding_dim:] # Convert to Complex Tensor # This Z already has Amplitude AND Intrinsic Phase content_z = torch.complex(real, imag) # B. Apply Positional Rotation (The "Clock") seq_len = input_ids.shape[1] positions = torch.arange(seq_len, device=input_ids.device).float() angles = torch.outer(positions, self.freqs) # Create Rotation (Phase Shift) # e^(i * theta) pos_rotation = torch.polar(torch.ones_like(angles), angles).unsqueeze(0) # C. Rotate the Content # Z_final = Z_content * e^(i * pos) return content_z * pos_rotation class PRISMEncoder(nn.Module): def __init__(self, num_layers, d_model, max_len, dropout=0.1): super().__init__() self.layers = nn.ModuleList([PRISMLayer(d_model, max_len, dropout) for _ in range(num_layers)]) self.final_norm = PhasePreservingLayerNorm(d_model) def forward(self, x, src_mask=None): for layer in self.layers: x = layer(x, src_mask) # Apply Final Norm return self.final_norm(x) class ModReLU(nn.Module): def __init__(self, features): super().__init__() self.b = nn.Parameter(torch.zeros(features)) def forward(self, z): mag = torch.abs(z) new_mag = F.relu(mag + self.b) phase = z / (mag + 1e-6) return new_mag * phase class PRISMLayer(nn.Module): def __init__(self, d_model, max_len=5000, dropout=0.1): super().__init__() self.d_model = d_model self.filter_len = max_len # --- REMOVED GATING PARAMS --- # self.pre_gate = nn.Linear(d_model * 2, d_model) # Global Filter self.global_filter = nn.Parameter(torch.randn(d_model, max_len, dtype=torch.cfloat) * 0.02) # Mixing self.mix_real = nn.Linear(d_model, d_model) self.mix_imag = nn.Linear(d_model, d_model) self.out_real = nn.Linear(d_model, d_model) self.out_imag = nn.Linear(d_model, d_model) self.activation = ModReLU(d_model) self.norm = PhasePreservingLayerNorm(d_model) self.dropout = ComplexDropout(dropout) def complex_linear(self, x, l_real, l_imag): r, i = x.real, x.imag new_r = l_real(r) - l_imag(i) new_i = l_real(i) + l_imag(r) return torch.complex(new_r, new_i) def forward(self, x, src_mask=None): residual = x x_norm = self.norm(x) if src_mask is not None: mask_expanded = src_mask.unsqueeze(-1) x_norm = x_norm.masked_fill(mask_expanded, 0.0) # --- REMOVED GATING LOGIC --- # Pass x_norm directly to FFT x_gated = x_norm # B. FFT Resonance B, L, D = x_gated.shape x_freq = torch.fft.fft(x_gated, n=self.filter_len, dim=1) filter_transposed = self.global_filter.transpose(-1, -2) x_filtered = x_freq * filter_transposed x_time = torch.fft.ifft(x_filtered, n=self.filter_len, dim=1) x_time = x_time[:, :L, :] # C. Mix & Activate x_mixed = self.complex_linear(x_time, self.mix_real, self.mix_imag) x_act = self.activation(x_mixed) out = self.complex_linear(x_act, self.out_real, self.out_imag) return self.dropout(out) + residual class ComplexToRealBridge(nn.Module): def __init__(self, d_model): super().__init__() self.proj = nn.Linear(d_model * 2, d_model) def forward(self, x_complex): cat = torch.cat([x_complex.real, x_complex.imag], dim=-1) return self.proj(cat) class PRISMHybrid_RoPE(nn.Module): def __init__(self, num_encoder_layers, num_refining_layers, num_decoder_layers, num_heads, d_model, dff, vocab_size, max_length, dropout): super().__init__() self.d_model = d_model # 1. Embeddings self.harmonic_embedding = HarmonicEmbedding(vocab_size, d_model) self.tgt_embedding = nn.Embedding(vocab_size, d_model) self.dropout = nn.Dropout(dropout) # 2. Harmonic Body (PRISM Encoder) if num_encoder_layers > 0: self.prism_encoder = PRISMEncoder(num_encoder_layers, d_model, max_length, dropout) else: self.prism_encoder = None # 3. The Bridge self.bridge = ComplexToRealBridge(d_model) # 4. Refining Encoder if num_refining_layers > 0: refining_layer = nn.TransformerEncoderLayer( d_model, num_heads, dff, dropout, batch_first=True, norm_first=True ) self.reasoning_encoder = nn.TransformerEncoder(refining_layer, num_layers=num_refining_layers) else: self.reasoning_encoder = None # 5. Decoder (x-transformers) self.decoder = Decoder( dim = d_model, depth = num_decoder_layers, heads = num_heads, attn_dim_head = d_model // num_heads, ff_mult = dff / d_model, rotary_pos_emb = True, cross_attend = True, attn_flash = True, attn_dropout = dropout, ff_dropout = dropout, use_rmsnorm = True ) # 6. Output Head self.final_linear = nn.Linear(d_model, vocab_size) self.final_linear.weight = self.tgt_embedding.weight def create_masks(self, src, tgt): src_padding_mask = (src == tokenizer.pad_token_id) tgt_padding_mask = (tgt == tokenizer.pad_token_id) tgt_mask = nn.Transformer.generate_square_subsequent_mask( sz=tgt.size(1), device=src.device, dtype=torch.bool ) return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask def forward(self, src, tgt, src_mask, tgt_pad, mem_pad, tgt_mask): # A. Harmonic Phase src_harmonic = self.harmonic_embedding(src) if src_mask is not None: src_harmonic = src_harmonic.masked_fill(src_mask.unsqueeze(-1), 0.0) # PRISM Encoder Pass if self.prism_encoder is not None: if self.training: src_harmonic.requires_grad_(True) encoded_complex = torch.utils.checkpoint.checkpoint( self.prism_encoder, src_harmonic, src_mask, use_reentrant=False ) else: encoded_complex = self.prism_encoder(src_harmonic, src_mask) else: encoded_complex = src_harmonic # B. The Bridge coarse_memory = self.bridge(encoded_complex) # C. Refining Phase if self.reasoning_encoder is not None: refined_memory = self.reasoning_encoder(coarse_memory, src_key_padding_mask=mem_pad) else: refined_memory = coarse_memory # D. Decoder Prep tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model) tgt_emb = self.dropout(tgt_emb) context_mask = ~mem_pad if mem_pad is not None else None decoder_mask = ~tgt_pad if tgt_pad is not None else None # E. Decoder Pass (Checkpointing) if self.training: tgt_emb.requires_grad_(True) output = torch.utils.checkpoint.checkpoint( self.decoder, tgt_emb, context=refined_memory, mask=decoder_mask, context_mask=context_mask, use_reentrant=False ) else: output = self.decoder( tgt_emb, context=refined_memory, mask=decoder_mask, context_mask=context_mask ) return self.final_linear(output) @torch.no_grad() def generate(self, src, max_length, num_beams=5): self.eval() src_mask = (src == tokenizer.pad_token_id) context_mask = ~src_mask src_harmonic = self.harmonic_embedding(src) if src_mask is not None: src_harmonic = src_harmonic.masked_fill(src_mask.unsqueeze(-1), 0.0) if self.prism_encoder is not None: encoded_complex = self.prism_encoder(src_harmonic, src_mask) else: encoded_complex = src_harmonic coarse_memory = self.bridge(encoded_complex) if self.reasoning_encoder is not None: memory = self.reasoning_encoder(coarse_memory, src_key_padding_mask=src_mask) else: memory = coarse_memory batch_size = src.shape[0] memory = memory.repeat_interleave(num_beams, dim=0) context_mask = context_mask.repeat_interleave(num_beams, dim=0) beams = torch.full((batch_size * num_beams, 1), tokenizer.pad_token_id, dtype=torch.long, device=src.device) beam_scores = torch.zeros(batch_size * num_beams, device=src.device) finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device) for _ in range(max_length - 1): if finished_beams.all(): break tgt_emb = self.tgt_embedding(beams) * math.sqrt(self.d_model) tgt_emb = self.dropout(tgt_emb) # Decoder decoder_output = self.decoder(tgt_emb, context=memory, context_mask=context_mask) logits = self.final_linear(decoder_output[:, -1, :]) log_probs = F.log_softmax(logits, dim=-1) # Masking log_probs[:, tokenizer.pad_token_id] = -torch.inf if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0 # --- BEAM SEARCH LOGIC FIX --- if _ == 0: # First Step: Expand from the first beam only (since all are identical start tokens) # Reshape to (batch, beams, vocab) total = (beam_scores.unsqueeze(1) + log_probs).view(batch_size, num_beams, -1) # Mask out all beams except the first one (-inf) total[:, 1:, :] = -torch.inf # Flatten back to (batch, beams*vocab) to pick top k total = total.view(batch_size, -1) else: # Subsequent Steps: Standard Flatten total = (beam_scores.unsqueeze(1) + log_probs).view(batch_size, -1) top_scores, top_indices = torch.topk(total, k=num_beams, dim=1) beam_indices = top_indices // log_probs.shape[-1] token_indices = top_indices % log_probs.shape[-1] # Now dimensions match: (batch_size, 1) + (batch_size, k) effective = (torch.arange(batch_size, device=src.device).unsqueeze(1) * num_beams + beam_indices).view(-1) beams = torch.cat([beams[effective], token_indices.view(-1, 1)], dim=1) beam_scores = top_scores.view(-1) finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id) final_beams = beams.view(batch_size, num_beams, -1) best_beams = final_beams[:, 0, :] self.train() return best_beams