CodexTrouter / ProTalk_MemoryChat.py
prelington's picture
Create ProTalk_MemoryChat.py
e9e5efd verified
raw
history blame
1.79 kB
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
import threading
model_name = "microsoft/phi-2"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
low_cpu_mem_usage=True
).to(device)
system_prompt = (
"You are ProTalk, a professional AI assistant. "
"You remember everything the user said in this session and respond politely, "
"clearly, and intelligently. Keep a coherent conversation history."
)
chat_history = []
def chat_loop():
print("ProTalk Memory Chat Online — type 'exit' to quit.\n")
while True:
user_input = input("User: ")
if user_input.lower() == "exit":
break
chat_history.append(f"User: {user_input}")
prompt = system_prompt + "\n" + "\n".join(chat_history) + "\nProTalk:"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
thread = threading.Thread(target=model.generate, kwargs={
"input_ids": inputs["input_ids"],
"max_new_tokens": 300,
"do_sample": True,
"temperature": 0.7,
"top_p": 0.9,
"repetition_penalty": 1.2,
"streamer": streamer
})
thread.start()
output_text = ""
for token in streamer:
print(token, end="", flush=True)
output_text += token
thread.join()
print()
chat_history.append(f"ProTalk: {output_text}")
if __name__ == "__main__":
chat_loop()