# chat.py from model_loader import load_model from config import DEFAULT_MODEL, MAX_TOKENS, TEMPERATURE import torch conversation_memory = {} current_model_name = DEFAULT_MODEL tokenizer, model = load_model(current_model_name) def switch_model(model_name): global tokenizer, model, current_model_name tokenizer, model = load_model(model_name) current_model_name = model_name if model_name not in conversation_memory: conversation_memory[model_name] = [] return f"Switched to model: {model_name}" def generate_response_stream(prompt, max_length=MAX_TOKENS, temperature=TEMPERATURE): """Streaming response with memory and optimized memory usage""" global conversation_memory if current_model_name not in conversation_memory: conversation_memory[current_model_name] = [] history = conversation_memory[current_model_name] history.append(f"User: {prompt}") full_prompt = "\n".join(history) + "\nAI:" # Tokenize in small batches to save memory inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) outputs = model.generate( **inputs, max_new_tokens=max_length, do_sample=True, temperature=temperature, pad_token_id=tokenizer.eos_token_id, streamer=None )[0] decoded_text = "" for token_id in outputs[len(inputs["input_ids"][0]):]: decoded_token = tokenizer.decode(token_id) decoded_text += decoded_token yield decoded_text conversation_memory[current_model_name].append(f"AI: {decoded_text}") def reset_conversation(): global conversation_memory conversation_memory[current_model_name] = [] return "Conversation reset."