BytArch commited on
Commit
6d369e1
·
verified ·
1 Parent(s): dd5220e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -7
app.py CHANGED
@@ -37,11 +37,11 @@ SYSTEM_PROMPT = (
37
 
38
 
39
  def build_context(user_message):
40
- return f"{SYSTEM_PROMPT}\n\nUser: {user_message}"
 
41
 
42
 
43
 
44
- # Generate response
45
  def generate_response(
46
  prompt,
47
  max_tokens=300,
@@ -74,13 +74,27 @@ def generate_response(
74
  eos_token_id=tokenizer.eos_token_id,
75
  )
76
 
 
77
  new_tokens = outputs[0][inputs.input_ids.shape[-1]:]
78
  response = tokenizer.decode(new_tokens, skip_special_tokens=True)
79
-
80
-
81
- response = response.replace("<|im_end|>", "")
82
-
83
- return response.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
 
86
  # Respond function for Gradio
 
37
 
38
 
39
  def build_context(user_message):
40
+ return SYSTEM_PROMPT + "\n\nUser: " + user_message + "\n\nAssistant:"
41
+
42
 
43
 
44
 
 
45
  def generate_response(
46
  prompt,
47
  max_tokens=300,
 
74
  eos_token_id=tokenizer.eos_token_id,
75
  )
76
 
77
+ # Take only newly generated tokens
78
  new_tokens = outputs[0][inputs.input_ids.shape[-1]:]
79
  response = tokenizer.decode(new_tokens, skip_special_tokens=True)
80
+
81
+ # Remove leftover special tokens
82
+ response = response.replace("<|im_end|>", "").strip()
83
+
84
+ # Remove any label and text following it on the same line
85
+ lines = response.splitlines()
86
+ cleaned_lines = []
87
+ for line in lines:
88
+ for label in ["Assistant:", "assistant:", "User:", "user:"]:
89
+ if label in line:
90
+ line = line.split(label)[0].strip()
91
+ if line: # keep non-empty lines
92
+ cleaned_lines.append(line)
93
+
94
+ response = "\n".join(cleaned_lines)
95
+
96
+ return response
97
+
98
 
99
 
100
  # Respond function for Gradio