AHA-OLMO2 / generation.py
xuan-luo's picture
Upload folder using huggingface_hub
36e3d1b verified
from transformers import TextStreamer
import torch
from transformers.generation.streamers import BaseStreamer
import transformers
class TokenStreamer(BaseStreamer):
"""
Simple token streamer that prints each token with its corresponding layers used.
Parameters:
tokenizer (`AutoTokenizer`):
The tokenizer used to decode the tokens.
skip_prompt (`bool`, *optional*, defaults to `False`):
Whether to skip the prompt tokens in the output. Useful for chatbots.
"""
def __init__(self, tokenizer, skip_prompt=True):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.next_tokens_are_prompt = True
def put(self, value):
"""
Receives tokens and prints each one surrounded by brackets.
"""
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("TokenStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
# Process each token in the received tensor
for token_id in value.tolist():
token_text = self.tokenizer.decode([token_id])
print(f"={repr(token_text)}", end="\n", flush=True)
def end(self):
"""Prints a newline at the end of generation."""
self.next_tokens_are_prompt = True
print() # Print a newline at the end
# model path
model_id = "./"
# tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
trust_remote_code=True
)
messages = [
{"role": "user", "content": \
"""
Please continue writing the sequence:
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100
"""},
]
streamer = TokenStreamer(tokenizer)
outputs = pipeline(
messages,
max_new_tokens=300,
do_sample=True,
temperature=0.6,
top_p=1.0,
streamer=streamer,
)