| import transformers |
| from transformers import TextStreamer |
| import torch |
| from transformers.generation.streamers import BaseStreamer |
| from datasets import load_dataset |
| import random |
|
|
| class TokenStreamer(BaseStreamer): |
| """ |
| Simple token streamer that prints each token surrounded by brackets as soon as it's generated. |
| |
| Parameters: |
| tokenizer (`AutoTokenizer`): |
| The tokenizer used to decode the tokens. |
| skip_prompt (`bool`, *optional*, defaults to `False`): |
| Whether to skip the prompt tokens in the output. Useful for chatbots. |
| """ |
|
|
| def __init__(self, tokenizer, skip_prompt=True): |
| self.tokenizer = tokenizer |
| self.skip_prompt = skip_prompt |
| self.next_tokens_are_prompt = True |
|
|
| def put(self, value): |
| if len(value.shape) > 1 and value.shape[0] > 1: |
| raise ValueError("TokenStreamer only supports batch size 1") |
| elif len(value.shape) > 1: |
| value = value[0] |
|
|
| if self.skip_prompt and self.next_tokens_are_prompt: |
| self.next_tokens_are_prompt = False |
| return |
|
|
| for token_id in value.tolist(): |
| token_text = self.tokenizer.decode([token_id]) |
| print(f"={repr(token_text)}", end="\n", flush=True) |
|
|
| def end(self): |
| self.next_tokens_are_prompt = True |
| print() |
|
|
| model_id = "../" |
| tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
| tokenizer.pad_token = tokenizer.eos_token |
| model = transformers.AutoModelForCausalLM.from_pretrained( |
| model_id, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| trust_remote_code=True |
| ) |
|
|
| pipeline = transformers.pipeline( |
| "text-generation", |
| model=model, |
| tokenizer=tokenizer, |
| model_kwargs={"torch_dtype": torch.bfloat16}, |
| device_map="auto", |
| trust_remote_code=True |
| ) |
|
|
|
|
| dataset = load_dataset("EdinburghNLP/xsum", split="test") |
|
|
|
|
| random.seed(0) |
| sampled_indices = random.sample(range(len(dataset)), 100) |
| sampled_dataset = dataset.select(sampled_indices) |
|
|
|
|
| streamer = TokenStreamer(tokenizer) |
| for sample in sampled_dataset: |
| document = sample["document"] |
| prompt = "Please copy this paragraph: <paragraph>" + document + "</paragraph> Directly output the copied paragraph here: " |
| |
| messages = [ |
| {"role": "user", "content": prompt} |
| ] |
| |
| print("===") |
| outputs = pipeline( |
| messages, |
| max_new_tokens=64, |
| do_sample=True, |
| temperature=0.6, |
| top_p=1.0, |
| streamer=streamer, |
| ) |
| print("===") |