prelington commited on
Commit
9d82262
·
verified ·
1 Parent(s): b7da0c3

Update chat.py

Browse files
Files changed (1) hide show
  1. chat.py +17 -11
chat.py CHANGED
@@ -1,6 +1,7 @@
1
  # chat.py
2
  from model_loader import load_model
3
  from config import DEFAULT_MODEL, MAX_TOKENS, TEMPERATURE
 
4
 
5
  # Conversation memory per model
6
  conversation_memory = {}
@@ -18,32 +19,37 @@ def switch_model(model_name):
18
  conversation_memory[model_name] = []
19
  return f"Switched to model: {model_name}"
20
 
21
- def generate_response(prompt, max_length=MAX_TOKENS, temperature=TEMPERATURE):
22
- """Generate response with conversation memory per model"""
23
  global conversation_memory
24
 
25
  if current_model_name not in conversation_memory:
26
  conversation_memory[current_model_name] = []
27
 
28
  history = conversation_memory[current_model_name]
29
-
30
  history.append(f"User: {prompt}")
31
  full_prompt = "\n".join(history) + "\nAI:"
32
 
33
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
34
- outputs = model.generate(
 
 
35
  **inputs,
36
  max_length=max_length,
37
  do_sample=True,
38
  temperature=temperature,
39
  pad_token_id=tokenizer.eos_token_id
40
- )
41
-
42
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
- response = response.split("AI:")[-1].strip()
44
-
45
- history.append(f"AI: {response}")
46
- return response
 
 
 
 
47
 
48
  def reset_conversation():
49
  """Reset memory for current model"""
 
1
  # chat.py
2
  from model_loader import load_model
3
  from config import DEFAULT_MODEL, MAX_TOKENS, TEMPERATURE
4
+ import torch
5
 
6
  # Conversation memory per model
7
  conversation_memory = {}
 
19
  conversation_memory[model_name] = []
20
  return f"Switched to model: {model_name}"
21
 
22
+ def generate_response_stream(prompt, max_length=MAX_TOKENS, temperature=TEMPERATURE):
23
+ """Stream response token by token for typing effect"""
24
  global conversation_memory
25
 
26
  if current_model_name not in conversation_memory:
27
  conversation_memory[current_model_name] = []
28
 
29
  history = conversation_memory[current_model_name]
 
30
  history.append(f"User: {prompt}")
31
  full_prompt = "\n".join(history) + "\nAI:"
32
 
33
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
34
+
35
+ # Generate with streaming
36
+ output_ids = model.generate(
37
  **inputs,
38
  max_length=max_length,
39
  do_sample=True,
40
  temperature=temperature,
41
  pad_token_id=tokenizer.eos_token_id
42
+ )[0]
43
+
44
+ # Decode token by token
45
+ decoded_text = ""
46
+ for token_id in output_ids[len(inputs["input_ids"][0]):]: # Skip prompt tokens
47
+ decoded_token = tokenizer.decode(token_id)
48
+ decoded_text += decoded_token
49
+ yield decoded_text
50
+
51
+ # Save to conversation memory
52
+ conversation_memory[current_model_name].append(f"AI: {decoded_text}")
53
 
54
  def reset_conversation():
55
  """Reset memory for current model"""