Kamal-prog-code commited on
Commit
4073fa4
·
1 Parent(s): b99d870

Enhance OCR functionality by integrating new model and processor, and update requirements

Browse files
Files changed (2) hide show
  1. app.py +91 -10
  2. requirements.txt +8 -9
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import AutoModel, AutoTokenizer
3
  import torch
4
  import spaces
5
  import os
@@ -16,17 +16,87 @@ from io import StringIO, BytesIO
16
  MODEL_NAME = 'deepseek-ai/DeepSeek-OCR-2'
17
 
18
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
19
- model = AutoModel.from_pretrained(
20
- MODEL_NAME,
21
- _attn_implementation="flash_attention_2",
22
- torch_dtype=torch.bfloat16,
23
- trust_remote_code=True,
24
- use_safetensors=True,
25
- )
26
  model = model.eval()
27
  if torch.cuda.is_available():
28
  model = model.to("cuda")
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  BASE_SIZE = 1024
31
  IMAGE_SIZE = 768
32
  CROP_MODE = True
@@ -264,6 +334,11 @@ with gr.Blocks(title="DeepSeek-OCR-2") as demo:
264
  )
265
  input_img = gr.Image(label="Input Image", type="pil", height=300, interactive=False)
266
  page_selector = gr.Number(label="Select Page", value=1, minimum=1, step=1, visible=False)
 
 
 
 
 
267
  btn = gr.Button("Extract", variant="primary", size="lg")
268
 
269
  with gr.Column(scale=2):
@@ -286,13 +361,19 @@ with gr.Blocks(title="DeepSeek-OCR-2") as demo:
286
  multimodal_in.change(update_page_selector_from_multimodal, [multimodal_in], [page_selector])
287
  page_selector.change(load_image_from_multimodal, [multimodal_in, page_selector], [input_img])
288
 
289
- def run(multimodal_value, page_num):
290
  file_path = unpack_multimodal(multimodal_value)
291
  if file_path:
 
 
 
 
 
 
292
  return process_file(file_path, int(page_num))
293
  return "Error: Upload a file or image", "", "", None, []
294
 
295
- submit_event = btn.click(run, [multimodal_in, page_selector],
296
  [text_out, md_out, raw_out, img_out, gallery])
297
 
298
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoProcessor
3
  import torch
4
  import spaces
5
  import os
 
16
  MODEL_NAME = 'deepseek-ai/DeepSeek-OCR-2'
17
 
18
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
19
+ model = AutoModel.from_pretrained(MODEL_NAME, _attn_implementation='flash_attention_2', torch_dtype=torch.bfloat16, trust_remote_code=True, use_safetensors=True)
 
 
 
 
 
 
20
  model = model.eval()
21
  if torch.cuda.is_available():
22
  model = model.to("cuda")
23
 
24
+ try:
25
+ from qwen_vl_utils import process_vision_info
26
+ except Exception:
27
+ process_vision_info = None
28
+
29
+ DOTS_OCR_PROMPT = "Extract all text from this image."
30
+
31
+ _DOTS_OCR_MODEL = None
32
+ _DOTS_OCR_PROCESSOR = None
33
+
34
+ def get_dots_ocr_model():
35
+ global _DOTS_OCR_MODEL, _DOTS_OCR_PROCESSOR
36
+ if _DOTS_OCR_MODEL is None or _DOTS_OCR_PROCESSOR is None:
37
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
38
+ model_kwargs = {
39
+ "torch_dtype": dtype,
40
+ "trust_remote_code": True,
41
+ }
42
+ if torch.cuda.is_available():
43
+ model_kwargs["attn_implementation"] = "flash_attention_2"
44
+ model_kwargs["device_map"] = "auto"
45
+ _DOTS_OCR_MODEL = AutoModelForCausalLM.from_pretrained(
46
+ "rednote-hilab/dots.ocr",
47
+ **model_kwargs,
48
+ )
49
+ _DOTS_OCR_PROCESSOR = AutoProcessor.from_pretrained(
50
+ "rednote-hilab/dots.ocr",
51
+ trust_remote_code=True,
52
+ )
53
+ return _DOTS_OCR_MODEL, _DOTS_OCR_PROCESSOR
54
+
55
+ def dots_ocr_infer(image, prompt=DOTS_OCR_PROMPT, max_new_tokens=4096):
56
+ if process_vision_info is None:
57
+ return "dots.ocr error: qwen_vl_utils is not available."
58
+ model, processor = get_dots_ocr_model()
59
+ messages = [
60
+ {
61
+ "role": "user",
62
+ "content": [
63
+ {"type": "image", "image": image},
64
+ {"type": "text", "text": prompt},
65
+ ],
66
+ }
67
+ ]
68
+ text = processor.apply_chat_template(
69
+ messages,
70
+ tokenize=False,
71
+ add_generation_prompt=True,
72
+ )
73
+ image_inputs, video_inputs = process_vision_info(messages)
74
+ inputs = processor(
75
+ text=[text],
76
+ images=image_inputs,
77
+ videos=video_inputs,
78
+ padding=True,
79
+ return_tensors="pt",
80
+ )
81
+ device = next(model.parameters()).device
82
+ inputs = inputs.to(device)
83
+ with torch.no_grad():
84
+ generated_ids = model.generate(
85
+ **inputs,
86
+ max_new_tokens=max_new_tokens,
87
+ do_sample=False,
88
+ temperature=0.1,
89
+ )
90
+ generated_ids_trimmed = [
91
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
92
+ ]
93
+ output_text = processor.batch_decode(
94
+ generated_ids_trimmed,
95
+ skip_special_tokens=True,
96
+ clean_up_tokenization_spaces=False,
97
+ )
98
+ return output_text[0] if output_text else ""
99
+
100
  BASE_SIZE = 1024
101
  IMAGE_SIZE = 768
102
  CROP_MODE = True
 
334
  )
335
  input_img = gr.Image(label="Input Image", type="pil", height=300, interactive=False)
336
  page_selector = gr.Number(label="Select Page", value=1, minimum=1, step=1, visible=False)
337
+ model_choice = gr.Dropdown(
338
+ ["DeepSeek-OCR-2", "dots.ocr"],
339
+ value="DeepSeek-OCR-2",
340
+ label="Model",
341
+ )
342
  btn = gr.Button("Extract", variant="primary", size="lg")
343
 
344
  with gr.Column(scale=2):
 
361
  multimodal_in.change(update_page_selector_from_multimodal, [multimodal_in], [page_selector])
362
  page_selector.change(load_image_from_multimodal, [multimodal_in, page_selector], [input_img])
363
 
364
+ def run(multimodal_value, page_num, model_name):
365
  file_path = unpack_multimodal(multimodal_value)
366
  if file_path:
367
+ if model_name == "dots.ocr":
368
+ image = load_image(file_path, int(page_num))
369
+ if image is None:
370
+ return "Error: Upload a file or image", "", "", None, []
371
+ dots_text = dots_ocr_infer(image)
372
+ return dots_text, dots_text, dots_text, None, []
373
  return process_file(file_path, int(page_num))
374
  return "Error: Upload a file or image", "", "", None, []
375
 
376
+ submit_event = btn.click(run, [multimodal_in, page_selector, model_choice],
377
  [text_out, md_out, raw_out, img_out, gallery])
378
 
379
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,11 +1,10 @@
1
- torch==2.6.0
2
- transformers==4.46.3
3
- tokenizers==0.20.3
4
- accelerate
5
- einops
6
- addict
7
- easydict
8
  torchvision
9
- flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
 
10
  PyMuPDF
11
- hf_transfer
 
 
1
+ spaces
2
+ huggingface_hub
3
+ transformers==4.51.3
4
+ torch
 
 
 
5
  torchvision
6
+ qwen_vl_utils
7
+ Pillow
8
  PyMuPDF
9
+ accelerate
10
+ https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.0.8/flash_attn-2.7.4.post1+cu126torch2.7-cp310-cp310-linux_x86_64.whl