KarthiEz commited on
Commit
e4b0b88
·
verified ·
1 Parent(s): 8848bb4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -230
app.py CHANGED
@@ -52,6 +52,37 @@ def _get_args():
52
  args = parser.parse_args()
53
  return args
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def _load_model_processor(args):
57
  # ZeroGPU: Model loads on CPU, uses eager mode
@@ -130,14 +161,15 @@ def _launch_demo(args, model, processor):
130
  # Track first call
131
  first_call = [True]
132
 
133
- # Uses closure to access model and processor
134
- # Duration increased to 120s to avoid timeout during peak hours
 
135
  @spaces.GPU(duration=120)
136
  def call_local_model(messages):
137
  import time
138
  import sys
139
  start_time = time.time()
140
-
141
  if first_call[0]:
142
  print(f"[INFO] ========== First inference call ==========")
143
  first_call[0] = False
@@ -154,13 +186,13 @@ def _launch_demo(args, model, processor):
154
  print(f"[DEBUG] Device name: {torch.cuda.get_device_name(0)}")
155
  print(f"[DEBUG] GPU Memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
156
  print(f"[DEBUG] GPU Memory reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")
157
-
158
- # Ensure model is on GPU
159
  model_device = next(model.parameters()).device
160
  print(f"[DEBUG] Model device: {model_device}")
161
  print(f"[DEBUG] Model dtype: {next(model.parameters()).dtype}")
162
 
163
- if str(model_device) == 'cpu':
164
  print(f"[ERROR] Model on CPU! Attempting to move to GPU...")
165
  if torch.cuda.is_available():
166
  move_start = time.time()
@@ -170,12 +202,10 @@ def _launch_demo(args, model, processor):
170
  print(f"[DEBUG] Model moved to GPU in: {move_time:.2f}s")
171
  else:
172
  print(f"[CRITICAL] CUDA unavailable! Running on CPU will be slow!")
173
- print(f"[CRITICAL] This may be due to ZeroGPU resource constraints")
174
- else:
175
- print(f"[INFO] Model already on GPU: {model_device}")
176
-
177
  messages = [messages]
178
-
179
  # Build input using processor
180
  texts = [
181
  processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
@@ -186,14 +216,6 @@ def _launch_demo(args, model, processor):
186
  image_inputs, video_inputs = process_vision_info(messages)
187
  print(f"[DEBUG] Image processing done, elapsed: {time.time() - start_time:.2f}s")
188
 
189
- # Check image input size
190
- if image_inputs:
191
- for idx, img in enumerate(image_inputs):
192
- if hasattr(img, 'size'):
193
- print(f"[DEBUG] Image {idx} size: {img.size}")
194
- elif isinstance(img, np.ndarray):
195
- print(f"[DEBUG] Image {idx} shape: {img.shape}")
196
-
197
  print(f"[DEBUG] Starting processor encoding...")
198
  processor_start = time.time()
199
  inputs = processor(
@@ -205,239 +227,204 @@ def _launch_demo(args, model, processor):
205
  )
206
  print(f"[DEBUG] Processor encoding done, elapsed: {time.time() - processor_start:.2f}s")
207
 
208
- # Ensure inputs on GPU
209
  to_device_start = time.time()
210
- inputs = inputs.to('cuda' if torch.cuda.is_available() else 'cpu')
211
- print(f"[DEBUG] Inputs moved to device, elapsed: {time.time() - to_device_start:.2f}s")
212
- print(f"[DEBUG] Input preparation done, total elapsed: {time.time() - start_time:.2f}s")
213
  print(f"[DEBUG] Input IDs shape: {inputs.input_ids.shape}")
214
- print(f"[DEBUG] Input device: {inputs.input_ids.device}")
215
- print(f"[DEBUG] Input sequence length: {inputs.input_ids.shape[1]}")
216
-
217
  # Generation
218
  gen_start = time.time()
219
- print(f"[DEBUG] ========== Starting token generation ==========")
220
-
221
- # Optimized max_new_tokens for OCR tasks
222
  max_new_tokens = 2048
223
- print(f"[DEBUG] max_new_tokens: {max_new_tokens}")
224
 
225
- # Progress callback
226
- token_count = [0]
227
- last_time = [gen_start]
228
-
229
- def progress_callback(input_ids, scores, **kwargs):
230
- token_count[0] += 1
231
- current_time = time.time()
232
- if token_count[0] % 10 == 0 or (current_time - last_time[0]) > 2.0:
233
- elapsed = current_time - gen_start
234
- tokens_per_sec = token_count[0] / elapsed if elapsed > 0 else 0
235
- print(f"[DEBUG] Generated {token_count[0]} tokens, speed: {tokens_per_sec:.2f} tokens/s, elapsed: {elapsed:.2f}s")
236
- last_time[0] = current_time
237
- return False
238
-
239
  with torch.no_grad():
240
- print(f"[DEBUG] Entered torch.no_grad() context, elapsed: {time.time() - start_time:.2f}s")
241
-
242
- # Test forward pass
243
  print(f"[DEBUG] Testing forward pass...")
244
  forward_test_start = time.time()
245
  try:
246
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
247
- test_outputs = model(**inputs, use_cache=False)
 
 
 
248
  print(f"[DEBUG] Forward pass test successful, elapsed: {time.time() - forward_test_start:.2f}s")
249
  except Exception as e:
250
  print(f"[WARNING] Forward pass test failed: {e}")
251
 
252
  print(f"[DEBUG] Starting model.generate()... (elapsed: {time.time() - start_time:.2f}s)")
253
  generate_call_start = time.time()
254
-
255
- try:
256
- # Deterministic generation
257
- generated_ids = model.generate(
258
- **inputs,
259
- max_new_tokens=max_new_tokens,
260
- do_sample=False,
261
- temperature=0
262
- )
263
- print(f"[DEBUG] model.generate() returned, elapsed: {time.time() - generate_call_start:.2f}s")
264
- except Exception as e:
265
- print(f"[ERROR] Generation failed: {e}")
266
- import traceback
267
- traceback.print_exc()
268
- raise
269
-
270
- print(f"[DEBUG] Exited torch.no_grad() context")
271
 
272
  gen_time = time.time() - gen_start
273
- print(f"[DEBUG] ========== Generation complete ==========")
274
  print(f"[DEBUG] Generation time: {gen_time:.2f}s")
275
  print(f"[DEBUG] Output shape: {generated_ids.shape}")
276
 
277
- # Decode output
278
- if "input_ids" in inputs:
279
- input_ids = inputs.input_ids
280
- else:
281
- input_ids = inputs.inputs
282
-
283
  generated_ids_trimmed = [
284
- out_ids[len(in_ids):] for in_ids, out_ids in zip(input_ids, generated_ids)
285
  ]
286
-
287
  actual_tokens = len(generated_ids_trimmed[0])
288
  print(f"[DEBUG] Actual tokens generated: {actual_tokens}")
289
- print(f"[DEBUG] Time per token: {gen_time/actual_tokens if actual_tokens > 0 else 0:.3f}s")
290
 
291
  output_texts = processor.batch_decode(
292
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
293
  )
294
-
295
-
296
  total_time = time.time() - start_time
297
- print(f"[DEBUG] ========== All done ==========")
298
  print(f"[DEBUG] Total time: {total_time:.2f}s")
299
- print(f"[DEBUG] Output length: {len(output_texts[0])} chars")
300
  print(f"[DEBUG] Output preview: {output_texts[0][:100]}...")
301
  output_texts[0] = clean_repeated_substrings(output_texts[0])
302
  return output_texts
303
-
304
 
 
 
 
305
  def create_predict_fn():
306
-
307
- def predict(_chatbot, task_history):
308
  nonlocal model, processor
309
- chat_query = _chatbot[-1][0]
 
 
 
310
  query = task_history[-1][0]
311
- if len(chat_query) == 0:
312
- _chatbot.pop()
313
- task_history.pop()
314
- return _chatbot
315
- print('User: ', query)
316
  history_cp = copy.deepcopy(task_history)
317
- full_response = ''
 
318
  messages = []
319
  content = []
320
  for q, a in history_cp:
321
  if isinstance(q, (tuple, list)):
322
- # Check if URL or local path
323
  img_path = q[0]
324
- if img_path.startswith(('http://', 'https://')):
325
- content.append({'type': 'image', 'image': img_path})
326
  else:
327
- content.append({'type': 'image', 'image': f'{os.path.abspath(img_path)}'})
328
  else:
329
- content.append({'type': 'text', 'text': q})
330
- messages.append({'role': 'user', 'content': content})
331
- messages.append({'role': 'assistant', 'content': [{'type': 'text', 'text': a}]})
332
- content = []
333
- messages.pop()
334
-
335
- # Call model to get response
 
 
 
 
 
 
 
 
 
336
  response_list = call_local_model(messages)
337
  response = response_list[0] if response_list else ""
338
-
339
- _chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response)))
340
  full_response = _parse_text(response)
341
 
342
  task_history[-1] = (query, full_response)
343
- print('HunyuanOCR: ' + _parse_text(full_response))
344
- yield _chatbot
 
 
345
 
346
  return predict
347
-
348
- def create_regenerate_fn():
349
 
350
- def regenerate(_chatbot, task_history):
351
- nonlocal model, processor
352
- if not task_history:
353
- return _chatbot
354
- item = task_history[-1]
355
- if item[1] is None:
356
- return _chatbot
357
- task_history[-1] = (item[0], None)
358
- chatbot_item = _chatbot.pop(-1)
359
- if chatbot_item[0] is None:
360
- _chatbot[-1] = (_chatbot[-1][0], None)
361
- else:
362
- _chatbot.append((chatbot_item[0], None))
363
- # Use outer predict function
364
- _chatbot_gen = predict(_chatbot, task_history)
365
- for _chatbot in _chatbot_gen:
366
- yield _chatbot
367
 
368
  return regenerate
369
 
370
  predict = create_predict_fn()
371
  regenerate = create_regenerate_fn()
372
 
373
- def add_text(history, task_history, text):
374
  task_text = text
375
- history = history if history is not None else []
376
  task_history = task_history if task_history is not None else []
377
- history = history + [(_parse_text(text), None)]
378
  task_history = task_history + [(task_text, None)]
379
- return history, task_history, ''
 
380
 
381
- def add_file(history, task_history, file):
382
- history = history if history is not None else []
383
  task_history = task_history if task_history is not None else []
384
- history = history + [((file.name,), None)]
385
  task_history = task_history + [((file.name,), None)]
386
- return history, task_history
387
-
388
- def download_url_image(url):
389
- """Download URL image to local temp file"""
390
- try:
391
- # Use URL hash as filename to avoid duplicate downloads
392
- url_hash = hashlib.md5(url.encode()).hexdigest()
393
- temp_dir = tempfile.gettempdir()
394
- temp_path = os.path.join(temp_dir, f"hyocr_demo_{url_hash}.png")
395
-
396
- # Return cached file if exists
397
- if os.path.exists(temp_path):
398
- return temp_path
399
-
400
- # Download image
401
- response = requests.get(url, timeout=10)
402
- response.raise_for_status()
403
- with open(temp_path, 'wb') as f:
404
- f.write(response.content)
405
- return temp_path
406
- except Exception as e:
407
- print(f"Failed to download image: {url}, error: {e}")
408
- return url # Return original URL on failure
409
 
410
  def reset_user_input():
411
- return gr.update(value='')
412
 
413
- def reset_state(_chatbot, task_history):
414
- task_history.clear()
415
- _chatbot.clear()
416
  _gc()
417
- return []
418
 
419
- # Example image paths - local files
420
  EXAMPLE_IMAGES = {
421
  "spotting": "examples/spotting.jpg",
422
  "parsing": "examples/parsing.jpg",
423
  "ie": "examples/ie.jpg",
424
  "vqa": "examples/vqa.jpg",
425
- "translation": "examples/translation.jpg"
426
  }
427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  with gr.Blocks() as demo:
429
- # Header
430
  gr.Markdown("# HunyuanOCR\n*Powered by Tencent Hunyuan Team*")
431
-
432
  with gr.Column():
433
- # Chat area
434
  chatbot = gr.Chatbot(
435
  label="Chat",
436
  height=600,
437
- type="tuple", # 👈 tell Gradio we are using (user, bot) tuples
438
  )
 
439
 
440
- # Input panel
441
  with gr.Group():
442
  query = gr.Textbox(
443
  lines=2,
@@ -454,75 +441,78 @@ def _launch_demo(args, model, processor):
454
  submit_btn = gr.Button("Send", variant="primary", scale=3)
455
  regen_btn = gr.Button("Regenerate")
456
  empty_bin = gr.Button("Clear")
457
-
458
- # Examples section
459
- gr.Markdown("### Quick Examples - Click to load")
460
 
 
461
  with gr.Row():
462
  example_1_btn = gr.Button("Text Detection")
463
  example_2_btn = gr.Button("Document Parsing")
464
  example_3_btn = gr.Button("Info Extraction")
465
  example_4_btn = gr.Button("Visual Q&A")
466
  example_5_btn = gr.Button("Translation")
467
-
468
- task_history = gr.State([])
469
-
470
-
471
- # Example 1: Text Detection
472
- def load_example_1(history, task_hist):
473
- prompt = "Detect and recognize all text in this image. Output the text with bounding box coordinates."
474
- image_path = EXAMPLE_IMAGES["spotting"]
475
- history = [((image_path,), None)]
476
- task_hist = [((image_path,), None)]
477
- return history, task_hist, prompt
478
-
479
- # Example 2: Document Parsing
480
- def load_example_2(history, task_hist):
481
- prompt = "Extract all text from this document in markdown format. Use HTML for tables and LaTeX for equations. Parse in reading order."
482
- image_path = EXAMPLE_IMAGES["parsing"]
483
- history = [((image_path,), None)]
484
- task_hist = [((image_path,), None)]
485
- return history, task_hist, prompt
486
-
487
- # Example 3: Information Extraction
488
- def load_example_3(history, task_hist):
489
- prompt = "Extract the following fields from this receipt and return as JSON: ['total', 'subtotal', 'tax', 'date', 'items']"
490
- image_path = EXAMPLE_IMAGES["ie"]
491
- history = [((image_path,), None)]
492
- task_hist = [((image_path,), None)]
493
- return history, task_hist, prompt
494
-
495
- # Example 4: Visual Q&A
496
- def load_example_4(history, task_hist):
497
- prompt = "Look at this chart and answer: Which quarter had the highest revenue? What was the Sales value in Q4?"
498
- image_path = EXAMPLE_IMAGES["vqa"]
499
- history = [((image_path,), None)]
500
- task_hist = [((image_path,), None)]
501
- return history, task_hist, prompt
502
-
503
- # Example 5: Translation
504
- def load_example_5(history, task_hist):
505
- prompt = "Translate all text in this image to English."
506
- image_path = EXAMPLE_IMAGES["translation"]
507
- history = [((image_path,), None)]
508
- task_hist = [((image_path,), None)]
509
- return history, task_hist, prompt
510
-
511
- # Bind events
512
- example_1_btn.click(load_example_1, [chatbot, task_history], [chatbot, task_history, query])
513
- example_2_btn.click(load_example_2, [chatbot, task_history], [chatbot, task_history, query])
514
- example_3_btn.click(load_example_3, [chatbot, task_history], [chatbot, task_history, query])
515
- example_4_btn.click(load_example_4, [chatbot, task_history], [chatbot, task_history, query])
516
- example_5_btn.click(load_example_5, [chatbot, task_history], [chatbot, task_history, query])
517
-
518
- submit_btn.click(add_text, [chatbot, task_history, query],
519
- [chatbot, task_history]).then(predict, [chatbot, task_history], [chatbot], show_progress=True)
520
  submit_btn.click(reset_user_input, [], [query])
521
- empty_bin.click(reset_state, [chatbot, task_history], [chatbot], show_progress=True)
522
- regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
523
- addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
524
 
525
- # Feature descriptions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  with gr.Row():
527
  with gr.Column(scale=1):
528
  gr.Markdown("""
@@ -543,17 +533,13 @@ def _launch_demo(args, model, processor):
543
  - **Use Cases** - OCR, document digitization, receipt recognition, translation
544
  """)
545
 
546
- # Footer
547
  gr.Markdown("---\n*2025 Tencent Hunyuan Team. For research and educational use.*")
548
 
549
  demo.queue().launch(
550
  share=args.share,
551
  inbrowser=args.inbrowser,
552
- # server_port=args.server_port,
553
- # server_name=args.server_name,
554
  )
555
 
556
-
557
  def main():
558
  args = _get_args()
559
  model, processor = _load_model_processor(args)
 
52
  args = parser.parse_args()
53
  return args
54
 
55
+ def build_chatbot_messages(task_history):
56
+ """
57
+ Convert internal task_history [(q, a), ...] into Gradio Chatbot
58
+ messages format: [{"role": "...", "content": ...}, ...]
59
+ """
60
+ messages = []
61
+ for q, a in task_history:
62
+ # User side
63
+ if isinstance(q, (tuple, list)):
64
+ # Image-only turn
65
+ img_path = q[0]
66
+ messages.append({
67
+ "role": "user",
68
+ "content": [{"type": "image", "image": img_path}],
69
+ })
70
+ else:
71
+ messages.append({
72
+ "role": "user",
73
+ "content": q,
74
+ })
75
+
76
+ # Assistant side
77
+ if a is not None:
78
+ messages.append({
79
+ "role": "assistant",
80
+ "content": a,
81
+ })
82
+
83
+ return messages
84
+
85
+
86
 
87
  def _load_model_processor(args):
88
  # ZeroGPU: Model loads on CPU, uses eager mode
 
161
  # Track first call
162
  first_call = [True]
163
 
164
+ # =========================
165
+ # Model call (unchanged)
166
+ # =========================
167
  @spaces.GPU(duration=120)
168
  def call_local_model(messages):
169
  import time
170
  import sys
171
  start_time = time.time()
172
+
173
  if first_call[0]:
174
  print(f"[INFO] ========== First inference call ==========")
175
  first_call[0] = False
 
186
  print(f"[DEBUG] Device name: {torch.cuda.get_device_name(0)}")
187
  print(f"[DEBUG] GPU Memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
188
  print(f"[DEBUG] GPU Memory reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")
189
+
190
+ # Ensure model is on correct device
191
  model_device = next(model.parameters()).device
192
  print(f"[DEBUG] Model device: {model_device}")
193
  print(f"[DEBUG] Model dtype: {next(model.parameters()).dtype}")
194
 
195
+ if str(model_device) == "cpu":
196
  print(f"[ERROR] Model on CPU! Attempting to move to GPU...")
197
  if torch.cuda.is_available():
198
  move_start = time.time()
 
202
  print(f"[DEBUG] Model moved to GPU in: {move_time:.2f}s")
203
  else:
204
  print(f"[CRITICAL] CUDA unavailable! Running on CPU will be slow!")
205
+
206
+ # Hunyuan expects a list of conversations → wrap once
 
 
207
  messages = [messages]
208
+
209
  # Build input using processor
210
  texts = [
211
  processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
 
216
  image_inputs, video_inputs = process_vision_info(messages)
217
  print(f"[DEBUG] Image processing done, elapsed: {time.time() - start_time:.2f}s")
218
 
 
 
 
 
 
 
 
 
219
  print(f"[DEBUG] Starting processor encoding...")
220
  processor_start = time.time()
221
  inputs = processor(
 
227
  )
228
  print(f"[DEBUG] Processor encoding done, elapsed: {time.time() - processor_start:.2f}s")
229
 
230
+ # Move to device
231
  to_device_start = time.time()
232
+ device = "cuda" if torch.cuda.is_available() else "cpu"
233
+ inputs = inputs.to(device)
234
+ print(f"[DEBUG] Inputs moved to {device}, elapsed: {time.time() - to_device_start:.2f}s")
235
  print(f"[DEBUG] Input IDs shape: {inputs.input_ids.shape}")
236
+
 
 
237
  # Generation
238
  gen_start = time.time()
 
 
 
239
  max_new_tokens = 2048
240
+ print(f"[DEBUG] ========== Starting token generation (max_new_tokens={max_new_tokens}) ==========")
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  with torch.no_grad():
 
 
 
243
  print(f"[DEBUG] Testing forward pass...")
244
  forward_test_start = time.time()
245
  try:
246
+ if device == "cuda":
247
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
248
+ _ = model(**inputs, use_cache=False)
249
+ else:
250
+ _ = model(**inputs, use_cache=False)
251
  print(f"[DEBUG] Forward pass test successful, elapsed: {time.time() - forward_test_start:.2f}s")
252
  except Exception as e:
253
  print(f"[WARNING] Forward pass test failed: {e}")
254
 
255
  print(f"[DEBUG] Starting model.generate()... (elapsed: {time.time() - start_time:.2f}s)")
256
  generate_call_start = time.time()
257
+ generated_ids = model.generate(
258
+ **inputs,
259
+ max_new_tokens=max_new_tokens,
260
+ do_sample=False,
261
+ temperature=0,
262
+ )
263
+ print(f"[DEBUG] model.generate() returned, elapsed: {time.time() - generate_call_start:.2f}s")
 
 
 
 
 
 
 
 
 
 
264
 
265
  gen_time = time.time() - gen_start
 
266
  print(f"[DEBUG] Generation time: {gen_time:.2f}s")
267
  print(f"[DEBUG] Output shape: {generated_ids.shape}")
268
 
269
+ # Decode
270
+ input_ids = inputs.input_ids
 
 
 
 
271
  generated_ids_trimmed = [
272
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids, generated_ids)
273
  ]
274
+
275
  actual_tokens = len(generated_ids_trimmed[0])
276
  print(f"[DEBUG] Actual tokens generated: {actual_tokens}")
 
277
 
278
  output_texts = processor.batch_decode(
279
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
280
  )
281
+
 
282
  total_time = time.time() - start_time
 
283
  print(f"[DEBUG] Total time: {total_time:.2f}s")
 
284
  print(f"[DEBUG] Output preview: {output_texts[0][:100]}...")
285
  output_texts[0] = clean_repeated_substrings(output_texts[0])
286
  return output_texts
 
287
 
288
+ # =========================
289
+ # Chat logic
290
+ # =========================
291
  def create_predict_fn():
292
+ def predict(chatbot_value, task_history):
 
293
  nonlocal model, processor
294
+
295
+ if not task_history:
296
+ return chatbot_value, task_history
297
+
298
  query = task_history[-1][0]
299
+ print("User:", query)
 
 
 
 
300
  history_cp = copy.deepcopy(task_history)
301
+
302
+ # Build messages for Hunyuan
303
  messages = []
304
  content = []
305
  for q, a in history_cp:
306
  if isinstance(q, (tuple, list)):
 
307
  img_path = q[0]
308
+ if img_path.startswith(("http://", "https://")):
309
+ content.append({"type": "image", "image": img_path})
310
  else:
311
+ content.append({"type": "image", "image": os.path.abspath(img_path)})
312
  else:
313
+ content.append({"type": "text", "text": q})
314
+
315
+ messages.append({"role": "user", "content": content})
316
+ content = []
317
+
318
+ if a is not None:
319
+ messages.append(
320
+ {
321
+ "role": "assistant",
322
+ "content": [{"type": "text", "text": a}],
323
+ }
324
+ )
325
+
326
+ if messages and messages[-1]["role"] == "assistant" and history_cp[-1][1] is None:
327
+ messages.pop()
328
+
329
  response_list = call_local_model(messages)
330
  response = response_list[0] if response_list else ""
 
 
331
  full_response = _parse_text(response)
332
 
333
  task_history[-1] = (query, full_response)
334
+ print("HunyuanOCR:", full_response)
335
+
336
+ chatbot_messages = build_chatbot_messages(task_history)
337
+ return chatbot_messages, task_history
338
 
339
  return predict
 
 
340
 
341
+ def create_regenerate_fn():
342
+ def regenerate(chatbot_value, task_history):
343
+ # No-op regenerate for now
344
+ return chatbot_value, task_history
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
  return regenerate
347
 
348
  predict = create_predict_fn()
349
  regenerate = create_regenerate_fn()
350
 
351
+ def add_text(chatbot_value, task_history, text):
352
  task_text = text
 
353
  task_history = task_history if task_history is not None else []
 
354
  task_history = task_history + [(task_text, None)]
355
+ chatbot_messages = build_chatbot_messages(task_history)
356
+ return chatbot_messages, task_history, ""
357
 
358
+ def add_file(chatbot_value, task_history, file):
 
359
  task_history = task_history if task_history is not None else []
 
360
  task_history = task_history + [((file.name,), None)]
361
+ chatbot_messages = build_chatbot_messages(task_history)
362
+ return chatbot_messages, task_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
  def reset_user_input():
365
+ return gr.update(value="")
366
 
367
+ def reset_state(chatbot_value, task_history):
 
 
368
  _gc()
369
+ return [], []
370
 
371
+ # Example image paths
372
  EXAMPLE_IMAGES = {
373
  "spotting": "examples/spotting.jpg",
374
  "parsing": "examples/parsing.jpg",
375
  "ie": "examples/ie.jpg",
376
  "vqa": "examples/vqa.jpg",
377
+ "translation": "examples/translation.jpg",
378
  }
379
 
380
+ # Example loaders: they only touch task_history; chatbot is rebuilt via helper
381
+ def load_example_1(chatbot_value, task_hist):
382
+ prompt = "Detect and recognize all text in this image. Output the text with bounding box coordinates."
383
+ task_hist = [((EXAMPLE_IMAGES["spotting"],), None)]
384
+ chatbot_messages = build_chatbot_messages(task_hist)
385
+ return chatbot_messages, task_hist, prompt
386
+
387
+ def load_example_2(chatbot_value, task_hist):
388
+ prompt = (
389
+ "Extract all text from this document in markdown format. Use HTML for tables "
390
+ "and LaTeX for equations. Parse in reading order."
391
+ )
392
+ task_hist = [((EXAMPLE_IMAGES["parsing"],), None)]
393
+ chatbot_messages = build_chatbot_messages(task_hist)
394
+ return chatbot_messages, task_hist, prompt
395
+
396
+ def load_example_3(chatbot_value, task_hist):
397
+ prompt = "Extract the following fields from this receipt and return as JSON: ['total', 'subtotal', 'tax', 'date', 'items']"
398
+ task_hist = [((EXAMPLE_IMAGES["ie"],), None)]
399
+ chatbot_messages = build_chatbot_messages(task_hist)
400
+ return chatbot_messages, task_hist, prompt
401
+
402
+ def load_example_4(chatbot_value, task_hist):
403
+ prompt = "Look at this chart and answer: Which quarter had the highest revenue? What was the Sales value in Q4?"
404
+ task_hist = [((EXAMPLE_IMAGES["vqa"],), None)]
405
+ chatbot_messages = build_chatbot_messages(task_hist)
406
+ return chatbot_messages, task_hist, prompt
407
+
408
+ def load_example_5(chatbot_value, task_hist):
409
+ prompt = "Translate all text in this image to English."
410
+ task_hist = [((EXAMPLE_IMAGES["translation"],), None)]
411
+ chatbot_messages = build_chatbot_messages(task_hist)
412
+ return chatbot_messages, task_hist, prompt
413
+
414
+ # =========================
415
+ # UI
416
+ # =========================
417
  with gr.Blocks() as demo:
 
418
  gr.Markdown("# HunyuanOCR\n*Powered by Tencent Hunyuan Team*")
419
+
420
  with gr.Column():
 
421
  chatbot = gr.Chatbot(
422
  label="Chat",
423
  height=600,
424
+ # DO NOT PASS type=... here this env doesn't support it
425
  )
426
+ task_history = gr.State([])
427
 
 
428
  with gr.Group():
429
  query = gr.Textbox(
430
  lines=2,
 
441
  submit_btn = gr.Button("Send", variant="primary", scale=3)
442
  regen_btn = gr.Button("Regenerate")
443
  empty_bin = gr.Button("Clear")
 
 
 
444
 
445
+ gr.Markdown("### Quick Examples - Click to load")
446
  with gr.Row():
447
  example_1_btn = gr.Button("Text Detection")
448
  example_2_btn = gr.Button("Document Parsing")
449
  example_3_btn = gr.Button("Info Extraction")
450
  example_4_btn = gr.Button("Visual Q&A")
451
  example_5_btn = gr.Button("Translation")
452
+
453
+ # Example bindings
454
+ example_1_btn.click(
455
+ load_example_1,
456
+ inputs=[chatbot, task_history],
457
+ outputs=[chatbot, task_history, query],
458
+ )
459
+ example_2_btn.click(
460
+ load_example_2,
461
+ inputs=[chatbot, task_history],
462
+ outputs=[chatbot, task_history, query],
463
+ )
464
+ example_3_btn.click(
465
+ load_example_3,
466
+ inputs=[chatbot, task_history],
467
+ outputs=[chatbot, task_history, query],
468
+ )
469
+ example_4_btn.click(
470
+ load_example_4,
471
+ inputs=[chatbot, task_history],
472
+ outputs=[chatbot, task_history, query],
473
+ )
474
+ example_5_btn.click(
475
+ load_example_5,
476
+ inputs=[chatbot, task_history],
477
+ outputs=[chatbot, task_history, query],
478
+ )
479
+
480
+ # Main flow
481
+ submit_btn.click(
482
+ add_text,
483
+ inputs=[chatbot, task_history, query],
484
+ outputs=[chatbot, task_history, query],
485
+ ).then(
486
+ predict,
487
+ inputs=[chatbot, task_history],
488
+ outputs=[chatbot, task_history],
489
+ show_progress=True,
490
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  submit_btn.click(reset_user_input, [], [query])
 
 
 
492
 
493
+ empty_bin.click(
494
+ reset_state,
495
+ inputs=[chatbot, task_history],
496
+ outputs=[chatbot, task_history],
497
+ show_progress=True,
498
+ )
499
+
500
+ regen_btn.click(
501
+ regenerate,
502
+ inputs=[chatbot, task_history],
503
+ outputs=[chatbot, task_history],
504
+ show_progress=True,
505
+ )
506
+
507
+ # Upload: pass only chatbot + state; file comes as extra arg
508
+ addfile_btn.upload(
509
+ add_file,
510
+ inputs=[chatbot, task_history],
511
+ outputs=[chatbot, task_history],
512
+ show_progress=True,
513
+ )
514
+
515
+ # Descriptive section (unchanged)
516
  with gr.Row():
517
  with gr.Column(scale=1):
518
  gr.Markdown("""
 
533
  - **Use Cases** - OCR, document digitization, receipt recognition, translation
534
  """)
535
 
 
536
  gr.Markdown("---\n*2025 Tencent Hunyuan Team. For research and educational use.*")
537
 
538
  demo.queue().launch(
539
  share=args.share,
540
  inbrowser=args.inbrowser,
 
 
541
  )
542
 
 
543
  def main():
544
  args = _get_args()
545
  model, processor = _load_model_processor(args)