GSoumyajit2005 commited on
Commit
f74e17e
·
1 Parent(s): 1144bea

feat: Enhance pipeline with smart PDF handling, Pydantic validation, and semantic hashing, and refactor API to src.

Browse files
Files changed (10) hide show
  1. api.py +0 -35
  2. app.py +63 -75
  3. requirements.txt +19 -2
  4. src/api.py +67 -0
  5. src/pdf_utils.py +50 -0
  6. src/pipeline.py +106 -25
  7. src/schema.py +112 -0
  8. src/utils.py +35 -0
  9. tests/test_full_pipeline.py +2 -2
  10. tests/test_pipeline.py +1 -1
api.py DELETED
@@ -1,35 +0,0 @@
1
- from fastapi import FastAPI, UploadFile, File, HTTPException
2
- from src.pipeline import process_invoice
3
- import shutil
4
- import os
5
- import uvicorn
6
-
7
- app = FastAPI(title="Invoice Extraction API", version="1.0")
8
-
9
- @app.post("/extract")
10
- async def extract_invoice(file: UploadFile = File(...), method: str = 'ml'):
11
- """
12
- Endpoint to process an uploaded invoice file.
13
- """
14
- temp_file_path = f"temp_{file.filename}"
15
-
16
- try:
17
- # Save uploaded file temporarily
18
- with open(temp_file_path, "wb") as buffer:
19
- shutil.copyfileobj(file.file, buffer)
20
-
21
- # Run pipeline
22
- result = process_invoice(temp_file_path, method=method, save_results=False)
23
-
24
- return {"status": "success", "data": result}
25
-
26
- except Exception as e:
27
- raise HTTPException(status_code=500, detail=str(e))
28
-
29
- finally:
30
- # Cleanup temp file
31
- if os.path.exists(temp_file_path):
32
- os.remove(temp_file_path)
33
-
34
- if __name__ == "__main__":
35
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -12,16 +12,12 @@ import sys
12
  sys.path.append('src')
13
  from pipeline import process_invoice
14
 
15
- # --- Mock Functions to support the UI without errors ---
16
- # These functions simulate the ones from your example README.
17
- # They allow the UI to render without needing to build a complex format detector today.
18
-
19
  def detect_invoice_format(ocr_text: str):
20
  """
21
  A mock function to simulate format detection.
22
  In a real system, this would analyze the text layout.
23
  """
24
- # Simple heuristic: if it contains "SDN BHD", it's our known format.
25
  if "SDN BHD" in ocr_text:
26
  return {
27
  'name': 'Template A (Retail)',
@@ -44,9 +40,8 @@ def get_format_recommendations(format_info):
44
  else:
45
  return ["• Results may be incomplete.", "• Consider adding patterns for this format."]
46
 
47
- # --- Streamlit App ---
48
 
49
- # Page configuration
50
  st.set_page_config(
51
  page_title="Invoice Processor",
52
  page_icon="📄",
@@ -54,7 +49,7 @@ st.set_page_config(
54
  initial_sidebar_state="expanded"
55
  )
56
 
57
- # Custom CSS for styling
58
  st.markdown("""
59
  <style>
60
  .main-header {
@@ -87,11 +82,10 @@ st.markdown("""
87
  </style>
88
  """, unsafe_allow_html=True)
89
 
90
- # Title
91
  st.markdown('<h1 class="main-header">📄 Smart Invoice Processor</h1>', unsafe_allow_html=True)
92
  st.markdown("### Extract structured data from invoices using your custom-built OCR pipeline")
93
 
94
- # Sidebar
95
  with st.sidebar:
96
  st.header("ℹ️ About")
97
  st.info("""
@@ -118,7 +112,7 @@ with st.sidebar:
118
  extraction_method = st.selectbox(
119
  "Choose Extraction Method:",
120
  ('ML-Based (LayoutLMv3)', 'Rule-Based (Regex)'),
121
- help="ML-Based is more robust but may miss fields not in its training data. Rule-Based is faster but more fragile."
122
  )
123
 
124
  # Main content
@@ -128,18 +122,22 @@ with tab1:
128
  st.header("Upload an Invoice")
129
 
130
  uploaded_file = st.file_uploader(
131
- "Choose an invoice image (JPG, PNG)",
132
- type=['jpg', 'jpeg', 'png'],
133
- help="Upload a clear image of an invoice or receipt"
134
  )
135
 
136
  if uploaded_file is not None:
137
  col1, col2 = st.columns([1, 1])
138
 
139
  with col1:
140
- st.subheader("📸 Original Image")
141
- image = Image.open(uploaded_file)
142
- st.image(image, use_container_width=True)
 
 
 
 
143
  st.caption(f"Filename: {uploaded_file.name}")
144
 
145
  with col2:
@@ -148,27 +146,24 @@ with tab1:
148
  if st.button("🚀 Extract Data", type="primary"):
149
  with st.spinner("Executing your custom pipeline..."):
150
  try:
151
- # Save the uploaded file to a temporary path to be used by our pipeline
152
  temp_dir = "temp"
153
  os.makedirs(temp_dir, exist_ok=True)
154
  temp_path = os.path.join(temp_dir, uploaded_file.name)
155
  with open(temp_path, "wb") as f:
156
  f.write(uploaded_file.getbuffer())
157
 
158
- # Step 1: Call YOUR full pipeline function
159
  st.write("✅ Calling `process_invoice`...")
160
- # Map the user-friendly name from the dropdown to the actual method parameter
161
  method = 'ml' if extraction_method == 'ML-Based (LayoutLMv3)' else 'rules'
162
  st.write(f"⚙️ Using **{method.upper()}** extraction method...")
163
 
164
- # Call the pipeline with the selected method
165
- extracted_data = process_invoice(temp_path, method=method)
166
 
167
- # Step 2: Simulate format detection using the extracted data
168
  st.write("✅ Simulating format detection...")
169
  format_info = detect_invoice_format(extracted_data.get("raw_text", ""))
170
 
171
- # Store results in session state to display them
172
  st.session_state.extracted_data = extracted_data
173
  st.session_state.format_info = format_info
174
  st.session_state.processed_count += 1
@@ -178,12 +173,12 @@ with tab1:
178
  except Exception as e:
179
  st.error(f"❌ An error occurred in the pipeline: {str(e)}")
180
 
181
- # Display results if they exist in the session state
182
  if 'extracted_data' in st.session_state:
183
  st.markdown("---")
184
  st.header("📊 Extraction Results")
185
 
186
- # --- Format Detection Section ---
187
  format_info = st.session_state.format_info
188
  st.subheader("📋 Detected Format (Simulated)")
189
  col1_fmt, col2_fmt = st.columns([2, 3])
@@ -199,39 +194,34 @@ with tab1:
199
  for rec in get_format_recommendations(format_info): st.write(rec)
200
  st.markdown("---")
201
 
202
- # --- Main Results Section ---
203
  data = st.session_state.extracted_data
204
 
205
- # Confidence display
206
- confidence = data.get('extraction_confidence', 0)
207
- if confidence >= 80:
208
- st.markdown(f'<div class="success-box">✅ <strong>High Confidence: {confidence}%</strong> - Most key fields were found.</div>', unsafe_allow_html=True)
209
- elif confidence >= 50:
210
- st.markdown(f'<div class="warning-box">⚠️ <strong>Medium Confidence: {confidence}%</strong> - Some fields may be missing.</div>', unsafe_allow_html=True)
 
211
  else:
212
- st.markdown(f'<div class="error-box">❌ <strong>Low Confidence: {confidence}%</strong> - Format likely unsupported.</div>', unsafe_allow_html=True)
213
 
214
- # Validation display
215
- if data.get('validation_passed', False):
216
- st.success("✔️ Validation Passed: Total amount appears consistent with other extracted amounts.")
217
- else:
218
- st.warning("⚠️ Validation Failed: Total amount could not be verified against other numbers.")
219
-
220
- # Key metrics display
221
- # Key metrics display
222
- st.metric("🏢 Vendor", data.get('vendor') or "N/A") # <-- ADD THIS
223
 
224
  res_col1, res_col2, res_col3 = st.columns(3)
225
  res_col1.metric("📄 Receipt Number", data.get('receipt_number') or "N/A")
226
  res_col2.metric("📅 Date", data.get('date') or "N/A")
227
- res_col3.metric("💵 Total Amount", f"${data.get('total_amount'):.2f}" if data.get('total_amount') is not None else "N/A")
 
 
228
 
229
- # Use an expander for longer text fields like address
230
  with st.expander("Show More Details"):
231
- # Handle receipt_number
232
  st.markdown(f"**🧾 Receipt Number:** {data.get('receipt_number') or 'N/A'}")
233
 
234
- # Handle bill_to (can be string from ML or dict from rules)
235
  bill_to = data.get('bill_to')
236
  if isinstance(bill_to, dict):
237
  bill_to_display = bill_to.get('name') or 'N/A'
@@ -242,16 +232,18 @@ with tab1:
242
  st.markdown(f"**👤 Bill To:** {bill_to_display}")
243
 
244
  st.markdown(f"**📍 Vendor Address:** {data.get('address') or 'N/A'}")
 
 
 
245
 
246
- # Line items table
247
  if data.get('items'):
248
  st.subheader("🛒 Line Items")
249
- # Ensure data is in the right format for DataFrame
250
  items_df_data = [{
251
  "Description": item.get("description", "N/A"),
252
  "Qty": item.get("quantity", "N/A"),
253
- "Unit Price": f"${item.get('unit_price', 0.0):.2f}",
254
- "Total": f"${item.get('total', 0.0):.2f}"
255
  } for item in data['items']]
256
  df = pd.DataFrame(items_df_data)
257
  st.dataframe(df, use_container_width=True)
@@ -281,15 +273,17 @@ with tab2:
281
  st.header("📚 Sample Invoices")
282
  st.write("Try the sample invoice below to see how the system performs:")
283
 
284
- sample_dir = "data/samples" # ✅ Points to the correct folder
285
  if os.path.exists(sample_dir):
286
- sample_files = [f for f in os.listdir(sample_dir) if f.endswith(('.jpg', '.png', '.jpeg'))]
287
 
288
  if sample_files:
289
- # Display the first sample found
290
- img_path = os.path.join(sample_dir, sample_files[0])
291
- st.image(Image.open(img_path), caption=sample_files[0], use_container_width=True)
292
- st.info("You can download this image and upload it in the 'Upload & Process' tab to test the pipeline.")
 
 
293
  else:
294
  st.warning("No sample invoices found in `data/samples/`.")
295
  else:
@@ -300,26 +294,20 @@ with tab3:
300
  st.markdown("""
301
  This app follows the exact pipeline you built:
302
  ```
303
- 1. 📸 Image Upload
304
-
305
- 2. 🔄 Preprocessing (OpenCV)
306
- Grayscale conversion and noise removal.
307
 
308
- 3. 🔍 OCR (Tesseract)
309
- Optimized with PSM 6 for receipt layouts.
 
310
 
311
- 4. 🎯 Rule-Based Extraction (Regex)
312
- Your custom patterns find specific fields.
313
 
314
- 5. Confidence & Validation
315
- Heuristics to check the quality of the extraction.
316
 
317
- 6. 📊 Output JSON
318
- Presents all extracted data in a structured format.
319
  ```
320
- """)
321
- st.info("This rule-based system is a great foundation. The next step is to replace the extraction logic with an ML model like LayoutLM to handle more diverse formats!")
322
-
323
- # Footer
324
- st.markdown("---")
325
- st.markdown("<div style='text-align: center; color: #666;'>Built with your custom Python pipeline | UI by Streamlit</div>", unsafe_allow_html=True)
 
12
  sys.path.append('src')
13
  from pipeline import process_invoice
14
 
15
+ # --- Mock Functions (KEPT AS IS) ---
 
 
 
16
  def detect_invoice_format(ocr_text: str):
17
  """
18
  A mock function to simulate format detection.
19
  In a real system, this would analyze the text layout.
20
  """
 
21
  if "SDN BHD" in ocr_text:
22
  return {
23
  'name': 'Template A (Retail)',
 
40
  else:
41
  return ["• Results may be incomplete.", "• Consider adding patterns for this format."]
42
 
43
+ # --- Streamlit App (KEPT AS IS) ---
44
 
 
45
  st.set_page_config(
46
  page_title="Invoice Processor",
47
  page_icon="📄",
 
49
  initial_sidebar_state="expanded"
50
  )
51
 
52
+ # Custom CSS (KEPT AS IS)
53
  st.markdown("""
54
  <style>
55
  .main-header {
 
82
  </style>
83
  """, unsafe_allow_html=True)
84
 
85
+ # Title & Sidebar (KEPT AS IS)
86
  st.markdown('<h1 class="main-header">📄 Smart Invoice Processor</h1>', unsafe_allow_html=True)
87
  st.markdown("### Extract structured data from invoices using your custom-built OCR pipeline")
88
 
 
89
  with st.sidebar:
90
  st.header("ℹ️ About")
91
  st.info("""
 
112
  extraction_method = st.selectbox(
113
  "Choose Extraction Method:",
114
  ('ML-Based (LayoutLMv3)', 'Rule-Based (Regex)'),
115
+ help="ML-Based is more robust. Rule-Based is faster."
116
  )
117
 
118
  # Main content
 
122
  st.header("Upload an Invoice")
123
 
124
  uploaded_file = st.file_uploader(
125
+ "Choose an invoice image (JPG, PNG) or PDF",
126
+ type=['jpg', 'jpeg', 'png', 'pdf'], # Added PDF support
127
+ help="Upload a clear image or PDF of an invoice"
128
  )
129
 
130
  if uploaded_file is not None:
131
  col1, col2 = st.columns([1, 1])
132
 
133
  with col1:
134
+ st.subheader("📸 Original Document")
135
+ # Preview Logic updated for PDF support
136
+ if uploaded_file.type == "application/pdf":
137
+ st.info("📄 PDF Uploaded (Preview not supported directly)")
138
+ else:
139
+ image = Image.open(uploaded_file)
140
+ st.image(image, use_container_width=True)
141
  st.caption(f"Filename: {uploaded_file.name}")
142
 
143
  with col2:
 
146
  if st.button("🚀 Extract Data", type="primary"):
147
  with st.spinner("Executing your custom pipeline..."):
148
  try:
149
+ # Save temp file
150
  temp_dir = "temp"
151
  os.makedirs(temp_dir, exist_ok=True)
152
  temp_path = os.path.join(temp_dir, uploaded_file.name)
153
  with open(temp_path, "wb") as f:
154
  f.write(uploaded_file.getbuffer())
155
 
156
+ # Call Pipeline
157
  st.write("✅ Calling `process_invoice`...")
 
158
  method = 'ml' if extraction_method == 'ML-Based (LayoutLMv3)' else 'rules'
159
  st.write(f"⚙️ Using **{method.upper()}** extraction method...")
160
 
161
+ # ⚠️ UPDATE: Pass string path
162
+ extracted_data = process_invoice(str(temp_path), method=method)
163
 
 
164
  st.write("✅ Simulating format detection...")
165
  format_info = detect_invoice_format(extracted_data.get("raw_text", ""))
166
 
 
167
  st.session_state.extracted_data = extracted_data
168
  st.session_state.format_info = format_info
169
  st.session_state.processed_count += 1
 
173
  except Exception as e:
174
  st.error(f"❌ An error occurred in the pipeline: {str(e)}")
175
 
176
+ # Display results
177
  if 'extracted_data' in st.session_state:
178
  st.markdown("---")
179
  st.header("📊 Extraction Results")
180
 
181
+ # --- Format Detection Section (KEPT AS IS) ---
182
  format_info = st.session_state.format_info
183
  st.subheader("📋 Detected Format (Simulated)")
184
  col1_fmt, col2_fmt = st.columns([2, 3])
 
194
  for rec in get_format_recommendations(format_info): st.write(rec)
195
  st.markdown("---")
196
 
197
+ # --- Main Results Section (UPDATED) ---
198
  data = st.session_state.extracted_data
199
 
200
+ # 1. New Validation Display (Replaces old Confidence box)
201
+ status = data.get('validation_status', 'unknown')
202
+ if status == 'passed':
203
+ st.markdown(f'<div class="success-box">✅ <strong>Validation Passed</strong>: Data meets strict quality rules (Pydantic).</div>', unsafe_allow_html=True)
204
+ elif status == 'failed':
205
+ err_count = len(data.get('validation_errors', []))
206
+ st.markdown(f'<div class="error-box">❌ <strong>Validation Failed</strong>: Found {err_count} issues. Check JSON for details.</div>', unsafe_allow_html=True)
207
  else:
208
+ st.markdown(f'<div class="warning-box">⚠️ <strong>Status Unknown</strong>: Validation logic was skipped.</div>', unsafe_allow_html=True)
209
 
210
+ # 2. Key Metrics (Mapped to NEW keys)
211
+ st.metric("🏢 Vendor", data.get('vendor') or "N/A")
 
 
 
 
 
 
 
212
 
213
  res_col1, res_col2, res_col3 = st.columns(3)
214
  res_col1.metric("📄 Receipt Number", data.get('receipt_number') or "N/A")
215
  res_col2.metric("📅 Date", data.get('date') or "N/A")
216
+ # Handle total (it's now a string from the pipeline, but metric handles strings fine)
217
+ total = data.get('total_amount')
218
+ res_col3.metric("💵 Total Amount", f"${total}" if total else "N/A")
219
 
220
+ # 3. Expanded Details
221
  with st.expander("Show More Details"):
 
222
  st.markdown(f"**🧾 Receipt Number:** {data.get('receipt_number') or 'N/A'}")
223
 
224
+ # Handle bill_to
225
  bill_to = data.get('bill_to')
226
  if isinstance(bill_to, dict):
227
  bill_to_display = bill_to.get('name') or 'N/A'
 
232
  st.markdown(f"**👤 Bill To:** {bill_to_display}")
233
 
234
  st.markdown(f"**📍 Vendor Address:** {data.get('address') or 'N/A'}")
235
+
236
+ # New: Show Duplicate Hash
237
+ st.markdown(f"**🔑 Semantic Hash (Duplicate ID):** `{data.get('semantic_hash') or 'N/A'}`")
238
 
239
+ # 4. Line items table
240
  if data.get('items'):
241
  st.subheader("🛒 Line Items")
 
242
  items_df_data = [{
243
  "Description": item.get("description", "N/A"),
244
  "Qty": item.get("quantity", "N/A"),
245
+ "Unit Price": f"${item.get('unit_price', 0.0) if item.get('unit_price') is not None else 0}",
246
+ "Total": f"${item.get('total', 0.0) if item.get('total') is not None else 0}"
247
  } for item in data['items']]
248
  df = pd.DataFrame(items_df_data)
249
  st.dataframe(df, use_container_width=True)
 
273
  st.header("📚 Sample Invoices")
274
  st.write("Try the sample invoice below to see how the system performs:")
275
 
276
+ sample_dir = "data/samples"
277
  if os.path.exists(sample_dir):
278
+ sample_files = [f for f in os.listdir(sample_dir) if f.endswith(('.jpg', '.png', '.jpeg', '.pdf'))]
279
 
280
  if sample_files:
281
+ file_path = os.path.join(sample_dir, sample_files[0])
282
+ st.write(f"**Sample File:** {sample_files[0]}")
283
+ if file_path.endswith('.pdf'):
284
+ st.info("📄 PDF Sample available. Download and upload it to test.")
285
+ else:
286
+ st.image(Image.open(file_path), caption=sample_files[0], use_container_width=True)
287
  else:
288
  st.warning("No sample invoices found in `data/samples/`.")
289
  else:
 
294
  st.markdown("""
295
  This app follows the exact pipeline you built:
296
  ```
297
+ 1. 📸 Input Handling
298
+ Detects JPG vs PDF. Smart Loader extracts text from PDFs instantly.
 
 
299
 
300
+ 2. 🧠 Hybrid Engine
301
+ - Digital PDFs: Direct Text Extraction (Fast)
302
+ - Images/Scans: LayoutLMv3 (ML) + Tesseract (OCR)
303
 
304
+ 3. 🛡️ Validation Gate
305
+ Pydantic Schema ensures data integrity (Decimal precision, Date formats).
306
 
307
+ 4. 🔑 Duplicate Detection
308
+ Generates a unique semantic hash based on content.
309
 
310
+ 5. 📊 Output JSON
311
+ Standardized, validated output ready for API response.
312
  ```
313
+ """)
 
 
 
 
 
requirements.txt CHANGED
@@ -1,14 +1,31 @@
 
1
  streamlit>=1.28.0
 
 
2
  pytesseract>=0.3.10
3
  opencv-python>=4.8.0
4
  Pillow>=10.0.0
 
 
5
  numpy>=1.24.0
6
  pandas>=2.0.0
7
 
8
- # Machine Learning
9
  torch>=2.0.0
10
  torchvision>=0.15.0
11
  transformers>=4.30.0
12
  datasets>=2.14.0
13
  huggingface-hub>=0.17.0
14
- seqeval>=1.2.2
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ----- Streamlit -----
2
  streamlit>=1.28.0
3
+
4
+ # ----- OCR -----
5
  pytesseract>=0.3.10
6
  opencv-python>=4.8.0
7
  Pillow>=10.0.0
8
+
9
+ # ----- Data -----
10
  numpy>=1.24.0
11
  pandas>=2.0.0
12
 
13
+ # ----- Machine Learning -----
14
  torch>=2.0.0
15
  torchvision>=0.15.0
16
  transformers>=4.30.0
17
  datasets>=2.14.0
18
  huggingface-hub>=0.17.0
19
+ seqeval>=1.2.2
20
+
21
+ # ----- Data Validation -----
22
+ pydantic>=2.12.0
23
+
24
+ # ----- PDF Processing -----
25
+ pdfplumber>=0.11.0
26
+ pdf2image>=1.16.0
27
+
28
+ # ----- API Framework -----
29
+ fastapi>=0.126.0
30
+ uvicorn[standard]>=0.38.0
31
+ python-multipart>=0.0.21
src/api.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ # src/api.py
5
+
6
+ from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks
7
+ from fastapi.responses import JSONResponse
8
+ import shutil
9
+ import os
10
+ from pathlib import Path
11
+ import uuid
12
+ import sys
13
+
14
+ # Import modules
15
+ sys.path.append(str(Path(__file__).resolve().parent))
16
+ from pipeline import process_invoice
17
+ from schema import InvoiceData
18
+
19
+ app = FastAPI(
20
+ title="Invoice Extraction API",
21
+ description="Hybrid ML + Rule-Based Pipeline with LayoutLMv3",
22
+ version="2.0"
23
+ )
24
+
25
+ # Create temp folder if not exists
26
+ UPLOAD_DIR = Path("temp_uploads")
27
+ UPLOAD_DIR.mkdir(exist_ok=True)
28
+
29
+ def cleanup_file(path: str):
30
+ """Background task to remove temp file after processing"""
31
+ try:
32
+ if os.path.exists(path):
33
+ os.remove(path)
34
+ except Exception as e:
35
+ print(f"Error cleaning up {path}: {e}")
36
+
37
+ @app.post("/extract", response_model=InvoiceData) # <--- CONTRACT ENFORCED
38
+ async def extract_invoice(
39
+ background_tasks: BackgroundTasks,
40
+ file: UploadFile = File(...)
41
+ ):
42
+ """
43
+ Upload an invoice (PDF/JPG/PNG) and get structured data.
44
+ """
45
+ # 1. Generate unique filename to prevent collisions
46
+ file_ext = Path(file.filename).suffix
47
+ unique_name = f"{uuid.uuid4()}{file_ext}"
48
+ temp_path = UPLOAD_DIR / unique_name
49
+
50
+ try:
51
+ # 2. Save Uploaded File
52
+ with open(temp_path, "wb") as buffer:
53
+ shutil.copyfileobj(file.file, buffer)
54
+
55
+ # 3. Process Logic
56
+ result = process_invoice(str(temp_path), method='ml')
57
+
58
+ # 4. Cleanup
59
+ # We use background_tasks to delete the file AFTER the response is sent
60
+ background_tasks.add_task(cleanup_file, str(temp_path))
61
+
62
+ return result
63
+
64
+ except Exception as e:
65
+ # Cleanup even on error
66
+ cleanup_file(str(temp_path))
67
+ raise HTTPException(status_code=500, detail=str(e))
src/pdf_utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdfplumber
2
+ from pdf2image import convert_from_path
3
+ from pathlib import Path
4
+ from typing import List, Union
5
+ import numpy as np
6
+ import cv2
7
+
8
+ def extract_text_from_pdf(pdf_path: str) -> str:
9
+ """Extracts raw text from a digital PDF"""
10
+
11
+ path = Path(pdf_path)
12
+
13
+ if not path.exists():
14
+ raise FileNotFoundError(f"PDF not found: {pdf_path}")
15
+
16
+ try:
17
+ with pdfplumber.open(pdf_path) as pdf:
18
+ full_text = ""
19
+ for page in pdf.pages:
20
+ page_text = page.extract_text() or ""
21
+ full_text += page_text + "\n"
22
+ return full_text.strip()
23
+
24
+ except Exception as e:
25
+ raise ValueError(f"Failed to read PDF {pdf_path}: {str(e)}")
26
+
27
+
28
+ def convert_pdf_to_images(pdf_path: str) -> List[np.ndarray]:
29
+ """
30
+ Converts a PDF into a list of OpenCV images (numpy arrays).
31
+ Required for the ML pipeline (LayoutLM) or Scanned PDFs.
32
+
33
+ Logic:
34
+ 1. Use 'convert_from_path' to get PIL images.
35
+ 2. Convert PIL images to numpy arrays (OpenCV format).
36
+ 3. Return list of arrays.
37
+ """
38
+ # 1. Convert to PIL images
39
+ try:
40
+ pil_images = convert_from_path(pdf_path)
41
+ except Exception as e:
42
+ raise ValueError(f"Error converting PDF to image: {e}")
43
+
44
+ cv_images = []
45
+ for pil_img in pil_images:
46
+
47
+ array = np.array(pil_img)
48
+ cv_images.append(cv2.cvtColor(array, cv2.COLOR_RGB2BGR))
49
+
50
+ return cv_images
src/pipeline.py CHANGED
@@ -6,15 +6,20 @@ Orchestrates preprocessing, OCR, and extraction
6
  from typing import Dict, Any, Optional
7
  from pathlib import Path
8
  import json
 
 
9
 
10
- # Make sure all your modules are imported
11
  from preprocessing import load_image, convert_to_grayscale, remove_noise
12
  from ocr import extract_text
13
  from extraction import structure_output
14
  from ml_extraction import extract_ml_based
 
 
 
15
 
16
  def process_invoice(image_path: str,
17
- method: str = 'ml', # <-- New parameter: 'ml' or 'rules'
18
  save_results: bool = False,
19
  output_dir: str = 'outputs') -> Dict[str, Any]:
20
  """
@@ -29,45 +34,121 @@ def process_invoice(image_path: str,
29
  Returns:
30
  A dictionary with the extracted invoice data.
31
  """
 
32
  if not Path(image_path).exists():
33
- raise FileNotFoundError(f"Image not found at path: {image_path}")
 
 
34
 
35
- print(f"Processing with '{method}' method...")
 
36
 
37
- if method == 'ml':
38
- # --- ML-Based Extraction ---
 
39
  try:
40
- # The ml_extraction function handles everything internally
41
- structured_data = extract_ml_based(image_path)
42
- except Exception as e:
43
- raise ValueError(f"Error during ML-based extraction: {e}")
44
 
45
- elif method == 'rules':
46
- # --- Rule-Based Extraction (Your original logic) ---
47
- try:
48
- image = load_image(image_path)
49
- gray_image = convert_to_grayscale(image)
50
- preprocessed_image = remove_noise(gray_image, kernel_size=3)
51
- text = extract_text(preprocessed_image, config='--psm 6')
52
- structured_data = structure_output(text) # Calls your old extraction.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  except Exception as e:
54
- raise ValueError(f"Error during rule-based extraction: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- else:
57
- raise ValueError(f"Unknown extraction method: '{method}'. Choose 'ml' or 'rules'.")
 
 
 
 
 
58
 
59
- # --- Saving Logic (remains the same) ---
 
 
 
 
 
 
 
60
  if save_results:
61
  output_path = Path(output_dir)
62
  output_path.mkdir(parents=True, exist_ok=True)
63
- json_path = output_path / (Path(image_path).stem + f"_{method}.json") # Add method to filename
 
 
 
64
  try:
65
  with open(json_path, 'w', encoding='utf-8') as f:
66
- json.dump(structured_data, f, indent=2, ensure_ascii=False)
 
67
  except Exception as e:
68
  raise IOError(f"Error saving results to {json_path}: {e}")
69
 
70
- return structured_data
71
 
72
 
73
  def process_batch(image_folder: str, output_dir: str = 'outputs') -> list:
 
6
  from typing import Dict, Any, Optional
7
  from pathlib import Path
8
  import json
9
+ from pydantic import ValidationError
10
+ import cv2
11
 
12
+ # --- IMPORTS ---
13
  from preprocessing import load_image, convert_to_grayscale, remove_noise
14
  from ocr import extract_text
15
  from extraction import structure_output
16
  from ml_extraction import extract_ml_based
17
+ from schema import InvoiceData
18
+ from pdf_utils import extract_text_from_pdf, convert_pdf_to_images
19
+ from utils import generate_semantic_hash
20
 
21
  def process_invoice(image_path: str,
22
+ method: str = 'ml',
23
  save_results: bool = False,
24
  output_dir: str = 'outputs') -> Dict[str, Any]:
25
  """
 
34
  Returns:
35
  A dictionary with the extracted invoice data.
36
  """
37
+
38
  if not Path(image_path).exists():
39
+ raise FileNotFoundError(f"Image/PDF not found at path: {image_path}")
40
+
41
+ print(f"Processing: {image_path}")
42
 
43
+ raw_result = {}
44
+ is_digital_pdf = False
45
 
46
+ # --- 1. SMART PDF HANDLING ---
47
+ if image_path.lower().endswith('.pdf'):
48
+ print("📄 PDF detected. Checking type...")
49
  try:
50
+ # Attempt to extract text directly (Fast Path)
51
+ digital_text = extract_text_from_pdf(image_path)
 
 
52
 
53
+ # Heuristic: If we found >50 chars, it's likely a native Digital PDF
54
+ if len(digital_text.strip()) > 50:
55
+ print(" ✅ Digital Text found. Using Rule-Based Engine (Fast Mode).")
56
+ # We bypass the ML model because we have perfect text
57
+ raw_result = structure_output(digital_text)
58
+ is_digital_pdf = True
59
+ method = 'rules (digital)' # Override method for logging
60
+ else:
61
+ print(" ⚠️ Sparse text detected. Treating as Scanned PDF.")
62
+ # Convert first page to image for the ML pipeline
63
+ print(" 🔄 Converting Page 1 to Image...")
64
+ images = convert_pdf_to_images(image_path)
65
+
66
+ # Save as temp jpg so our existing pipeline can read it
67
+ # (In production, you might pass the array directly, but this is safer for now)
68
+ temp_jpg = image_path.replace('.pdf', '.jpg')
69
+ cv2.imwrite(temp_jpg, images[0])
70
+
71
+ # SWAP THE PATH: The rest of the pipeline will now see a JPG!
72
+ image_path = temp_jpg
73
+ print(f" ➡️ Continuing with converted image: {image_path}")
74
+
75
  except Exception as e:
76
+ print(f" PDF Error: {e}. Falling back to standard processing.")
77
+
78
+ # --- 2. STANDARD EXTRACTION (ML / RULES) ---
79
+ # Only run this if we didn't already extract from Digital PDF
80
+ if not is_digital_pdf:
81
+ print(f"⚙️ Using '{method}' method on image...")
82
+
83
+ if method == 'ml':
84
+ try:
85
+ raw_result = extract_ml_based(image_path)
86
+ except Exception as e:
87
+ raise ValueError(f"Error during ML-based extraction: {e}")
88
+
89
+ elif method == 'rules':
90
+ try:
91
+ image = load_image(image_path)
92
+ gray_image = convert_to_grayscale(image)
93
+ preprocessed_image = remove_noise(gray_image, kernel_size=3)
94
+ text = extract_text(preprocessed_image, config='--psm 6')
95
+ raw_result = structure_output(text)
96
+ except Exception as e:
97
+ raise ValueError(f"Error during rule-based extraction: {e}")
98
+
99
+ # Clean up temp file if we created one
100
+ if image_path.endswith('.jpg') and 'sample_pdf' in image_path: # Safety check
101
+ # Optional: os.remove(image_path)
102
+ pass
103
+
104
+ # --- VALIDATION STEP ---
105
+ final_data = raw_result # Default to raw if validation crashes hard
106
+
107
+ if method == 'ml':
108
+ try:
109
+ invoice = InvoiceData(**raw_result)
110
+ final_data = invoice.model_dump(mode='json')
111
+ final_data['validation_status'] = 'passed'
112
+ print("✅ Data Validation Passed")
113
+ except ValidationError as e:
114
+ print(f"❌ Data Validation Failed: {len(e.errors())} errors")
115
+
116
+ # We keep the 'raw_result' data so the user isn't left with nothing,
117
+ # but we attach the error report so they know what to fix.
118
+ final_data = raw_result.copy()
119
+ final_data['validation_status'] = 'failed'
120
 
121
+ # Format errors nicely
122
+ error_list = []
123
+ for err in e.errors():
124
+ field = " -> ".join(str(loc) for loc in err['loc'])
125
+ msg = err['msg']
126
+ print(f" - {field}: {msg}")
127
+ error_list.append(f"{field}: {msg}")
128
 
129
+ final_data['validation_errors'] = error_list
130
+
131
+ # --- DUPLICATE DETECTION ---
132
+ # We calculate the hash based on the final (or raw) data.
133
+ # This gives us a unique fingerprint for this specific business transaction.
134
+ final_data['semantic_hash'] = generate_semantic_hash(final_data)
135
+
136
+ # --- SAVING STEP ---
137
  if save_results:
138
  output_path = Path(output_dir)
139
  output_path.mkdir(parents=True, exist_ok=True)
140
+
141
+ # Helper to serialize Decimals/Dates for JSON (standard json.dump fails on them)
142
+ # You can use 'default=str' in json.dump or convert before saving
143
+ json_path = output_path / (Path(image_path).stem + f"_{method}.json")
144
  try:
145
  with open(json_path, 'w', encoding='utf-8') as f:
146
+ # Use default=str to handle Decimal and Date objects automatically
147
+ json.dump(final_data, f, indent=2, ensure_ascii=False, default=str)
148
  except Exception as e:
149
  raise IOError(f"Error saving results to {json_path}: {e}")
150
 
151
+ return final_data
152
 
153
 
154
  def process_batch(image_folder: str, output_dir: str = 'outputs') -> list:
src/schema.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/schema.py
2
+
3
+ from pydantic import BaseModel, Field, field_validator, model_validator
4
+ from typing import List, Optional, Union, Dict
5
+ from decimal import Decimal, InvalidOperation
6
+ from datetime import date as DateType, datetime
7
+
8
+ # --- 1. Line Item Schema ---
9
+ class LineItem(BaseModel):
10
+ description: str
11
+ quantity: int = Field(default=1, ge=1)
12
+ unit_price: Optional[Decimal] = Field(default=None, ge=0)
13
+ total: Decimal = Field(default=0, ge=0)
14
+
15
+ @field_validator('unit_price', 'total', mode='before')
16
+ @classmethod
17
+ def validate_precision(cls, v):
18
+ """Ensure exactly 2 decimal places for currency."""
19
+ if v is None:
20
+ return None
21
+ try:
22
+ d = Decimal(str(v))
23
+ return d.quantize(Decimal('0.01'))
24
+ except (InvalidOperation, ValueError, TypeError):
25
+ return Decimal('0.00')
26
+
27
+ # --- 2. Invoice Schema ---
28
+ class InvoiceData(BaseModel):
29
+ """
30
+ Strict Data Contract for Invoice Extraction.
31
+ """
32
+ # Core Fields
33
+ receipt_number: Optional[str] = Field(default=None, description="Unique ID")
34
+
35
+ date: Optional[DateType] = Field(default=None, description="Invoice Date")
36
+
37
+ # Financials
38
+ total_amount: Optional[Decimal] = Field(default=None, ge=0)
39
+
40
+ # Entities
41
+ vendor: Optional[str] = None
42
+ address: Optional[str] = None
43
+ bill_to: Optional[Union[str, Dict]] = None
44
+
45
+ # Nested Items
46
+ items: List[LineItem] = Field(default_factory=list)
47
+
48
+ # --- METADATA ---
49
+ validation_status: str = Field(default="unknown", description="passed/failed")
50
+ validation_errors: List[str] = Field(default_factory=list, description="List of validation failure messages")
51
+ semantic_hash: Optional[str] = Field(default=None, description="Unique fingerprint of the invoice content")
52
+
53
+ # --- VALIDATORS ---
54
+
55
+ @field_validator('date', mode='before')
56
+ @classmethod
57
+ def clean_date(cls, v):
58
+ """Logic: Handle None, parse formats, then validate range."""
59
+ if not v:
60
+ return None
61
+
62
+ parsed_date = v
63
+
64
+ if isinstance(v, str):
65
+ try:
66
+ # Try common formats
67
+ for fmt in ("%d/%m/%Y", "%Y-%m-%d", "%d-%m-%Y", "%d.%m.%Y"):
68
+ try:
69
+ parsed_date = datetime.strptime(v, fmt).date()
70
+ break
71
+ except ValueError:
72
+ continue
73
+ except Exception:
74
+ return None
75
+
76
+ if isinstance(parsed_date, DateType):
77
+ today = datetime.now().date()
78
+ if parsed_date > today:
79
+ return None
80
+
81
+ # ⚠️ FIX: Use 'DateType' constructor
82
+ min_date = DateType(today.year - 10, 1, 1)
83
+ if parsed_date < min_date:
84
+ return None
85
+
86
+ return parsed_date
87
+
88
+ return None
89
+
90
+ @field_validator('total_amount', mode='before')
91
+ @classmethod
92
+ def validate_money(cls, v):
93
+ if v is None:
94
+ return None
95
+ try:
96
+ d = Decimal(str(v))
97
+ return d.quantize(Decimal('0.01'))
98
+ except (InvalidOperation, ValueError):
99
+ return None
100
+
101
+ @model_validator(mode='after')
102
+ def validate_math(self):
103
+ if not self.items or self.total_amount is None:
104
+ return self
105
+
106
+ line_items_sum = sum(item.total for item in self.items)
107
+ diff = abs(self.total_amount - line_items_sum)
108
+
109
+ if diff > Decimal('0.05'):
110
+ print(f"⚠️ Validation Warning: Total {self.total_amount} != Sum of items {line_items_sum}")
111
+
112
+ return self
src/utils.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from typing import Dict, Any
3
+ from decimal import Decimal
4
+ from datetime import date
5
+
6
+ def generate_semantic_hash(invoice_data: Dict[str, Any]) -> str:
7
+ """
8
+ Generates a unique fingerprint using a Composite Key strategy.
9
+
10
+ Composite Key = Vendor + Date + Total + Receipt Number
11
+ """
12
+ # Define the specific fields that determine uniqueness
13
+ keys_to_hash = ['vendor', 'date', 'total_amount', 'receipt_number']
14
+ normalized_values = []
15
+
16
+ for key in keys_to_hash:
17
+ value = invoice_data[key]
18
+
19
+ # Normalize without modifying the original object
20
+ if value is None:
21
+ norm_val = ""
22
+ elif isinstance(value, (date, Decimal, int, float)):
23
+ norm_val = str(value)
24
+ else:
25
+ # String normalization
26
+ norm_val = str(value).lower().strip()
27
+
28
+ normalized_values.append(norm_val)
29
+
30
+ # Create the fingerprint string
31
+ composite_string = "|".join(normalized_values)
32
+
33
+ # Return the SHA256 hash of the string
34
+ return hashlib.sha256(composite_string.encode()).hexdigest()
35
+
tests/test_full_pipeline.py CHANGED
@@ -37,6 +37,6 @@ print("=" * 60)
37
  print("\n🎉 PIPELINE COMPLETE!")
38
  print("\n📋 Summary:")
39
  print(f" Vendor: {result['vendor']}")
40
- print(f" Invoice #: {result['invoice_number']}")
41
  print(f" Date: {result['date']}")
42
- print(f" Total: ${result['total']}")
 
37
  print("\n🎉 PIPELINE COMPLETE!")
38
  print("\n📋 Summary:")
39
  print(f" Vendor: {result['vendor']}")
40
+ print(f" Invoice #: {result['receipt_number']}")
41
  print(f" Date: {result['date']}")
42
+ print(f" Total: ${result.get('total_amount', '0.00')}")
tests/test_pipeline.py CHANGED
@@ -75,7 +75,7 @@ def test_full_pipeline():
75
  print(" - No line items extracted.")
76
 
77
  # Print total and validation status
78
- print(f"\n💵 Total Amount: ${result.get('total_amount', 0.0):.2f}")
79
 
80
  confidence = result.get('extraction_confidence', 0)
81
  print(f"📈 Confidence: {confidence}%")
 
75
  print(" - No line items extracted.")
76
 
77
  # Print total and validation status
78
+ print(f"\n💵 Total Amount: ${result.get('total_amount', 0.0)}")
79
 
80
  confidence = result.get('extraction_confidence', 0)
81
  print(f"📈 Confidence: {confidence}%")