import json from threading import Thread from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import torch from .configuration_phi import PhiConfig from .modeling_phi import PhiForCausalLM # This works, but is not streaming """ if __name__ == "__main__": device = "cuda" model_config = PhiConfig(**json.load(open("simplified_phi2/config.json"))) model = PhiForCausalLM(model_config).to(device) phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True) model.load_state_dict(phi_model.state_dict()) tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) text = "Write an essay on sea monkeys: " tokens = tokenizer(text, return_tensors="pt", return_attention_mask=False).to(device) outputs = model.generate(**tokens, max_length=200) text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] print(text) """ # This is streaming, but does not work because you can't set trust_remote_code=True """ if __name__ == "__main__": client = InferenceClient(model="microsoft/phi-2") text = "How do you make cheese?" for token in client.text_generation(text, max_new_tokens=500, stream=True): print(token, end="") """ # This is trying the TextIteratorStreamer class if __name__ == "__main__": # make and load tokenizer, use tokenizer to initialize token_streamer tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) token_streamer = TextIteratorStreamer(tokenizer) # make model and run model.generate(streamer=TextIteratorStreamer) on a thread device = "cuda" model_config = PhiConfig(**json.load(open("simplified_phi2/config.json"))) model = PhiForCausalLM(model_config).to(device) phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True) model.load_state_dict(phi_model.state_dict()) thread = Thread( target=model.generate, kwargs=dict( tokenizer( # returns a torch dictionary "Here is an essay on sea monkeys: ", return_tensors="pt", return_attention_mask=False, ).to(device), streamer=token_streamer, max_new_tokens=500, eos_token_id=tokenizer.eos_token_id, ), ) thread.start() # generate my_output = "" for new_token in token_streamer: my_output += new_token print(new_token, end="", flush=True) print()