|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import DataLoader |
|
|
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast |
|
|
from torchmetrics.text import BLEUScore, SacreBLEUScore |
|
|
from tqdm.auto import tqdm |
|
|
import config |
|
|
from src import model, utils |
|
|
|
|
|
|
|
|
TGT_VOCAB_SIZE: int = config.VOCAB_SIZE |
|
|
|
|
|
|
|
|
def train_one_epoch( |
|
|
model: model.Transformer, |
|
|
dataloader: DataLoader, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
criterion: nn.Module, |
|
|
scheduler: torch.optim.lr_scheduler.LambdaLR, |
|
|
device: torch.device, |
|
|
logger=None, |
|
|
) -> float: |
|
|
""" |
|
|
Runs a single training epoch. |
|
|
|
|
|
Args: |
|
|
model: The Transformer model. |
|
|
dataloader: The training DataLoader. |
|
|
optimizer: The optimizer. |
|
|
criterion: The loss function (e.g., CrossEntropyLoss). |
|
|
device: The device to run on (e.g., 'cuda'). |
|
|
|
|
|
Returns: |
|
|
The average training loss for the epoch. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
model.train() |
|
|
|
|
|
total_loss = 0.0 |
|
|
|
|
|
|
|
|
progress_bar = tqdm(dataloader, desc="Training", leave=False) |
|
|
batch_idx: int = 0 |
|
|
|
|
|
for batch in progress_bar: |
|
|
batch_idx += 1 |
|
|
|
|
|
|
|
|
|
|
|
batch_gpu = { |
|
|
k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor) |
|
|
} |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
|
|
|
logits = model( |
|
|
src=batch_gpu["src_ids"], |
|
|
tgt=batch_gpu["tgt_input_ids"], |
|
|
src_mask=batch_gpu["src_mask"], |
|
|
tgt_mask=batch_gpu["tgt_mask"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = criterion(logits.view(-1, TGT_VOCAB_SIZE), batch_gpu["labels"].view(-1)) |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
|
|
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
|
|
|
|
|
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
total_loss += loss.item() |
|
|
progress_bar.set_postfix(loss=loss.item()) |
|
|
|
|
|
|
|
|
if logger and batch_idx % 100 == 0: |
|
|
logger.log( |
|
|
{ |
|
|
"train/batch_loss": loss.item(), |
|
|
"train/learning_rate": optimizer.param_groups[0]["lr"], |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
return total_loss / len(dataloader) |
|
|
|
|
|
|
|
|
def validate_one_epoch( |
|
|
model: model.Transformer, |
|
|
dataloader: DataLoader, |
|
|
criterion: nn.Module, |
|
|
device: torch.device, |
|
|
) -> float: |
|
|
""" |
|
|
Runs a single validation epoch. |
|
|
|
|
|
Args: |
|
|
model: The Transformer model. |
|
|
dataloader: The validation DataLoader. |
|
|
criterion: The loss function (e.g., CrossEntropyLoss). |
|
|
device: The device to run on (e.g., 'cuda'). |
|
|
|
|
|
Returns: |
|
|
The average validation loss for the epoch. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
total_loss = 0.0 |
|
|
|
|
|
|
|
|
progress_bar = tqdm(dataloader, desc="Validating", leave=False) |
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in progress_bar: |
|
|
|
|
|
batch_gpu = { |
|
|
k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor) |
|
|
} |
|
|
|
|
|
|
|
|
logits = model( |
|
|
src=batch_gpu["src_ids"], |
|
|
tgt=batch_gpu["tgt_input_ids"], |
|
|
src_mask=batch_gpu["src_mask"], |
|
|
tgt_mask=batch_gpu["tgt_mask"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
loss = criterion( |
|
|
logits.view(-1, TGT_VOCAB_SIZE), batch_gpu["labels"].view(-1) |
|
|
) |
|
|
|
|
|
|
|
|
total_loss += loss.item() |
|
|
progress_bar.set_postfix(loss=loss.item()) |
|
|
|
|
|
|
|
|
return total_loss / len(dataloader) |
|
|
|
|
|
|
|
|
def evaluate_model( |
|
|
model: model.Transformer, |
|
|
dataloader: DataLoader, |
|
|
tokenizer: PreTrainedTokenizerFast, |
|
|
device: torch.device, |
|
|
table=None, |
|
|
) -> tuple[float, float]: |
|
|
""" |
|
|
Runs final evaluation on the test set using Beam Search |
|
|
and calculates the SacreBLEU score. |
|
|
""" |
|
|
print("\n--- Starting Evaluation (BLEU + SacreBLEU) ---") |
|
|
|
|
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
all_predicted_strings = [] |
|
|
all_expected_strings = [] |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in tqdm(dataloader, desc="Evaluating"): |
|
|
|
|
|
batch_gpu = { |
|
|
k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor) |
|
|
} |
|
|
|
|
|
src_ids = batch_gpu["src_ids"] |
|
|
src_mask = batch_gpu["src_mask"] |
|
|
expected_ids = batch_gpu["labels"] |
|
|
|
|
|
B = src_ids.size(0) |
|
|
|
|
|
|
|
|
batch_expected_strings = [] |
|
|
|
|
|
|
|
|
expected_id_lists = expected_ids.cpu().tolist() |
|
|
|
|
|
|
|
|
for id_list in expected_id_lists: |
|
|
|
|
|
|
|
|
token_list = tokenizer.convert_ids_to_tokens(id_list) |
|
|
batch_expected_strings.append( |
|
|
utils.filter_and_detokenize(token_list, skip_special=True) |
|
|
) |
|
|
|
|
|
|
|
|
batch_predicted_strings = [] |
|
|
for i in tqdm(range(B), desc="Decoding Batch", leave=False): |
|
|
src_sentence = src_ids[i].unsqueeze(0) |
|
|
src_sentence_mask = src_mask[i].unsqueeze(0) |
|
|
|
|
|
|
|
|
predicted_ids = utils.greedy_decode_sentence( |
|
|
model, |
|
|
src_sentence, |
|
|
src_sentence_mask, |
|
|
max_len=config.MAX_SEQ_LEN, |
|
|
sos_token_id=config.SOS_TOKEN_ID, |
|
|
eos_token_id=config.EOS_TOKEN_ID, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
|
|
|
predicted_id_list = predicted_ids.cpu().tolist() |
|
|
|
|
|
|
|
|
predicted_token_list = tokenizer.convert_ids_to_tokens( |
|
|
predicted_id_list |
|
|
) |
|
|
|
|
|
decoded_str = utils.filter_and_detokenize( |
|
|
predicted_token_list, skip_special=True |
|
|
) |
|
|
batch_predicted_strings.append(decoded_str) |
|
|
|
|
|
|
|
|
all_predicted_strings.extend(batch_predicted_strings) |
|
|
all_expected_strings.extend([[s] for s in batch_expected_strings]) |
|
|
|
|
|
bleu_metric = BLEUScore(n_gram=4, smooth=True).to(config.DEVICE) |
|
|
sacrebleu_metric = SacreBLEUScore( |
|
|
n_gram=4, smooth=True, tokenize="intl", lowercase=False |
|
|
).to(config.DEVICE) |
|
|
|
|
|
|
|
|
print("\nCalculating final BLEU score...") |
|
|
final_bleu = bleu_metric(all_predicted_strings, all_expected_strings) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\nCalculating final SacreBLEU score...") |
|
|
final_sacrebleu = sacrebleu_metric(all_predicted_strings, all_expected_strings) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n--- Translation Examples (Pred vs Exp) ---") |
|
|
for i in range(min(5, len(all_predicted_strings))): |
|
|
print(f" PRED: {all_predicted_strings[i]}") |
|
|
print(f" EXP: {all_expected_strings[i][0]}") |
|
|
print(" ---") |
|
|
|
|
|
table.add_data(all_expected_strings[i][0], all_predicted_strings[i]) |
|
|
|
|
|
return final_bleu.item() * 100, final_sacrebleu.item() * 100 |
|
|
|