simplified_phi2 / streaming_inference.py
BucketOfFish's picture
Removed output check
7da2fc9
raw
history blame
2.56 kB
import json
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from .configuration_phi import PhiConfig
from .modeling_phi import PhiForCausalLM
# This works, but is not streaming
"""
if __name__ == "__main__":
device = "cuda"
model_config = PhiConfig(**json.load(open("simplified_phi2/config.json")))
model = PhiForCausalLM(model_config).to(device)
phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
model.load_state_dict(phi_model.state_dict())
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
text = "Write an essay on sea monkeys: "
tokens = tokenizer(text, return_tensors="pt", return_attention_mask=False).to(device)
outputs = model.generate(**tokens, max_length=200)
text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
print(text)
"""
# This is streaming, but does not work because you can't set trust_remote_code=True
"""
if __name__ == "__main__":
client = InferenceClient(model="microsoft/phi-2")
text = "How do you make cheese?"
for token in client.text_generation(text, max_new_tokens=500, stream=True):
print(token, end="")
"""
# This is trying the TextIteratorStreamer class
if __name__ == "__main__":
# make and load tokenizer, use tokenizer to initialize token_streamer
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
token_streamer = TextIteratorStreamer(tokenizer)
# make model and run model.generate(streamer=TextIteratorStreamer) on a thread
device = "cuda"
model_config = PhiConfig(**json.load(open("simplified_phi2/config.json")))
model = PhiForCausalLM(model_config).to(device)
phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
model.load_state_dict(phi_model.state_dict())
thread = Thread(
target=model.generate,
kwargs=dict(
tokenizer( # returns a torch dictionary
"Here is an essay on sea monkeys: ",
return_tensors="pt",
return_attention_mask=False,
).to(device),
streamer=token_streamer,
max_new_tokens=500,
eos_token_id=tokenizer.eos_token_id,
),
)
thread.start()
# generate
my_output = ""
for new_token in token_streamer:
my_output += new_token
print(new_token, end="", flush=True)
print()