prithivMLmods commited on
Commit
867049a
·
verified ·
1 Parent(s): 1dc1493

update app

Browse files
Files changed (1) hide show
  1. app.py +21 -25
app.py CHANGED
@@ -15,11 +15,9 @@ import cv2
15
 
16
  from transformers import (
17
  Qwen2_5_VLForConditionalGeneration,
18
- AutoModelForImageTextToText,
19
- AutoModelForCausalLM,
20
  AutoProcessor,
21
  TextIteratorStreamer,
22
- AutoTokenizer
23
  )
24
  from transformers.image_utils import load_image
25
  from gradio.themes import Soft
@@ -125,27 +123,25 @@ if torch.cuda.is_available():
125
  print("Using device:", device)
126
 
127
  # --- Model Loading ---
128
- # Load Nanonets-OCR2-3B using its specific, correct class
129
  MODEL_ID_V = "nanonets/Nanonets-OCR2-3B"
130
  processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
131
  model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
132
  MODEL_ID_V,
133
  trust_remote_code=True,
134
- torch_dtype=torch.float16,
135
- device_map="auto",
136
- attn_implementation="flash_attention_2"
137
- ).eval()
138
 
139
- # Load Dots.OCR (rednote-hilab/dots.ocr)
140
- MODEL_ID_D = "rednote-hilab/dots.ocr"
141
- processor_d = AutoProcessor.from_pretrained(MODEL_ID_D, trust_remote_code=True)
142
- model_d = AutoModelForCausalLM.from_pretrained(
143
- MODEL_ID_D,
 
144
  trust_remote_code=True,
145
- torch_dtype=torch.float16,
146
- device_map="auto",
147
- attn_implementation="flash_attention_2"
148
- ).eval()
149
 
150
 
151
  @spaces.GPU
@@ -163,9 +159,9 @@ def generate_image(model_name: str, text: str, image: Image.Image,
163
  if model_name == "Nanonets-OCR2-3B":
164
  processor = processor_v
165
  model = model_v
166
- elif model_name == "Dots.OCR":
167
- processor = processor_d
168
- model = model_d
169
  else:
170
  yield "Invalid model selected.", "Invalid model selected."
171
  return
@@ -183,9 +179,9 @@ def generate_image(model_name: str, text: str, image: Image.Image,
183
  text=[prompt_full],
184
  images=[image],
185
  return_tensors="pt",
186
- padding=True).to(model.device)
187
 
188
- # Nanonets model supports streaming, so we use it for a better UX
189
  if model_name == "Nanonets-OCR2-3B":
190
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
191
  generation_kwargs = {
@@ -207,8 +203,8 @@ def generate_image(model_name: str, text: str, image: Image.Image,
207
  time.sleep(0.01)
208
  yield buffer, buffer
209
 
210
- # Dots.OCR does not use the streamer in the same way, generate full response
211
- elif model_name == "Dots.OCR":
212
  generation_kwargs = {
213
  **inputs,
214
  "max_new_tokens": max_new_tokens,
@@ -266,7 +262,7 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
266
  markdown_output = gr.Markdown(label="(Result.Md)")
267
 
268
  model_choice = gr.Radio(
269
- choices=["Nanonets-OCR2-3B", "Dots.OCR"],
270
  label="Select Model",
271
  value="Nanonets-OCR2-3B"
272
  )
 
15
 
16
  from transformers import (
17
  Qwen2_5_VLForConditionalGeneration,
18
+ PaddleOCRVLForConditionalGeneration, # Added for PaddleOCR-VL
 
19
  AutoProcessor,
20
  TextIteratorStreamer,
 
21
  )
22
  from transformers.image_utils import load_image
23
  from gradio.themes import Soft
 
123
  print("Using device:", device)
124
 
125
  # --- Model Loading ---
126
+ # Load Nanonets-OCR2-3B
127
  MODEL_ID_V = "nanonets/Nanonets-OCR2-3B"
128
  processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
129
  model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
130
  MODEL_ID_V,
131
  trust_remote_code=True,
132
+ torch_dtype=torch.float16
133
+ ).to(device).eval()
 
 
134
 
135
+ # Load PaddleOCR-VL
136
+ MODEL_ID_P = "PaddlePaddle/PaddleOCR-VL"
137
+ SUBFOLDER_P = "PaddleOCR-VL-0.9B"
138
+ processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True, subfolder=SUBFOLDER_P)
139
+ model_p = PaddleOCRVLForConditionalGeneration.from_pretrained(
140
+ MODEL_ID_P,
141
  trust_remote_code=True,
142
+ subfolder=SUBFOLDER_P,
143
+ torch_dtype=torch.float16
144
+ ).to(device).eval()
 
145
 
146
 
147
  @spaces.GPU
 
159
  if model_name == "Nanonets-OCR2-3B":
160
  processor = processor_v
161
  model = model_v
162
+ elif model_name == "PaddleOCR-VL":
163
+ processor = processor_p
164
+ model = model_p
165
  else:
166
  yield "Invalid model selected.", "Invalid model selected."
167
  return
 
179
  text=[prompt_full],
180
  images=[image],
181
  return_tensors="pt",
182
+ padding=True).to(device)
183
 
184
+ # Nanonets model supports streaming
185
  if model_name == "Nanonets-OCR2-3B":
186
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
187
  generation_kwargs = {
 
203
  time.sleep(0.01)
204
  yield buffer, buffer
205
 
206
+ # PaddleOCR-VL does not use a streamer, generate full response
207
+ elif model_name == "PaddleOCR-VL":
208
  generation_kwargs = {
209
  **inputs,
210
  "max_new_tokens": max_new_tokens,
 
262
  markdown_output = gr.Markdown(label="(Result.Md)")
263
 
264
  model_choice = gr.Radio(
265
+ choices=["Nanonets-OCR2-3B", "PaddleOCR-VL"],
266
  label="Select Model",
267
  value="Nanonets-OCR2-3B"
268
  )