prelington commited on
Commit
80d089c
·
verified ·
1 Parent(s): f910256

Update chat.py

Browse files
Files changed (1) hide show
  1. chat.py +29 -2
chat.py CHANGED
@@ -5,11 +5,23 @@ from config import MAX_TOKENS, TEMPERATURE
5
 
6
  tokenizer, model = load_model()
7
 
 
 
 
8
  def generate_response(prompt, max_length=MAX_TOKENS, temperature=TEMPERATURE):
9
  """
10
- Generate a response from the model
11
  """
12
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
 
 
13
  outputs = model.generate(
14
  **inputs,
15
  max_length=max_length,
@@ -17,5 +29,20 @@ def generate_response(prompt, max_length=MAX_TOKENS, temperature=TEMPERATURE):
17
  temperature=temperature,
18
  pad_token_id=tokenizer.eos_token_id
19
  )
 
 
20
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
21
  return response
 
 
 
 
 
 
 
5
 
6
  tokenizer, model = load_model()
7
 
8
+ # Conversation memory: stores previous exchanges
9
+ conversation_history = []
10
+
11
  def generate_response(prompt, max_length=MAX_TOKENS, temperature=TEMPERATURE):
12
  """
13
+ Generate a response from the model with memory of previous conversation.
14
  """
15
+ global conversation_history
16
+
17
+ # Add user prompt to conversation
18
+ conversation_history.append(f"User: {prompt}")
19
+
20
+ # Combine conversation history for context
21
+ full_prompt = "\n".join(conversation_history) + "\nAI:"
22
+
23
+ # Tokenize and generate
24
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
25
  outputs = model.generate(
26
  **inputs,
27
  max_length=max_length,
 
29
  temperature=temperature,
30
  pad_token_id=tokenizer.eos_token_id
31
  )
32
+
33
+ # Decode response
34
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
35
+
36
+ # Extract AI's latest message
37
+ response = response.split("AI:")[-1].strip()
38
+
39
+ # Add AI response to conversation history
40
+ conversation_history.append(f"AI: {response}")
41
+
42
  return response
43
+
44
+ def reset_conversation():
45
+ """Reset conversation memory"""
46
+ global conversation_history
47
+ conversation_history = []
48
+ return "Conversation reset."