Luis J Camargo commited on
Commit
610efd0
Β·
1 Parent(s): b7745a8

feat: Implement streaming OCR results by converting `run_inference` to a generator and updating UI wrappers to consume partial outputs.

Browse files
Files changed (1) hide show
  1. app.py +47 -31
app.py CHANGED
@@ -173,20 +173,22 @@ def _escape_inequalities_in_math(md: str) -> str:
173
 
174
  def run_inference(img_path, task_type="ocr", progress=gr.Progress()):
175
  if not PADDLE_AVAILABLE:
176
- return "❌ Paddle backend not installed.", "", "", ""
 
177
 
178
  if pipeline is None:
179
- return "❌ Pipeline is not initialized. Check server logs for error details.", "", "", ""
 
180
 
181
  if not img_path:
182
- return "⚠️ No image provided.", "", "", ""
 
183
 
184
  try:
185
  logger.info(f"--- Inference Start: {task_type} ---")
186
  progress(0, desc="πŸ“¦ Initializing inference engine...")
187
  output = pipeline.predict(input=img_path)
188
  logger.info(f"Output object type: {type(output)}")
189
- logger.info(f"Output object: {output}")
190
 
191
  md_content = ""
192
  json_content = ""
@@ -196,52 +198,60 @@ def run_inference(img_path, task_type="ocr", progress=gr.Progress()):
196
  run_output_dir = os.path.join(OUTPUT_DIR, run_id)
197
  os.makedirs(run_output_dir, exist_ok=True)
198
 
199
- logger.info(f"will iterate")
200
- progress(0.2, desc="πŸ” Parsing document structure...")
201
 
202
  for i, res in enumerate(output):
203
  logger.info(f"Processing segment {i+1}...")
204
- progress((i + 1) / 5, desc=f"✍️ Recognizing content (segment {i+1})...")
 
 
205
 
206
- # Save results
207
- res.save_to_json(save_path=run_output_dir)
208
- res.save_to_markdown(save_path=run_output_dir)
 
 
209
  res.print()
210
 
211
- # Read back generated files
212
- fnames = os.listdir(run_output_dir)
213
  for fname in fnames:
214
- fpath = os.path.join(run_output_dir, fname)
215
  if fname.endswith(".md"):
216
  with open(fpath, 'r', encoding='utf-8') as f:
217
  content = f.read()
218
- if content not in md_content:
219
- md_content += content + "\n\n"
220
  elif fname.endswith(".json"):
221
  with open(fpath, 'r', encoding='utf-8') as f:
222
  content = f.read()
223
- if content not in json_content:
224
- json_content += content + "\n\n"
225
  elif fname.endswith((".png", ".jpg", ".jpeg")) and ("res" in fname or "vis" in fname):
226
  vis_src = image_to_base64_data_url(fpath)
227
  vis_html += f'<div style="margin-bottom:20px; border: 2px solid #10b981; border-radius: 12px; overflow: hidden; background:white;">'
228
  vis_html += f'<img src="{vis_src}" alt="Vis {i+1}" style="width:100%;">'
229
  vis_html += f'</div>'
230
 
 
 
 
231
  logger.info(f"Finished processing segment {i+1}")
232
 
233
  if not md_content:
234
  md_content = "⚠️ Finished but no content was recognized."
 
 
 
 
 
235
 
236
- md_preview = _escape_inequalities_in_math(md_content)
237
  logger.info("--- Inference Finished Successfully ---")
238
- progress(1.0, desc="βœ… Recovery complete")
239
- return md_preview, md_content, vis_html, json_content
240
 
241
  except Exception as e:
242
  logger.error(f"❌ Inference Error: {e}")
243
  logger.error(traceback.format_exc())
244
- return f"❌ Error: {str(e)}", "", "", ""
 
245
 
246
  # --- UI Components ---
247
 
@@ -281,7 +291,7 @@ with gr.Blocks() as demo:
281
  with gr.Row():
282
  with gr.Column(scale=5):
283
  file_doc = gr.Image(label="Upload Image", type="filepath")
284
- btn_parse = gr.Button("οΏ½ Start Parsing", variant="primary")
285
  with gr.Row():
286
  chart_switch = gr.Checkbox(label="Chart OCR", value=True)
287
  unwarp_switch = gr.Checkbox(label="Unwarping", value=False)
@@ -296,9 +306,11 @@ with gr.Blocks() as demo:
296
  md_raw_doc = gr.Code(language="markdown")
297
 
298
  def parse_doc_wrapper(fp, ch, uw):
299
- if not fp: return "⚠️ Please upload an image.", "", ""
300
- res_preview, res_raw, res_vis, res_json = run_inference(fp, task_type="Document")
301
- return res_preview, res_vis, res_raw
 
 
302
 
303
  btn_parse.click(
304
  parse_doc_wrapper,
@@ -325,9 +337,11 @@ with gr.Blocks() as demo:
325
  md_raw_vl = gr.Code(language="markdown")
326
 
327
  def run_vl_wrapper(fp, prompt):
328
- if not fp: return "⚠️ Please upload an image.", ""
329
- res_preview, res_raw, _, _ = run_inference(fp, task_type=prompt)
330
- return res_preview, res_raw
 
 
331
 
332
  for btn, prompt in [(btn_ocr, "Text"), (btn_formula, "Formula"), (btn_table, "Table")]:
333
  btn.click(
@@ -352,9 +366,11 @@ with gr.Blocks() as demo:
352
  json_spot = gr.Code(label="JSON", language="json")
353
 
354
  def run_spotting_wrapper(fp):
355
- if not fp: return "", ""
356
- _, _, vis, js = run_inference(fp, task_type="Spotting")
357
- return vis, js
 
 
358
 
359
  btn_run_spot.click(
360
  run_spotting_wrapper,
 
173
 
174
  def run_inference(img_path, task_type="ocr", progress=gr.Progress()):
175
  if not PADDLE_AVAILABLE:
176
+ yield "❌ Paddle backend not installed.", "", "", ""
177
+ return
178
 
179
  if pipeline is None:
180
+ yield "❌ Pipeline is not initialized. Check server logs for error details.", "", "", ""
181
+ return
182
 
183
  if not img_path:
184
+ yield "⚠️ No image provided.", "", "", ""
185
+ return
186
 
187
  try:
188
  logger.info(f"--- Inference Start: {task_type} ---")
189
  progress(0, desc="πŸ“¦ Initializing inference engine...")
190
  output = pipeline.predict(input=img_path)
191
  logger.info(f"Output object type: {type(output)}")
 
192
 
193
  md_content = ""
194
  json_content = ""
 
198
  run_output_dir = os.path.join(OUTPUT_DIR, run_id)
199
  os.makedirs(run_output_dir, exist_ok=True)
200
 
201
+ logger.info(f"Inference generator ready. Starting iteration...")
202
+ progress(0.1, desc="πŸ” Document preprocessing...")
203
 
204
  for i, res in enumerate(output):
205
  logger.info(f"Processing segment {i+1}...")
206
+ # Use dynamic progress increment
207
+ p_val = min(0.1 + (i + 1) * 0.15, 0.95)
208
+ progress(p_val, desc=f"✍️ Recognizing content (segment {i+1})...")
209
 
210
+ # Save results to unique dir
211
+ seg_dir = os.path.join(run_output_dir, f"seg_{i}")
212
+ os.makedirs(seg_dir, exist_ok=True)
213
+ res.save_to_json(save_path=seg_dir)
214
+ res.save_to_markdown(save_path=seg_dir)
215
  res.print()
216
 
217
+ # Gather files specifically from this segment
218
+ fnames = os.listdir(seg_dir)
219
  for fname in fnames:
220
+ fpath = os.path.join(seg_dir, fname)
221
  if fname.endswith(".md"):
222
  with open(fpath, 'r', encoding='utf-8') as f:
223
  content = f.read()
224
+ md_content += content + "\n\n"
 
225
  elif fname.endswith(".json"):
226
  with open(fpath, 'r', encoding='utf-8') as f:
227
  content = f.read()
228
+ json_content += content + "\n\n"
 
229
  elif fname.endswith((".png", ".jpg", ".jpeg")) and ("res" in fname or "vis" in fname):
230
  vis_src = image_to_base64_data_url(fpath)
231
  vis_html += f'<div style="margin-bottom:20px; border: 2px solid #10b981; border-radius: 12px; overflow: hidden; background:white;">'
232
  vis_html += f'<img src="{vis_src}" alt="Vis {i+1}" style="width:100%;">'
233
  vis_html += f'</div>'
234
 
235
+ # Yield partial results to keep UI alive
236
+ partial_md = _escape_inequalities_in_math(md_content)
237
+ yield partial_md, md_content, vis_html, json_content
238
  logger.info(f"Finished processing segment {i+1}")
239
 
240
  if not md_content:
241
  md_content = "⚠️ Finished but no content was recognized."
242
+ yield md_content, md_content, "", ""
243
+ else:
244
+ final_md = _escape_inequalities_in_math(md_content)
245
+ progress(1.0, desc="βœ… Complete")
246
+ yield final_md, md_content, vis_html, json_content
247
 
 
248
  logger.info("--- Inference Finished Successfully ---")
 
 
249
 
250
  except Exception as e:
251
  logger.error(f"❌ Inference Error: {e}")
252
  logger.error(traceback.format_exc())
253
+ yield f"❌ Error: {str(e)}", "", "", ""
254
+ return
255
 
256
  # --- UI Components ---
257
 
 
291
  with gr.Row():
292
  with gr.Column(scale=5):
293
  file_doc = gr.Image(label="Upload Image", type="filepath")
294
+ btn_parse = gr.Button("πŸ” Start Parsing", variant="primary")
295
  with gr.Row():
296
  chart_switch = gr.Checkbox(label="Chart OCR", value=True)
297
  unwarp_switch = gr.Checkbox(label="Unwarping", value=False)
 
306
  md_raw_doc = gr.Code(language="markdown")
307
 
308
  def parse_doc_wrapper(fp, ch, uw):
309
+ if not fp:
310
+ yield "⚠️ Please upload an image.", "", ""
311
+ return
312
+ for res_preview, res_raw, res_vis, res_json in run_inference(fp, task_type="Document"):
313
+ yield res_preview, res_vis, res_raw
314
 
315
  btn_parse.click(
316
  parse_doc_wrapper,
 
337
  md_raw_vl = gr.Code(language="markdown")
338
 
339
  def run_vl_wrapper(fp, prompt):
340
+ if not fp:
341
+ yield "⚠️ Please upload an image.", ""
342
+ return
343
+ for res_preview, res_raw, _, _ in run_inference(fp, task_type=prompt):
344
+ yield res_preview, res_raw
345
 
346
  for btn, prompt in [(btn_ocr, "Text"), (btn_formula, "Formula"), (btn_table, "Table")]:
347
  btn.click(
 
366
  json_spot = gr.Code(label="JSON", language="json")
367
 
368
  def run_spotting_wrapper(fp):
369
+ if not fp:
370
+ yield "", ""
371
+ return
372
+ for _, _, vis, js in run_inference(fp, task_type="Spotting"):
373
+ yield vis, js
374
 
375
  btn_run_spot.click(
376
  run_spotting_wrapper,