import torch, tokenizers import torch.nn as nn from .config import HPARAMS from .model import TransformerNMT @torch.no_grad() def greedy_decode(model, src_ids, pad_id, bos_id, eos_id, max_len, device): """ Greedy decoding for Transformer model: computes encoder memory once, then iteratively generates target tokens using prior decoder outputs and memory. Supports batched inference, stops at EOS or max_len, and builds its own padding and causal masks. """ batch_size = src_ids.size(0) model.eval() src_ids = src_ids.to(device) src_key_padding_mask = (src_ids == pad_id).to(device) # (N, S) # compute encoder memory src_emb = model.positional_embedding(model.shared_embedding(src_ids)) # (N, S, E) memory = model.transformer.encoder(src = src_emb, src_key_padding_mask = src_key_padding_mask) # (N, S, E) # prepare initial decoder input current_tokens = torch.full((batch_size, 1), bos_id, dtype=torch.long).to(device) # (N, 1) finished = torch.zeros(batch_size, dtype=torch.bool).to(device) outputs = [[] for _ in range(batch_size)] # decoding for step in range(max_len): # target embedding & masks (causal/padding) tgt_emb = model.positional_embedding(model.shared_embedding(current_tokens)).to(device) # (N, L, E) tgt_key_padding_mask = (current_tokens == pad_id).to(device) # usually false (N ,L) causal_mask = nn.Transformer.generate_square_subsequent_mask(tgt_emb.size(1), dtype=torch.bool).to(device) # (L, L) # decoder outputs decoder_outputs = model.transformer.decoder(tgt = tgt_emb, memory = memory, tgt_mask = causal_mask, tgt_key_padding_mask = tgt_key_padding_mask, memory_key_padding_mask = src_key_padding_mask) # (N, L, E) next_logits = model.output(decoder_outputs)[:, -1, :] # (N, vocab_size) next_tokens = next_logits.argmax(dim=-1) # (N,) # update current decoded tokens current_tokens = torch.cat([current_tokens, next_tokens.unsqueeze(1)], dim=1) # (N, L+1) # store output tokens & stop if EOS token found for i in range(batch_size): if not finished[i]: outputs[i].append(int(next_tokens[i].item())) if next_tokens[i] == eos_id: finished[i] = True if finished.all(): break return outputs def translate(model, tokenizer, src_list, max_len=64, device=None): """ args: src_list (List[str]): Source sentences to translate. max_len (int): maximum length of generated output sequence. device (torch.device, optional) returns: List[str]: translated target sentences. """ if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pad_id, bos_id, eos_id = [tokenizer.token_to_id(i) for i in ["[PAD]", "[BOS]", "[EOS]"]] src_ids = torch.tensor([enc.ids for enc in tokenizer.encode_batch(src_list)], dtype=torch.long) # (N, S) outputs = greedy_decode(model, src_ids, pad_id, bos_id, eos_id, max_len, device) return tokenizer.decode_batch(outputs) def load_model_and_tokenizer(tokenizer_path, model_checkpoint_path, device=None): if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Torch Device: {device}") hp = HPARAMS() try: tokenizer = tokenizers.Tokenizer.from_file(tokenizer_path) tokenizer.enable_truncation(hp.max_seq_len) tokenizer.enable_padding(pad_id = 0, pad_token = "[PAD]") model = TransformerNMT(tokenizer.get_vocab_size(), hp.max_seq_len, **hp.model_hparams).to(device) state_dict = torch.load(model_checkpoint_path, map_location=device, weights_only=True) model.load_state_dict(state_dict) return model, tokenizer except Exception as e: print(f"Error loading model/tokenizer: {e}") return None, None