|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
import os |
|
|
from x_transformers import Encoder, Decoder |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
|
|
|
try: |
|
|
if os.path.exists("tokenizer_config.json"): |
|
|
tokenizer = AutoTokenizer.from_pretrained(".") |
|
|
else: |
|
|
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-de-en") |
|
|
except Exception as e: |
|
|
print(f"Warning: Tokenizer load failed: {e}") |
|
|
|
|
|
|
|
|
|
|
|
class RoPETransformer(nn.Module): |
|
|
def __init__(self, num_encoder_layers, num_decoder_layers, num_heads, d_model, dff, vocab_size, max_length, dropout): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.embedding = nn.Embedding(vocab_size, d_model) |
|
|
|
|
|
|
|
|
self.dropout_layer = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
self.encoder = Encoder( |
|
|
dim = d_model, |
|
|
depth = num_encoder_layers, |
|
|
heads = num_heads, |
|
|
attn_dim_head = d_model // num_heads, |
|
|
ff_mult = dff / d_model, |
|
|
rotary_pos_emb = True, |
|
|
attn_flash = True, |
|
|
attn_dropout = dropout, |
|
|
ff_dropout = dropout, |
|
|
use_rmsnorm = True |
|
|
) |
|
|
|
|
|
|
|
|
self.decoder = Decoder( |
|
|
dim = d_model, |
|
|
depth = num_decoder_layers, |
|
|
heads = num_heads, |
|
|
attn_dim_head = d_model // num_heads, |
|
|
ff_mult = dff / d_model, |
|
|
rotary_pos_emb = True, |
|
|
cross_attend = True, |
|
|
attn_flash = True, |
|
|
attn_dropout = dropout, |
|
|
ff_dropout = dropout, |
|
|
use_rmsnorm = True |
|
|
) |
|
|
|
|
|
self.final_linear = nn.Linear(d_model, vocab_size) |
|
|
self.final_linear.weight = self.embedding.weight |
|
|
|
|
|
def forward(self, src, tgt, src_padding_mask, tgt_padding_mask, memory_key_padding_mask, tgt_mask): |
|
|
|
|
|
src_emb = self.embedding(src) * math.sqrt(self.d_model) |
|
|
src_emb = self.dropout_layer(src_emb) |
|
|
|
|
|
tgt_emb = self.embedding(tgt) * math.sqrt(self.d_model) |
|
|
tgt_emb = self.dropout_layer(tgt_emb) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enc_mask = ~src_padding_mask if src_padding_mask is not None else None |
|
|
dec_mask = ~tgt_padding_mask if tgt_padding_mask is not None else None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
memory = self.encoder(src_emb, mask=enc_mask) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
decoder_output = self.decoder( |
|
|
tgt_emb, |
|
|
context=memory, |
|
|
mask=dec_mask, |
|
|
context_mask=enc_mask |
|
|
) |
|
|
|
|
|
return self.final_linear(decoder_output) |
|
|
|
|
|
|
|
|
def create_masks(self, src, tgt): |
|
|
src_padding_mask = (src == tokenizer.pad_token_id) |
|
|
tgt_padding_mask = (tgt == tokenizer.pad_token_id) |
|
|
|
|
|
tgt_mask = nn.Transformer.generate_square_subsequent_mask( |
|
|
sz=tgt.size(1), device=src.device, dtype=torch.bool |
|
|
) |
|
|
return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(self, src: torch.Tensor, max_length: int, num_beams: int = 5) -> torch.Tensor: |
|
|
self.eval() |
|
|
|
|
|
src_padding_mask = (src == tokenizer.pad_token_id) |
|
|
|
|
|
enc_mask = ~src_padding_mask |
|
|
|
|
|
|
|
|
src_emb = self.embedding(src) * math.sqrt(self.d_model) |
|
|
|
|
|
memory = self.encoder(self.dropout_layer(src_emb), mask=enc_mask) |
|
|
|
|
|
batch_size = src.shape[0] |
|
|
|
|
|
memory = memory.repeat_interleave(num_beams, dim=0) |
|
|
enc_mask = enc_mask.repeat_interleave(num_beams, dim=0) |
|
|
|
|
|
initial_token = tokenizer.pad_token_id |
|
|
beams = torch.full((batch_size * num_beams, 1), initial_token, dtype=torch.long, device=src.device) |
|
|
beam_scores = torch.zeros(batch_size * num_beams, device=src.device) |
|
|
finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device) |
|
|
|
|
|
for _ in range(max_length - 1): |
|
|
if finished_beams.all(): break |
|
|
|
|
|
|
|
|
tgt_emb = self.embedding(beams) * math.sqrt(self.d_model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
decoder_output = self.decoder( |
|
|
self.dropout_layer(tgt_emb), |
|
|
context=memory, |
|
|
context_mask=enc_mask |
|
|
) |
|
|
|
|
|
logits = self.final_linear(decoder_output[:, -1, :]) |
|
|
log_probs = F.log_softmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
log_probs[:, tokenizer.pad_token_id] = -torch.inf |
|
|
if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0 |
|
|
|
|
|
total_scores = beam_scores.unsqueeze(1) + log_probs |
|
|
if _ == 0: |
|
|
total_scores = total_scores.view(batch_size, num_beams, -1) |
|
|
total_scores[:, 1:, :] = -torch.inf |
|
|
total_scores = total_scores.view(batch_size * num_beams, -1) |
|
|
else: |
|
|
total_scores = beam_scores.unsqueeze(1) + log_probs |
|
|
|
|
|
total_scores = total_scores.view(batch_size, -1) |
|
|
top_scores, top_indices = torch.topk(total_scores, k=num_beams, dim=1) |
|
|
|
|
|
beam_indices = top_indices // log_probs.shape[-1] |
|
|
token_indices = top_indices % log_probs.shape[-1] |
|
|
|
|
|
batch_indices = torch.arange(batch_size, device=src.device).unsqueeze(1) |
|
|
effective_indices = (batch_indices * num_beams + beam_indices).view(-1) |
|
|
|
|
|
beams = beams[effective_indices] |
|
|
beams = torch.cat([beams, token_indices.view(-1, 1)], dim=1) |
|
|
beam_scores = top_scores.view(-1) |
|
|
finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id) |
|
|
|
|
|
final_beams = beams.view(batch_size, num_beams, -1) |
|
|
final_scores = beam_scores.view(batch_size, num_beams) |
|
|
normalized_scores = final_scores / (final_beams != tokenizer.pad_token_id).sum(-1).float().clamp(min=1) |
|
|
best_beams = final_beams[torch.arange(batch_size), normalized_scores.argmax(1), :] |
|
|
self.train() |
|
|
return best_beams |
|
|
|