Geraldine commited on
Commit
9704588
·
verified ·
1 Parent(s): 28c56d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -2
app.py CHANGED
@@ -3,6 +3,7 @@ import gc
3
  import json
4
  import base64
5
  import time
 
6
  from io import BytesIO
7
  from threading import Thread
8
 
@@ -10,6 +11,7 @@ import gradio as gr
10
  import spaces
11
  import torch
12
  from PIL import Image
 
13
 
14
  from transformers import (
15
  Qwen2VLForConditionalGeneration,
@@ -41,6 +43,53 @@ if torch.cuda.is_available():
41
  print("Using device:", device)
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  MODEL_ID_V = "nanonets/Nanonets-OCR2-3B"
45
  processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
46
  model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
@@ -51,9 +100,10 @@ model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
51
  ).to(device).eval()
52
 
53
  MODEL_ID_Y = "rednote-hilab/dots.ocr"
54
- processor_y = AutoProcessor.from_pretrained(MODEL_ID_Y, trust_remote_code=True)
 
55
  model_y = AutoModelForCausalLM.from_pretrained(
56
- MODEL_ID_Y,
57
  attn_implementation="kernels-community/flash-attn2",
58
  trust_remote_code=True,
59
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
 
3
  import json
4
  import base64
5
  import time
6
+ from pathlib import Path
7
  from io import BytesIO
8
  from threading import Thread
9
 
 
11
  import spaces
12
  import torch
13
  from PIL import Image
14
+ from huggingface_hub import snapshot_download
15
 
16
  from transformers import (
17
  Qwen2VLForConditionalGeneration,
 
43
  print("Using device:", device)
44
 
45
 
46
+ def patch_dots_ocr_configuration(repo_path: str) -> None:
47
+ config_path = Path(repo_path) / "configuration_dots.py"
48
+ if not config_path.exists():
49
+ return
50
+
51
+ source = config_path.read_text(encoding="utf-8")
52
+ updated = source
53
+
54
+ if 'attributes = ["image_processor", "tokenizer"]' not in updated:
55
+ updated = updated.replace(
56
+ "class DotsVLProcessor(Qwen2_5_VLProcessor):\n",
57
+ 'class DotsVLProcessor(Qwen2_5_VLProcessor):\n attributes = ["image_processor", "tokenizer"]\n',
58
+ 1,
59
+ )
60
+
61
+ if "def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):" in updated:
62
+ updated = updated.replace(
63
+ "def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):",
64
+ "def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):",
65
+ 1,
66
+ )
67
+
68
+ if "super().__init__(image_processor, tokenizer, chat_template=chat_template)" in updated:
69
+ updated = updated.replace(
70
+ "super().__init__(image_processor, tokenizer, chat_template=chat_template)",
71
+ "super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)",
72
+ 1,
73
+ )
74
+
75
+ if updated != source:
76
+ config_path.write_text(updated, encoding="utf-8")
77
+ print(f"Patched dots.OCR processor config: {config_path}")
78
+
79
+
80
+ def resolve_dots_ocr_model_path(repo_id: str) -> str:
81
+ try:
82
+ AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
83
+ return repo_id
84
+ except TypeError as exc:
85
+ if "video_processor" not in str(exc):
86
+ raise
87
+ print("dots.OCR processor compatibility issue detected, applying local patch...")
88
+ local_path = snapshot_download(repo_id=repo_id, local_dir="/tmp/dots_ocr_model", local_dir_use_symlinks=False)
89
+ patch_dots_ocr_configuration(local_path)
90
+ return local_path
91
+
92
+
93
  MODEL_ID_V = "nanonets/Nanonets-OCR2-3B"
94
  processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
95
  model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
 
100
  ).to(device).eval()
101
 
102
  MODEL_ID_Y = "rednote-hilab/dots.ocr"
103
+ MODEL_PATH_Y = resolve_dots_ocr_model_path(MODEL_ID_Y)
104
+ processor_y = AutoProcessor.from_pretrained(MODEL_PATH_Y, trust_remote_code=True)
105
  model_y = AutoModelForCausalLM.from_pretrained(
106
+ MODEL_PATH_Y,
107
  attn_implementation="kernels-community/flash-attn2",
108
  trust_remote_code=True,
109
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32