File size: 1,710 Bytes
ab9ffec
 
3143e00
9d82262
ab9ffec
3143e00
 
 
 
 
 
 
 
 
 
 
80d089c
9d82262
efce57f
3143e00
 
 
efce57f
3143e00
 
 
9d82262
efce57f
 
 
ab9ffec
efce57f
ab9ffec
 
efce57f
 
9d82262
 
 
efce57f
9d82262
 
 
 
 
80d089c
 
3143e00
 
80d089c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# chat.py
from model_loader import load_model
from config import DEFAULT_MODEL, MAX_TOKENS, TEMPERATURE
import torch

conversation_memory = {}
current_model_name = DEFAULT_MODEL
tokenizer, model = load_model(current_model_name)

def switch_model(model_name):
    global tokenizer, model, current_model_name
    tokenizer, model = load_model(model_name)
    current_model_name = model_name
    if model_name not in conversation_memory:
        conversation_memory[model_name] = []
    return f"Switched to model: {model_name}"

def generate_response_stream(prompt, max_length=MAX_TOKENS, temperature=TEMPERATURE):
    """Streaming response with memory and optimized memory usage"""
    global conversation_memory
    if current_model_name not in conversation_memory:
        conversation_memory[current_model_name] = []

    history = conversation_memory[current_model_name]
    history.append(f"User: {prompt}")
    full_prompt = "\n".join(history) + "\nAI:"

    # Tokenize in small batches to save memory
    inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_length,
        do_sample=True,
        temperature=temperature,
        pad_token_id=tokenizer.eos_token_id,
        streamer=None
    )[0]

    decoded_text = ""
    for token_id in outputs[len(inputs["input_ids"][0]):]:
        decoded_token = tokenizer.decode(token_id)
        decoded_text += decoded_token
        yield decoded_text

    conversation_memory[current_model_name].append(f"AI: {decoded_text}")

def reset_conversation():
    global conversation_memory
    conversation_memory[current_model_name] = []
    return "Conversation reset."