prithivMLmods commited on
Commit
f50453e
·
verified ·
1 Parent(s): 9efae34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -64
app.py CHANGED
@@ -92,52 +92,36 @@ css = """
92
  """
93
 
94
  # --- Fix for Dots.OCR Processor Loading ---
95
-
96
- # Define a local directory to cache the model
97
  CACHE_PATH = "./model_cache"
98
  if not os.path.exists(CACHE_PATH):
99
  os.makedirs(CACHE_PATH)
100
 
101
- # Download the model files locally
102
  model_path_d_local = snapshot_download(
103
  repo_id='rednote-hilab/dots.ocr',
104
- local_dir=os.path.join(CACHE_PATH, 'dots.ocr'),
105
  max_workers=20,
106
  local_dir_use_symlinks=False
107
  )
108
-
109
- # Modify the configuration file to fix the processor loading issue
110
  config_file_path = os.path.join(model_path_d_local, "configuration_dots.py")
111
-
112
  if os.path.exists(config_file_path):
113
  with open(config_file_path, 'r') as f:
114
  input_code = f.read()
115
-
116
  lines = input_code.splitlines()
117
  if "class DotsVLProcessor" in input_code and not any("attributes = " in line for line in lines):
118
  output_lines = []
119
  for line in lines:
120
  output_lines.append(line)
121
  if line.strip().startswith("class DotsVLProcessor"):
122
- # Insert the attributes line to specify which processors to load
123
  output_lines.append(" attributes = [\"image_processor\", \"tokenizer\"]")
124
-
125
- # Write the modified content back to the file
126
  with open(config_file_path, 'w') as f:
127
  f.write('\n'.join(output_lines))
128
  print("Patched configuration_dots.py successfully.")
129
-
130
- # Add the local model path to sys.path so transformers can use the modified code
131
  sys.path.append(model_path_d_local)
132
 
133
-
134
  # --- Model Loading ---
135
-
136
- # Constants for text generation
137
  MAX_MAX_NEW_TOKENS = 4096
138
  DEFAULT_MAX_NEW_TOKENS = 2048
139
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
140
-
141
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
142
 
143
  # Load Nanonets-OCR2-3B
@@ -149,7 +133,7 @@ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
149
  torch_dtype=torch.float16
150
  ).to(device).eval()
151
 
152
- # Load Dots.OCR from the local, patched directory
153
  MODEL_PATH_D = model_path_d_local
154
  processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
155
  model_d = AutoModelForCausalLM.from_pretrained(
@@ -163,8 +147,7 @@ model_d = AutoModelForCausalLM.from_pretrained(
163
  # Load ByteDance/Dolphin
164
  MODEL_ID_B = "ByteDance/Dolphin"
165
  processor_b = AutoProcessor.from_pretrained(MODEL_ID_B)
166
- model_b = VisionEncoderDecoderModel.from_pretrained(MODEL_ID_B)
167
- model_b.to(device).eval().half()
168
 
169
 
170
  @spaces.GPU
@@ -175,64 +158,75 @@ def generate_image(model_name: str, text: str, image: Image.Image,
175
  top_k: int = 50,
176
  repetition_penalty: float = 1.2):
177
  """Generate responses for image input using the selected model."""
178
- is_streaming = True
179
- if model_name == "Nanonets-OCR2-3B":
180
- processor, model = processor_m, model_m
181
- elif model_name == "Dots.OCR":
182
- processor, model = processor_d, model_d
183
- elif model_name == "Dolphin":
184
- processor, model = processor_b, model_b
185
- is_streaming = False
186
- else:
187
- yield "Invalid model selected.", "Invalid model selected."
188
- return
189
-
190
  if image is None:
191
  yield "Please upload an image.", "Please upload an image."
192
  return
193
 
194
- image_rgb = image.convert("RGB")
195
 
196
- if is_streaming:
197
- messages = [
198
- {
199
- "role": "user",
200
- "content": [{"type": "image"}] + [{"type": "text", "text": text}]
201
- }
202
- ]
203
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
204
- inputs = processor(text=prompt, images=[image_rgb], return_tensors="pt").to(device)
205
-
206
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
207
- generation_kwargs = {
208
- **inputs,
209
- "streamer": streamer,
210
- "max_new_tokens": max_new_tokens,
211
- "temperature": temperature,
212
- "top_p": top_p,
213
- "top_k": top_k,
214
- "repetition_penalty": repetition_penalty,
215
- "do_sample": True
216
- }
217
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
218
  thread.start()
219
-
220
  buffer = ""
221
  for new_text in streamer:
222
  buffer += new_text.replace("<|im_end|>", "").replace("<end_of_utterance>", "")
223
  yield buffer, buffer
224
- else:
225
- # Handle non-streaming generation for ByteDance/Dolphin
226
- pixel_values = processor(images=[image_rgb], return_tensors="pt").pixel_values.to(device).half()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
- # Note: The user's text query is not explicitly used here as the VisionEncoderDecoderModel
229
- # pipeline primarily generates captions from images directly.
230
- generated_ids = model.generate(pixel_values, max_new_tokens=max_new_tokens)
231
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
- # For this model, the output appears at once.
234
- yield generated_text, generated_text
235
-
 
 
 
 
236
 
237
  # Define examples for image inference
238
  image_examples = [
@@ -265,7 +259,7 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
265
  formatted_output = gr.Markdown(label="Formatted Result")
266
 
267
  model_choice = gr.Radio(
268
- choices=["Nanonets-OCR2-3B", "Dots.OCR", "Dolphin"],
269
  label="Select Model",
270
  value="Nanonets-OCR2-3B"
271
  )
 
92
  """
93
 
94
  # --- Fix for Dots.OCR Processor Loading ---
 
 
95
  CACHE_PATH = "./model_cache"
96
  if not os.path.exists(CACHE_PATH):
97
  os.makedirs(CACHE_PATH)
98
 
 
99
  model_path_d_local = snapshot_download(
100
  repo_id='rednote-hilab/dots.ocr',
101
+ local_dir=CACHE_PATH,
102
  max_workers=20,
103
  local_dir_use_symlinks=False
104
  )
 
 
105
  config_file_path = os.path.join(model_path_d_local, "configuration_dots.py")
 
106
  if os.path.exists(config_file_path):
107
  with open(config_file_path, 'r') as f:
108
  input_code = f.read()
 
109
  lines = input_code.splitlines()
110
  if "class DotsVLProcessor" in input_code and not any("attributes = " in line for line in lines):
111
  output_lines = []
112
  for line in lines:
113
  output_lines.append(line)
114
  if line.strip().startswith("class DotsVLProcessor"):
 
115
  output_lines.append(" attributes = [\"image_processor\", \"tokenizer\"]")
 
 
116
  with open(config_file_path, 'w') as f:
117
  f.write('\n'.join(output_lines))
118
  print("Patched configuration_dots.py successfully.")
 
 
119
  sys.path.append(model_path_d_local)
120
 
 
121
  # --- Model Loading ---
 
 
122
  MAX_MAX_NEW_TOKENS = 4096
123
  DEFAULT_MAX_NEW_TOKENS = 2048
124
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
125
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
126
 
127
  # Load Nanonets-OCR2-3B
 
133
  torch_dtype=torch.float16
134
  ).to(device).eval()
135
 
136
+ # Load Dots.OCR
137
  MODEL_PATH_D = model_path_d_local
138
  processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
139
  model_d = AutoModelForCausalLM.from_pretrained(
 
147
  # Load ByteDance/Dolphin
148
  MODEL_ID_B = "ByteDance/Dolphin"
149
  processor_b = AutoProcessor.from_pretrained(MODEL_ID_B)
150
+ model_b = VisionEncoderDecoderModel.from_pretrained(MODEL_ID_B, torch_dtype=torch.float16).to(device).eval()
 
151
 
152
 
153
  @spaces.GPU
 
158
  top_k: int = 50,
159
  repetition_penalty: float = 1.2):
160
  """Generate responses for image input using the selected model."""
 
 
 
 
 
 
 
 
 
 
 
 
161
  if image is None:
162
  yield "Please upload an image.", "Please upload an image."
163
  return
164
 
165
+ images = [image.convert("RGB")]
166
 
167
+ if model_name == "Nanonets-OCR2-3B":
168
+ processor, model = processor_m, model_m
169
+ messages = [{"role": "user", "content": [{"type": "image"}] + [{"type": "text", "text": text}]}]
 
 
 
 
170
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
171
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
 
172
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
173
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, "do_sample": True}
 
 
 
 
 
 
 
 
 
174
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
175
  thread.start()
 
176
  buffer = ""
177
  for new_text in streamer:
178
  buffer += new_text.replace("<|im_end|>", "").replace("<end_of_utterance>", "")
179
  yield buffer, buffer
180
+
181
+ elif model_name == "Dots.OCR":
182
+ processor, model = processor_d, model_d
183
+ messages = [{"role": "user", "content": [{"type": "image"}] + [{"type": "text", "text": text}]}]
184
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
185
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
186
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
187
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, "do_sample": True}
188
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
189
+ thread.start()
190
+ buffer = ""
191
+ for new_text in streamer:
192
+ buffer += new_text.replace("<|im_end|>", "").replace("<end_of_utterance>", "")
193
+ yield buffer, buffer
194
+
195
+ elif model_name == "ByteDance/Dolphin":
196
+ processor, model = processor_b, model_b
197
+ pixel_values = processor(images=images, return_tensors="pt").pixel_values.to(device, torch.float16)
198
 
199
+ prompt_template = f"<s>{text} <Answer/>"
200
+ prompt_inputs = processor.tokenizer(
201
+ [prompt_template],
202
+ add_special_tokens=False,
203
+ return_tensors="pt"
204
+ )
205
+ prompt_ids = prompt_inputs.input_ids.to(device)
206
+ attention_mask = prompt_inputs.attention_mask.to(device)
207
+
208
+ outputs = model.generate(
209
+ pixel_values=pixel_values,
210
+ decoder_input_ids=prompt_ids,
211
+ decoder_attention_mask=attention_mask,
212
+ max_length=max_new_tokens,
213
+ pad_token_id=processor.tokenizer.pad_token_id,
214
+ eos_token_id=processor.tokenizer.eos_token_id,
215
+ use_cache=True,
216
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
217
+ return_dict_in_generate=True,
218
+ do_sample=False, # Dolphin works best with greedy decoding
219
+ num_beams=1,
220
+ repetition_penalty=repetition_penalty
221
+ )
222
 
223
+ sequence = processor.tokenizer.decode(outputs.sequences[0], skip_special_tokens=False)
224
+ cleaned_output = sequence.replace(prompt_template, "").replace("<pad>", "").replace("</s>", "").strip()
225
+ yield cleaned_output, cleaned_output
226
+
227
+ else:
228
+ yield "Invalid model selected.", "Invalid model selected."
229
+ return
230
 
231
  # Define examples for image inference
232
  image_examples = [
 
259
  formatted_output = gr.Markdown(label="Formatted Result")
260
 
261
  model_choice = gr.Radio(
262
+ choices=["Nanonets-OCR2-3B", "Dots.OCR", "ByteDance/Dolphin"],
263
  label="Select Model",
264
  value="Nanonets-OCR2-3B"
265
  )