prithivMLmods commited on
Commit
ebd6535
·
verified ·
1 Parent(s): 5838753

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -29
app.py CHANGED
@@ -9,12 +9,11 @@ import spaces
9
  import torch
10
  from PIL import Image
11
  from transformers import (
12
- Qwen2VLForConditionalGeneration,
13
  Qwen2_5_VLForConditionalGeneration,
14
  AutoModelForCausalLM,
15
  AutoProcessor,
16
  TextIteratorStreamer,
17
- AutoTokenizer,
18
  )
19
 
20
  from gradio.themes import Soft
@@ -93,24 +92,22 @@ css = """
93
  }
94
  """
95
 
96
- # --- Fix for Dots.OCR Processor Loading ---
97
 
98
- # Define a local directory to cache the model
99
  CACHE_PATH = "./model_cache"
100
  if not os.path.exists(CACHE_PATH):
101
  os.makedirs(CACHE_PATH)
102
 
103
- # Download the model files locally
104
  model_path_d_local = snapshot_download(
105
  repo_id='rednote-hilab/dots.ocr',
106
- local_dir=CACHE_PATH,
107
  max_workers=20,
108
  local_dir_use_symlinks=False
109
  )
110
 
111
- # Modify the configuration file to fix the processor loading issue
112
  config_file_path = os.path.join(model_path_d_local, "configuration_dots.py")
113
-
114
  if os.path.exists(config_file_path):
115
  with open(config_file_path, 'r') as f:
116
  input_code = f.read()
@@ -121,17 +118,37 @@ if os.path.exists(config_file_path):
121
  for line in lines:
122
  output_lines.append(line)
123
  if line.strip().startswith("class DotsVLProcessor"):
124
- # Insert the attributes line to specify which processors to load
125
  output_lines.append(" attributes = [\"image_processor\", \"tokenizer\"]")
126
-
127
- # Write the modified content back to the file
128
  with open(config_file_path, 'w') as f:
129
  f.write('\n'.join(output_lines))
130
  print("Patched configuration_dots.py successfully.")
131
 
132
- # Add the local model path to sys.path so transformers can use the modified code
133
  sys.path.append(model_path_d_local)
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  # --- Model Loading ---
137
 
@@ -162,13 +179,14 @@ model_d = AutoModelForCausalLM.from_pretrained(
162
  trust_remote_code=True
163
  ).eval()
164
 
165
- # Load DeepSeek-OCR
166
- MODEL_ID_S = 'deepseek-ai/DeepSeek-OCR'
167
- processor_s = AutoProcessor.from_pretrained(MODEL_ID_S, trust_remote_code=True)
168
  model_s = AutoModelForCausalLM.from_pretrained(
169
- MODEL_ID_S,
170
- _attn_implementation='flash_attention_2',
171
  torch_dtype=torch.bfloat16,
 
172
  trust_remote_code=True,
173
  ).eval()
174
 
@@ -196,18 +214,12 @@ def generate_image(model_name: str, text: str, image: Image.Image,
196
  return
197
 
198
  images = [image.convert("RGB")]
199
-
200
- # For DeepSeek-OCR, the recommended prompt format is slightly different
201
  if model_name == "DeepSeek-OCR":
202
- # Using a format found in documentation for better performance
203
- prompt_text = f"<image>\n<|grounding|>{text}"
204
- messages = [
205
- {"role": "user", "content": prompt_text}
206
- ]
207
- # apply_chat_template is not used directly, instead we build the prompt manually
208
- prompt = processor.tokenizer.apply_chat_template(messages, add_generation_prompt=True)
209
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
210
-
211
  else:
212
  messages = [
213
  {
@@ -216,7 +228,7 @@ def generate_image(model_name: str, text: str, image: Image.Image,
216
  }
217
  ]
218
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
219
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
220
 
221
 
222
  streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
 
9
  import torch
10
  from PIL import Image
11
  from transformers import (
 
12
  Qwen2_5_VLForConditionalGeneration,
13
  AutoModelForCausalLM,
14
  AutoProcessor,
15
  TextIteratorStreamer,
16
+ AutoTokenizer, # Added for DeepSeek, though AutoProcessor is used
17
  )
18
 
19
  from gradio.themes import Soft
 
92
  }
93
  """
94
 
95
+ # --- Local Model Caching and Patching ---
96
 
97
+ # Define a local directory to cache all models
98
  CACHE_PATH = "./model_cache"
99
  if not os.path.exists(CACHE_PATH):
100
  os.makedirs(CACHE_PATH)
101
 
102
+ # --- Fix for Dots.OCR Processor Loading ---
103
  model_path_d_local = snapshot_download(
104
  repo_id='rednote-hilab/dots.ocr',
105
+ local_dir=os.path.join(CACHE_PATH, "dots.ocr"),
106
  max_workers=20,
107
  local_dir_use_symlinks=False
108
  )
109
 
 
110
  config_file_path = os.path.join(model_path_d_local, "configuration_dots.py")
 
111
  if os.path.exists(config_file_path):
112
  with open(config_file_path, 'r') as f:
113
  input_code = f.read()
 
118
  for line in lines:
119
  output_lines.append(line)
120
  if line.strip().startswith("class DotsVLProcessor"):
 
121
  output_lines.append(" attributes = [\"image_processor\", \"tokenizer\"]")
 
 
122
  with open(config_file_path, 'w') as f:
123
  f.write('\n'.join(output_lines))
124
  print("Patched configuration_dots.py successfully.")
125
 
 
126
  sys.path.append(model_path_d_local)
127
 
128
+ # --- Fix for DeepSeek-OCR ImportError ---
129
+ model_path_s_local = snapshot_download(
130
+ repo_id='deepseek-ai/DeepSeek-OCR',
131
+ local_dir=os.path.join(CACHE_PATH, "DeepSeek-OCR"),
132
+ max_workers=20,
133
+ local_dir_use_symlinks=False
134
+ )
135
+
136
+ deepseek_modeling_file = os.path.join(model_path_s_local, "modeling_deepseekv2.py")
137
+ if os.path.exists(deepseek_modeling_file):
138
+ with open(deepseek_modeling_file, 'r', encoding='utf-8') as f:
139
+ content = f.read()
140
+
141
+ # Check if the problematic import exists and hasn't been patched yet
142
+ problematic_import_str = "from transformers.models.llama.modeling_llama import (\n LlamaFlashAttention2,"
143
+ if problematic_import_str in content:
144
+ # Patch the file by commenting out the LlamaFlashAttention2 import
145
+ patched_content = content.replace("LlamaFlashAttention2,", "# LlamaFlashAttention2,")
146
+ with open(deepseek_modeling_file, 'w', encoding='utf-8') as f:
147
+ f.write(patched_content)
148
+ print("Patched modeling_deepseekv2.py successfully.")
149
+
150
+ sys.path.append(model_path_s_local)
151
+
152
 
153
  # --- Model Loading ---
154
 
 
179
  trust_remote_code=True
180
  ).eval()
181
 
182
+ # Load DeepSeek-OCR from the local, patched directory
183
+ MODEL_PATH_S = model_path_s_local
184
+ processor_s = AutoProcessor.from_pretrained(MODEL_PATH_S, trust_remote_code=True)
185
  model_s = AutoModelForCausalLM.from_pretrained(
186
+ MODEL_PATH_S,
187
+ _attn_implementation='eager',
188
  torch_dtype=torch.bfloat16,
189
+ device_map="auto",
190
  trust_remote_code=True,
191
  ).eval()
192
 
 
214
  return
215
 
216
  images = [image.convert("RGB")]
217
+
218
+ # Use the model's appropriate processor and chat template
219
  if model_name == "DeepSeek-OCR":
220
+ messages = [{"role": "user", "content": f"<image>\n{text}"}]
221
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
222
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(model.device)
 
 
 
 
 
 
223
  else:
224
  messages = [
225
  {
 
228
  }
229
  ]
230
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
231
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(model.device)
232
 
233
 
234
  streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)