prithivMLmods commited on
Commit
f75e630
·
verified ·
1 Parent(s): f50453e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -70
app.py CHANGED
@@ -13,7 +13,6 @@ from transformers import (
13
  AutoModelForCausalLM,
14
  AutoProcessor,
15
  TextIteratorStreamer,
16
- VisionEncoderDecoderModel,
17
  )
18
  from gradio.themes import Soft
19
  from gradio.themes.utils import colors, fonts, sizes
@@ -92,20 +91,27 @@ css = """
92
  """
93
 
94
  # --- Fix for Dots.OCR Processor Loading ---
 
 
95
  CACHE_PATH = "./model_cache"
96
  if not os.path.exists(CACHE_PATH):
97
  os.makedirs(CACHE_PATH)
98
 
 
99
  model_path_d_local = snapshot_download(
100
  repo_id='rednote-hilab/dots.ocr',
101
- local_dir=CACHE_PATH,
102
  max_workers=20,
103
  local_dir_use_symlinks=False
104
  )
 
 
105
  config_file_path = os.path.join(model_path_d_local, "configuration_dots.py")
 
106
  if os.path.exists(config_file_path):
107
  with open(config_file_path, 'r') as f:
108
  input_code = f.read()
 
109
  lines = input_code.splitlines()
110
  if "class DotsVLProcessor" in input_code and not any("attributes = " in line for line in lines):
111
  output_lines = []
@@ -116,12 +122,18 @@ if os.path.exists(config_file_path):
116
  with open(config_file_path, 'w') as f:
117
  f.write('\n'.join(output_lines))
118
  print("Patched configuration_dots.py successfully.")
 
 
119
  sys.path.append(model_path_d_local)
120
 
 
121
  # --- Model Loading ---
 
 
122
  MAX_MAX_NEW_TOKENS = 4096
123
  DEFAULT_MAX_NEW_TOKENS = 2048
124
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
125
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
126
 
127
  # Load Nanonets-OCR2-3B
@@ -133,7 +145,7 @@ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
133
  torch_dtype=torch.float16
134
  ).to(device).eval()
135
 
136
- # Load Dots.OCR
137
  MODEL_PATH_D = model_path_d_local
138
  processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
139
  model_d = AutoModelForCausalLM.from_pretrained(
@@ -144,10 +156,14 @@ model_d = AutoModelForCausalLM.from_pretrained(
144
  trust_remote_code=True
145
  ).eval()
146
 
147
- # Load ByteDance/Dolphin
148
- MODEL_ID_B = "ByteDance/Dolphin"
149
- processor_b = AutoProcessor.from_pretrained(MODEL_ID_B)
150
- model_b = VisionEncoderDecoderModel.from_pretrained(MODEL_ID_B, torch_dtype=torch.float16).to(device).eval()
 
 
 
 
151
 
152
 
153
  @spaces.GPU
@@ -158,76 +174,48 @@ def generate_image(model_name: str, text: str, image: Image.Image,
158
  top_k: int = 50,
159
  repetition_penalty: float = 1.2):
160
  """Generate responses for image input using the selected model."""
161
- if image is None:
162
- yield "Please upload an image.", "Please upload an image."
163
- return
164
-
165
- images = [image.convert("RGB")]
166
-
167
  if model_name == "Nanonets-OCR2-3B":
168
  processor, model = processor_m, model_m
169
- messages = [{"role": "user", "content": [{"type": "image"}] + [{"type": "text", "text": text}]}]
170
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
171
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
172
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
173
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, "do_sample": True}
174
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
175
- thread.start()
176
- buffer = ""
177
- for new_text in streamer:
178
- buffer += new_text.replace("<|im_end|>", "").replace("<end_of_utterance>", "")
179
- yield buffer, buffer
180
-
181
  elif model_name == "Dots.OCR":
182
  processor, model = processor_d, model_d
183
- messages = [{"role": "user", "content": [{"type": "image"}] + [{"type": "text", "text": text}]}]
184
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
185
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
186
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
187
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, "do_sample": True}
188
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
189
- thread.start()
190
- buffer = ""
191
- for new_text in streamer:
192
- buffer += new_text.replace("<|im_end|>", "").replace("<end_of_utterance>", "")
193
- yield buffer, buffer
194
-
195
- elif model_name == "ByteDance/Dolphin":
196
- processor, model = processor_b, model_b
197
- pixel_values = processor(images=images, return_tensors="pt").pixel_values.to(device, torch.float16)
198
-
199
- prompt_template = f"<s>{text} <Answer/>"
200
- prompt_inputs = processor.tokenizer(
201
- [prompt_template],
202
- add_special_tokens=False,
203
- return_tensors="pt"
204
- )
205
- prompt_ids = prompt_inputs.input_ids.to(device)
206
- attention_mask = prompt_inputs.attention_mask.to(device)
207
-
208
- outputs = model.generate(
209
- pixel_values=pixel_values,
210
- decoder_input_ids=prompt_ids,
211
- decoder_attention_mask=attention_mask,
212
- max_length=max_new_tokens,
213
- pad_token_id=processor.tokenizer.pad_token_id,
214
- eos_token_id=processor.tokenizer.eos_token_id,
215
- use_cache=True,
216
- bad_words_ids=[[processor.tokenizer.unk_token_id]],
217
- return_dict_in_generate=True,
218
- do_sample=False, # Dolphin works best with greedy decoding
219
- num_beams=1,
220
- repetition_penalty=repetition_penalty
221
- )
222
-
223
- sequence = processor.tokenizer.decode(outputs.sequences[0], skip_special_tokens=False)
224
- cleaned_output = sequence.replace(prompt_template, "").replace("<pad>", "").replace("</s>", "").strip()
225
- yield cleaned_output, cleaned_output
226
-
227
  else:
228
  yield "Invalid model selected.", "Invalid model selected."
229
  return
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  # Define examples for image inference
232
  image_examples = [
233
  ["Reconstruct the doc [table] as it is.", "images/0.png"],
@@ -259,7 +247,7 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
259
  formatted_output = gr.Markdown(label="Formatted Result")
260
 
261
  model_choice = gr.Radio(
262
- choices=["Nanonets-OCR2-3B", "Dots.OCR", "ByteDance/Dolphin"],
263
  label="Select Model",
264
  value="Nanonets-OCR2-3B"
265
  )
 
13
  AutoModelForCausalLM,
14
  AutoProcessor,
15
  TextIteratorStreamer,
 
16
  )
17
  from gradio.themes import Soft
18
  from gradio.themes.utils import colors, fonts, sizes
 
91
  """
92
 
93
  # --- Fix for Dots.OCR Processor Loading ---
94
+
95
+ # Define a local directory to cache the model
96
  CACHE_PATH = "./model_cache"
97
  if not os.path.exists(CACHE_PATH):
98
  os.makedirs(CACHE_PATH)
99
 
100
+ # Download the model files locally
101
  model_path_d_local = snapshot_download(
102
  repo_id='rednote-hilab/dots.ocr',
103
+ local_dir=os.path.join(CACHE_PATH, 'dots.ocr'),
104
  max_workers=20,
105
  local_dir_use_symlinks=False
106
  )
107
+
108
+ # Modify the configuration file to fix the processor loading issue
109
  config_file_path = os.path.join(model_path_d_local, "configuration_dots.py")
110
+
111
  if os.path.exists(config_file_path):
112
  with open(config_file_path, 'r') as f:
113
  input_code = f.read()
114
+
115
  lines = input_code.splitlines()
116
  if "class DotsVLProcessor" in input_code and not any("attributes = " in line for line in lines):
117
  output_lines = []
 
122
  with open(config_file_path, 'w') as f:
123
  f.write('\n'.join(output_lines))
124
  print("Patched configuration_dots.py successfully.")
125
+
126
+ # Add the local model path to sys.path so transformers can use the modified code
127
  sys.path.append(model_path_d_local)
128
 
129
+
130
  # --- Model Loading ---
131
+
132
+ # Constants for text generation
133
  MAX_MAX_NEW_TOKENS = 4096
134
  DEFAULT_MAX_NEW_TOKENS = 2048
135
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
136
+
137
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
138
 
139
  # Load Nanonets-OCR2-3B
 
145
  torch_dtype=torch.float16
146
  ).to(device).eval()
147
 
148
+ # Load Dots.OCR from the local, patched directory
149
  MODEL_PATH_D = model_path_d_local
150
  processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
151
  model_d = AutoModelForCausalLM.from_pretrained(
 
156
  trust_remote_code=True
157
  ).eval()
158
 
159
+ # Load PaddleOCR
160
+ MODEL_ID_P = "strangervisionhf/paddle"
161
+ processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True)
162
+ model_p = AutoModelForCausalLM.from_pretrained(
163
+ MODEL_ID_P,
164
+ trust_remote_code=True,
165
+ torch_dtype=torch.bfloat16
166
+ ).to(device).eval()
167
 
168
 
169
  @spaces.GPU
 
174
  top_k: int = 50,
175
  repetition_penalty: float = 1.2):
176
  """Generate responses for image input using the selected model."""
 
 
 
 
 
 
177
  if model_name == "Nanonets-OCR2-3B":
178
  processor, model = processor_m, model_m
 
 
 
 
 
 
 
 
 
 
 
 
179
  elif model_name == "Dots.OCR":
180
  processor, model = processor_d, model_d
181
+ elif model_name == "PaddleOCR":
182
+ processor, model = processor_p, model_p
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  else:
184
  yield "Invalid model selected.", "Invalid model selected."
185
  return
186
 
187
+ if image is None:
188
+ yield "Please upload an image.", "Please upload an image."
189
+ return
190
+
191
+ images = [image.convert("RGB")]
192
+
193
+ messages = [
194
+ {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": text}]}
195
+ ]
196
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
197
+
198
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
199
+
200
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
201
+ generation_kwargs = {
202
+ **inputs,
203
+ "streamer": streamer,
204
+ "max_new_tokens": max_new_tokens,
205
+ "temperature": temperature,
206
+ "top_p": top_p,
207
+ "top_k": top_k,
208
+ "repetition_penalty": repetition_penalty,
209
+ "do_sample": True
210
+ }
211
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
212
+ thread.start()
213
+
214
+ buffer = ""
215
+ for new_text in streamer:
216
+ buffer += new_text.replace("<|im_end|>", "").replace("<end_of_utterance>", "")
217
+ yield buffer, buffer
218
+
219
  # Define examples for image inference
220
  image_examples = [
221
  ["Reconstruct the doc [table] as it is.", "images/0.png"],
 
247
  formatted_output = gr.Markdown(label="Formatted Result")
248
 
249
  model_choice = gr.Radio(
250
+ choices=["Nanonets-OCR2-3B", "Dots.OCR", "PaddleOCR"],
251
  label="Select Model",
252
  value="Nanonets-OCR2-3B"
253
  )