Bapt120 commited on
Commit
80b2df0
·
verified ·
1 Parent(s): 5b6bee3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -73
app.py CHANGED
@@ -1,29 +1,51 @@
1
  #!/usr/bin/env python3
2
- import os
3
- import json
4
- import base64
5
- import requests
 
 
 
6
  import gradio as gr
7
  from PIL import Image
8
  from io import BytesIO
9
  import pypdfium2 as pdfium
 
 
 
 
 
10
 
11
- ENDPOINT = os.environ.get("VLLM_ENDPOINT")
12
- MODEL = os.environ.get("VLLM_MODEL")
13
 
14
- if not ENDPOINT or not MODEL:
15
- raise ValueError("VLLM_ENDPOINT and VLLM_MODEL environment variables must be set.")
 
 
 
 
 
 
 
16
 
 
 
 
 
 
 
 
 
17
 
18
- def image_to_base64(image):
19
- buffered = BytesIO()
20
- if image.mode == 'RGBA':
21
- image = image.convert('RGB')
22
- image.save(buffered, format="PNG")
23
- return base64.b64encode(buffered.getvalue()).decode("utf-8")
24
 
25
 
26
  def render_pdf_page(page, max_resolution=1540, scale=2.77):
 
27
  width, height = page.get_size()
28
  pixel_width = width * scale
29
  pixel_height = height * scale
@@ -33,6 +55,7 @@ def render_pdf_page(page, max_resolution=1540, scale=2.77):
33
 
34
 
35
  def process_pdf(pdf_path, page_num=1):
 
36
  pdf = pdfium.PdfDocument(pdf_path)
37
  total_pages = len(pdf)
38
  page_idx = min(max(int(page_num) - 1, 0), total_pages - 1)
@@ -44,7 +67,109 @@ def process_pdf(pdf_path, page_num=1):
44
  return img, total_pages, page_idx + 1
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def process_input(file_input, temperature, page_num):
 
48
  if file_input is None:
49
  yield "Please upload an image or PDF first.", "", "", None, gr.update()
50
  return
@@ -54,78 +179,35 @@ def process_input(file_input, temperature, page_num):
54
 
55
  file_path = file_input if isinstance(file_input, str) else file_input.name
56
 
 
57
  if file_path.lower().endswith('.pdf'):
58
  try:
59
  image_to_process, total_pages, actual_page = process_pdf(file_path, int(page_num))
60
  page_info = f"Processing page {actual_page} of {total_pages}"
61
  except Exception as e:
62
- yield f"Error processing PDF", "", "", None, gr.update()
63
  return
 
64
  else:
65
  try:
66
  image_to_process = Image.open(file_path)
67
  page_info = "Processing image"
68
  except Exception as e:
69
- yield f"Error opening image", "", "", None, gr.update()
70
  return
71
 
72
- content = [
73
- {"type": "text", "text": ""},
74
- {
75
- "type": "image_url",
76
- "image_url": {"url": f"data:image/png;base64,{image_to_base64(image_to_process)}"}
77
- }
78
- ]
79
-
80
- payload = {
81
- "model": MODEL,
82
- "messages": [{"role": "user", "content": content}],
83
- "temperature": temperature,
84
- "stream": True
85
- }
86
-
87
  try:
88
- response = requests.post(
89
- ENDPOINT,
90
- headers={"Content-Type": "application/json"},
91
- data=json.dumps(payload),
92
- stream=True
93
- )
94
- response.raise_for_status()
95
-
96
- accumulated_response = ""
97
- first_chunk = True
98
 
99
- for line in response.iter_lines():
100
- if line:
101
- line = line.decode('utf-8')
102
- if line.startswith('data: '):
103
- line = line[6:]
104
-
105
- if line.strip() == '[DONE]':
106
- break
107
-
108
- try:
109
- chunk = json.loads(line)
110
- if 'choices' in chunk and len(chunk['choices']) > 0:
111
- delta = chunk['choices'][0].get('delta', {})
112
- content_delta = delta.get('content', '')
113
- if content_delta:
114
- accumulated_response += content_delta
115
- if first_chunk:
116
- yield accumulated_response, accumulated_response, page_info, image_to_process, gr.update()
117
- first_chunk = False
118
- else:
119
- yield accumulated_response, accumulated_response, page_info, gr.update(), gr.update()
120
- except json.JSONDecodeError:
121
- continue
122
-
123
  except Exception as e:
124
- error_msg = f"Error"
125
  yield error_msg, error_msg, page_info, image_to_process, gr.update()
126
 
127
 
128
  def update_slider(file_input):
 
129
  if file_input is None:
130
  return gr.update(maximum=20, value=1)
131
 
@@ -143,17 +225,22 @@ def update_slider(file_input):
143
  return gr.update(maximum=1, value=1)
144
 
145
 
146
- with gr.Blocks(title="📖 Image/PDF OCR", theme=gr.themes.Soft()) as demo:
147
- gr.Markdown("""
148
- # 📖 Image/PDF to Text Extraction
 
149
 
150
  **💡 How to use:**
151
  1. Upload an image or PDF
152
  2. For PDFs: select which page to extract (1-20)
153
- 3. Adjust temperature if needed
154
  4. Click "Extract Text"
155
 
156
- **Note:** The Markdown rendering for tables is not always correct, check the raw output for complex tables!
 
 
 
 
157
  """)
158
 
159
  with gr.Row():
@@ -183,11 +270,12 @@ with gr.Blocks(title="📖 Image/PDF OCR", theme=gr.themes.Soft()) as demo:
183
  interactive=False
184
  )
185
  temperature = gr.Slider(
186
- minimum=0.1,
187
  maximum=1.0,
188
  value=0.2,
189
  step=0.05,
190
- label="Temperature"
 
191
  )
192
  submit_btn = gr.Button("Extract Text", variant="primary")
193
  clear_btn = gr.Button("Clear", variant="secondary")
@@ -208,6 +296,7 @@ with gr.Blocks(title="📖 Image/PDF OCR", theme=gr.themes.Soft()) as demo:
208
  show_copy_button=True
209
  )
210
 
 
211
  submit_btn.click(
212
  fn=process_input,
213
  inputs=[file_input, temperature, num_pages],
 
1
  #!/usr/bin/env python3
2
+ import subprocess
3
+ import sys
4
+ import threading
5
+
6
+ import spaces
7
+ import torch
8
+
9
  import gradio as gr
10
  from PIL import Image
11
  from io import BytesIO
12
  import pypdfium2 as pdfium
13
+ from transformers import (
14
+ LightOnOCRForConditionalGeneration,
15
+ LightOnOCRProcessor,
16
+ TextIteratorStreamer,
17
+ )
18
 
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
20
 
21
+ # Choose best attention implementation based on device
22
+ if device == "cuda":
23
+ attn_implementation = "sdpa"
24
+ dtype = torch.bfloat16
25
+ print("Using sdpa for GPU")
26
+ else:
27
+ attn_implementation = "eager" # Best for CPU
28
+ dtype = torch.float32
29
+ print("Using eager attention for CPU")
30
 
31
+ # Initialize the LightOnOCR model and processor
32
+ print(f"Loading model on {device} with {attn_implementation} attention...")
33
+ model = LightOnOCRForConditionalGeneration.from_pretrained(
34
+ "lightonai/LightOnOCR-1B-1025",
35
+ attn_implementation=attn_implementation,
36
+ torch_dtype=dtype,
37
+ trust_remote_code=True
38
+ ).to(device).eval()
39
 
40
+ processor = LightOnOCRProcessor.from_pretrained(
41
+ "lightonai/LightOnOCR-1B-1025",
42
+ trust_remote_code=True
43
+ )
44
+ print("Model loaded successfully!")
 
45
 
46
 
47
  def render_pdf_page(page, max_resolution=1540, scale=2.77):
48
+ """Render a PDF page to PIL Image."""
49
  width, height = page.get_size()
50
  pixel_width = width * scale
51
  pixel_height = height * scale
 
55
 
56
 
57
  def process_pdf(pdf_path, page_num=1):
58
+ """Extract a specific page from PDF."""
59
  pdf = pdfium.PdfDocument(pdf_path)
60
  total_pages = len(pdf)
61
  page_idx = min(max(int(page_num) - 1, 0), total_pages - 1)
 
67
  return img, total_pages, page_idx + 1
68
 
69
 
70
+ def clean_output_text(text):
71
+ """Remove chat template artifacts from output."""
72
+ # Remove common chat template markers
73
+ markers_to_remove = ["system", "user", "assistant"]
74
+
75
+ # Split by lines and filter
76
+ lines = text.split('\n')
77
+ cleaned_lines = []
78
+
79
+ for line in lines:
80
+ stripped = line.strip()
81
+ # Skip lines that are just template markers
82
+ if stripped.lower() not in markers_to_remove:
83
+ cleaned_lines.append(line)
84
+
85
+ # Join back and strip leading/trailing whitespace
86
+ cleaned = '\n'.join(cleaned_lines).strip()
87
+
88
+ # Alternative approach: if there's an "assistant" marker, take everything after it
89
+ if "assistant" in text.lower():
90
+ parts = text.split("assistant", 1)
91
+ if len(parts) > 1:
92
+ cleaned = parts[1].strip()
93
+
94
+ return cleaned
95
+
96
+
97
+ @spaces.GPU
98
+ def extract_text_from_image(image, temperature=0.2, stream=False):
99
+ """Extract text from image using LightOnOCR model."""
100
+ # Prepare the chat format
101
+ chat = [
102
+ {
103
+ "role": "user",
104
+ "content": [
105
+ {"type": "image", "url": image},
106
+ ],
107
+ }
108
+ ]
109
+
110
+ # Apply chat template and tokenize
111
+ inputs = processor.apply_chat_template(
112
+ chat,
113
+ add_generation_prompt=True,
114
+ tokenize=True,
115
+ return_dict=True,
116
+ return_tensors="pt"
117
+ )
118
+
119
+ # Move inputs to device AND convert to the correct dtype
120
+ inputs = {
121
+ k: v.to(device=device, dtype=dtype) if isinstance(v, torch.Tensor) and v.dtype in [torch.float32, torch.float16, torch.bfloat16]
122
+ else v.to(device) if isinstance(v, torch.Tensor)
123
+ else v
124
+ for k, v in inputs.items()
125
+ }
126
+
127
+ generation_kwargs = dict(
128
+ **inputs,
129
+ max_new_tokens=2048,
130
+ temperature=temperature if temperature > 0 else 0.0,
131
+ use_cache=True,
132
+ do_sample=temperature > 0,
133
+ )
134
+
135
+ if stream:
136
+ # Setup streamer for streaming generation
137
+ streamer = TextIteratorStreamer(
138
+ processor.tokenizer,
139
+ skip_prompt=True,
140
+ skip_special_tokens=True
141
+ )
142
+ generation_kwargs["streamer"] = streamer
143
+
144
+ # Run generation in a separate thread
145
+ thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
146
+ thread.start()
147
+
148
+ # Yield chunks as they arrive
149
+ full_text = ""
150
+ for new_text in streamer:
151
+ full_text += new_text
152
+ # Clean the accumulated text
153
+ cleaned_text = clean_output_text(full_text)
154
+ yield cleaned_text
155
+
156
+ thread.join()
157
+ else:
158
+ # Non-streaming generation
159
+ with torch.no_grad():
160
+ outputs = model.generate(**generation_kwargs)
161
+
162
+ # Decode the output
163
+ output_text = processor.decode(outputs[0], skip_special_tokens=True)
164
+
165
+ # Clean the output
166
+ cleaned_text = clean_output_text(output_text)
167
+
168
+ yield cleaned_text
169
+
170
+
171
  def process_input(file_input, temperature, page_num):
172
+ """Process uploaded file (image or PDF) and extract text with streaming."""
173
  if file_input is None:
174
  yield "Please upload an image or PDF first.", "", "", None, gr.update()
175
  return
 
179
 
180
  file_path = file_input if isinstance(file_input, str) else file_input.name
181
 
182
+ # Handle PDF files
183
  if file_path.lower().endswith('.pdf'):
184
  try:
185
  image_to_process, total_pages, actual_page = process_pdf(file_path, int(page_num))
186
  page_info = f"Processing page {actual_page} of {total_pages}"
187
  except Exception as e:
188
+ yield f"Error processing PDF: {str(e)}", "", "", None, gr.update()
189
  return
190
+ # Handle image files
191
  else:
192
  try:
193
  image_to_process = Image.open(file_path)
194
  page_info = "Processing image"
195
  except Exception as e:
196
+ yield f"Error opening image: {str(e)}", "", "", None, gr.update()
197
  return
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  try:
200
+ # Extract text using LightOnOCR with streaming
201
+ for extracted_text in extract_text_from_image(image_to_process, temperature, stream=True):
202
+ yield extracted_text, extracted_text, page_info, image_to_process, gr.update()
 
 
 
 
 
 
 
203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  except Exception as e:
205
+ error_msg = f"Error during text extraction: {str(e)}"
206
  yield error_msg, error_msg, page_info, image_to_process, gr.update()
207
 
208
 
209
  def update_slider(file_input):
210
+ """Update page slider based on PDF page count."""
211
  if file_input is None:
212
  return gr.update(maximum=20, value=1)
213
 
 
225
  return gr.update(maximum=1, value=1)
226
 
227
 
228
+ # Create Gradio interface
229
+ with gr.Blocks(title="📖 Image/PDF OCR with LightOnOCR", theme=gr.themes.Soft()) as demo:
230
+ gr.Markdown(f"""
231
+ # 📖 Image/PDF to Text Extraction with LightOnOCR
232
 
233
  **💡 How to use:**
234
  1. Upload an image or PDF
235
  2. For PDFs: select which page to extract (1-20)
236
+ 3. Adjust temperature if needed (0.0 for deterministic, higher for more varied output)
237
  4. Click "Extract Text"
238
 
239
+ **Note:** The Markdown rendering for tables may not always be perfect. Check the raw output for complex tables!
240
+
241
+ **Model:** LightOnOCR-1B-1025 by LightOn AI
242
+ **Device:** {device.upper()}
243
+ **Attention:** {attn_implementation}
244
  """)
245
 
246
  with gr.Row():
 
270
  interactive=False
271
  )
272
  temperature = gr.Slider(
273
+ minimum=0.0,
274
  maximum=1.0,
275
  value=0.2,
276
  step=0.05,
277
+ label="Temperature",
278
+ info="0.0 = deterministic, Higher = more varied"
279
  )
280
  submit_btn = gr.Button("Extract Text", variant="primary")
281
  clear_btn = gr.Button("Clear", variant="secondary")
 
296
  show_copy_button=True
297
  )
298
 
299
+ # Event handlers
300
  submit_btn.click(
301
  fn=process_input,
302
  inputs=[file_input, temperature, num_pages],