hysts HF Staff commited on
Commit
329bc40
·
1 Parent(s): eb884c1
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -71,8 +71,8 @@ def _generate_on_gpu(
71
 
72
  thread.join()
73
  if exception_holder:
74
- msg = f"Generation failed: {exception_holder[0]}"
75
- raise gr.Error(msg)
76
 
77
 
78
  def generate(
@@ -87,15 +87,22 @@ def generate(
87
  if not message or not message.strip():
88
  raise gr.Error("Please enter a message.")
89
 
90
- conversation = [*chat_history, {"role": "user", "content": message}]
 
 
 
 
 
 
 
91
 
92
  input_ids = tokenizer.apply_chat_template(
93
  conversation, add_generation_prompt=True, return_tensors="pt", return_dict=True
94
  ).input_ids
95
  n_input_tokens = input_ids.shape[1]
96
  if n_input_tokens > MAX_INPUT_TOKENS:
97
- msg = f"Input too long ({n_input_tokens} tokens). Maximum is {MAX_INPUT_TOKENS} tokens."
98
- raise gr.Error(msg)
99
 
100
  max_new_tokens = min(max_new_tokens, MAX_INPUT_TOKENS - n_input_tokens)
101
  if max_new_tokens <= 0:
 
71
 
72
  thread.join()
73
  if exception_holder:
74
+ error_msg = f"Generation failed: {exception_holder[0]}"
75
+ raise gr.Error(error_msg)
76
 
77
 
78
  def generate(
 
87
  if not message or not message.strip():
88
  raise gr.Error("Please enter a message.")
89
 
90
+ conversation = []
91
+ for hist_msg in chat_history:
92
+ if isinstance(hist_msg["content"], list):
93
+ text = "".join(part["text"] for part in hist_msg["content"] if part["type"] == "text")
94
+ else:
95
+ text = str(hist_msg["content"])
96
+ conversation.append({"role": hist_msg["role"], "content": text})
97
+ conversation.append({"role": "user", "content": message})
98
 
99
  input_ids = tokenizer.apply_chat_template(
100
  conversation, add_generation_prompt=True, return_tensors="pt", return_dict=True
101
  ).input_ids
102
  n_input_tokens = input_ids.shape[1]
103
  if n_input_tokens > MAX_INPUT_TOKENS:
104
+ error_msg = f"Input too long ({n_input_tokens} tokens). Maximum is {MAX_INPUT_TOKENS} tokens."
105
+ raise gr.Error(error_msg)
106
 
107
  max_new_tokens = min(max_new_tokens, MAX_INPUT_TOKENS - n_input_tokens)
108
  if max_new_tokens <= 0: