# Config Setup

In [1]:
! pip install transformers sentencepiece datasets
! pip install tqdm
! pip install torch
!pip install sacrebleu

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.wh

#  Huggingface Configuration

In [2]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

# Data Preparation

In [3]:
import torch
import numpy as np
from datasets import load_dataset, DatasetDict
from google.colab import drive
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)
from typing import Dict, List, Tuple, Optional
from tqdm import tqdm_notebook
import matplotlib.pyplot as plt
import seaborn as sns
from torch import optim
from torch.nn import functional as F
import glob
import os
import gc
from tqdm.notebook import tqdm
from huggingface_hub import HfFolder
from huggingface_hub import HfApi, Repository, hf_hub_download

sns.set_style('dark')

from google.colab import drive
drive.mount('/content/drive')
curfile = os.getcwd()
print(curfile)


# Model and configuration parameters
model_name = 'Bildad/Swahili-English_Translation'
# branch_used="recovery"
max_length = 128
batch_size = 8
learning_rate = 5e-5
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Language tokens
lang_token = {
    'sw': '<sw>',
    'en': '<en>'

}

api = HfApi()  # Initialize the API client once
repo_id="JMwagunda/Trial"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model.to(device)

repo_name="JMwagunda/Trial"
checkpoint_file = "checkpoint.pth"

# Add special tokens
special_tokens = {'additional_special_tokens': list(lang_token.values())}
tokenizer.add_special_tokens(special_tokens)
model.resize_token_embeddings(len(tokenizer))

# Load the dataset and check available splits
ds = load_dataset('Rogendo/English-Swahili-Sentence-Pairs')
print("Available splits:", ds.keys())


# Calculate the sizes based on percentages
total_rows = len(ds['train'])
train_size = int(0.8 * total_rows)  # 80% for training
validation_size = int(0.1 * total_rows)  # 10% for validation
test_size = total_rows - train_size - validation_size  # Remaining 10% for test

# Shuffle the dataset to ensure randomness in splitting
ds_shuffled = ds['train'].shuffle(seed=42)

# Create the splits based on the specified sizes
train_dataset = ds_shuffled.select(range(train_size)) # train_dataset should be a dataset, not an int
validation_dataset = ds_shuffled.select(range(train_size, train_size + validation_size))
test_dataset = ds_shuffled.select(range(train_size + validation_size, train_size + validation_size + test_size))

new_dataset = DatasetDict({
    'train': train_dataset,
    'validation': validation_dataset,
    'test': test_dataset
})

# Print the new dataset structure
print(new_dataset)

Mounted at /content/drive
/content


tokenizer_config.json:   0%|          | 0.00/843 [00:00<?, ?B/s]

source.spm:   0%|          | 0.00/817k [00:00<?, ?B/s]

target.spm:   0%|          | 0.00/823k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.45M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/416 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/1.38k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/297M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/288 [00:00<?, ?B/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


README.md:   0%|          | 0.00/173 [00:00<?, ?B/s]

ensw.csv:   0%|          | 0.00/21.6M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/210471 [00:00<?, ? examples/s]

Available splits: dict_keys(['train'])
DatasetDict({
    train: Dataset({
        features: ['English sentence', 'Swahili Translation'],
        num_rows: 168376
    })
    validation: Dataset({
        features: ['English sentence', 'Swahili Translation'],
        num_rows: 21047
    })
    test: Dataset({
        features: ['English sentence', 'Swahili Translation'],
        num_rows: 21048
    })
})


# Preprocessor

In [4]:
import torch
from typing import Dict, List, Tuple, Optional, Union, Iterator
from dataclasses import dataclass
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from enum import Enum

class TranslationDirection(Enum):
    """Enum for translation directions"""
    SW2EN = "sw2en"
    EN2SW = "en2sw"


@dataclass
class TranslationPair:
    """Data class for holding translation pairs."""
    source_text: str
    target_text: str
    source_lang: str
    target_lang: str

    def __post_init__(self):
        """Validate inputs upon initialization."""
        # Convert None to empty string to avoid attribute errors
        self.source_text = str(self.source_text or "")
        self.target_text = str(self.target_text or "")
        self.source_lang = str(self.source_lang or "")
        self.target_lang = str(self.target_lang or "")

    def is_valid(self) -> bool:
        """Check if the translation pair is valid for training."""
        return (bool(self.source_text.strip()) and
                bool(self.target_text.strip()) and
                bool(self.source_lang.strip()) and
                bool(self.target_lang.strip()))

class TranslationPreprocessor:
    def __init__(self, tokenizer: AutoTokenizer, lang_tokens: Dict[str, str], max_length: int):
        """
        Initialize the preprocessor.

        Args:
            tokenizer: HuggingFace tokenizer
            lang_tokens: Dictionary mapping language codes to tokens
            max_length: Maximum sequence length
        """
        self.tokenizer = tokenizer
        self.lang_tokens = lang_tokens
        self.max_length = max_length
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.key_mapping = {
            TranslationDirection.SW2EN.value: ("Swahili Translation", "English sentence"),
            TranslationDirection.EN2SW.value: ("English sentence", "Swahili Translation")
        }

    def add_language_token(self, text: str, lang_code: str) -> str:
        """Prepend language-specific token to text."""
        return f"{self.lang_tokens[lang_code]} {text}"

    def encode_input_str(self, text: str, target_lang: str, tokenizer: AutoTokenizer, seq_len: int, lang_token: Dict[str, str]) -> torch.Tensor:
        """Encode input string with source language token"""
        # Use source language token (English)
        source_lang_token = lang_token['sw']

        # Tokenize with English token
        input_ids = tokenizer.encode(
            text=f"{source_lang_token} {text}",
            return_tensors='pt',
            padding='max_length',
            truncation=True,
            max_length=seq_len
        )

        return input_ids[0]

    def encode_target_str(self, text: str, tokenizer: AutoTokenizer, seq_len: int, lang_token: Dict[str, str]) -> torch.Tensor:
        """Encode target string with target language token"""
        # For target, we'll keep the Giriama token as it's the target language
        target_lang_token = lang_token['en']

        token_ids = tokenizer.encode(
            text=f"{target_lang_token} {text}",
            return_tensors='pt',
            padding='max_length',
            truncation=True,
            max_length=seq_len
        )

        return token_ids[0]

    def format_translation_data(self, translations: Dict[str, str], lang_token: Dict[str, str], tokenizer: AutoTokenizer, seq_len: int = 20) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
        """Format translation data for model input"""
        # Define possible input-output language pairs and randomly select one
        input_lang, target_lang = 'sw', 'en'

        # Map language codes to column names
        column_map = {
            'sw': 'Swahili Translation',
            'en': 'English sentence'
        }

        # Extract the translations based on the chosen direction
        input_text = translations.get(column_map[input_lang])
        target_text = translations.get(column_map[target_lang])

        # Check for any missing translations
        if input_text is None or target_text is None:
            return None

        # Tokenize input with English as source language
        input_token_ids = self.encode_input_str(
            text=input_text,
            target_lang='en',
            tokenizer=tokenizer,
            seq_len=seq_len,
            lang_token=lang_token
        )

        # Tokenize target text
        target_token_ids = self.encode_target_str(
            text=target_text,
            tokenizer=tokenizer,
            seq_len=seq_len,
            lang_token=lang_token
        )

        return input_token_ids, target_token_ids

    def transform_batch(self, batch: Dict[str, List[str]], lang_token: Dict[str, str], tokenizer: AutoTokenizer, max_length: int) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
        """Transform a batch of translations"""
        inputs = []
        targets = []

        # Get all English Sentences and Giriama translations from the batch
        swahili_translations = batch['Swahili Translation']
        english_sentences = batch['English sentence']


        # Process each pair in the batch
        for eng, sw in zip(english_sentences, swahili_translations):
            translation_set = {
                'Swahili Translation': sw,
                'English sentence': eng

            }

            # Process the translation pair
            formatted_data = self.format_translation_data(
                translation_set,
                lang_token,
                tokenizer,
                max_length
            )

            if formatted_data is None:
                continue

            input_ids, target_ids = formatted_data
            inputs.append(input_ids.unsqueeze(0))
            targets.append(target_ids.unsqueeze(0))

        if not inputs or not targets:  # Check if we have any valid translations
            return None

        # Concatenate tensors for batch
        device = "cuda" if torch.cuda.is_available() else "cpu"
        batch_input_ids = torch.cat(inputs).to(device)
        batch_target_ids = torch.cat(targets).to(device)

        return batch_input_ids, batch_target_ids

    def get_data_generator(self, dataset, lang_token: Dict[str, str], tokenizer: AutoTokenizer, batch_size: int = 8, direction: TranslationDirection = TranslationDirection.SW2EN) -> Iterator[Tuple[torch.Tensor, torch.Tensor]]:
        """Generate batches of processed translation data"""
        # dataset = dataset.shuffle()
        for i in range(0, len(dataset), batch_size):
            end_idx = min(i + batch_size, len(dataset))
            batch = dataset[i:end_idx]
            batch_dict = {
                'Swahili Translation': batch['Swahili Translation'],
                'English sentence': batch['English sentence']
            }
            processed_batch = self.transform_batch(batch, lang_token, tokenizer, self.max_length)
            if processed_batch is not None:
                yield processed_batch

    def process_translation_pair(self, pair: TranslationPair) -> Tuple[torch.Tensor, torch.Tensor]:
        """Process a single translation pair"""
        # Add language tokens to source and target texts
        source_text_with_token = self.add_language_token(pair.source_text, pair.source_lang)
        target_text_with_token = self.add_language_token(pair.target_text, pair.target_lang)

        # Tokenize source and target texts
        source_ids = self.tokenizer.encode(
            source_text_with_token,
            return_tensors='pt',
            padding='max_length',
            truncation=True,
            max_length=self.max_length
        )

        target_ids = self.tokenizer.encode(
            target_text_with_token,
            return_tensors='pt',
            padding='max_length',
            truncation=True,
            max_length=self.max_length
        )

        return source_ids[0], target_ids[0]

# Optimizer


In [5]:
from transformers import get_cosine_schedule_with_warmup
from torch.optim import AdamW
import math
import numpy as np

# Configurable parameters
learning_rate = 5e-5
epsilon_value = 1e-8
batch_size = 8
num_epochs = 10

# Optimizer setup
optimizer = AdamW(
    model.parameters(),
    lr=learning_rate,
    eps=epsilon_value,
)

# Calculate steps
n_batches = int(np.ceil(len(train_dataset) / batch_size))
total_steps = num_epochs * n_batches
num_warmup_steps = int(0.05 * total_steps)  # 5% warmup

# Use standard cosine scheduler
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=total_steps
)

# Debugging function
def print_lr_details(optimizer, scheduler, step):
    print(f"Step {step}:")
    print(f"  Base LR: {optimizer.param_groups[0]['lr']:.2e}")
    print(f"  Scheduler LR: {scheduler.get_last_lr()[0]:.2e}")

# Example usage (replace with your actual training loop)
for step in range(5):
   print_lr_details(optimizer, scheduler, step)
   optimizer.step()
   scheduler.step()

Step 0:
  Base LR: 0.00e+00
  Scheduler LR: 0.00e+00
Step 1:
  Base LR: 4.75e-09
  Scheduler LR: 4.75e-09
Step 2:
  Base LR: 9.50e-09
  Scheduler LR: 9.50e-09
Step 3:
  Base LR: 1.43e-08
  Scheduler LR: 1.43e-08
Step 4:
  Base LR: 1.90e-08
  Scheduler LR: 1.90e-08


# Checkpoint Manager


In [6]:
import torch
from huggingface_hub import HfApi, hf_hub_download
import os
import gc
import matplotlib.pyplot as plt
import numpy as np

class CheckpointManager:
    def __init__(self, model, optimizer, scheduler, tokenizer, repo_name, device, branch_name="main"):
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.tokenizer = tokenizer
        self.repo_name = repo_name
        self.device = device
        self.branch_name = branch_name
        self.api = HfApi()
        self.batches_per_epoch = None

    def set_dataset_info(self, num_samples, batch_size):
        """Store dataset info for accurate step calculations"""
        self.batches_per_epoch = int(np.ceil(num_samples / batch_size))
        self.batch_size = batch_size
        self.num_samples = num_samples
        print(f"CheckpointManager: Batches per epoch set to {self.batches_per_epoch}")


    def save_checkpoint(self, epoch, batch_idx, losses, total_steps=None, steps_taken=None, repo_name=None, branch_name=None, is_best=False):
        repo_name = repo_name or self.repo_name
        branch_name = branch_name or self.branch_name

        try:
            # Calculate steps_taken if not provided, using the stored batches_per_epoch
            if steps_taken is None and self.batches_per_epoch is not None:
                steps_taken = epoch * self.batches_per_epoch + batch_idx
                print(f"Calculated steps_taken for saving: {steps_taken}")

            # Get current learning rates
            current_lr = self.optimizer.param_groups[0]['lr']
            scheduler_lr = self.scheduler.get_last_lr()[0]

            # Prepare the checkpoint dictionary with enhanced information
            checkpoint = {
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'scheduler_state_dict': self.scheduler.state_dict(),
                'epoch': epoch,
                'batch': batch_idx,
                'losses': losses,
                'current_lr': current_lr,
                'scheduler_lr': scheduler_lr,
                'total_steps': total_steps,
                'batches_per_epoch': self.batches_per_epoch,
                'steps_taken': steps_taken if steps_taken is not None else (epoch * batch_idx),
                'is_best': is_best
            }

            # Save the checkpoint locally
            checkpoint_filename = 'best_checkpoint.pth' if is_best else 'checkpoint.pth'
            torch.save(checkpoint, checkpoint_filename)
            print(f"{'Best m' if is_best else 'M'}odel checkpoint saved locally with learning rate: {current_lr:.2e}, scheduler lr: {scheduler_lr:.2e}")
            print(f"Saved at epoch {epoch + 1}, batch {batch_idx}, steps_taken: {steps_taken}")

            # Ensure all tensors in the model are contiguous
            for param in self.model.parameters():
                param.data = param.data.contiguous()

            # Upload model to Hugging Face Hub
            commit_message = f"{'Best m' if is_best else 'M'}odel checkpoint after epoch {epoch + 1}, batch {batch_idx}"
            self.model.push_to_hub(
                repo_id=repo_name,
                commit_message=commit_message,
                use_temp_dir=True,
                revision=branch_name
            )

            # Upload tokenizer to Hugging Face Hub
            self.tokenizer.push_to_hub(
                repo_id=repo_name,
                revision=branch_name
            )

            # Upload checkpoint file to Hugging Face Hub
            self.api.upload_file(
                path_or_fileobj=checkpoint_filename,
                path_in_repo=checkpoint_filename,
                repo_id=repo_name,
                revision=branch_name
            )

            # Upload notebook if available
            notebook_path = "/content/drive/MyDrive/Colab Notebooks/progress.ipynb"
            if os.path.exists(notebook_path):
                self.api.upload_file(
                    path_or_fileobj=notebook_path,
                    path_in_repo="current_notebook.ipynb",
                    repo_id=repo_name,
                    revision=branch_name
                )

            print(f"Checkpoint file uploaded to Hugging Face Hub.")
            print(f"{'Best m' if is_best else 'M'}odel checkpoint saved to Hugging Face Hub after epoch {epoch + 1}, batch {batch_idx}")

        except Exception as e:
            print(f"Error saving checkpoint: {e}")

    def load_checkpoint(self, repo_name=None, branch_name=None):
        repo_name = repo_name or self.repo_name
        branch_name = branch_name or self.branch_name

        try:
            # Clear memory before loading checkpoint
            torch.cuda.empty_cache()
            gc.collect()

            # Check if the checkpoint file exists in the repo
            if self.checkpoint_exists(repo_name, "checkpoint.pth", branch_name):
                # Download the checkpoint file
                checkpoint_path = hf_hub_download(
                    repo_id=repo_name,
                    filename="checkpoint.pth",
                    revision=branch_name
                )

                # Load the checkpoint
                checkpoint = torch.load(checkpoint_path, map_location=self.device)

                # Store learning rate before loading for comparison
                prev_lr = self.optimizer.param_groups[0]['lr'] if self.optimizer else None
                prev_scheduler_lr = self.scheduler.get_last_lr()[0] if self.scheduler else None

                # Load model, optimizer, and scheduler states
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

                # Load training state
                epoch_idx = checkpoint['epoch']
                batch_idx = checkpoint['batch']
                losses = checkpoint['losses']

                # Extract additional information (with backward compatibility)
                current_lr = checkpoint.get('current_lr', self.optimizer.param_groups[0]['lr'])
                scheduler_lr = checkpoint.get('scheduler_lr', self.scheduler.get_last_lr()[0])
                steps_taken = checkpoint.get('steps_taken', None)
                total_steps = checkpoint.get('total_steps', None)
                is_best = checkpoint.get('is_best', False)

                if steps_taken is None and self.batches_per_epoch is not None:
                    # Calculate steps_taken from epoch and batch_idx
                    steps_taken = epoch_idx * self.batches_per_epoch + batch_idx
                    print(f"Calculated steps_taken from epoch/batch: {steps_taken}")

                # Verify learning rates
                print(f"LR before checkpoint loading: {prev_lr:.2e}, After: {self.optimizer.param_groups[0]['lr']:.2e}")
                print(f"Scheduler LR before: {prev_scheduler_lr:.2e}, After: {self.scheduler.get_last_lr()[0]:.2e}")
                print(f"Checkpoint saved with LR: {current_lr:.2e}, Scheduler LR: {scheduler_lr:.2e}")

                # Verify if the learning rates match up
                if abs(self.optimizer.param_groups[0]['lr'] - current_lr) > 1e-8:
                    print("WARNING: Current learning rate doesn't match the checkpoint's saved learning rate!")

                print(f"Loaded checkpoint from epoch {epoch_idx + 1}, batch {batch_idx}")
                if is_best:
                    print("This is marked as the best checkpoint with lowest validation loss.")

                if steps_taken is not None:
                    print(f"Steps taken according to checkpoint: {steps_taken}")
                    if self.batches_per_epoch is not None:
                        expected_steps = epoch_idx * self.batches_per_epoch + batch_idx
                        if expected_steps != steps_taken:
                            print(f"WARNING: Expected steps ({expected_steps}) doesn't match steps_taken in checkpoint ({steps_taken})!")

                return epoch_idx, batch_idx, losses, steps_taken, total_steps

            # Try loading the best checkpoint if regular checkpoint not found
            elif self.checkpoint_exists(repo_name, "best_checkpoint.pth", branch_name):
                print("Regular checkpoint not found. Attempting to load best checkpoint.")
                checkpoint_path = hf_hub_download(
                    repo_id=repo_name,
                    filename="best_checkpoint.pth",
                    revision=branch_name
                )

                # Process the best checkpoint (similar logic as above)
                checkpoint = torch.load(checkpoint_path, map_location=self.device)

                # Store learning rate before loading for comparison
                prev_lr = self.optimizer.param_groups[0]['lr'] if self.optimizer else None
                prev_scheduler_lr = self.scheduler.get_last_lr()[0] if self.scheduler else None

                # Load model, optimizer, and scheduler states
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

                # Load training state
                epoch_idx = checkpoint['epoch']
                batch_idx = checkpoint['batch']
                losses = checkpoint['losses']

                # Extract additional information
                current_lr = checkpoint.get('current_lr', self.optimizer.param_groups[0]['lr'])
                scheduler_lr = checkpoint.get('scheduler_lr', self.scheduler.get_last_lr()[0])
                steps_taken = checkpoint.get('steps_taken', None)
                total_steps = checkpoint.get('total_steps', None)

                print(f"Loaded best checkpoint from epoch {epoch_idx + 1}, batch {batch_idx}")

                return epoch_idx, batch_idx, losses, steps_taken, total_steps
            else:
                print("No checkpoint file found in the repository.")

        except Exception as e:
            print(f"Error loading checkpoint: {e}")
            torch.cuda.empty_cache()
            gc.collect()

        print("No checkpoint found or error occurred. Starting training from scratch.")
        return 0, 0, [], 0, None

    def checkpoint_exists(self, repo_name, filename, branch_name):
        try:
            # List all files in the specified branch
            files = self.api.list_repo_files(repo_id=repo_name, revision=branch_name)
            return filename in files
        except Exception as e:
            print(f"Error checking for checkpoint: {e}")
            return False

# Evaluate

In [7]:
def eval_model(model, val_dataset, tokenizer, LANG_TOKEN_MAPPING):
    model.eval()
    total_val_loss = 0
    val_data_generator = preprocessor.get_data_generator(
        val_dataset, LANG_TOKEN_MAPPING, tokenizer, batch_size=batch_size
    )
    total_val_batches = len(val_dataset) // batch_size
    if len(val_dataset) % batch_size != 0:
        total_val_batches += 1

    with torch.no_grad():
        for batch_idx, (input_batch, label_batch) in enumerate(val_data_generator):
            input_batch = input_batch.to(device)
            label_batch = label_batch.to(device)
            outputs = model(input_ids=input_batch, labels=label_batch)
            total_val_loss += outputs.loss.item()

    avg_val_loss = total_val_loss / total_val_batches
    return avg_val_loss

# Bleue Score

In [8]:
import sacrebleu

def calculate_bleu(predictions, references):
    """
    Calculate BLEU score using sacrebleu.

    Args:
        predictions: List of predicted translations
        references: List of reference translations

    Returns:
        BLEU score
    """
    # Convert single references to the format expected by sacrebleu
    refs = [[ref] for ref in references]

    # Calculate BLEU score
    bleu = sacrebleu.corpus_bleu(predictions, refs)

    return bleu.score

In [9]:
%matplotlib inline
from transformers import get_cosine_schedule_with_warmup
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import random

def plot_training_history(history):
    """
    Plot training loss, validation loss, and BLEU scores.

    Args:
        history: Dictionary containing training and validation metrics.
    """
    plt.figure(figsize=(12, 6))

    # Plot training and validation loss
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Training Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()

    # Plot BLEU score
    plt.subplot(1, 2, 2)
    plt.plot(history['val_bleu'], label='Validation BLEU', color='green')
    plt.xlabel('Epoch')
    plt.ylabel('BLEU Score')
    plt.title('Validation BLEU Score')
    plt.legend()

    plt.tight_layout()
    plt.show()

# Training loop

In [10]:
def train_model(
    model, train_dataset, val_dataset, optimizer, scheduler, num_epochs, device, tokenizer,
    preprocessor, LANG_TOKEN_MAPPING, batch_size=8, max_length=128,
    learning_rate=5e-5, repo_name="JMwagunda/Trial"  # Reduced default learning rate
):
    """
    Train the model with checkpointing, validation, and BLEU score calculation.
    """
    # Memory management and initialization
    torch.cuda.empty_cache()
    gc.collect()

    # Move model to device
    model.to(device)

    # Store training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_bleu': [],
    }

    # Calculate total batches and steps
    n_batches = int(np.ceil(len(train_dataset) / batch_size))
    total_steps = num_epochs * n_batches

    # Reduced warmup ratio to 3%
    num_warmup_steps = int(0.03 * total_steps)  # 3% warmup
    print(f"Number of warmup steps: {num_warmup_steps}")
    print(f"Total training steps: {total_steps}")
    print(f"Batches per epoch: {n_batches}")

    # Initialize CheckpointManager
    checkpoint_manager = CheckpointManager(
        model=model,
        optimizer=optimizer,
        scheduler=scheduler,
        tokenizer=tokenizer,
        repo_name=repo_name,
        device=device
    )

    # Set dataset info for accurate calculations
    checkpoint_manager.set_dataset_info(len(train_dataset), batch_size)

    # Load checkpoint if it exists - pass dataset info here
    start_epoch, start_batch, checkpoint_losses, steps_taken, checkpoint_total_steps = checkpoint_manager.load_checkpoint()

    # Check if total_steps has changed and print warning
    if checkpoint_total_steps is not None and checkpoint_total_steps != total_steps:
        print(f"WARNING: Total steps in checkpoint ({checkpoint_total_steps}) differs from current setting ({total_steps})")

    # Manually adjust the scheduler state if resuming from a checkpoint
    if start_epoch > 0 or start_batch > 0:
        current_steps = steps_taken if steps_taken is not None else (start_epoch * n_batches + start_batch)
        expected_steps = start_epoch * n_batches + start_batch

        print(f"Checkpoint indicates {current_steps} steps taken, calculated steps: {expected_steps}")

        # If steps_taken is not available, manually step the scheduler
        if steps_taken is None:
            print(f"Stepping scheduler {expected_steps} times to match checkpoint position")
            # Reset scheduler to initial state
            scheduler = get_cosine_schedule_with_warmup(
                optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps
            )
            # Step to match checkpoint position
            for _ in range(expected_steps):
                scheduler.step()

        print(f"Learning rate after scheduler adjustment: {optimizer.param_groups[0]['lr']:.2e}")
        print(f"Scheduler learning rate after adjustment: {scheduler.get_last_lr()[0]:.2e}")

    # Check if training has already completed
    if start_epoch >= num_epochs:
        print(f"Training already completed. Checkpoint found at epoch {start_epoch}, batch {start_batch}.")
        return history, start_batch

    # Early stopping variables
    best_val_loss = float('inf')
    patience = 5  # Number of epochs to wait for improvement
    patience_counter = 0
    early_stop = False

    # Training loop
    for epoch in range(start_epoch, num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print(f"Starting from batch: {start_batch if epoch == start_epoch else 0}")
        print(f"Current learning rate: {optimizer.param_groups[0]['lr']:.2e}")
        print(f"Current scheduler learning rate: {scheduler.get_last_lr()[0]:.2e}")

        model.train()

        # Reset epoch-specific tracking variables
        epoch_losses = []  # Track losses only for this epoch
        processed_batches = 0  # Count of actually processed batches

        # Initialize total_train_loss only with losses from the current epoch
        total_train_loss = 0.0

        # Create data generator for this epoch
        train_data_generator = preprocessor.get_data_generator(
            train_dataset, LANG_TOKEN_MAPPING, tokenizer, batch_size=batch_size
        )

        # Calculate total number of batches
        total_batches = len(train_dataset) // batch_size
        if len(train_dataset) % batch_size != 0:
            total_batches += 1

        # Training progress bar
        pbar = tqdm(train_data_generator, total=total_batches, desc=f"Epoch {epoch + 1}/{num_epochs}", ncols=100)

        # Keep track of steps for this session
        session_steps = 0

        # Add gradient accumulation
        gradient_accumulation_steps = 4  # Accumulate gradients over 4 batches
        optimizer.zero_grad()

        for batch_idx, (input_batch, label_batch) in enumerate(pbar):
            try:
                # Skip batches already processed in the checkpoint
                if epoch == start_epoch and batch_idx < start_batch:
                    continue

                # Memory management
                torch.cuda.empty_cache()

                # Move batches to device
                input_batch = input_batch.to(device)
                label_batch = label_batch.to(device)

                # Forward pass with memory management
                try:
                    outputs = model(input_ids=input_batch, labels=label_batch)
                    loss = outputs.loss / gradient_accumulation_steps  # Scale loss
                except RuntimeError as e:
                    if 'CUDA out of memory' in str(e):
                        print("Out of memory error. Clearing memory and retrying.")
                        torch.cuda.empty_cache()
                        gc.collect()
                        continue
                    else:
                        raise

                # Backward pass
                loss.backward()

                # Step the optimizer every gradient_accumulation_steps batches
                if (batch_idx + 1) % gradient_accumulation_steps == 0 or (batch_idx + 1) == total_batches:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                    session_steps += 1

                # Update metrics - only track losses for current epoch
                # Multiply by gradient_accumulation_steps to get the actual loss value
                current_loss = loss.item() * gradient_accumulation_steps
                epoch_losses.append(current_loss)
                total_train_loss += current_loss
                processed_batches += 1

                # Calculate steps taken more accurately
                steps_from_previous_epochs = start_epoch * n_batches
                steps_from_current_epoch = session_steps
                if epoch == start_epoch:
                    steps_from_current_epoch += start_batch
                total_steps_taken = steps_from_previous_epochs + steps_from_current_epoch

                # Update progress bar with current epoch's metrics only
                current_avg_loss = total_train_loss / processed_batches if processed_batches > 0 else 0
                pbar.set_postfix({
                    "loss": current_loss,
                    "avg_loss": current_avg_loss,
                    "lr": scheduler.get_last_lr()[0],
                    "steps": total_steps_taken
                })

            except Exception as e:
                print(f"\nError processing training batch {batch_idx}: {str(e)}")

                # Save checkpoint before exiting in case of error
                steps_from_previous_epochs = start_epoch * n_batches
                steps_from_current_epoch = session_steps
                if epoch == start_epoch:
                    steps_from_current_epoch += start_batch
                total_steps_taken = steps_from_previous_epochs + steps_from_current_epoch

                # Combine previous checkpoint losses with current epoch losses for saving
                all_losses = checkpoint_losses + epoch_losses if epoch == start_epoch else epoch_losses

                checkpoint_manager.save_checkpoint(
                    epoch=epoch,
                    batch_idx=batch_idx,
                    losses=all_losses,
                    total_steps=total_steps,
                    steps_taken=total_steps_taken
                )
                print(f"Checkpoint saved after error at epoch {epoch + 1}, batch {batch_idx}")
                continue

        pbar.close()

        # Calculate average training loss for this epoch only
        avg_train_loss = total_train_loss / processed_batches if processed_batches > 0 else 0
        history['train_loss'].append(avg_train_loss)
        print(f"Average Training Loss: {avg_train_loss:.4f}")

        # Validation loop
        model.eval()
        total_val_loss = 0
        processed_val_batches = 0  # Count actual processed validation batches
        predictions = []
        references = []

        # Create validation data generator
        val_data_generator = preprocessor.get_data_generator(
            val_dataset, LANG_TOKEN_MAPPING, tokenizer, batch_size=batch_size
        )

        # Calculate total number of validation batches
        total_val_batches = len(val_dataset) // batch_size
        if len(val_dataset) % batch_size != 0:
            total_val_batches += 1

        # Validation progress bar
        pbar = tqdm(val_data_generator, total=total_val_batches, desc="Validating", ncols=100)

        with torch.no_grad():
            for batch_idx, (input_batch, label_batch) in enumerate(pbar):
                try:
                    # Move batches to device
                    input_batch = input_batch.to(device)
                    label_batch = label_batch.to(device)

                    # Calculate validation loss
                    model_outputs = model(input_ids=input_batch, labels=label_batch)
                    batch_val_loss = model_outputs.loss.item()
                    total_val_loss += batch_val_loss
                    processed_val_batches += 1  # Count only successfully processed batches

                    # Generate translations
                    outputs = model.generate(input_batch, max_length=max_length)

                    # Decode predictions and references
                    pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
                    ref_texts = tokenizer.batch_decode(label_batch, skip_special_tokens=True)

                    # Print some sample translations for debugging (first 2 in the batch)
                    if batch_idx == 0:
                        print("\nSample translations:")
                        for i in range(min(2, len(pred_texts))):
                            print(f"Reference: {ref_texts[i]}")
                            print(f"Prediction: {pred_texts[i]}")
                            print("---")

                    # Add to predictions and references
                    predictions.extend(pred_texts)
                    references.extend(ref_texts)

                    # Update progress bar with current batch loss and running average
                    current_avg_val_loss = total_val_loss / processed_val_batches
                    pbar.set_postfix({
                        "batch_loss": batch_val_loss,
                        "avg_val_loss": current_avg_val_loss
                    })

                except Exception as e:
                    print(f"Error processing validation batch {batch_idx}: {str(e)}")
                    continue

        pbar.close()

        # Calculate average validation loss based on actually processed batches
        avg_val_loss = total_val_loss / processed_val_batches if processed_val_batches > 0 else float('inf')
        history['val_loss'].append(avg_val_loss)

        # Calculate BLEU score with the fixed implementation
        bleu_score = calculate_bleu(predictions, references)
        history['val_bleu'].append(bleu_score)
        print(f"Validation Loss: {avg_val_loss:.4f}")
        print(f"Validation BLEU: {bleu_score:.4f}")

        # Update the combined list of losses for checkpoint saving
        if epoch == start_epoch:
            all_losses = checkpoint_losses + epoch_losses
        else:
            all_losses = epoch_losses if epoch == 0 else history.get('checkpoint_losses', []) + epoch_losses

        # Store losses for future epochs
        history['checkpoint_losses'] = all_losses

        # Save checkpoint at the end of the epoch - only every 10 epochs or final epoch
        if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:  # Save every 10 epochs or at the final epoch
            total_steps_taken = (epoch + 1) * n_batches  # End of epoch
            checkpoint_manager.save_checkpoint(
                epoch=epoch + 1,  # Save as the start of the next epoch
                batch_idx=0,      # Start from batch 0 of next epoch
                losses=all_losses,
                total_steps=total_steps,
                steps_taken=total_steps_taken
            )
            print(f"Checkpoint saved at epoch {epoch + 1}")
        else:
            print(f"Skipping checkpoint save at epoch {epoch + 1} (saving every 10 epochs)")

        # Print epoch summary
        print(f"\nEpoch {epoch + 1} Summary:")
        print(f"Training Loss: {avg_train_loss:.4f}")
        print(f"Validation Loss: {avg_val_loss:.4f}")
        print(f"BLEU Score: {bleu_score:.4f}")
        print(f"Final Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")

        # Early stopping check - always save the best model regardless of epoch
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            # Save best model
            checkpoint_manager.save_checkpoint(
                epoch=epoch + 1,
                batch_idx=0,
                losses=all_losses,
                total_steps=total_steps,
                steps_taken=total_steps_taken,
                is_best=True  # Important to mark this as the best checkpoint
            )
            print("New best model saved!")
        else:
            patience_counter += 1
            print(f"Validation loss did not improve. Patience: {patience_counter}/{patience}")

        if patience_counter >= patience:
            print(f"Early stopping triggered after {epoch + 1} epochs")
            early_stop = True
            break

        if early_stop:
            break

    # Plot training history
    plot_training_history(history)

    return history

# Run training loop

In [None]:
# Device setup
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#clear the cache
torch.cuda.empty_cache()

# Preprocessing setup
preprocessor = TranslationPreprocessor(tokenizer, lang_token, max_length)

# Shuffle the dataset before creating the data generator
new_dataset['train'] = new_dataset['train'].shuffle(seed=42)
new_dataset['validation'] = new_dataset['validation'].shuffle(seed=42)

# Clear the cache
torch.cuda.empty_cache()

# Run training loop
history = train_model(
    model=model,
    train_dataset=new_dataset['train'],
    val_dataset=new_dataset['validation'],
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=num_epochs,
    tokenizer=tokenizer,
    device=device,
    preprocessor=preprocessor,
    LANG_TOKEN_MAPPING=lang_token,
    repo_name=repo_name
)
# history, last_batch_idx = train_model(
#     model=model,
#     train_dataset=new_dataset['train'],
#     val_dataset=new_dataset['validation'],
#     optimizer=optimizer,
#     scheduler=scheduler,
#     num_epochs=num_epochs,
#     tokenizer=tokenizer,
#     device=device,
#     preprocessor=preprocessor,
#     LANG_TOKEN_MAPPING=lang_token,
#     repo_name=repo_name
# )

Number of warmup steps: 6314
Total training steps: 210470
Batches per epoch: 21047
CheckpointManager: Batches per epoch set to 21047
Regular checkpoint not found. Attempting to load best checkpoint.


best_checkpoint.pth:   0%|          | 0.00/895M [00:00<?, ?B/s]

Loaded best checkpoint from epoch 3, batch 0
Checkpoint indicates 5262 steps taken, calculated steps: 42094
Learning rate after scheduler adjustment: 5.00e-05
Scheduler learning rate after adjustment: 5.00e-05

Epoch 3/10
Starting from batch: 0
Current learning rate: 5.00e-05
Current scheduler learning rate: 5.00e-05


Epoch 3/10:   0%|                                                         | 0/21047 [00:00<?, ?it/s]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
Epoch 3/10: 100%|█| 21047/21047 [1:06:34<00:00,  5.27it/s, loss=0.114, avg_loss=0.236, lr=4.99e-5, s


Average Training Loss: 0.2364


Validating:   0%|          | 1/2631 [00:01<1:25:08,  1.94s/it, batch_loss=0.346, avg_val_loss=0.346]


Sample translations:
Reference: Tom quickly realized that something was wrong
Prediction: Tom quickly found that something was wrong
---
Reference: He must've missed the bus
Prediction: She must have missed the bus
---


Validating: 100%|█████████| 2631/2631 [30:13<00:00,  1.45it/s, batch_loss=0.068, avg_val_loss=0.206]


Validation Loss: 0.2062
Validation BLEU: 53.7285
Skipping checkpoint save at epoch 3 (saving every 10 epochs)

Epoch 3 Summary:
Training Loss: 0.2364
Validation Loss: 0.2062
BLEU Score: 53.7285
Final Learning Rate: 4.99e-05
Best model checkpoint saved locally with learning rate: 4.99e-05, scheduler lr: 4.99e-05
Saved at epoch 4, batch 0, steps_taken: 47356


README.md:   0%|          | 0.00/5.17k [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/297M [00:00<?, ?B/s]

No files have been modified since last commit. Skipping to prevent empty commit.


best_checkpoint.pth:   0%|          | 0.00/895M [00:00<?, ?B/s]

# Plotting code

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

def plot_training_history(history: Dict[str, List[float]], save_path: Optional[str] = None):
    """
    Plot training loss, validation loss, and BLEU scores.

    Args:
        history: Dictionary containing training and validation metrics.
        save_path: Optional path to save the plot to a file.
    """
    plt.figure(figsize=(12, 6))

    # Plot training and validation loss
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Training Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()

    # Plot BLEU score
    plt.subplot(1, 2, 2)
    plt.plot(history['val_bleu'], label='Validation BLEU', color='green')
    plt.xlabel('Epoch')
    plt.ylabel('BLEU Score')
    plt.title('Validation BLEU Score')
    plt.legend()

    plt.tight_layout()

    # Save the plot if a path is provided
    if save_path:
        plt.savefig(save_path)
        print(f"Plot saved to {save_path}")

    plt.show()

# Example usage
if __name__ == "__main__":
    history = {}

    # Call the plotting function
    plot_training_history(history, save_plot_path)

# Save Final model

In [None]:
# Save the final checkpoint
checkpoint_manager = CheckpointManager(
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    tokenizer=tokenizer,
    repo_name="Lingua-Connect/JMW_TrainerImproved",
    device=device
)

# Save the final model and tokenizer
checkpoint_manager.save_checkpoint(
    epoch=num_epochs - 1,  # Final epoch
    batch_idx= 0,
    losses=history['train_loss']
)

print("Final model and tokenizer pushed to Hugging Face Hub.")

# Load the saved model




In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Dictionary for language tokens
lang_token = {
    'en': '<en>',
    'nys': '<nys>'
}

# Function to load model and tokenizer
def load_model_and_tokenizer(model_name):
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model.to(device)
    return model, tokenizer

# Load English to Giriama model
eng_gir_model, eng_gir_tokenizer = load_model_and_tokenizer("Lingua-Connect/ENG_GIR")
print("English to Giriama model and tokenizer loaded.")

# Load Giriama to English model
gir_eng_model, gir_eng_tokenizer = load_model_and_tokenizer("Lingua-Connect/GIR_ENG")
print("Giriama to English model and tokenizer loaded.")

# Perform inference

In [None]:
def translate(text, model, tokenizer, source_lang, target_lang, max_length=128, num_beams=5, num_return_sequences=1):
    # Add language token to the input text
    input_text = f"{lang_token[source_lang]} {text}"

    # Tokenize the input text
    input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=max_length, truncation=True)
    input_ids = input_ids.to(device)

    # Generate translation
    outputs = model.generate(
        input_ids,
        max_length=max_length,
        num_beams=num_beams,
        num_return_sequences=num_return_sequences,
        early_stopping=True
    )

    # Decode the generated tokens to text
    translated_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]

    return translated_texts

In [None]:
# Example usage - English to Giriama
eng_text = "Where are you?"
gir_translations = translate(
    eng_text,
    eng_gir_model,
    eng_gir_tokenizer,
    source_lang="en",
    target_lang="nys",
    num_beams=5,
    num_return_sequences=3
)

print(f"\nEnglish to Giriama:")
print(f"Input: {eng_text}")
for i, translation in enumerate(gir_translations):
    print(f"Translation {i+1}: {translation}")

In [None]:
# Example usage - Giriama to English
gir_text = gir_translations[0]  # Use the first translation as input
eng_back_translations = translate(
    gir_text,
    gir_eng_model,
    gir_eng_tokenizer,
    source_lang="nys",
    target_lang="en",
    num_beams=5,
    num_return_sequences=3
)

print(f"\nGiriama to English:")
print(f"Input: {gir_text}")
for i, translation in enumerate(eng_back_translations):
    print(f"Translation {i+1}: {translation}")

# Validation Inference


In [None]:
# If you have a validation dataset
if 'validation_dataset' in locals():
    print("\n===== VALIDATION DATASET TESTS =====")
    sample = validation_dataset[76]  # Change the index to test different samples
    input_text = sample['English sentence']
    target_text = sample['Swahili Translation']  # This should be 'Giriama Translation'

    print(f"Dataset sample:")
    print(f"English: {input_text}")
    print(f"Reference Giriama: {target_text}")

    # Translate English to Giriama
    gir_translations = translate(
        input_text,
        eng_gir_model,
        eng_gir_tokenizer,
        source_lang="en",
        target_lang="gir",
        num_beams=5,
        num_return_sequences=3
    )

    print("\nGenerated translations:")
    for i, translation in enumerate(gir_translations):
        print(f"Translation {i+1}: {translation}")