| |
| import datasets |
| import numpy as np |
| import os |
| import pandas as pd |
| import random |
| import sentencepiece |
| import sacrebleu |
| import sacremoses |
| import tqdm |
| import transformers |
| import torch |
| import wandb |
|
|
|
|
| from transformers import AutoTokenizer |
| from torch.utils.data import Dataset |
| from typing import List |
| import torch |
|
|
| class TranslationDataset(Dataset): |
| def __init__(self, source_sentences: List[str], target_sentences: List[str], tokenizer, max_length=32): |
| self.source_sentences = source_sentences |
| self.target_sentences = target_sentences |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
|
|
| def __len__(self): |
| return len(self.source_sentences) |
|
|
| def __getitem__(self, idx): |
| source_sentence = self.source_sentences[idx] |
| target_sentence = self.target_sentences[idx] |
|
|
| tokenized_source = self.tokenizer(source_sentence, truncation=True, padding='max_length', max_length=self.max_length, return_tensors="pt") |
| tokenized_target = self.tokenizer(target_sentence, truncation=True, padding='max_length', max_length=self.max_length, return_tensors="pt") |
|
|
| return tokenized_source, tokenized_target |
|
|
|
|
| |
| def load_sentences(file_path): |
| with open(file_path, "r") as f: |
| sentences = f.read().split("\n") |
| |
| sentences = [sentence for sentence in sentences if sentence] |
| return sentences |
|
|
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained("Sunbird/sunbird-mul-en-mbart-merged") |
|
|
| |
| source_sentences = load_sentences("Total Combined Data V2 Aug 16 2023/train.lug") |
| target_sentences = load_sentences("Total Combined Data V2 Aug 16 2023/train.en") |
|
|
| |
| dataset = TranslationDataset(source_sentences, target_sentences, tokenizer) |
|
|
| |
|
|
| valid_source_sentences = load_sentences("Total Combined Data V2 Aug 16 2023/val.lug") |
| valid_target_sentences = load_sentences("Total Combined Data V2 Aug 16 2023/val.en") |
|
|
| vadi_dataset = TranslationDataset(valid_source_sentences, valid_target_sentences, tokenizer) |
|
|
|
|
| from transformers import AutoModelForSeq2SeqLM |
|
|
| |
| model = AutoModelForSeq2SeqLM.from_pretrained("Sunbird/sunbird-mul-en-mbart-merged") |
|
|
| from torch.utils.data import DataLoader |
| from transformers import AdamW |
| from transformers.optimization import Adafactor, AdafactorSchedule |
| import torch.nn.functional as F |
|
|
| |
| dataloader = DataLoader(dataset, batch_size=16, shuffle=True) |
| val_dataloader = DataLoader(vadi_dataset, batch_size=16, shuffle=True) |
|
|
| |
| optimizer = AdamW(model.parameters(), lr=1e-6, no_deprecation_warning=True) |
|
|
| |
| model = model.to("cuda") |
|
|
| def train_model(model, dataloader, val_dataloader, optimizer, num_epochs=5000, save_path="Total Combined Data V2 Aug 16 2023/Models/mul_en_base_v2.bin", early_stop=10): |
| |
| best_val_loss = float("inf") |
| |
| early_stop_counter = 0 |
|
|
| |
| for epoch in range(num_epochs): |
| |
| model.train() |
| for batch in dataloader: |
| optimizer.zero_grad() |
| input_ids = batch[0]["input_ids"].squeeze().to("cuda") |
| attention_mask = batch[0]["attention_mask"].squeeze().to("cuda") |
| labels = batch[1]["input_ids"].squeeze().to("cuda") |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) |
| loss = outputs.loss |
| loss.backward() |
| optimizer.step() |
|
|
| |
| model.eval() |
| total_val_loss = 0 |
| with torch.no_grad(): |
| for batch in val_dataloader: |
| input_ids = batch[0]["input_ids"].squeeze().to("cuda") |
| attention_mask = batch[0]["attention_mask"].squeeze().to("cuda") |
| labels = batch[1]["input_ids"].squeeze().to("cuda") |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) |
| loss = outputs.loss |
| total_val_loss += loss.item() |
| avg_val_loss = total_val_loss / len(val_dataloader) |
| print(f"Validation loss at epoch {epoch}: {avg_val_loss}") |
|
|
| |
| if avg_val_loss < best_val_loss: |
| best_val_loss = avg_val_loss |
| torch.save(model.state_dict(), save_path) |
| early_stop_counter = 0 |
| |
| else: |
| early_stop_counter += 1 |
|
|
| |
| if early_stop_counter >= early_stop: |
| print("Early stopping triggered") |
| break |
|
|
|
|
|
|
| |
| print("Training Begins Here!") |
| train_model(model, dataloader, val_dataloader, optimizer) |
|
|
|
|
| |
| print("Scoring Has Begun!") |
| from transformers import AutoTokenizer |
| from torch.utils.data import DataLoader |
| from sacrebleu import corpus_bleu |
|
|
| |
| test_source_sentences = load_sentences("Total Combined Data V2 Aug 16 2023/test.lug") |
| test_target_sentences = load_sentences("Total Combined Data V2 Aug 16 2023/test.en") |
| test_dataset = TranslationDataset(test_source_sentences, test_target_sentences, tokenizer) |
| test_dataloader = DataLoader(test_dataset, batch_size=16) |
|
|
| |
| model.eval() |
| model.to("cuda") |
|
|
| predictions = [] |
| actuals = [] |
|
|
| with torch.no_grad(): |
| for batch in test_dataloader: |
| input_ids = batch[0]["input_ids"].squeeze().to("cuda") |
| attention_mask = batch[0]["attention_mask"].squeeze().to("cuda") |
| labels = batch[1]["input_ids"].squeeze().to("cuda") |
| outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask) |
|
|
| |
| pred_sentences = [tokenizer.decode(tokens) for tokens in outputs] |
| actual_sentences = [tokenizer.decode(tokens) for tokens in labels] |
| predictions.extend(pred_sentences) |
| actuals.extend(actual_sentences) |
|
|
| |
| bleu_score = corpus_bleu(predictions, [actuals]).score |
| print(f"BLEU score: {bleu_score}") |
|
|