{
"cells": [
{
"cell_type": "markdown",
"id": "cd408a7a-d4b6-4f33-83d3-c607dbc5f580",
"metadata": {
"collapsed": true,
"jupyter": {
"outputs_hidden": true
}
},
"source": [
"# Prompt Formatter Tutorial\n",
"\n",
"This tutorial introduces NeMo's PromptFormatter API available in module `nemo.collections.common.prompts`.\n",
"After finishing this tutorial you will be familiar with the existing prompt formatters, how to use them, and how to build your own.\n",
"\n",
"We cover the following topics:\n",
"\n",
"* Using existing prompt formatters with Llama2 as an example.\n",
"\n",
"* Defining your own prompt formatter.\n",
"\n",
"We also support applying prompt formatters for multimodal data and Lhotse-compatible data types. To learn more, see our other tutorial: [Multimodal Lhotse Dataloading](./Multimodal Lhotse Dataloading.ipynb)"
]
},
{
"cell_type": "markdown",
"id": "3f87f30c-79c0-41e8-b126-283ff5436465",
"metadata": {},
"source": [
"### Pre-requsite: building a dummy tokenizer\n",
"\n",
"We're going to need a tokenizer to work with prompt formatters - we'll just build a dummy one for the purpose of this tutorial."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "e91ebef5-9a25-4eb1-8211-d0f5990f7c37",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/pzelasko/miniforge3/envs/nemo/lib/python3.10/site-packages/transformers/utils/generic.py:441: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
" _torch_pytree._register_pytree_node(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[NeMo I 2024-10-23 11:26:41 sentencepiece_tokenizer:333] tokenizer model _tutorial_spt/tokenizer.model already exists\n"
]
}
],
"source": [
"import string\n",
"import shlex\n",
"from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model\n",
"\n",
"!echo {shlex.quote(' '.join(string.printable))} > _tutorial_train_text.txt\n",
"\n",
"tok_path, vocab_path = create_spt_model(\n",
" data_file=\"_tutorial_train_text.txt\", \n",
" output_dir=\"_tutorial_spt\",\n",
" vocab_size=512, \n",
" sample_size=-1, \n",
" do_lower_case=False, \n",
" bos=True, \n",
" eos=True, \n",
" pad=True, \n",
" user_defined_symbols=[\"[INST]\", \"[/INST]\", \"<>\", \"<>\", \"[audio]\"]\n",
")\n",
"\n",
"tokenizer = SentencePieceTokenizer(tok_path)\n",
"\n",
"def display(encoded_chat, with_mask=False):\n",
" \"\"\"Utility for printing prompt formatted chats.\"\"\"\n",
" for key, val in encoded_chat.items():\n",
" if key.endswith(\"_ids\"):\n",
" print(key, '--', tokenizer.ids_to_text(val), '\\n')\n",
" if key == \"mask\" and with_mask:\n",
" print(key, '--', val)"
]
},
{
"cell_type": "markdown",
"id": "4c5c6c88-c882-4305-8757-585fec3eab46",
"metadata": {},
"source": [
"## Using an existing PromptFormatter: Llama2\n",
"\n",
"\n",
"**Instanting the prompt formatter.** Let's start with a simple example of Llama2 prompt format use."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c77a993e-453f-474e-8912-fd35c7fc39ba",
"metadata": {},
"outputs": [],
"source": [
"from nemo.collections.common.prompts.llama import Llama2PromptFormatter\n",
"from pprint import pprint\n",
"\n",
"prompt = Llama2PromptFormatter(tokenizer)"
]
},
{
"cell_type": "markdown",
"id": "92054a0f-5b97-4178-94b8-a27e62acf97b",
"metadata": {},
"source": [
"**Chat example.** We'll define a multi-turn conversation between the user and assistant below:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c5eabe5e-4160-41d7-ad85-a4df596de38b",
"metadata": {},
"outputs": [],
"source": [
"chat = [\n",
" {\"role\": \"user\", \"slots\": {\"message\": \"Do you know something about electronics?\"}},\n",
" {\"role\": \"assistant\", \"slots\": {\"message\": \"Sure, ask away.\"}},\n",
" {\"role\": \"user\", \"slots\": {\"message\": \"How to build my own audio amplifier?\"}},\n",
" {\"role\": \"assistant\", \"slots\": {\"message\": \"In order to build your own audio amplifier, start with ...\"}},\n",
"]"
]
},
{
"cell_type": "markdown",
"id": "eff61b98-c7be-4345-ac97-15573d1a9533",
"metadata": {},
"source": [
"**Prompt formatter outputs.** Now, we apply prompt formatter to that conversation to obtain four tensors useful for training:\n",
"* `context_ids` encode the whole dialog history up to the last response of the assistant;\n",
"* `answer_ids` encode the last response of the assistant;\n",
"* `input_ids` encode the full conversation;\n",
"* `mask` is a boolean training loss mask that's set to `True` for every token belonging to assistant's turns.\n",
"\n",
"Since the token IDs are meaningless, we'll apply reverse tokenizer for displaying the prompt formatted example."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a10216b3-2bbe-4a2f-8ca8-557c3b9056be",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input_ids -- [INST] Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] In order to build your own audio amplifier, start with ... \n",
"\n",
"context_ids -- [INST] Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] \n",
"\n",
"answer_ids -- In order to build your own audio amplifier, start with ... \n",
"\n",
"mask -- tensor([False, False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True])\n"
]
}
],
"source": [
"encoded = prompt.encode_dialog(chat)\n",
"display(encoded, with_mask=True)"
]
},
{
"cell_type": "markdown",
"id": "e181618e-6df8-44b2-b986-15660133e486",
"metadata": {},
"source": [
"**System prompt.** We also support the system prompt. Since it affects the prompt format in a non-trivial way, it is defined as a separate role `\"system_and_user\"`, which has two slots `\"system\"` and `\"message\"`. We'll omit printing the mask for brevity."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2c3476a4-b301-4f35-9520-90d4b919363d",
"metadata": {},
"outputs": [],
"source": [
"chat_with_system = [\n",
" {\"role\": \"system_and_user\", \"slots\": {\"system\": \"You are a sales rep in an electronics store.\", \"message\": \"Do you know something about electronics?\"}},\n",
" {\"role\": \"assistant\", \"slots\": {\"message\": \"Sure, ask away.\"}},\n",
" {\"role\": \"user\", \"slots\": {\"message\": \"How to build my own audio amplifier?\"}},\n",
" {\"role\": \"assistant\", \"slots\": {\"message\": \"In order to build your own audio amplifier, start with ...\"}},\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "5c8c329d-f8b3-48cb-b664-baed0fcd90ab",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input_ids -- [INST] <> You are a sales rep in an electronics store. <> Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] In order to build your own audio amplifier, start with ... \n",
"\n",
"context_ids -- [INST] <> You are a sales rep in an electronics store. <> Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] \n",
"\n",
"answer_ids -- In order to build your own audio amplifier, start with ... \n",
"\n"
]
}
],
"source": [
"encoded = prompt.encode_dialog(chat_with_system)\n",
"display(encoded)"
]
},
{
"cell_type": "markdown",
"id": "a453345a-6456-43ed-a663-0554c459fddb",
"metadata": {},
"source": [
"**Constructing inference-time prompts.** During inference, we don't know what's the last turn of the assistant - we only want to construct the ``context_ids`` tensor. In those cases, just omit the last assistant's turn. The prompt formatter will return the ``context_ids`` tensor (with ``input_ids`` alias for it too)."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "4ede7100-9d28-4cf0-ab75-bfede9936218",
"metadata": {},
"outputs": [],
"source": [
"inference_chat = [\n",
" {\"role\": \"system_and_user\", \"slots\": {\"system\": \"You are a sales rep in an electronics store.\", \"message\": \"Do you know something about electronics?\"}},\n",
" {\"role\": \"assistant\", \"slots\": {\"message\": \"Sure, ask away.\"}},\n",
" {\"role\": \"user\", \"slots\": {\"message\": \"How to build my own audio amplifier?\"}},\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "61bf8e77-0630-4a84-bd30-ca4c27f8d898",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input_ids -- [INST] <> You are a sales rep in an electronics store. <> Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] \n",
"\n",
"context_ids -- [INST] <> You are a sales rep in an electronics store. <> Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] \n",
"\n"
]
}
],
"source": [
"encoded = prompt.encode_dialog(inference_chat)\n",
"display(encoded)"
]
},
{
"cell_type": "markdown",
"id": "a334e00a-9530-4333-98de-5cb8fb08eb47",
"metadata": {},
"source": [
"### How is Llama2 PromptFormatter built\n",
"\n",
"`Llama2PromptFormatter` is a small class with prompt definition that inherits `PromptFormatter`, which implements the logic for applying prompt format and tokenization to multi-turn conversations. \n",
"\n",
"Let's take a look at `Llama2PromptFormatter` definition:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f29fbf2f-3caa-4b27-86ca-5012d9fc6ba5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"class Llama2PromptFormatter(PromptFormatter):\n",
" \"\"\"\n",
" This template has been validated to provide identical tokenized results to the official code\n",
" in https://github.com/meta-llama/llama/blob/main/llama/generation.py\n",
" \"\"\"\n",
"\n",
" NAME = \"llama2\"\n",
" OUTPUT_ROLE = \"assistant\"\n",
" TEMPLATE = {\n",
" \"system_and_user\": {\n",
" \"template\": f\"{BOS_SLOT}[INST] <>\\n|system|\\n<>\\n\\n|message| [/INST]\",\n",
" \"slots\": {\n",
" \"system\": Modality.Text,\n",
" \"message\": Modality.Text,\n",
" },\n",
" },\n",
" \"user\": {\n",
" \"template\": f\"{BOS_SLOT}[INST] |message| [/INST]\",\n",
" \"slots\": {\n",
" \"message\": Modality.Text,\n",
" },\n",
" },\n",
" OUTPUT_ROLE: {\n",
" \"template\": f\"|message| {EOS_SLOT}\",\n",
" \"slots\": {\n",
" \"message\": Modality.Text,\n",
" },\n",
" },\n",
" }\n",
"\n"
]
}
],
"source": [
"import inspect\n",
"print(inspect.getsource(Llama2PromptFormatter))"
]
},
{
"cell_type": "markdown",
"id": "b24e9310-b8ed-4e35-9dda-d24aa62cfb6a",
"metadata": {},
"source": [
"As you can see, the definition consist of the following key components:\n",
"* Derives `PromptFormatter` parent class.\n",
"* Specifies `NAME`, which is used for dynamic resolution of string to class via `cls = PromptFormatter.resolve(name)`.\n",
"* Specifies `OUTPUT_ROLE`, which is the name for the role with assistant's responses (typically `\"assistant\"`).\n",
"* Specifies `TEMPLATE` which defines the dialog structure and how user-provided values (slots) are applied to prompts. Notably:\n",
" * The slots are wrapped into pipe operators `\"|\"` in the prompt template definition, and substituted with user provided values before tokenization.\n",
" * `\"system_and_user`\" role has two slots, `\"system\"` and `\"message\"`, and a template that wraps them with Llama2 special tokens.\n",
" * We use `BOS_SLOT` and `EOS_SLOT` to insert sentencepiece tokenizer's `bos_id` and `eos_id` in the right places (remember that sentencepiece won't tokenize them from text, they need to be inserted programmatically).\n",
" * The slots have a type, currently supported types are `Modality.Text` and `Modality.TextLiteral(value1, value2, ...)` that allows to restrict the set of slots values."
]
},
{
"cell_type": "markdown",
"id": "8cbdca6c-6c0f-42a9-a4a7-b936684c6e12",
"metadata": {},
"source": [
"## Defining your own prompt formatter"
]
},
{
"cell_type": "markdown",
"id": "25a9b6d2-d004-4f7f-8b24-4fd6d4eae244",
"metadata": {},
"source": [
"Generally you can follow the definition of existing prompt formatters to define your own. \n",
"We have several prompt formats implemented for Llama, Gemma, Phi, etc. \n",
"\n",
"We'll define a custom simple prompt format that has no system prompt below as an illustration:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "b69f6532-24d8-4419-b1da-42184c3d72de",
"metadata": {},
"outputs": [],
"source": [
"from nemo.collections.common.prompts.formatter import PromptFormatter, Modality\n",
"\n",
"class MyPrompt(PromptFormatter):\n",
" NAME = \"myprompt\"\n",
" OUTPUT_ROLE = \"assistant\"\n",
" TEMPLATE = {\n",
" \"user\": {\n",
" \"template\": \"User: |message|\\n\",\n",
" \"slots\": {\"message\": Modality.Text},\n",
" },\n",
" \"assistant\": {\n",
" \"template\": \"Assistant: |message|\\n\",\n",
" \"slots\": {\"message\": Modality.Text},\n",
" },\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "a97c6589-1303-446c-952f-d2b4007ca7e9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input_ids -- User: Do you know something about electronics? Assistant: Sure, ask away. User: How to build my own audio amplifier? Assistant: In order to build your own audio amplifier, start with ... \n",
"\n",
"context_ids -- User: Do you know something about electronics? Assistant: Sure, ask away. User: How to build my own audio amplifier? \n",
"\n",
"answer_ids -- Assistant: In order to build your own audio amplifier, start with ... \n",
"\n"
]
}
],
"source": [
"my_prompt_cls = PromptFormatter.resolve(\"myprompt\") # it is auto-registered\n",
"my_prompt = my_prompt_cls(tokenizer)\n",
"display(my_prompt.encode_dialog(chat))"
]
},
{
"cell_type": "markdown",
"id": "30f9c96a-6cf8-4cd3-b0e8-6b461c86100f",
"metadata": {},
"source": [
"## Applying prompt formatter to multimodal data\n",
"\n",
"We refer the reader to our other tutorial, [Multimodal Lhotse Dataloading](./Multimodal Lhotse Dataloading.ipynb), where this is discussed in detail."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}