Upload handler.py
Browse files- handler.py +29 -5
handler.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
-
from typing import Dict, Any
|
| 2 |
import logging
|
| 3 |
-
|
| 4 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 5 |
from peft import PeftConfig, PeftModel
|
| 6 |
import torch.cuda
|
| 7 |
|
|
@@ -105,5 +104,30 @@ def generate(
|
|
| 105 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
| 106 |
input_ids = input_ids.to(model.device)
|
| 107 |
|
| 108 |
-
#
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
+
from typing import Dict, Any
|
| 3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 4 |
from peft import PeftConfig, PeftModel
|
| 5 |
import torch.cuda
|
| 6 |
|
|
|
|
| 104 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
| 105 |
input_ids = input_ids.to(model.device)
|
| 106 |
|
| 107 |
+
# Create a TextIteratorStreamer instance
|
| 108 |
+
streamer = TextIteratorStreamer(
|
| 109 |
+
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Generate the response using TextIteratorStreamer
|
| 113 |
+
generate_kwargs = dict(
|
| 114 |
+
{"input_ids": input_ids},
|
| 115 |
+
streamer=streamer,
|
| 116 |
+
max_new_tokens=max_new_tokens,
|
| 117 |
+
do_sample=True,
|
| 118 |
+
top_p=top_p,
|
| 119 |
+
top_k=top_k,
|
| 120 |
+
temperature=temperature,
|
| 121 |
+
num_beams=1,
|
| 122 |
+
repetition_penalty=repetition_penalty,
|
| 123 |
+
)
|
| 124 |
+
model.generate(**generate_kwargs)
|
| 125 |
+
|
| 126 |
+
outputs = []
|
| 127 |
+
for text in streamer:
|
| 128 |
+
outputs.append(text)
|
| 129 |
+
if "[/INST]" in "".join(outputs):
|
| 130 |
+
return "".join(outputs).replace("[/INST]","")
|
| 131 |
+
if "[INST]" in "".join(outputs):
|
| 132 |
+
return "".join(outputs).replace("[INST]","")
|
| 133 |
+
return "".join(outputs)
|