File size: 2,713 Bytes
2807ff7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import os
from typing import Optional, Union
from transformers import AutoModel, AutoTokenizer, LogitsProcessorList
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
def batch(
model,
tokenizer,
prompts: Union[str, list[str]],
max_length: int = 8192,
num_beams: int = 1,
do_sample: bool = True,
top_p: float = 0.8,
temperature: float = 0.8,
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
):
tokenizer.encode_special_tokens = True
if isinstance(prompts, str):
prompts = [prompts]
batched_inputs = tokenizer(prompts, return_tensors="pt", padding="longest")
batched_inputs = batched_inputs.to(model.device)
eos_token_id = [
tokenizer.eos_token_id,
tokenizer.get_command("<|user|>"),
tokenizer.get_command("<|assistant|>"),
]
gen_kwargs = {
"max_length": max_length,
"num_beams": num_beams,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
"logits_processor": logits_processor,
"eos_token_id": eos_token_id,
}
batched_outputs = model.generate(**batched_inputs, **gen_kwargs)
batched_response = []
for input_ids, output_ids in zip(batched_inputs.input_ids, batched_outputs):
decoded_text = tokenizer.decode(output_ids[len(input_ids):])
batched_response.append(decoded_text.strip())
return batched_response
def main(batch_queries):
gen_kwargs = {
"max_length": 2048,
"do_sample": True,
"top_p": 0.8,
"temperature": 0.8,
"num_beams": 1,
}
batch_responses = batch(model, tokenizer, batch_queries, **gen_kwargs)
return batch_responses
if __name__ == "__main__":
batch_queries = [
"<|user|>\n讲个故事\n<|assistant|>",
"<|user|>\n讲个爱情故事\n<|assistant|>",
"<|user|>\n讲个开心故事\n<|assistant|>",
"<|user|>\n讲个睡前故事\n<|assistant|>",
"<|user|>\n讲个励志的故事\n<|assistant|>",
"<|user|>\n讲个少壮不努力的故事\n<|assistant|>",
"<|user|>\n讲个青春校园恋爱故事\n<|assistant|>",
"<|user|>\n讲个工作故事\n<|assistant|>",
"<|user|>\n讲个旅游的故事\n<|assistant|>",
]
batch_responses = main(batch_queries)
for response in batch_responses:
print("=" * 10)
print(response)
|