prithivMLmods commited on
Commit
388e24a
·
verified ·
1 Parent(s): d6bbd32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -14
app.py CHANGED
@@ -13,7 +13,7 @@ from transformers import (
13
  AutoModelForCausalLM,
14
  AutoProcessor,
15
  TextIteratorStreamer,
16
- AutoTokenizer, # Added for DeepSeek, though AutoProcessor is used
17
  )
18
 
19
  from gradio.themes import Soft
@@ -135,27 +135,24 @@ if os.path.exists(modeling_file_path):
135
  with open(modeling_file_path, 'r', encoding='utf-8') as f:
136
  input_code = f.read()
137
 
138
- # The problematic import line
139
  original_import = "from transformers.models.llama.modeling_llama import (\n LlamaAttention,\n LlamaFlashAttention2\n)"
140
-
141
  if original_import in input_code:
142
- # Replace with a safe version that handles the ImportError
143
  safe_import = """from transformers.models.llama.modeling_llama import (
144
  LlamaAttention
145
  )
146
  try:
147
  from transformers.models.llama.modeling_llama import LlamaFlashAttention2
148
  except ImportError:
149
- print("Warning: `LlamaFlashAttention2` not found. Falling back to `LlamaAttention`.")
150
  LlamaFlashAttention2 = LlamaAttention"""
151
-
152
  patched_code = input_code.replace(original_import, safe_import)
153
-
154
  with open(modeling_file_path, 'w', encoding='utf-8') as f:
155
  f.write(patched_code)
156
  print("Patched modeling_deepseekv2.py successfully.")
157
  sys.path.append(model_path_s_local)
158
 
 
 
 
159
 
160
  # --- Model Loading ---
161
 
@@ -186,12 +183,13 @@ model_d = AutoModelForCausalLM.from_pretrained(
186
  trust_remote_code=True
187
  ).eval()
188
 
189
- # Load DeepSeek-OCR from the local, patched directory
190
  MODEL_PATH_S = model_path_s_local
191
  processor_s = AutoProcessor.from_pretrained(MODEL_PATH_S, trust_remote_code=True)
192
- model_s = AutoModelForCausalLM.from_pretrained(
 
193
  MODEL_PATH_S,
194
- _attn_implementation='flash_attention_2',
195
  torch_dtype=torch.bfloat16,
196
  trust_remote_code=True,
197
  ).to(device).eval()
@@ -221,11 +219,7 @@ def generate_image(model_name: str, text: str, image: Image.Image,
221
 
222
  images = [image.convert("RGB")]
223
 
224
- # For DeepSeek-OCR, the recommended prompt format is slightly different
225
  if model_name == "DeepSeek-OCR":
226
- # Using a format found in documentation for better performance
227
- # Note: The processor is expected to handle the full templating.
228
- # This approach follows the user's implementation.
229
  messages = [
230
  {"role": "user", "content": f"<image>\n<|grounding|>{text}"}
231
  ]
 
13
  AutoModelForCausalLM,
14
  AutoProcessor,
15
  TextIteratorStreamer,
16
+ AutoTokenizer,
17
  )
18
 
19
  from gradio.themes import Soft
 
135
  with open(modeling_file_path, 'r', encoding='utf-8') as f:
136
  input_code = f.read()
137
 
 
138
  original_import = "from transformers.models.llama.modeling_llama import (\n LlamaAttention,\n LlamaFlashAttention2\n)"
 
139
  if original_import in input_code:
 
140
  safe_import = """from transformers.models.llama.modeling_llama import (
141
  LlamaAttention
142
  )
143
  try:
144
  from transformers.models.llama.modeling_llama import LlamaFlashAttention2
145
  except ImportError:
 
146
  LlamaFlashAttention2 = LlamaAttention"""
 
147
  patched_code = input_code.replace(original_import, safe_import)
 
148
  with open(modeling_file_path, 'w', encoding='utf-8') as f:
149
  f.write(patched_code)
150
  print("Patched modeling_deepseekv2.py successfully.")
151
  sys.path.append(model_path_s_local)
152
 
153
+ # --- NEW: Import the specific model class for DeepSeek-OCR ---
154
+ from modeling_deepseekocr import DeepseekOCRForCausalLM
155
+
156
 
157
  # --- Model Loading ---
158
 
 
183
  trust_remote_code=True
184
  ).eval()
185
 
186
+ # Load DeepSeek-OCR from the local, patched directory using its specific class
187
  MODEL_PATH_S = model_path_s_local
188
  processor_s = AutoProcessor.from_pretrained(MODEL_PATH_S, trust_remote_code=True)
189
+ # --- MODIFIED: Use the specific class instead of AutoModelForCausalLM ---
190
+ model_s = DeepseekOCRForCausalLM.from_pretrained(
191
  MODEL_PATH_S,
192
+ _attn_implementation='eager',
193
  torch_dtype=torch.bfloat16,
194
  trust_remote_code=True,
195
  ).to(device).eval()
 
219
 
220
  images = [image.convert("RGB")]
221
 
 
222
  if model_name == "DeepSeek-OCR":
 
 
 
223
  messages = [
224
  {"role": "user", "content": f"<image>\n<|grounding|>{text}"}
225
  ]