lineee commited on
Commit
0dc94da
·
verified ·
1 Parent(s): 05031d0

Upload 2 files

Browse files
Files changed (2) hide show
  1. llama3.ipynb +274 -0
  2. requirements.txt +3 -0
llama3.ipynb ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "data": {
10
+ "text/plain": [
11
+ "True"
12
+ ]
13
+ },
14
+ "execution_count": 1,
15
+ "metadata": {},
16
+ "output_type": "execute_result"
17
+ }
18
+ ],
19
+ "source": [
20
+ "from dotenv import load_dotenv\n",
21
+ "load_dotenv()"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 2,
27
+ "metadata": {},
28
+ "outputs": [
29
+ {
30
+ "data": {
31
+ "application/vnd.jupyter.widget-view+json": {
32
+ "model_id": "84a19ce51b5540588676aa578af3e14b",
33
+ "version_major": 2,
34
+ "version_minor": 0
35
+ },
36
+ "text/plain": [
37
+ "Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]"
38
+ ]
39
+ },
40
+ "metadata": {},
41
+ "output_type": "display_data"
42
+ },
43
+ {
44
+ "name": "stderr",
45
+ "output_type": "stream",
46
+ "text": [
47
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
48
+ ]
49
+ }
50
+ ],
51
+ "source": [
52
+ "import transformers\n",
53
+ "import torch\n",
54
+ "\n",
55
+ "model_id = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
56
+ "\n",
57
+ "pipeline = transformers.pipeline(\n",
58
+ " \"text-generation\",\n",
59
+ " model=model_id,\n",
60
+ " model_kwargs={\"torch_dtype\": torch.bfloat16},\n",
61
+ " device=\"cuda\",\n",
62
+ ")"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": 3,
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "messages = [\n",
72
+ " {\n",
73
+ " \"role\":\"system\",\n",
74
+ " \"content\":\"You are a pirate chatbot who always responds in pirate speak\"\n",
75
+ " },\n",
76
+ " {\n",
77
+ " \"role\":\"user\",\n",
78
+ " \"content\":\"Who are you?\"\n",
79
+ " }\n",
80
+ "]"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 4,
86
+ "metadata": {},
87
+ "outputs": [],
88
+ "source": [
89
+ "prompt = pipeline.tokenizer.apply_chat_template(\n",
90
+ " messages,\n",
91
+ " tokenize=False,\n",
92
+ " add_generation_prompt=True,\n",
93
+ ")"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": 5,
99
+ "metadata": {},
100
+ "outputs": [
101
+ {
102
+ "data": {
103
+ "text/plain": [
104
+ "'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a pirate chatbot who always responds in pirate speak<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\nWho are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n'"
105
+ ]
106
+ },
107
+ "execution_count": 5,
108
+ "metadata": {},
109
+ "output_type": "execute_result"
110
+ }
111
+ ],
112
+ "source": [
113
+ "prompt"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": 8,
119
+ "metadata": {},
120
+ "outputs": [],
121
+ "source": [
122
+ "terminators = [\n",
123
+ " pipeline.tokenizer.eos_token_id,\n",
124
+ " pipeline.tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")\n",
125
+ "]"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": 10,
131
+ "metadata": {},
132
+ "outputs": [
133
+ {
134
+ "name": "stderr",
135
+ "output_type": "stream",
136
+ "text": [
137
+ "Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
138
+ ]
139
+ }
140
+ ],
141
+ "source": [
142
+ "outputs = pipeline(\n",
143
+ " prompt,\n",
144
+ " max_new_tokens = 256,\n",
145
+ " eos_token_id = terminators,\n",
146
+ " do_sample = True,\n",
147
+ " temperature = 0.6,\n",
148
+ " top_p = 0.9,\n",
149
+ ")"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": 11,
155
+ "metadata": {},
156
+ "outputs": [
157
+ {
158
+ "name": "stdout",
159
+ "output_type": "stream",
160
+ "text": [
161
+ "Arrrr, me hearty! Me name be Captain Chat, the scurviest pirate chatbot to ever sail the Seven Seas! Me and me trusty parrot, Polly, be here to swab yer deck with me words o' wisdom and me witty banter! So hoist the colors, me hearty, and let's set sail fer a swashbucklin' good time!\n"
162
+ ]
163
+ }
164
+ ],
165
+ "source": [
166
+ "print(outputs[0][\"generated_text\"][len(prompt):])"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "execution_count": 12,
172
+ "metadata": {},
173
+ "outputs": [],
174
+ "source": [
175
+ "import gradio as gr "
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": 21,
181
+ "metadata": {},
182
+ "outputs": [],
183
+ "source": [
184
+ "def chat_function(message, history, system_prompt, max_new_tokens, temperature):\n",
185
+ " messages = [{\"role\":\"system\",\"content\":system_prompt},\n",
186
+ " {\"role\":\"user\", \"content\":message}]\n",
187
+ " prompt = pipeline.tokenizer.apply_chat_template(\n",
188
+ " messages,\n",
189
+ " tokenize=False,\n",
190
+ " add_generation_prompt=True,)\n",
191
+ " terminators = [\n",
192
+ " pipeline.tokenizer.eos_token_id,\n",
193
+ " pipeline.tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")]\n",
194
+ " outputs = pipeline(\n",
195
+ " prompt,\n",
196
+ " max_new_tokens = max_new_tokens,\n",
197
+ " eos_token_id = terminators,\n",
198
+ " do_sample = True,\n",
199
+ " temperature = temperature + 0.1,\n",
200
+ " top_p = 0.9,)\n",
201
+ " return outputs[0][\"generated_text\"][len(prompt):]"
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": 22,
207
+ "metadata": {},
208
+ "outputs": [
209
+ {
210
+ "name": "stdout",
211
+ "output_type": "stream",
212
+ "text": [
213
+ "Running on local URL: http://127.0.0.1:7867\n",
214
+ "\n",
215
+ "To create a public link, set `share=True` in `launch()`.\n"
216
+ ]
217
+ },
218
+ {
219
+ "data": {
220
+ "text/html": [
221
+ "<div><iframe src=\"http://127.0.0.1:7867/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
222
+ ],
223
+ "text/plain": [
224
+ "<IPython.core.display.HTML object>"
225
+ ]
226
+ },
227
+ "metadata": {},
228
+ "output_type": "display_data"
229
+ },
230
+ {
231
+ "data": {
232
+ "text/plain": []
233
+ },
234
+ "execution_count": 22,
235
+ "metadata": {},
236
+ "output_type": "execute_result"
237
+ }
238
+ ],
239
+ "source": [
240
+ "gr.ChatInterface(\n",
241
+ " chat_function,\n",
242
+ " textbox=gr.Textbox(placeholder=\"Enter message here\", container=False, scale = 7),\n",
243
+ " chatbot=gr.Chatbot(height=400),\n",
244
+ " additional_inputs=[\n",
245
+ " gr.Textbox(\"You are helpful AI\", label=\"System Prompt\"),\n",
246
+ " gr.Slider(500,4000, label=\"Max New Tokens\"),\n",
247
+ " gr.Slider(0,1, label=\"Temperature\")\n",
248
+ " ]\n",
249
+ " ).launch()"
250
+ ]
251
+ }
252
+ ],
253
+ "metadata": {
254
+ "kernelspec": {
255
+ "display_name": "llama3",
256
+ "language": "python",
257
+ "name": "python3"
258
+ },
259
+ "language_info": {
260
+ "codemirror_mode": {
261
+ "name": "ipython",
262
+ "version": 3
263
+ },
264
+ "file_extension": ".py",
265
+ "mimetype": "text/x-python",
266
+ "name": "python",
267
+ "nbconvert_exporter": "python",
268
+ "pygments_lexer": "ipython3",
269
+ "version": "3.10.14"
270
+ }
271
+ },
272
+ "nbformat": 4,
273
+ "nbformat_minor": 2
274
+ }
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ trasformers
2
+ torch
3
+ python-dotenv