prelington commited on
Commit
efce57f
·
verified ·
1 Parent(s): e71c280

Update chat.py

Browse files
Files changed (1) hide show
  1. chat.py +9 -17
chat.py CHANGED
@@ -3,15 +3,11 @@ from model_loader import load_model
3
  from config import DEFAULT_MODEL, MAX_TOKENS, TEMPERATURE
4
  import torch
5
 
6
- # Conversation memory per model
7
  conversation_memory = {}
8
-
9
- # Load default model
10
  current_model_name = DEFAULT_MODEL
11
  tokenizer, model = load_model(current_model_name)
12
 
13
  def switch_model(model_name):
14
- """Switch to a different model"""
15
  global tokenizer, model, current_model_name
16
  tokenizer, model = load_model(model_name)
17
  current_model_name = model_name
@@ -20,39 +16,35 @@ def switch_model(model_name):
20
  return f"Switched to model: {model_name}"
21
 
22
  def generate_response_stream(prompt, max_length=MAX_TOKENS, temperature=TEMPERATURE):
23
- """Stream response token by token for typing effect"""
24
  global conversation_memory
25
-
26
  if current_model_name not in conversation_memory:
27
  conversation_memory[current_model_name] = []
28
-
29
  history = conversation_memory[current_model_name]
30
  history.append(f"User: {prompt}")
31
  full_prompt = "\n".join(history) + "\nAI:"
32
-
33
- inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
34
 
35
- # Generate with streaming
36
- output_ids = model.generate(
 
37
  **inputs,
38
- max_length=max_length,
39
  do_sample=True,
40
  temperature=temperature,
41
- pad_token_id=tokenizer.eos_token_id
 
42
  )[0]
43
 
44
- # Decode token by token
45
  decoded_text = ""
46
- for token_id in output_ids[len(inputs["input_ids"][0]):]: # Skip prompt tokens
47
  decoded_token = tokenizer.decode(token_id)
48
  decoded_text += decoded_token
49
  yield decoded_text
50
 
51
- # Save to conversation memory
52
  conversation_memory[current_model_name].append(f"AI: {decoded_text}")
53
 
54
  def reset_conversation():
55
- """Reset memory for current model"""
56
  global conversation_memory
57
  conversation_memory[current_model_name] = []
58
  return "Conversation reset."
 
3
  from config import DEFAULT_MODEL, MAX_TOKENS, TEMPERATURE
4
  import torch
5
 
 
6
  conversation_memory = {}
 
 
7
  current_model_name = DEFAULT_MODEL
8
  tokenizer, model = load_model(current_model_name)
9
 
10
  def switch_model(model_name):
 
11
  global tokenizer, model, current_model_name
12
  tokenizer, model = load_model(model_name)
13
  current_model_name = model_name
 
16
  return f"Switched to model: {model_name}"
17
 
18
  def generate_response_stream(prompt, max_length=MAX_TOKENS, temperature=TEMPERATURE):
19
+ """Streaming response with memory and optimized memory usage"""
20
  global conversation_memory
 
21
  if current_model_name not in conversation_memory:
22
  conversation_memory[current_model_name] = []
23
+
24
  history = conversation_memory[current_model_name]
25
  history.append(f"User: {prompt}")
26
  full_prompt = "\n".join(history) + "\nAI:"
 
 
27
 
28
+ # Tokenize in small batches to save memory
29
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
30
+ outputs = model.generate(
31
  **inputs,
32
+ max_new_tokens=max_length,
33
  do_sample=True,
34
  temperature=temperature,
35
+ pad_token_id=tokenizer.eos_token_id,
36
+ streamer=None
37
  )[0]
38
 
 
39
  decoded_text = ""
40
+ for token_id in outputs[len(inputs["input_ids"][0]):]:
41
  decoded_token = tokenizer.decode(token_id)
42
  decoded_text += decoded_token
43
  yield decoded_text
44
 
 
45
  conversation_memory[current_model_name].append(f"AI: {decoded_text}")
46
 
47
  def reset_conversation():
 
48
  global conversation_memory
49
  conversation_memory[current_model_name] = []
50
  return "Conversation reset."