mahmoudmosa2877 commited on
Commit
8723ac4
Β·
verified Β·
1 Parent(s): 1b21e80

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +437 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import json
4
+ import re
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ from PIL import Image
10
+
11
+ from pdf2image import convert_from_path
12
+ from transformers import DonutProcessor, VisionEncoderDecoderModel # Donut :contentReference[oaicite:1]{index=1}
13
+ from paddleocr import PaddleOCR # PaddleOCR :contentReference[oaicite:2]{index=2}
14
+
15
+
16
+ # -----------------------------
17
+ # Global model initialization
18
+ # -----------------------------
19
+
20
+ DONUT_MODEL_ID = os.getenv(
21
+ "DONUT_MODEL_ID",
22
+ "nielsr/donut-docvqa-demo", # good general DocVQA Donut model
23
+ )
24
+
25
+ device = "cpu" # HF Spaces CPU basic
26
+ processor = DonutProcessor.from_pretrained(DONUT_MODEL_ID)
27
+ model = VisionEncoderDecoderModel.from_pretrained(DONUT_MODEL_ID).to(device)
28
+ model.eval()
29
+
30
+ # PaddleOCR as fallback OCR engine (English)
31
+ ocr_engine = PaddleOCR(use_angle_cls=True, lang="en")
32
+
33
+
34
+ # -----------------------------
35
+ # File / image helpers
36
+ # -----------------------------
37
+
38
+ def load_first_page_as_image(filepath: str) -> Image.Image:
39
+ ext = os.path.splitext(filepath)[1].lower()
40
+ if ext == ".pdf":
41
+ # Convert first page of PDF to image
42
+ pages = convert_from_path(filepath, dpi=200)
43
+ img = pages[0].convert("RGB")
44
+ else:
45
+ img = Image.open(filepath).convert("RGB")
46
+ return img
47
+
48
+
49
+ # -----------------------------
50
+ # Donut helpers
51
+ # -----------------------------
52
+
53
+ def run_donut(image: Image.Image) -> Tuple[Optional[Dict[str, Any]], str]:
54
+ """
55
+ Run Donut on an image.
56
+ Returns:
57
+ (parsed_json_or_none, raw_sequence_text)
58
+ """
59
+ pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
60
+ with torch.no_grad():
61
+ output_ids = model.generate(
62
+ pixel_values,
63
+ max_length=512,
64
+ num_beams=3,
65
+ early_stopping=True,
66
+ )
67
+
68
+ sequence = processor.batch_decode(output_ids, skip_special_tokens=False)[0]
69
+
70
+ # Clean sequence: remove special tokens, keep text
71
+ seq = sequence.replace(processor.tokenizer.eos_token, "")
72
+ seq = seq.replace(processor.tokenizer.pad_token, "")
73
+ seq = seq.strip()
74
+
75
+ # Try to extract JSON-like content from Donut output
76
+ json_obj = None
77
+ start = seq.find("{")
78
+ end = seq.rfind("}")
79
+ if start != -1 and end != -1 and end > start:
80
+ raw_json = seq[start : end + 1]
81
+ try:
82
+ json_obj = json.loads(raw_json)
83
+ except Exception:
84
+ json_obj = None
85
+
86
+ return json_obj, seq
87
+
88
+
89
+ # -----------------------------
90
+ # PaddleOCR helpers
91
+ # -----------------------------
92
+
93
+ def run_paddle_ocr(image: Image.Image) -> str:
94
+ """
95
+ Run PaddleOCR on the image and concatenate all recognized text into one string.
96
+ """
97
+ img_np = np.array(image)
98
+ result = ocr_engine.ocr(img_np, cls=True)
99
+ texts: List[str] = []
100
+ for page in result:
101
+ for line in page:
102
+ text = line[1][0]
103
+ texts.append(text)
104
+ return "\n".join(texts)
105
+
106
+
107
+ # -----------------------------
108
+ # Parsing helpers
109
+ # -----------------------------
110
+
111
+ def to_int_or_none(value: Optional[str]) -> Optional[int]:
112
+ if value is None:
113
+ return None
114
+ value = value.strip()
115
+ if not value:
116
+ return None
117
+ try:
118
+ return int(re.sub(r"[^\d]", "", value))
119
+ except Exception:
120
+ return None
121
+
122
+
123
+ def find_regex(pattern: str, text: str, group: int = 1) -> Optional[str]:
124
+ m = re.search(pattern, text, re.IGNORECASE | re.MULTILINE)
125
+ if m:
126
+ return m.group(group).strip()
127
+ return None
128
+
129
+
130
+ def parse_dimensions(dim_str: str) -> Tuple[Optional[int], Optional[int], Optional[int]]:
131
+ """
132
+ Parse dimension patterns like '2x6x14', '2 x 6 x 14', etc.
133
+ IMPORTANT (per your spec):
134
+ '2x6x14' β†’ height=6, width=14, length=2
135
+ i.e. dims[0]=length, dims[1]=height, dims[2]=width
136
+ """
137
+ m = re.search(r"(\d+)\s*[xX]\s*(\d+)\s*[xX]\s*(\d+)", dim_str)
138
+ if not m:
139
+ return None, None, None
140
+
141
+ a = int(m.group(1))
142
+ b = int(m.group(2))
143
+ c = int(m.group(3))
144
+ length = a
145
+ height = b
146
+ width = c
147
+ return height, width, length
148
+
149
+
150
+ def normalize_unit(unit_str: Optional[str]) -> Optional[str]:
151
+ """
152
+ Normalize units to your canonical set: PCS/PKG/MBF/MSFT/etc.
153
+ """
154
+ if not unit_str:
155
+ return None
156
+ u = unit_str.strip().upper()
157
+ mapping = {
158
+ "PCS": "PCS",
159
+ "PC": "PCS",
160
+ "PKG": "PKG",
161
+ "PKGS": "PKG",
162
+ "PACKAGE": "PKG",
163
+ "PACKAGES": "PKG",
164
+ "MBF": "MBF",
165
+ "MSF": "MSFT",
166
+ "MSFT": "MSFT",
167
+ "FBM": "FBM",
168
+ "SF": "SF",
169
+ "SQFT": "SF",
170
+ "UNIT": "UNIT",
171
+ "UNITS": "UNIT",
172
+ }
173
+ # Try exact / prefix matching
174
+ for k, v in mapping.items():
175
+ if u == k or u.startswith(k):
176
+ return v
177
+ return u # fallback: return raw uppercased
178
+
179
+
180
+ def extract_custom_fields(text: str) -> List[str]:
181
+ """
182
+ Extract custom fields like Mill and Vendor from the text.
183
+ Returns a list of "Key||Value" strings.
184
+ """
185
+ fields: List[str] = []
186
+
187
+ mill = find_regex(r"\bMill[:\-]\s*(.+)", text)
188
+ if mill:
189
+ fields.append(f"Mill||{mill}")
190
+
191
+ vendor = find_regex(r"\bVendor[:\-]\s*(.+)", text)
192
+ if vendor:
193
+ fields.append(f"Vendor||{vendor}")
194
+
195
+ return fields
196
+
197
+
198
+ def extract_header_fields(full_text: str) -> Dict[str, Any]:
199
+ """
200
+ Extract top-level header fields (PO, shipFrom, carrier, etc.) from text.
201
+ All fields default to None if not found.
202
+ """
203
+ po_number = find_regex(r"\bPO(?:\s*#|[:\-])?\s*([A-Z0-9\-]+)", full_text)
204
+ ship_from = find_regex(r"(?:Ship From|Origin)\s*[:\-]\s*(.+)", full_text)
205
+
206
+ # Carrier type (RAIL/TRUCK/etc)
207
+ carrier_type = None
208
+ carrier_type_match = find_regex(r"\b(Carrier Type|Mode)\s*[:\-]\s*(.+)", full_text, group=2)
209
+ if carrier_type_match:
210
+ carrier_type = carrier_type_match.upper()
211
+ else:
212
+ # heuristic: look for RAIL/TRUCK literal
213
+ if re.search(r"\bRAIL\b", full_text, re.IGNORECASE):
214
+ carrier_type = "RAIL"
215
+ elif re.search(r"\bTRUCK\b", full_text, re.IGNORECASE):
216
+ carrier_type = "TRUCK"
217
+
218
+ origin_carrier = find_regex(r"(?:Rail Carrier|Carrier)\s*[:\-]\s*([A-Z0-9 &]+)", full_text)
219
+ rail_car_num = find_regex(
220
+ r"(?:Rail\s*Car|Car\s*No\.?|Railcar)\s*[:\-#]*\s*([A-Z0-9\- ]+)", full_text
221
+ )
222
+
223
+ account_name = find_regex(r"(?:Consignee|Ship To|Customer)\s*[:\-]\s*(.+)", full_text)
224
+
225
+ # Date (very rough – you’ll probably want to refine)
226
+ date_str = find_regex(
227
+ r"\b(?:Date|Shipment Date|Ship Date)\s*[:\-]\s*([0-9]{1,2}[\/\-][0-9]{1,2}[\/\-][0-9]{2,4})",
228
+ full_text,
229
+ )
230
+
231
+ return {
232
+ "poNumber": po_number,
233
+ "shipFrom": ship_from,
234
+ "carrierType": carrier_type,
235
+ "originCarrier": origin_carrier,
236
+ "railCarNumber": rail_car_num,
237
+ "accountName": account_name,
238
+ "date": date_str,
239
+ }
240
+
241
+
242
+ def extract_line_items(full_text: str) -> List[Dict[str, Any]]:
243
+ """
244
+ Heuristic product line parser.
245
+ Looks for lines like:
246
+ 24 2x6x14 SPF #2&BTR KD PKG
247
+ 30 7/16 OSB T&G 4x8 MSF
248
+
249
+ This WILL need tuning for your customers' actual BOL formats.
250
+ """
251
+ items: List[Dict[str, Any]] = []
252
+
253
+ lines = [ln.strip() for ln in full_text.splitlines() if ln.strip()]
254
+
255
+ line_pattern = re.compile(
256
+ r"""^
257
+ (\d+) # quantity (packages)
258
+ \s+
259
+ ([0-9xX\s]+) # dimensions e.g. 2x6x14
260
+ \s+
261
+ (.+?) # product description
262
+ \s+
263
+ (PCS|PKG|PKGS|MBF|MSF|MSFT|FBM|SF|UNIT|UNITS)\b # unit
264
+ """,
265
+ re.IGNORECASE | re.VERBOSE,
266
+ )
267
+
268
+ for ln in lines:
269
+ m = line_pattern.match(ln)
270
+ if not m:
271
+ continue
272
+
273
+ qty_str = m.group(1)
274
+ dims_str = m.group(2)
275
+ desc = m.group(3).strip()
276
+ unit_str = m.group(4)
277
+
278
+ quantity_shipped = to_int_or_none(qty_str)
279
+ h, w, l = parse_dimensions(dims_str)
280
+ inventory_units = normalize_unit(unit_str)
281
+
282
+ # productCode is often separate; we don't try to guess here
283
+ product_code = None
284
+
285
+ # We don't attempt to guess pcs / mbf / sf here; leave null unless you want to
286
+ product_obj: Dict[str, Any] = {
287
+ "category": None, # e.g., Lumber, OSB – you can classify based on desc
288
+ "unit": inventory_units,
289
+ "pcs": None,
290
+ "mbf": None,
291
+ "sf": None,
292
+ "pcsHeight": h,
293
+ "pcsWidth": w,
294
+ "pcsLength": l,
295
+ }
296
+
297
+ items.append(
298
+ {
299
+ "quantityShipped": quantity_shipped,
300
+ "inventoryUnits": inventory_units,
301
+ "productName": desc,
302
+ "productCode": product_code,
303
+ "product": product_obj,
304
+ "customFields": [], # header-level customFields added later
305
+ }
306
+ )
307
+
308
+ return items
309
+
310
+
311
+ def build_schema(
312
+ full_text: str,
313
+ donut_json: Optional[Dict[str, Any]] = None,
314
+ ) -> Dict[str, Any]:
315
+ """
316
+ Build the final JSON document according to your spec.
317
+ Priority: use Donut JSON if it obviously maps, otherwise fall back to regex/heuristics.
318
+ For now we mostly use heuristics and ignore donut_json except as a future hook.
319
+ """
320
+
321
+ header = extract_header_fields(full_text)
322
+ line_items = extract_line_items(full_text)
323
+
324
+ # totalQuantity & totalUnits
325
+ total_quantity = sum(
326
+ [itm["quantityShipped"] for itm in line_items if isinstance(itm["quantityShipped"], int)]
327
+ ) or None
328
+
329
+ # pick most common unit among items
330
+ units = [itm["inventoryUnits"] for itm in line_items if itm["inventoryUnits"]]
331
+ total_units = units[0] if units else None
332
+
333
+ # custom fields (applied to all items)
334
+ header_custom_fields = extract_custom_fields(full_text)
335
+ for itm in line_items:
336
+ itm["customFields"] = header_custom_fields.copy()
337
+
338
+ # If no line items detected, still return empty array but valid schema
339
+ if not line_items:
340
+ line_items = []
341
+
342
+ result: Dict[str, Any] = {
343
+ "poNumber": header["poNumber"],
344
+ "shipFrom": header["shipFrom"],
345
+ "carrierType": header["carrierType"],
346
+ "originCarrier": header["originCarrier"],
347
+ "railCarNumber": header["railCarNumber"],
348
+ "totalQuantity": total_quantity,
349
+ "totalUnits": total_units,
350
+ "accountName": header["accountName"],
351
+ "inventories": {
352
+ "items": line_items,
353
+ },
354
+ }
355
+
356
+ # NOTE: "Date" was part of your narrative spec but not in the final JSON schema.
357
+ # If you want it, you can add it as a customField or separate top-level key.
358
+
359
+ return result
360
+
361
+
362
+ # -----------------------------
363
+ # Main prediction function
364
+ # -----------------------------
365
+
366
+ import torch # after functions to avoid circular issues in spaces
367
+
368
+
369
+ def extract_from_document(filepath: str) -> Dict[str, Any]:
370
+ """
371
+ Main function called by Gradio:
372
+ 1. Load first page as image
373
+ 2. Try Donut for structured text
374
+ 3. Fallback to PaddleOCR
375
+ 4. Build final schema-compliant JSON
376
+ """
377
+ image = load_first_page_as_image(filepath)
378
+
379
+ # 1) Try Donut
380
+ donut_json, donut_seq = run_donut(image)
381
+
382
+ full_text = ""
383
+ if donut_json is not None:
384
+ # If donut_json contains a "text" field or similar, use it; otherwise use raw sequence.
385
+ if isinstance(donut_json, dict):
386
+ # This is model-dependent; adjust to your fine-tuned schema
387
+ text_candidate = donut_json.get("text") or donut_json.get("raw_text")
388
+ if isinstance(text_candidate, str) and text_candidate.strip():
389
+ full_text = text_candidate
390
+ if not full_text:
391
+ full_text = donut_seq
392
+
393
+ # 2) If donut didn't give us usable text, use PaddleOCR
394
+ if not full_text or len(full_text.strip()) < 10:
395
+ full_text = run_paddle_ocr(image)
396
+
397
+ # 3) Build final JSON schema
398
+ final_json = build_schema(full_text=full_text, donut_json=donut_json)
399
+
400
+ # Ensure we never return empty strings where null is required
401
+ def clean_nulls(obj: Any) -> Any:
402
+ if isinstance(obj, dict):
403
+ return {k: clean_nulls(v) for k, v in obj.items()}
404
+ if isinstance(obj, list):
405
+ return [clean_nulls(v) for v in obj]
406
+ if isinstance(obj, str) and obj.strip() == "":
407
+ return None
408
+ return obj
409
+
410
+ final_json = clean_nulls(final_json)
411
+
412
+ return final_json
413
+
414
+
415
+ # -----------------------------
416
+ # Gradio UI
417
+ # -----------------------------
418
+
419
+ demo = gr.Interface(
420
+ fn=extract_from_document,
421
+ inputs=gr.File(
422
+ label="Upload PDF or Image (BOL / Shipping Doc)",
423
+ file_types=[".pdf", ".png", ".jpg", ".jpeg", ".tif", ".tiff"],
424
+ type="filepath",
425
+ ),
426
+ outputs=gr.JSON(label="Extracted JSON"),
427
+ title="Shipping Document Text Extraction (Donut + PaddleOCR)",
428
+ description=(
429
+ "Upload a shipping document (PDF or image). "
430
+ "The app will run Donut (structured extraction) with PaddleOCR fallback "
431
+ "and return a JSON payload suitable for your inbound shipment form."
432
+ ),
433
+ )
434
+
435
+
436
+ if __name__ == "__main__":
437
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ transformers>=4.38.0
3
+ torch>=2.1.0
4
+ paddlepaddle==2.5.2
5
+ paddleocr==2.7.0.3
6
+ pdf2image
7
+ pillow
8
+ numpy