|
|
|
|
|
from model_loader import load_model |
|
|
from config import DEFAULT_MODEL, MAX_TOKENS, TEMPERATURE |
|
|
import torch |
|
|
|
|
|
conversation_memory = {} |
|
|
current_model_name = DEFAULT_MODEL |
|
|
tokenizer, model = load_model(current_model_name) |
|
|
|
|
|
def switch_model(model_name): |
|
|
global tokenizer, model, current_model_name |
|
|
tokenizer, model = load_model(model_name) |
|
|
current_model_name = model_name |
|
|
if model_name not in conversation_memory: |
|
|
conversation_memory[model_name] = [] |
|
|
return f"Switched to model: {model_name}" |
|
|
|
|
|
def generate_response_stream(prompt, max_length=MAX_TOKENS, temperature=TEMPERATURE): |
|
|
"""Streaming response with memory and optimized memory usage""" |
|
|
global conversation_memory |
|
|
if current_model_name not in conversation_memory: |
|
|
conversation_memory[current_model_name] = [] |
|
|
|
|
|
history = conversation_memory[current_model_name] |
|
|
history.append(f"User: {prompt}") |
|
|
full_prompt = "\n".join(history) + "\nAI:" |
|
|
|
|
|
|
|
|
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_length, |
|
|
do_sample=True, |
|
|
temperature=temperature, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
streamer=None |
|
|
)[0] |
|
|
|
|
|
decoded_text = "" |
|
|
for token_id in outputs[len(inputs["input_ids"][0]):]: |
|
|
decoded_token = tokenizer.decode(token_id) |
|
|
decoded_text += decoded_token |
|
|
yield decoded_text |
|
|
|
|
|
conversation_memory[current_model_name].append(f"AI: {decoded_text}") |
|
|
|
|
|
def reset_conversation(): |
|
|
global conversation_memory |
|
|
conversation_memory[current_model_name] = [] |
|
|
return "Conversation reset." |
|
|
|