prithivMLmods commited on
Commit
d618111
·
verified ·
1 Parent(s): 5d80be9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -28
app.py CHANGED
@@ -14,6 +14,7 @@ from transformers import (
14
  AutoProcessor,
15
  TextIteratorStreamer,
16
  )
 
17
  from gradio.themes import Soft
18
  from gradio.themes.utils import colors, fonts, sizes
19
 
@@ -100,7 +101,7 @@ if not os.path.exists(CACHE_PATH):
100
  # Download the model files locally
101
  model_path_d_local = snapshot_download(
102
  repo_id='rednote-hilab/dots.ocr',
103
- local_dir=os.path.join(CACHE_PATH, 'dots.ocr'),
104
  max_workers=20,
105
  local_dir_use_symlinks=False
106
  )
@@ -159,15 +160,6 @@ model_d = AutoModelForCausalLM.from_pretrained(
159
  trust_remote_code=True
160
  ).eval()
161
 
162
- # Load PaddleOCR
163
- MODEL_ID_P = "strangervisionhf/paddle"
164
- processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True)
165
- model_p = AutoModelForCausalLM.from_pretrained(
166
- MODEL_ID_P,
167
- trust_remote_code=True,
168
- torch_dtype=torch.bfloat16
169
- ).to(device).eval()
170
-
171
 
172
  @spaces.GPU
173
  def generate_image(model_name: str, text: str, image: Image.Image,
@@ -181,8 +173,6 @@ def generate_image(model_name: str, text: str, image: Image.Image,
181
  processor, model = processor_m, model_m
182
  elif model_name == "Dots.OCR":
183
  processor, model = processor_d, model_d
184
- elif model_name == "PaddleOCR":
185
- processor, model = processor_p, model_p
186
  else:
187
  yield "Invalid model selected.", "Invalid model selected."
188
  return
@@ -193,19 +183,13 @@ def generate_image(model_name: str, text: str, image: Image.Image,
193
 
194
  images = [image.convert("RGB")]
195
 
196
- # Create the prompt based on the specific model's requirements
197
- if model_name == "PaddleOCR":
198
- # PaddleOCR's template expects a single string with an image placeholder
199
- messages = [
200
- {"role": "user", "content": f"<image>\n{text}"}
201
- ]
202
- else:
203
- # Standard format for Nanonets and Dots.OCR
204
- messages = [
205
- {"role": "user", "content": [{"type": "image"}] + [{"type": "text", "text": text}]}
206
- ]
207
-
208
- prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
209
  inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
210
 
211
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
@@ -253,12 +237,12 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
253
 
254
  with gr.Column(scale=3):
255
  gr.Markdown("## Output", elem_id="output-title")
256
- raw_output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=13, show_copy_button=True)
257
- with gr.Accordion("Formatted Result", open=True):
258
  formatted_output = gr.Markdown(label="Formatted Result")
259
 
260
  model_choice = gr.Radio(
261
- choices=["Nanonets-OCR2-3B", "Dots.OCR", "PaddleOCR"],
262
  label="Select Model",
263
  value="Nanonets-OCR2-3B"
264
  )
 
14
  AutoProcessor,
15
  TextIteratorStreamer,
16
  )
17
+
18
  from gradio.themes import Soft
19
  from gradio.themes.utils import colors, fonts, sizes
20
 
 
101
  # Download the model files locally
102
  model_path_d_local = snapshot_download(
103
  repo_id='rednote-hilab/dots.ocr',
104
+ local_dir=CACHE_PATH,
105
  max_workers=20,
106
  local_dir_use_symlinks=False
107
  )
 
160
  trust_remote_code=True
161
  ).eval()
162
 
 
 
 
 
 
 
 
 
 
163
 
164
  @spaces.GPU
165
  def generate_image(model_name: str, text: str, image: Image.Image,
 
173
  processor, model = processor_m, model_m
174
  elif model_name == "Dots.OCR":
175
  processor, model = processor_d, model_d
 
 
176
  else:
177
  yield "Invalid model selected.", "Invalid model selected."
178
  return
 
183
 
184
  images = [image.convert("RGB")]
185
 
186
+ messages = [
187
+ {
188
+ "role": "user",
189
+ "content": [{"type": "image"}] + [{"type": "text", "text": text}]
190
+ }
191
+ ]
192
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
 
 
 
 
 
 
193
  inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
194
 
195
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
 
237
 
238
  with gr.Column(scale=3):
239
  gr.Markdown("## Output", elem_id="output-title")
240
+ raw_output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True)
241
+ with gr.Accordion("Formatted Result", open=False):
242
  formatted_output = gr.Markdown(label="Formatted Result")
243
 
244
  model_choice = gr.Radio(
245
+ choices=["Nanonets-OCR2-3B", "Dots.OCR"],
246
  label="Select Model",
247
  value="Nanonets-OCR2-3B"
248
  )