prithivMLmods commited on
Commit
632c48d
·
verified ·
1 Parent(s): f85950b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -51
app.py CHANGED
@@ -15,9 +15,11 @@ import cv2
15
 
16
  from transformers import (
17
  Qwen2_5_VLForConditionalGeneration,
18
- AutoModelForCausalLM, # Added for Dots.OCR
 
19
  AutoProcessor,
20
  TextIteratorStreamer,
 
21
  )
22
  from transformers.image_utils import load_image
23
  from gradio.themes import Soft
@@ -123,25 +125,27 @@ if torch.cuda.is_available():
123
  print("Using device:", device)
124
 
125
  # --- Model Loading ---
126
- # Load Nanonets-OCR2-3B
127
  MODEL_ID_V = "nanonets/Nanonets-OCR2-3B"
128
  processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
129
- model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
130
  MODEL_ID_V,
131
  trust_remote_code=True,
132
  torch_dtype=torch.float16,
133
- #_attn_implementation="flash_attention_2"
134
- ).to(device).eval()
 
135
 
136
  # Load Dots.OCR (rednote-hilab/dots.ocr)
137
- MODEL_ID_D = "strangervisionhf/dot.fix"
138
  processor_d = AutoProcessor.from_pretrained(MODEL_ID_D, trust_remote_code=True)
139
  model_d = AutoModelForCausalLM.from_pretrained(
140
  MODEL_ID_D,
141
  trust_remote_code=True,
142
  torch_dtype=torch.float16,
143
- _attn_implementation="flash_attention_2"
144
- ).to(device).eval()
 
145
 
146
 
147
  @spaces.GPU
@@ -175,55 +179,35 @@ def generate_image(model_name: str, text: str, image: Image.Image,
175
  }]
176
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
177
 
 
178
  inputs = processor(
179
  text=[prompt_full],
180
  images=[image],
181
  return_tensors="pt",
182
- padding=True).to(device)
 
183
 
184
- # Nanonets model supports streaming
185
- if model_name == "Nanonets-OCR2-3B":
186
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
187
- generation_kwargs = {
188
- **inputs,
189
- "streamer": streamer,
190
- "max_new_tokens": max_new_tokens,
191
- "do_sample": True,
192
- "temperature": temperature,
193
- "top_p": top_p,
194
- "top_k": top_k,
195
- "repetition_penalty": repetition_penalty,
196
- }
197
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
198
- thread.start()
199
- buffer = ""
200
- for new_text in streamer:
201
- buffer += new_text
202
- buffer = buffer.replace("<|im_end|>", "")
203
- time.sleep(0.01)
204
- yield buffer, buffer
205
 
206
- # Dots.OCR does not use the streamer in the same way, generate full response
207
- elif model_name == "Dots.OCR":
208
- generation_kwargs = {
209
- **inputs,
210
- "max_new_tokens": max_new_tokens,
211
- "do_sample": True,
212
- "temperature": temperature,
213
- "top_p": top_p,
214
- "top_k": top_k,
215
- "repetition_penalty": repetition_penalty,
216
- }
217
- generated_ids = model.generate(**generation_kwargs)
218
- generated_ids_trimmed = [
219
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
220
- ]
221
- output_text = processor.batch_decode(
222
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
223
- )[0]
224
-
225
- output_text = output_text.replace("<|im_end|>", "").strip()
226
- yield output_text, output_text
227
 
228
 
229
  # Define examples for image inference
 
15
 
16
  from transformers import (
17
  Qwen2_5_VLForConditionalGeneration,
18
+ AutoModelForImageTextToText,
19
+ AutoModelForCausalLM,
20
  AutoProcessor,
21
  TextIteratorStreamer,
22
+ AutoTokenizer
23
  )
24
  from transformers.image_utils import load_image
25
  from gradio.themes import Soft
 
125
  print("Using device:", device)
126
 
127
  # --- Model Loading ---
128
+ # Load Nanonets-OCR2-3B using AutoModelForImageTextToText
129
  MODEL_ID_V = "nanonets/Nanonets-OCR2-3B"
130
  processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
131
+ model_v = AutoModelForImageTextToText.from_pretrained(
132
  MODEL_ID_V,
133
  trust_remote_code=True,
134
  torch_dtype=torch.float16,
135
+ device_map="auto",
136
+ attn_implementation="flash_attention_2"
137
+ ).eval()
138
 
139
  # Load Dots.OCR (rednote-hilab/dots.ocr)
140
+ MODEL_ID_D = "rednote-hilab/dots.ocr"
141
  processor_d = AutoProcessor.from_pretrained(MODEL_ID_D, trust_remote_code=True)
142
  model_d = AutoModelForCausalLM.from_pretrained(
143
  MODEL_ID_D,
144
  trust_remote_code=True,
145
  torch_dtype=torch.float16,
146
+ device_map="auto",
147
+ attn_implementation="flash_attention_2"
148
+ ).eval()
149
 
150
 
151
  @spaces.GPU
 
179
  }]
180
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
181
 
182
+ # Since model is loaded with device_map="auto", we don't need to manually move inputs to device
183
  inputs = processor(
184
  text=[prompt_full],
185
  images=[image],
186
  return_tensors="pt",
187
+ padding=True
188
+ ).to(model.device)
189
 
190
+ # Both models now use a non-streaming generation approach
191
+ generation_kwargs = {
192
+ **inputs,
193
+ "max_new_tokens": max_new_tokens,
194
+ "do_sample": True,
195
+ "temperature": temperature,
196
+ "top_p": top_p,
197
+ "top_k": top_k,
198
+ "repetition_penalty": repetition_penalty,
199
+ }
 
 
 
 
 
 
 
 
 
 
 
200
 
201
+ generated_ids = model.generate(**generation_kwargs)
202
+ generated_ids_trimmed = [
203
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
204
+ ]
205
+ output_text = processor.batch_decode(
206
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
207
+ )[0]
208
+
209
+ output_text = output_text.replace("<|im_end|>", "").strip()
210
+ yield output_text, output_text
 
 
 
 
 
 
 
 
 
 
 
211
 
212
 
213
  # Define examples for image inference