prithivMLmods commited on
Commit
bf6df93
·
verified ·
1 Parent(s): 388e24a

update app

Browse files
Files changed (1) hide show
  1. app.py +38 -71
app.py CHANGED
@@ -13,7 +13,6 @@ from transformers import (
13
  AutoModelForCausalLM,
14
  AutoProcessor,
15
  TextIteratorStreamer,
16
- AutoTokenizer,
17
  )
18
 
19
  from gradio.themes import Soft
@@ -92,66 +91,44 @@ css = """
92
  }
93
  """
94
 
95
- # --- Model Patching ---
96
 
97
- # Define a local directory to cache models
98
  CACHE_PATH = "./model_cache"
99
  if not os.path.exists(CACHE_PATH):
100
  os.makedirs(CACHE_PATH)
101
 
102
- # --- Fix for Dots.OCR Processor Loading ---
103
  model_path_d_local = snapshot_download(
104
  repo_id='rednote-hilab/dots.ocr',
105
- local_dir=os.path.join(CACHE_PATH, 'dots.ocr'),
106
  max_workers=20,
107
  local_dir_use_symlinks=False
108
  )
 
 
109
  config_file_path = os.path.join(model_path_d_local, "configuration_dots.py")
 
110
  if os.path.exists(config_file_path):
111
  with open(config_file_path, 'r') as f:
112
  input_code = f.read()
 
113
  lines = input_code.splitlines()
114
  if "class DotsVLProcessor" in input_code and not any("attributes = " in line for line in lines):
115
  output_lines = []
116
  for line in lines:
117
  output_lines.append(line)
118
  if line.strip().startswith("class DotsVLProcessor"):
 
119
  output_lines.append(" attributes = [\"image_processor\", \"tokenizer\"]")
 
 
120
  with open(config_file_path, 'w') as f:
121
  f.write('\n'.join(output_lines))
122
  print("Patched configuration_dots.py successfully.")
123
- sys.path.append(model_path_d_local)
124
-
125
 
126
- # --- Fix for DeepSeek-OCR ImportError ---
127
- model_path_s_local = snapshot_download(
128
- repo_id='deepseek-ai/DeepSeek-OCR',
129
- local_dir=os.path.join(CACHE_PATH, 'DeepSeek-OCR'),
130
- max_workers=20,
131
- local_dir_use_symlinks=False
132
- )
133
- modeling_file_path = os.path.join(model_path_s_local, "modeling_deepseekv2.py")
134
- if os.path.exists(modeling_file_path):
135
- with open(modeling_file_path, 'r', encoding='utf-8') as f:
136
- input_code = f.read()
137
-
138
- original_import = "from transformers.models.llama.modeling_llama import (\n LlamaAttention,\n LlamaFlashAttention2\n)"
139
- if original_import in input_code:
140
- safe_import = """from transformers.models.llama.modeling_llama import (
141
- LlamaAttention
142
- )
143
- try:
144
- from transformers.models.llama.modeling_llama import LlamaFlashAttention2
145
- except ImportError:
146
- LlamaFlashAttention2 = LlamaAttention"""
147
- patched_code = input_code.replace(original_import, safe_import)
148
- with open(modeling_file_path, 'w', encoding='utf-8') as f:
149
- f.write(patched_code)
150
- print("Patched modeling_deepseekv2.py successfully.")
151
- sys.path.append(model_path_s_local)
152
-
153
- # --- NEW: Import the specific model class for DeepSeek-OCR ---
154
- from modeling_deepseekocr import DeepseekOCRForCausalLM
155
 
156
 
157
  # --- Model Loading ---
@@ -177,21 +154,19 @@ MODEL_PATH_D = model_path_d_local
177
  processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
178
  model_d = AutoModelForCausalLM.from_pretrained(
179
  MODEL_PATH_D,
180
- _attn_implementation="flash_attention_2",
181
  torch_dtype=torch.bfloat16,
182
  device_map="auto",
183
  trust_remote_code=True
184
  ).eval()
185
 
186
- # Load DeepSeek-OCR from the local, patched directory using its specific class
187
- MODEL_PATH_S = model_path_s_local
188
- processor_s = AutoProcessor.from_pretrained(MODEL_PATH_S, trust_remote_code=True)
189
- # --- MODIFIED: Use the specific class instead of AutoModelForCausalLM ---
190
- model_s = DeepseekOCRForCausalLM.from_pretrained(
191
- MODEL_PATH_S,
192
- _attn_implementation='eager',
193
- torch_dtype=torch.bfloat16,
194
  trust_remote_code=True,
 
195
  ).to(device).eval()
196
 
197
 
@@ -207,8 +182,8 @@ def generate_image(model_name: str, text: str, image: Image.Image,
207
  processor, model = processor_m, model_m
208
  elif model_name == "Dots.OCR":
209
  processor, model = processor_d, model_d
210
- elif model_name == "DeepSeek-OCR":
211
- processor, model = processor_s, model_s
212
  else:
213
  yield "Invalid model selected.", "Invalid model selected."
214
  return
@@ -219,24 +194,16 @@ def generate_image(model_name: str, text: str, image: Image.Image,
219
 
220
  images = [image.convert("RGB")]
221
 
222
- if model_name == "DeepSeek-OCR":
223
- messages = [
224
- {"role": "user", "content": f"<image>\n<|grounding|>{text}"}
225
- ]
226
- prompt = processor.tokenizer.apply_chat_template(messages, add_generation_prompt=True)
227
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
228
- else:
229
- messages = [
230
- {
231
- "role": "user",
232
- "content": [{"type": "image"}] + [{"type": "text", "text": text}]
233
- }
234
- ]
235
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
236
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
237
-
238
 
239
- streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
240
  generation_kwargs = {
241
  **inputs,
242
  "streamer": streamer,
@@ -257,14 +224,14 @@ def generate_image(model_name: str, text: str, image: Image.Image,
257
 
258
  # Define examples for image inference
259
  image_examples = [
260
- ["Reconstruct the doc [table] as it is.", "images/a.jpg"],
261
- ["Extract all content.", "images/b.jpg"],
262
- ["OCR the image", "images/c.jpg"],
263
  ]
264
 
265
  # Create the Gradio Interface
266
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
267
- gr.Markdown("# **Multimodal OCR3**", elem_id="main-title")
268
  with gr.Row():
269
  with gr.Column(scale=2):
270
  image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
@@ -281,14 +248,14 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
281
 
282
  with gr.Column(scale=3):
283
  gr.Markdown("## Output", elem_id="output-title")
284
- raw_output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=9, show_copy_button=True)
285
  with gr.Accordion("Formatted Result", open=False):
286
  formatted_output = gr.Markdown(label="Formatted Result")
287
 
288
  model_choice = gr.Radio(
289
- choices=["DeepSeek-OCR", "Nanonets-OCR2-3B", "Dots.OCR"],
290
  label="Select Model",
291
- value="DeepSeek-OCR"
292
  )
293
 
294
  image_submit.click(
 
13
  AutoModelForCausalLM,
14
  AutoProcessor,
15
  TextIteratorStreamer,
 
16
  )
17
 
18
  from gradio.themes import Soft
 
91
  }
92
  """
93
 
94
+ # --- Fix for Dots.OCR Processor Loading ---
95
 
96
+ # Define a local directory to cache the model
97
  CACHE_PATH = "./model_cache"
98
  if not os.path.exists(CACHE_PATH):
99
  os.makedirs(CACHE_PATH)
100
 
101
+ # Download the model files locally
102
  model_path_d_local = snapshot_download(
103
  repo_id='rednote-hilab/dots.ocr',
104
+ local_dir=os.path.join(CACHE_PATH, 'dots.ocr'), # Create a dedicated subfolder
105
  max_workers=20,
106
  local_dir_use_symlinks=False
107
  )
108
+
109
+ # Modify the configuration file to fix the processor loading issue
110
  config_file_path = os.path.join(model_path_d_local, "configuration_dots.py")
111
+
112
  if os.path.exists(config_file_path):
113
  with open(config_file_path, 'r') as f:
114
  input_code = f.read()
115
+
116
  lines = input_code.splitlines()
117
  if "class DotsVLProcessor" in input_code and not any("attributes = " in line for line in lines):
118
  output_lines = []
119
  for line in lines:
120
  output_lines.append(line)
121
  if line.strip().startswith("class DotsVLProcessor"):
122
+ # Insert the attributes line to specify which processors to load
123
  output_lines.append(" attributes = [\"image_processor\", \"tokenizer\"]")
124
+
125
+ # Write the modified content back to the file
126
  with open(config_file_path, 'w') as f:
127
  f.write('\n'.join(output_lines))
128
  print("Patched configuration_dots.py successfully.")
 
 
129
 
130
+ # Add the local model path to sys.path so transformers can use the modified code
131
+ sys.path.append(model_path_d_local)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
 
134
  # --- Model Loading ---
 
154
  processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
155
  model_d = AutoModelForCausalLM.from_pretrained(
156
  MODEL_PATH_D,
157
+ attn_implementation="eager",
158
  torch_dtype=torch.bfloat16,
159
  device_map="auto",
160
  trust_remote_code=True
161
  ).eval()
162
 
163
+ # Load PaddleOCR
164
+ MODEL_ID_P = "strangervisionhf/paddle"
165
+ processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True)
166
+ model_p = AutoModelForCausalLM.from_pretrained(
167
+ MODEL_ID_P,
 
 
 
168
  trust_remote_code=True,
169
+ torch_dtype=torch.bfloat16
170
  ).to(device).eval()
171
 
172
 
 
182
  processor, model = processor_m, model_m
183
  elif model_name == "Dots.OCR":
184
  processor, model = processor_d, model_d
185
+ elif model_name == "PaddleOCR":
186
+ processor, model = processor_p, model_p
187
  else:
188
  yield "Invalid model selected.", "Invalid model selected."
189
  return
 
194
 
195
  images = [image.convert("RGB")]
196
 
197
+ messages = [
198
+ {
199
+ "role": "user",
200
+ "content": [{"type": "image"}] + [{"type": "text", "text": text}]
201
+ }
202
+ ]
203
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
204
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
205
 
206
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
207
  generation_kwargs = {
208
  **inputs,
209
  "streamer": streamer,
 
224
 
225
  # Define examples for image inference
226
  image_examples = [
227
+ ["Reconstruct the doc [table] as it is.", "images/0.png"],
228
+ ["Describe the image!", "images/8.png"],
229
+ ["OCR the image", "images/2.jpg"],
230
  ]
231
 
232
  # Create the Gradio Interface
233
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
234
+ gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
235
  with gr.Row():
236
  with gr.Column(scale=2):
237
  image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
 
248
 
249
  with gr.Column(scale=3):
250
  gr.Markdown("## Output", elem_id="output-title")
251
+ raw_output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True)
252
  with gr.Accordion("Formatted Result", open=False):
253
  formatted_output = gr.Markdown(label="Formatted Result")
254
 
255
  model_choice = gr.Radio(
256
+ choices=["Nanonets-OCR2-3B", "Dots.OCR", "PaddleOCR"],
257
  label="Select Model",
258
+ value="Nanonets-OCR2-3B"
259
  )
260
 
261
  image_submit.click(