|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.fft |
|
|
import math |
|
|
from x_transformers import Decoder |
|
|
from transformers import AutoTokenizer |
|
|
import os |
|
|
|
|
|
|
|
|
try: |
|
|
if os.path.exists("tokenizer_config.json"): |
|
|
tokenizer = AutoTokenizer.from_pretrained(".") |
|
|
else: |
|
|
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-de-en") |
|
|
except Exception as e: |
|
|
print(f"Warning: Tokenizer load failed: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ComplexDropout(nn.Module): |
|
|
""" |
|
|
# Standard nn.Dropout doesn't work on ComplexFloat. |
|
|
# This module generates a mask based on the shape and applies it to both |
|
|
# Real and Imaginary parts identically to preserve Phase. |
|
|
""" |
|
|
def __init__(self, p=0.5): |
|
|
super().__init__() |
|
|
self.p = p |
|
|
|
|
|
def forward(self, z): |
|
|
if not self.training or self.p == 0.0: |
|
|
return z |
|
|
|
|
|
|
|
|
|
|
|
mask = torch.ones_like(z.real) |
|
|
mask = F.dropout(mask, self.p, self.training, inplace=False) |
|
|
|
|
|
|
|
|
return z * mask |
|
|
|
|
|
class PhasePreservingLayerNorm(nn.Module): |
|
|
def __init__(self, d_model, eps=1e-5): |
|
|
super().__init__() |
|
|
self.layernorm = nn.LayerNorm(d_model, eps=eps) |
|
|
self.eps = eps |
|
|
|
|
|
def forward(self, x): |
|
|
mag = torch.abs(x) |
|
|
mag_norm = self.layernorm(mag) |
|
|
|
|
|
return mag_norm.to(x.dtype) * (x / (mag + self.eps)) |
|
|
|
|
|
class HarmonicEmbedding(nn.Module): |
|
|
def __init__(self, num_embeddings, embedding_dim, max_period=10000.0): |
|
|
super().__init__() |
|
|
self.embedding_dim = embedding_dim |
|
|
|
|
|
|
|
|
|
|
|
self.complex_embedding = nn.Embedding(num_embeddings, embedding_dim * 2) |
|
|
|
|
|
|
|
|
freqs = torch.exp(torch.arange(0, embedding_dim, dtype=torch.float32) * -(math.log(max_period) / embedding_dim)) |
|
|
self.register_buffer('freqs', freqs) |
|
|
|
|
|
def forward(self, input_ids): |
|
|
|
|
|
|
|
|
raw_embeds = self.complex_embedding(input_ids) |
|
|
|
|
|
|
|
|
real = raw_embeds[..., :self.embedding_dim] |
|
|
imag = raw_embeds[..., self.embedding_dim:] |
|
|
|
|
|
|
|
|
|
|
|
content_z = torch.complex(real, imag) |
|
|
|
|
|
|
|
|
seq_len = input_ids.shape[1] |
|
|
positions = torch.arange(seq_len, device=input_ids.device).float() |
|
|
angles = torch.outer(positions, self.freqs) |
|
|
|
|
|
|
|
|
|
|
|
pos_rotation = torch.polar(torch.ones_like(angles), angles).unsqueeze(0) |
|
|
|
|
|
|
|
|
|
|
|
return content_z * pos_rotation |
|
|
|
|
|
class PRISMEncoder(nn.Module): |
|
|
def __init__(self, num_layers, d_model, max_len, dropout=0.1): |
|
|
super().__init__() |
|
|
self.layers = nn.ModuleList([PRISMLayer(d_model, max_len, dropout) for _ in range(num_layers)]) |
|
|
|
|
|
self.final_norm = PhasePreservingLayerNorm(d_model) |
|
|
|
|
|
def forward(self, x, src_mask=None): |
|
|
for layer in self.layers: |
|
|
x = layer(x, src_mask) |
|
|
|
|
|
|
|
|
return self.final_norm(x) |
|
|
|
|
|
class ModReLU(nn.Module): |
|
|
def __init__(self, features): |
|
|
super().__init__() |
|
|
self.b = nn.Parameter(torch.zeros(features)) |
|
|
def forward(self, z): |
|
|
mag = torch.abs(z) |
|
|
new_mag = F.relu(mag + self.b) |
|
|
phase = z / (mag + 1e-6) |
|
|
return new_mag * phase |
|
|
|
|
|
class PRISMLayer(nn.Module): |
|
|
def __init__(self, d_model, max_len=5000, dropout=0.1): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.filter_len = max_len |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.global_filter = nn.Parameter(torch.randn(d_model, max_len, dtype=torch.cfloat) * 0.02) |
|
|
|
|
|
|
|
|
self.mix_real = nn.Linear(d_model, d_model) |
|
|
self.mix_imag = nn.Linear(d_model, d_model) |
|
|
self.out_real = nn.Linear(d_model, d_model) |
|
|
self.out_imag = nn.Linear(d_model, d_model) |
|
|
|
|
|
self.activation = ModReLU(d_model) |
|
|
self.norm = PhasePreservingLayerNorm(d_model) |
|
|
self.dropout = ComplexDropout(dropout) |
|
|
|
|
|
def complex_linear(self, x, l_real, l_imag): |
|
|
r, i = x.real, x.imag |
|
|
new_r = l_real(r) - l_imag(i) |
|
|
new_i = l_real(i) + l_imag(r) |
|
|
return torch.complex(new_r, new_i) |
|
|
|
|
|
def forward(self, x, src_mask=None): |
|
|
residual = x |
|
|
x_norm = self.norm(x) |
|
|
|
|
|
if src_mask is not None: |
|
|
mask_expanded = src_mask.unsqueeze(-1) |
|
|
x_norm = x_norm.masked_fill(mask_expanded, 0.0) |
|
|
|
|
|
|
|
|
|
|
|
x_gated = x_norm |
|
|
|
|
|
|
|
|
B, L, D = x_gated.shape |
|
|
x_freq = torch.fft.fft(x_gated, n=self.filter_len, dim=1) |
|
|
filter_transposed = self.global_filter.transpose(-1, -2) |
|
|
x_filtered = x_freq * filter_transposed |
|
|
x_time = torch.fft.ifft(x_filtered, n=self.filter_len, dim=1) |
|
|
x_time = x_time[:, :L, :] |
|
|
|
|
|
|
|
|
x_mixed = self.complex_linear(x_time, self.mix_real, self.mix_imag) |
|
|
x_act = self.activation(x_mixed) |
|
|
out = self.complex_linear(x_act, self.out_real, self.out_imag) |
|
|
|
|
|
return self.dropout(out) + residual |
|
|
|
|
|
class ComplexToRealBridge(nn.Module): |
|
|
def __init__(self, d_model): |
|
|
super().__init__() |
|
|
self.proj = nn.Linear(d_model * 2, d_model) |
|
|
def forward(self, x_complex): |
|
|
cat = torch.cat([x_complex.real, x_complex.imag], dim=-1) |
|
|
return self.proj(cat) |
|
|
|
|
|
class PRISMHybrid_RoPE(nn.Module): |
|
|
def __init__(self, num_encoder_layers, num_refining_layers, num_decoder_layers, |
|
|
num_heads, d_model, dff, vocab_size, max_length, dropout): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
|
|
|
|
|
|
self.harmonic_embedding = HarmonicEmbedding(vocab_size, d_model) |
|
|
self.tgt_embedding = nn.Embedding(vocab_size, d_model) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
if num_encoder_layers > 0: |
|
|
self.prism_encoder = PRISMEncoder(num_encoder_layers, d_model, max_length, dropout) |
|
|
else: |
|
|
self.prism_encoder = None |
|
|
|
|
|
|
|
|
self.bridge = ComplexToRealBridge(d_model) |
|
|
|
|
|
|
|
|
if num_refining_layers > 0: |
|
|
refining_layer = nn.TransformerEncoderLayer( |
|
|
d_model, num_heads, dff, dropout, |
|
|
batch_first=True, norm_first=True |
|
|
) |
|
|
self.reasoning_encoder = nn.TransformerEncoder(refining_layer, num_layers=num_refining_layers) |
|
|
else: |
|
|
self.reasoning_encoder = None |
|
|
|
|
|
|
|
|
self.decoder = Decoder( |
|
|
dim = d_model, |
|
|
depth = num_decoder_layers, |
|
|
heads = num_heads, |
|
|
attn_dim_head = d_model // num_heads, |
|
|
ff_mult = dff / d_model, |
|
|
rotary_pos_emb = True, |
|
|
cross_attend = True, |
|
|
attn_flash = True, |
|
|
attn_dropout = dropout, |
|
|
ff_dropout = dropout, |
|
|
use_rmsnorm = True |
|
|
) |
|
|
|
|
|
|
|
|
self.final_linear = nn.Linear(d_model, vocab_size) |
|
|
self.final_linear.weight = self.tgt_embedding.weight |
|
|
|
|
|
def create_masks(self, src, tgt): |
|
|
src_padding_mask = (src == tokenizer.pad_token_id) |
|
|
tgt_padding_mask = (tgt == tokenizer.pad_token_id) |
|
|
tgt_mask = nn.Transformer.generate_square_subsequent_mask( |
|
|
sz=tgt.size(1), device=src.device, dtype=torch.bool |
|
|
) |
|
|
return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask |
|
|
|
|
|
def forward(self, src, tgt, src_mask, tgt_pad, mem_pad, tgt_mask): |
|
|
|
|
|
src_harmonic = self.harmonic_embedding(src) |
|
|
if src_mask is not None: |
|
|
src_harmonic = src_harmonic.masked_fill(src_mask.unsqueeze(-1), 0.0) |
|
|
|
|
|
|
|
|
if self.prism_encoder is not None: |
|
|
if self.training: |
|
|
src_harmonic.requires_grad_(True) |
|
|
encoded_complex = torch.utils.checkpoint.checkpoint( |
|
|
self.prism_encoder, src_harmonic, src_mask, use_reentrant=False |
|
|
) |
|
|
else: |
|
|
encoded_complex = self.prism_encoder(src_harmonic, src_mask) |
|
|
else: |
|
|
encoded_complex = src_harmonic |
|
|
|
|
|
|
|
|
coarse_memory = self.bridge(encoded_complex) |
|
|
|
|
|
|
|
|
if self.reasoning_encoder is not None: |
|
|
refined_memory = self.reasoning_encoder(coarse_memory, src_key_padding_mask=mem_pad) |
|
|
else: |
|
|
refined_memory = coarse_memory |
|
|
|
|
|
|
|
|
tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model) |
|
|
tgt_emb = self.dropout(tgt_emb) |
|
|
context_mask = ~mem_pad if mem_pad is not None else None |
|
|
decoder_mask = ~tgt_pad if tgt_pad is not None else None |
|
|
|
|
|
|
|
|
if self.training: |
|
|
tgt_emb.requires_grad_(True) |
|
|
output = torch.utils.checkpoint.checkpoint( |
|
|
self.decoder, |
|
|
tgt_emb, |
|
|
context=refined_memory, |
|
|
mask=decoder_mask, |
|
|
context_mask=context_mask, |
|
|
use_reentrant=False |
|
|
) |
|
|
else: |
|
|
output = self.decoder( |
|
|
tgt_emb, |
|
|
context=refined_memory, |
|
|
mask=decoder_mask, |
|
|
context_mask=context_mask |
|
|
) |
|
|
|
|
|
return self.final_linear(output) |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(self, src, max_length, num_beams=5): |
|
|
self.eval() |
|
|
src_mask = (src == tokenizer.pad_token_id) |
|
|
context_mask = ~src_mask |
|
|
src_harmonic = self.harmonic_embedding(src) |
|
|
if src_mask is not None: |
|
|
src_harmonic = src_harmonic.masked_fill(src_mask.unsqueeze(-1), 0.0) |
|
|
|
|
|
if self.prism_encoder is not None: |
|
|
encoded_complex = self.prism_encoder(src_harmonic, src_mask) |
|
|
else: |
|
|
encoded_complex = src_harmonic |
|
|
|
|
|
coarse_memory = self.bridge(encoded_complex) |
|
|
|
|
|
if self.reasoning_encoder is not None: |
|
|
memory = self.reasoning_encoder(coarse_memory, src_key_padding_mask=src_mask) |
|
|
else: |
|
|
memory = coarse_memory |
|
|
|
|
|
batch_size = src.shape[0] |
|
|
memory = memory.repeat_interleave(num_beams, dim=0) |
|
|
context_mask = context_mask.repeat_interleave(num_beams, dim=0) |
|
|
|
|
|
beams = torch.full((batch_size * num_beams, 1), tokenizer.pad_token_id, dtype=torch.long, device=src.device) |
|
|
beam_scores = torch.zeros(batch_size * num_beams, device=src.device) |
|
|
finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device) |
|
|
|
|
|
for _ in range(max_length - 1): |
|
|
if finished_beams.all(): break |
|
|
tgt_emb = self.tgt_embedding(beams) * math.sqrt(self.d_model) |
|
|
tgt_emb = self.dropout(tgt_emb) |
|
|
|
|
|
|
|
|
decoder_output = self.decoder(tgt_emb, context=memory, context_mask=context_mask) |
|
|
logits = self.final_linear(decoder_output[:, -1, :]) |
|
|
log_probs = F.log_softmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
log_probs[:, tokenizer.pad_token_id] = -torch.inf |
|
|
if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0 |
|
|
|
|
|
|
|
|
if _ == 0: |
|
|
|
|
|
|
|
|
total = (beam_scores.unsqueeze(1) + log_probs).view(batch_size, num_beams, -1) |
|
|
|
|
|
total[:, 1:, :] = -torch.inf |
|
|
|
|
|
total = total.view(batch_size, -1) |
|
|
else: |
|
|
|
|
|
total = (beam_scores.unsqueeze(1) + log_probs).view(batch_size, -1) |
|
|
|
|
|
top_scores, top_indices = torch.topk(total, k=num_beams, dim=1) |
|
|
|
|
|
beam_indices = top_indices // log_probs.shape[-1] |
|
|
token_indices = top_indices % log_probs.shape[-1] |
|
|
|
|
|
|
|
|
effective = (torch.arange(batch_size, device=src.device).unsqueeze(1) * num_beams + beam_indices).view(-1) |
|
|
beams = torch.cat([beams[effective], token_indices.view(-1, 1)], dim=1) |
|
|
beam_scores = top_scores.view(-1) |
|
|
finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id) |
|
|
|
|
|
final_beams = beams.view(batch_size, num_beams, -1) |
|
|
best_beams = final_beams[:, 0, :] |
|
|
self.train() |
|
|
return best_beams |
|
|
|