File size: 18,015 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cd408a7a-d4b6-4f33-83d3-c607dbc5f580",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "source": [
    "# Prompt Formatter Tutorial\n",
    "\n",
    "This tutorial introduces NeMo's PromptFormatter API available in module `nemo.collections.common.prompts`.\n",
    "After finishing this tutorial you will be familiar with the existing prompt formatters, how to use them, and how to build your own.\n",
    "\n",
    "We cover the following topics:\n",
    "\n",
    "* Using existing prompt formatters with Llama2 as an example.\n",
    "\n",
    "* Defining your own prompt formatter.\n",
    "\n",
    "We also support applying prompt formatters for multimodal data and Lhotse-compatible data types. To learn more, see our other tutorial: [Multimodal Lhotse Dataloading](./Multimodal Lhotse Dataloading.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f87f30c-79c0-41e8-b126-283ff5436465",
   "metadata": {},
   "source": [
    "### Pre-requsite: building a dummy tokenizer\n",
    "\n",
    "We're going to need a tokenizer to work with prompt formatters - we'll just build a dummy one for the purpose of this tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e91ebef5-9a25-4eb1-8211-d0f5990f7c37",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/pzelasko/miniforge3/envs/nemo/lib/python3.10/site-packages/transformers/utils/generic.py:441: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
      "  _torch_pytree._register_pytree_node(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[NeMo I 2024-10-23 11:26:41 sentencepiece_tokenizer:333] tokenizer model _tutorial_spt/tokenizer.model already exists\n"
     ]
    }
   ],
   "source": [
    "import string\n",
    "import shlex\n",
    "from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model\n",
    "\n",
    "!echo {shlex.quote(' '.join(string.printable))} > _tutorial_train_text.txt\n",
    "\n",
    "tok_path, vocab_path = create_spt_model(\n",
    "    data_file=\"_tutorial_train_text.txt\", \n",
    "    output_dir=\"_tutorial_spt\",\n",
    "    vocab_size=512, \n",
    "    sample_size=-1, \n",
    "    do_lower_case=False, \n",
    "    bos=True, \n",
    "    eos=True, \n",
    "    pad=True, \n",
    "    user_defined_symbols=[\"[INST]\", \"[/INST]\", \"<<SYS>>\", \"<</SYS>>\", \"[audio]\"]\n",
    ")\n",
    "\n",
    "tokenizer = SentencePieceTokenizer(tok_path)\n",
    "\n",
    "def display(encoded_chat, with_mask=False):\n",
    "    \"\"\"Utility for printing prompt formatted chats.\"\"\"\n",
    "    for key, val in encoded_chat.items():\n",
    "        if key.endswith(\"_ids\"):\n",
    "            print(key, '--', tokenizer.ids_to_text(val), '\\n')\n",
    "        if key == \"mask\" and with_mask:\n",
    "            print(key, '--', val)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c5c6c88-c882-4305-8757-585fec3eab46",
   "metadata": {},
   "source": [
    "## Using an existing PromptFormatter: Llama2\n",
    "\n",
    "\n",
    "**Instanting the prompt formatter.** Let's start with a simple example of Llama2 prompt format use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c77a993e-453f-474e-8912-fd35c7fc39ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.collections.common.prompts.llama import Llama2PromptFormatter\n",
    "from pprint import pprint\n",
    "\n",
    "prompt = Llama2PromptFormatter(tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92054a0f-5b97-4178-94b8-a27e62acf97b",
   "metadata": {},
   "source": [
    "**Chat example.** We'll define a multi-turn conversation between the user and assistant below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c5eabe5e-4160-41d7-ad85-a4df596de38b",
   "metadata": {},
   "outputs": [],
   "source": [
    "chat = [\n",
    "    {\"role\": \"user\", \"slots\": {\"message\": \"Do you know something about electronics?\"}},\n",
    "    {\"role\": \"assistant\", \"slots\": {\"message\": \"Sure, ask away.\"}},\n",
    "    {\"role\": \"user\", \"slots\": {\"message\": \"How to build my own audio amplifier?\"}},\n",
    "    {\"role\": \"assistant\", \"slots\": {\"message\": \"In order to build your own audio amplifier, start with ...\"}},\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eff61b98-c7be-4345-ac97-15573d1a9533",
   "metadata": {},
   "source": [
    "**Prompt formatter outputs.** Now, we apply prompt formatter to that conversation to obtain four tensors useful for training:\n",
    "* `context_ids` encode the whole dialog history up to the last response of the assistant;\n",
    "* `answer_ids` encode the last response of the assistant;\n",
    "* `input_ids` encode the full conversation;\n",
    "* `mask` is a boolean training loss mask that's set to `True` for every token belonging to assistant's turns.\n",
    "\n",
    "Since the token IDs are meaningless, we'll apply reverse tokenizer for displaying the prompt formatted example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a10216b3-2bbe-4a2f-8ca8-557c3b9056be",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input_ids -- [INST] Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] In order to build your own audio amplifier, start with ... \n",
      "\n",
      "context_ids -- [INST] Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] \n",
      "\n",
      "answer_ids -- In order to build your own audio amplifier, start with ... \n",
      "\n",
      "mask -- tensor([False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False, False, False, False, False, False,\n",
      "        False, False, False, False, False,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "         True,  True,  True,  True,  True])\n"
     ]
    }
   ],
   "source": [
    "encoded = prompt.encode_dialog(chat)\n",
    "display(encoded, with_mask=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e181618e-6df8-44b2-b986-15660133e486",
   "metadata": {},
   "source": [
    "**System prompt.** We also support the system prompt. Since it affects the prompt format in a non-trivial way, it is defined as a separate role `\"system_and_user\"`, which has two slots `\"system\"` and `\"message\"`. We'll omit printing the mask for brevity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2c3476a4-b301-4f35-9520-90d4b919363d",
   "metadata": {},
   "outputs": [],
   "source": [
    "chat_with_system = [\n",
    "    {\"role\": \"system_and_user\", \"slots\": {\"system\": \"You are a sales rep in an electronics store.\", \"message\": \"Do you know something about electronics?\"}},\n",
    "    {\"role\": \"assistant\", \"slots\": {\"message\": \"Sure, ask away.\"}},\n",
    "    {\"role\": \"user\", \"slots\": {\"message\": \"How to build my own audio amplifier?\"}},\n",
    "    {\"role\": \"assistant\", \"slots\": {\"message\": \"In order to build your own audio amplifier, start with ...\"}},\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5c8c329d-f8b3-48cb-b664-baed0fcd90ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input_ids -- [INST] <<SYS>> You are a sales rep in an electronics store. <</SYS>> Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] In order to build your own audio amplifier, start with ... \n",
      "\n",
      "context_ids -- [INST] <<SYS>> You are a sales rep in an electronics store. <</SYS>> Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] \n",
      "\n",
      "answer_ids -- In order to build your own audio amplifier, start with ... \n",
      "\n"
     ]
    }
   ],
   "source": [
    "encoded = prompt.encode_dialog(chat_with_system)\n",
    "display(encoded)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a453345a-6456-43ed-a663-0554c459fddb",
   "metadata": {},
   "source": [
    "**Constructing inference-time prompts.** During inference, we don't know what's the last turn of the assistant - we only want to construct the ``context_ids`` tensor. In those cases, just omit the last assistant's turn. The prompt formatter will return the ``context_ids`` tensor (with ``input_ids`` alias for it too)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4ede7100-9d28-4cf0-ab75-bfede9936218",
   "metadata": {},
   "outputs": [],
   "source": [
    "inference_chat = [\n",
    "    {\"role\": \"system_and_user\", \"slots\": {\"system\": \"You are a sales rep in an electronics store.\", \"message\": \"Do you know something about electronics?\"}},\n",
    "    {\"role\": \"assistant\", \"slots\": {\"message\": \"Sure, ask away.\"}},\n",
    "    {\"role\": \"user\", \"slots\": {\"message\": \"How to build my own audio amplifier?\"}},\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "61bf8e77-0630-4a84-bd30-ca4c27f8d898",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input_ids -- [INST] <<SYS>> You are a sales rep in an electronics store. <</SYS>> Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] \n",
      "\n",
      "context_ids -- [INST] <<SYS>> You are a sales rep in an electronics store. <</SYS>> Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] \n",
      "\n"
     ]
    }
   ],
   "source": [
    "encoded = prompt.encode_dialog(inference_chat)\n",
    "display(encoded)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a334e00a-9530-4333-98de-5cb8fb08eb47",
   "metadata": {},
   "source": [
    "### How is Llama2 PromptFormatter built\n",
    "\n",
    "`Llama2PromptFormatter` is a small class with prompt definition that inherits `PromptFormatter`, which implements the logic for applying prompt format and tokenization to multi-turn conversations. \n",
    "\n",
    "Let's take a look at `Llama2PromptFormatter` definition:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f29fbf2f-3caa-4b27-86ca-5012d9fc6ba5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "class Llama2PromptFormatter(PromptFormatter):\n",
      "    \"\"\"\n",
      "    This template has been validated to provide identical tokenized results to the official code\n",
      "    in https://github.com/meta-llama/llama/blob/main/llama/generation.py\n",
      "    \"\"\"\n",
      "\n",
      "    NAME = \"llama2\"\n",
      "    OUTPUT_ROLE = \"assistant\"\n",
      "    TEMPLATE = {\n",
      "        \"system_and_user\": {\n",
      "            \"template\": f\"{BOS_SLOT}[INST] <<SYS>>\\n|system|\\n<</SYS>>\\n\\n|message| [/INST]\",\n",
      "            \"slots\": {\n",
      "                \"system\": Modality.Text,\n",
      "                \"message\": Modality.Text,\n",
      "            },\n",
      "        },\n",
      "        \"user\": {\n",
      "            \"template\": f\"{BOS_SLOT}[INST] |message| [/INST]\",\n",
      "            \"slots\": {\n",
      "                \"message\": Modality.Text,\n",
      "            },\n",
      "        },\n",
      "        OUTPUT_ROLE: {\n",
      "            \"template\": f\"|message| {EOS_SLOT}\",\n",
      "            \"slots\": {\n",
      "                \"message\": Modality.Text,\n",
      "            },\n",
      "        },\n",
      "    }\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import inspect\n",
    "print(inspect.getsource(Llama2PromptFormatter))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b24e9310-b8ed-4e35-9dda-d24aa62cfb6a",
   "metadata": {},
   "source": [
    "As you can see, the definition consist of the following key components:\n",
    "* Derives `PromptFormatter` parent class.\n",
    "* Specifies `NAME`, which is used for dynamic resolution of string to class via `cls = PromptFormatter.resolve(name)`.\n",
    "* Specifies `OUTPUT_ROLE`, which is the name for the role with assistant's responses (typically `\"assistant\"`).\n",
    "* Specifies `TEMPLATE` which defines the dialog structure and how user-provided values (slots) are applied to prompts. Notably:\n",
    "  * The slots are wrapped into pipe operators `\"|\"` in the prompt template definition, and substituted with user provided values before tokenization.\n",
    "  * `\"system_and_user`\" role has two slots, `\"system\"` and `\"message\"`, and a template that wraps them with Llama2 special tokens.\n",
    "  * We use `BOS_SLOT` and `EOS_SLOT` to insert sentencepiece tokenizer's `bos_id` and `eos_id` in the right places (remember that sentencepiece won't tokenize them from text, they need to be inserted programmatically).\n",
    "  * The slots have a type, currently supported types are `Modality.Text` and `Modality.TextLiteral(value1, value2, ...)` that allows to restrict the set of slots values."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8cbdca6c-6c0f-42a9-a4a7-b936684c6e12",
   "metadata": {},
   "source": [
    "## Defining your own prompt formatter"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25a9b6d2-d004-4f7f-8b24-4fd6d4eae244",
   "metadata": {},
   "source": [
    "Generally you can follow the definition of existing prompt formatters to define your own. \n",
    "We have several prompt formats implemented for Llama, Gemma, Phi, etc. \n",
    "\n",
    "We'll define a custom simple prompt format that has no system prompt below as an illustration:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b69f6532-24d8-4419-b1da-42184c3d72de",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.collections.common.prompts.formatter import PromptFormatter, Modality\n",
    "\n",
    "class MyPrompt(PromptFormatter):\n",
    "    NAME = \"myprompt\"\n",
    "    OUTPUT_ROLE = \"assistant\"\n",
    "    TEMPLATE = {\n",
    "        \"user\": {\n",
    "            \"template\": \"User: |message|\\n\",\n",
    "            \"slots\": {\"message\": Modality.Text},\n",
    "        },\n",
    "        \"assistant\": {\n",
    "            \"template\": \"Assistant: |message|\\n\",\n",
    "            \"slots\": {\"message\": Modality.Text},\n",
    "        },\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a97c6589-1303-446c-952f-d2b4007ca7e9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input_ids -- User: Do you know something about electronics? Assistant: Sure, ask away. User: How to build my own audio amplifier? Assistant: In order to build your own audio amplifier, start with ... \n",
      "\n",
      "context_ids -- User: Do you know something about electronics? Assistant: Sure, ask away. User: How to build my own audio amplifier? \n",
      "\n",
      "answer_ids -- Assistant: In order to build your own audio amplifier, start with ... \n",
      "\n"
     ]
    }
   ],
   "source": [
    "my_prompt_cls = PromptFormatter.resolve(\"myprompt\")  # it is auto-registered\n",
    "my_prompt = my_prompt_cls(tokenizer)\n",
    "display(my_prompt.encode_dialog(chat))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30f9c96a-6cf8-4cd3-b0e8-6b461c86100f",
   "metadata": {},
   "source": [
    "## Applying prompt formatter to multimodal data\n",
    "\n",
    "We refer the reader to our other tutorial, [Multimodal Lhotse Dataloading](./Multimodal Lhotse Dataloading.ipynb), where this is discussed in detail."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}