ChatGPT-Tune / chat.py
prelington's picture
Update chat.py
efce57f verified
# chat.py
from model_loader import load_model
from config import DEFAULT_MODEL, MAX_TOKENS, TEMPERATURE
import torch
conversation_memory = {}
current_model_name = DEFAULT_MODEL
tokenizer, model = load_model(current_model_name)
def switch_model(model_name):
global tokenizer, model, current_model_name
tokenizer, model = load_model(model_name)
current_model_name = model_name
if model_name not in conversation_memory:
conversation_memory[model_name] = []
return f"Switched to model: {model_name}"
def generate_response_stream(prompt, max_length=MAX_TOKENS, temperature=TEMPERATURE):
"""Streaming response with memory and optimized memory usage"""
global conversation_memory
if current_model_name not in conversation_memory:
conversation_memory[current_model_name] = []
history = conversation_memory[current_model_name]
history.append(f"User: {prompt}")
full_prompt = "\n".join(history) + "\nAI:"
# Tokenize in small batches to save memory
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
do_sample=True,
temperature=temperature,
pad_token_id=tokenizer.eos_token_id,
streamer=None
)[0]
decoded_text = ""
for token_id in outputs[len(inputs["input_ids"][0]):]:
decoded_token = tokenizer.decode(token_id)
decoded_text += decoded_token
yield decoded_text
conversation_memory[current_model_name].append(f"AI: {decoded_text}")
def reset_conversation():
global conversation_memory
conversation_memory[current_model_name] = []
return "Conversation reset."