prithivMLmods commited on
Commit
9efae34
·
verified ·
1 Parent(s): 022a079

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -30
app.py CHANGED
@@ -13,6 +13,7 @@ from transformers import (
13
  AutoModelForCausalLM,
14
  AutoProcessor,
15
  TextIteratorStreamer,
 
16
  )
17
  from gradio.themes import Soft
18
  from gradio.themes.utils import colors, fonts, sizes
@@ -100,7 +101,7 @@ if not os.path.exists(CACHE_PATH):
100
  # Download the model files locally
101
  model_path_d_local = snapshot_download(
102
  repo_id='rednote-hilab/dots.ocr',
103
- local_dir=CACHE_PATH,
104
  max_workers=20,
105
  local_dir_use_symlinks=False
106
  )
@@ -159,6 +160,12 @@ model_d = AutoModelForCausalLM.from_pretrained(
159
  trust_remote_code=True
160
  ).eval()
161
 
 
 
 
 
 
 
162
 
163
  @spaces.GPU
164
  def generate_image(model_name: str, text: str, image: Image.Image,
@@ -168,10 +175,14 @@ def generate_image(model_name: str, text: str, image: Image.Image,
168
  top_k: int = 50,
169
  repetition_penalty: float = 1.2):
170
  """Generate responses for image input using the selected model."""
 
171
  if model_name == "Nanonets-OCR2-3B":
172
  processor, model = processor_m, model_m
173
  elif model_name == "Dots.OCR":
174
  processor, model = processor_d, model_d
 
 
 
175
  else:
176
  yield "Invalid model selected.", "Invalid model selected."
177
  return
@@ -180,35 +191,48 @@ def generate_image(model_name: str, text: str, image: Image.Image,
180
  yield "Please upload an image.", "Please upload an image."
181
  return
182
 
183
- images = [image.convert("RGB")]
184
-
185
- messages = [
186
- {
187
- "role": "user",
188
- "content": [{"type": "image"}] + [{"type": "text", "text": text}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  }
190
- ]
191
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
192
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
193
-
194
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
195
- generation_kwargs = {
196
- **inputs,
197
- "streamer": streamer,
198
- "max_new_tokens": max_new_tokens,
199
- "temperature": temperature,
200
- "top_p": top_p,
201
- "top_k": top_k,
202
- "repetition_penalty": repetition_penalty,
203
- "do_sample": True
204
- }
205
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
206
- thread.start()
207
-
208
- buffer = ""
209
- for new_text in streamer:
210
- buffer += new_text.replace("<|im_end|>", "").replace("<end_of_utterance>", "")
211
- yield buffer, buffer
212
 
213
  # Define examples for image inference
214
  image_examples = [
@@ -241,7 +265,7 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
241
  formatted_output = gr.Markdown(label="Formatted Result")
242
 
243
  model_choice = gr.Radio(
244
- choices=["Nanonets-OCR2-3B", "Dots.OCR"],
245
  label="Select Model",
246
  value="Nanonets-OCR2-3B"
247
  )
 
13
  AutoModelForCausalLM,
14
  AutoProcessor,
15
  TextIteratorStreamer,
16
+ VisionEncoderDecoderModel,
17
  )
18
  from gradio.themes import Soft
19
  from gradio.themes.utils import colors, fonts, sizes
 
101
  # Download the model files locally
102
  model_path_d_local = snapshot_download(
103
  repo_id='rednote-hilab/dots.ocr',
104
+ local_dir=os.path.join(CACHE_PATH, 'dots.ocr'),
105
  max_workers=20,
106
  local_dir_use_symlinks=False
107
  )
 
160
  trust_remote_code=True
161
  ).eval()
162
 
163
+ # Load ByteDance/Dolphin
164
+ MODEL_ID_B = "ByteDance/Dolphin"
165
+ processor_b = AutoProcessor.from_pretrained(MODEL_ID_B)
166
+ model_b = VisionEncoderDecoderModel.from_pretrained(MODEL_ID_B)
167
+ model_b.to(device).eval().half()
168
+
169
 
170
  @spaces.GPU
171
  def generate_image(model_name: str, text: str, image: Image.Image,
 
175
  top_k: int = 50,
176
  repetition_penalty: float = 1.2):
177
  """Generate responses for image input using the selected model."""
178
+ is_streaming = True
179
  if model_name == "Nanonets-OCR2-3B":
180
  processor, model = processor_m, model_m
181
  elif model_name == "Dots.OCR":
182
  processor, model = processor_d, model_d
183
+ elif model_name == "Dolphin":
184
+ processor, model = processor_b, model_b
185
+ is_streaming = False
186
  else:
187
  yield "Invalid model selected.", "Invalid model selected."
188
  return
 
191
  yield "Please upload an image.", "Please upload an image."
192
  return
193
 
194
+ image_rgb = image.convert("RGB")
195
+
196
+ if is_streaming:
197
+ messages = [
198
+ {
199
+ "role": "user",
200
+ "content": [{"type": "image"}] + [{"type": "text", "text": text}]
201
+ }
202
+ ]
203
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
204
+ inputs = processor(text=prompt, images=[image_rgb], return_tensors="pt").to(device)
205
+
206
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
207
+ generation_kwargs = {
208
+ **inputs,
209
+ "streamer": streamer,
210
+ "max_new_tokens": max_new_tokens,
211
+ "temperature": temperature,
212
+ "top_p": top_p,
213
+ "top_k": top_k,
214
+ "repetition_penalty": repetition_penalty,
215
+ "do_sample": True
216
  }
217
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
218
+ thread.start()
219
+
220
+ buffer = ""
221
+ for new_text in streamer:
222
+ buffer += new_text.replace("<|im_end|>", "").replace("<end_of_utterance>", "")
223
+ yield buffer, buffer
224
+ else:
225
+ # Handle non-streaming generation for ByteDance/Dolphin
226
+ pixel_values = processor(images=[image_rgb], return_tensors="pt").pixel_values.to(device).half()
227
+
228
+ # Note: The user's text query is not explicitly used here as the VisionEncoderDecoderModel
229
+ # pipeline primarily generates captions from images directly.
230
+ generated_ids = model.generate(pixel_values, max_new_tokens=max_new_tokens)
231
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
232
+
233
+ # For this model, the output appears at once.
234
+ yield generated_text, generated_text
235
+
 
 
 
236
 
237
  # Define examples for image inference
238
  image_examples = [
 
265
  formatted_output = gr.Markdown(label="Formatted Result")
266
 
267
  model_choice = gr.Radio(
268
+ choices=["Nanonets-OCR2-3B", "Dots.OCR", "Dolphin"],
269
  label="Select Model",
270
  value="Nanonets-OCR2-3B"
271
  )