prithivMLmods commited on
Commit
8690171
·
verified ·
1 Parent(s): 924cc45

update app

Browse files
Files changed (1) hide show
  1. app.py +38 -12
app.py CHANGED
@@ -13,6 +13,7 @@ from transformers import (
13
  AutoModelForCausalLM,
14
  AutoProcessor,
15
  TextIteratorStreamer,
 
16
  )
17
 
18
  from gradio.themes import Soft
@@ -160,6 +161,16 @@ model_d = AutoModelForCausalLM.from_pretrained(
160
  trust_remote_code=True
161
  ).eval()
162
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  @spaces.GPU
165
  def generate_image(model_name: str, text: str, image: Image.Image,
@@ -173,6 +184,8 @@ def generate_image(model_name: str, text: str, image: Image.Image,
173
  processor, model = processor_m, model_m
174
  elif model_name == "Dots.OCR":
175
  processor, model = processor_d, model_d
 
 
176
  else:
177
  yield "Invalid model selected.", "Invalid model selected."
178
  return
@@ -183,16 +196,29 @@ def generate_image(model_name: str, text: str, image: Image.Image,
183
 
184
  images = [image.convert("RGB")]
185
 
186
- messages = [
187
- {
188
- "role": "user",
189
- "content": [{"type": "image"}] + [{"type": "text", "text": text}]
190
- }
191
- ]
192
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
193
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
196
  generation_kwargs = {
197
  **inputs,
198
  "streamer": streamer,
@@ -237,14 +263,14 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
237
 
238
  with gr.Column(scale=3):
239
  gr.Markdown("## Output", elem_id="output-title")
240
- raw_output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=10, show_copy_button=True)
241
  with gr.Accordion("Formatted Result", open=False):
242
  formatted_output = gr.Markdown(label="Formatted Result")
243
 
244
  model_choice = gr.Radio(
245
- choices=["Nanonets-OCR2-3B", "Dots.OCR"],
246
  label="Select Model",
247
- value="Nanonets-OCR2-3B"
248
  )
249
 
250
  image_submit.click(
 
13
  AutoModelForCausalLM,
14
  AutoProcessor,
15
  TextIteratorStreamer,
16
+ AutoTokenizer, # Added for DeepSeek, though AutoProcessor is used
17
  )
18
 
19
  from gradio.themes import Soft
 
161
  trust_remote_code=True
162
  ).eval()
163
 
164
+ # Load DeepSeek-OCR
165
+ MODEL_ID_S = 'deepseek-ai/DeepSeek-OCR'
166
+ processor_s = AutoProcessor.from_pretrained(MODEL_ID_S, trust_remote_code=True)
167
+ model_s = AutoModelForCausalLM.from_pretrained(
168
+ MODEL_ID_S,
169
+ _attn_implementation='flash_attention_2',
170
+ trust_remote_code=True,
171
+ use_safetensors=True
172
+ ).eval().to(device).to(torch.bfloat16)
173
+
174
 
175
  @spaces.GPU
176
  def generate_image(model_name: str, text: str, image: Image.Image,
 
184
  processor, model = processor_m, model_m
185
  elif model_name == "Dots.OCR":
186
  processor, model = processor_d, model_d
187
+ elif model_name == "DeepSeek-OCR":
188
+ processor, model = processor_s, model_s
189
  else:
190
  yield "Invalid model selected.", "Invalid model selected."
191
  return
 
196
 
197
  images = [image.convert("RGB")]
198
 
199
+ # For DeepSeek-OCR, the recommended prompt format is slightly different
200
+ if model_name == "DeepSeek-OCR":
201
+ # Using a format found in documentation for better performance
202
+ prompt_text = f"<image>\n<|grounding|>{text}"
203
+ messages = [
204
+ {"role": "user", "content": prompt_text}
205
+ ]
206
+ # apply_chat_template is not used directly, instead we build the prompt manually
207
+ prompt = processor.tokenizer.apply_chat_template(messages, add_generation_prompt=True)
208
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
209
+
210
+ else:
211
+ messages = [
212
+ {
213
+ "role": "user",
214
+ "content": [{"type": "image"}] + [{"type": "text", "text": text}]
215
+ }
216
+ ]
217
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
218
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
219
+
220
 
221
+ streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
222
  generation_kwargs = {
223
  **inputs,
224
  "streamer": streamer,
 
263
 
264
  with gr.Column(scale=3):
265
  gr.Markdown("## Output", elem_id="output-title")
266
+ raw_output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=9, show_copy_button=True)
267
  with gr.Accordion("Formatted Result", open=False):
268
  formatted_output = gr.Markdown(label="Formatted Result")
269
 
270
  model_choice = gr.Radio(
271
+ choices=["DeepSeek-OCR", "Nanonets-OCR2-3B", "Dots.OCR"],
272
  label="Select Model",
273
+ value="DeepSeek-OCR"
274
  )
275
 
276
  image_submit.click(