ZennyKenny commited on
Commit
c9bfe98
·
verified ·
1 Parent(s): 509aed5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -215
app.py CHANGED
@@ -1,229 +1,167 @@
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
- from gradio_pdf import PDF
5
- from pdf2image import convert_from_path
6
- from PIL import Image
7
- from transformers import AutoModelForImageTextToText, AutoProcessor, AutoTokenizer
8
-
9
- model_path = "nanonets/Nanonets-OCR-s"
10
-
11
- # Load model once at startup
12
- print("Loading Nanonets OCR model...")
13
- model = AutoModelForImageTextToText.from_pretrained(
14
- model_path,
15
- torch_dtype="auto",
16
- device_map="auto",
17
- attn_implementation="flash_attention_2",
18
  )
19
- model.eval()
20
-
21
- tokenizer = AutoTokenizer.from_pretrained(model_path)
22
- processor = AutoProcessor.from_pretrained(model_path)
23
- print("Model loaded successfully!")
24
-
25
-
26
- @spaces.GPU()
27
- def ocr_image_gradio(image, max_tokens=4096):
28
- """Process image through Nanonets OCR model for Gradio interface"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  if image is None:
30
- return "Please upload an image."
 
31
 
32
- prompt = """Extract the text from the above document as if you were reading it naturally. Return the tables in html format. Return the equations in LaTeX representation. If there is an image in the document and image caption is not present, add a small description of the image inside the <img></img> tag; otherwise, add the image caption inside <img></img>. Watermarks should be wrapped in brackets. Ex: <watermark>OFFICIAL COPY</watermark>. Page numbers should be wrapped in brackets. Ex: <page_number>14</page_number> or <page_number>9/22</page_number>. Prefer using ☐ and ☑ for check boxes."""
33
-
34
- # Convert PIL image if needed
35
- if not isinstance(image, Image.Image):
36
- image = Image.fromarray(image)
37
 
38
  messages = [
39
- {"role": "system", "content": "You are a helpful assistant."},
40
  {
41
  "role": "user",
42
- "content": [
43
- {"type": "image", "image": image},
44
- {"type": "text", "text": prompt},
45
- ],
46
- },
47
  ]
48
-
49
- text = processor.apply_chat_template(
50
- messages, tokenize=False, add_generation_prompt=True
51
- )
52
- inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt")
53
- inputs = inputs.to(model.device)
54
-
55
- with torch.no_grad():
56
- output_ids = model.generate(
57
- **inputs,
58
- max_new_tokens=max_tokens,
59
- do_sample=False,
60
- repetition_penalty=1.25,
61
- )
62
- generated_ids = [
63
- output_ids[len(input_ids) :]
64
- for input_ids, output_ids in zip(inputs.input_ids, output_ids)
65
- ]
66
-
67
- output_text = processor.batch_decode(
68
- generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  )
70
- return output_text[0]
71
-
72
-
73
- @spaces.GPU()
74
- def ocr_pdf_gradio(pdf_path, max_tokens=4096, progress=gr.Progress()):
75
- """Process each page of a PDF through Nanonets OCR model"""
76
- if pdf_path is None:
77
- return "Please upload a PDF file."
78
-
79
- # Convert PDF to images
80
- progress(0, desc="Converting PDF to images...")
81
- pdf_images = convert_from_path(pdf_path)
82
-
83
- # Process each page
84
- all_text = []
85
- total_pages = len(pdf_images)
86
-
87
- for i, image in enumerate(pdf_images):
88
- progress(
89
- (i + 1) / total_pages, desc=f"Processing page {i + 1}/{total_pages}..."
90
- )
91
- page_text = ocr_image_gradio(image, max_tokens)
92
- all_text.append(f"--- PAGE {i + 1} ---\n{page_text}\n")
93
-
94
- # Combine results
95
- combined_text = "\n".join(all_text)
96
- return combined_text
97
-
98
-
99
- # Create Gradio interface
100
- with gr.Blocks(title="Nanonets OCR Demo") as demo:
101
- # Replace simple markdown with styled HTML header that includes resources
102
- gr.HTML("""
103
- <div class="title" style="text-align: center">
104
- <h1>🔍 Nanonets OCR - Document Text Extraction</h1>
105
- <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
106
- A state-of-the-art image-to-markdown OCR model for intelligent document processing
107
- </p>
108
- <div style="display: flex; justify-content: center; gap: 20px; margin: 15px 0;">
109
- <a href="https://huggingface.co/nanonets/Nanonets-OCR-s" target="_blank" style="text-decoration: none; color: #2563eb; font-weight: 500;">
110
- 📚 Hugging Face Model
111
- </a>
112
- <a href="https://nanonets.com/research/nanonets-ocr-s/" target="_blank" style="text-decoration: none; color: #2563eb; font-weight: 500;">
113
- 📝 Release Blog
114
- </a>
115
- <a href="https://github.com/NanoNets/docext" target="_blank" style="text-decoration: none; color: #2563eb; font-weight: 500;">
116
- 💻 GitHub Repository
117
- </a>
118
- </div>
119
- </div>
120
- """)
121
-
122
- with gr.Tabs() as tabs:
123
- # Image tab
124
- with gr.TabItem("Image OCR"):
125
- with gr.Row():
126
- with gr.Column(scale=1):
127
- image_input = gr.Image(
128
- label="Upload Document Image", type="pil", height=400
129
- )
130
- image_max_tokens = gr.Slider(
131
- minimum=1024,
132
- maximum=8192,
133
- value=4096,
134
- step=512,
135
- label="Max Tokens",
136
- info="Maximum number of tokens to generate",
137
- )
138
- image_extract_btn = gr.Button(
139
- "Extract Text", variant="primary", size="lg"
140
- )
141
-
142
- with gr.Column(scale=2):
143
- image_output_text = gr.Textbox(
144
- label="Extracted Text",
145
- lines=20,
146
- show_copy_button=True,
147
- placeholder="Extracted text will appear here...",
148
- )
149
-
150
- # PDF tab
151
- with gr.TabItem("PDF OCR"):
152
- with gr.Row():
153
- with gr.Column(scale=1):
154
- pdf_input = PDF(label="Upload PDF Document", height=400)
155
- pdf_max_tokens = gr.Slider(
156
- minimum=1024,
157
- maximum=8192,
158
- value=4096,
159
- step=512,
160
- label="Max Tokens per Page",
161
- info="Maximum number of tokens to generate for each page",
162
- )
163
- pdf_extract_btn = gr.Button(
164
- "Extract PDF Text", variant="primary", size="lg"
165
- )
166
-
167
- with gr.Column(scale=2):
168
- pdf_output_text = gr.Textbox(
169
- label="Extracted Text (All Pages)",
170
- lines=20,
171
- show_copy_button=True,
172
- placeholder="Extracted text will appear here...",
173
- )
174
-
175
- # Event handlers for Image tab
176
- image_extract_btn.click(
177
- fn=ocr_image_gradio,
178
- inputs=[image_input, image_max_tokens],
179
- outputs=image_output_text,
180
- show_progress=True,
181
- )
182
-
183
- image_input.change(
184
- fn=ocr_image_gradio,
185
- inputs=[image_input, image_max_tokens],
186
- outputs=image_output_text,
187
- show_progress=True,
188
- )
189
-
190
- # Event handlers for PDF tab
191
- pdf_extract_btn.click(
192
- fn=ocr_pdf_gradio,
193
- inputs=[pdf_input, pdf_max_tokens],
194
- outputs=pdf_output_text,
195
- show_progress=True,
196
- )
197
-
198
- # Add model information section
199
- with gr.Accordion("About Nanonets-OCR-s", open=False):
200
- gr.Markdown("""
201
- ## Nanonets-OCR-s
202
-
203
- Nanonets-OCR-s is a powerful, state-of-the-art image-to-markdown OCR model that goes far beyond traditional text extraction.
204
- It transforms documents into structured markdown with intelligent content recognition and semantic tagging, making it ideal
205
- for downstream processing by Large Language Models (LLMs).
206
-
207
- ### Key Features
208
-
209
- - **LaTeX Equation Recognition**: Automatically converts mathematical equations and formulas into properly formatted LaTeX syntax.
210
- It distinguishes between inline ($...$) and display ($$...$$) equations.
211
-
212
- - **Intelligent Image Description**: Describes images within documents using structured `<img>` tags, making them digestible
213
- for LLM processing. It can describe various image types, including logos, charts, graphs and so on, detailing their content,
214
- style, and context.
215
-
216
- - **Signature Detection & Isolation**: Identifies and isolates signatures from other text, outputting them within a `<signature>` tag.
217
- This is crucial for processing legal and business documents.
218
-
219
- - **Watermark Extraction**: Detects and extracts watermark text from documents, placing it within a `<watermark>` tag.
220
-
221
- - **Smart Checkbox Handling**: Converts form checkboxes and radio buttons into standardized Unicode symbols (☐, ☑, ☒)
222
- for consistent and reliable processing.
223
-
224
- - **Complex Table Extraction**: Accurately extracts complex tables from documents and converts them into both markdown
225
- and HTML table formats.
226
- """)
227
 
228
  if __name__ == "__main__":
229
- demo.queue().launch(ssr_mode=False)
 
1
+ import os
2
+ import random
3
+ import uuid
4
+ import json
5
+ import time
6
+ import asyncio
7
+ from threading import Thread
8
+
9
  import gradio as gr
10
  import spaces
11
  import torch
12
+ import numpy as np
13
+ from PIL import Image, ImageOps
14
+ # import cv2 # not needed anymore
15
+
16
+ from transformers import (
17
+ Qwen2_5_VLForConditionalGeneration,
18
+ AutoProcessor,
19
+ TextIteratorStreamer,
 
 
 
 
 
 
20
  )
21
+ from transformers.image_utils import load_image
22
+
23
+ # Optional docling imports (unused now but kept for easy re-enable)
24
+ # from docling_core.types.doc import DoclingDocument, DocTagsDocument
25
+
26
+ import re
27
+ import ast
28
+ import html
29
+
30
+ # ---------------------------
31
+ # Constants & device
32
+ # ---------------------------
33
+ MAX_MAX_NEW_TOKENS = 2048
34
+ DEFAULT_MAX_NEW_TOKENS = 1024
35
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
36
+
37
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
38
+
39
+ # ---------------------------
40
+ # Load ONLY Typhoon OCR 20B
41
+ # ---------------------------
42
+ MODEL_ID = "scb10x/typhoon-ocr-20b" # <- 20B model
43
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
44
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
45
+ MODEL_ID,
46
+ trust_remote_code=True,
47
+ torch_dtype=torch.float16
48
+ ).to(device).eval()
49
+
50
+ # ---------------------------
51
+ # (Optional) image helpers
52
+ # ---------------------------
53
+ def add_random_padding(image, min_percent=0.1, max_percent=0.10):
54
+ image = image.convert("RGB")
55
+ width, height = image.size
56
+ pad_w_percent = random.uniform(min_percent, max_percent)
57
+ pad_h_percent = random.uniform(min_percent, max_percent)
58
+ pad_w = int(width * pad_w_percent)
59
+ pad_h = int(height * pad_h_percent)
60
+ corner_pixel = image.getpixel((0, 0))
61
+ padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel)
62
+ return padded_image
63
+
64
+ def normalize_values(text, target_max=500):
65
+ def normalize_list(values):
66
+ max_value = max(values) if values else 1
67
+ return [round((v / max_value) * target_max) for v in values]
68
+
69
+ def process_match(match):
70
+ num_list = ast.literal_eval(match.group(0))
71
+ normalized = normalize_list(num_list)
72
+ return "".join([f"<loc_{num}>" for num in normalized])
73
+
74
+ pattern = r"\[([\d\.\s,]+)\]"
75
+ return re.sub(pattern, process_match, text)
76
+
77
+ # ---------------------------
78
+ # Image generation only
79
+ # ---------------------------
80
+ @spaces.GPU
81
+ def generate_image(
82
+ text: str,
83
+ image: Image.Image,
84
+ max_new_tokens: int = 2048,
85
+ temperature: float = 0.1,
86
+ top_p: float = 0.9,
87
+ top_k: int = 50,
88
+ repetition_penalty: float = 1.2,
89
+ ):
90
+ """Generate OCR/vision response for a single image with Typhoon OCR 20B."""
91
  if image is None:
92
+ yield "Please upload an image."
93
+ return
94
 
95
+ images = [image]
 
 
 
 
96
 
97
  messages = [
 
98
  {
99
  "role": "user",
100
+ "content": [{"type": "image"} for _ in images] + [
101
+ {"type": "text", "text": text}
102
+ ]
103
+ }
 
104
  ]
105
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
106
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
107
+
108
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
109
+ generation_kwargs = {
110
+ **inputs,
111
+ "streamer": streamer,
112
+ "max_new_tokens": max_new_tokens,
113
+ "temperature": temperature,
114
+ "top_p": top_p,
115
+ "top_k": top_k,
116
+ "repetition_penalty": repetition_penalty,
117
+ }
118
+
119
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
120
+ thread.start()
121
+
122
+ buffer = ""
123
+ for new_text in streamer:
124
+ buffer += new_text.replace("<|im_end|>", "")
125
+ yield buffer
126
+
127
+ # ---------------------------
128
+ # Minimal UI (Image only)
129
+ # ---------------------------
130
+ css = """
131
+ .submit-btn {
132
+ background-color: #2980b9 !important;
133
+ color: white !important;
134
+ }
135
+ .submit-btn:hover {
136
+ background-color: #3498db !important;
137
+ }
138
+ """
139
+
140
+ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
141
+ gr.Markdown("# **Typhoon OCR 20B**")
142
+
143
+ with gr.Row():
144
+ with gr.Column():
145
+ image_query = gr.Textbox(label="Query Input", placeholder="e.g., \"OCR the image\" or task instruction…")
146
+ image_upload = gr.Image(type="pil", label="Image")
147
+ image_submit = gr.Button("Submit", elem_classes="submit-btn")
148
+
149
+ with gr.Accordion("Advanced options", open=False):
150
+ max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
151
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.1)
152
+ top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
153
+ top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
154
+ repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
155
+
156
+ # Right column: ONLY output (no model info, no radios)
157
+ with gr.Column():
158
+ output = gr.Textbox(label="Output", interactive=False, lines=12, scale=2)
159
+
160
+ image_submit.click(
161
+ fn=generate_image,
162
+ inputs=[image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
163
+ outputs=output
164
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  if __name__ == "__main__":
167
+ demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)