Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| smart_ocr_pipeline_final.py | |
| --------------------------------- | |
| Production-ready merge of your v3 + v1 pipelines with: | |
| - Secure OpenAI setup (no hard-coded key) | |
| - Global DocTR model cache (faster) | |
| - Strong preprocessing (deskew, CLAHE, sharpen) | |
| - Geometry-aware line grouping | |
| - GPT-4o-mini Vision post-processing (cost-aware) | |
| - Validation & auto-correction (math checks, type normalization) | |
| - Lightweight fallback rerun on large mismatches | |
| - Optional EasyOCR/Tesseract fallback if DocTR fails | |
| - Structured logging | |
| Usage: | |
| python smart_ocr_pipeline_final.py <path/to/invoice.jpg> [output_dir] | |
| Default output_dir is "." (kept from your first code). | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import base64 | |
| import time | |
| import logging | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple, Optional | |
| # Image processing | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| # OCR engines | |
| from doctr.io import DocumentFile | |
| from doctr.models import ocr_predictor | |
| # OpenAI | |
| from openai import OpenAI | |
| # Optional: dotenv for local development (no-op if .env absent) | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except Exception: | |
| pass | |
| # ============================================================ | |
| # Logging | |
| # ============================================================ | |
| def setup_logger() -> logging.Logger: | |
| logger = logging.getLogger("smart_ocr") | |
| logger.setLevel(logging.INFO) | |
| if not logger.handlers: | |
| ch = logging.StreamHandler(sys.stdout) | |
| ch.setFormatter(logging.Formatter("%(asctime)s | %(levelname)s | %(message)s")) | |
| logger.addHandler(ch) | |
| return logger | |
| log = setup_logger() | |
| # ============================================================ | |
| # 1) SETUP & CONFIGURATION | |
| # ============================================================ | |
| def setup_environment() -> OpenAI: | |
| """ | |
| Initialize OpenAI client with a reliable API key source. | |
| Uses env var OPENAI_API_KEY. Fail fast if missing. | |
| """ | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| raise ValueError( | |
| "OPENAI_API_KEY not found. Set it in your environment, e.g.\n" | |
| "Windows (PowerShell): $env:OPENAI_API_KEY='sk-...'\n" | |
| "macOS/Linux (bash): export OPENAI_API_KEY='sk-...'" | |
| ) | |
| log.info("OpenAI client initialized") | |
| return OpenAI(api_key=api_key) | |
| # Global cache for DocTR model (faster repeated runs) | |
| _DOCTR_MODEL = None | |
| def get_doctr_model(): | |
| global _DOCTR_MODEL | |
| if _DOCTR_MODEL is None: | |
| t0 = time.time() | |
| _DOCTR_MODEL = ocr_predictor(pretrained=True) | |
| log.info(f"DocTR model loaded in {time.time() - t0:.2f}s") | |
| return _DOCTR_MODEL | |
| # ============================================================ | |
| # 2) IMAGE PREPROCESSING | |
| # ============================================================ | |
| def preprocess_image(input_path: str, output_dir: str = ".") -> Tuple[str, str]: | |
| log.info("Loading image for preprocessing…") | |
| img = cv2.imread(input_path) | |
| if img is None: | |
| raise ValueError(f"Could not load image: {input_path}") | |
| log.info("Cleaning image (grayscale → denoise → deskew → CLAHE → normalize → sharpen)…") | |
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
| denoised = cv2.bilateralFilter(gray, 9, 75, 75) | |
| desk = deskew_image(denoised) | |
| # Contrast + normalize | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
| enhanced = clahe.apply(desk) | |
| normalized = cv2.normalize(enhanced, None, 0, 255, cv2.NORM_MINMAX) | |
| # Light sharpen | |
| kernel = np.array([[-1, -1, -1], | |
| [-1, 9, -1], | |
| [-1, -1, -1]]) | |
| sharpened = cv2.filter2D(normalized, -1, kernel) | |
| processed_path = os.path.join(output_dir, "processed_invoice.png") | |
| cv2.imwrite(processed_path, sharpened) | |
| log.info(f"Processed image saved: {processed_path}") | |
| preview_path = create_preview(sharpened, output_dir) | |
| log.info(f"Preview image saved: {preview_path}") | |
| return processed_path, preview_path | |
| def deskew_image(image: np.ndarray) -> np.ndarray: | |
| try: | |
| edges = cv2.Canny(image, 50, 150, apertureSize=3) | |
| lines = cv2.HoughLines(edges, 1, np.pi / 180, 200) | |
| if lines is None: | |
| return image | |
| angles = [np.degrees(theta) - 90 for rho, theta in lines[:, 0]] | |
| median_angle = np.median(angles) | |
| if abs(median_angle) > 0.5: | |
| (h, w) = image.shape[:2] | |
| M = cv2.getRotationMatrix2D((w // 2, h // 2), median_angle, 1.0) | |
| rot = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE) | |
| log.info(f"Deskewed by {median_angle:.2f}°") | |
| return rot | |
| return image | |
| except Exception as e: | |
| log.warning(f"Deskew failed, using original: {e}") | |
| return image | |
| def create_preview(image: np.ndarray, output_dir: str) -> str: | |
| # Use 1024 max side to give the vision model more detail (as in your v3) | |
| pil_img = Image.fromarray(image) | |
| pil_img.thumbnail((1024, 1024), Image.Resampling.LANCZOS) | |
| preview_path = os.path.join(output_dir, "preview_invoice.png") | |
| pil_img.save(preview_path) | |
| return preview_path | |
| # ============================================================ | |
| # 3) OCR EXTRACTION + LINE GROUPING | |
| # ============================================================ | |
| HEADER_KEYWORDS = [ | |
| "quantità", "prezzo", "sconto", "importo", "iva", | |
| "descrizione", "codice", | |
| "tot.", "tot,", "tot", "totale", "merce", "conforme", | |
| "trasporto", "porto", "peso", "colli", | |
| "quantity", "price", "discount", "amount", "description", "code", "total", | |
| ] | |
| def clean_blocks(blocks: List[Dict]) -> List[Dict]: | |
| clean = [] | |
| for b in blocks: | |
| text = b.get("text", "").strip() | |
| lt = text.lower() | |
| if len(text) <= 1: | |
| continue | |
| if any(k in lt for k in HEADER_KEYWORDS): | |
| continue | |
| clean.append(b) | |
| return clean | |
| def group_by_y(blocks: List[Dict], y_threshold: float = 0.015) -> List[str]: | |
| if not blocks: | |
| return [] | |
| blocks_sorted = sorted(blocks, key=lambda b: (b["geometry"][0][1], b["geometry"][0][0])) | |
| lines, current_line = [], [blocks_sorted[0]] | |
| current_y = blocks_sorted[0]["geometry"][0][1] | |
| for b in blocks_sorted[1:]: | |
| y = b["geometry"][0][1] | |
| if abs(y - current_y) <= y_threshold: | |
| current_line.append(b) | |
| else: | |
| text = " ".join(x["text"] for x in sorted(current_line, key=lambda x: x["geometry"][0][0])) | |
| if text.strip(): | |
| lines.append(text.strip()) | |
| current_line = [b] | |
| current_y = y | |
| if current_line: | |
| text = " ".join(x["text"] for x in sorted(current_line, key=lambda x: x["geometry"][0][0])) | |
| if text.strip(): | |
| lines.append(text.strip()) | |
| return lines | |
| def extract_text_with_doctr(image_path: str, output_dir: str = ".") -> Tuple[str, Dict, List[str]]: | |
| log.info("Running DocTR OCR with geometry-based line grouping…") | |
| model = get_doctr_model() | |
| doc = DocumentFile.from_images(image_path) | |
| result = model(doc) | |
| all_blocks: List[Dict] = [] | |
| pages = [] | |
| for page_idx, page in enumerate(result.pages): | |
| page_blocks = [] | |
| line_strings = [] | |
| for block in page.blocks: | |
| for line in block.lines: | |
| for word in line.words: | |
| page_blocks.append({ | |
| "text": word.value, | |
| "confidence": float(word.confidence), | |
| "geometry": word.geometry, # [[x1,y1], [x2,y2]] normalized 0..1 | |
| }) | |
| line_text = " ".join([w.value for w in line.words]).strip() | |
| if line_text: | |
| line_strings.append(line_text) | |
| pages.append({"page_number": page_idx + 1, "blocks": page_blocks, "lines": line_strings}) | |
| all_blocks.extend(page_blocks) | |
| confs = [b["confidence"] for b in all_blocks if "confidence" in b] | |
| avg_conf = float(np.mean(confs)) if confs else 0.0 | |
| ocr_json = {"pages": pages, "average_confidence": avg_conf} | |
| # Clean + group | |
| cleaned_blocks = clean_blocks(all_blocks) | |
| y_lines = group_by_y(cleaned_blocks, y_threshold=0.01) | |
| doctr_lines = sum((p["lines"] for p in pages), []) | |
| chosen_lines = y_lines if len(y_lines) >= len(doctr_lines) else doctr_lines | |
| formatted_lines = [f"{i+1}. {ln}" for i, ln in enumerate(chosen_lines)] | |
| # Save debugs | |
| ocr_result_path = os.path.join(output_dir, "ocr_result.json") | |
| with open(ocr_result_path, "w", encoding="utf-8") as f: | |
| json.dump(ocr_json, f, indent=2, ensure_ascii=False) | |
| lines_path = os.path.join(output_dir, "ocr_lines.txt") | |
| with open(lines_path, "w", encoding="utf-8") as f: | |
| f.write("\n".join(formatted_lines)) | |
| log.info(f"DocTR complete (confidence: {avg_conf:.2f}; lines: {len(formatted_lines)})") | |
| return "\n".join(chosen_lines), ocr_json, formatted_lines | |
| # ============================================================ | |
| # 4) AI POST-PROCESSING (GPT-4o-mini Vision by default) | |
| # ============================================================ | |
| def extract_structured_data( | |
| client: OpenAI, | |
| formatted_lines: List[str], | |
| preview_path: str, | |
| model_name: str = "gpt-4o-mini" | |
| ) -> Dict: | |
| """ | |
| Use GPT Vision to parse structured JSON from numbered, grouped lines + image. | |
| """ | |
| log.info(f"Processing with {model_name} …") | |
| with open(preview_path, "rb") as img_file: | |
| img_b64 = base64.b64encode(img_file.read()).decode("utf-8") | |
| def is_header(line: str) -> bool: | |
| low = line.lower() | |
| return any(k in low for k in HEADER_KEYWORDS) | |
| filtered_lines = [ln for ln in formatted_lines if not is_header(ln)] | |
| system_message = """ | |
| You are a professional invoice/receipt parser for ChefCode. | |
| You receive: | |
| (1) Numbered OCR lines (already grouped horizontally by row). | |
| (2) The invoice image. | |
| Return ONLY valid JSON with this schema: | |
| { | |
| "supplier": "string", | |
| "invoice_number": "string", | |
| "date": "YYYY-MM-DD or null", | |
| "line_items": [ | |
| { | |
| "lot_number": "string", | |
| "item_name": "string", | |
| "unit": "string", | |
| "quantity": number, | |
| "unit_price": number or null, | |
| "line_total": number or null, | |
| "type": "string" | |
| } | |
| ], | |
| "total_amount": number or null, | |
| "confidence": "high | medium | low" | |
| } | |
| Extraction rules (critical): | |
| - The table is horizontal: Lot → Item → Unit → Quantity → Unit Price → Line Total. | |
| - The quantity is the number DIRECTLY AFTER the unit. | |
| - If numbers for a line appear missing, check up to TWO lines BELOW that line in OCR_LINES, | |
| ignoring header words (Quantità, Prezzo, Sconto, Importo, IVA). | |
| - Do not skip any visible row; compare OCR row count with extracted items and recover missing lines. | |
| - Verify math: quantity × unit_price ≈ line_total (±3%). If off, re-read digits from the image. | |
| - If two adjacent rows share identical numbers, re-check both in the image; do not merge distinct items. | |
| - Use "." as decimal separator and strip any currency symbols. | |
| - Keep supplier and item names exactly as printed; do not translate them. | |
| - Infer "type" (meat/vegetable/dairy/grain/condiment/beverage/grocery). If invoice language is Italian, | |
| output these category words in Italian (carne, verdura, latticini, cereali, condimento, bevanda, drogheria). | |
| - Output ONLY JSON — no prose, no markdown. | |
| """.strip() | |
| user_message = f"""Extract structured data from this invoice. | |
| OCR_LINES (count={len(filtered_lines)}): | |
| {chr(10).join(filtered_lines)} | |
| """ | |
| resp = client.chat.completions.create( | |
| model=model_name, | |
| temperature=0.1, | |
| max_completion_tokens=2000, | |
| messages=[ | |
| {"role": "system", "content": system_message}, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": user_message}, | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/png;base64,{img_b64}", "detail": "high"}, | |
| }, | |
| ], | |
| }, | |
| ], | |
| ) | |
| # ✅ Capture real token usage directly from the API response | |
| usage = None | |
| try: | |
| if hasattr(resp, "usage") and resp.usage: | |
| usage = { | |
| "prompt_tokens": resp.usage.prompt_tokens, | |
| "completion_tokens": resp.usage.completion_tokens, | |
| "total_tokens": resp.usage.total_tokens, | |
| } | |
| print(f"🔢 Token usage: {usage}") | |
| else: | |
| print("⚠️ No token usage info found in response.") | |
| except Exception as e: | |
| print(f"⚠️ Couldn't read token usage: {e}") | |
| raw = resp.choices[0].message.content.strip() | |
| # ✅ Save token usage into the structured data so it appears in smart_output.json | |
| if raw.startswith("```json"): | |
| raw = raw.replace("```json", "").replace("```", "").strip() | |
| elif raw.startswith("```"): | |
| raw = raw.replace("```", "").strip() | |
| try: | |
| data = json.loads(raw) | |
| except json.JSONDecodeError as e: | |
| log.error(f"JSON parse error: {e}") | |
| return {"error": "json_parse_error", "raw_response": raw, "confidence": "low"} | |
| log.info("GPT response parsed") | |
| if usage: | |
| data["usage"] = usage | |
| return data | |
| # ============================================================ | |
| # 5) VALIDATION & AUTO-CORRECTION | |
| # ============================================================ | |
| def _coerce_number(x): | |
| if x is None: | |
| return None | |
| if isinstance(x, (int, float)): | |
| return float(x) | |
| try: | |
| s = str(x).replace("€", "").replace("EUR", "").replace(",", ".").strip() | |
| return float(s) | |
| except Exception: | |
| return None | |
| def detect_invoice_language(structured: Dict) -> str: | |
| supplier = structured.get("supplier", "").lower() | |
| items = structured.get("line_items", []) | |
| italian_indicators = ["srl", "spa", "via", "roma", "milano", "kg", "lt"] | |
| text_to_check = supplier + " " + " ".join(it.get("item_name", "").lower() for it in items[:3]) | |
| italian_count = sum(1 for word in italian_indicators if word in text_to_check) | |
| return "it" if italian_count >= 2 else "en" | |
| def normalize_item_types(structured: Dict) -> Dict: | |
| language = detect_invoice_language(structured) | |
| if language != "it": | |
| return structured | |
| type_mapping = { | |
| "grain": "cereali", | |
| "meat": "carne", | |
| "fish": "pesce", | |
| "vegetable": "verdura", | |
| "fruit": "frutta", | |
| "dairy": "latticini", | |
| "condiment": "condimento", | |
| "beverage": "bevanda", | |
| "grocery": "alimentari", | |
| "other": "altro" | |
| } | |
| items = structured.get("line_items", []) | |
| for it in items: | |
| item_type = (it.get("type") or "").lower() | |
| if item_type in type_mapping: | |
| it["type"] = type_mapping[item_type] | |
| return structured | |
| def reconcile_and_validate(structured: Dict, ocr_json: Dict) -> Dict: | |
| notes = [] | |
| items = structured.get("line_items", []) or [] | |
| fixed_items = [] | |
| for it in items: | |
| q = _coerce_number(it.get("quantity")) | |
| p = _coerce_number(it.get("unit_price")) | |
| t = _coerce_number(it.get("line_total")) | |
| if q == 0: q = None | |
| if p == 0: p = None | |
| if t == 0: t = None | |
| if q is not None and p is not None: | |
| calc = round(q * p, 2) | |
| if t is not None and t > 0 and abs(calc - t) > 0.1 * (t if t else 1): | |
| notes.append( | |
| f"⚠️ Large mismatch (>10%) for '{it.get('item_name','')}': q={q}, p={p}, expected={calc}, got={t}. Auto-correcting to {calc}." | |
| ) | |
| t = calc | |
| elif t is None or abs(calc - t) <= 0.05: | |
| t = calc | |
| elif abs(calc - t) <= 0.15: | |
| notes.append(f"✓ Corrected line_total from {t} to {calc} for '{it.get('item_name','')}'.") | |
| t = calc | |
| else: | |
| notes.append(f"⚠️ Line math mismatch for '{it.get('item_name','')}' (q*p={calc}, got {t}). Corrected.") | |
| t = calc | |
| it["quantity"] = q | |
| it["unit_price"] = p | |
| it["line_total"] = t | |
| fixed_items.append(it) | |
| structured["line_items"] = fixed_items | |
| structured = normalize_item_types(structured) | |
| line_sum = round(sum(it.get("line_total") or 0 for it in fixed_items), 2) | |
| ta = _coerce_number(structured.get("total_amount")) | |
| if ta is None: | |
| structured["total_amount"] = line_sum | |
| notes.append(f"Set total_amount from sum(line_totals) = {line_sum}.") | |
| else: | |
| if ta > 0: | |
| diff_percent = abs(line_sum - ta) / ta * 100 | |
| if diff_percent <= 1.0: | |
| notes.append(f"✓ Total validated: sum={line_sum}, header={ta}, diff={diff_percent:.2f}%") | |
| structured["total_amount"] = line_sum | |
| elif diff_percent <= 5.0: | |
| notes.append(f"⚠️ Total mismatch (±{diff_percent:.2f}%): sum={line_sum}, header={ta}") | |
| structured["confidence"] = "medium" | |
| else: | |
| notes.append(f"❌ Large total mismatch ({diff_percent:.2f}%): sum={line_sum}, header={ta}") | |
| structured["confidence"] = "low" | |
| else: | |
| structured["total_amount"] = line_sum | |
| notes.append(f"Set total_amount from sum(line_totals) = {line_sum}.") | |
| ocr_line_count = sum(len(p["lines"]) for p in ocr_json.get("pages", [])) | |
| if len(fixed_items) < max(3, int(0.5 * ocr_line_count)): | |
| notes.append(f"Only {len(fixed_items)}/{ocr_line_count} OCR lines became items; possible skips.") | |
| if any("❌" in n for n in notes): | |
| structured["confidence"] = "low" | |
| elif any("⚠️" in n for n in notes): | |
| if structured.get("confidence") != "low": | |
| structured["confidence"] = "medium" | |
| elif not any("mismatch" in n.lower() for n in notes): | |
| structured["confidence"] = structured.get("confidence", "high") | |
| if notes: | |
| existing = structured.get("validation_notes") | |
| structured["validation_notes"] = ("; ".join(notes) if not existing else (existing + "; " + "; ".join(notes))) | |
| return structured | |
| # ============================================================ | |
| # 5B) LIGHTWEIGHT FALLBACK | |
| # ============================================================ | |
| def extract_structured_data_lightweight( | |
| client: OpenAI, filtered_lines: List[str], preview_path: str, model_name: str = "gpt-4o-mini" | |
| ) -> Dict: | |
| log.info("Re-running with lightweight prompt (numeric focus)…") | |
| with open(preview_path, "rb") as f: | |
| img_b64 = base64.b64encode(f.read()).decode("utf-8") | |
| system_message = """You are a precise invoice data extractor. | |
| FOCUS: Extract ONLY the numeric columns accurately. Do not worry about perfect item names. | |
| Return valid JSON with this schema: | |
| { | |
| "supplier": "string", | |
| "invoice_number": "string", | |
| "date": "string", | |
| "line_items": [ | |
| { | |
| "lot_number": "string", | |
| "item_name": "string", | |
| "unit": "string", | |
| "quantity": number, | |
| "unit_price": number, | |
| "line_total": number, | |
| "type": "string" | |
| } | |
| ], | |
| "total_amount": number, | |
| "confidence": "high|medium|low" | |
| } | |
| CRITICAL RULES: | |
| 1. For each line, extract: quantity, unit_price, line_total in that exact order | |
| 2. Verify: quantity × unit_price ≈ line_total (±2%) | |
| 3. Count ALL visible rows in the table - don't skip any | |
| 4. Sum all line_totals and verify against invoice total | |
| 5. If a row has numbers, include it - better to have all rows than miss some | |
| Return ONLY valid JSON, no markdown.""" | |
| user_message = f"""Extract ALL line items from this invoice. Focus on getting every row with numbers. | |
| OCR_LINES (count={len(filtered_lines)}): | |
| {chr(10).join(filtered_lines)} | |
| Extract EVERY line item visible in the table.""" | |
| resp = client.chat.completions.create( | |
| model=model_name, | |
| max_completion_tokens=3000, | |
| messages=[ | |
| {"role": "system", "content": system_message}, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": user_message}, | |
| {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}", "detail": "high"}}, | |
| ], | |
| }, | |
| ], | |
| ) | |
| if not resp.choices: | |
| log.error("No choices in response") | |
| return {"error": "no_choices", "confidence": "low"} | |
| choice = resp.choices[0] | |
| raw = (choice.message.content or "").strip() | |
| if not raw: | |
| log.error(f"Empty response from GPT (finish_reason={choice.finish_reason})") | |
| return {"error": "empty_response", "finish_reason": choice.finish_reason, "confidence": "low"} | |
| if raw.startswith("```json"): | |
| raw = raw.replace("```json", "").replace("```", "").strip() | |
| elif raw.startswith("```"): | |
| raw = raw.replace("```", "").strip() | |
| try: | |
| data = json.loads(raw) | |
| except json.JSONDecodeError as e: | |
| log.error(f"JSON parse error: {e}") | |
| return {"error": "json_parse_error", "raw_response": raw[:500], "confidence": "low"} | |
| log.info(f"Lightweight extraction: {len(data.get('line_items', []))} items") | |
| return data | |
| def should_rerun_lightweight(structured: Dict) -> bool: | |
| line_items = structured.get("line_items", []) | |
| if not line_items: | |
| return False | |
| line_sum = sum(_coerce_number(it.get("line_total")) or 0 for it in line_items) | |
| header_total = _coerce_number(structured.get("total_amount")) | |
| if header_total is None or header_total == 0: | |
| return False | |
| diff_percent = abs(line_sum - header_total) / header_total * 100 | |
| if diff_percent > 30: | |
| log.warning(f"Large total mismatch: {diff_percent:.1f}% (line_sum={line_sum}, header={header_total})") | |
| return True | |
| return False | |
| # ============================================================ | |
| # 6) OPTIONAL FALLBACK OCR (Tesseract / EasyOCR) | |
| # ============================================================ | |
| def fallback_ocr_plain(image_path: str, output_dir: str) -> Tuple[str, Dict, List[str]]: | |
| """ | |
| Fallback if DocTR throws: try pytesseract or EasyOCR. | |
| Returns raw text, json (minimal), and naive line list. | |
| """ | |
| try: | |
| import pytesseract | |
| log.info("Running Tesseract OCR (fallback)…") | |
| img = cv2.imread(image_path) | |
| text = pytesseract.image_to_string(img) or "" | |
| lines = [ln.strip() for ln in text.splitlines() if ln.strip()] | |
| ocr_json = { | |
| "pages": [{"page_number": 1, "blocks": [], "lines": lines}], | |
| "average_confidence": 0.7, | |
| "engine": "tesseract_fallback", | |
| } | |
| return text, ocr_json, [f"{i+1}. {ln}" for i, ln in enumerate(lines)] | |
| except Exception: | |
| pass | |
| try: | |
| import easyocr | |
| log.info("Running EasyOCR (fallback)…") | |
| reader = easyocr.Reader(["it", "en"], gpu=False) | |
| results = reader.readtext(image_path, detail=1, paragraph=False) | |
| lines = [res[1] for res in results if len(res) >= 2 and res[1].strip()] | |
| ocr_json = { | |
| "pages": [{"page_number": 1, "blocks": [], "lines": lines}], | |
| "average_confidence": 0.75, | |
| "engine": "easyocr_fallback", | |
| } | |
| return "\n".join(lines), ocr_json, [f"{i+1}. {ln}" for i, ln in enumerate(lines)] | |
| except Exception as e: | |
| log.error(f"All OCR fallbacks failed: {e}") | |
| return "", {"pages": [], "average_confidence": 0.0, "engine": "none"}, [] | |
| # ============================================================ | |
| # 7) MAIN PIPELINE | |
| # ============================================================ | |
| def main(invoice_path: str, output_dir: str = "."): | |
| print("\n" + "="*60) | |
| print("🧠 SMART OCR PIPELINE (final, gpt-4o-mini by default)") | |
| print("="*60 + "\n") | |
| Path(output_dir).mkdir(parents=True, exist_ok=True) | |
| # 1) Setup | |
| client = setup_environment() | |
| # 2) Preprocess | |
| t0 = time.time() | |
| processed_path, preview_path = preprocess_image(invoice_path, output_dir) | |
| # 3) OCR | |
| try: | |
| ocr_text, ocr_json, formatted_lines = extract_text_with_doctr(processed_path, output_dir) | |
| except Exception as e: | |
| log.error(f"DocTR OCR failed: {e}") | |
| ocr_text, ocr_json, formatted_lines = fallback_ocr_plain(processed_path, output_dir) | |
| # 4) AI post-processing | |
| structured = extract_structured_data(client, formatted_lines, preview_path, model_name="gpt-4o-mini") | |
| # 5) Validation & save | |
| structured = reconcile_and_validate(structured, ocr_json) | |
| # 6) Lightweight fallback rerun if needed | |
| if should_rerun_lightweight(structured): | |
| log.info("Triggering lightweight fallback extraction…") | |
| structured_retry = extract_structured_data_lightweight(client, formatted_lines, preview_path, model_name="gpt-4o-mini") | |
| retry_items = len(structured_retry.get("line_items", [])) | |
| original_items = len(structured.get("line_items", [])) | |
| if retry_items > original_items: | |
| log.info(f"Using lightweight result: {retry_items} items vs {original_items} items") | |
| structured = reconcile_and_validate(structured_retry, ocr_json) | |
| structured["rerun_applied"] = "lightweight_fallback" | |
| else: | |
| log.info(f"Keeping original result: {original_items} items vs {retry_items} items") | |
| structured["rerun_attempted"] = "lightweight_fallback_not_better" | |
| final_output = { | |
| "status": "success", | |
| "pipeline_version": "3.1_final_gpt4o-mini", | |
| "input_file": Path(invoice_path).name, | |
| "ocr_confidence": ocr_json.get("average_confidence", 0.0), | |
| "lines_detected": sum(len(p["lines"]) for p in ocr_json.get("pages", [])) if ocr_json.get("pages") else 0, | |
| "data": structured, | |
| "elapsed_sec": round(time.time() - t0, 2), | |
| "usage": structured.get("usage", None), | |
| } | |
| out_path = os.path.join(output_dir, "smart_output.json") | |
| with open(out_path, "w", encoding="utf-8") as f: | |
| json.dump(final_output, f, indent=2, ensure_ascii=False) | |
| log.info(f"Final output saved: {out_path}") | |
| log.info(f" • OCR Confidence: {final_output['ocr_confidence']:.2f}") | |
| log.info(f" • Items parsed: {len(structured.get('line_items', []))}") | |
| log.info(f" • Total: {structured.get('total_amount')}") | |
| log.info(f" • Elapsed: {final_output['elapsed_sec']}s") | |
| print("\nDone.\n") | |
| return final_output | |
| if __name__ == "__main__": | |
| if len(sys.argv) < 2: | |
| print("Usage: python smart_ocr_pipeline_final.py <path/to/invoice.jpg> [output_dir]") | |
| sys.exit(1) | |
| invoice_path = sys.argv[1] | |
| output_dir = sys.argv[2] if len(sys.argv) > 2 else "." | |
| main(invoice_path, output_dir) | |