prithivMLmods commited on
Commit
8aa52e7
·
verified ·
1 Parent(s): 6b0450c
Files changed (1) hide show
  1. app.py +15 -5
app.py CHANGED
@@ -14,7 +14,6 @@ from transformers import (
14
  AutoProcessor,
15
  TextIteratorStreamer,
16
  )
17
-
18
  from gradio.themes import Soft
19
  from gradio.themes.utils import colors, fonts, sizes
20
 
@@ -101,7 +100,7 @@ if not os.path.exists(CACHE_PATH):
101
  # Download the model files locally
102
  model_path_d_local = snapshot_download(
103
  repo_id='rednote-hilab/dots.ocr',
104
- local_dir=CACHE_PATH,
105
  max_workers=20,
106
  local_dir_use_symlinks=False
107
  )
@@ -160,6 +159,15 @@ model_d = AutoModelForCausalLM.from_pretrained(
160
  trust_remote_code=True
161
  ).eval()
162
 
 
 
 
 
 
 
 
 
 
163
 
164
  @spaces.GPU
165
  def generate_image(model_name: str, text: str, image: Image.Image,
@@ -173,6 +181,8 @@ def generate_image(model_name: str, text: str, image: Image.Image,
173
  processor, model = processor_m, model_m
174
  elif model_name == "Dots.OCR":
175
  processor, model = processor_d, model_d
 
 
176
  else:
177
  yield "Invalid model selected.", "Invalid model selected."
178
  return
@@ -237,12 +247,12 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
237
 
238
  with gr.Column(scale=3):
239
  gr.Markdown("## Output", elem_id="output-title")
240
- raw_output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True)
241
- with gr.Accordion("Formatted Result", open=False):
242
  formatted_output = gr.Markdown(label="Formatted Result")
243
 
244
  model_choice = gr.Radio(
245
- choices=["Nanonets-OCR2-3B", "Dots.OCR"],
246
  label="Select Model",
247
  value="Nanonets-OCR2-3B"
248
  )
 
14
  AutoProcessor,
15
  TextIteratorStreamer,
16
  )
 
17
  from gradio.themes import Soft
18
  from gradio.themes.utils import colors, fonts, sizes
19
 
 
100
  # Download the model files locally
101
  model_path_d_local = snapshot_download(
102
  repo_id='rednote-hilab/dots.ocr',
103
+ local_dir=os.path.join(CACHE_PATH, 'dots.ocr'),
104
  max_workers=20,
105
  local_dir_use_symlinks=False
106
  )
 
159
  trust_remote_code=True
160
  ).eval()
161
 
162
+ # Load PaddleOCR
163
+ MODEL_ID_P = "strangervisionhf/paddle"
164
+ processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True)
165
+ model_p = AutoModelForCausalLM.from_pretrained(
166
+ MODEL_ID_P,
167
+ trust_remote_code=True,
168
+ torch_dtype=torch.bfloat16
169
+ ).to(device).eval()
170
+
171
 
172
  @spaces.GPU
173
  def generate_image(model_name: str, text: str, image: Image.Image,
 
181
  processor, model = processor_m, model_m
182
  elif model_name == "Dots.OCR":
183
  processor, model = processor_d, model_d
184
+ elif model_name == "PaddleOCR":
185
+ processor, model = processor_p, model_p
186
  else:
187
  yield "Invalid model selected.", "Invalid model selected."
188
  return
 
247
 
248
  with gr.Column(scale=3):
249
  gr.Markdown("## Output", elem_id="output-title")
250
+ raw_output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=13, show_copy_button=True)
251
+ with gr.Accordion("Formatted Result", open=True):
252
  formatted_output = gr.Markdown(label="Formatted Result")
253
 
254
  model_choice = gr.Radio(
255
+ choices=["Nanonets-OCR2-3B", "Dots.OCR", "PaddleOCR"],
256
  label="Select Model",
257
  value="Nanonets-OCR2-3B"
258
  )