PRISM-Baseline-4-6 / modeling_baseline.py
Yujivus's picture
Upload folder using huggingface_hub
f8d046c verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
from x_transformers import Encoder, Decoder
from transformers import AutoTokenizer
# --- SMART TOKENIZER SETUP ---
try:
if os.path.exists("tokenizer_config.json"):
tokenizer = AutoTokenizer.from_pretrained(".")
else:
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-de-en")
except Exception as e:
print(f"Warning: Tokenizer load failed: {e}")
# -----------------------------
class RoPETransformer(nn.Module):
def __init__(self, num_encoder_layers, num_decoder_layers, num_heads, d_model, dff, vocab_size, max_length, dropout):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
# We REMOVE self.pos_encoder (RoPE handles position internally)
self.dropout_layer = nn.Dropout(dropout)
# --- x-transformers Encoder ---
self.encoder = Encoder(
dim = d_model,
depth = num_encoder_layers,
heads = num_heads,
attn_dim_head = d_model // num_heads,
ff_mult = dff / d_model,
rotary_pos_emb = True,
attn_flash = True,
attn_dropout = dropout,
ff_dropout = dropout,
use_rmsnorm = True
)
# --- x-transformers Decoder ---
self.decoder = Decoder(
dim = d_model,
depth = num_decoder_layers,
heads = num_heads,
attn_dim_head = d_model // num_heads,
ff_mult = dff / d_model,
rotary_pos_emb = True,
cross_attend = True,
attn_flash = True,
attn_dropout = dropout,
ff_dropout = dropout,
use_rmsnorm = True
)
self.final_linear = nn.Linear(d_model, vocab_size)
self.final_linear.weight = self.embedding.weight
def forward(self, src, tgt, src_padding_mask, tgt_padding_mask, memory_key_padding_mask, tgt_mask):
# 1. Embeddings (No Absolute Positional Encoding added!)
src_emb = self.embedding(src) * math.sqrt(self.d_model)
src_emb = self.dropout_layer(src_emb)
tgt_emb = self.embedding(tgt) * math.sqrt(self.d_model)
tgt_emb = self.dropout_layer(tgt_emb)
# 2. Mask Conversion
# User provides True=PAD. x-transformers wants True=KEEP.
# We invert the boolean mask using ~
enc_mask = ~src_padding_mask if src_padding_mask is not None else None
dec_mask = ~tgt_padding_mask if tgt_padding_mask is not None else None
# Note: 'tgt_mask' (causal mask) is handled automatically by x-transformers Decoder!
# We do NOT pass the square causal mask manually.
# 3. Encoder
# x-transformers takes embeddings directly
memory = self.encoder(src_emb, mask=enc_mask)
# 4. Decoder
# context = memory (from encoder)
# context_mask = mask for memory (encoder mask)
decoder_output = self.decoder(
tgt_emb,
context=memory,
mask=dec_mask,
context_mask=enc_mask
)
return self.final_linear(decoder_output)
# Keep your existing create_masks (used for Data Processing mostly)
def create_masks(self, src, tgt):
src_padding_mask = (src == tokenizer.pad_token_id)
tgt_padding_mask = (tgt == tokenizer.pad_token_id)
# We still generate this for compatibility, though x-transformers handles causality internally
tgt_mask = nn.Transformer.generate_square_subsequent_mask(
sz=tgt.size(1), device=src.device, dtype=torch.bool
)
return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask
@torch.no_grad()
def generate(self, src: torch.Tensor, max_length: int, num_beams: int = 5) -> torch.Tensor:
self.eval()
# Create Mask (True=PAD)
src_padding_mask = (src == tokenizer.pad_token_id)
# Invert for x-transformers (True=KEEP)
enc_mask = ~src_padding_mask
# Encode
src_emb = self.embedding(src) * math.sqrt(self.d_model)
# No Pos Encoder
memory = self.encoder(self.dropout_layer(src_emb), mask=enc_mask)
batch_size = src.shape[0]
# Expand for beams
memory = memory.repeat_interleave(num_beams, dim=0)
enc_mask = enc_mask.repeat_interleave(num_beams, dim=0)
initial_token = tokenizer.pad_token_id
beams = torch.full((batch_size * num_beams, 1), initial_token, dtype=torch.long, device=src.device)
beam_scores = torch.zeros(batch_size * num_beams, device=src.device)
finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device)
for _ in range(max_length - 1):
if finished_beams.all(): break
# Embed beams
tgt_emb = self.embedding(beams) * math.sqrt(self.d_model)
# No Pos Encoder
# Decode
# x-transformers automatically handles the causal masking for the sequence length of tgt_emb
decoder_output = self.decoder(
self.dropout_layer(tgt_emb),
context=memory,
context_mask=enc_mask
)
logits = self.final_linear(decoder_output[:, -1, :])
log_probs = F.log_softmax(logits, dim=-1)
# ... (Rest of your Beam Search Logic remains identical) ...
log_probs[:, tokenizer.pad_token_id] = -torch.inf
if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0
total_scores = beam_scores.unsqueeze(1) + log_probs
if _ == 0:
total_scores = total_scores.view(batch_size, num_beams, -1)
total_scores[:, 1:, :] = -torch.inf
total_scores = total_scores.view(batch_size * num_beams, -1)
else:
total_scores = beam_scores.unsqueeze(1) + log_probs
total_scores = total_scores.view(batch_size, -1)
top_scores, top_indices = torch.topk(total_scores, k=num_beams, dim=1)
beam_indices = top_indices // log_probs.shape[-1]
token_indices = top_indices % log_probs.shape[-1]
batch_indices = torch.arange(batch_size, device=src.device).unsqueeze(1)
effective_indices = (batch_indices * num_beams + beam_indices).view(-1)
beams = beams[effective_indices]
beams = torch.cat([beams, token_indices.view(-1, 1)], dim=1)
beam_scores = top_scores.view(-1)
finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id)
final_beams = beams.view(batch_size, num_beams, -1)
final_scores = beam_scores.view(batch_size, num_beams)
normalized_scores = final_scores / (final_beams != tokenizer.pad_token_id).sum(-1).float().clamp(min=1)
best_beams = final_beams[torch.arange(batch_size), normalized_scores.argmax(1), :]
self.train()
return best_beams