abinash73 commited on
Commit
bc87f48
·
verified ·
1 Parent(s): dbd1f1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +320 -210
app.py CHANGED
@@ -1,241 +1,351 @@
1
- import torch
2
- import base64
3
  import gradio as gr
4
- from io import BytesIO
 
 
 
5
  from PIL import Image
6
- from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
 
7
 
8
- from olmocr.data.renderpdf import render_pdf_to_base64png
9
- from olmocr.prompts import build_no_anchoring_v4_yaml_prompt
10
 
11
- # Initialize the model
12
- print("Loading OlmOCR model...")
13
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
14
- "allenai/olmOCR-2-7B-1025",
15
- torch_dtype=torch.bfloat16
16
- ).eval()
17
- processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
18
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
- model.to(device)
20
- print(f"Model loaded successfully on {device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- def process_pdf(pdf_file, page_number=1, max_new_tokens=50, temperature=0.1):
23
- """
24
- Process a PDF file and extract text using OlmOCR
25
-
26
- Args:
27
- pdf_file: Path to uploaded PDF file
28
- page_number: Page number to extract (default: 1)
29
- max_new_tokens: Maximum tokens to generate
30
- temperature: Sampling temperature
31
-
32
- Returns:
33
- Extracted text from the PDF
34
- """
35
- try:
36
- # Render PDF page to base64 image
37
- image_base64 = render_pdf_to_base64png(
38
- pdf_file,
39
- page_number,
40
- target_longest_image_dim=1288
41
- )
42
-
43
- # Build the prompt
44
- messages = [
45
- {
46
- "role": "user",
47
- "content": [
48
- {"type": "text", "text": build_no_anchoring_v4_yaml_prompt()},
49
- {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
50
- ],
51
- }
52
- ]
53
-
54
- # Process inputs
55
- text = processor.apply_chat_template(
56
- messages,
57
- tokenize=False,
58
- add_generation_prompt=True
59
- )
60
- main_image = Image.open(BytesIO(base64.b64decode(image_base64)))
61
-
62
- inputs = processor(
63
- text=[text],
64
- images=[main_image],
65
- padding=True,
66
- return_tensors="pt",
67
- )
68
- inputs = {key: value.to(device) for (key, value) in inputs.items()}
69
-
70
- # Generate output
71
- output = model.generate(
72
- **inputs,
73
- temperature=temperature,
74
- max_new_tokens=max_new_tokens,
75
- num_return_sequences=1,
76
- do_sample=True,
77
- )
78
-
79
- # Decode output
80
- prompt_length = inputs["input_ids"].shape[1]
81
- new_tokens = output[:, prompt_length:]
82
- text_output = processor.tokenizer.batch_decode(
83
- new_tokens,
84
- skip_special_tokens=True
85
- )
86
-
87
- return text_output[0] if text_output else "No text extracted"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- except Exception as e:
90
- return f"Error processing PDF: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- def process_image(image_file, max_new_tokens=50, temperature=0.1):
93
- """
94
- Process an image file directly using OlmOCR
95
-
96
- Args:
97
- image_file: PIL Image or path to image file
98
- max_new_tokens: Maximum tokens to generate
99
- temperature: Sampling temperature
100
-
101
- Returns:
102
- Extracted text from the image
103
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  try:
105
- # Convert image to base64
106
- if isinstance(image_file, str):
107
- with open(image_file, 'rb') as f:
108
- image_bytes = f.read()
109
  else:
110
- buffered = BytesIO()
111
- image_file.save(buffered, format="PNG")
112
- image_bytes = buffered.getvalue()
113
-
114
- image_base64 = base64.b64encode(image_bytes).decode('utf-8')
115
-
116
- # Build the prompt
117
- messages = [
118
- {
119
- "role": "user",
120
- "content": [
121
- {"type": "text", "text": build_no_anchoring_v4_yaml_prompt()},
122
- {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
123
- ],
124
- }
125
- ]
126
-
127
- # Process inputs
128
- text = processor.apply_chat_template(
129
- messages,
130
- tokenize=False,
131
- add_generation_prompt=True
132
- )
133
- main_image = Image.open(BytesIO(image_bytes))
134
-
135
- inputs = processor(
136
- text=[text],
137
- images=[main_image],
138
- padding=True,
139
- return_tensors="pt",
140
- )
141
- inputs = {key: value.to(device) for (key, value) in inputs.items()}
142
-
143
- # Generate output
144
- output = model.generate(
145
- **inputs,
146
- temperature=temperature,
147
- max_new_tokens=max_new_tokens,
148
- num_return_sequences=1,
149
- do_sample=True,
150
- )
151
-
152
- # Decode output
153
- prompt_length = inputs["input_ids"].shape[1]
154
- new_tokens = output[:, prompt_length:]
155
- text_output = processor.tokenizer.batch_decode(
156
- new_tokens,
157
- skip_special_tokens=True
158
- )
159
-
160
- return text_output[0] if text_output else "No text extracted"
161
 
162
  except Exception as e:
163
- return f"Error processing image: {str(e)}"
 
 
 
164
 
165
- # Create Gradio interface with tabs
166
- with gr.Blocks(title="OlmOCR API") as demo:
167
- gr.Markdown("# OlmOCR - PDF & Image Text Extraction")
168
- gr.Markdown("Extract text from PDFs and images using the OlmOCR model")
169
-
170
- with gr.Tab("PDF Processing"):
171
- with gr.Row():
172
- with gr.Column():
173
- pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
174
- pdf_page = gr.Number(label="Page Number", value=1, precision=0)
175
- pdf_tokens = gr.Slider(label="Max New Tokens", minimum=10, maximum=500, value=50, step=10)
176
- pdf_temp = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.1, step=0.1)
177
- pdf_button = gr.Button("Extract Text from PDF", variant="primary")
178
- with gr.Column():
179
- pdf_output = gr.Textbox(label="Extracted Text", lines=15)
180
-
181
- pdf_button.click(
182
- fn=process_pdf,
183
- inputs=[pdf_input, pdf_page, pdf_tokens, pdf_temp],
184
- outputs=pdf_output
185
- )
186
-
187
- with gr.Tab("Image Processing"):
188
- with gr.Row():
189
- with gr.Column():
190
- image_input = gr.Image(label="Upload Image", type="pil")
191
- image_tokens = gr.Slider(label="Max New Tokens", minimum=10, maximum=500, value=50, step=10)
192
- image_temp = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.1, step=0.1)
193
- image_button = gr.Button("Extract Text from Image", variant="primary")
194
- with gr.Column():
195
- image_output = gr.Textbox(label="Extracted Text", lines=15)
196
-
197
- image_button.click(
198
- fn=process_image,
199
- inputs=[image_input, image_tokens, image_temp],
200
- outputs=image_output
201
- )
202
 
203
  gr.Markdown("""
204
- ### API Usage
205
- Once running, you can access the API at:
206
- - **Web Interface**: http://localhost:7860
207
- - **API Endpoint**: http://localhost:7860/api/predict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
- ### Python API Client Example:
 
 
210
  ```python
211
  from gradio_client import Client
212
 
213
  client = Client("http://localhost:7860")
214
-
215
- # For PDF
216
  result = client.predict(
217
- pdf_file="path/to/file.pdf",
218
- page_number=1,
219
- max_new_tokens=50,
220
- temperature=0.1,
221
  api_name="/predict"
222
  )
 
 
223
 
224
- # For Image
225
- result = client.predict(
226
- image_file="path/to/image.png",
227
- max_new_tokens=50,
228
- temperature=0.1,
229
- api_name="/predict_1"
230
- )
231
  ```
232
  """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
- # Launch the app
235
  if __name__ == "__main__":
236
  demo.launch(
237
  server_name="0.0.0.0",
238
  server_port=7860,
239
- share=False, # Set to True to create a public link
240
- show_api=True # Enable API documentation
241
  )
 
 
 
1
  import gradio as gr
2
+ import json
3
+ import re
4
+ from datetime import datetime
5
+ from paddleocr import PaddleOCR
6
  from PIL import Image
7
+ import pdf2image
8
+ import numpy as np
9
 
10
+ # Initialize PaddleOCR
11
+ ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False)
12
 
13
+ def extract_text_from_image(image):
14
+ """Extract text from image using PaddleOCR"""
15
+ if isinstance(image, Image.Image):
16
+ image = np.array(image)
17
+
18
+ result = ocr.ocr(image, cls=True)
19
+
20
+ # Extract text with coordinates
21
+ text_blocks = []
22
+ for line in result[0]:
23
+ bbox = line[0]
24
+ text = line[1][0]
25
+ confidence = line[1][1]
26
+
27
+ # Calculate center point for positioning
28
+ y_center = (bbox[0][1] + bbox[2][1]) / 2
29
+ x_center = (bbox[0][0] + bbox[2][0]) / 2
30
+
31
+ text_blocks.append({
32
+ 'text': text,
33
+ 'y': y_center,
34
+ 'x': x_center,
35
+ 'confidence': confidence
36
+ })
37
+
38
+ return text_blocks
39
 
40
+ def pdf_to_images(pdf_file):
41
+ """Convert PDF to images"""
42
+ images = pdf2image.convert_from_path(pdf_file)
43
+ return images
44
+
45
+ def extract_gstin(text):
46
+ """Extract GSTIN using pattern matching"""
47
+ gstin_pattern = r'\d{2}[A-Z]{5}\d{4}[A-Z]{1}[A-Z\d]{1}[Z]{1}[A-Z\d]{1}'
48
+ match = re.search(gstin_pattern, text)
49
+ return match.group(0) if match else None
50
+
51
+ def extract_pincode(text):
52
+ """Extract 6-digit PIN code"""
53
+ pincode_pattern = r'\b\d{6}\b'
54
+ match = re.search(pincode_pattern, text)
55
+ return match.group(0) if match else None
56
+
57
+ def extract_mobile(text):
58
+ """Extract mobile number"""
59
+ mobile_pattern = r'\b[6-9]\d{9}\b'
60
+ match = re.search(mobile_pattern, text)
61
+ return match.group(0) if match else None
62
+
63
+ def extract_date(text):
64
+ """Extract date in various formats"""
65
+ date_patterns = [
66
+ r'\d{2}[-/]\d{2}[-/]\d{4}',
67
+ r'\d{2}[-/]\d{2}[-/]\d{2}',
68
+ r'\d{4}[-/]\d{2}[-/]\d{2}'
69
+ ]
70
+ for pattern in date_patterns:
71
+ match = re.search(pattern, text)
72
+ if match:
73
+ return match.group(0)
74
+ return None
75
+
76
+ def extract_invoice_number(text_blocks):
77
+ """Extract invoice/bill number"""
78
+ for block in text_blocks:
79
+ text = block['text']
80
+ if re.search(r'(invoice|bill)\s*(no|number|#)', text.lower()):
81
+ # Look for number in same or next block
82
+ number_match = re.search(r'[A-Z0-9/-]+', text)
83
+ if number_match:
84
+ return number_match.group(0)
85
+ return None
86
+
87
+ def extract_amounts(text):
88
+ """Extract monetary amounts"""
89
+ amount_pattern = r'₹?\s*(\d+(?:,\d+)*(?:\.\d{2})?)'
90
+ amounts = re.findall(amount_pattern, text)
91
+ return [float(amt.replace(',', '')) for amt in amounts]
92
+
93
+ def find_header_info(text_blocks):
94
+ """Extract header information (supplier details)"""
95
+ all_text = ' '.join([block['text'] for block in text_blocks])
96
+
97
+ header = {
98
+ "supplier_name": None,
99
+ "supplier_pincode": extract_pincode(all_text),
100
+ "gstin": extract_gstin(all_text),
101
+ "contact_no": extract_mobile(all_text),
102
+ "invoice_no": extract_invoice_number(text_blocks),
103
+ "invoice_date": extract_date(all_text)
104
+ }
105
+
106
+ # Extract supplier name (usually first few lines)
107
+ top_blocks = sorted(text_blocks, key=lambda x: x['y'])[:5]
108
+ supplier_name_candidates = []
109
+ for block in top_blocks:
110
+ text = block['text'].strip()
111
+ if len(text) > 3 and not re.match(r'^[\d\s.,]+$', text):
112
+ supplier_name_candidates.append(text)
113
+
114
+ if supplier_name_candidates:
115
+ header['supplier_name'] = supplier_name_candidates[0]
116
+
117
+ return header
118
+
119
+ def find_line_items(text_blocks):
120
+ """Extract line items from invoice"""
121
+ # Sort blocks by Y coordinate
122
+ sorted_blocks = sorted(text_blocks, key=lambda x: x['y'])
123
+
124
+ items = []
125
+ current_item = {}
126
+
127
+ # Simple heuristic: Look for patterns
128
+ for i, block in enumerate(sorted_blocks):
129
+ text = block['text'].strip()
130
 
131
+ # Look for HSN codes (6 or 8 digits)
132
+ hsn_match = re.search(r'\b\d{4,8}\b', text)
133
+ if hsn_match and not current_item.get('hsn'):
134
+ current_item['hsn'] = hsn_match.group(0)
135
+
136
+ # Look for quantities
137
+ qty_match = re.search(r'\b(\d+(?:\.\d+)?)\s*(pcs|nos|kg|ltr|box|unit)?', text.lower())
138
+ if qty_match and not current_item.get('qty'):
139
+ current_item['qty'] = float(qty_match.group(1))
140
+ current_item['unit'] = qty_match.group(2) if qty_match.group(2) else 'Nos'
141
+
142
+ # Look for rates/amounts
143
+ amount_matches = re.findall(r'₹?\s*(\d+(?:,\d+)*(?:\.\d{2})?)', text)
144
+ if amount_matches:
145
+ amounts = [float(amt.replace(',', '')) for amt in amount_matches]
146
+ if not current_item.get('rate') and len(amounts) > 0:
147
+ current_item['rate'] = amounts[0]
148
+
149
+ # Look for GST percentages
150
+ gst_match = re.search(r'(\d+(?:\.\d+)?)\s*%', text)
151
+ if gst_match and not current_item.get('gst_percent'):
152
+ current_item['gst_percent'] = float(gst_match.group(1))
153
+
154
+ # If we have enough info, save item
155
+ if len(current_item) >= 3:
156
+ if 'item_name' not in current_item:
157
+ current_item['item_name'] = text[:50]
158
+
159
+ items.append({
160
+ 'item_name': current_item.get('item_name', 'Item'),
161
+ 'hsn': current_item.get('hsn', ''),
162
+ 'qty': current_item.get('qty', 0),
163
+ 'unit': current_item.get('unit', 'Nos'),
164
+ 'rate': current_item.get('rate', 0),
165
+ 'discount': current_item.get('discount', 0),
166
+ 'gst_percent': current_item.get('gst_percent', 0)
167
+ })
168
+ current_item = {}
169
+
170
+ return items
171
 
172
+ def calculate_totals(items):
173
+ """Calculate totals from line items"""
174
+ total_gross = 0
175
+ total_taxable = 0
176
+ total_gst = 0
177
+
178
+ for item in items:
179
+ qty = item.get('qty', 0)
180
+ rate = item.get('rate', 0)
181
+ discount = item.get('discount', 0)
182
+ gst_percent = item.get('gst_percent', 0)
183
+
184
+ gross = qty * rate
185
+ taxable = gross - discount
186
+ gst_amount = (taxable * gst_percent) / 100
187
+
188
+ item['gross_amount'] = round(gross, 2)
189
+ item['taxable_amount'] = round(taxable, 2)
190
+ item['gst_amount'] = round(gst_amount, 2)
191
+ item['total_amount'] = round(taxable + gst_amount, 2)
192
+
193
+ total_gross += gross
194
+ total_taxable += taxable
195
+ total_gst += gst_amount
196
+
197
+ return {
198
+ 'total_gross': round(total_gross, 2),
199
+ 'total_taxable': round(total_taxable, 2),
200
+ 'total_gst': round(total_gst, 2),
201
+ 'grand_total': round(total_taxable + total_gst, 2)
202
+ }
203
+
204
+ def extract_invoice_data(file):
205
+ """Main function to extract all invoice data"""
206
  try:
207
+ # Convert PDF to image if needed
208
+ if file.name.lower().endswith('.pdf'):
209
+ images = pdf_to_images(file.name)
210
+ image = images[0] # Process first page
211
  else:
212
+ image = Image.open(file.name)
213
+
214
+ # Extract text with OCR
215
+ text_blocks = extract_text_from_image(image)
216
+
217
+ # Extract different sections
218
+ header = find_header_info(text_blocks)
219
+ details = find_line_items(text_blocks)
220
+ footer = calculate_totals(details)
221
+
222
+ # Build final JSON structure
223
+ result = {
224
+ "header": header,
225
+ "details": details,
226
+ "footer": footer
227
+ }
228
+
229
+ return json.dumps(result, indent=2, ensure_ascii=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
  except Exception as e:
232
+ return json.dumps({
233
+ "error": str(e),
234
+ "message": "Failed to process invoice"
235
+ }, indent=2)
236
 
237
+ # Create Gradio Interface
238
+ with gr.Blocks(title="Purchase Invoice Data Extraction", theme=gr.themes.Soft()) as demo:
239
+ gr.Markdown("""
240
+ # 🧾 Purchase Invoice Data Extraction API
241
+
242
+ Upload purchase invoices (PDF or Image) to automatically extract structured data including:
243
+ - Supplier details (Name, PIN, GSTIN, Contact)
244
+ - Invoice information (Number, Date)
245
+ - Line items (Name, HSN, Qty, Rate, Discounts, GST%)
246
+ - Calculated totals (Gross, Taxable, Tax, Grand Total)
247
+ """)
248
+
249
+ with gr.Row():
250
+ with gr.Column():
251
+ file_input = gr.File(
252
+ label="Upload Invoice (PDF or Image)",
253
+ file_types=[".pdf", ".png", ".jpg", ".jpeg"]
254
+ )
255
+ extract_btn = gr.Button("Extract Data", variant="primary", size="lg")
256
+
257
+ gr.Markdown("""
258
+ ### Supported Formats:
259
+ - PDF documents
260
+ - PNG, JPG, JPEG images
261
+ - English and Hindi text
262
+ """)
263
+
264
+ with gr.Column():
265
+ output_json = gr.Code(
266
+ label="Extracted Data (JSON)",
267
+ language="json",
268
+ lines=25
269
+ )
 
 
 
 
270
 
271
  gr.Markdown("""
272
+ ### Output Structure:
273
+ ```json
274
+ {
275
+ "header": {
276
+ "supplier_name": "...",
277
+ "supplier_pincode": "...",
278
+ "gstin": "...",
279
+ "contact_no": "...",
280
+ "invoice_no": "...",
281
+ "invoice_date": "..."
282
+ },
283
+ "details": [
284
+ {
285
+ "item_name": "...",
286
+ "hsn": "...",
287
+ "qty": 0,
288
+ "unit": "...",
289
+ "rate": 0,
290
+ "discount": 0,
291
+ "gst_percent": 0,
292
+ "gross_amount": 0,
293
+ "taxable_amount": 0,
294
+ "gst_amount": 0,
295
+ "total_amount": 0
296
+ }
297
+ ],
298
+ "footer": {
299
+ "total_gross": 0,
300
+ "total_taxable": 0,
301
+ "total_gst": 0,
302
+ "grand_total": 0
303
+ }
304
+ }
305
+ ```
306
+
307
+ ---
308
 
309
+ ### API Usage:
310
+
311
+ **Python Client:**
312
  ```python
313
  from gradio_client import Client
314
 
315
  client = Client("http://localhost:7860")
 
 
316
  result = client.predict(
317
+ file="path/to/invoice.pdf",
 
 
 
318
  api_name="/predict"
319
  )
320
+ print(result)
321
+ ```
322
 
323
+ **cURL:**
324
+ ```bash
325
+ curl -X POST http://localhost:7860/api/predict \\
326
+ -F "file=@invoice.pdf"
 
 
 
327
  ```
328
  """)
329
+
330
+ extract_btn.click(
331
+ fn=extract_invoice_data,
332
+ inputs=[file_input],
333
+ outputs=[output_json]
334
+ )
335
+
336
+ # Example usage
337
+ gr.Examples(
338
+ examples=[],
339
+ inputs=[file_input],
340
+ outputs=[output_json],
341
+ fn=extract_invoice_data,
342
+ cache_examples=False
343
+ )
344
 
 
345
  if __name__ == "__main__":
346
  demo.launch(
347
  server_name="0.0.0.0",
348
  server_port=7860,
349
+ share=False,
350
+ show_api=True
351
  )