In [1]:
%env CUDA_VISIBLE_DEVICES=0  # force using CUDA GPU device 0
%env ZE_AFFINITY_MASK=0  # force using Intel XPU device 0
%env TOKENIZERS_PARALLELISM=false

env: CUDA_VISIBLE_DEVICES=0  # force using CUDA GPU device 0
env: ZE_AFFINITY_MASK=0  # force using Intel XPU device 0
env: TOKENIZERS_PARALLELISM=false


## Initialize PolyModel

In [None]:
import torch
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    default_data_collator,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)
from datasets import load_dataset, concatenate_datasets
from peft import PolyConfig, get_peft_model, TaskType, PeftModel, PeftConfig

model_name_or_path = "google/flan-t5-xl"

r = 8  # rank of lora in poly
n_tasks = 4  # number of tasks
n_skills = 2  # number of skills (loras)
n_splits = 4  # number of heads

batch_size = 8
lr = 5e-5
num_epochs = 8

In [3]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, trust_remote_code=True)

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 22.43it/s]


In [4]:
peft_config = PolyConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    poly_type="poly",
    r=r,
    n_tasks=n_tasks,
    n_skills=n_skills,
    n_splits=n_splits,
)

model = get_peft_model(base_model, peft_config)
model.print_trainable_parameters()

trainable params: 9,441,792 || all params: 2,859,198,976 || trainable%: 0.3302


## Prepare datasets

For this example, we selected four `SuperGLUE` benchmark datasets: `boolq`, `multirc`, `rte`, and `wic`, each with a training set of 1,000 examples and an evaluation set of 100 examples.

In [5]:
# boolq
boolq_dataset = (
    load_dataset("super_glue", "boolq")
    .map(
        lambda x: {
            "input": f"{x['passage']}\nQuestion: {x['question']}\nA. Yes\nB. No\nAnswer:",
            # 0 - False
            # 1 - True
            "output": ["B", "A"][int(x["label"])],
            "task_name": "boolq",
        }
    )
    .select_columns(["input", "output", "task_name"])
)
print("boolq example: ")
print(boolq_dataset["train"][0])

# multirc
multirc_dataset = (
    load_dataset("super_glue", "multirc")
    .map(
        lambda x: {
            "input": (
                f"{x['paragraph']}\nQuestion: {x['question']}\nAnswer: {x['answer']}\nIs it"
                " true?\nA. Yes\nB. No\nAnswer:"
            ),
            # 0 - False
            # 1 - True
            "output": ["B", "A"][int(x["label"])],
            "task_name": "multirc",
        }
    )
    .select_columns(["input", "output", "task_name"])
)
print("multirc example: ")
print(multirc_dataset["train"][0])

# rte
rte_dataset = (
    load_dataset("super_glue", "rte")
    .map(
        lambda x: {
            "input": (
                f"{x['premise']}\n{x['hypothesis']}\nIs the sentence below entailed by the"
                " sentence above?\nA. Yes\nB. No\nAnswer:"
            ),
            # 0 - entailment
            # 1 - not_entailment
            "output": ["A", "B"][int(x["label"])],
            "task_name": "rte",
        }
    )
    .select_columns(["input", "output", "task_name"])
)
print("rte example: ")
print(rte_dataset["train"][0])

# wic
wic_dataset = (
    load_dataset("super_glue", "wic")
    .map(
        lambda x: {
            "input": (
                f"Sentence 1: {x['sentence1']}\nSentence 2: {x['sentence2']}\nAre '{x['word']}'"
                " in the above two sentences the same?\nA. Yes\nB. No\nAnswer:"
            ),
            # 0 - False
            # 1 - True
            "output": ["B", "A"][int(x["label"])],
            "task_name": "wic",
        }
    )
    .select_columns(["input", "output", "task_name"])
)
print("wic example: ")
print(wic_dataset["train"][0])

boolq example: 
{'input': 'Persian language -- Persian (/ˈpɜːrʒən, -ʃən/), also known by its endonym Farsi (فارسی fārsi (fɒːɾˈsiː) ( listen)), is one of the Western Iranian languages within the Indo-Iranian branch of the Indo-European language family. It is primarily spoken in Iran, Afghanistan (officially known as Dari since 1958), and Tajikistan (officially known as Tajiki since the Soviet era), and some other regions which historically were Persianate societies and considered part of Greater Iran. It is written in the Persian alphabet, a modified variant of the Arabic script, which itself evolved from the Aramaic alphabet.\nQuestion: do iran and afghanistan speak the same language\nA. Yes\nB. No\nAnswer:', 'output': 'A', 'task_name': 'boolq'}
multirc example: 
{'input': 'While this process moved along, diplomacy continued its rounds. Direct pressure on the Taliban had proved unsuccessful. As one NSC staff note put it, "Under the Taliban, Afghanistan is not so much a state sponsor of

In [6]:
# define a task2id map
TASK2ID = {
    "boolq": 0,
    "multirc": 1,
    "rte": 2,
    "wic": 3,
}


def tokenize(examples):
    inputs, targets = examples["input"], examples["output"]
    features = tokenizer(inputs, max_length=512, padding="max_length", truncation=True, return_tensors="pt")
    labels = tokenizer(targets, max_length=2, padding="max_length", truncation=True, return_tensors="pt")
    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    features["labels"] = labels
    features["task_ids"] = torch.tensor([[TASK2ID[t]] for t in examples["task_name"]]).long()
    return features

In [7]:
def get_superglue_dataset(
    split="train",
    n_samples=500,
):
    ds = concatenate_datasets(
        [
            boolq_dataset[split].shuffle().select(range(n_samples)),
            multirc_dataset[split].shuffle().select(range(n_samples)),
            rte_dataset[split].shuffle().select(range(n_samples)),
            wic_dataset[split].shuffle().select(range(n_samples)),
        ]
    )
    ds = ds.map(
        tokenize,
        batched=True,
        remove_columns=["input", "output", "task_name"],
        load_from_cache_file=False,
    )
    return ds

As a toy example, we only select 1,000 from each subdataset for training and 100 each for eval.

In [8]:
superglue_train_dataset = get_superglue_dataset(split="train", n_samples=1000)
superglue_eval_dataset = get_superglue_dataset(split="test", n_samples=100)

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]

Map: 100%|██████████| 4000/4000 [00:02<00:00, 1880.98 examples/s]
Map: 100%|██████████| 400/400 [00:00<00:00, 2124.88 examples/s]


## Train and evaluate

In [None]:
# training and evaluation
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    preds = [[i for i in seq if i != -100] for seq in preds]
    labels = [[i for i in seq if i != -100] for seq in labels]
    preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    correct = 0
    total = 0
    for pred, true in zip(preds, labels):
        if pred.strip() == true.strip():
            correct += 1
        total += 1
    accuracy = correct / total
    return {"accuracy": accuracy}


training_args = Seq2SeqTrainingArguments(
    "output",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=lr,
    num_train_epochs=num_epochs,
    eval_strategy="epoch",
    logging_strategy="epoch",
    save_strategy="no",
    report_to=[],
    predict_with_generate=True,
    generation_max_length=2,
    remove_unused_columns=False,
)
trainer = Seq2SeqTrainer(
    model=model,
    processing_class=tokenizer,
    args=training_args,
    train_dataset=superglue_train_dataset,
    eval_dataset=superglue_eval_dataset,
    data_collator=default_data_collator,
    compute_metrics=compute_metrics,
)
trainer.train()

In [15]:
# saving model
model_name_or_path = "google/flan-t5-xl"
peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"
model.save_pretrained(peft_model_id)

In [16]:
!ls -lh $peft_model_id

total 37M
-rw-r--r-- 1 root root 5.1K Aug  4 20:25 README.md
-rw-r--r-- 1 root root  381 Aug  4 20:25 adapter_config.json
-rw-r--r-- 1 root root  37M Aug  4 20:25 adapter_model.safetensors


## Load and infer

In [17]:
device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
device = f"{device_type}:0" if device_type != "cpu" else "cpu"

In [18]:
peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"

config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, peft_model_id)
model = model.to(device)
model = model.eval()

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 22.17it/s]


In [19]:
i = 5
inputs = tokenizer(rte_dataset["validation"]["input"][i], return_tensors="pt")
inputs["task_ids"] = torch.LongTensor([TASK2ID["rte"]])
inputs = {k: v.to(device) for k, v in inputs.items()}
print(rte_dataset["validation"]["input"][i])
print(rte_dataset["validation"]["output"][i])
print(inputs)

with torch.no_grad():
    outputs = model.generate(**inputs, max_new_tokens=2)
    print(outputs[0])
    print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])

In 1979, the leaders signed the Egypt-Israel peace treaty on the White House lawn. Both President Begin and Sadat received the Nobel Peace Prize for their work. The two nations have enjoyed peaceful relations to this day.
The Israel-Egypt Peace Agreement was signed in 1979.
Is the sentence below entailed by the sentence above?
A. Yes
B. No
Answer:
A
{'input_ids': tensor([[   86, 15393,     6,     8,  2440,  3814,     8, 10438,    18, 30387,
          3065,  2665,    63,    30,     8,  1945,  1384,  8652,     5,  2867,
          1661, 10129,    77,    11, 18875,   144,  1204,     8, 22232, 11128,
         11329,    21,    70,   161,     5,    37,   192,  9352,    43,  2994,
          9257,  5836,    12,    48,   239,     5,    37,  3352,    18,   427,
           122,    63,   102,    17, 11128,  7139,    47,  3814,    16, 15393,
             5,    27,     7,     8,  7142,   666,     3,   295, 10990,    57,
             8,  7142,   756,    58,    71,     5,  2163,   272,     5,   465,
  