emelryan commited on
Commit
dd40da5
·
1 Parent(s): b5f69e6

batched pipeline

Browse files
app.py CHANGED
@@ -26,7 +26,6 @@ import numpy as np
26
  from huggingface_hub import hf_hub_download
27
  from PIL import Image, ImageDraw
28
 
29
- from nemotron_ocr.inference.pipeline import NemotronOCR
30
  from nemotron_ocr.inference.pipeline_v2 import NemotronOCRV2
31
 
32
  MODELS = {
@@ -35,9 +34,7 @@ MODELS = {
35
  "v1 (legacy, English-only)": "v1",
36
  }
37
 
38
- PIPELINE_CHOICES = ["v2 (batched)", "v1 (original)"]
39
-
40
- _pipelines: dict[str, object] = {}
41
 
42
  GROUP_COLORS = [
43
  (76, 175, 80),
@@ -86,23 +83,14 @@ def _ensure_v1_model_dir() -> str:
86
  return model_dir
87
 
88
 
89
- def _get_pipeline(lang_key: str, pipeline_type: str):
90
- cache_key = f"{lang_key}::{pipeline_type}"
91
- if cache_key not in _pipelines:
92
- use_v1_pipeline = pipeline_type == "v1 (original)"
93
-
94
  if lang_key in ("v1", "legacy"):
95
  model_dir = _ensure_v1_model_dir()
96
- if use_v1_pipeline:
97
- _pipelines[cache_key] = NemotronOCR(model_dir=model_dir)
98
- else:
99
- _pipelines[cache_key] = NemotronOCRV2(model_dir=model_dir)
100
  else:
101
- if use_v1_pipeline:
102
- _pipelines[cache_key] = NemotronOCR(lang=lang_key)
103
- else:
104
- _pipelines[cache_key] = NemotronOCRV2(lang=lang_key)
105
- return _pipelines[cache_key]
106
 
107
 
108
  def draw_boxes(image: Image.Image, predictions: list[dict]) -> Image.Image:
@@ -227,12 +215,12 @@ def format_text(predictions: list[dict], merge_level: str) -> str:
227
 
228
 
229
  @spaces.GPU(duration=120)
230
- def run_ocr(image: Image.Image, model_name: str, merge_level: str, pipeline_type: str):
231
  if image is None:
232
  return None, "Please upload an image."
233
 
234
  lang_key = MODELS[model_name]
235
- ocr = _get_pipeline(lang_key, pipeline_type)
236
  img_array = np.array(image.convert("RGB"))
237
 
238
  if merge_level == "layout":
@@ -294,11 +282,6 @@ with gr.Blocks(
294
  value="layout",
295
  label="Output Mode",
296
  )
297
- pipeline_type = gr.Radio(
298
- choices=PIPELINE_CHOICES,
299
- value="v2 (batched)",
300
- label="Pipeline",
301
- )
302
  run_btn = gr.Button("Run OCR", variant="primary")
303
 
304
  with gr.Column(scale=1):
@@ -311,7 +294,7 @@ with gr.Blocks(
311
 
312
  run_btn.click(
313
  fn=run_ocr,
314
- inputs=[input_image, model_choice, merge_level, pipeline_type],
315
  outputs=[output_image, output_text],
316
  )
317
 
 
26
  from huggingface_hub import hf_hub_download
27
  from PIL import Image, ImageDraw
28
 
 
29
  from nemotron_ocr.inference.pipeline_v2 import NemotronOCRV2
30
 
31
  MODELS = {
 
34
  "v1 (legacy, English-only)": "v1",
35
  }
36
 
37
+ _pipelines: dict[str, NemotronOCRV2] = {}
 
 
38
 
39
  GROUP_COLORS = [
40
  (76, 175, 80),
 
83
  return model_dir
84
 
85
 
86
+ def _get_pipeline(lang_key: str) -> NemotronOCRV2:
87
+ if lang_key not in _pipelines:
 
 
 
88
  if lang_key in ("v1", "legacy"):
89
  model_dir = _ensure_v1_model_dir()
90
+ _pipelines[lang_key] = NemotronOCRV2(model_dir=model_dir)
 
 
 
91
  else:
92
+ _pipelines[lang_key] = NemotronOCRV2(lang=lang_key)
93
+ return _pipelines[lang_key]
 
 
 
94
 
95
 
96
  def draw_boxes(image: Image.Image, predictions: list[dict]) -> Image.Image:
 
215
 
216
 
217
  @spaces.GPU(duration=120)
218
+ def run_ocr(image: Image.Image, model_name: str, merge_level: str):
219
  if image is None:
220
  return None, "Please upload an image."
221
 
222
  lang_key = MODELS[model_name]
223
+ ocr = _get_pipeline(lang_key)
224
  img_array = np.array(image.convert("RGB"))
225
 
226
  if merge_level == "layout":
 
282
  value="layout",
283
  label="Output Mode",
284
  )
 
 
 
 
 
285
  run_btn = gr.Button("Run OCR", variant="primary")
286
 
287
  with gr.Column(scale=1):
 
294
 
295
  run_btn.click(
296
  fn=run_ocr,
297
+ inputs=[input_image, model_choice, merge_level],
298
  outputs=[output_image, output_text],
299
  )
300
 
nemotron_ocr-1.0.0-cp312-cp312-linux_x86_64.whl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0ce2c3c3a382fdf90a2c9c5147fb91f9d6ccc516b312ae659ee42c97f99579ce
3
- size 45945424
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc42583de574879c5c127d4e76398f04254bc5e3db651c3df0ccc942a2b48fa2
3
+ size 45944873