prithivMLmods commited on
Commit
87b573a
·
verified ·
1 Parent(s): eebb9c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -49
app.py CHANGED
@@ -91,19 +91,27 @@ css = """
91
  """
92
 
93
  # --- Fix for Dots.OCR Processor Loading ---
 
 
94
  CACHE_PATH = "./model_cache"
95
  if not os.path.exists(CACHE_PATH):
96
  os.makedirs(CACHE_PATH)
 
 
97
  model_path_d_local = snapshot_download(
98
  repo_id='rednote-hilab/dots.ocr',
99
  local_dir=os.path.join(CACHE_PATH, 'dots.ocr'),
100
  max_workers=20,
101
  local_dir_use_symlinks=False
102
  )
 
 
103
  config_file_path = os.path.join(model_path_d_local, "configuration_dots.py")
 
104
  if os.path.exists(config_file_path):
105
  with open(config_file_path, 'r') as f:
106
  input_code = f.read()
 
107
  lines = input_code.splitlines()
108
  if "class DotsVLProcessor" in input_code and not any("attributes = " in line for line in lines):
109
  output_lines = []
@@ -114,52 +122,58 @@ if os.path.exists(config_file_path):
114
  with open(config_file_path, 'w') as f:
115
  f.write('\n'.join(output_lines))
116
  print("Patched configuration_dots.py successfully.")
 
 
117
  sys.path.append(model_path_d_local)
118
 
119
 
120
  # --- Model Loading ---
 
 
121
  MAX_MAX_NEW_TOKENS = 4096
122
  DEFAULT_MAX_NEW_TOKENS = 2048
 
 
123
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
124
 
125
  # Load Nanonets-OCR2-3B
126
  MODEL_ID_M = "nanonets/Nanonets-OCR2-3B"
127
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
128
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
129
- MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16
 
 
130
  ).to(device).eval()
131
 
132
  # Load Dots.OCR from the local, patched directory
133
  MODEL_PATH_D = model_path_d_local
134
  processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
135
  model_d = AutoModelForCausalLM.from_pretrained(
136
- MODEL_PATH_D, attn_implementation="eager", torch_dtype=torch.bfloat16,
137
- device_map="auto", trust_remote_code=True
 
 
 
138
  ).eval()
139
 
140
  # Load PaddleOCR
141
  MODEL_ID_P = "strangervisionhf/paddle"
142
  processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True)
143
  model_p = AutoModelForCausalLM.from_pretrained(
144
- MODEL_ID_P, trust_remote_code=True, torch_dtype=torch.bfloat16
 
 
145
  ).to(device).eval()
146
 
147
 
148
  @spaces.GPU
149
- def generate_image(model_name: str, text: str, paddle_task: str, image: Image.Image,
150
  max_new_tokens: int = 1024,
151
  temperature: float = 0.6,
152
  top_p: float = 0.9,
153
  top_k: int = 50,
154
  repetition_penalty: float = 1.2):
155
  """Generate responses for image input using the selected model."""
156
- PROMPTS = {
157
- "OCR": "OCR:",
158
- "Table Recognition": "Table Recognition:",
159
- "Chart Recognition": "Chart Recognition:",
160
- "Formula Recognition": "Formula Recognition:",
161
- }
162
-
163
  if model_name == "Nanonets-OCR2-3B":
164
  processor, model = processor_m, model_m
165
  elif model_name == "Dots.OCR":
@@ -175,16 +189,22 @@ def generate_image(model_name: str, text: str, paddle_task: str, image: Image.Im
175
  return
176
 
177
  images = [image.convert("RGB")]
178
-
179
- # --- FIX: Handle different prompt formats required by models ---
180
  if model_name == "PaddleOCR":
181
- # PaddleOCR expects specific, predefined prompts for its tasks.
182
- prompt_text = PROMPTS.get(paddle_task, "OCR:")
 
 
 
 
 
 
183
  messages = [{"role": "user", "content": prompt_text}]
184
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
185
  inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
186
  else:
187
- # Nanonets and Dots.OCR support the modern list format for multimodal content.
188
  messages = [
189
  {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": text}]}
190
  ]
@@ -223,17 +243,7 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
223
  gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
224
  with gr.Row():
225
  with gr.Column(scale=2):
226
- # General query input, visible by default
227
- image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...", visible=True)
228
-
229
- # Specific task selector for PaddleOCR, hidden by default
230
- paddle_task = gr.Radio(
231
- label="Select PaddleOCR Task",
232
- choices=["OCR", "Table Recognition", "Chart Recognition", "Formula Recognition"],
233
- value="OCR",
234
- visible=False
235
- )
236
-
237
  image_upload = gr.Image(type="pil", label="Upload Image", height=320)
238
  image_submit = gr.Button("Submit", variant="primary")
239
  gr.Examples(examples=image_examples, inputs=[image_query, image_upload])
@@ -256,30 +266,19 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
256
  label="Select Model",
257
  value="Nanonets-OCR2-3B"
258
  )
259
-
260
- # Function to dynamically update the UI based on model selection
261
- def update_ui_for_model(model_name):
262
- if model_name == "PaddleOCR":
263
- return {
264
- image_query: gr.Textbox(visible=False),
265
- paddle_task: gr.Radio(visible=True)
266
- }
267
- else:
268
- return {
269
- image_query: gr.Textbox(visible=True),
270
- paddle_task: gr.Radio(visible=False)
271
- }
272
-
273
- # Attach the function to the model_choice radio button's change event
274
- model_choice.change(
275
- fn=update_ui_for_model,
276
- inputs=model_choice,
277
- outputs=[image_query, paddle_task]
278
- )
279
 
280
  image_submit.click(
281
  fn=generate_image,
282
- inputs=[model_choice, image_query, paddle_task, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
283
  outputs=[raw_output, formatted_output]
284
  )
285
 
 
91
  """
92
 
93
  # --- Fix for Dots.OCR Processor Loading ---
94
+
95
+ # Define a local directory to cache the model
96
  CACHE_PATH = "./model_cache"
97
  if not os.path.exists(CACHE_PATH):
98
  os.makedirs(CACHE_PATH)
99
+
100
+ # Download the model files locally
101
  model_path_d_local = snapshot_download(
102
  repo_id='rednote-hilab/dots.ocr',
103
  local_dir=os.path.join(CACHE_PATH, 'dots.ocr'),
104
  max_workers=20,
105
  local_dir_use_symlinks=False
106
  )
107
+
108
+ # Modify the configuration file to fix the processor loading issue
109
  config_file_path = os.path.join(model_path_d_local, "configuration_dots.py")
110
+
111
  if os.path.exists(config_file_path):
112
  with open(config_file_path, 'r') as f:
113
  input_code = f.read()
114
+
115
  lines = input_code.splitlines()
116
  if "class DotsVLProcessor" in input_code and not any("attributes = " in line for line in lines):
117
  output_lines = []
 
122
  with open(config_file_path, 'w') as f:
123
  f.write('\n'.join(output_lines))
124
  print("Patched configuration_dots.py successfully.")
125
+
126
+ # Add the local model path to sys.path so transformers can use the modified code
127
  sys.path.append(model_path_d_local)
128
 
129
 
130
  # --- Model Loading ---
131
+
132
+ # Constants for text generation
133
  MAX_MAX_NEW_TOKENS = 4096
134
  DEFAULT_MAX_NEW_TOKENS = 2048
135
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
136
+
137
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
138
 
139
  # Load Nanonets-OCR2-3B
140
  MODEL_ID_M = "nanonets/Nanonets-OCR2-3B"
141
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
142
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
143
+ MODEL_ID_M,
144
+ trust_remote_code=True,
145
+ torch_dtype=torch.float16
146
  ).to(device).eval()
147
 
148
  # Load Dots.OCR from the local, patched directory
149
  MODEL_PATH_D = model_path_d_local
150
  processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
151
  model_d = AutoModelForCausalLM.from_pretrained(
152
+ MODEL_PATH_D,
153
+ attn_implementation="eager",
154
+ torch_dtype=torch.bfloat16,
155
+ device_map="auto",
156
+ trust_remote_code=True
157
  ).eval()
158
 
159
  # Load PaddleOCR
160
  MODEL_ID_P = "strangervisionhf/paddle"
161
  processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True)
162
  model_p = AutoModelForCausalLM.from_pretrained(
163
+ MODEL_ID_P,
164
+ trust_remote_code=True,
165
+ torch_dtype=torch.bfloat16
166
  ).to(device).eval()
167
 
168
 
169
  @spaces.GPU
170
+ def generate_image(model_name: str, text: str, image: Image.Image, task_type: str,
171
  max_new_tokens: int = 1024,
172
  temperature: float = 0.6,
173
  top_p: float = 0.9,
174
  top_k: int = 50,
175
  repetition_penalty: float = 1.2):
176
  """Generate responses for image input using the selected model."""
 
 
 
 
 
 
 
177
  if model_name == "Nanonets-OCR2-3B":
178
  processor, model = processor_m, model_m
179
  elif model_name == "Dots.OCR":
 
189
  return
190
 
191
  images = [image.convert("RGB")]
192
+
193
+ # --- FIX: Use task-specific prompts for PaddleOCR for structured output ---
194
  if model_name == "PaddleOCR":
195
+ task_prompts = {
196
+ "General OCR": "Recognize the text in this image.",
197
+ "Table Recognition": "Recognize the table in this image.",
198
+ "Formula Recognition": "Recognize the formula in this image.",
199
+ "Layout Analysis": "Analyze the layout of this document. Return the result in markdown format."
200
+ }
201
+ # Use the task-specific prompt and ignore the user's free-form text query
202
+ prompt_text = task_prompts.get(task_type, "Recognize the text in this image.")
203
  messages = [{"role": "user", "content": prompt_text}]
204
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
205
  inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
206
  else:
207
+ # For other models, use the standard user-provided text query
208
  messages = [
209
  {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": text}]}
210
  ]
 
243
  gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
244
  with gr.Row():
245
  with gr.Column(scale=2):
246
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
 
 
 
 
 
 
 
 
 
 
247
  image_upload = gr.Image(type="pil", label="Upload Image", height=320)
248
  image_submit = gr.Button("Submit", variant="primary")
249
  gr.Examples(examples=image_examples, inputs=[image_query, image_upload])
 
266
  label="Select Model",
267
  value="Nanonets-OCR2-3B"
268
  )
269
+
270
+ # --- NEW UI ELEMENT FOR PADDLEOCR ---
271
+ task_type_dropdown = gr.Radio(
272
+ choices=["General OCR", "Table Recognition", "Formula Recognition", "Layout Analysis"],
273
+ label="Select Task for PaddleOCR",
274
+ value="General OCR",
275
+ info="This selection is used ONLY for the PaddleOCR model to ensure structured output. The 'Query Input' box will be ignored."
276
+ )
277
+ # --- END NEW UI ELEMENT ---
 
 
 
 
 
 
 
 
 
 
 
278
 
279
  image_submit.click(
280
  fn=generate_image,
281
+ inputs=[model_choice, image_query, image_upload, task_type_dropdown, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
282
  outputs=[raw_output, formatted_output]
283
  )
284