prithivMLmods commited on
Commit
858d0e5
·
verified ·
1 Parent(s): 87b573a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -45
app.py CHANGED
@@ -100,7 +100,7 @@ if not os.path.exists(CACHE_PATH):
100
  # Download the model files locally
101
  model_path_d_local = snapshot_download(
102
  repo_id='rednote-hilab/dots.ocr',
103
- local_dir=os.path.join(CACHE_PATH, 'dots.ocr'),
104
  max_workers=20,
105
  local_dir_use_symlinks=False
106
  )
@@ -118,7 +118,10 @@ if os.path.exists(config_file_path):
118
  for line in lines:
119
  output_lines.append(line)
120
  if line.strip().startswith("class DotsVLProcessor"):
 
121
  output_lines.append(" attributes = [\"image_processor\", \"tokenizer\"]")
 
 
122
  with open(config_file_path, 'w') as f:
123
  f.write('\n'.join(output_lines))
124
  print("Patched configuration_dots.py successfully.")
@@ -156,18 +159,9 @@ model_d = AutoModelForCausalLM.from_pretrained(
156
  trust_remote_code=True
157
  ).eval()
158
 
159
- # Load PaddleOCR
160
- MODEL_ID_P = "strangervisionhf/paddle"
161
- processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True)
162
- model_p = AutoModelForCausalLM.from_pretrained(
163
- MODEL_ID_P,
164
- trust_remote_code=True,
165
- torch_dtype=torch.bfloat16
166
- ).to(device).eval()
167
-
168
 
169
  @spaces.GPU
170
- def generate_image(model_name: str, text: str, image: Image.Image, task_type: str,
171
  max_new_tokens: int = 1024,
172
  temperature: float = 0.6,
173
  top_p: float = 0.9,
@@ -178,8 +172,6 @@ def generate_image(model_name: str, text: str, image: Image.Image, task_type: st
178
  processor, model = processor_m, model_m
179
  elif model_name == "Dots.OCR":
180
  processor, model = processor_d, model_d
181
- elif model_name == "PaddleOCR":
182
- processor, model = processor_p, model_p
183
  else:
184
  yield "Invalid model selected.", "Invalid model selected."
185
  return
@@ -189,28 +181,15 @@ def generate_image(model_name: str, text: str, image: Image.Image, task_type: st
189
  return
190
 
191
  images = [image.convert("RGB")]
192
-
193
- # --- FIX: Use task-specific prompts for PaddleOCR for structured output ---
194
- if model_name == "PaddleOCR":
195
- task_prompts = {
196
- "General OCR": "Recognize the text in this image.",
197
- "Table Recognition": "Recognize the table in this image.",
198
- "Formula Recognition": "Recognize the formula in this image.",
199
- "Layout Analysis": "Analyze the layout of this document. Return the result in markdown format."
200
  }
201
- # Use the task-specific prompt and ignore the user's free-form text query
202
- prompt_text = task_prompts.get(task_type, "Recognize the text in this image.")
203
- messages = [{"role": "user", "content": prompt_text}]
204
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
205
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
206
- else:
207
- # For other models, use the standard user-provided text query
208
- messages = [
209
- {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": text}]}
210
- ]
211
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
212
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
213
- # --- END FIX ---
214
 
215
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
216
  generation_kwargs = {
@@ -262,23 +241,14 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
262
  formatted_output = gr.Markdown(label="Formatted Result")
263
 
264
  model_choice = gr.Radio(
265
- choices=["Nanonets-OCR2-3B", "Dots.OCR", "PaddleOCR"],
266
  label="Select Model",
267
  value="Nanonets-OCR2-3B"
268
  )
269
-
270
- # --- NEW UI ELEMENT FOR PADDLEOCR ---
271
- task_type_dropdown = gr.Radio(
272
- choices=["General OCR", "Table Recognition", "Formula Recognition", "Layout Analysis"],
273
- label="Select Task for PaddleOCR",
274
- value="General OCR",
275
- info="This selection is used ONLY for the PaddleOCR model to ensure structured output. The 'Query Input' box will be ignored."
276
- )
277
- # --- END NEW UI ELEMENT ---
278
 
279
  image_submit.click(
280
  fn=generate_image,
281
- inputs=[model_choice, image_query, image_upload, task_type_dropdown, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
282
  outputs=[raw_output, formatted_output]
283
  )
284
 
 
100
  # Download the model files locally
101
  model_path_d_local = snapshot_download(
102
  repo_id='rednote-hilab/dots.ocr',
103
+ local_dir=CACHE_PATH,
104
  max_workers=20,
105
  local_dir_use_symlinks=False
106
  )
 
118
  for line in lines:
119
  output_lines.append(line)
120
  if line.strip().startswith("class DotsVLProcessor"):
121
+ # Insert the attributes line to specify which processors to load
122
  output_lines.append(" attributes = [\"image_processor\", \"tokenizer\"]")
123
+
124
+ # Write the modified content back to the file
125
  with open(config_file_path, 'w') as f:
126
  f.write('\n'.join(output_lines))
127
  print("Patched configuration_dots.py successfully.")
 
159
  trust_remote_code=True
160
  ).eval()
161
 
 
 
 
 
 
 
 
 
 
162
 
163
  @spaces.GPU
164
+ def generate_image(model_name: str, text: str, image: Image.Image,
165
  max_new_tokens: int = 1024,
166
  temperature: float = 0.6,
167
  top_p: float = 0.9,
 
172
  processor, model = processor_m, model_m
173
  elif model_name == "Dots.OCR":
174
  processor, model = processor_d, model_d
 
 
175
  else:
176
  yield "Invalid model selected.", "Invalid model selected."
177
  return
 
181
  return
182
 
183
  images = [image.convert("RGB")]
184
+
185
+ messages = [
186
+ {
187
+ "role": "user",
188
+ "content": [{"type": "image"}] + [{"type": "text", "text": text}]
 
 
 
189
  }
190
+ ]
191
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
192
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
193
 
194
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
195
  generation_kwargs = {
 
241
  formatted_output = gr.Markdown(label="Formatted Result")
242
 
243
  model_choice = gr.Radio(
244
+ choices=["Nanonets-OCR2-3B", "Dots.OCR"],
245
  label="Select Model",
246
  value="Nanonets-OCR2-3B"
247
  )
 
 
 
 
 
 
 
 
 
248
 
249
  image_submit.click(
250
  fn=generate_image,
251
+ inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
252
  outputs=[raw_output, formatted_output]
253
  )
254