prithivMLmods commited on
Commit
476fc05
·
verified ·
1 Parent(s): b737007

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -18
app.py CHANGED
@@ -18,7 +18,6 @@ from transformers import (
18
  from gradio.themes import Soft
19
  from gradio.themes.utils import colors, fonts, sizes
20
 
21
- # --- Theme and CSS Definition ---
22
 
23
  colors.steel_blue = colors.Color(
24
  name="steel_blue",
@@ -27,7 +26,7 @@ colors.steel_blue = colors.Color(
27
  c200="#A8CCE1",
28
  c300="#7DB3D2",
29
  c400="#529AC3",
30
- c500="#4682B4", # SteelBlue base color
31
  c600="#3E72A0",
32
  c700="#36638C",
33
  c800="#2E5378",
@@ -91,22 +90,20 @@ css = """
91
  }
92
  """
93
 
94
- # --- Fix for Dots.OCR Processor Loading ---
95
 
96
- # Define a local directory to cache the model
97
  CACHE_PATH = "./model_cache"
98
  if not os.path.exists(CACHE_PATH):
99
  os.makedirs(CACHE_PATH)
100
 
101
- # Download the model files locally
102
  model_path_d_local = snapshot_download(
103
  repo_id='rednote-hilab/dots.ocr',
104
- local_dir=os.path.join(CACHE_PATH, 'dots.ocr'), # Create a dedicated subfolder
105
  max_workers=20,
106
  local_dir_use_symlinks=False
107
  )
108
 
109
- # Modify the configuration file to fix the processor loading issue
110
  config_file_path = os.path.join(model_path_d_local, "configuration_dots.py")
111
 
112
  if os.path.exists(config_file_path):
@@ -119,28 +116,24 @@ if os.path.exists(config_file_path):
119
  for line in lines:
120
  output_lines.append(line)
121
  if line.strip().startswith("class DotsVLProcessor"):
122
- # Insert the attributes line to specify which processors to load
123
  output_lines.append(" attributes = [\"image_processor\", \"tokenizer\"]")
124
 
125
- # Write the modified content back to the file
126
  with open(config_file_path, 'w') as f:
127
  f.write('\n'.join(output_lines))
128
  print("Patched configuration_dots.py successfully.")
129
 
130
- # Add the local model path to sys.path so transformers can use the modified code
131
  sys.path.append(model_path_d_local)
132
 
133
 
134
- # --- Model Loading ---
135
 
136
- # Constants for text generation
137
  MAX_MAX_NEW_TOKENS = 4096
138
  DEFAULT_MAX_NEW_TOKENS = 2048
139
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
140
 
141
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
142
 
143
- # Load Nanonets-OCR2-3B
144
  MODEL_ID_M = "nanonets/Nanonets-OCR2-3B"
145
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
146
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
@@ -149,7 +142,7 @@ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
149
  torch_dtype=torch.float16
150
  ).to(device).eval()
151
 
152
- # Load Dots.OCR from the local, patched directory
153
  MODEL_PATH_D = model_path_d_local
154
  processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
155
  model_d = AutoModelForCausalLM.from_pretrained(
@@ -160,7 +153,7 @@ model_d = AutoModelForCausalLM.from_pretrained(
160
  trust_remote_code=True
161
  ).eval()
162
 
163
- # Load PaddleOCR
164
  MODEL_ID_P = "strangervisionhf/paddle"
165
  processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True)
166
  model_p = AutoModelForCausalLM.from_pretrained(
@@ -222,16 +215,16 @@ def generate_image(model_name: str, text: str, image: Image.Image,
222
  buffer += new_text.replace("<|im_end|>", "").replace("<end_of_utterance>", "")
223
  yield buffer, buffer
224
 
225
- # Define examples for image inference
226
  image_examples = [
227
  ["Reconstruct the doc [table] as it is.", "images/0.png"],
228
  ["Describe the image!", "images/8.png"],
229
  ["OCR the image", "images/2.jpg"],
230
  ]
231
 
232
- # Create the Gradio Interface
233
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
234
- gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
235
  with gr.Row():
236
  with gr.Column(scale=2):
237
  image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
 
18
  from gradio.themes import Soft
19
  from gradio.themes.utils import colors, fonts, sizes
20
 
 
21
 
22
  colors.steel_blue = colors.Color(
23
  name="steel_blue",
 
26
  c200="#A8CCE1",
27
  c300="#7DB3D2",
28
  c400="#529AC3",
29
+ c500="#4682B4",
30
  c600="#3E72A0",
31
  c700="#36638C",
32
  c800="#2E5378",
 
90
  }
91
  """
92
 
 
93
 
 
94
  CACHE_PATH = "./model_cache"
95
  if not os.path.exists(CACHE_PATH):
96
  os.makedirs(CACHE_PATH)
97
 
98
+
99
  model_path_d_local = snapshot_download(
100
  repo_id='rednote-hilab/dots.ocr',
101
+ local_dir=os.path.join(CACHE_PATH, 'dots.ocr'),
102
  max_workers=20,
103
  local_dir_use_symlinks=False
104
  )
105
 
106
+
107
  config_file_path = os.path.join(model_path_d_local, "configuration_dots.py")
108
 
109
  if os.path.exists(config_file_path):
 
116
  for line in lines:
117
  output_lines.append(line)
118
  if line.strip().startswith("class DotsVLProcessor"):
 
119
  output_lines.append(" attributes = [\"image_processor\", \"tokenizer\"]")
120
 
 
121
  with open(config_file_path, 'w') as f:
122
  f.write('\n'.join(output_lines))
123
  print("Patched configuration_dots.py successfully.")
124
 
125
+
126
  sys.path.append(model_path_d_local)
127
 
128
 
 
129
 
 
130
  MAX_MAX_NEW_TOKENS = 4096
131
  DEFAULT_MAX_NEW_TOKENS = 2048
132
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
133
 
134
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
135
 
136
+
137
  MODEL_ID_M = "nanonets/Nanonets-OCR2-3B"
138
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
139
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
 
142
  torch_dtype=torch.float16
143
  ).to(device).eval()
144
 
145
+
146
  MODEL_PATH_D = model_path_d_local
147
  processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
148
  model_d = AutoModelForCausalLM.from_pretrained(
 
153
  trust_remote_code=True
154
  ).eval()
155
 
156
+
157
  MODEL_ID_P = "strangervisionhf/paddle"
158
  processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True)
159
  model_p = AutoModelForCausalLM.from_pretrained(
 
215
  buffer += new_text.replace("<|im_end|>", "").replace("<end_of_utterance>", "")
216
  yield buffer, buffer
217
 
218
+
219
  image_examples = [
220
  ["Reconstruct the doc [table] as it is.", "images/0.png"],
221
  ["Describe the image!", "images/8.png"],
222
  ["OCR the image", "images/2.jpg"],
223
  ]
224
 
225
+
226
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
227
+ gr.Markdown("# **Multimodal OCR3**", elem_id="main-title")
228
  with gr.Row():
229
  with gr.Column(scale=2):
230
  image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")