|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| import os |
| from x_transformers import Encoder, Decoder |
| from transformers import AutoTokenizer |
|
|
| |
| try: |
| if os.path.exists("tokenizer_config.json"): |
| tokenizer = AutoTokenizer.from_pretrained(".") |
| else: |
| tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-de-en") |
| except Exception as e: |
| print(f"Warning: Tokenizer load failed: {e}") |
| |
|
|
|
|
| class RoPETransformer(nn.Module): |
| def __init__(self, num_encoder_layers, num_decoder_layers, num_heads, d_model, dff, vocab_size, max_length, dropout): |
| super().__init__() |
| self.d_model = d_model |
| self.embedding = nn.Embedding(vocab_size, d_model) |
|
|
| |
| self.dropout_layer = nn.Dropout(dropout) |
|
|
| |
| self.encoder = Encoder( |
| dim = d_model, |
| depth = num_encoder_layers, |
| heads = num_heads, |
| attn_dim_head = d_model // num_heads, |
| ff_mult = dff / d_model, |
| rotary_pos_emb = True, |
| attn_flash = True, |
| attn_dropout = dropout, |
| ff_dropout = dropout, |
| use_rmsnorm = True |
| ) |
|
|
| |
| self.decoder = Decoder( |
| dim = d_model, |
| depth = num_decoder_layers, |
| heads = num_heads, |
| attn_dim_head = d_model // num_heads, |
| ff_mult = dff / d_model, |
| rotary_pos_emb = True, |
| cross_attend = True, |
| attn_flash = True, |
| attn_dropout = dropout, |
| ff_dropout = dropout, |
| use_rmsnorm = True |
| ) |
|
|
| self.final_linear = nn.Linear(d_model, vocab_size) |
| self.final_linear.weight = self.embedding.weight |
|
|
| def forward(self, src, tgt, src_padding_mask, tgt_padding_mask, memory_key_padding_mask, tgt_mask): |
| |
| src_emb = self.embedding(src) * math.sqrt(self.d_model) |
| src_emb = self.dropout_layer(src_emb) |
|
|
| tgt_emb = self.embedding(tgt) * math.sqrt(self.d_model) |
| tgt_emb = self.dropout_layer(tgt_emb) |
|
|
| |
| |
| |
| enc_mask = ~src_padding_mask if src_padding_mask is not None else None |
| dec_mask = ~tgt_padding_mask if tgt_padding_mask is not None else None |
|
|
| |
| |
|
|
| |
| |
| memory = self.encoder(src_emb, mask=enc_mask) |
|
|
| |
| |
| |
| decoder_output = self.decoder( |
| tgt_emb, |
| context=memory, |
| mask=dec_mask, |
| context_mask=enc_mask |
| ) |
|
|
| return self.final_linear(decoder_output) |
|
|
| |
| def create_masks(self, src, tgt): |
| src_padding_mask = (src == tokenizer.pad_token_id) |
| tgt_padding_mask = (tgt == tokenizer.pad_token_id) |
| |
| tgt_mask = nn.Transformer.generate_square_subsequent_mask( |
| sz=tgt.size(1), device=src.device, dtype=torch.bool |
| ) |
| return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask |
|
|
| @torch.no_grad() |
| def generate(self, src: torch.Tensor, max_length: int, num_beams: int = 5) -> torch.Tensor: |
| self.eval() |
| |
| src_padding_mask = (src == tokenizer.pad_token_id) |
| |
| enc_mask = ~src_padding_mask |
|
|
| |
| src_emb = self.embedding(src) * math.sqrt(self.d_model) |
| |
| memory = self.encoder(self.dropout_layer(src_emb), mask=enc_mask) |
|
|
| batch_size = src.shape[0] |
| |
| memory = memory.repeat_interleave(num_beams, dim=0) |
| enc_mask = enc_mask.repeat_interleave(num_beams, dim=0) |
|
|
| initial_token = tokenizer.pad_token_id |
| beams = torch.full((batch_size * num_beams, 1), initial_token, dtype=torch.long, device=src.device) |
| beam_scores = torch.zeros(batch_size * num_beams, device=src.device) |
| finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device) |
|
|
| for _ in range(max_length - 1): |
| if finished_beams.all(): break |
|
|
| |
| tgt_emb = self.embedding(beams) * math.sqrt(self.d_model) |
| |
|
|
| |
| |
| decoder_output = self.decoder( |
| self.dropout_layer(tgt_emb), |
| context=memory, |
| context_mask=enc_mask |
| ) |
|
|
| logits = self.final_linear(decoder_output[:, -1, :]) |
| log_probs = F.log_softmax(logits, dim=-1) |
|
|
| |
| log_probs[:, tokenizer.pad_token_id] = -torch.inf |
| if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0 |
|
|
| total_scores = beam_scores.unsqueeze(1) + log_probs |
| if _ == 0: |
| total_scores = total_scores.view(batch_size, num_beams, -1) |
| total_scores[:, 1:, :] = -torch.inf |
| total_scores = total_scores.view(batch_size * num_beams, -1) |
| else: |
| total_scores = beam_scores.unsqueeze(1) + log_probs |
|
|
| total_scores = total_scores.view(batch_size, -1) |
| top_scores, top_indices = torch.topk(total_scores, k=num_beams, dim=1) |
|
|
| beam_indices = top_indices // log_probs.shape[-1] |
| token_indices = top_indices % log_probs.shape[-1] |
|
|
| batch_indices = torch.arange(batch_size, device=src.device).unsqueeze(1) |
| effective_indices = (batch_indices * num_beams + beam_indices).view(-1) |
|
|
| beams = beams[effective_indices] |
| beams = torch.cat([beams, token_indices.view(-1, 1)], dim=1) |
| beam_scores = top_scores.view(-1) |
| finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id) |
|
|
| final_beams = beams.view(batch_size, num_beams, -1) |
| final_scores = beam_scores.view(batch_size, num_beams) |
| normalized_scores = final_scores / (final_beams != tokenizer.pad_token_id).sum(-1).float().clamp(min=1) |
| best_beams = final_beams[torch.arange(batch_size), normalized_scores.argmax(1), :] |
| self.train() |
| return best_beams |
|
|