codys12 commited on
Commit
75fa31b
·
1 Parent(s): a4bbf5f

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +29 -5
handler.py CHANGED
@@ -1,7 +1,6 @@
1
- from typing import Dict, Any
2
  import logging
3
-
4
- from transformers import AutoModelForCausalLM, AutoTokenizer
5
  from peft import PeftConfig, PeftModel
6
  import torch.cuda
7
 
@@ -105,5 +104,30 @@ def generate(
105
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
106
  input_ids = input_ids.to(model.device)
107
 
108
- # Generate the response
109
- return tokenizer.decode(model.generate(input_ids, max_new_tokens=max_new_tokens))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
+ from typing import Dict, Any
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  from peft import PeftConfig, PeftModel
5
  import torch.cuda
6
 
 
104
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
105
  input_ids = input_ids.to(model.device)
106
 
107
+ # Create a TextIteratorStreamer instance
108
+ streamer = TextIteratorStreamer(
109
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False
110
+ )
111
+
112
+ # Generate the response using TextIteratorStreamer
113
+ generate_kwargs = dict(
114
+ {"input_ids": input_ids},
115
+ streamer=streamer,
116
+ max_new_tokens=max_new_tokens,
117
+ do_sample=True,
118
+ top_p=top_p,
119
+ top_k=top_k,
120
+ temperature=temperature,
121
+ num_beams=1,
122
+ repetition_penalty=repetition_penalty,
123
+ )
124
+ model.generate(**generate_kwargs)
125
+
126
+ outputs = []
127
+ for text in streamer:
128
+ outputs.append(text)
129
+ if "[/INST]" in "".join(outputs):
130
+ return "".join(outputs).replace("[/INST]","")
131
+ if "[INST]" in "".join(outputs):
132
+ return "".join(outputs).replace("[INST]","")
133
+ return "".join(outputs)