OCR_Smart / smart_ocr_pipeline_final.py
Mariem-Daha's picture
Upload 10 files
3d7858b verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
smart_ocr_pipeline_final.py
---------------------------------
Production-ready merge of your v3 + v1 pipelines with:
- Secure OpenAI setup (no hard-coded key)
- Global DocTR model cache (faster)
- Strong preprocessing (deskew, CLAHE, sharpen)
- Geometry-aware line grouping
- GPT-4o-mini Vision post-processing (cost-aware)
- Validation & auto-correction (math checks, type normalization)
- Lightweight fallback rerun on large mismatches
- Optional EasyOCR/Tesseract fallback if DocTR fails
- Structured logging
Usage:
python smart_ocr_pipeline_final.py <path/to/invoice.jpg> [output_dir]
Default output_dir is "." (kept from your first code).
"""
import os
import sys
import json
import base64
import time
import logging
from pathlib import Path
from typing import Dict, List, Tuple, Optional
# Image processing
import cv2
import numpy as np
from PIL import Image
# OCR engines
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
# OpenAI
from openai import OpenAI
# Optional: dotenv for local development (no-op if .env absent)
try:
from dotenv import load_dotenv
load_dotenv()
except Exception:
pass
# ============================================================
# Logging
# ============================================================
def setup_logger() -> logging.Logger:
logger = logging.getLogger("smart_ocr")
logger.setLevel(logging.INFO)
if not logger.handlers:
ch = logging.StreamHandler(sys.stdout)
ch.setFormatter(logging.Formatter("%(asctime)s | %(levelname)s | %(message)s"))
logger.addHandler(ch)
return logger
log = setup_logger()
# ============================================================
# 1) SETUP & CONFIGURATION
# ============================================================
def setup_environment() -> OpenAI:
"""
Initialize OpenAI client with a reliable API key source.
Uses env var OPENAI_API_KEY. Fail fast if missing.
"""
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError(
"OPENAI_API_KEY not found. Set it in your environment, e.g.\n"
"Windows (PowerShell): $env:OPENAI_API_KEY='sk-...'\n"
"macOS/Linux (bash): export OPENAI_API_KEY='sk-...'"
)
log.info("OpenAI client initialized")
return OpenAI(api_key=api_key)
# Global cache for DocTR model (faster repeated runs)
_DOCTR_MODEL = None
def get_doctr_model():
global _DOCTR_MODEL
if _DOCTR_MODEL is None:
t0 = time.time()
_DOCTR_MODEL = ocr_predictor(pretrained=True)
log.info(f"DocTR model loaded in {time.time() - t0:.2f}s")
return _DOCTR_MODEL
# ============================================================
# 2) IMAGE PREPROCESSING
# ============================================================
def preprocess_image(input_path: str, output_dir: str = ".") -> Tuple[str, str]:
log.info("Loading image for preprocessing…")
img = cv2.imread(input_path)
if img is None:
raise ValueError(f"Could not load image: {input_path}")
log.info("Cleaning image (grayscale → denoise → deskew → CLAHE → normalize → sharpen)…")
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
denoised = cv2.bilateralFilter(gray, 9, 75, 75)
desk = deskew_image(denoised)
# Contrast + normalize
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(desk)
normalized = cv2.normalize(enhanced, None, 0, 255, cv2.NORM_MINMAX)
# Light sharpen
kernel = np.array([[-1, -1, -1],
[-1, 9, -1],
[-1, -1, -1]])
sharpened = cv2.filter2D(normalized, -1, kernel)
processed_path = os.path.join(output_dir, "processed_invoice.png")
cv2.imwrite(processed_path, sharpened)
log.info(f"Processed image saved: {processed_path}")
preview_path = create_preview(sharpened, output_dir)
log.info(f"Preview image saved: {preview_path}")
return processed_path, preview_path
def deskew_image(image: np.ndarray) -> np.ndarray:
try:
edges = cv2.Canny(image, 50, 150, apertureSize=3)
lines = cv2.HoughLines(edges, 1, np.pi / 180, 200)
if lines is None:
return image
angles = [np.degrees(theta) - 90 for rho, theta in lines[:, 0]]
median_angle = np.median(angles)
if abs(median_angle) > 0.5:
(h, w) = image.shape[:2]
M = cv2.getRotationMatrix2D((w // 2, h // 2), median_angle, 1.0)
rot = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)
log.info(f"Deskewed by {median_angle:.2f}°")
return rot
return image
except Exception as e:
log.warning(f"Deskew failed, using original: {e}")
return image
def create_preview(image: np.ndarray, output_dir: str) -> str:
# Use 1024 max side to give the vision model more detail (as in your v3)
pil_img = Image.fromarray(image)
pil_img.thumbnail((1024, 1024), Image.Resampling.LANCZOS)
preview_path = os.path.join(output_dir, "preview_invoice.png")
pil_img.save(preview_path)
return preview_path
# ============================================================
# 3) OCR EXTRACTION + LINE GROUPING
# ============================================================
HEADER_KEYWORDS = [
"quantità", "prezzo", "sconto", "importo", "iva",
"descrizione", "codice",
"tot.", "tot,", "tot", "totale", "merce", "conforme",
"trasporto", "porto", "peso", "colli",
"quantity", "price", "discount", "amount", "description", "code", "total",
]
def clean_blocks(blocks: List[Dict]) -> List[Dict]:
clean = []
for b in blocks:
text = b.get("text", "").strip()
lt = text.lower()
if len(text) <= 1:
continue
if any(k in lt for k in HEADER_KEYWORDS):
continue
clean.append(b)
return clean
def group_by_y(blocks: List[Dict], y_threshold: float = 0.015) -> List[str]:
if not blocks:
return []
blocks_sorted = sorted(blocks, key=lambda b: (b["geometry"][0][1], b["geometry"][0][0]))
lines, current_line = [], [blocks_sorted[0]]
current_y = blocks_sorted[0]["geometry"][0][1]
for b in blocks_sorted[1:]:
y = b["geometry"][0][1]
if abs(y - current_y) <= y_threshold:
current_line.append(b)
else:
text = " ".join(x["text"] for x in sorted(current_line, key=lambda x: x["geometry"][0][0]))
if text.strip():
lines.append(text.strip())
current_line = [b]
current_y = y
if current_line:
text = " ".join(x["text"] for x in sorted(current_line, key=lambda x: x["geometry"][0][0]))
if text.strip():
lines.append(text.strip())
return lines
def extract_text_with_doctr(image_path: str, output_dir: str = ".") -> Tuple[str, Dict, List[str]]:
log.info("Running DocTR OCR with geometry-based line grouping…")
model = get_doctr_model()
doc = DocumentFile.from_images(image_path)
result = model(doc)
all_blocks: List[Dict] = []
pages = []
for page_idx, page in enumerate(result.pages):
page_blocks = []
line_strings = []
for block in page.blocks:
for line in block.lines:
for word in line.words:
page_blocks.append({
"text": word.value,
"confidence": float(word.confidence),
"geometry": word.geometry, # [[x1,y1], [x2,y2]] normalized 0..1
})
line_text = " ".join([w.value for w in line.words]).strip()
if line_text:
line_strings.append(line_text)
pages.append({"page_number": page_idx + 1, "blocks": page_blocks, "lines": line_strings})
all_blocks.extend(page_blocks)
confs = [b["confidence"] for b in all_blocks if "confidence" in b]
avg_conf = float(np.mean(confs)) if confs else 0.0
ocr_json = {"pages": pages, "average_confidence": avg_conf}
# Clean + group
cleaned_blocks = clean_blocks(all_blocks)
y_lines = group_by_y(cleaned_blocks, y_threshold=0.01)
doctr_lines = sum((p["lines"] for p in pages), [])
chosen_lines = y_lines if len(y_lines) >= len(doctr_lines) else doctr_lines
formatted_lines = [f"{i+1}. {ln}" for i, ln in enumerate(chosen_lines)]
# Save debugs
ocr_result_path = os.path.join(output_dir, "ocr_result.json")
with open(ocr_result_path, "w", encoding="utf-8") as f:
json.dump(ocr_json, f, indent=2, ensure_ascii=False)
lines_path = os.path.join(output_dir, "ocr_lines.txt")
with open(lines_path, "w", encoding="utf-8") as f:
f.write("\n".join(formatted_lines))
log.info(f"DocTR complete (confidence: {avg_conf:.2f}; lines: {len(formatted_lines)})")
return "\n".join(chosen_lines), ocr_json, formatted_lines
# ============================================================
# 4) AI POST-PROCESSING (GPT-4o-mini Vision by default)
# ============================================================
def extract_structured_data(
client: OpenAI,
formatted_lines: List[str],
preview_path: str,
model_name: str = "gpt-4o-mini"
) -> Dict:
"""
Use GPT Vision to parse structured JSON from numbered, grouped lines + image.
"""
log.info(f"Processing with {model_name} …")
with open(preview_path, "rb") as img_file:
img_b64 = base64.b64encode(img_file.read()).decode("utf-8")
def is_header(line: str) -> bool:
low = line.lower()
return any(k in low for k in HEADER_KEYWORDS)
filtered_lines = [ln for ln in formatted_lines if not is_header(ln)]
system_message = """
You are a professional invoice/receipt parser for ChefCode.
You receive:
(1) Numbered OCR lines (already grouped horizontally by row).
(2) The invoice image.
Return ONLY valid JSON with this schema:
{
"supplier": "string",
"invoice_number": "string",
"date": "YYYY-MM-DD or null",
"line_items": [
{
"lot_number": "string",
"item_name": "string",
"unit": "string",
"quantity": number,
"unit_price": number or null,
"line_total": number or null,
"type": "string"
}
],
"total_amount": number or null,
"confidence": "high | medium | low"
}
Extraction rules (critical):
- The table is horizontal: Lot → Item → Unit → Quantity → Unit Price → Line Total.
- The quantity is the number DIRECTLY AFTER the unit.
- If numbers for a line appear missing, check up to TWO lines BELOW that line in OCR_LINES,
ignoring header words (Quantità, Prezzo, Sconto, Importo, IVA).
- Do not skip any visible row; compare OCR row count with extracted items and recover missing lines.
- Verify math: quantity × unit_price ≈ line_total (±3%). If off, re-read digits from the image.
- If two adjacent rows share identical numbers, re-check both in the image; do not merge distinct items.
- Use "." as decimal separator and strip any currency symbols.
- Keep supplier and item names exactly as printed; do not translate them.
- Infer "type" (meat/vegetable/dairy/grain/condiment/beverage/grocery). If invoice language is Italian,
output these category words in Italian (carne, verdura, latticini, cereali, condimento, bevanda, drogheria).
- Output ONLY JSON — no prose, no markdown.
""".strip()
user_message = f"""Extract structured data from this invoice.
OCR_LINES (count={len(filtered_lines)}):
{chr(10).join(filtered_lines)}
"""
resp = client.chat.completions.create(
model=model_name,
temperature=0.1,
max_completion_tokens=2000,
messages=[
{"role": "system", "content": system_message},
{
"role": "user",
"content": [
{"type": "text", "text": user_message},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{img_b64}", "detail": "high"},
},
],
},
],
)
# ✅ Capture real token usage directly from the API response
usage = None
try:
if hasattr(resp, "usage") and resp.usage:
usage = {
"prompt_tokens": resp.usage.prompt_tokens,
"completion_tokens": resp.usage.completion_tokens,
"total_tokens": resp.usage.total_tokens,
}
print(f"🔢 Token usage: {usage}")
else:
print("⚠️ No token usage info found in response.")
except Exception as e:
print(f"⚠️ Couldn't read token usage: {e}")
raw = resp.choices[0].message.content.strip()
# ✅ Save token usage into the structured data so it appears in smart_output.json
if raw.startswith("```json"):
raw = raw.replace("```json", "").replace("```", "").strip()
elif raw.startswith("```"):
raw = raw.replace("```", "").strip()
try:
data = json.loads(raw)
except json.JSONDecodeError as e:
log.error(f"JSON parse error: {e}")
return {"error": "json_parse_error", "raw_response": raw, "confidence": "low"}
log.info("GPT response parsed")
if usage:
data["usage"] = usage
return data
# ============================================================
# 5) VALIDATION & AUTO-CORRECTION
# ============================================================
def _coerce_number(x):
if x is None:
return None
if isinstance(x, (int, float)):
return float(x)
try:
s = str(x).replace("€", "").replace("EUR", "").replace(",", ".").strip()
return float(s)
except Exception:
return None
def detect_invoice_language(structured: Dict) -> str:
supplier = structured.get("supplier", "").lower()
items = structured.get("line_items", [])
italian_indicators = ["srl", "spa", "via", "roma", "milano", "kg", "lt"]
text_to_check = supplier + " " + " ".join(it.get("item_name", "").lower() for it in items[:3])
italian_count = sum(1 for word in italian_indicators if word in text_to_check)
return "it" if italian_count >= 2 else "en"
def normalize_item_types(structured: Dict) -> Dict:
language = detect_invoice_language(structured)
if language != "it":
return structured
type_mapping = {
"grain": "cereali",
"meat": "carne",
"fish": "pesce",
"vegetable": "verdura",
"fruit": "frutta",
"dairy": "latticini",
"condiment": "condimento",
"beverage": "bevanda",
"grocery": "alimentari",
"other": "altro"
}
items = structured.get("line_items", [])
for it in items:
item_type = (it.get("type") or "").lower()
if item_type in type_mapping:
it["type"] = type_mapping[item_type]
return structured
def reconcile_and_validate(structured: Dict, ocr_json: Dict) -> Dict:
notes = []
items = structured.get("line_items", []) or []
fixed_items = []
for it in items:
q = _coerce_number(it.get("quantity"))
p = _coerce_number(it.get("unit_price"))
t = _coerce_number(it.get("line_total"))
if q == 0: q = None
if p == 0: p = None
if t == 0: t = None
if q is not None and p is not None:
calc = round(q * p, 2)
if t is not None and t > 0 and abs(calc - t) > 0.1 * (t if t else 1):
notes.append(
f"⚠️ Large mismatch (>10%) for '{it.get('item_name','')}': q={q}, p={p}, expected={calc}, got={t}. Auto-correcting to {calc}."
)
t = calc
elif t is None or abs(calc - t) <= 0.05:
t = calc
elif abs(calc - t) <= 0.15:
notes.append(f"✓ Corrected line_total from {t} to {calc} for '{it.get('item_name','')}'.")
t = calc
else:
notes.append(f"⚠️ Line math mismatch for '{it.get('item_name','')}' (q*p={calc}, got {t}). Corrected.")
t = calc
it["quantity"] = q
it["unit_price"] = p
it["line_total"] = t
fixed_items.append(it)
structured["line_items"] = fixed_items
structured = normalize_item_types(structured)
line_sum = round(sum(it.get("line_total") or 0 for it in fixed_items), 2)
ta = _coerce_number(structured.get("total_amount"))
if ta is None:
structured["total_amount"] = line_sum
notes.append(f"Set total_amount from sum(line_totals) = {line_sum}.")
else:
if ta > 0:
diff_percent = abs(line_sum - ta) / ta * 100
if diff_percent <= 1.0:
notes.append(f"✓ Total validated: sum={line_sum}, header={ta}, diff={diff_percent:.2f}%")
structured["total_amount"] = line_sum
elif diff_percent <= 5.0:
notes.append(f"⚠️ Total mismatch (±{diff_percent:.2f}%): sum={line_sum}, header={ta}")
structured["confidence"] = "medium"
else:
notes.append(f"❌ Large total mismatch ({diff_percent:.2f}%): sum={line_sum}, header={ta}")
structured["confidence"] = "low"
else:
structured["total_amount"] = line_sum
notes.append(f"Set total_amount from sum(line_totals) = {line_sum}.")
ocr_line_count = sum(len(p["lines"]) for p in ocr_json.get("pages", []))
if len(fixed_items) < max(3, int(0.5 * ocr_line_count)):
notes.append(f"Only {len(fixed_items)}/{ocr_line_count} OCR lines became items; possible skips.")
if any("❌" in n for n in notes):
structured["confidence"] = "low"
elif any("⚠️" in n for n in notes):
if structured.get("confidence") != "low":
structured["confidence"] = "medium"
elif not any("mismatch" in n.lower() for n in notes):
structured["confidence"] = structured.get("confidence", "high")
if notes:
existing = structured.get("validation_notes")
structured["validation_notes"] = ("; ".join(notes) if not existing else (existing + "; " + "; ".join(notes)))
return structured
# ============================================================
# 5B) LIGHTWEIGHT FALLBACK
# ============================================================
def extract_structured_data_lightweight(
client: OpenAI, filtered_lines: List[str], preview_path: str, model_name: str = "gpt-4o-mini"
) -> Dict:
log.info("Re-running with lightweight prompt (numeric focus)…")
with open(preview_path, "rb") as f:
img_b64 = base64.b64encode(f.read()).decode("utf-8")
system_message = """You are a precise invoice data extractor.
FOCUS: Extract ONLY the numeric columns accurately. Do not worry about perfect item names.
Return valid JSON with this schema:
{
"supplier": "string",
"invoice_number": "string",
"date": "string",
"line_items": [
{
"lot_number": "string",
"item_name": "string",
"unit": "string",
"quantity": number,
"unit_price": number,
"line_total": number,
"type": "string"
}
],
"total_amount": number,
"confidence": "high|medium|low"
}
CRITICAL RULES:
1. For each line, extract: quantity, unit_price, line_total in that exact order
2. Verify: quantity × unit_price ≈ line_total (±2%)
3. Count ALL visible rows in the table - don't skip any
4. Sum all line_totals and verify against invoice total
5. If a row has numbers, include it - better to have all rows than miss some
Return ONLY valid JSON, no markdown."""
user_message = f"""Extract ALL line items from this invoice. Focus on getting every row with numbers.
OCR_LINES (count={len(filtered_lines)}):
{chr(10).join(filtered_lines)}
Extract EVERY line item visible in the table."""
resp = client.chat.completions.create(
model=model_name,
max_completion_tokens=3000,
messages=[
{"role": "system", "content": system_message},
{
"role": "user",
"content": [
{"type": "text", "text": user_message},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}", "detail": "high"}},
],
},
],
)
if not resp.choices:
log.error("No choices in response")
return {"error": "no_choices", "confidence": "low"}
choice = resp.choices[0]
raw = (choice.message.content or "").strip()
if not raw:
log.error(f"Empty response from GPT (finish_reason={choice.finish_reason})")
return {"error": "empty_response", "finish_reason": choice.finish_reason, "confidence": "low"}
if raw.startswith("```json"):
raw = raw.replace("```json", "").replace("```", "").strip()
elif raw.startswith("```"):
raw = raw.replace("```", "").strip()
try:
data = json.loads(raw)
except json.JSONDecodeError as e:
log.error(f"JSON parse error: {e}")
return {"error": "json_parse_error", "raw_response": raw[:500], "confidence": "low"}
log.info(f"Lightweight extraction: {len(data.get('line_items', []))} items")
return data
def should_rerun_lightweight(structured: Dict) -> bool:
line_items = structured.get("line_items", [])
if not line_items:
return False
line_sum = sum(_coerce_number(it.get("line_total")) or 0 for it in line_items)
header_total = _coerce_number(structured.get("total_amount"))
if header_total is None or header_total == 0:
return False
diff_percent = abs(line_sum - header_total) / header_total * 100
if diff_percent > 30:
log.warning(f"Large total mismatch: {diff_percent:.1f}% (line_sum={line_sum}, header={header_total})")
return True
return False
# ============================================================
# 6) OPTIONAL FALLBACK OCR (Tesseract / EasyOCR)
# ============================================================
def fallback_ocr_plain(image_path: str, output_dir: str) -> Tuple[str, Dict, List[str]]:
"""
Fallback if DocTR throws: try pytesseract or EasyOCR.
Returns raw text, json (minimal), and naive line list.
"""
try:
import pytesseract
log.info("Running Tesseract OCR (fallback)…")
img = cv2.imread(image_path)
text = pytesseract.image_to_string(img) or ""
lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
ocr_json = {
"pages": [{"page_number": 1, "blocks": [], "lines": lines}],
"average_confidence": 0.7,
"engine": "tesseract_fallback",
}
return text, ocr_json, [f"{i+1}. {ln}" for i, ln in enumerate(lines)]
except Exception:
pass
try:
import easyocr
log.info("Running EasyOCR (fallback)…")
reader = easyocr.Reader(["it", "en"], gpu=False)
results = reader.readtext(image_path, detail=1, paragraph=False)
lines = [res[1] for res in results if len(res) >= 2 and res[1].strip()]
ocr_json = {
"pages": [{"page_number": 1, "blocks": [], "lines": lines}],
"average_confidence": 0.75,
"engine": "easyocr_fallback",
}
return "\n".join(lines), ocr_json, [f"{i+1}. {ln}" for i, ln in enumerate(lines)]
except Exception as e:
log.error(f"All OCR fallbacks failed: {e}")
return "", {"pages": [], "average_confidence": 0.0, "engine": "none"}, []
# ============================================================
# 7) MAIN PIPELINE
# ============================================================
def main(invoice_path: str, output_dir: str = "."):
print("\n" + "="*60)
print("🧠 SMART OCR PIPELINE (final, gpt-4o-mini by default)")
print("="*60 + "\n")
Path(output_dir).mkdir(parents=True, exist_ok=True)
# 1) Setup
client = setup_environment()
# 2) Preprocess
t0 = time.time()
processed_path, preview_path = preprocess_image(invoice_path, output_dir)
# 3) OCR
try:
ocr_text, ocr_json, formatted_lines = extract_text_with_doctr(processed_path, output_dir)
except Exception as e:
log.error(f"DocTR OCR failed: {e}")
ocr_text, ocr_json, formatted_lines = fallback_ocr_plain(processed_path, output_dir)
# 4) AI post-processing
structured = extract_structured_data(client, formatted_lines, preview_path, model_name="gpt-4o-mini")
# 5) Validation & save
structured = reconcile_and_validate(structured, ocr_json)
# 6) Lightweight fallback rerun if needed
if should_rerun_lightweight(structured):
log.info("Triggering lightweight fallback extraction…")
structured_retry = extract_structured_data_lightweight(client, formatted_lines, preview_path, model_name="gpt-4o-mini")
retry_items = len(structured_retry.get("line_items", []))
original_items = len(structured.get("line_items", []))
if retry_items > original_items:
log.info(f"Using lightweight result: {retry_items} items vs {original_items} items")
structured = reconcile_and_validate(structured_retry, ocr_json)
structured["rerun_applied"] = "lightweight_fallback"
else:
log.info(f"Keeping original result: {original_items} items vs {retry_items} items")
structured["rerun_attempted"] = "lightweight_fallback_not_better"
final_output = {
"status": "success",
"pipeline_version": "3.1_final_gpt4o-mini",
"input_file": Path(invoice_path).name,
"ocr_confidence": ocr_json.get("average_confidence", 0.0),
"lines_detected": sum(len(p["lines"]) for p in ocr_json.get("pages", [])) if ocr_json.get("pages") else 0,
"data": structured,
"elapsed_sec": round(time.time() - t0, 2),
"usage": structured.get("usage", None),
}
out_path = os.path.join(output_dir, "smart_output.json")
with open(out_path, "w", encoding="utf-8") as f:
json.dump(final_output, f, indent=2, ensure_ascii=False)
log.info(f"Final output saved: {out_path}")
log.info(f" • OCR Confidence: {final_output['ocr_confidence']:.2f}")
log.info(f" • Items parsed: {len(structured.get('line_items', []))}")
log.info(f" • Total: {structured.get('total_amount')}")
log.info(f" • Elapsed: {final_output['elapsed_sec']}s")
print("\nDone.\n")
return final_output
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python smart_ocr_pipeline_final.py <path/to/invoice.jpg> [output_dir]")
sys.exit(1)
invoice_path = sys.argv[1]
output_dir = sys.argv[2] if len(sys.argv) > 2 else "."
main(invoice_path, output_dir)