prithivMLmods commited on
Commit
eebb9c6
·
verified ·
1 Parent(s): 1113989
Files changed (1) hide show
  1. app.py +47 -31
app.py CHANGED
@@ -91,27 +91,19 @@ css = """
91
  """
92
 
93
  # --- Fix for Dots.OCR Processor Loading ---
94
-
95
- # Define a local directory to cache the model
96
  CACHE_PATH = "./model_cache"
97
  if not os.path.exists(CACHE_PATH):
98
  os.makedirs(CACHE_PATH)
99
-
100
- # Download the model files locally
101
  model_path_d_local = snapshot_download(
102
  repo_id='rednote-hilab/dots.ocr',
103
  local_dir=os.path.join(CACHE_PATH, 'dots.ocr'),
104
  max_workers=20,
105
  local_dir_use_symlinks=False
106
  )
107
-
108
- # Modify the configuration file to fix the processor loading issue
109
  config_file_path = os.path.join(model_path_d_local, "configuration_dots.py")
110
-
111
  if os.path.exists(config_file_path):
112
  with open(config_file_path, 'r') as f:
113
  input_code = f.read()
114
-
115
  lines = input_code.splitlines()
116
  if "class DotsVLProcessor" in input_code and not any("attributes = " in line for line in lines):
117
  output_lines = []
@@ -122,58 +114,52 @@ if os.path.exists(config_file_path):
122
  with open(config_file_path, 'w') as f:
123
  f.write('\n'.join(output_lines))
124
  print("Patched configuration_dots.py successfully.")
125
-
126
- # Add the local model path to sys.path so transformers can use the modified code
127
  sys.path.append(model_path_d_local)
128
 
129
 
130
  # --- Model Loading ---
131
-
132
- # Constants for text generation
133
  MAX_MAX_NEW_TOKENS = 4096
134
  DEFAULT_MAX_NEW_TOKENS = 2048
135
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
136
-
137
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
138
 
139
  # Load Nanonets-OCR2-3B
140
  MODEL_ID_M = "nanonets/Nanonets-OCR2-3B"
141
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
142
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
143
- MODEL_ID_M,
144
- trust_remote_code=True,
145
- torch_dtype=torch.float16
146
  ).to(device).eval()
147
 
148
  # Load Dots.OCR from the local, patched directory
149
  MODEL_PATH_D = model_path_d_local
150
  processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
151
  model_d = AutoModelForCausalLM.from_pretrained(
152
- MODEL_PATH_D,
153
- attn_implementation="eager",
154
- torch_dtype=torch.bfloat16,
155
- device_map="auto",
156
- trust_remote_code=True
157
  ).eval()
158
 
159
  # Load PaddleOCR
160
  MODEL_ID_P = "strangervisionhf/paddle"
161
  processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True)
162
  model_p = AutoModelForCausalLM.from_pretrained(
163
- MODEL_ID_P,
164
- trust_remote_code=True,
165
- torch_dtype=torch.bfloat16
166
  ).to(device).eval()
167
 
168
 
169
  @spaces.GPU
170
- def generate_image(model_name: str, text: str, image: Image.Image,
171
  max_new_tokens: int = 1024,
172
  temperature: float = 0.6,
173
  top_p: float = 0.9,
174
  top_k: int = 50,
175
  repetition_penalty: float = 1.2):
176
  """Generate responses for image input using the selected model."""
 
 
 
 
 
 
 
177
  if model_name == "Nanonets-OCR2-3B":
178
  processor, model = processor_m, model_m
179
  elif model_name == "Dots.OCR":
@@ -192,9 +178,9 @@ def generate_image(model_name: str, text: str, image: Image.Image,
192
 
193
  # --- FIX: Handle different prompt formats required by models ---
194
  if model_name == "PaddleOCR":
195
- # PaddleOCR's template expects a simple string content for the text part.
196
- # The image is passed to the processor separately.
197
- messages = [{"role": "user", "content": text}]
198
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
199
  inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
200
  else:
@@ -237,7 +223,17 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
237
  gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
238
  with gr.Row():
239
  with gr.Column(scale=2):
240
- image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
 
 
 
 
 
 
 
 
 
 
241
  image_upload = gr.Image(type="pil", label="Upload Image", height=320)
242
  image_submit = gr.Button("Submit", variant="primary")
243
  gr.Examples(examples=image_examples, inputs=[image_query, image_upload])
@@ -261,9 +257,29 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
261
  value="Nanonets-OCR2-3B"
262
  )
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  image_submit.click(
265
  fn=generate_image,
266
- inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
267
  outputs=[raw_output, formatted_output]
268
  )
269
 
 
91
  """
92
 
93
  # --- Fix for Dots.OCR Processor Loading ---
 
 
94
  CACHE_PATH = "./model_cache"
95
  if not os.path.exists(CACHE_PATH):
96
  os.makedirs(CACHE_PATH)
 
 
97
  model_path_d_local = snapshot_download(
98
  repo_id='rednote-hilab/dots.ocr',
99
  local_dir=os.path.join(CACHE_PATH, 'dots.ocr'),
100
  max_workers=20,
101
  local_dir_use_symlinks=False
102
  )
 
 
103
  config_file_path = os.path.join(model_path_d_local, "configuration_dots.py")
 
104
  if os.path.exists(config_file_path):
105
  with open(config_file_path, 'r') as f:
106
  input_code = f.read()
 
107
  lines = input_code.splitlines()
108
  if "class DotsVLProcessor" in input_code and not any("attributes = " in line for line in lines):
109
  output_lines = []
 
114
  with open(config_file_path, 'w') as f:
115
  f.write('\n'.join(output_lines))
116
  print("Patched configuration_dots.py successfully.")
 
 
117
  sys.path.append(model_path_d_local)
118
 
119
 
120
  # --- Model Loading ---
 
 
121
  MAX_MAX_NEW_TOKENS = 4096
122
  DEFAULT_MAX_NEW_TOKENS = 2048
 
 
123
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
124
 
125
  # Load Nanonets-OCR2-3B
126
  MODEL_ID_M = "nanonets/Nanonets-OCR2-3B"
127
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
128
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
129
+ MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16
 
 
130
  ).to(device).eval()
131
 
132
  # Load Dots.OCR from the local, patched directory
133
  MODEL_PATH_D = model_path_d_local
134
  processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
135
  model_d = AutoModelForCausalLM.from_pretrained(
136
+ MODEL_PATH_D, attn_implementation="eager", torch_dtype=torch.bfloat16,
137
+ device_map="auto", trust_remote_code=True
 
 
 
138
  ).eval()
139
 
140
  # Load PaddleOCR
141
  MODEL_ID_P = "strangervisionhf/paddle"
142
  processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True)
143
  model_p = AutoModelForCausalLM.from_pretrained(
144
+ MODEL_ID_P, trust_remote_code=True, torch_dtype=torch.bfloat16
 
 
145
  ).to(device).eval()
146
 
147
 
148
  @spaces.GPU
149
+ def generate_image(model_name: str, text: str, paddle_task: str, image: Image.Image,
150
  max_new_tokens: int = 1024,
151
  temperature: float = 0.6,
152
  top_p: float = 0.9,
153
  top_k: int = 50,
154
  repetition_penalty: float = 1.2):
155
  """Generate responses for image input using the selected model."""
156
+ PROMPTS = {
157
+ "OCR": "OCR:",
158
+ "Table Recognition": "Table Recognition:",
159
+ "Chart Recognition": "Chart Recognition:",
160
+ "Formula Recognition": "Formula Recognition:",
161
+ }
162
+
163
  if model_name == "Nanonets-OCR2-3B":
164
  processor, model = processor_m, model_m
165
  elif model_name == "Dots.OCR":
 
178
 
179
  # --- FIX: Handle different prompt formats required by models ---
180
  if model_name == "PaddleOCR":
181
+ # PaddleOCR expects specific, predefined prompts for its tasks.
182
+ prompt_text = PROMPTS.get(paddle_task, "OCR:")
183
+ messages = [{"role": "user", "content": prompt_text}]
184
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
185
  inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
186
  else:
 
223
  gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
224
  with gr.Row():
225
  with gr.Column(scale=2):
226
+ # General query input, visible by default
227
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...", visible=True)
228
+
229
+ # Specific task selector for PaddleOCR, hidden by default
230
+ paddle_task = gr.Radio(
231
+ label="Select PaddleOCR Task",
232
+ choices=["OCR", "Table Recognition", "Chart Recognition", "Formula Recognition"],
233
+ value="OCR",
234
+ visible=False
235
+ )
236
+
237
  image_upload = gr.Image(type="pil", label="Upload Image", height=320)
238
  image_submit = gr.Button("Submit", variant="primary")
239
  gr.Examples(examples=image_examples, inputs=[image_query, image_upload])
 
257
  value="Nanonets-OCR2-3B"
258
  )
259
 
260
+ # Function to dynamically update the UI based on model selection
261
+ def update_ui_for_model(model_name):
262
+ if model_name == "PaddleOCR":
263
+ return {
264
+ image_query: gr.Textbox(visible=False),
265
+ paddle_task: gr.Radio(visible=True)
266
+ }
267
+ else:
268
+ return {
269
+ image_query: gr.Textbox(visible=True),
270
+ paddle_task: gr.Radio(visible=False)
271
+ }
272
+
273
+ # Attach the function to the model_choice radio button's change event
274
+ model_choice.change(
275
+ fn=update_ui_for_model,
276
+ inputs=model_choice,
277
+ outputs=[image_query, paddle_task]
278
+ )
279
+
280
  image_submit.click(
281
  fn=generate_image,
282
+ inputs=[model_choice, image_query, paddle_task, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
283
  outputs=[raw_output, formatted_output]
284
  )
285