PRISM-Shimmer / modeling_prism.py
Yujivus's picture
Upload folder using huggingface_hub
9cff9de verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft
import math
from x_transformers import Decoder
from transformers import AutoTokenizer
import os
# --- GLOBAL TOKENIZER SETUP ---
try:
if os.path.exists("tokenizer_config.json"):
tokenizer = AutoTokenizer.from_pretrained(".")
else:
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-de-en")
except Exception as e:
print(f"Warning: Tokenizer load failed: {e}")
# ==================================================================
# SHIMMER ARCHITECTURE CLASSES
# ==================================================================
class ComplexDropout(nn.Module):
"""
# Standard nn.Dropout doesn't work on ComplexFloat.
# This module generates a mask based on the shape and applies it to both
# Real and Imaginary parts identically to preserve Phase.
"""
def __init__(self, p=0.5):
super().__init__()
self.p = p
def forward(self, z):
if not self.training or self.p == 0.0:
return z
# Generate mask using F.dropout on a ones tensor of the same shape (Real part)
# F.dropout handles the scaling (1 / 1-p) automatically
mask = torch.ones_like(z.real)
mask = F.dropout(mask, self.p, self.training, inplace=False)
# Apply mask to the complex tensor
return z * mask
class PhasePreservingLayerNorm(nn.Module):
def __init__(self, d_model, eps=1e-5):
super().__init__()
self.layernorm = nn.LayerNorm(d_model, eps=eps)
self.eps = eps
def forward(self, x):
mag = torch.abs(x)
mag_norm = self.layernorm(mag)
# Avoid division by zero
return mag_norm.to(x.dtype) * (x / (mag + self.eps))
class HarmonicEmbedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim, max_period=10000.0):
super().__init__()
self.embedding_dim = embedding_dim
# 1. Learnable Real and Imaginary parts (Cartesian coordinates)
# This allows learning both Amplitude AND Intrinsic Phase implicitly
self.complex_embedding = nn.Embedding(num_embeddings, embedding_dim * 2)
# Frequencies (Fixed)
freqs = torch.exp(torch.arange(0, embedding_dim, dtype=torch.float32) * -(math.log(max_period) / embedding_dim))
self.register_buffer('freqs', freqs)
def forward(self, input_ids):
# A. Get Learnable Content (Mag + Intrinsic Phase)
# Shape: [Batch, Seq, Dim * 2]
raw_embeds = self.complex_embedding(input_ids)
# Split into Real/Imag
real = raw_embeds[..., :self.embedding_dim]
imag = raw_embeds[..., self.embedding_dim:]
# Convert to Complex Tensor
# This Z already has Amplitude AND Intrinsic Phase
content_z = torch.complex(real, imag)
# B. Apply Positional Rotation (The "Clock")
seq_len = input_ids.shape[1]
positions = torch.arange(seq_len, device=input_ids.device).float()
angles = torch.outer(positions, self.freqs)
# Create Rotation (Phase Shift)
# e^(i * theta)
pos_rotation = torch.polar(torch.ones_like(angles), angles).unsqueeze(0)
# C. Rotate the Content
# Z_final = Z_content * e^(i * pos)
return content_z * pos_rotation
class PRISMEncoder(nn.Module):
def __init__(self, num_layers, d_model, max_len, dropout=0.1):
super().__init__()
self.layers = nn.ModuleList([PRISMLayer(d_model, max_len, dropout) for _ in range(num_layers)])
self.final_norm = PhasePreservingLayerNorm(d_model)
def forward(self, x, src_mask=None):
for layer in self.layers:
x = layer(x, src_mask)
# Apply Final Norm
return self.final_norm(x)
class ModReLU(nn.Module):
def __init__(self, features):
super().__init__()
self.b = nn.Parameter(torch.zeros(features))
def forward(self, z):
mag = torch.abs(z)
new_mag = F.relu(mag + self.b)
phase = z / (mag + 1e-6)
return new_mag * phase
class PRISMLayer(nn.Module):
def __init__(self, d_model, max_len=5000, dropout=0.1):
super().__init__()
self.d_model = d_model
self.filter_len = max_len
# --- REMOVED GATING PARAMS ---
# self.pre_gate = nn.Linear(d_model * 2, d_model)
# Global Filter
self.global_filter = nn.Parameter(torch.randn(d_model, max_len, dtype=torch.cfloat) * 0.02)
# Mixing
self.mix_real = nn.Linear(d_model, d_model)
self.mix_imag = nn.Linear(d_model, d_model)
self.out_real = nn.Linear(d_model, d_model)
self.out_imag = nn.Linear(d_model, d_model)
self.activation = ModReLU(d_model)
self.norm = PhasePreservingLayerNorm(d_model)
self.dropout = ComplexDropout(dropout)
def complex_linear(self, x, l_real, l_imag):
r, i = x.real, x.imag
new_r = l_real(r) - l_imag(i)
new_i = l_real(i) + l_imag(r)
return torch.complex(new_r, new_i)
def forward(self, x, src_mask=None):
residual = x
x_norm = self.norm(x)
if src_mask is not None:
mask_expanded = src_mask.unsqueeze(-1)
x_norm = x_norm.masked_fill(mask_expanded, 0.0)
# --- REMOVED GATING LOGIC ---
# Pass x_norm directly to FFT
x_gated = x_norm
# B. FFT Resonance
B, L, D = x_gated.shape
x_freq = torch.fft.fft(x_gated, n=self.filter_len, dim=1)
filter_transposed = self.global_filter.transpose(-1, -2)
x_filtered = x_freq * filter_transposed
x_time = torch.fft.ifft(x_filtered, n=self.filter_len, dim=1)
x_time = x_time[:, :L, :]
# C. Mix & Activate
x_mixed = self.complex_linear(x_time, self.mix_real, self.mix_imag)
x_act = self.activation(x_mixed)
out = self.complex_linear(x_act, self.out_real, self.out_imag)
return self.dropout(out) + residual
class ComplexToRealBridge(nn.Module):
def __init__(self, d_model):
super().__init__()
self.proj = nn.Linear(d_model * 2, d_model)
def forward(self, x_complex):
cat = torch.cat([x_complex.real, x_complex.imag], dim=-1)
return self.proj(cat)
class PRISMHybrid_RoPE(nn.Module):
def __init__(self, num_encoder_layers, num_refining_layers, num_decoder_layers,
num_heads, d_model, dff, vocab_size, max_length, dropout):
super().__init__()
self.d_model = d_model
# 1. Embeddings
self.harmonic_embedding = HarmonicEmbedding(vocab_size, d_model)
self.tgt_embedding = nn.Embedding(vocab_size, d_model)
self.dropout = nn.Dropout(dropout)
# 2. Harmonic Body (PRISM Encoder)
if num_encoder_layers > 0:
self.prism_encoder = PRISMEncoder(num_encoder_layers, d_model, max_length, dropout)
else:
self.prism_encoder = None
# 3. The Bridge
self.bridge = ComplexToRealBridge(d_model)
# 4. Refining Encoder
if num_refining_layers > 0:
refining_layer = nn.TransformerEncoderLayer(
d_model, num_heads, dff, dropout,
batch_first=True, norm_first=True
)
self.reasoning_encoder = nn.TransformerEncoder(refining_layer, num_layers=num_refining_layers)
else:
self.reasoning_encoder = None
# 5. Decoder (x-transformers)
self.decoder = Decoder(
dim = d_model,
depth = num_decoder_layers,
heads = num_heads,
attn_dim_head = d_model // num_heads,
ff_mult = dff / d_model,
rotary_pos_emb = True,
cross_attend = True,
attn_flash = True,
attn_dropout = dropout,
ff_dropout = dropout,
use_rmsnorm = True
)
# 6. Output Head
self.final_linear = nn.Linear(d_model, vocab_size)
self.final_linear.weight = self.tgt_embedding.weight
def create_masks(self, src, tgt):
src_padding_mask = (src == tokenizer.pad_token_id)
tgt_padding_mask = (tgt == tokenizer.pad_token_id)
tgt_mask = nn.Transformer.generate_square_subsequent_mask(
sz=tgt.size(1), device=src.device, dtype=torch.bool
)
return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask
def forward(self, src, tgt, src_mask, tgt_pad, mem_pad, tgt_mask):
# A. Harmonic Phase
src_harmonic = self.harmonic_embedding(src)
if src_mask is not None:
src_harmonic = src_harmonic.masked_fill(src_mask.unsqueeze(-1), 0.0)
# PRISM Encoder Pass
if self.prism_encoder is not None:
if self.training:
src_harmonic.requires_grad_(True)
encoded_complex = torch.utils.checkpoint.checkpoint(
self.prism_encoder, src_harmonic, src_mask, use_reentrant=False
)
else:
encoded_complex = self.prism_encoder(src_harmonic, src_mask)
else:
encoded_complex = src_harmonic
# B. The Bridge
coarse_memory = self.bridge(encoded_complex)
# C. Refining Phase
if self.reasoning_encoder is not None:
refined_memory = self.reasoning_encoder(coarse_memory, src_key_padding_mask=mem_pad)
else:
refined_memory = coarse_memory
# D. Decoder Prep
tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
tgt_emb = self.dropout(tgt_emb)
context_mask = ~mem_pad if mem_pad is not None else None
decoder_mask = ~tgt_pad if tgt_pad is not None else None
# E. Decoder Pass (Checkpointing)
if self.training:
tgt_emb.requires_grad_(True)
output = torch.utils.checkpoint.checkpoint(
self.decoder,
tgt_emb,
context=refined_memory,
mask=decoder_mask,
context_mask=context_mask,
use_reentrant=False
)
else:
output = self.decoder(
tgt_emb,
context=refined_memory,
mask=decoder_mask,
context_mask=context_mask
)
return self.final_linear(output)
@torch.no_grad()
def generate(self, src, max_length, num_beams=5):
self.eval()
src_mask = (src == tokenizer.pad_token_id)
context_mask = ~src_mask
src_harmonic = self.harmonic_embedding(src)
if src_mask is not None:
src_harmonic = src_harmonic.masked_fill(src_mask.unsqueeze(-1), 0.0)
if self.prism_encoder is not None:
encoded_complex = self.prism_encoder(src_harmonic, src_mask)
else:
encoded_complex = src_harmonic
coarse_memory = self.bridge(encoded_complex)
if self.reasoning_encoder is not None:
memory = self.reasoning_encoder(coarse_memory, src_key_padding_mask=src_mask)
else:
memory = coarse_memory
batch_size = src.shape[0]
memory = memory.repeat_interleave(num_beams, dim=0)
context_mask = context_mask.repeat_interleave(num_beams, dim=0)
beams = torch.full((batch_size * num_beams, 1), tokenizer.pad_token_id, dtype=torch.long, device=src.device)
beam_scores = torch.zeros(batch_size * num_beams, device=src.device)
finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device)
for _ in range(max_length - 1):
if finished_beams.all(): break
tgt_emb = self.tgt_embedding(beams) * math.sqrt(self.d_model)
tgt_emb = self.dropout(tgt_emb)
# Decoder
decoder_output = self.decoder(tgt_emb, context=memory, context_mask=context_mask)
logits = self.final_linear(decoder_output[:, -1, :])
log_probs = F.log_softmax(logits, dim=-1)
# Masking
log_probs[:, tokenizer.pad_token_id] = -torch.inf
if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0
# --- BEAM SEARCH LOGIC FIX ---
if _ == 0:
# First Step: Expand from the first beam only (since all are identical start tokens)
# Reshape to (batch, beams, vocab)
total = (beam_scores.unsqueeze(1) + log_probs).view(batch_size, num_beams, -1)
# Mask out all beams except the first one (-inf)
total[:, 1:, :] = -torch.inf
# Flatten back to (batch, beams*vocab) to pick top k
total = total.view(batch_size, -1)
else:
# Subsequent Steps: Standard Flatten
total = (beam_scores.unsqueeze(1) + log_probs).view(batch_size, -1)
top_scores, top_indices = torch.topk(total, k=num_beams, dim=1)
beam_indices = top_indices // log_probs.shape[-1]
token_indices = top_indices % log_probs.shape[-1]
# Now dimensions match: (batch_size, 1) + (batch_size, k)
effective = (torch.arange(batch_size, device=src.device).unsqueeze(1) * num_beams + beam_indices).view(-1)
beams = torch.cat([beams[effective], token_indices.view(-1, 1)], dim=1)
beam_scores = top_scores.view(-1)
finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id)
final_beams = beams.view(batch_size, num_beams, -1)
best_beams = final_beams[:, 0, :]
self.train()
return best_beams