Kamal-prog-code commited on
Commit
ca7e05a
·
1 Parent(s): 5ebb043

revert back to deepseek

Browse files
Files changed (2) hide show
  1. app.py +7 -184
  2. requirements.txt +9 -10
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoProcessor
3
  import torch
4
  import spaces
5
  import os
@@ -12,178 +12,12 @@ import re
12
  import numpy as np
13
  import base64
14
  from io import StringIO, BytesIO
15
- from huggingface_hub import snapshot_download
16
 
17
- def ensure_llama_flash_attn2():
18
- try:
19
- from transformers.models.llama import modeling_llama as llama_mod
20
- except Exception:
21
- return
22
- if not hasattr(llama_mod, "LlamaFlashAttention2"):
23
- class LlamaFlashAttention2: # fallback shim; not used when attn impl is SDPA
24
- pass
25
- llama_mod.LlamaFlashAttention2 = LlamaFlashAttention2
26
 
27
- ensure_llama_flash_attn2()
28
-
29
- def ensure_dynamiccache_max_length():
30
- try:
31
- from transformers.cache_utils import DynamicCache
32
- except Exception:
33
- return
34
- if not hasattr(DynamicCache, "get_max_length"):
35
- def get_max_length(self):
36
- return self.get_seq_length()
37
- DynamicCache.get_max_length = get_max_length
38
-
39
- ensure_dynamiccache_max_length()
40
-
41
- def allow_none_video_processor():
42
- try:
43
- import transformers.processing_utils as proc_utils
44
- except Exception:
45
- return
46
- original = proc_utils.ProcessorMixin.check_argument_for_proper_class
47
- def patched(self, attribute_name, arg):
48
- if attribute_name == "video_processor" and arg is None:
49
- return
50
- return original(self, attribute_name, arg)
51
- proc_utils.ProcessorMixin.check_argument_for_proper_class = patched
52
-
53
- allow_none_video_processor()
54
-
55
- MODEL_NAME = "deepseek-ai/DeepSeek-OCR-2"
56
- MODEL_REVISION = "e6322a289fe5b5218278d276d4e7c58e8103f46a"
57
- DOTS_OCR_MODEL = "rednote-hilab/dots.ocr"
58
- DOTS_OCR_REVISION = "c69eab6fac32ae66aaa8deea1f28a550ca8adec7"
59
- DOTS_OCR_LOCAL_DIR = "./models/dots-ocr"
60
-
61
- def resolve_attn_impl():
62
- if os.environ.get("DISABLE_FLASH_ATTN") == "1":
63
- return "eager"
64
- try:
65
- import flash_attn # noqa: F401
66
- return "flash_attention_2"
67
- except Exception:
68
- return "eager"
69
-
70
-
71
- ATTN_IMPL = resolve_attn_impl()
72
-
73
- def resolve_torch_dtype():
74
- if torch.cuda.is_available():
75
- if os.environ.get("FORCE_BF16") == "1" and torch.cuda.is_bf16_supported():
76
- return torch.bfloat16
77
- return torch.float16
78
- return torch.float32
79
-
80
- TORCH_DTYPE = resolve_torch_dtype()
81
-
82
- tokenizer = AutoTokenizer.from_pretrained(
83
- MODEL_NAME,
84
- trust_remote_code=True,
85
- revision=MODEL_REVISION,
86
- )
87
- model = AutoModel.from_pretrained(
88
- MODEL_NAME,
89
- attn_implementation=ATTN_IMPL,
90
- torch_dtype=TORCH_DTYPE,
91
- trust_remote_code=True,
92
- use_safetensors=True,
93
- revision=MODEL_REVISION,
94
- )
95
- model = model.eval()
96
- if torch.cuda.is_available():
97
- model = model.to("cuda")
98
- if TORCH_DTYPE == torch.float16:
99
- model = model.to(torch.float16)
100
-
101
- try:
102
- from qwen_vl_utils import process_vision_info
103
- except Exception:
104
- process_vision_info = None
105
-
106
- DOTS_OCR_PROMPT = "Extract all text from this image."
107
-
108
- _DOTS_OCR_MODEL = None
109
- _DOTS_OCR_PROCESSOR = None
110
-
111
- def get_dots_ocr_model():
112
- global _DOTS_OCR_MODEL, _DOTS_OCR_PROCESSOR
113
- if _DOTS_OCR_MODEL is None or _DOTS_OCR_PROCESSOR is None:
114
- os.makedirs(DOTS_OCR_LOCAL_DIR, exist_ok=True)
115
- snapshot_download(
116
- repo_id=DOTS_OCR_MODEL,
117
- revision=DOTS_OCR_REVISION,
118
- local_dir=DOTS_OCR_LOCAL_DIR,
119
- local_dir_use_symlinks=False,
120
- )
121
- dtype = TORCH_DTYPE
122
- model_kwargs = {
123
- "torch_dtype": dtype,
124
- "trust_remote_code": True,
125
- "revision": DOTS_OCR_REVISION,
126
- }
127
- if torch.cuda.is_available():
128
- model_kwargs["attn_implementation"] = ATTN_IMPL
129
- model_kwargs["device_map"] = "auto"
130
- _DOTS_OCR_MODEL = AutoModelForCausalLM.from_pretrained(
131
- DOTS_OCR_LOCAL_DIR,
132
- **model_kwargs,
133
- )
134
- _DOTS_OCR_PROCESSOR = AutoProcessor.from_pretrained(
135
- DOTS_OCR_LOCAL_DIR,
136
- trust_remote_code=True,
137
- revision=DOTS_OCR_REVISION,
138
- )
139
- return _DOTS_OCR_MODEL, _DOTS_OCR_PROCESSOR
140
-
141
- def dots_ocr_infer(image, prompt=DOTS_OCR_PROMPT, max_new_tokens=4096):
142
- if process_vision_info is None:
143
- return "dots.ocr error: qwen_vl_utils is not available."
144
- model, processor = get_dots_ocr_model()
145
- messages = [
146
- {
147
- "role": "user",
148
- "content": [
149
- {"type": "image", "image": image},
150
- {"type": "text", "text": prompt},
151
- ],
152
- }
153
- ]
154
- text = processor.apply_chat_template(
155
- messages,
156
- tokenize=False,
157
- add_generation_prompt=True,
158
- )
159
- image_inputs, video_inputs = process_vision_info(messages)
160
- inputs = processor(
161
- text=[text],
162
- images=image_inputs,
163
- videos=video_inputs,
164
- padding=True,
165
- return_tensors="pt",
166
- )
167
- device = next(model.parameters()).device
168
- inputs = inputs.to(device)
169
- if TORCH_DTYPE in (torch.float16, torch.bfloat16) and "pixel_values" in inputs:
170
- inputs["pixel_values"] = inputs["pixel_values"].to(TORCH_DTYPE)
171
- with torch.no_grad():
172
- generated_ids = model.generate(
173
- **inputs,
174
- max_new_tokens=max_new_tokens,
175
- do_sample=False,
176
- temperature=0.1,
177
- )
178
- generated_ids_trimmed = [
179
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
180
- ]
181
- output_text = processor.batch_decode(
182
- generated_ids_trimmed,
183
- skip_special_tokens=True,
184
- clean_up_tokenization_spaces=False,
185
- )
186
- return output_text[0] if output_text else ""
187
 
188
  BASE_SIZE = 1024
189
  IMAGE_SIZE = 768
@@ -422,11 +256,6 @@ with gr.Blocks(title="DeepSeek-OCR-2") as demo:
422
  )
423
  input_img = gr.Image(label="Input Image", type="pil", height=300, interactive=False)
424
  page_selector = gr.Number(label="Select Page", value=1, minimum=1, step=1, visible=False)
425
- model_choice = gr.Dropdown(
426
- ["DeepSeek-OCR-2", "dots.ocr"],
427
- value="DeepSeek-OCR-2",
428
- label="Model",
429
- )
430
  btn = gr.Button("Extract", variant="primary", size="lg")
431
 
432
  with gr.Column(scale=2):
@@ -449,19 +278,13 @@ with gr.Blocks(title="DeepSeek-OCR-2") as demo:
449
  multimodal_in.change(update_page_selector_from_multimodal, [multimodal_in], [page_selector])
450
  page_selector.change(load_image_from_multimodal, [multimodal_in, page_selector], [input_img])
451
 
452
- def run(multimodal_value, page_num, model_name):
453
  file_path = unpack_multimodal(multimodal_value)
454
  if file_path:
455
- if model_name == "dots.ocr":
456
- image = load_image(file_path, int(page_num))
457
- if image is None:
458
- return "Error: Upload a file or image", "", "", None, []
459
- dots_text = dots_ocr_infer(image)
460
- return dots_text, dots_text, dots_text, None, []
461
  return process_file(file_path, int(page_num))
462
  return "Error: Upload a file or image", "", "", None, []
463
 
464
- submit_event = btn.click(run, [multimodal_in, page_selector, model_choice],
465
  [text_out, md_out, raw_out, img_out, gallery])
466
 
467
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ from transformers import AutoModel, AutoTokenizer
3
  import torch
4
  import spaces
5
  import os
 
12
  import numpy as np
13
  import base64
14
  from io import StringIO, BytesIO
 
15
 
16
+ MODEL_NAME = 'deepseek-ai/DeepSeek-OCR-2'
 
 
 
 
 
 
 
 
17
 
18
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
19
+ model = AutoModel.from_pretrained(MODEL_NAME, _attn_implementation='flash_attention_2', torch_dtype=torch.bfloat16, trust_remote_code=True, use_safetensors=True)
20
+ model = model.eval().cuda()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  BASE_SIZE = 1024
23
  IMAGE_SIZE = 768
 
256
  )
257
  input_img = gr.Image(label="Input Image", type="pil", height=300, interactive=False)
258
  page_selector = gr.Number(label="Select Page", value=1, minimum=1, step=1, visible=False)
 
 
 
 
 
259
  btn = gr.Button("Extract", variant="primary", size="lg")
260
 
261
  with gr.Column(scale=2):
 
278
  multimodal_in.change(update_page_selector_from_multimodal, [multimodal_in], [page_selector])
279
  page_selector.change(load_image_from_multimodal, [multimodal_in, page_selector], [input_img])
280
 
281
+ def run(multimodal_value, page_num):
282
  file_path = unpack_multimodal(multimodal_value)
283
  if file_path:
 
 
 
 
 
 
284
  return process_file(file_path, int(page_num))
285
  return "Error: Upload a file or image", "", "", None, []
286
 
287
+ submit_event = btn.click(run, [multimodal_in, page_selector],
288
  [text_out, md_out, raw_out, img_out, gallery])
289
 
290
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,12 +1,11 @@
1
- spaces
2
- huggingface_hub
3
- transformers
4
- torch
5
- torchvision
6
- qwen_vl_utils
7
- Pillow
8
- PyMuPDF
9
  accelerate
10
- addict
11
- matplotlib
12
  einops
 
 
 
 
 
 
 
1
+ torch==2.6.0
2
+ transformers==4.46.3
3
+ tokenizers
 
 
 
 
 
4
  accelerate
 
 
5
  einops
6
+ addict
7
+ easydict
8
+ torchvision
9
+ flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
10
+ PyMuPDF
11
+ hf_transfer