prithivMLmods commited on
Commit
3a99e35
·
verified ·
1 Parent(s): 632c48d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -25
app.py CHANGED
@@ -125,10 +125,10 @@ if torch.cuda.is_available():
125
  print("Using device:", device)
126
 
127
  # --- Model Loading ---
128
- # Load Nanonets-OCR2-3B using AutoModelForImageTextToText
129
  MODEL_ID_V = "nanonets/Nanonets-OCR2-3B"
130
  processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
131
- model_v = AutoModelForImageTextToText.from_pretrained(
132
  MODEL_ID_V,
133
  trust_remote_code=True,
134
  torch_dtype=torch.float16,
@@ -179,35 +179,55 @@ def generate_image(model_name: str, text: str, image: Image.Image,
179
  }]
180
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
181
 
182
- # Since model is loaded with device_map="auto", we don't need to manually move inputs to device
183
  inputs = processor(
184
  text=[prompt_full],
185
  images=[image],
186
  return_tensors="pt",
187
- padding=True
188
- ).to(model.device)
189
 
190
- # Both models now use a non-streaming generation approach
191
- generation_kwargs = {
192
- **inputs,
193
- "max_new_tokens": max_new_tokens,
194
- "do_sample": True,
195
- "temperature": temperature,
196
- "top_p": top_p,
197
- "top_k": top_k,
198
- "repetition_penalty": repetition_penalty,
199
- }
200
-
201
- generated_ids = model.generate(**generation_kwargs)
202
- generated_ids_trimmed = [
203
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
204
- ]
205
- output_text = processor.batch_decode(
206
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
207
- )[0]
 
 
 
208
 
209
- output_text = output_text.replace("<|im_end|>", "").strip()
210
- yield output_text, output_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
 
213
  # Define examples for image inference
 
125
  print("Using device:", device)
126
 
127
  # --- Model Loading ---
128
+ # Load Nanonets-OCR2-3B using its specific, correct class
129
  MODEL_ID_V = "nanonets/Nanonets-OCR2-3B"
130
  processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
131
+ model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
132
  MODEL_ID_V,
133
  trust_remote_code=True,
134
  torch_dtype=torch.float16,
 
179
  }]
180
  prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
181
 
 
182
  inputs = processor(
183
  text=[prompt_full],
184
  images=[image],
185
  return_tensors="pt",
186
+ padding=True).to(model.device)
 
187
 
188
+ # Nanonets model supports streaming, so we use it for a better UX
189
+ if model_name == "Nanonets-OCR2-3B":
190
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
191
+ generation_kwargs = {
192
+ **inputs,
193
+ "streamer": streamer,
194
+ "max_new_tokens": max_new_tokens,
195
+ "do_sample": True,
196
+ "temperature": temperature,
197
+ "top_p": top_p,
198
+ "top_k": top_k,
199
+ "repetition_penalty": repetition_penalty,
200
+ }
201
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
202
+ thread.start()
203
+ buffer = ""
204
+ for new_text in streamer:
205
+ buffer += new_text
206
+ buffer = buffer.replace("<|im_end|>", "")
207
+ time.sleep(0.01)
208
+ yield buffer, buffer
209
 
210
+ # Dots.OCR does not use the streamer in the same way, generate full response
211
+ elif model_name == "Dots.OCR":
212
+ generation_kwargs = {
213
+ **inputs,
214
+ "max_new_tokens": max_new_tokens,
215
+ "do_sample": True,
216
+ "temperature": temperature,
217
+ "top_p": top_p,
218
+ "top_k": top_k,
219
+ "repetition_penalty": repetition_penalty,
220
+ }
221
+ generated_ids = model.generate(**generation_kwargs)
222
+ generated_ids_trimmed = [
223
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
224
+ ]
225
+ output_text = processor.batch_decode(
226
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
227
+ )[0]
228
+
229
+ output_text = output_text.replace("<|im_end|>", "").strip()
230
+ yield output_text, output_text
231
 
232
 
233
  # Define examples for image inference