prelington commited on
Commit
3143e00
·
verified ·
1 Parent(s): 7381a1f

Update chat.py

Browse files
Files changed (1) hide show
  1. chat.py +27 -23
chat.py CHANGED
@@ -1,26 +1,35 @@
1
  # chat.py
2
- import torch
3
  from model_loader import load_model
4
- from config import MAX_TOKENS, TEMPERATURE
5
 
6
- tokenizer, model = load_model()
 
7
 
8
- # Conversation memory: stores previous exchanges
9
- conversation_history = []
 
 
 
 
 
 
 
 
 
 
10
 
11
  def generate_response(prompt, max_length=MAX_TOKENS, temperature=TEMPERATURE):
12
- """
13
- Generate a response from the model with memory of previous conversation.
14
- """
15
- global conversation_history
 
16
 
17
- # Add user prompt to conversation
18
- conversation_history.append(f"User: {prompt}")
19
 
20
- # Combine conversation history for context
21
- full_prompt = "\n".join(conversation_history) + "\nAI:"
22
 
23
- # Tokenize and generate
24
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
25
  outputs = model.generate(
26
  **inputs,
@@ -30,19 +39,14 @@ def generate_response(prompt, max_length=MAX_TOKENS, temperature=TEMPERATURE):
30
  pad_token_id=tokenizer.eos_token_id
31
  )
32
 
33
- # Decode response
34
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
35
-
36
- # Extract AI's latest message
37
  response = response.split("AI:")[-1].strip()
38
 
39
- # Add AI response to conversation history
40
- conversation_history.append(f"AI: {response}")
41
-
42
  return response
43
 
44
  def reset_conversation():
45
- """Reset conversation memory"""
46
- global conversation_history
47
- conversation_history = []
48
  return "Conversation reset."
 
1
  # chat.py
 
2
  from model_loader import load_model
3
+ from config import DEFAULT_MODEL, MAX_TOKENS, TEMPERATURE
4
 
5
+ # Conversation memory per model
6
+ conversation_memory = {}
7
 
8
+ # Load default model
9
+ current_model_name = DEFAULT_MODEL
10
+ tokenizer, model = load_model(current_model_name)
11
+
12
+ def switch_model(model_name):
13
+ """Switch to a different model"""
14
+ global tokenizer, model, current_model_name
15
+ tokenizer, model = load_model(model_name)
16
+ current_model_name = model_name
17
+ if model_name not in conversation_memory:
18
+ conversation_memory[model_name] = []
19
+ return f"Switched to model: {model_name}"
20
 
21
  def generate_response(prompt, max_length=MAX_TOKENS, temperature=TEMPERATURE):
22
+ """Generate response with conversation memory per model"""
23
+ global conversation_memory
24
+
25
+ if current_model_name not in conversation_memory:
26
+ conversation_memory[current_model_name] = []
27
 
28
+ history = conversation_memory[current_model_name]
 
29
 
30
+ history.append(f"User: {prompt}")
31
+ full_prompt = "\n".join(history) + "\nAI:"
32
 
 
33
  inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
34
  outputs = model.generate(
35
  **inputs,
 
39
  pad_token_id=tokenizer.eos_token_id
40
  )
41
 
 
42
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
43
  response = response.split("AI:")[-1].strip()
44
 
45
+ history.append(f"AI: {response}")
 
 
46
  return response
47
 
48
  def reset_conversation():
49
+ """Reset memory for current model"""
50
+ global conversation_memory
51
+ conversation_memory[current_model_name] = []
52
  return "Conversation reset."