Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import pytesseract | |
| from pdf2image import convert_from_path | |
| from PIL import Image, ImageEnhance, ImageFilter | |
| import os | |
| import json | |
| import re | |
| import config | |
| # Load Model | |
| print(f">>> Loading AI Model: {config.MODEL_ID}...") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(config.MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained(config.MODEL_ID, device_map="cpu", torch_dtype=torch.float32, low_cpu_mem_usage=True) | |
| except: | |
| model = None | |
| print("❌ Model Failed to Load") | |
| # ===================================================== | |
| # 1. ADVANCED OCR PIPELINE | |
| # ===================================================== | |
| def preprocess_image(image): | |
| """ | |
| Cleans image for better OCR results: | |
| 1. Grayscale | |
| 2. Sharpen | |
| 3. Increase Contrast | |
| """ | |
| # Convert to gray | |
| image = image.convert('L') | |
| # Increase Contrast | |
| enhancer = ImageEnhance.Contrast(image) | |
| image = enhancer.enhance(2.0) | |
| # Sharpen (helps with blurry fonts) | |
| image = image.filter(ImageFilter.SHARPEN) | |
| return image | |
| def perform_ocr(file_obj): | |
| if file_obj is None: return "", None, {} | |
| try: | |
| filename = os.path.basename(file_obj) | |
| # HIGH QUALITY CONVERSION (DPI=300) | |
| if filename.lower().endswith(".pdf"): | |
| # dpi=300 makes text much clearer than default 72 | |
| images = convert_from_path(file_obj, first_page=1, last_page=1, dpi=300) | |
| original_img = images[0] | |
| else: | |
| original_img = Image.open(file_obj).convert("RGB") | |
| # Preprocess for Tesseract | |
| processed_img = preprocess_image(original_img) | |
| # Run Tesseract | |
| text = pytesseract.image_to_string(processed_img) | |
| # Metadata extraction | |
| meta = { | |
| "filename": filename, | |
| "size_kb": os.path.getsize(file_obj)/1024 | |
| } | |
| return text, original_img, meta | |
| except Exception as e: | |
| print(f"OCR Error: {e}") | |
| return "", None, {} | |
| # ===================================================== | |
| # 2. REGEX FALLBACKS (The "Generic Name" Fix) | |
| # ===================================================== | |
| def regex_extract_vendor(text): | |
| """ | |
| If AI fails, we use old-school logic to find the name. | |
| """ | |
| lines = [l.strip() for l in text.split('\n') if len(l.strip()) > 3] | |
| # 1. Look for "To" / "From" | |
| for i, line in enumerate(lines): | |
| if re.search(r'^(bill|invoice)\s*to:?$', line.lower()): | |
| # The NEXT line is likely the customer name | |
| if i + 1 < len(lines): return lines[i+1] | |
| if re.search(r'^(from|vendor):?$', line.lower()): | |
| if i + 1 < len(lines): return lines[i+1] | |
| # 2. Top-most bold text (heuristic: usually the first or second line is the Company Name) | |
| if len(lines) > 0: | |
| # Ignore common headers | |
| if "invoice" not in lines[0].lower(): return lines[0] | |
| if len(lines) > 1: return lines[1] | |
| return "Unknown" | |
| def regex_extract_total(text): | |
| # Looks for "Total $1,234.56" patterns | |
| match = re.search(r'(?:total|amount|balance).*?([\d,]+\.\d{2})', text.lower()) | |
| if match: | |
| try: return float(match.group(1).replace(',', '')) | |
| except: pass | |
| return 0.0 | |
| # ===================================================== | |
| # 3. AI EXTRACTION | |
| # ===================================================== | |
| def repair_json(json_str): | |
| if not json_str: return {} | |
| try: | |
| # Find the first { and the last } | |
| start = json_str.find('{') | |
| end = json_str.rfind('}') + 1 | |
| if start != -1 and end != 0: | |
| return json.loads(json_str[start:end]) | |
| except: pass | |
| return {} | |
| def extract_intelligent_json(text, metadata): | |
| if not model: return {} | |
| # Stronger Prompt | |
| prompt = f"""<|im_start|>system | |
| You are a financial data extractor. | |
| TASK: Convert OCR text into JSON. | |
| MANDATORY RULES: | |
| 1. Extract the VENDOR_NAME (Who sent the invoice?) | |
| 2. Extract the DOCUMENT_TYPE: ["invoice", "bill", "expense", "estimate"] | |
| 3. Extract LINE_ITEMS. | |
| JSON FORMAT: | |
| {{ | |
| "doc_type": "invoice", | |
| "data": {{ | |
| "vendor_name": "Acme Corp", | |
| "date": "2024-01-01", | |
| "reference_number": "INV-001", | |
| "total": 100.00, | |
| "line_items": [ {{"name": "Service", "description": "...", "rate": 100, "quantity": 1}} ] | |
| }} | |
| }} | |
| <|im_end|> | |
| <|im_start|>user | |
| DOCUMENT TEXT: | |
| {text[:2000]} | |
| <|im_end|> | |
| <|im_start|>assistant | |
| ```json | |
| """ | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| out = model.generate(**inputs, max_new_tokens=500, temperature=0.1) | |
| raw_output = tokenizer.decode(out[0]) | |
| data = repair_json(raw_output) | |
| # --- FALLBACK LAYER --- | |
| # If AI returned empty/garbage data, overlay with Regex | |
| if not data or "data" not in data: | |
| data = {"doc_type": "invoice", "data": {}} | |
| inner = data.get("data", {}) | |
| # Fix Name | |
| if not inner.get("vendor_name") or inner["vendor_name"] == "Unknown": | |
| inner["vendor_name"] = regex_extract_vendor(text) | |
| # Fix Total | |
| if not inner.get("total"): | |
| inner["total"] = regex_extract_total(text) | |
| data["data"] = inner | |
| return data |