{ "cells": [ { "cell_type": "markdown", "id": "cd408a7a-d4b6-4f33-83d3-c607dbc5f580", "metadata": { "collapsed": true, "jupyter": { "outputs_hidden": true } }, "source": [ "# Prompt Formatter Tutorial\n", "\n", "This tutorial introduces NeMo's PromptFormatter API available in module `nemo.collections.common.prompts`.\n", "After finishing this tutorial you will be familiar with the existing prompt formatters, how to use them, and how to build your own.\n", "\n", "We cover the following topics:\n", "\n", "* Using existing prompt formatters with Llama2 as an example.\n", "\n", "* Defining your own prompt formatter.\n", "\n", "We also support applying prompt formatters for multimodal data and Lhotse-compatible data types. To learn more, see our other tutorial: [Multimodal Lhotse Dataloading](./Multimodal Lhotse Dataloading.ipynb)" ] }, { "cell_type": "markdown", "id": "3f87f30c-79c0-41e8-b126-283ff5436465", "metadata": {}, "source": [ "### Pre-requsite: building a dummy tokenizer\n", "\n", "We're going to need a tokenizer to work with prompt formatters - we'll just build a dummy one for the purpose of this tutorial." ] }, { "cell_type": "code", "execution_count": 1, "id": "e91ebef5-9a25-4eb1-8211-d0f5990f7c37", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/pzelasko/miniforge3/envs/nemo/lib/python3.10/site-packages/transformers/utils/generic.py:441: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n", " _torch_pytree._register_pytree_node(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[NeMo I 2024-10-23 11:26:41 sentencepiece_tokenizer:333] tokenizer model _tutorial_spt/tokenizer.model already exists\n" ] } ], "source": [ "import string\n", "import shlex\n", "from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model\n", "\n", "!echo {shlex.quote(' '.join(string.printable))} > _tutorial_train_text.txt\n", "\n", "tok_path, vocab_path = create_spt_model(\n", " data_file=\"_tutorial_train_text.txt\", \n", " output_dir=\"_tutorial_spt\",\n", " vocab_size=512, \n", " sample_size=-1, \n", " do_lower_case=False, \n", " bos=True, \n", " eos=True, \n", " pad=True, \n", " user_defined_symbols=[\"[INST]\", \"[/INST]\", \"<>\", \"<>\", \"[audio]\"]\n", ")\n", "\n", "tokenizer = SentencePieceTokenizer(tok_path)\n", "\n", "def display(encoded_chat, with_mask=False):\n", " \"\"\"Utility for printing prompt formatted chats.\"\"\"\n", " for key, val in encoded_chat.items():\n", " if key.endswith(\"_ids\"):\n", " print(key, '--', tokenizer.ids_to_text(val), '\\n')\n", " if key == \"mask\" and with_mask:\n", " print(key, '--', val)" ] }, { "cell_type": "markdown", "id": "4c5c6c88-c882-4305-8757-585fec3eab46", "metadata": {}, "source": [ "## Using an existing PromptFormatter: Llama2\n", "\n", "\n", "**Instanting the prompt formatter.** Let's start with a simple example of Llama2 prompt format use." ] }, { "cell_type": "code", "execution_count": 2, "id": "c77a993e-453f-474e-8912-fd35c7fc39ba", "metadata": {}, "outputs": [], "source": [ "from nemo.collections.common.prompts.llama import Llama2PromptFormatter\n", "from pprint import pprint\n", "\n", "prompt = Llama2PromptFormatter(tokenizer)" ] }, { "cell_type": "markdown", "id": "92054a0f-5b97-4178-94b8-a27e62acf97b", "metadata": {}, "source": [ "**Chat example.** We'll define a multi-turn conversation between the user and assistant below:" ] }, { "cell_type": "code", "execution_count": 3, "id": "c5eabe5e-4160-41d7-ad85-a4df596de38b", "metadata": {}, "outputs": [], "source": [ "chat = [\n", " {\"role\": \"user\", \"slots\": {\"message\": \"Do you know something about electronics?\"}},\n", " {\"role\": \"assistant\", \"slots\": {\"message\": \"Sure, ask away.\"}},\n", " {\"role\": \"user\", \"slots\": {\"message\": \"How to build my own audio amplifier?\"}},\n", " {\"role\": \"assistant\", \"slots\": {\"message\": \"In order to build your own audio amplifier, start with ...\"}},\n", "]" ] }, { "cell_type": "markdown", "id": "eff61b98-c7be-4345-ac97-15573d1a9533", "metadata": {}, "source": [ "**Prompt formatter outputs.** Now, we apply prompt formatter to that conversation to obtain four tensors useful for training:\n", "* `context_ids` encode the whole dialog history up to the last response of the assistant;\n", "* `answer_ids` encode the last response of the assistant;\n", "* `input_ids` encode the full conversation;\n", "* `mask` is a boolean training loss mask that's set to `True` for every token belonging to assistant's turns.\n", "\n", "Since the token IDs are meaningless, we'll apply reverse tokenizer for displaying the prompt formatted example." ] }, { "cell_type": "code", "execution_count": 4, "id": "a10216b3-2bbe-4a2f-8ca8-557c3b9056be", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "input_ids -- [INST] Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] In order to build your own audio amplifier, start with ... \n", "\n", "context_ids -- [INST] Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] \n", "\n", "answer_ids -- In order to build your own audio amplifier, start with ... \n", "\n", "mask -- tensor([False, False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True, True, True, True, True, True,\n", " True, True, True, True, True])\n" ] } ], "source": [ "encoded = prompt.encode_dialog(chat)\n", "display(encoded, with_mask=True)" ] }, { "cell_type": "markdown", "id": "e181618e-6df8-44b2-b986-15660133e486", "metadata": {}, "source": [ "**System prompt.** We also support the system prompt. Since it affects the prompt format in a non-trivial way, it is defined as a separate role `\"system_and_user\"`, which has two slots `\"system\"` and `\"message\"`. We'll omit printing the mask for brevity." ] }, { "cell_type": "code", "execution_count": 5, "id": "2c3476a4-b301-4f35-9520-90d4b919363d", "metadata": {}, "outputs": [], "source": [ "chat_with_system = [\n", " {\"role\": \"system_and_user\", \"slots\": {\"system\": \"You are a sales rep in an electronics store.\", \"message\": \"Do you know something about electronics?\"}},\n", " {\"role\": \"assistant\", \"slots\": {\"message\": \"Sure, ask away.\"}},\n", " {\"role\": \"user\", \"slots\": {\"message\": \"How to build my own audio amplifier?\"}},\n", " {\"role\": \"assistant\", \"slots\": {\"message\": \"In order to build your own audio amplifier, start with ...\"}},\n", "]" ] }, { "cell_type": "code", "execution_count": 6, "id": "5c8c329d-f8b3-48cb-b664-baed0fcd90ab", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "input_ids -- [INST] <> You are a sales rep in an electronics store. <> Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] In order to build your own audio amplifier, start with ... \n", "\n", "context_ids -- [INST] <> You are a sales rep in an electronics store. <> Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] \n", "\n", "answer_ids -- In order to build your own audio amplifier, start with ... \n", "\n" ] } ], "source": [ "encoded = prompt.encode_dialog(chat_with_system)\n", "display(encoded)" ] }, { "cell_type": "markdown", "id": "a453345a-6456-43ed-a663-0554c459fddb", "metadata": {}, "source": [ "**Constructing inference-time prompts.** During inference, we don't know what's the last turn of the assistant - we only want to construct the ``context_ids`` tensor. In those cases, just omit the last assistant's turn. The prompt formatter will return the ``context_ids`` tensor (with ``input_ids`` alias for it too)." ] }, { "cell_type": "code", "execution_count": 7, "id": "4ede7100-9d28-4cf0-ab75-bfede9936218", "metadata": {}, "outputs": [], "source": [ "inference_chat = [\n", " {\"role\": \"system_and_user\", \"slots\": {\"system\": \"You are a sales rep in an electronics store.\", \"message\": \"Do you know something about electronics?\"}},\n", " {\"role\": \"assistant\", \"slots\": {\"message\": \"Sure, ask away.\"}},\n", " {\"role\": \"user\", \"slots\": {\"message\": \"How to build my own audio amplifier?\"}},\n", "]" ] }, { "cell_type": "code", "execution_count": 8, "id": "61bf8e77-0630-4a84-bd30-ca4c27f8d898", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "input_ids -- [INST] <> You are a sales rep in an electronics store. <> Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] \n", "\n", "context_ids -- [INST] <> You are a sales rep in an electronics store. <> Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] \n", "\n" ] } ], "source": [ "encoded = prompt.encode_dialog(inference_chat)\n", "display(encoded)" ] }, { "cell_type": "markdown", "id": "a334e00a-9530-4333-98de-5cb8fb08eb47", "metadata": {}, "source": [ "### How is Llama2 PromptFormatter built\n", "\n", "`Llama2PromptFormatter` is a small class with prompt definition that inherits `PromptFormatter`, which implements the logic for applying prompt format and tokenization to multi-turn conversations. \n", "\n", "Let's take a look at `Llama2PromptFormatter` definition:" ] }, { "cell_type": "code", "execution_count": 9, "id": "f29fbf2f-3caa-4b27-86ca-5012d9fc6ba5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "class Llama2PromptFormatter(PromptFormatter):\n", " \"\"\"\n", " This template has been validated to provide identical tokenized results to the official code\n", " in https://github.com/meta-llama/llama/blob/main/llama/generation.py\n", " \"\"\"\n", "\n", " NAME = \"llama2\"\n", " OUTPUT_ROLE = \"assistant\"\n", " TEMPLATE = {\n", " \"system_and_user\": {\n", " \"template\": f\"{BOS_SLOT}[INST] <>\\n|system|\\n<>\\n\\n|message| [/INST]\",\n", " \"slots\": {\n", " \"system\": Modality.Text,\n", " \"message\": Modality.Text,\n", " },\n", " },\n", " \"user\": {\n", " \"template\": f\"{BOS_SLOT}[INST] |message| [/INST]\",\n", " \"slots\": {\n", " \"message\": Modality.Text,\n", " },\n", " },\n", " OUTPUT_ROLE: {\n", " \"template\": f\"|message| {EOS_SLOT}\",\n", " \"slots\": {\n", " \"message\": Modality.Text,\n", " },\n", " },\n", " }\n", "\n" ] } ], "source": [ "import inspect\n", "print(inspect.getsource(Llama2PromptFormatter))" ] }, { "cell_type": "markdown", "id": "b24e9310-b8ed-4e35-9dda-d24aa62cfb6a", "metadata": {}, "source": [ "As you can see, the definition consist of the following key components:\n", "* Derives `PromptFormatter` parent class.\n", "* Specifies `NAME`, which is used for dynamic resolution of string to class via `cls = PromptFormatter.resolve(name)`.\n", "* Specifies `OUTPUT_ROLE`, which is the name for the role with assistant's responses (typically `\"assistant\"`).\n", "* Specifies `TEMPLATE` which defines the dialog structure and how user-provided values (slots) are applied to prompts. Notably:\n", " * The slots are wrapped into pipe operators `\"|\"` in the prompt template definition, and substituted with user provided values before tokenization.\n", " * `\"system_and_user`\" role has two slots, `\"system\"` and `\"message\"`, and a template that wraps them with Llama2 special tokens.\n", " * We use `BOS_SLOT` and `EOS_SLOT` to insert sentencepiece tokenizer's `bos_id` and `eos_id` in the right places (remember that sentencepiece won't tokenize them from text, they need to be inserted programmatically).\n", " * The slots have a type, currently supported types are `Modality.Text` and `Modality.TextLiteral(value1, value2, ...)` that allows to restrict the set of slots values." ] }, { "cell_type": "markdown", "id": "8cbdca6c-6c0f-42a9-a4a7-b936684c6e12", "metadata": {}, "source": [ "## Defining your own prompt formatter" ] }, { "cell_type": "markdown", "id": "25a9b6d2-d004-4f7f-8b24-4fd6d4eae244", "metadata": {}, "source": [ "Generally you can follow the definition of existing prompt formatters to define your own. \n", "We have several prompt formats implemented for Llama, Gemma, Phi, etc. \n", "\n", "We'll define a custom simple prompt format that has no system prompt below as an illustration:" ] }, { "cell_type": "code", "execution_count": 10, "id": "b69f6532-24d8-4419-b1da-42184c3d72de", "metadata": {}, "outputs": [], "source": [ "from nemo.collections.common.prompts.formatter import PromptFormatter, Modality\n", "\n", "class MyPrompt(PromptFormatter):\n", " NAME = \"myprompt\"\n", " OUTPUT_ROLE = \"assistant\"\n", " TEMPLATE = {\n", " \"user\": {\n", " \"template\": \"User: |message|\\n\",\n", " \"slots\": {\"message\": Modality.Text},\n", " },\n", " \"assistant\": {\n", " \"template\": \"Assistant: |message|\\n\",\n", " \"slots\": {\"message\": Modality.Text},\n", " },\n", " }" ] }, { "cell_type": "code", "execution_count": 11, "id": "a97c6589-1303-446c-952f-d2b4007ca7e9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "input_ids -- User: Do you know something about electronics? Assistant: Sure, ask away. User: How to build my own audio amplifier? Assistant: In order to build your own audio amplifier, start with ... \n", "\n", "context_ids -- User: Do you know something about electronics? Assistant: Sure, ask away. User: How to build my own audio amplifier? \n", "\n", "answer_ids -- Assistant: In order to build your own audio amplifier, start with ... \n", "\n" ] } ], "source": [ "my_prompt_cls = PromptFormatter.resolve(\"myprompt\") # it is auto-registered\n", "my_prompt = my_prompt_cls(tokenizer)\n", "display(my_prompt.encode_dialog(chat))" ] }, { "cell_type": "markdown", "id": "30f9c96a-6cf8-4cd3-b0e8-6b461c86100f", "metadata": {}, "source": [ "## Applying prompt formatter to multimodal data\n", "\n", "We refer the reader to our other tutorial, [Multimodal Lhotse Dataloading](./Multimodal Lhotse Dataloading.ipynb), where this is discussed in detail." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 5 }