BytArch commited on
Commit
05ca5ff
·
verified ·
1 Parent(s): 3ff63d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -26
app.py CHANGED
@@ -9,6 +9,17 @@ model = AutoModelForCausalLM.from_pretrained(model_path)
9
  if tokenizer.pad_token is None:
10
  tokenizer.pad_token = tokenizer.eos_token
11
 
 
 
 
 
 
 
 
 
 
 
 
12
  def generate_response(
13
  prompt,
14
  system_message,
@@ -19,25 +30,7 @@ def generate_response(
19
  repetition_penalty=1.031,
20
  top_k=55,
21
  ):
22
- context = ""
23
- if conversation_history:
24
- recent = conversation_history[-30:] if len(conversation_history) > 30 else conversation_history
25
- is_first_message = False
26
- for i, message in enumerate(recent):
27
- if i == 0:
28
- is_first_message = True
29
- context += (
30
- f"<|start|>User:<|message|>{system_message}<|end|>\n"
31
- f"<|start|>Assistant:<|message|>Hello, nice to meet you!<|end|>\n"
32
- )
33
- if message["role"] == "user":
34
- context += f"<|start|>User:<|message|>{message['content']}<|end|>\n"
35
- else:
36
- context += f"<|start|>Assistant:<|message|>{message['content']}<|end|>\n"
37
-
38
- formatted_input = (
39
- f"{context}<|start|>User:<|message|>{prompt}<|end|>\n<|start|>Assistant:<|message|>"
40
- )
41
 
42
  inputs = tokenizer(
43
  formatted_input,
@@ -64,6 +57,9 @@ def generate_response(
64
  new_tokens = outputs[0][inputs.input_ids.shape[-1]:]
65
  response = tokenizer.decode(new_tokens, skip_special_tokens=False)
66
 
 
 
 
67
  return response.strip()
68
 
69
  def respond(
@@ -76,22 +72,17 @@ def respond(
76
  repetition_penalty,
77
  top_k,
78
  ):
79
- conversation_history = history
80
  response = generate_response(
81
  message,
82
  system_message,
83
- conversation_history,
84
  max_tokens=max_tokens,
85
  temperature=temperature,
86
  top_p=top_p,
87
  repetition_penalty=repetition_penalty,
88
  top_k=top_k,
89
  )
90
-
91
- if "<|end|>" in response:
92
- response = response.split("<|end|>")[0]
93
-
94
- return response.strip()
95
 
96
  chatbot = gr.ChatInterface(
97
  respond,
 
9
  if tokenizer.pad_token is None:
10
  tokenizer.pad_token = tokenizer.eos_token
11
 
12
+ def build_context(system_message, conversation_history, user_message):
13
+ context = f"<|start|>System:<|message|>{system_message}<|end|>\n"
14
+ if conversation_history:
15
+ for message in conversation_history:
16
+ if message["role"] == "user":
17
+ context += f"<|start|>User:<|message|>{message['content']}<|end|>\n"
18
+ elif message["role"] == "assistant":
19
+ context += f"<|start|>Assistant:<|message|>{message['content']}<|end|>\n"
20
+ context += f"<|start|>User:<|message|>{user_message}<|end|>\n<|start|>Assistant:<|message|>"
21
+ return context
22
+
23
  def generate_response(
24
  prompt,
25
  system_message,
 
30
  repetition_penalty=1.031,
31
  top_k=55,
32
  ):
33
+ formatted_input = build_context(system_message, conversation_history, prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  inputs = tokenizer(
36
  formatted_input,
 
57
  new_tokens = outputs[0][inputs.input_ids.shape[-1]:]
58
  response = tokenizer.decode(new_tokens, skip_special_tokens=False)
59
 
60
+ if "<|end|>" in response:
61
+ response = response.split("<|end|>")[0]
62
+
63
  return response.strip()
64
 
65
  def respond(
 
72
  repetition_penalty,
73
  top_k,
74
  ):
 
75
  response = generate_response(
76
  message,
77
  system_message,
78
+ history,
79
  max_tokens=max_tokens,
80
  temperature=temperature,
81
  top_p=top_p,
82
  repetition_penalty=repetition_penalty,
83
  top_k=top_k,
84
  )
85
+ return response
 
 
 
 
86
 
87
  chatbot = gr.ChatInterface(
88
  respond,