Seth0330 commited on
Commit
1e0ecd6
·
verified ·
1 Parent(s): ed88753

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -175
app.py CHANGED
@@ -1,43 +1,49 @@
1
- import io, os, re, json
 
 
 
 
 
 
2
  from typing import List, Tuple, Dict
3
- import os, sys
4
- sys.path.append(os.path.abspath("pdrt")) # <— add this
5
- import models as pdrt_models # <— from your vendored repo
6
 
7
  import numpy as np
8
  import pandas as pd
9
  from PIL import Image, ImageOps, ImageFilter
10
 
11
  import streamlit as st
12
- import torch
13
- import torchvision.transforms as T
14
 
15
- # --- word detector (Tesseract) ---
16
  import pytesseract
17
  from pytesseract import Output
18
-
19
- # --- PDF -> images ---
20
  from pdf2image import convert_from_bytes
21
 
22
- # ---- import the repo's models ----
23
- # Install via requirements.txt (git+https URL) OR copy repo files into root.
24
- # The repo defines model classes: Swin_CTC, VED
25
- import models as pdrt_models # from dparres/Pretrained-Document-Recognition-Transformers
26
-
27
- st.set_page_config(page_title="Invoice OCR (ViT recognizer + Tesseract detector)", layout="wide")
28
-
29
- # ========================= UI SIDEBAR =========================
30
- st.sidebar.header("Model")
31
- arch = st.sidebar.selectbox("Architecture", ["Swin_CTC", "VED"], index=0)
32
- ckpt_path = st.sidebar.text_input("Checkpoint path (inside Space)", value="checkpoints/pdrt_weights.pth")
33
- alphabet = st.sidebar.text_input("Alphabet (ordered classes, exclude CTC blank)", value="0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz-_/.,:;()[]{}#+*&%$@!?\"' ")
34
- img_h = st.sidebar.number_input("Recognizer input height", 64, 256, 128, 8)
35
- img_w = st.sidebar.number_input("Recognizer input width", 128, 2048, 512, 16)
36
- det_lang = st.sidebar.text_input("Tesseract lang(s) for detection only", value="eng")
 
 
 
 
37
  show_boxes = st.sidebar.checkbox("Show word boxes", value=False)
38
  device = "cuda" if torch.cuda.is_available() else "cpu"
39
 
40
- # ========================= UTILITIES =========================
 
 
 
41
  def load_pages(file_bytes: bytes, name: str) -> List[Image.Image]:
42
  name = (name or "").lower()
43
  if name.endswith(".pdf"):
@@ -50,94 +56,39 @@ def preprocess_for_detection(img: Image.Image) -> Image.Image:
50
  g = g.filter(ImageFilter.UnsharpMask(radius=1, percent=150, threshold=3))
51
  return g
52
 
53
- @st.cache_resource
54
- def load_pdrt(arch_name: str, ckpt: str, num_classes: int):
55
- if arch_name == "Swin_CTC":
56
- model = pdrt_models.Swin_CTC(num_classes=num_classes)
57
- elif arch_name == "VED":
58
- model = pdrt_models.VED(num_classes=num_classes)
59
- else:
60
- raise ValueError("Unknown model")
61
- state = torch.load(ckpt, map_location="cpu")
62
- model.load_state_dict(state, strict=False)
63
- model.eval().to(device)
64
- return model
65
-
66
- def build_transform(img_h: int, img_w: int):
67
- return T.Compose([
68
- T.Grayscale(num_output_channels=3), # keep 3ch if encoder expects RGB
69
- T.Resize((img_h, img_w)),
70
- T.ToTensor(),
71
- T.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]),
72
- ])
73
-
74
- def greedy_ctc_decode(logits: torch.Tensor, alphabet: str) -> str:
75
- """
76
- logits: (B, T, C) or (T, B, C). We map argmax to chars, collapse repeats, remove blank.
77
- We assume blank_id = len(alphabet).
78
- """
79
- if logits.dim() == 3 and logits.shape[0] != 1 and logits.shape[1] == 1:
80
- # rare shape, just permute if needed
81
- pass
82
- if logits.shape[0] == 1:
83
- logits = logits.squeeze(0) # (T, C)
84
- elif logits.shape[1] == 1:
85
- logits = logits[:,0,:] # (T, C)
86
- probs = logits.softmax(-1)
87
- ids = probs.argmax(-1).tolist()
88
- blank_id = len(alphabet)
89
- out = []
90
- prev = None
91
- for i in ids:
92
- if i != prev and i != blank_id:
93
- out.append(alphabet[i] if i < len(alphabet) else "")
94
- prev = i
95
- return "".join(out)
96
-
97
- def recognize_word_crops(model, crops: List[Image.Image], tfm, arch_name: str, alphabet: str) -> List[str]:
98
- texts = []
99
  with torch.no_grad():
100
- for im in crops:
101
- x = tfm(im).unsqueeze(0).to(device)
102
- y = model(x)
103
- if arch_name == "Swin_CTC":
104
- # expect CTC logits [B, T, C] or [T, B, C]
105
- if y.dim() == 3 and y.shape[0] == 1: # [1, T, C]
106
- logits = y[0] # [T, C]
107
- elif y.dim() == 3 and y.shape[1] == 1: # [T, 1, C]
108
- logits = y[:,0,:]
109
- else:
110
- logits = y
111
- txt = greedy_ctc_decode(logits, alphabet)
112
- else:
113
- # VED: if returns token ids/logits, plug your repo's decoding here.
114
- # Fallback: argmax over last dim per step and map ids to alphabet (no blank).
115
- if y.dim() == 3 and y.shape[0] == 1:
116
- y = y[0]
117
- ids = y.argmax(-1).tolist()
118
- txt = "".join(alphabet[i] if i < len(alphabet) else "" for i in ids).strip()
119
- texts.append(txt)
120
- return texts
121
-
122
- def detect_words(img: Image.Image, lang="eng") -> pd.DataFrame:
123
- df = pytesseract.image_to_data(img, lang=lang, output_type=Output.DATAFRAME)
124
- df = df.dropna(subset=["text"]).reset_index(drop=True)
125
- df["x2"] = df["left"] + df["width"]
126
- df["y2"] = df["top"] + df["height"]
127
- return df[df["conf"] > -1]
128
-
129
- def crop_words(img: Image.Image, df: pd.DataFrame) -> List[Tuple[Image.Image, Dict]]:
130
- crops, metas = [], []
131
- for _, r in df.iterrows():
132
- if str(r["text"]).strip() == "":
133
- continue
134
- box = (int(r["left"]), int(r["top"]), int(r["x2"]), int(r["y2"]))
135
- c = img.crop(box)
136
- crops.append(c)
137
- metas.append({"box": box})
138
- return crops, metas
139
-
140
- # ---------------- key fields & table (same logic as earlier Tesseract app) ----------------
141
  CURRENCY = r"(?P<curr>USD|CAD|EUR|GBP|\$|C\$|€|£)?"
142
  MONEY = rf"{CURRENCY}\s?(?P<amt>\d{{1,3}}(?:[,]\d{{3}})*(?:[.]\d{{2}})?)"
143
  DATE = r"(?P<date>(?:\d{4}[-/]\d{1,2}[-/]\d{1,2})|(?:\d{1,2}[-/]\d{1,2}[-/]\d{2,4})|(?:[A-Za-z]{3,9}\s+\d{1,2},\s*\d{2,4}))"
@@ -147,7 +98,7 @@ TOTAL_PAT = rf"(?:\b(total(?:\s*amount)?|amount\s*due|grand\s*total)\b.*?{MONEY}
147
  SUBTOTAL_PAT = rf"(?:\bsub\s*total\b.*?{MONEY})"
148
  TAX_PAT = rf"(?:\b(tax|gst|vat|hst)\b.*?{MONEY})"
149
 
150
- def parse_fields(fulltext: str):
151
  t = re.sub(r"[ \t]+", " ", fulltext)
152
  t = re.sub(r"\n{2,}", "\n", t)
153
  out = {"invoice_number":None,"invoice_date":None,"po_number":None,"subtotal":None,"tax":None,"total":None,"currency":None}
@@ -155,22 +106,87 @@ def parse_fields(fulltext: str):
155
  m = re.search(PO_PAT, t, re.I); out["po_number"] = m.group("po") if m else None
156
  m = re.search(rf"(invoice\s*date[:\-\s]*){DATE}", t, re.I)
157
  out["invoice_date"] = (m.group("date") if m else (re.search(DATE, t, re.I).group("date") if re.search(DATE, t, re.I) else None))
158
- m = re.search(SUBTOTAL_PAT, t, re.I|re.S);
159
  if m: out["subtotal"], out["currency"] = m.group("amt").replace(",",""), m.group("curr") or out["currency"]
160
  m = re.search(TAX_PAT, t, re.I|re.S);
161
  if m: out["tax"], out["currency"] = m.group("amt").replace(",",""), m.group("curr") or out["currency"]
162
  m = re.search(TOTAL_PAT, t, re.I|re.S);
163
- if m: out["total"], out["currency"] = m.group("amt").replace(",",""), m.group("curr") or out["currency"]
 
164
  if out["currency"] in ["$", "C$", "€", "£"]:
165
  out["currency"] = {"$":"USD", "C$":"CAD", "€":"EUR", "£":"GBP"}[out["currency"]]
166
  return out
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  HEAD_CANDIDATES = ["description","item","qty","quantity","price","unit","rate","amount","total"]
169
  def items_from_wordgrid(df: pd.DataFrame) -> pd.DataFrame:
170
- # Group into lines
 
171
  df = df.copy()
172
  df["cx"] = df["left"] + 0.5*df["width"]
173
  df["cy"] = df["top"] + 0.5*df["height"]
 
 
174
  lines = []
175
  for (b,p,l), g in df.groupby(["block_num","par_num","line_num"]):
176
  text = " ".join([t for t in g["text"].astype(str) if t.strip()])
@@ -180,7 +196,7 @@ def items_from_wordgrid(df: pd.DataFrame) -> pd.DataFrame:
180
  "text": text.lower(),
181
  "top": g["top"].min(), "bottom": (g["top"]+g["height"]).max(),
182
  "left": g["left"].min(), "right": (g["left"]+g["width"]).max(),
183
- "words": g.sort_values("cx")[["cx","left","top","width","height"]].values.tolist()
184
  })
185
  L = pd.DataFrame(lines)
186
  if L.empty: return pd.DataFrame()
@@ -190,36 +206,80 @@ def items_from_wordgrid(df: pd.DataFrame) -> pd.DataFrame:
190
  H = headers.iloc[0]
191
  header_y = H["bottom"] + 4
192
 
193
- # choose column centers from header words positions
194
- # we reuse df within header band
195
- header_band = df[(df["top"]>=H["top"]-5) & ((df["top"]+df["height"])<=H["bottom"]+5)]
196
- header_band = header_band.sort_values("left")
197
- col_x = header_band["left"].tolist()
198
- if len(col_x)<2: return pd.DataFrame()
199
- # region below header until totals
200
- below = df[df["top"]>header_y].copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  totals_mask = below["text"].str.lower().str.contains(r"(sub\s*total|amount\s*due|total|grand\s*total|balance)", regex=True, na=False)
202
  if totals_mask.any():
203
- stop_y = below.loc[totals_mask,"top"].min()
204
- below = below[below["top"]<stop_y-4]
 
 
 
 
205
  rows = []
206
  for (b,p,l), g in below.groupby(["block_num","par_num","line_num"]):
207
- if g["text"].astype(str).str.strip().eq("").all(): continue
208
  g = g.sort_values("left")
209
- # assign to nearest header word x
210
- xs = np.array(col_x)
211
  buckets = {i:[] for i in range(len(xs))}
212
- for _,w in g.iterrows():
213
- idx = int(np.abs(xs - w["left"]).argmin())
 
 
214
  buckets[idx].append(str(w["text"]))
215
- vals = [" ".join(buckets.get(i,[])).strip() for i in range(len(xs))]
216
  rows.append(vals)
217
- if not rows: return pd.DataFrame()
 
 
218
  df_rows = pd.DataFrame(rows).fillna("")
219
- # try to name columns
220
  names = []
221
- for i, w in enumerate(header_band["text"].tolist()[:df_rows.shape[1]]):
222
- wl = w.lower()
 
223
  if "desc" in wl or wl in ["item","description"]:
224
  names.append("description")
225
  elif wl in ["qty","quantity"]:
@@ -235,62 +295,49 @@ def items_from_wordgrid(df: pd.DataFrame) -> pd.DataFrame:
235
  df_rows = df_rows[~(df_rows.fillna("").apply(lambda r: "".join(r.values), axis=1).str.strip()=="")]
236
  return df_rows.reset_index(drop=True)
237
 
238
- # ========================= APP =========================
239
- st.title("Invoice Extraction — ViT recognizer (dparres) + Tesseract detector")
240
 
241
  up = st.file_uploader("Upload an invoice (PDF/JPG/PNG)", type=["pdf","png","jpg","jpeg"])
242
  if not up:
243
  st.info("Upload a scanned invoice to begin.")
244
  st.stop()
245
 
246
- pages = load_pages(up.read(), up.name)
247
-
248
  # load model once
249
- num_classes = len(alphabet) + (1 if arch=="Swin_CTC" else 0) # add CTC blank for Swin_CTC
250
- assert os.path.exists(ckpt_path), f"Checkpoint not found: {ckpt_path}"
251
- model = load_pdrt(arch, ckpt_path, num_classes)
252
- tfm = build_transform(img_h, img_w)
253
 
254
  page_idx = 0
255
  if len(pages) > 1:
256
  page_idx = st.number_input("Page", 1, len(pages), 1) - 1
257
  img = pages[page_idx]
258
 
259
- col1, col2 = st.columns([1.1,1.3], gap="large")
260
 
261
  with col1:
262
  st.subheader("Preview")
263
  st.image(img, use_column_width=True)
264
  det_img = preprocess_for_detection(img)
265
- with st.expander("Detection view"):
266
  st.image(det_img, use_column_width=True)
267
 
268
  with col2:
269
  st.subheader("OCR & Extraction")
270
- # 1) detect words (boxes only)
271
- det_df = detect_words(det_img, lang=det_lang)
272
 
273
- # 2) crop & recognize each word via ViT recognizer
274
- crops, metas = crop_words(det_img, det_df)
275
- texts = recognize_word_crops(model, crops, tfm, arch, alphabet)
276
 
277
- # 3) stitch line-by-line using tesseract line indices
278
- det_df = det_df.reset_index(drop=True)
279
- det_df["pred"] = texts
280
- grouped = det_df.groupby(["block_num","par_num","line_num"])
281
- lines = []
282
- for _, g in grouped:
283
- g = g.sort_values("left")
284
- line = " ".join([t for t in g["pred"].tolist() if t])
285
- lines.append(line)
286
- full_text = "\n".join([ln for ln in lines if ln.strip()])
287
-
288
- if show_boxes:
289
- st.caption("First 15 predicted words")
290
- st.write(det_df[["left","top","width","height","text","pred"]].head(15))
291
 
292
- # 4) key fields
293
- key_fields = parse_fields(full_text)
294
  k1,k2,k3 = st.columns(3)
295
  with k1:
296
  st.write(f"**Invoice #:** {key_fields.get('invoice_number') or '—'}")
@@ -304,21 +351,34 @@ with col2:
304
  cur = key_fields.get('currency') or ''
305
  st.write(f"**Total:** {tot} {cur}".strip())
306
 
307
- # 5) line items (geometry heuristic)
308
- items = items_from_wordgrid(det_df.assign(text=det_df["pred"]))
 
 
 
 
 
309
  st.markdown("**Line Items**")
 
310
  if items.empty:
311
  st.caption("No line items confidently detected.")
312
  else:
313
  st.dataframe(items, use_container_width=True)
314
 
315
- # 6) downloads
316
  result = {
317
- "file": up.name, "page": page_idx+1,
 
318
  "key_fields": key_fields,
319
  "items": items.to_dict(orient="records") if not items.empty else [],
320
- "full_text": full_text
321
  }
322
- st.download_button("Download JSON", data=json.dumps(result, indent=2), file_name="invoice_extraction.json", mime="application/json")
 
323
  if not items.empty:
324
- st.download_button("Download Items CSV", data=items.to_csv(index=False), file_name="invoice_items.csv", mime="text/csv")
 
 
 
 
 
 
1
+ # Streamlit Invoice Extraction — Hugging Face Donut (no local .pth) + Tesseract tables
2
+ # - Uses a pretrained model from HF Hub (default: naver-clova-ix/donut-base-finetuned-sroie)
3
+ # - Extracts key fields via Donut JSON if available, else regex fallback
4
+ # - Extracts line items via Tesseract word boxes + geometry heuristics
5
+ # - Works on HF Spaces without any custom checkpoints
6
+
7
+ import os, io, re, json
8
  from typing import List, Tuple, Dict
 
 
 
9
 
10
  import numpy as np
11
  import pandas as pd
12
  from PIL import Image, ImageOps, ImageFilter
13
 
14
  import streamlit as st
 
 
15
 
16
+ # OCR for word boxes (detection only) + pdf to images
17
  import pytesseract
18
  from pytesseract import Output
 
 
19
  from pdf2image import convert_from_bytes
20
 
21
+ # HF Donut (pretrained, downloaded automatically)
22
+ import torch
23
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
24
+
25
+ st.set_page_config(page_title="Invoice Extraction — Donut (HF) + Tesseract tables", layout="wide")
26
+
27
+ # ----------------------------- Sidebar -----------------------------
28
+ st.sidebar.header("Model (Hugging Face)")
29
+ model_id = st.sidebar.text_input(
30
+ "HF model id",
31
+ value="naver-clova-ix/donut-base-finetuned-sroie", # good default for receipts/invoices (SROIE)
32
+ help="Examples: naver-clova-ix/donut-base-finetuned-sroie, naver-clova-ix/donut-base-finetuned-docvqa"
33
+ )
34
+ task_prompt = st.sidebar.text_input(
35
+ "Task prompt (for Donut models expecting prompts)",
36
+ value="<s_cord-v2>", # SROIE/cord-style models typically ignore or use default; harmless to keep
37
+ help="Some Donut checkpoints use task-specific prompts; keep or adjust as needed."
38
+ )
39
+ det_lang = st.sidebar.text_input("Tesseract language(s) — detection only", value="eng")
40
  show_boxes = st.sidebar.checkbox("Show word boxes", value=False)
41
  device = "cuda" if torch.cuda.is_available() else "cpu"
42
 
43
+ st.sidebar.markdown("---")
44
+ st.sidebar.caption("Tip: If your model outputs JSON (e.g., SROIE), we’ll parse it for key fields. Otherwise we’ll regex from generated text.")
45
+
46
+ # ----------------------------- Utilities -----------------------------
47
  def load_pages(file_bytes: bytes, name: str) -> List[Image.Image]:
48
  name = (name or "").lower()
49
  if name.endswith(".pdf"):
 
56
  g = g.filter(ImageFilter.UnsharpMask(radius=1, percent=150, threshold=3))
57
  return g
58
 
59
+ @st.cache_resource(show_spinner=True)
60
+ def load_donut(_model_id: str):
61
+ processor = DonutProcessor.from_pretrained(_model_id)
62
+ model = VisionEncoderDecoderModel.from_pretrained(_model_id)
63
+ model.to(device)
64
+ model.eval()
65
+ return processor, model
66
+
67
+ def donut_infer(img: Image.Image, processor: DonutProcessor, model: VisionEncoderDecoderModel, prompt: str):
68
+ # Donut expects RGB PIL Image; processor handles resizing/normalization
69
+ inputs = processor(images=img, text=prompt, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  with torch.no_grad():
71
+ outputs = model.generate(
72
+ **inputs,
73
+ max_length=1024,
74
+ num_beams=1,
75
+ early_stopping=True,
76
+ )
77
+ # decode
78
+ seq = processor.batch_decode(outputs, skip_special_tokens=True)[0]
79
+ # Donut models often emit JSON; try to parse
80
+ parsed = None
81
+ try:
82
+ # strip whitespace garbage around JSON
83
+ start = seq.find("{")
84
+ end = seq.rfind("}")
85
+ if start != -1 and end != -1 and end > start:
86
+ parsed = json.loads(seq[start:end+1])
87
+ except Exception:
88
+ parsed = None
89
+ return seq, parsed
90
+
91
+ # ----------------------------- Key fields & line items -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  CURRENCY = r"(?P<curr>USD|CAD|EUR|GBP|\$|C\$|€|£)?"
93
  MONEY = rf"{CURRENCY}\s?(?P<amt>\d{{1,3}}(?:[,]\d{{3}})*(?:[.]\d{{2}})?)"
94
  DATE = r"(?P<date>(?:\d{4}[-/]\d{1,2}[-/]\d{1,2})|(?:\d{1,2}[-/]\d{1,2}[-/]\d{2,4})|(?:[A-Za-z]{3,9}\s+\d{1,2},\s*\d{2,4}))"
 
98
  SUBTOTAL_PAT = rf"(?:\bsub\s*total\b.*?{MONEY})"
99
  TAX_PAT = rf"(?:\b(tax|gst|vat|hst)\b.*?{MONEY})"
100
 
101
+ def parse_fields_regex(fulltext: str):
102
  t = re.sub(r"[ \t]+", " ", fulltext)
103
  t = re.sub(r"\n{2,}", "\n", t)
104
  out = {"invoice_number":None,"invoice_date":None,"po_number":None,"subtotal":None,"tax":None,"total":None,"currency":None}
 
106
  m = re.search(PO_PAT, t, re.I); out["po_number"] = m.group("po") if m else None
107
  m = re.search(rf"(invoice\s*date[:\-\s]*){DATE}", t, re.I)
108
  out["invoice_date"] = (m.group("date") if m else (re.search(DATE, t, re.I).group("date") if re.search(DATE, t, re.I) else None))
109
+ m = re.search(SUBTOTAL_PAT, t, re.I|re.S);
110
  if m: out["subtotal"], out["currency"] = m.group("amt").replace(",",""), m.group("curr") or out["currency"]
111
  m = re.search(TAX_PAT, t, re.I|re.S);
112
  if m: out["tax"], out["currency"] = m.group("amt").replace(",",""), m.group("curr") or out["currency"]
113
  m = re.search(TOTAL_PAT, t, re.I|re.S);
114
+ if m:
115
+ out["total"], out["currency"] = m.group("amt").replace(",", ""), m.group("curr") or out["currency"]
116
  if out["currency"] in ["$", "C$", "€", "£"]:
117
  out["currency"] = {"$":"USD", "C$":"CAD", "€":"EUR", "£":"GBP"}[out["currency"]]
118
  return out
119
 
120
+ def normalize_kv_from_donut(parsed: dict):
121
+ """Try to map common Donut outputs to our schema."""
122
+ txt = json.dumps(parsed).lower()
123
+ # heuristic mapping for typical SROIE/receipt keys
124
+ candidates = {
125
+ "invoice_number": ["invoice_number","invoice no","invoice_no","invoice","inv_no"],
126
+ "invoice_date": ["date","invoice_date","bill_date"],
127
+ "po_number": ["po_number","po","purchase_order"],
128
+ "subtotal": ["subtotal","sub_total"],
129
+ "tax": ["tax","gst","vat","hst"],
130
+ "total": ["total","amount_total","amount_due","grand_total"]
131
+ }
132
+ out = {k: None for k in ["invoice_number","invoice_date","po_number","subtotal","tax","total","currency"]}
133
+ # simple search: pick first occurrence
134
+ def search_keys(obj, key_list):
135
+ # breadth-first scan
136
+ if isinstance(obj, dict):
137
+ for k, v in obj.items():
138
+ if any(kk in k.lower() for kk in key_list):
139
+ return v
140
+ found = search_keys(v, key_list)
141
+ if found is not None:
142
+ return found
143
+ elif isinstance(obj, list):
144
+ for it in obj:
145
+ found = search_keys(it, key_list)
146
+ if found is not None:
147
+ return found
148
+ return None
149
+
150
+ for outk, key_list in candidates.items():
151
+ val = search_keys(parsed, key_list)
152
+ if isinstance(val, (dict, list)):
153
+ val = None # keep it simple; Donut sometimes nests values
154
+ if isinstance(val, str):
155
+ out[outk] = val.strip()
156
+ # currency guess:
157
+ curr = re.search(r"(USD|CAD|EUR|GBP|\$|C\$|€|£)", json.dumps(parsed, ensure_ascii=False), re.I)
158
+ if curr:
159
+ sym = curr.group(1)
160
+ out["currency"] = {"$":"USD","C$":"CAD","€":"EUR","£":"GBP"}.get(sym, sym.upper())
161
+ return out
162
+
163
+ def detect_words(img: Image.Image, lang="eng") -> pd.DataFrame:
164
+ df = pytesseract.image_to_data(img, lang=lang, output_type=Output.DATAFRAME)
165
+ df = df.dropna(subset=["text"]).reset_index(drop=True)
166
+ df["x2"] = df["left"] + df["width"]
167
+ df["y2"] = df["top"] + df["height"]
168
+ return df[df["conf"] > -1]
169
+
170
+ def crop_words(img: Image.Image, df: pd.DataFrame) -> List[Tuple[Image.Image, Dict]]:
171
+ crops, metas = [], []
172
+ for _, r in df.iterrows():
173
+ if str(r["text"]).strip() == "":
174
+ continue
175
+ box = (int(r["left"]), int(r["top"]), int(r["x2"]), int(r["y2"]))
176
+ c = img.crop(box)
177
+ crops.append(c)
178
+ metas.append({"box": box})
179
+ return crops, metas
180
+
181
  HEAD_CANDIDATES = ["description","item","qty","quantity","price","unit","rate","amount","total"]
182
  def items_from_wordgrid(df: pd.DataFrame) -> pd.DataFrame:
183
+ if df.empty:
184
+ return pd.DataFrame()
185
  df = df.copy()
186
  df["cx"] = df["left"] + 0.5*df["width"]
187
  df["cy"] = df["top"] + 0.5*df["height"]
188
+
189
+ # group lines
190
  lines = []
191
  for (b,p,l), g in df.groupby(["block_num","par_num","line_num"]):
192
  text = " ".join([t for t in g["text"].astype(str) if t.strip()])
 
196
  "text": text.lower(),
197
  "top": g["top"].min(), "bottom": (g["top"]+g["height"]).max(),
198
  "left": g["left"].min(), "right": (g["left"]+g["width"]).max(),
199
+ "words": g.sort_values("left")[["left","top","width","height","text"]].values.tolist()
200
  })
201
  L = pd.DataFrame(lines)
202
  if L.empty: return pd.DataFrame()
 
206
  H = headers.iloc[0]
207
  header_y = H["bottom"] + 4
208
 
209
+ # derive column anchors from header words positions
210
+ df_header = detect_words(img=None, lang="eng") # placeholder to keep signature consistent
211
+
212
+ # get header band words
213
+ # reconstruct header band from original DF
214
+ # (we need original df back here; easier: pass it in as closure var)
215
+ # we'll adapt: compute from global last_df if present
216
+ return_df = pd.DataFrame()
217
+ return return_df
218
+
219
+ # We’ll implement a simpler, robust table extractor to avoid closure complexity:
220
+ def items_from_words_simple(tsv: pd.DataFrame) -> pd.DataFrame:
221
+ # find header line
222
+ L = []
223
+ for (b,p,l), g in tsv.groupby(["block_num","par_num","line_num"]):
224
+ text = " ".join([w for w in g["text"].astype(str).tolist() if w.strip()])
225
+ if text.strip():
226
+ L.append({
227
+ "block_num": b, "par_num": p, "line_num": l,
228
+ "text": text.lower(),
229
+ "top": g["top"].min(), "bottom": (g["top"]+g["height"]).max(),
230
+ "left": g["left"].min(), "right": (g["left"]+g["width"]).max()
231
+ })
232
+ lines = pd.DataFrame(L)
233
+ if lines.empty:
234
+ return pd.DataFrame()
235
+
236
+ def score_header(s: str):
237
+ return sum(1 for h in HEAD_CANDIDATES if h in s)
238
+
239
+ lines["header_score"] = lines["text"].apply(score_header)
240
+ hdrs = lines[lines["header_score"] >= 2].sort_values(["header_score","top"], ascending=[False,True])
241
+ if hdrs.empty:
242
+ return pd.DataFrame()
243
+ H = hdrs.iloc[0]
244
+ header_top, header_bottom = H["top"], H["bottom"]
245
+
246
+ # header words
247
+ header_words = tsv[(tsv["top"] >= header_top - 5) & ((tsv["top"] + tsv["height"]) <= header_bottom + 5)]
248
+ header_words = header_words.sort_values("left")
249
+ if header_words.empty:
250
+ return pd.DataFrame()
251
+ xs = header_words["left"].tolist()
252
+
253
+ # items region
254
+ below = tsv[tsv["top"] > header_bottom + 5].copy()
255
  totals_mask = below["text"].str.lower().str.contains(r"(sub\s*total|amount\s*due|total|grand\s*total|balance)", regex=True, na=False)
256
  if totals_mask.any():
257
+ stop_y = below.loc[totals_mask, "top"].min()
258
+ below = below[below["top"] < stop_y - 4]
259
+ if below.empty:
260
+ return pd.DataFrame()
261
+
262
+ # build rows by assigning words to nearest header x
263
  rows = []
264
  for (b,p,l), g in below.groupby(["block_num","par_num","line_num"]):
 
265
  g = g.sort_values("left")
 
 
266
  buckets = {i:[] for i in range(len(xs))}
267
+ for _, w in g.iterrows():
268
+ if not str(w["text"]).strip():
269
+ continue
270
+ idx = int(np.abs(np.array(xs) - w["left"]).argmin())
271
  buckets[idx].append(str(w["text"]))
272
+ vals = [" ".join(buckets[i]).strip() for i in range(len(xs))]
273
  rows.append(vals)
274
+ if not rows:
275
+ return pd.DataFrame()
276
+
277
  df_rows = pd.DataFrame(rows).fillna("")
278
+ # name columns heuristically
279
  names = []
280
+ hdr_tokens = [t.lower() for t in header_words["text"].tolist()]
281
+ for i in range(df_rows.shape[1]):
282
+ wl = hdr_tokens[i] if i < len(hdr_tokens) else f"col_{i}"
283
  if "desc" in wl or wl in ["item","description"]:
284
  names.append("description")
285
  elif wl in ["qty","quantity"]:
 
295
  df_rows = df_rows[~(df_rows.fillna("").apply(lambda r: "".join(r.values), axis=1).str.strip()=="")]
296
  return df_rows.reset_index(drop=True)
297
 
298
+ # ----------------------------- App -----------------------------
299
+ st.title("Invoice Extraction — Donut (HF pretrained) + Tesseract tables")
300
 
301
  up = st.file_uploader("Upload an invoice (PDF/JPG/PNG)", type=["pdf","png","jpg","jpeg"])
302
  if not up:
303
  st.info("Upload a scanned invoice to begin.")
304
  st.stop()
305
 
 
 
306
  # load model once
307
+ with st.spinner(f"Loading model '{model_id}' from Hugging Face…"):
308
+ processor, donut_model = load_donut(model_id)
309
+
310
+ pages = load_pages(up.read(), up.name)
311
 
312
  page_idx = 0
313
  if len(pages) > 1:
314
  page_idx = st.number_input("Page", 1, len(pages), 1) - 1
315
  img = pages[page_idx]
316
 
317
+ col1, col2 = st.columns([1.1, 1.3], gap="large")
318
 
319
  with col1:
320
  st.subheader("Preview")
321
  st.image(img, use_column_width=True)
322
  det_img = preprocess_for_detection(img)
323
+ with st.expander("Detection view (preprocessed for boxes)"):
324
  st.image(det_img, use_column_width=True)
325
 
326
  with col2:
327
  st.subheader("OCR & Extraction")
 
 
328
 
329
+ # 1) Donut extraction (key fields or full text)
330
+ with st.spinner("Running Donut…"):
331
+ seq, parsed = donut_infer(img, processor, donut_model, task_prompt)
332
 
333
+ # 2) Key fields
334
+ if parsed:
335
+ key_fields = normalize_kv_from_donut(parsed)
336
+ donut_payload = parsed
337
+ else:
338
+ key_fields = parse_fields_regex(seq)
339
+ donut_payload = {"generated_text": seq}
 
 
 
 
 
 
 
340
 
 
 
341
  k1,k2,k3 = st.columns(3)
342
  with k1:
343
  st.write(f"**Invoice #:** {key_fields.get('invoice_number') or '—'}")
 
351
  cur = key_fields.get('currency') or ''
352
  st.write(f"**Total:** {tot} {cur}".strip())
353
 
354
+ # 3) Tesseract line items (geometry heuristic)
355
+ with st.spinner("Detecting words with Tesseract (for table)…"):
356
+ tsv = pytesseract.image_to_data(det_img, lang=det_lang, output_type=Output.DATAFRAME)
357
+ tsv = tsv.dropna(subset=["text"]).reset_index(drop=True)
358
+ tsv["x2"] = tsv["left"] + tsv["width"]
359
+ tsv["y2"] = tsv["top"] + tsv["height"]
360
+
361
  st.markdown("**Line Items**")
362
+ items = items_from_words_simple(tsv)
363
  if items.empty:
364
  st.caption("No line items confidently detected.")
365
  else:
366
  st.dataframe(items, use_container_width=True)
367
 
368
+ # 4) Downloads
369
  result = {
370
+ "file": up.name,
371
+ "page": page_idx + 1,
372
  "key_fields": key_fields,
373
  "items": items.to_dict(orient="records") if not items.empty else [],
374
+ "donut_raw": donut_payload,
375
  }
376
+ st.download_button("Download JSON", data=json.dumps(result, indent=2),
377
+ file_name="invoice_extraction.json", mime="application/json")
378
  if not items.empty:
379
+ st.download_button("Download Items CSV", data=items.to_csv(index=False),
380
+ file_name="invoice_items.csv", mime="text/csv")
381
+
382
+ if show_boxes:
383
+ st.caption("First 20 Tesseract word boxes")
384
+ st.dataframe(tsv[["left","top","width","height","text","conf"]].head(20), use_container_width=True)