Rady10 commited on
Commit
5b9d376
·
verified ·
1 Parent(s): 56d265c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -98,10 +98,7 @@ def chunk_to_text(chunk) -> str:
98
 
99
 
100
  def to_content_list(content) -> list:
101
- """
102
- apply_chat_template requires content to ALWAYS be a list of dicts.
103
- Never a plain string — that causes: TypeError: string indices must be integers
104
- """
105
  if isinstance(content, str):
106
  return [{"type": "text", "text": content}]
107
  if isinstance(content, list):
@@ -151,19 +148,17 @@ def build_full_messages(messages: list, image: Image.Image, rag_context: str) ->
151
  )
152
  system_prompt = "\n\n".join(system_parts)
153
 
154
- # ⚠️ content MUST be list of dicts — never a plain string
155
  full_messages = [
156
  {"role": "user", "content": [{"type": "text", "text": system_prompt}]},
157
  {"role": "assistant", "content": [{"type": "text", "text": "Understood. I will use this knowledge to help you."}]},
158
  ]
159
 
160
- # normalize every incoming message too
161
  norm = [
162
  {"role": m["role"], "content": to_content_list(m.get("content", ""))}
163
  for m in messages
164
  ]
165
 
166
- # inject image into last user turn
167
  if image is not None:
168
  for i in range(len(norm) - 1, -1, -1):
169
  if norm[i]["role"] == "user":
@@ -183,12 +178,16 @@ def chat(req: ChatRequest):
183
  rag_context = "" if image else retrieve_rag_context(req.messages)
184
  full_messages = build_full_messages(req.messages, image, rag_context)
185
 
 
 
186
  inputs = processor.apply_chat_template(
187
  full_messages,
188
  add_generation_prompt=True,
189
  tokenize=True,
190
  return_tensors="pt",
191
- ).to(model.device)
 
 
192
 
193
  with torch.no_grad():
194
  output_ids = model.generate(
@@ -198,7 +197,10 @@ def chat(req: ChatRequest):
198
  top_p=0.9,
199
  )
200
 
201
- response_text = processor.decode(output_ids[0], skip_special_tokens=True)
 
 
 
202
 
203
  return {
204
  "response": response_text,
 
98
 
99
 
100
  def to_content_list(content) -> list:
101
+ """content must always be a list of dicts for apply_chat_template"""
 
 
 
102
  if isinstance(content, str):
103
  return [{"type": "text", "text": content}]
104
  if isinstance(content, list):
 
148
  )
149
  system_prompt = "\n\n".join(system_parts)
150
 
151
+ # content MUST be list of dicts — never plain string
152
  full_messages = [
153
  {"role": "user", "content": [{"type": "text", "text": system_prompt}]},
154
  {"role": "assistant", "content": [{"type": "text", "text": "Understood. I will use this knowledge to help you."}]},
155
  ]
156
 
 
157
  norm = [
158
  {"role": m["role"], "content": to_content_list(m.get("content", ""))}
159
  for m in messages
160
  ]
161
 
 
162
  if image is not None:
163
  for i in range(len(norm) - 1, -1, -1):
164
  if norm[i]["role"] == "user":
 
178
  rag_context = "" if image else retrieve_rag_context(req.messages)
179
  full_messages = build_full_messages(req.messages, image, rag_context)
180
 
181
+ # apply_chat_template with tokenize=True returns a plain Tensor, not a dict
182
+ # use return_dict=True to get {"input_ids": ..., "attention_mask": ...}
183
  inputs = processor.apply_chat_template(
184
  full_messages,
185
  add_generation_prompt=True,
186
  tokenize=True,
187
  return_tensors="pt",
188
+ return_dict=True, # ← fixes: argument after ** must be a mapping, not Tensor
189
+ )
190
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
191
 
192
  with torch.no_grad():
193
  output_ids = model.generate(
 
197
  top_p=0.9,
198
  )
199
 
200
+ # decode only the newly generated tokens (skip the input prompt)
201
+ input_len = inputs["input_ids"].shape[1]
202
+ new_tokens = output_ids[0][input_len:]
203
+ response_text = processor.decode(new_tokens, skip_special_tokens=True)
204
 
205
  return {
206
  "response": response_text,