# CPT Training and Inference
This notebook demonstrates the training and evaluation process of Context-Aware Prompt Tuning (CPT) using the Hugging Face Trainer. For more details, refer to the [Paper](https://huggingface.co/papers/2410.17222).


## Sections Overview:
1. **Setup**: Import libraries and configure the environment.
2. **Data Preparation**: Load and preprocess the dataset.
3. **Model Training**: Configure and train the model.
4. **Evaluation**: Test the model's performance and visualize results.

# Setup

---




## Installation

In [1]:
!pip install datasets
!pip install git+https://github.com/huggingface/peft

Collecting git+https://github.com/huggingface/peft
  Cloning https://github.com/huggingface/peft to /tmp/pip-req-build-0mbyx_z_
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/peft /tmp/pip-req-build-0mbyx_z_
  Resolved https://github.com/huggingface/peft to commit 131efba5d48753a3355ecd4f3833ae010a0510d6
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


## Imports

In [2]:
from typing import Any, Dict, List, Union

import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)

from peft import CPTConfig, get_peft_model


MAX_INPUT_LENGTH = 1024
MAX_ICL_SAMPLES = 10
NUM_TRAINING_SAMPLES = 100
model_id = 'bigscience/bloom-1b7'

# Data Preparation
---

In [3]:
# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    model_id,               # The name or path of the pre-trained tokenizer (e.g., "bert-base-uncased").
    cache_dir='.',          # Directory to cache the tokenizer files locally.
    padding_side='right',   # Specifies that padding should be added to the right side of sequences.
    trust_remote_code=True  # Allows loading tokenizer implementations from external sources.
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [4]:
# Load the SST-2 dataset from the GLUE benchmark
dataset = load_dataset('glue', 'sst2')

def add_string_labels(example):
    """
    Converts numerical labels into human-readable string labels.

    Args:
        example (dict): A single example from the dataset with a numerical 'label'.

    Returns:
        dict: The example augmented with a 'label_text' field.
    """
    # Map numerical label to string label
    example['label_text'] = "positive" if example['label'] == 1 else "negative"
    return example

# Subset and process the training dataset
context_dataset = dataset['train'].select(range(MAX_ICL_SAMPLES)).map(add_string_labels)
train_dataset = dataset['train'].select(range(MAX_ICL_SAMPLES, NUM_TRAINING_SAMPLES + MAX_ICL_SAMPLES)).map(add_string_labels)

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

**Note:** This notebook uses small subsets of the dataset to ensure quick execution. For proper testing and evaluation, it is recommended to use the entire dataset by setting quick_review to False.

In [5]:
quick_review = True # set to False for a comprehensive evaluation
num_of_test_examples = 100 if quick_review else len(dataset['validation'])
# Subset and process the validation dataset
test_dataset = dataset['validation'].select(range(num_of_test_examples)).map(add_string_labels)

In [6]:
class CPTDataset(Dataset):
    def __init__(self, samples, tokenizer, template, max_length=MAX_INPUT_LENGTH):
        """
        Initialize the CPTDataset with samples, a tokenizer, and a template.

        Args:
            samples (list): List of samples containing input sentences and labels.
            tokenizer: Tokenizer instance for encoding text.
            template (dict): Dictionary defining input/output templates and separators.
            max_length (int): Maximum input length for truncation.
        """
        self.template = template
        self.tokenizer = tokenizer
        self.max_length = max_length

        # Storage for tokenized inputs and masks
        self.attention_mask = []
        self.input_ids = []
        self.input_type_mask = []
        self.inter_seperator_ids = self._get_input_ids(template['inter_seperator'])

        # Tokenize each sample and prepare inputs
        for sample_i in tqdm(samples):
            input_text, label = sample_i['sentence'], sample_i['label_text']
            input_ids, attention_mask, input_type_mask = self.preprocess_sentence(input_text, label)

            self.input_ids.append(input_ids)
            self.attention_mask.append(attention_mask)
            self.input_type_mask.append(input_type_mask)


    def _get_input_ids(self, text):
        """
        Tokenize the given text into input IDs.

        Args:
            text (str): The text to tokenize.

        Returns:
            list: Tokenized input IDs.
        """
        return self.tokenizer(text, add_special_tokens=False)["input_ids"]


    def preprocess_sentence(self, input_text, label):
        """
        Preprocess a sentence and its corresponding label using templates.

        Args:
            input_text (str): The input sentence.
            label (str): The label text (e.g., "positive", "negative").

        Returns:
            tuple: (input_ids, attention_mask, input_type_mask)
        """

        # Split input template into parts
        input_template_part_1_text, input_template_part_2_text = self.template['input'].split('{}')
        input_template_tokenized_part1 = self._get_input_ids(input_template_part_1_text)
        input_tokenized = self._get_input_ids(input_text)
        input_template_tokenized_part2 = self._get_input_ids(input_template_part_2_text)

        # Separator token
        sep_tokenized = self._get_input_ids(self.template['intra_seperator'])

        # Process the label using the template
        label_template_part_1, label_template_part_2 = self.template['output'].split('{}')
        label_template_part1_tokenized = self._get_input_ids(label_template_part_1)
        label_tokenized = self._get_input_ids(label)
        label_template_part2_tokenized = self._get_input_ids(label_template_part_2)

        # End-of-sequence token
        eos = [self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id is not None else []

        # Concatenate all tokenized parts
        input_ids = input_template_tokenized_part1 + input_tokenized + input_template_tokenized_part2 + sep_tokenized + label_template_part1_tokenized + label_tokenized + label_template_part2_tokenized + eos

        # Generate attention and type masks
        attention_mask = [1] * len(input_ids)
        input_type_mask = [1] * len(input_template_tokenized_part1) + [2] * len(input_tokenized) + [1] * len(
            input_template_tokenized_part2) + [0] * len(sep_tokenized) + \
                          [3] * len(label_template_part1_tokenized) + [4] * len(label_tokenized) + [3] * len( \
            label_template_part2_tokenized) + [0] * len(eos)

        # Ensure all masks and inputs are the same length
        assert len(input_type_mask) == len(input_ids) == len(attention_mask)

        return input_ids, attention_mask, input_type_mask


    def __len__(self):
        """
        Get the number of examples in the dataset.

        Returns:
            int: Number of examples.
        """
        return len(self.input_ids)


    def __getitem__(self, idx):
        """
        Get the tokenized representation for the given index.

        Args:
            idx (int): Index of the example.

        Returns:
            dict: Tokenized inputs with attention and type masks.
        """

        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "input_type_mask": self.input_type_mask[idx]
        }

# Define templates for tokenization
templates = {
    'input': 'input: {}',     # Input template with placeholder
    'intra_seperator': ' ',   # Separator between input and output
    'output': 'output: {}',   # Output template with placeholder
    'inter_seperator': '\n'   # Separator between examples
}

# Initialize the dataset
cpt_train_dataset = CPTDataset(train_dataset, tokenizer, templates)


# - `templates`: Define how inputs and outputs should be formatted and separated.
# - `CPTDataset`: Converts the raw dataset into tokenized input IDs and masks.

100%|██████████| 100/100 [00:00<00:00, 874.85it/s]


In [7]:
# Initialize storage for context-level information
context_ids = []                # Concatenated input IDs for all samples
context_attention_mask = []     # Concatenated attention masks
context_input_type_mask = []    # Concatenated input type masks
first_type_mask = 0             # Initial offset for input type mask

cpt_context_dataset = CPTDataset(context_dataset, tokenizer, templates)

# Iterate through the CPT training dataset
for i in range(len(context_dataset)):
    # Add input IDs to the context
    context_ids += cpt_context_dataset[i]['input_ids']

    # Add attention mask to the context
    context_attention_mask += cpt_context_dataset[i]['attention_mask']

    # Adjust and add the input type mask to the context
    context_input_type_mask += [
        i + first_type_mask if i > 0 else 0 # Increment type indices dynamically
        for i in cpt_context_dataset[i]['input_type_mask']
        ]

    # Increment the type mask offset after processing the sample
    first_type_mask += 4

100%|██████████| 10/10 [00:00<00:00, 1133.50it/s]


# Model Training

---

## Load model

In [8]:
# Load a pre-trained causal language model
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    cache_dir='.',
    torch_dtype=torch.float16,
    device_map='auto'
)

# Initialize the CPT configuration
config = CPTConfig(
            cpt_token_ids=context_ids,
            cpt_mask=context_attention_mask,
            cpt_tokens_type_mask=context_input_type_mask,

            opt_weighted_loss_type='decay',
            opt_loss_decay_factor=0.95,         # we choose the exponential decay factor applied to the loss
            opt_projection_epsilon=0.2,         # we choose the projection over the input tokens
            opt_projection_format_epsilon=0.1,  # we choose the projection over input and output templates

            tokenizer_name_or_path=model_id,
)

# Initialize the CPT model with PEFT
model = get_peft_model(base_model, config)

## Setting Collate Function

In [9]:
class CPTDataCollatorForLanguageModeling(DataCollatorForLanguageModeling):
    def __init__(self, tokenizer, training=True, mlm=False):
        """
        Custom collator for CPT-style language modeling.

        Args:
            tokenizer: The tokenizer to handle tokenization and special tokens.
            training (bool): If True, operates in training mode; otherwise in evaluation mode.
            mlm (bool): If True, enables masked language modeling.
        """

        super().__init__(tokenizer, mlm=mlm) # Initialize the parent class
        self.training = training

        # Add a special padding token if not already defined
        self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})

    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        """
        Process a batch of examples for language modeling.

        Args:
            examples (List): A batch of examples with tokenized inputs and optional sample masks.

        Returns:
            Dict: A dictionary containing padded and tensor-converted inputs, attention masks,
                  input type masks, and optional sample masks and labels.
        """

        # Initialize a list to collect sample masks if provided
        list_sample_mask = []
        for i in range(len(examples)):
            if "sample_mask" in examples[i].keys():
                list_sample_mask.append(examples[i].pop("sample_mask"))

        # Define a helper function for padding sequences to the maximum length
        max_len = max(len(ex["input_ids"]) for ex in examples)

        # Define a helper function for padding sequences to the maximum length
        def pad_sequence(sequence, max_len, pad_value=0):
            return sequence + [pad_value] * (max_len - len(sequence))

        # Pad and convert `input_ids`, `attention_mask`, and `input_type_mask` to tensors
        input_ids = torch.tensor([pad_sequence(ex["input_ids"], max_len) for ex in examples])
        attention_mask = torch.tensor([pad_sequence(ex["attention_mask"], max_len) for ex in examples])
        input_type_mask = torch.tensor([pad_sequence(ex["input_type_mask"], max_len) for ex in examples])

        # Create the initial batch dictionary
        batch = {"input_ids": input_ids, "attention_mask": attention_mask, "input_type_mask": input_type_mask}

        # Create a tensor to store sample masks
        tensor_sample_mask = batch["input_ids"].clone().long()
        tensor_sample_mask[:, :] = 0 # Initialize with zeros

        # Populate the tensor with the provided sample masks
        for i in range(len(list_sample_mask)):
            tensor_sample_mask[i, : len(list_sample_mask[i])] = list_sample_mask[i]

        # Copy `input_ids` to use as `labels`
        batch["labels"] = batch["input_ids"].clone()

        # If in evaluation mode, include the `sample_mask` in the batch
        if not self.training:
            batch["sample_mask"] = tensor_sample_mask

        return batch

## Training

In [10]:
training_args = TrainingArguments(
    output_dir='../.',
    use_cpu=False,
    auto_find_batch_size=False,
    learning_rate=1e-4,
    logging_steps=100,
    per_device_train_batch_size=1,
    save_total_limit=1,
    remove_unused_columns=False,
    num_train_epochs=5,
    fp16=True,
    save_strategy='no',
    logging_dir="logs",
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=cpt_train_dataset,  # Custom CPT training dataset.
    data_collator=CPTDataCollatorForLanguageModeling(tokenizer, training=True, mlm=False)
)

trainer.train()

Step,Training Loss
100,0.4008
200,0.036
300,0.0263
400,0.0161
500,0.0116


TrainOutput(global_step=500, training_loss=0.09815525007247924, metrics={'train_runtime': 90.6767, 'train_samples_per_second': 5.514, 'train_steps_per_second': 5.514, 'total_flos': 79477977907200.0, 'train_loss': 0.09815525007247924, 'epoch': 5.0})

# Model Evaluation

---

In [11]:
model.eval()

# Select relevant columns from the test dataset
test_dataset = test_dataset.select_columns(['sentence', 'label_text'])

# Convert the test dataset to a CPT-compatible format
cpt_test_dataset = CPTDataset(test_dataset, tokenizer, templates)

# Get the device where the model is loaded (CPU, GPU or XPU)
device = model.device
list_bool_predictions = []

for i in range(len(test_dataset)):
    input_ids, input_type_mask = cpt_test_dataset[i]['input_ids'], cpt_test_dataset[i]['input_type_mask']

    # Pass the inputs through the model
    outputs = model(
        input_ids=torch.Tensor(input_ids).long().to(device=device).view(1, -1),
        labels=torch.Tensor(input_ids).long().to(device=device).view(1, -1),
        input_type_mask=torch.Tensor(input_type_mask).long().to(device=device).view(1, -1)
    )

    # Shift logits to exclude the last token and match the labels
    shifted_logits = outputs.logits[..., :-1, :].contiguous().to(model.dtype)[0, -len(input_ids) + 1:]
    shift_labels = torch.Tensor(input_ids).long().to(device=device).view(1, -1)[0, 1:].contiguous().to(device)
    shifted_input_type_mask = torch.Tensor(input_type_mask).long().to(device=device).view(1, -1)[..., 1:].contiguous().to(device)

    # Create a mask for the type `4` tokens (label tokens)
    mask = torch.Tensor(shifted_input_type_mask).long().to(device=device).view(-1,) == 4

    # Extract logits and labels corresponding to the mask
    logit = shifted_logits[mask]
    label = shift_labels[mask]

    # All possible label tokens for `negative` and `positive`
    all_labels = torch.Tensor([tokenizer(i, add_special_tokens=False)["input_ids"] for i in ['negative', 'positive']]).long().to(device).view(-1,)

    # Compare logits with label tokens and infer prediction
    prediction = logit[0, torch.Tensor([tokenizer(i, add_special_tokens=False)["input_ids"] for i in ['negative', 'positive']]).long().to(device).view(-1,)].argmax()
    prediction_text = 'negative' if prediction == 0 else 'positive'
    print(f"Sentence: {tokenizer.decode(input_ids)} \n \t The prediction is: {prediction_text}\n \t The GT is {tokenizer.decode(label)}")
    list_bool_predictions.append(prediction_text == tokenizer.decode(label))

print(f'The model Acc is {100 * np.mean(list_bool_predictions)}%')

100%|██████████| 100/100 [00:00<00:00, 1972.82it/s]


Sentence: input: it 's a charming and often affecting journey .  output: positive</s> 
 	 The prediction is: positive
 	 The GT is positive
Sentence: input: unflinchingly bleak and desperate  output: negative</s> 
 	 The prediction is: negative
 	 The GT is negative
Sentence: input: allows us to hope that nolan is poised to embark a major career as a commercial yet inventive filmmaker .  output: positive</s> 
 	 The prediction is: positive
 	 The GT is positive
Sentence: input: the acting , costumes , music , cinematography and sound are all astounding given the production 's austere locales .  output: positive</s> 
 	 The prediction is: positive
 	 The GT is positive
Sentence: input: it 's slow -- very , very slow .  output: negative</s> 
 	 The prediction is: negative
 	 The GT is negative
Sentence: input: although laced with humor and a few fanciful touches , the film is a refreshingly serious look at young women .  output: positive</s> 
 	 The prediction is: positive
 	 The GT is p