BartPhi-2.8 / Model
TuringsSolutions's picture
Create Model
6e73fc0
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Define adjustable hyperparameters
temperature = 0.7 # Controls the randomness of the generated text
top_k = 50 # Only consider the top k most likely tokens when generating text
repetition_penalty = 1.2 # Penalizes the repetition of tokens in the generated text
# Load models
phi_model_name = "microsoft/phi-1_5"
tokenizer_name = phi_model_name
phi_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name).to("cuda")
assistant_model_name = "roneneldan/TinyStories-33M"
assistant_model = AutoModelForCausalLM.from_pretrained(assistant_model_name).to("cuda")
# Define generate function
def generate_response(user_input, assistant_model, phi_model, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty):
# Assistant generates initial story
inputs = phi_tokenizer(user_input, return_tensors="pt").to("cuda")
story = assistant_model.generate(**inputs, max_length=25, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty)
story_text = phi_tokenizer.decode(story[0], skip_special_tokens=True)
# Phi cleans it up
phi_inputs = phi_tokenizer(story_text, return_tensors="pt").to("cuda")
phi_inputs.pop("attention_mask")
cleaned_story = phi_model.generate(**phi_inputs, max_length=500, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty)
cleaned_text = phi_tokenizer.decode(cleaned_story[0], skip_special_tokens=True)
# Assistant refines it
inputs = phi_tokenizer(cleaned_text, return_tensors="pt").to("cuda")
refined_story = assistant_model.generate(**inputs, max_length=100, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty)
refined_text = phi_tokenizer.decode(refined_story[0], skip_special_tokens=True)
# Final cleanup by Phi
phi_inputs = phi_tokenizer(refined_text, return_tensors="pt").to("cuda")
phi_inputs.pop("attention_mask")
final_story = phi_model.generate(**phi_inputs, max_length=500, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty)
final_text = phi_tokenizer.decode(final_story[0], skip_special_tokens=True)
return final_text
# Adjust hyperparameters before loop begins execution
# For example:
# temperature = 0.6
# top_k = 100
# repetition_penalty = 1.5
# Interactive loop
while True:
user_input = input("You: ")
if user_input.lower() in ["exit", "quit"]:
print("Goodbye!")
break
response = generate_response(user_input, assistant_model, phi_model)
print("BartPhi-2.8:", response)