prithivMLmods commited on
Commit
d6bbd32
·
verified ·
1 Parent(s): ebd6535

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -29
app.py CHANGED
@@ -92,9 +92,9 @@ css = """
92
  }
93
  """
94
 
95
- # --- Local Model Caching and Patching ---
96
 
97
- # Define a local directory to cache all models
98
  CACHE_PATH = "./model_cache"
99
  if not os.path.exists(CACHE_PATH):
100
  os.makedirs(CACHE_PATH)
@@ -102,16 +102,14 @@ if not os.path.exists(CACHE_PATH):
102
  # --- Fix for Dots.OCR Processor Loading ---
103
  model_path_d_local = snapshot_download(
104
  repo_id='rednote-hilab/dots.ocr',
105
- local_dir=os.path.join(CACHE_PATH, "dots.ocr"),
106
  max_workers=20,
107
  local_dir_use_symlinks=False
108
  )
109
-
110
  config_file_path = os.path.join(model_path_d_local, "configuration_dots.py")
111
  if os.path.exists(config_file_path):
112
  with open(config_file_path, 'r') as f:
113
  input_code = f.read()
114
-
115
  lines = input_code.splitlines()
116
  if "class DotsVLProcessor" in input_code and not any("attributes = " in line for line in lines):
117
  output_lines = []
@@ -122,31 +120,40 @@ if os.path.exists(config_file_path):
122
  with open(config_file_path, 'w') as f:
123
  f.write('\n'.join(output_lines))
124
  print("Patched configuration_dots.py successfully.")
125
-
126
  sys.path.append(model_path_d_local)
127
 
 
128
  # --- Fix for DeepSeek-OCR ImportError ---
129
  model_path_s_local = snapshot_download(
130
  repo_id='deepseek-ai/DeepSeek-OCR',
131
- local_dir=os.path.join(CACHE_PATH, "DeepSeek-OCR"),
132
  max_workers=20,
133
  local_dir_use_symlinks=False
134
  )
 
 
 
 
135
 
136
- deepseek_modeling_file = os.path.join(model_path_s_local, "modeling_deepseekv2.py")
137
- if os.path.exists(deepseek_modeling_file):
138
- with open(deepseek_modeling_file, 'r', encoding='utf-8') as f:
139
- content = f.read()
140
-
141
- # Check if the problematic import exists and hasn't been patched yet
142
- problematic_import_str = "from transformers.models.llama.modeling_llama import (\n LlamaFlashAttention2,"
143
- if problematic_import_str in content:
144
- # Patch the file by commenting out the LlamaFlashAttention2 import
145
- patched_content = content.replace("LlamaFlashAttention2,", "# LlamaFlashAttention2,")
146
- with open(deepseek_modeling_file, 'w', encoding='utf-8') as f:
147
- f.write(patched_content)
 
 
 
 
 
 
148
  print("Patched modeling_deepseekv2.py successfully.")
149
-
150
  sys.path.append(model_path_s_local)
151
 
152
 
@@ -184,11 +191,10 @@ MODEL_PATH_S = model_path_s_local
184
  processor_s = AutoProcessor.from_pretrained(MODEL_PATH_S, trust_remote_code=True)
185
  model_s = AutoModelForCausalLM.from_pretrained(
186
  MODEL_PATH_S,
187
- _attn_implementation='eager',
188
  torch_dtype=torch.bfloat16,
189
- device_map="auto",
190
  trust_remote_code=True,
191
- ).eval()
192
 
193
 
194
  @spaces.GPU
@@ -214,12 +220,17 @@ def generate_image(model_name: str, text: str, image: Image.Image,
214
  return
215
 
216
  images = [image.convert("RGB")]
217
-
218
- # Use the model's appropriate processor and chat template
219
  if model_name == "DeepSeek-OCR":
220
- messages = [{"role": "user", "content": f"<image>\n{text}"}]
221
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
222
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(model.device)
 
 
 
 
 
223
  else:
224
  messages = [
225
  {
@@ -228,7 +239,7 @@ def generate_image(model_name: str, text: str, image: Image.Image,
228
  }
229
  ]
230
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
231
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(model.device)
232
 
233
 
234
  streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
 
92
  }
93
  """
94
 
95
+ # --- Model Patching ---
96
 
97
+ # Define a local directory to cache models
98
  CACHE_PATH = "./model_cache"
99
  if not os.path.exists(CACHE_PATH):
100
  os.makedirs(CACHE_PATH)
 
102
  # --- Fix for Dots.OCR Processor Loading ---
103
  model_path_d_local = snapshot_download(
104
  repo_id='rednote-hilab/dots.ocr',
105
+ local_dir=os.path.join(CACHE_PATH, 'dots.ocr'),
106
  max_workers=20,
107
  local_dir_use_symlinks=False
108
  )
 
109
  config_file_path = os.path.join(model_path_d_local, "configuration_dots.py")
110
  if os.path.exists(config_file_path):
111
  with open(config_file_path, 'r') as f:
112
  input_code = f.read()
 
113
  lines = input_code.splitlines()
114
  if "class DotsVLProcessor" in input_code and not any("attributes = " in line for line in lines):
115
  output_lines = []
 
120
  with open(config_file_path, 'w') as f:
121
  f.write('\n'.join(output_lines))
122
  print("Patched configuration_dots.py successfully.")
 
123
  sys.path.append(model_path_d_local)
124
 
125
+
126
  # --- Fix for DeepSeek-OCR ImportError ---
127
  model_path_s_local = snapshot_download(
128
  repo_id='deepseek-ai/DeepSeek-OCR',
129
+ local_dir=os.path.join(CACHE_PATH, 'DeepSeek-OCR'),
130
  max_workers=20,
131
  local_dir_use_symlinks=False
132
  )
133
+ modeling_file_path = os.path.join(model_path_s_local, "modeling_deepseekv2.py")
134
+ if os.path.exists(modeling_file_path):
135
+ with open(modeling_file_path, 'r', encoding='utf-8') as f:
136
+ input_code = f.read()
137
 
138
+ # The problematic import line
139
+ original_import = "from transformers.models.llama.modeling_llama import (\n LlamaAttention,\n LlamaFlashAttention2\n)"
140
+
141
+ if original_import in input_code:
142
+ # Replace with a safe version that handles the ImportError
143
+ safe_import = """from transformers.models.llama.modeling_llama import (
144
+ LlamaAttention
145
+ )
146
+ try:
147
+ from transformers.models.llama.modeling_llama import LlamaFlashAttention2
148
+ except ImportError:
149
+ print("Warning: `LlamaFlashAttention2` not found. Falling back to `LlamaAttention`.")
150
+ LlamaFlashAttention2 = LlamaAttention"""
151
+
152
+ patched_code = input_code.replace(original_import, safe_import)
153
+
154
+ with open(modeling_file_path, 'w', encoding='utf-8') as f:
155
+ f.write(patched_code)
156
  print("Patched modeling_deepseekv2.py successfully.")
 
157
  sys.path.append(model_path_s_local)
158
 
159
 
 
191
  processor_s = AutoProcessor.from_pretrained(MODEL_PATH_S, trust_remote_code=True)
192
  model_s = AutoModelForCausalLM.from_pretrained(
193
  MODEL_PATH_S,
194
+ _attn_implementation='flash_attention_2',
195
  torch_dtype=torch.bfloat16,
 
196
  trust_remote_code=True,
197
+ ).to(device).eval()
198
 
199
 
200
  @spaces.GPU
 
220
  return
221
 
222
  images = [image.convert("RGB")]
223
+
224
+ # For DeepSeek-OCR, the recommended prompt format is slightly different
225
  if model_name == "DeepSeek-OCR":
226
+ # Using a format found in documentation for better performance
227
+ # Note: The processor is expected to handle the full templating.
228
+ # This approach follows the user's implementation.
229
+ messages = [
230
+ {"role": "user", "content": f"<image>\n<|grounding|>{text}"}
231
+ ]
232
+ prompt = processor.tokenizer.apply_chat_template(messages, add_generation_prompt=True)
233
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
234
  else:
235
  messages = [
236
  {
 
239
  }
240
  ]
241
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
242
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
243
 
244
 
245
  streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)