Spaces:
Runtime error
Runtime error
| import os | |
| import io | |
| import json | |
| import re | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| from pdf2image import convert_from_path | |
| from transformers import DonutProcessor, VisionEncoderDecoderModel # Donut :contentReference[oaicite:1]{index=1} | |
| from paddleocr import PaddleOCR # PaddleOCR :contentReference[oaicite:2]{index=2} | |
| # ----------------------------- | |
| # Global model initialization | |
| # ----------------------------- | |
| DONUT_MODEL_ID = os.getenv( | |
| "DONUT_MODEL_ID", | |
| "nielsr/donut-docvqa-demo", # good general DocVQA Donut model | |
| ) | |
| device = "cpu" # HF Spaces CPU basic | |
| processor = DonutProcessor.from_pretrained(DONUT_MODEL_ID) | |
| model = VisionEncoderDecoderModel.from_pretrained(DONUT_MODEL_ID).to(device) | |
| model.eval() | |
| # PaddleOCR as fallback OCR engine (English) | |
| ocr_engine = PaddleOCR(use_angle_cls=True, lang="en") | |
| # ----------------------------- | |
| # File / image helpers | |
| # ----------------------------- | |
| def load_first_page_as_image(filepath: str) -> Image.Image: | |
| ext = os.path.splitext(filepath)[1].lower() | |
| if ext == ".pdf": | |
| # Convert first page of PDF to image | |
| pages = convert_from_path(filepath, dpi=200) | |
| img = pages[0].convert("RGB") | |
| else: | |
| img = Image.open(filepath).convert("RGB") | |
| return img | |
| # ----------------------------- | |
| # Donut helpers | |
| # ----------------------------- | |
| def run_donut(image: Image.Image) -> Tuple[Optional[Dict[str, Any]], str]: | |
| """ | |
| Run Donut on an image. | |
| Returns: | |
| (parsed_json_or_none, raw_sequence_text) | |
| """ | |
| pixel_values = processor(image, return_tensors="pt").pixel_values.to(device) | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| pixel_values, | |
| max_length=512, | |
| num_beams=3, | |
| early_stopping=True, | |
| ) | |
| sequence = processor.batch_decode(output_ids, skip_special_tokens=False)[0] | |
| # Clean sequence: remove special tokens, keep text | |
| seq = sequence.replace(processor.tokenizer.eos_token, "") | |
| seq = seq.replace(processor.tokenizer.pad_token, "") | |
| seq = seq.strip() | |
| # Try to extract JSON-like content from Donut output | |
| json_obj = None | |
| start = seq.find("{") | |
| end = seq.rfind("}") | |
| if start != -1 and end != -1 and end > start: | |
| raw_json = seq[start : end + 1] | |
| try: | |
| json_obj = json.loads(raw_json) | |
| except Exception: | |
| json_obj = None | |
| return json_obj, seq | |
| # ----------------------------- | |
| # PaddleOCR helpers | |
| # ----------------------------- | |
| def run_paddle_ocr(image: Image.Image) -> str: | |
| """ | |
| Run PaddleOCR on the image and concatenate all recognized text into one string. | |
| """ | |
| img_np = np.array(image) | |
| result = ocr_engine.ocr(img_np, cls=True) | |
| texts: List[str] = [] | |
| for page in result: | |
| for line in page: | |
| text = line[1][0] | |
| texts.append(text) | |
| return "\n".join(texts) | |
| # ----------------------------- | |
| # Parsing helpers | |
| # ----------------------------- | |
| def to_int_or_none(value: Optional[str]) -> Optional[int]: | |
| if value is None: | |
| return None | |
| value = value.strip() | |
| if not value: | |
| return None | |
| try: | |
| return int(re.sub(r"[^\d]", "", value)) | |
| except Exception: | |
| return None | |
| def find_regex(pattern: str, text: str, group: int = 1) -> Optional[str]: | |
| m = re.search(pattern, text, re.IGNORECASE | re.MULTILINE) | |
| if m: | |
| return m.group(group).strip() | |
| return None | |
| def parse_dimensions(dim_str: str) -> Tuple[Optional[int], Optional[int], Optional[int]]: | |
| """ | |
| Parse dimension patterns like '2x6x14', '2 x 6 x 14', etc. | |
| IMPORTANT (per your spec): | |
| '2x6x14' β height=6, width=14, length=2 | |
| i.e. dims[0]=length, dims[1]=height, dims[2]=width | |
| """ | |
| m = re.search(r"(\d+)\s*[xX]\s*(\d+)\s*[xX]\s*(\d+)", dim_str) | |
| if not m: | |
| return None, None, None | |
| a = int(m.group(1)) | |
| b = int(m.group(2)) | |
| c = int(m.group(3)) | |
| length = a | |
| height = b | |
| width = c | |
| return height, width, length | |
| def normalize_unit(unit_str: Optional[str]) -> Optional[str]: | |
| """ | |
| Normalize units to your canonical set: PCS/PKG/MBF/MSFT/etc. | |
| """ | |
| if not unit_str: | |
| return None | |
| u = unit_str.strip().upper() | |
| mapping = { | |
| "PCS": "PCS", | |
| "PC": "PCS", | |
| "PKG": "PKG", | |
| "PKGS": "PKG", | |
| "PACKAGE": "PKG", | |
| "PACKAGES": "PKG", | |
| "MBF": "MBF", | |
| "MSF": "MSFT", | |
| "MSFT": "MSFT", | |
| "FBM": "FBM", | |
| "SF": "SF", | |
| "SQFT": "SF", | |
| "UNIT": "UNIT", | |
| "UNITS": "UNIT", | |
| } | |
| # Try exact / prefix matching | |
| for k, v in mapping.items(): | |
| if u == k or u.startswith(k): | |
| return v | |
| return u # fallback: return raw uppercased | |
| def extract_custom_fields(text: str) -> List[str]: | |
| """ | |
| Extract custom fields like Mill and Vendor from the text. | |
| Returns a list of "Key||Value" strings. | |
| """ | |
| fields: List[str] = [] | |
| mill = find_regex(r"\bMill[:\-]\s*(.+)", text) | |
| if mill: | |
| fields.append(f"Mill||{mill}") | |
| vendor = find_regex(r"\bVendor[:\-]\s*(.+)", text) | |
| if vendor: | |
| fields.append(f"Vendor||{vendor}") | |
| return fields | |
| def extract_header_fields(full_text: str) -> Dict[str, Any]: | |
| """ | |
| Extract top-level header fields (PO, shipFrom, carrier, etc.) from text. | |
| All fields default to None if not found. | |
| """ | |
| po_number = find_regex(r"\bPO(?:\s*#|[:\-])?\s*([A-Z0-9\-]+)", full_text) | |
| ship_from = find_regex(r"(?:Ship From|Origin)\s*[:\-]\s*(.+)", full_text) | |
| # Carrier type (RAIL/TRUCK/etc) | |
| carrier_type = None | |
| carrier_type_match = find_regex(r"\b(Carrier Type|Mode)\s*[:\-]\s*(.+)", full_text, group=2) | |
| if carrier_type_match: | |
| carrier_type = carrier_type_match.upper() | |
| else: | |
| # heuristic: look for RAIL/TRUCK literal | |
| if re.search(r"\bRAIL\b", full_text, re.IGNORECASE): | |
| carrier_type = "RAIL" | |
| elif re.search(r"\bTRUCK\b", full_text, re.IGNORECASE): | |
| carrier_type = "TRUCK" | |
| origin_carrier = find_regex(r"(?:Rail Carrier|Carrier)\s*[:\-]\s*([A-Z0-9 &]+)", full_text) | |
| rail_car_num = find_regex( | |
| r"(?:Rail\s*Car|Car\s*No\.?|Railcar)\s*[:\-#]*\s*([A-Z0-9\- ]+)", full_text | |
| ) | |
| account_name = find_regex(r"(?:Consignee|Ship To|Customer)\s*[:\-]\s*(.+)", full_text) | |
| # Date (very rough β youβll probably want to refine) | |
| date_str = find_regex( | |
| r"\b(?:Date|Shipment Date|Ship Date)\s*[:\-]\s*([0-9]{1,2}[\/\-][0-9]{1,2}[\/\-][0-9]{2,4})", | |
| full_text, | |
| ) | |
| return { | |
| "poNumber": po_number, | |
| "shipFrom": ship_from, | |
| "carrierType": carrier_type, | |
| "originCarrier": origin_carrier, | |
| "railCarNumber": rail_car_num, | |
| "accountName": account_name, | |
| "date": date_str, | |
| } | |
| def extract_line_items(full_text: str) -> List[Dict[str, Any]]: | |
| """ | |
| Heuristic product line parser. | |
| Looks for lines like: | |
| 24 2x6x14 SPF #2&BTR KD PKG | |
| 30 7/16 OSB T&G 4x8 MSF | |
| This WILL need tuning for your customers' actual BOL formats. | |
| """ | |
| items: List[Dict[str, Any]] = [] | |
| lines = [ln.strip() for ln in full_text.splitlines() if ln.strip()] | |
| line_pattern = re.compile( | |
| r"""^ | |
| (\d+) # quantity (packages) | |
| \s+ | |
| ([0-9xX\s]+) # dimensions e.g. 2x6x14 | |
| \s+ | |
| (.+?) # product description | |
| \s+ | |
| (PCS|PKG|PKGS|MBF|MSF|MSFT|FBM|SF|UNIT|UNITS)\b # unit | |
| """, | |
| re.IGNORECASE | re.VERBOSE, | |
| ) | |
| for ln in lines: | |
| m = line_pattern.match(ln) | |
| if not m: | |
| continue | |
| qty_str = m.group(1) | |
| dims_str = m.group(2) | |
| desc = m.group(3).strip() | |
| unit_str = m.group(4) | |
| quantity_shipped = to_int_or_none(qty_str) | |
| h, w, l = parse_dimensions(dims_str) | |
| inventory_units = normalize_unit(unit_str) | |
| # productCode is often separate; we don't try to guess here | |
| product_code = None | |
| # We don't attempt to guess pcs / mbf / sf here; leave null unless you want to | |
| product_obj: Dict[str, Any] = { | |
| "category": None, # e.g., Lumber, OSB β you can classify based on desc | |
| "unit": inventory_units, | |
| "pcs": None, | |
| "mbf": None, | |
| "sf": None, | |
| "pcsHeight": h, | |
| "pcsWidth": w, | |
| "pcsLength": l, | |
| } | |
| items.append( | |
| { | |
| "quantityShipped": quantity_shipped, | |
| "inventoryUnits": inventory_units, | |
| "productName": desc, | |
| "productCode": product_code, | |
| "product": product_obj, | |
| "customFields": [], # header-level customFields added later | |
| } | |
| ) | |
| return items | |
| def build_schema( | |
| full_text: str, | |
| donut_json: Optional[Dict[str, Any]] = None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Build the final JSON document according to your spec. | |
| Priority: use Donut JSON if it obviously maps, otherwise fall back to regex/heuristics. | |
| For now we mostly use heuristics and ignore donut_json except as a future hook. | |
| """ | |
| header = extract_header_fields(full_text) | |
| line_items = extract_line_items(full_text) | |
| # totalQuantity & totalUnits | |
| total_quantity = sum( | |
| [itm["quantityShipped"] for itm in line_items if isinstance(itm["quantityShipped"], int)] | |
| ) or None | |
| # pick most common unit among items | |
| units = [itm["inventoryUnits"] for itm in line_items if itm["inventoryUnits"]] | |
| total_units = units[0] if units else None | |
| # custom fields (applied to all items) | |
| header_custom_fields = extract_custom_fields(full_text) | |
| for itm in line_items: | |
| itm["customFields"] = header_custom_fields.copy() | |
| # If no line items detected, still return empty array but valid schema | |
| if not line_items: | |
| line_items = [] | |
| result: Dict[str, Any] = { | |
| "poNumber": header["poNumber"], | |
| "shipFrom": header["shipFrom"], | |
| "carrierType": header["carrierType"], | |
| "originCarrier": header["originCarrier"], | |
| "railCarNumber": header["railCarNumber"], | |
| "totalQuantity": total_quantity, | |
| "totalUnits": total_units, | |
| "accountName": header["accountName"], | |
| "inventories": { | |
| "items": line_items, | |
| }, | |
| } | |
| # NOTE: "Date" was part of your narrative spec but not in the final JSON schema. | |
| # If you want it, you can add it as a customField or separate top-level key. | |
| return result | |
| # ----------------------------- | |
| # Main prediction function | |
| # ----------------------------- | |
| import torch # after functions to avoid circular issues in spaces | |
| def extract_from_document(filepath: str) -> Dict[str, Any]: | |
| """ | |
| Main function called by Gradio: | |
| 1. Load first page as image | |
| 2. Try Donut for structured text | |
| 3. Fallback to PaddleOCR | |
| 4. Build final schema-compliant JSON | |
| """ | |
| image = load_first_page_as_image(filepath) | |
| # 1) Try Donut | |
| donut_json, donut_seq = run_donut(image) | |
| full_text = "" | |
| if donut_json is not None: | |
| # If donut_json contains a "text" field or similar, use it; otherwise use raw sequence. | |
| if isinstance(donut_json, dict): | |
| # This is model-dependent; adjust to your fine-tuned schema | |
| text_candidate = donut_json.get("text") or donut_json.get("raw_text") | |
| if isinstance(text_candidate, str) and text_candidate.strip(): | |
| full_text = text_candidate | |
| if not full_text: | |
| full_text = donut_seq | |
| # 2) If donut didn't give us usable text, use PaddleOCR | |
| if not full_text or len(full_text.strip()) < 10: | |
| full_text = run_paddle_ocr(image) | |
| # 3) Build final JSON schema | |
| final_json = build_schema(full_text=full_text, donut_json=donut_json) | |
| # Ensure we never return empty strings where null is required | |
| def clean_nulls(obj: Any) -> Any: | |
| if isinstance(obj, dict): | |
| return {k: clean_nulls(v) for k, v in obj.items()} | |
| if isinstance(obj, list): | |
| return [clean_nulls(v) for v in obj] | |
| if isinstance(obj, str) and obj.strip() == "": | |
| return None | |
| return obj | |
| final_json = clean_nulls(final_json) | |
| return final_json | |
| # ----------------------------- | |
| # Gradio UI | |
| # ----------------------------- | |
| demo = gr.Interface( | |
| fn=extract_from_document, | |
| inputs=gr.File( | |
| label="Upload PDF or Image (BOL / Shipping Doc)", | |
| file_types=[".pdf", ".png", ".jpg", ".jpeg", ".tif", ".tiff"], | |
| type="filepath", | |
| ), | |
| outputs=gr.JSON(label="Extracted JSON"), | |
| title="Shipping Document Text Extraction (Donut + PaddleOCR)", | |
| description=( | |
| "Upload a shipping document (PDF or image). " | |
| "The app will run Donut (structured extraction) with PaddleOCR fallback " | |
| "and return a JSON payload suitable for your inbound shipment form." | |
| ), | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |