abinash73 commited on
Commit
b3ff38b
·
verified ·
1 Parent(s): 1ca9281

Add main application file

Browse files
Files changed (1) hide show
  1. app.py +241 -0
app.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import base64
3
+ import gradio as gr
4
+ from io import BytesIO
5
+ from PIL import Image
6
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
7
+
8
+ from olmocr.data.renderpdf import render_pdf_to_base64png
9
+ from olmocr.prompts import build_no_anchoring_v4_yaml_prompt
10
+
11
+ # Initialize the model
12
+ print("Loading OlmOCR model...")
13
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
14
+ "allenai/olmOCR-2-7B-1025",
15
+ torch_dtype=torch.bfloat16
16
+ ).eval()
17
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ model.to(device)
20
+ print(f"Model loaded successfully on {device}")
21
+
22
+ def process_pdf(pdf_file, page_number=1, max_new_tokens=50, temperature=0.1):
23
+ """
24
+ Process a PDF file and extract text using OlmOCR
25
+
26
+ Args:
27
+ pdf_file: Path to uploaded PDF file
28
+ page_number: Page number to extract (default: 1)
29
+ max_new_tokens: Maximum tokens to generate
30
+ temperature: Sampling temperature
31
+
32
+ Returns:
33
+ Extracted text from the PDF
34
+ """
35
+ try:
36
+ # Render PDF page to base64 image
37
+ image_base64 = render_pdf_to_base64png(
38
+ pdf_file,
39
+ page_number,
40
+ target_longest_image_dim=1288
41
+ )
42
+
43
+ # Build the prompt
44
+ messages = [
45
+ {
46
+ "role": "user",
47
+ "content": [
48
+ {"type": "text", "text": build_no_anchoring_v4_yaml_prompt()},
49
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
50
+ ],
51
+ }
52
+ ]
53
+
54
+ # Process inputs
55
+ text = processor.apply_chat_template(
56
+ messages,
57
+ tokenize=False,
58
+ add_generation_prompt=True
59
+ )
60
+ main_image = Image.open(BytesIO(base64.b64decode(image_base64)))
61
+
62
+ inputs = processor(
63
+ text=[text],
64
+ images=[main_image],
65
+ padding=True,
66
+ return_tensors="pt",
67
+ )
68
+ inputs = {key: value.to(device) for (key, value) in inputs.items()}
69
+
70
+ # Generate output
71
+ output = model.generate(
72
+ **inputs,
73
+ temperature=temperature,
74
+ max_new_tokens=max_new_tokens,
75
+ num_return_sequences=1,
76
+ do_sample=True,
77
+ )
78
+
79
+ # Decode output
80
+ prompt_length = inputs["input_ids"].shape[1]
81
+ new_tokens = output[:, prompt_length:]
82
+ text_output = processor.tokenizer.batch_decode(
83
+ new_tokens,
84
+ skip_special_tokens=True
85
+ )
86
+
87
+ return text_output[0] if text_output else "No text extracted"
88
+
89
+ except Exception as e:
90
+ return f"Error processing PDF: {str(e)}"
91
+
92
+ def process_image(image_file, max_new_tokens=50, temperature=0.1):
93
+ """
94
+ Process an image file directly using OlmOCR
95
+
96
+ Args:
97
+ image_file: PIL Image or path to image file
98
+ max_new_tokens: Maximum tokens to generate
99
+ temperature: Sampling temperature
100
+
101
+ Returns:
102
+ Extracted text from the image
103
+ """
104
+ try:
105
+ # Convert image to base64
106
+ if isinstance(image_file, str):
107
+ with open(image_file, 'rb') as f:
108
+ image_bytes = f.read()
109
+ else:
110
+ buffered = BytesIO()
111
+ image_file.save(buffered, format="PNG")
112
+ image_bytes = buffered.getvalue()
113
+
114
+ image_base64 = base64.b64encode(image_bytes).decode('utf-8')
115
+
116
+ # Build the prompt
117
+ messages = [
118
+ {
119
+ "role": "user",
120
+ "content": [
121
+ {"type": "text", "text": build_no_anchoring_v4_yaml_prompt()},
122
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
123
+ ],
124
+ }
125
+ ]
126
+
127
+ # Process inputs
128
+ text = processor.apply_chat_template(
129
+ messages,
130
+ tokenize=False,
131
+ add_generation_prompt=True
132
+ )
133
+ main_image = Image.open(BytesIO(image_bytes))
134
+
135
+ inputs = processor(
136
+ text=[text],
137
+ images=[main_image],
138
+ padding=True,
139
+ return_tensors="pt",
140
+ )
141
+ inputs = {key: value.to(device) for (key, value) in inputs.items()}
142
+
143
+ # Generate output
144
+ output = model.generate(
145
+ **inputs,
146
+ temperature=temperature,
147
+ max_new_tokens=max_new_tokens,
148
+ num_return_sequences=1,
149
+ do_sample=True,
150
+ )
151
+
152
+ # Decode output
153
+ prompt_length = inputs["input_ids"].shape[1]
154
+ new_tokens = output[:, prompt_length:]
155
+ text_output = processor.tokenizer.batch_decode(
156
+ new_tokens,
157
+ skip_special_tokens=True
158
+ )
159
+
160
+ return text_output[0] if text_output else "No text extracted"
161
+
162
+ except Exception as e:
163
+ return f"Error processing image: {str(e)}"
164
+
165
+ # Create Gradio interface with tabs
166
+ with gr.Blocks(title="OlmOCR API") as demo:
167
+ gr.Markdown("# OlmOCR - PDF & Image Text Extraction")
168
+ gr.Markdown("Extract text from PDFs and images using the OlmOCR model")
169
+
170
+ with gr.Tab("PDF Processing"):
171
+ with gr.Row():
172
+ with gr.Column():
173
+ pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
174
+ pdf_page = gr.Number(label="Page Number", value=1, precision=0)
175
+ pdf_tokens = gr.Slider(label="Max New Tokens", minimum=10, maximum=500, value=50, step=10)
176
+ pdf_temp = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.1, step=0.1)
177
+ pdf_button = gr.Button("Extract Text from PDF", variant="primary")
178
+ with gr.Column():
179
+ pdf_output = gr.Textbox(label="Extracted Text", lines=15)
180
+
181
+ pdf_button.click(
182
+ fn=process_pdf,
183
+ inputs=[pdf_input, pdf_page, pdf_tokens, pdf_temp],
184
+ outputs=pdf_output
185
+ )
186
+
187
+ with gr.Tab("Image Processing"):
188
+ with gr.Row():
189
+ with gr.Column():
190
+ image_input = gr.Image(label="Upload Image", type="pil")
191
+ image_tokens = gr.Slider(label="Max New Tokens", minimum=10, maximum=500, value=50, step=10)
192
+ image_temp = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.1, step=0.1)
193
+ image_button = gr.Button("Extract Text from Image", variant="primary")
194
+ with gr.Column():
195
+ image_output = gr.Textbox(label="Extracted Text", lines=15)
196
+
197
+ image_button.click(
198
+ fn=process_image,
199
+ inputs=[image_input, image_tokens, image_temp],
200
+ outputs=image_output
201
+ )
202
+
203
+ gr.Markdown("""
204
+ ### API Usage
205
+ Once running, you can access the API at:
206
+ - **Web Interface**: http://localhost:7860
207
+ - **API Endpoint**: http://localhost:7860/api/predict
208
+
209
+ ### Python API Client Example:
210
+ ```python
211
+ from gradio_client import Client
212
+
213
+ client = Client("http://localhost:7860")
214
+
215
+ # For PDF
216
+ result = client.predict(
217
+ pdf_file="path/to/file.pdf",
218
+ page_number=1,
219
+ max_new_tokens=50,
220
+ temperature=0.1,
221
+ api_name="/predict"
222
+ )
223
+
224
+ # For Image
225
+ result = client.predict(
226
+ image_file="path/to/image.png",
227
+ max_new_tokens=50,
228
+ temperature=0.1,
229
+ api_name="/predict_1"
230
+ )
231
+ ```
232
+ """)
233
+
234
+ # Launch the app
235
+ if __name__ == "__main__":
236
+ demo.launch(
237
+ server_name="0.0.0.0",
238
+ server_port=7860,
239
+ share=False, # Set to True to create a public link
240
+ show_api=True # Enable API documentation
241
+ )