PEFT
nisten commited on
Commit
f63c8cd
·
1 Parent(s): 947d408

added inference UI code, testing the model halfway through training

Files changed (1) hide show
  1. README.md +321 -1
README.md CHANGED
@@ -1,7 +1,8 @@
1
  ---
2
  library_name: peft
 
3
  ---
4
- ## Training procedure
5
 
6
 
7
  The following `bitsandbytes` quantization config was used during training:
@@ -19,3 +20,322 @@ The following `bitsandbytes` quantization config was used during training:
19
 
20
 
21
  - PEFT 0.6.0.dev0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  library_name: peft
3
+ license: mit
4
  ---
5
+ ## Training procedure (only 2000/5000 complete)
6
 
7
 
8
  The following `bitsandbytes` quantization config was used during training:
 
20
 
21
 
22
  - PEFT 0.6.0.dev0
23
+
24
+ - To load start a jupyter notebook, here it is all in 2 parts
25
+
26
+ ```
27
+ !pip install -q -U bitsandbytes
28
+ !pip install -q -U git+https://github.com/huggingface/transformers.git
29
+ !pip install -q -U git+https://github.com/huggingface/peft.git
30
+ !pip install -q -U git+https://github.com/huggingface/accelerate.git
31
+ !pip install -q -U gradio
32
+ !pip install -q -U sentencepiece
33
+
34
+
35
+
36
+ import torch
37
+ from peft import PeftModel
38
+ from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
39
+
40
+ model_name = "TheBloke/CodeLlama-34B-Python-fp16"
41
+ adapters_name = 'nisten/bigdoc-c34b-python-v1'
42
+
43
+ print(f"Starting to load the model {model_name} into memory")
44
+
45
+ m = AutoModelForCausalLM.from_pretrained(
46
+ model_name,
47
+ load_in_4bit=True, #19GB in 4bit, 38GB with load_in_8bit, 67GB in full f16 if you just delete this line
48
+ torch_dtype=torch.bfloat16,
49
+ device_map={"": 0}
50
+ )
51
+ m = PeftModel.from_pretrained(m, adapters_name)
52
+ m = m.merge_and_unload()
53
+ tok = AutoTokenizer.from_pretrained(model_name)
54
+ eos_token_id = tok.convert_tokens_to_ids('/s')
55
+ tok.eos_token = '/s'
56
+ tok.pad_token = tok.eos_token
57
+ tok.padding_side = 'right'
58
+ tok.eos_token_id = eos_token_id
59
+ stop_token_ids = eos_token_id
60
+
61
+ print(f"Successfully loaded the model {model_name} into memory")
62
+
63
+ ```
64
+
65
+ ### And now for the UI
66
+
67
+ ```
68
+
69
+ # Setup the gradio Demo.
70
+
71
+ import datetime
72
+ import os
73
+ from threading import Event, Thread
74
+ from uuid import uuid4
75
+ from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
76
+ import gradio as gr
77
+ import requests
78
+
79
+ max_new_tokens = 2369
80
+ start_message = """A chat between a chill human asking ( Question: ) and an AI doctor ( Answer: ). The doctor answers in helpful, detailed, and exhaustively nerdy extensive answers to the user's every medical Question:"""
81
+
82
+ class StopOnTokens(StoppingCriteria):
83
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
84
+ if isinstance(stop_token_ids, (list, torch.Tensor)):
85
+ for stop_id in stop_token_ids:
86
+ if stop_id in input_ids[0]:
87
+ return True
88
+ else: # Assumes scalar
89
+ if input_ids[0][-1] == stop_token_ids:
90
+ return True
91
+ return False
92
+
93
+
94
+ def convert_history_to_text(history):
95
+ text = start_message + "".join(
96
+ [
97
+ "".join(
98
+ [
99
+ f" Question: {item[0]}\n",
100
+ f"\n\n Answer: {item[1]}\n",
101
+ ]
102
+ )
103
+ for item in history[:-1]
104
+ ]
105
+ )
106
+ text += "".join(
107
+ [
108
+ "".join(
109
+ [
110
+ f" Question: {history[-1][0]}\n",
111
+ f"\n\n Answer: {history[-1][1]}\n",
112
+ ]
113
+ )
114
+ ]
115
+ )
116
+ return text
117
+
118
+
119
+ def log_conversation(conversation_id, history, messages, generate_kwargs):
120
+ logging_url = os.getenv("LOGGING_URL", None)
121
+ if logging_url is None:
122
+ return
123
+
124
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
125
+
126
+ data = {
127
+ "conversation_id": conversation_id,
128
+ "timestamp": timestamp,
129
+ "history": history,
130
+ "messages": messages,
131
+ "generate_kwargs": generate_kwargs,
132
+ }
133
+
134
+ try:
135
+ requests.post(logging_url, json=data)
136
+ except requests.exceptions.RequestException as e:
137
+ print(f"Error logging conversation: {e}")
138
+
139
+
140
+ def user(message, history):
141
+ # Append the user's message to the conversation history
142
+ return "", history + [[message, ""]]
143
+
144
+
145
+ def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
146
+ print(f"history: {history}")
147
+ # Initialize a StopOnTokens object
148
+ stop = StopOnTokens()
149
+
150
+ # Construct the input message string for the model by concatenating the current system message and conversation history
151
+ messages = convert_history_to_text(history)
152
+
153
+ # Tokenize the messages string
154
+ input_ids = tok(messages, return_tensors="pt").input_ids
155
+ input_ids = input_ids.to(m.device)
156
+ streamer = TextIteratorStreamer(tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
157
+ generate_kwargs = dict(
158
+ input_ids=input_ids,
159
+ max_new_tokens=max_new_tokens,
160
+ temperature=temperature,
161
+ do_sample=temperature > 0.0,
162
+ top_p=top_p,
163
+ top_k=top_k,
164
+ repetition_penalty=repetition_penalty,
165
+ streamer=streamer,
166
+ stopping_criteria=StoppingCriteriaList([stop]),
167
+ )
168
+
169
+ stream_complete = Event()
170
+
171
+ def generate_and_signal_complete():
172
+ m.generate(**generate_kwargs)
173
+ stream_complete.set()
174
+
175
+ def log_after_stream_complete():
176
+ stream_complete.wait()
177
+ log_conversation(
178
+ conversation_id,
179
+ history,
180
+ messages,
181
+ {
182
+ "top_k": top_k,
183
+ "top_p": top_p,
184
+ "temperature": temperature,
185
+ "repetition_penalty": repetition_penalty,
186
+ },
187
+ )
188
+
189
+ t1 = Thread(target=generate_and_signal_complete)
190
+ t1.start()
191
+
192
+ t2 = Thread(target=log_after_stream_complete)
193
+ t2.start()
194
+
195
+ # Initialize an empty string to store the generated text
196
+ partial_text = ""
197
+ for new_text in streamer:
198
+ partial_text += new_text
199
+ history[-1][1] = partial_text
200
+ yield history
201
+
202
+
203
+ def get_uuid():
204
+ return str(uuid4())
205
+
206
+
207
+ with gr.Blocks(
208
+ theme=gr.themes.Soft(),
209
+ css=".disclaimer {font-variant-caps: all-small-caps;}",
210
+ ) as demo:
211
+ conversation_id = gr.State(get_uuid)
212
+ gr.Markdown(
213
+ """<h1><center>Nisten's 34b Doctor v1</center></h1>
214
+ """
215
+ )
216
+ chatbot = gr.Chatbot().style(height=969)
217
+ with gr.Row():
218
+ with gr.Column():
219
+ msg = gr.Textbox(
220
+ label="Chat Message Box",
221
+ placeholder="Chat Message Box",
222
+ show_label=False,
223
+ ).style(container=False)
224
+ with gr.Column():
225
+ with gr.Row():
226
+ submit = gr.Button("Submit")
227
+ stop = gr.Button("Stop")
228
+ clear = gr.Button("Clear")
229
+ with gr.Row():
230
+ with gr.Accordion("Advanced Options:", open=False):
231
+ with gr.Row():
232
+ with gr.Column():
233
+ with gr.Row():
234
+ temperature = gr.Slider(
235
+ label="Temperature",
236
+ value=0.7,
237
+ minimum=0.0,
238
+ maximum=1.0,
239
+ step=0.1,
240
+ interactive=True,
241
+ info="Higher values produce more diverse outputs",
242
+ )
243
+ with gr.Column():
244
+ with gr.Row():
245
+ top_p = gr.Slider(
246
+ label="Top-p (nucleus sampling)",
247
+ value=0.9,
248
+ minimum=0.0,
249
+ maximum=1,
250
+ step=0.01,
251
+ interactive=True,
252
+ info=(
253
+ "Sample from the smallest possible set of tokens whose cumulative probability "
254
+ "exceeds top_p. Set to 1 to disable and sample from all tokens."
255
+ ),
256
+ )
257
+ with gr.Column():
258
+ with gr.Row():
259
+ top_k = gr.Slider(
260
+ label="Top-k",
261
+ value=0,
262
+ minimum=0.0,
263
+ maximum=200,
264
+ step=1,
265
+ interactive=True,
266
+ info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
267
+ )
268
+ with gr.Column():
269
+ with gr.Row():
270
+ repetition_penalty = gr.Slider(
271
+ label="Repetition Penalty",
272
+ value=1.1,
273
+ minimum=1.0,
274
+ maximum=2.0,
275
+ step=0.1,
276
+ interactive=True,
277
+ info="Penalize repetition — 1.0 to disable.",
278
+ )
279
+ with gr.Row():
280
+ gr.Markdown(
281
+ "Disclaimer: The model can produce factually incorrect output, and should not be relied on to produce "
282
+ "factually accurate information. The model was trained on various public datasets; while great efforts "
283
+ "have been taken to clean the pretraining data, it is possible that this model could generate lewd, "
284
+ "biased, or otherwise offensive outputs.",
285
+ elem_classes=["disclaimer"],
286
+ )
287
+ with gr.Row():
288
+ gr.Markdown(
289
+ "[Privacy policy](https://gist.github.com/samhavens/c29c68cdcd420a9aa0202d0839876dac)",
290
+ elem_classes=["disclaimer"],
291
+ )
292
+
293
+ submit_event = msg.submit(
294
+ fn=user,
295
+ inputs=[msg, chatbot],
296
+ outputs=[msg, chatbot],
297
+ queue=False,
298
+ ).then(
299
+ fn=bot,
300
+ inputs=[
301
+ chatbot,
302
+ temperature,
303
+ top_p,
304
+ top_k,
305
+ repetition_penalty,
306
+ conversation_id,
307
+ ],
308
+ outputs=chatbot,
309
+ queue=True,
310
+ )
311
+ submit_click_event = submit.click(
312
+ fn=user,
313
+ inputs=[msg, chatbot],
314
+ outputs=[msg, chatbot],
315
+ queue=False,
316
+ ).then(
317
+ fn=bot,
318
+ inputs=[
319
+ chatbot,
320
+ temperature,
321
+ top_p,
322
+ top_k,
323
+ repetition_penalty,
324
+ conversation_id,
325
+ ],
326
+ outputs=chatbot,
327
+ queue=True,
328
+ )
329
+ stop.click(
330
+ fn=None,
331
+ inputs=None,
332
+ outputs=None,
333
+ cancels=[submit_event, submit_click_event],
334
+ queue=False,
335
+ )
336
+ clear.click(lambda: None, None, chatbot, queue=False)
337
+
338
+ demo.queue(max_size=128, concurrency_count=2)
339
+ demo.launch( share = True ) #delete share = True () to make it private
340
+
341
+ ```