|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
model_name = "distilgpt2" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
|
|
conversation = "You are a kind AI assistant. Stay on topic.\n" |
|
|
print("\nType your messages below. Type 'quit' to exit.\n") |
|
|
|
|
|
while True: |
|
|
try: |
|
|
user_input = input("You: ") |
|
|
except KeyboardInterrupt: |
|
|
print("\nEnding chat. Bye!") |
|
|
break |
|
|
|
|
|
if user_input.lower() in ["quit", "exit"]: |
|
|
print("Ending chat. Bye!") |
|
|
break |
|
|
|
|
|
conversation += f"User: {user_input}\nAI:" |
|
|
|
|
|
|
|
|
inputs = tokenizer(conversation, return_tensors="pt", truncation=True, max_length=512) |
|
|
|
|
|
|
|
|
output = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=40, |
|
|
temperature=0.7, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
ai_response = tokenizer.decode(output[0], skip_special_tokens=True).split("AI:")[-1].strip() |
|
|
print(f"Baby AI: {ai_response}") |
|
|
|
|
|
|
|
|
conversation += f"{ai_response}\n" |
|
|
|