Seth0330 commited on
Commit
11d644c
·
verified ·
1 Parent(s): 1403715

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +258 -276
app.py CHANGED
@@ -1,304 +1,294 @@
1
- import io
2
- import re
3
- import json
4
  import numpy as np
5
  import pandas as pd
6
  from PIL import Image, ImageOps, ImageFilter
 
7
  import streamlit as st
 
 
 
 
8
  import pytesseract
9
  from pytesseract import Output
10
 
11
- # PDF images
12
- try:
13
- from pdf2image import convert_from_bytes
14
- PDF_OK = True
15
- except Exception:
16
- PDF_OK = False
17
-
18
- st.set_page_config(page_title="Invoice OCR (Tesseract) · Streamlit", layout="wide")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # --------------------------- Image utils ---------------------------
21
- def preprocess(img: Image.Image) -> Image.Image:
22
- """Light cleanup to help Tesseract: grayscale, contrast, binarize, sharpen."""
23
  g = ImageOps.grayscale(img)
24
  g = ImageOps.autocontrast(g)
25
- # mild unsharp for text edges
26
  g = g.filter(ImageFilter.UnsharpMask(radius=1, percent=150, threshold=3))
27
- # adaptive-like: simple threshold after autocontrast
28
- arr = np.array(g)
29
- thr = np.clip(arr.mean() * 0.9, 110, 180) # heuristic
30
- bw = Image.fromarray((arr > thr).astype(np.uint8) * 255)
31
- return bw
32
-
33
- def load_pages(file_bytes: bytes, name: str):
34
- """Return a list of PIL Images (pages)."""
35
- name = (name or "").lower()
36
- if name.endswith(".pdf"):
37
- if not PDF_OK:
38
- st.error("pdf2image not available. Did you add poppler in apt.txt?")
39
- return []
40
- pages = convert_from_bytes(file_bytes, dpi=300)
41
- return pages
42
  else:
43
- img = Image.open(io.BytesIO(file_bytes)).convert("RGB")
44
- return [img]
45
-
46
- # --------------------------- OCR ---------------------------
47
- def ocr_tsv(img: Image.Image, lang="eng") -> pd.DataFrame:
48
- """Run Tesseract and return TSV dataframe (one row per word)."""
49
- # Important: keep original scale for better bbox geometry
50
- data = pytesseract.image_to_data(img, lang=lang, output_type=Output.DATAFRAME)
51
- # Drop NaNs that Tesseract sometimes emits
52
- data = data.dropna(subset=["text"]).reset_index(drop=True)
53
- # Compute centers for convenience
54
- data["x2"] = data["left"] + data["width"]
55
- data["y2"] = data["top"] + data["height"]
56
- data["cx"] = data["left"] + data["width"] / 2
57
- data["cy"] = data["top"] + data["height"] / 2
58
- return data
59
-
60
- def ocr_text(img: Image.Image, lang="eng") -> str:
61
- return pytesseract.image_to_string(img, lang=lang)
62
-
63
- # --------------------------- Key-field parsing ---------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  CURRENCY = r"(?P<curr>USD|CAD|EUR|GBP|\$|C\$|€|£)?"
65
  MONEY = rf"{CURRENCY}\s?(?P<amt>\d{{1,3}}(?:[,]\d{{3}})*(?:[.]\d{{2}})?)"
66
-
67
  DATE = r"(?P<date>(?:\d{4}[-/]\d{1,2}[-/]\d{1,2})|(?:\d{1,2}[-/]\d{1,2}[-/]\d{2,4})|(?:[A-Za-z]{3,9}\s+\d{1,2},\s*\d{2,4}))"
68
- INV_PAT = r"(?:invoice\s*(?:no\.?|#|number)?\s*[:\-]?\s*(?P<inv>[A-Z0-9\-_/]{4,})).*"
69
  PO_PAT = r"(?:po\s*(?:no\.?|#|number)?\s*[:\-]?\s*(?P<po>[A-Z0-9\-_/]{3,}))"
70
  TOTAL_PAT = rf"(?:\b(total(?:\s*amount)?|amount\s*due|grand\s*total)\b.*?{MONEY})"
71
  SUBTOTAL_PAT = rf"(?:\bsub\s*total\b.*?{MONEY})"
72
  TAX_PAT = rf"(?:\b(tax|gst|vat|hst)\b.*?{MONEY})"
73
 
74
- def find_first(pattern, text, flags=re.IGNORECASE | re.DOTALL):
75
- m = re.search(pattern, text, flags)
76
- return (m.groupdict() if m else None), m
77
-
78
  def parse_fields(fulltext: str):
79
- # Normalize spaces
80
  t = re.sub(r"[ \t]+", " ", fulltext)
81
  t = re.sub(r"\n{2,}", "\n", t)
82
-
83
- out = {
84
- "invoice_number": None,
85
- "invoice_date": None,
86
- "po_number": None,
87
- "subtotal": None,
88
- "tax": None,
89
- "total": None,
90
- "currency": None,
91
- }
92
-
93
- # Invoice number
94
- g,_ = find_first(INV_PAT, t)
95
- if g and g.get("inv"):
96
- out["invoice_number"] = g["inv"].strip()
97
-
98
- # PO
99
- g,_ = find_first(PO_PAT, t)
100
- if g and g.get("po"):
101
- out["po_number"] = g["po"].strip()
102
-
103
- # Date: look near "invoice date" first
104
- near_date = re.search(rf"(invoice\s*date[:\-\s]*){DATE}", t, re.IGNORECASE)
105
- if near_date:
106
- out["invoice_date"] = near_date.group("date")
107
- else:
108
- g,_ = find_first(DATE, t)
109
- if g and g.get("date"):
110
- out["invoice_date"] = g["date"]
111
-
112
- # Monetary values
113
- # Subtotal
114
- g,m = find_first(SUBTOTAL_PAT, t)
115
- if g and g.get("amt"):
116
- out["subtotal"] = g["amt"].replace(",", "")
117
- out["currency"] = g.get("curr") or out["currency"]
118
-
119
- # Tax
120
- g,m = find_first(TAX_PAT, t)
121
- if g and g.get("amt"):
122
- out["tax"] = g["amt"].replace(",", "")
123
- out["currency"] = g.get("curr") or out["currency"]
124
-
125
- # Total / Amount Due
126
- g,m = find_first(TOTAL_PAT, t)
127
- if g and g.get("amt"):
128
- out["total"] = g["amt"].replace(",", "")
129
- out["currency"] = g.get("curr") or out["currency"]
130
-
131
- # Normalize currency symbols
132
  if out["currency"] in ["$", "C$", "€", "£"]:
133
- sym_map = {"$":"USD", "C$":"CAD", "€":"EUR", "£":"GBP"}
134
- out["currency"] = sym_map.get(out["currency"], out["currency"])
135
-
136
  return out
137
 
138
- # --------------------------- Line item parsing ---------------------------
139
- HEAD_CANDIDATES = ["description", "item", "qty", "quantity", "price", "unit price", "rate", "amount", "total"]
140
- def guess_header_rows(tsv: pd.DataFrame) -> pd.DataFrame:
141
- """
142
- Try to find a header line based on presence of common header tokens.
143
- Returns candidate header rows (can be empty).
144
- """
145
- # Group by (block, par, line) -> line text and bbox
146
  lines = []
147
- for keys, g in tsv.groupby(["block_num", "par_num", "line_num"], as_index=False):
148
- text = " ".join([w for w in g["text"].astype(str).tolist() if w.strip()])
149
  if text.strip():
150
- row = {
151
- "block_num": keys[0],
152
- "par_num": keys[1],
153
- "line_num": keys[2],
154
  "text": text.lower(),
155
- "top": g["top"].min(),
156
- "bottom": g["y2"].max(),
157
- "left": g["left"].min(),
158
- "right": g["x2"].max(),
159
- }
160
- lines.append(row)
161
  L = pd.DataFrame(lines)
162
- if L.empty:
163
- return L
164
-
165
- def score_header(s: str):
166
- tokens = sum(1 for h in HEAD_CANDIDATES if h in s)
167
- return tokens
168
-
169
- L["header_score"] = L["text"].apply(score_header)
170
- return L[L["header_score"] >= 2].sort_values(["header_score", "top"], ascending=[False, True])
171
-
172
- def extract_table(tsv: pd.DataFrame) -> pd.DataFrame:
173
- """
174
- Simple geometry-driven itemization:
175
- - find a header line
176
- - derive rough column boundaries from header word x-positions
177
- - assign subsequent words into nearest column
178
- - stop when large vertical gap or when totals region starts
179
- """
180
- header_lines = guess_header_rows(tsv)
181
- if header_lines.empty:
182
- return pd.DataFrame()
183
-
184
- # Take the top-scoring header
185
- H = header_lines.iloc[0]
186
- header_band_top, header_band_bottom = H["top"], H["bottom"]
187
-
188
- # Words within header band
189
- header_words = tsv[(tsv["top"] >= header_band_top - 5) & (tsv["y2"] <= header_band_bottom + 5)]
190
- # Keep only words that look like header candidates
191
- header_words = header_words[header_words["text"].str.lower().isin([h for h in HEAD_CANDIDATES if " " not in h]) |
192
- header_words["text"].str.lower().isin(["description","item","qty","price","amount","total"])]
193
-
194
- if header_words.empty:
195
- return pd.DataFrame()
196
-
197
- # Sort by x center; build columns
198
- header_words = header_words.sort_values("cx")
199
- columns = []
200
- for _, w in header_words.iterrows():
201
- columns.append({"name": w["text"].lower(), "x": w["cx"]})
202
-
203
- # Canonical column order by x
204
- columns = sorted(columns, key=lambda c: c["x"])
205
-
206
- # Items region: words below header, but above totals area (heuristic)
207
- below = tsv[tsv["top"] > header_band_bottom + 5].copy()
208
-
209
- # Stop at the first strong "total" line to avoid footer math rows
210
- footer_y = None
211
  totals_mask = below["text"].str.lower().str.contains(r"(sub\s*total|amount\s*due|total|grand\s*total|balance)", regex=True, na=False)
212
  if totals_mask.any():
213
- footer_y = below.loc[totals_mask, "top"].min()
214
- below = below[below["top"] < footer_y - 4]
215
-
216
- if below.empty:
217
- return pd.DataFrame()
218
-
219
- # Group by line again, then split into columns by nearest header x
220
- items = []
221
  for (b,p,l), g in below.groupby(["block_num","par_num","line_num"]):
222
- words = g.sort_values("cx")
223
- if words["text"].str.strip().eq("").all():
224
- continue
225
-
226
- # Assign each word to nearest column center
227
- col_texts = {c["name"]: [] for c in columns}
228
- for _, w in words.iterrows():
229
- if not str(w["text"]).strip():
230
- continue
231
- nearest = min(columns, key=lambda c: abs(c["x"] - w["cx"]))
232
- col_texts[nearest["name"]].append(str(w["text"]))
233
-
234
- row = {k: " ".join(v).strip() for k,v in col_texts.items()}
235
- # basic filters to avoid empty noise lines
236
- if any(val for val in row.values()):
237
- items.append(row)
238
-
239
- df = pd.DataFrame(items)
240
- # Normalize common column names
241
- rename_map = {}
242
- for c in df.columns:
243
- if "desc" in c or c == "item":
244
- rename_map[c] = "description"
245
- elif c in ["qty","quantity"]:
246
- rename_map[c] = "quantity"
247
- elif "unit" in c or "rate" in c or "price" in c:
248
- rename_map[c] = "unit_price"
249
- elif "amount" in c or "total" in c:
250
- rename_map[c] = "line_total"
251
- df = df.rename(columns=rename_map)
252
-
253
- # Drop fully empty rows
254
- df = df[[c for c in ["description","quantity","unit_price","line_total"] if c in df.columns]]
255
- if not df.empty:
256
- df = df[~(df.fillna("").apply(lambda r: "".join(r.values), axis=1).str.strip()=="")]
257
- return df.reset_index(drop=True)
258
-
259
- # --------------------------- App UI ---------------------------
260
- st.title("Invoice Extraction (Tesseract · Streamlit)")
261
-
262
- st.sidebar.header("Settings")
263
- lang = st.sidebar.text_input("Tesseract language(s)", value="eng")
264
- show_tsv = st.sidebar.checkbox("Show raw OCR TSV", value=False)
265
- show_fulltext = st.sidebar.checkbox("Show full OCR text", value=False)
266
-
267
- up = st.file_uploader("Upload an invoice (PDF, PNG, JPG)", type=["pdf","png","jpg","jpeg"], accept_multiple_files=False)
268
-
269
  if not up:
270
- st.info("Upload a scanned invoice PDF or an image to begin.")
271
  st.stop()
272
 
273
  pages = load_pages(up.read(), up.name)
274
- if not pages:
275
- st.stop()
276
 
277
- # Page selector (for multi-page PDFs)
 
 
 
 
 
 
278
  if len(pages) > 1:
279
- idx = st.number_input("Page", min_value=1, max_value=len(pages), value=1)
280
- img = pages[idx-1]
281
- else:
282
- img = pages[0]
283
 
284
- col_prev, col_data = st.columns([1.1, 1.3], gap="large")
285
 
286
- with col_prev:
287
  st.subheader("Preview")
288
- st.image(img, use_column_width=True, caption="Original page")
289
- pre = preprocess(img)
290
- with st.expander("Preprocessed (for OCR)"):
291
- st.image(pre, use_column_width=True)
292
-
293
- with col_data:
294
- st.subheader("Extraction")
295
- with st.spinner("Running Tesseract..."):
296
- tsv = ocr_tsv(pre, lang=lang)
297
- text = ocr_text(pre, lang=lang)
298
-
299
- key_fields = parse_fields(text)
300
- st.markdown("**Key Fields (heuristic)**")
301
- k1, k2, k3 = st.columns(3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  with k1:
303
  st.write(f"**Invoice #:** {key_fields.get('invoice_number') or '—'}")
304
  st.write(f"**Invoice Date:** {key_fields.get('invoice_date') or '—'}")
@@ -311,29 +301,21 @@ with col_data:
311
  cur = key_fields.get('currency') or ''
312
  st.write(f"**Total:** {tot} {cur}".strip())
313
 
314
- st.markdown("**Line Items (auto-detected)**")
315
- items = extract_table(tsv)
 
316
  if items.empty:
317
- st.caption("No line items confidently detected. You can still download full OCR text.")
318
  else:
319
  st.dataframe(items, use_container_width=True)
320
 
321
- # Downloads
322
  result = {
323
- "file": up.name,
324
  "key_fields": key_fields,
325
  "items": items.to_dict(orient="records") if not items.empty else [],
326
- "full_text": text,
327
  }
328
- j = json.dumps(result, indent=2)
329
- st.download_button("Download JSON", data=j, file_name="invoice_extraction.json", mime="application/json")
330
  if not items.empty:
331
- csv = items.to_csv(index=False)
332
- st.download_button("Download Line Items CSV", data=csv, file_name="invoice_items.csv", mime="text/csv")
333
-
334
- # Optional raw views
335
- with st.expander("Advanced · Raw Outputs"):
336
- if show_fulltext:
337
- st.text_area("OCR Full Text", value=text, height=220)
338
- if show_tsv:
339
- st.dataframe(tsv.head(100), use_container_width=True)
 
1
+ import io, os, re, json
2
+ from typing import List, Tuple, Dict
3
+
4
  import numpy as np
5
  import pandas as pd
6
  from PIL import Image, ImageOps, ImageFilter
7
+
8
  import streamlit as st
9
+ import torch
10
+ import torchvision.transforms as T
11
+
12
+ # --- word detector (Tesseract) ---
13
  import pytesseract
14
  from pytesseract import Output
15
 
16
+ # --- PDF -> images ---
17
+ from pdf2image import convert_from_bytes
18
+
19
+ # ---- import the repo's models ----
20
+ # Install via requirements.txt (git+https URL) OR copy repo files into root.
21
+ # The repo defines model classes: Swin_CTC, VED
22
+ import models as pdrt_models # from dparres/Pretrained-Document-Recognition-Transformers
23
+
24
+ st.set_page_config(page_title="Invoice OCR (ViT recognizer + Tesseract detector)", layout="wide")
25
+
26
+ # ========================= UI SIDEBAR =========================
27
+ st.sidebar.header("Model")
28
+ arch = st.sidebar.selectbox("Architecture", ["Swin_CTC", "VED"], index=0)
29
+ ckpt_path = st.sidebar.text_input("Checkpoint path (inside Space)", value="checkpoints/pdrt_weights.pth")
30
+ alphabet = st.sidebar.text_input("Alphabet (ordered classes, exclude CTC blank)", value="0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz-_/.,:;()[]{}#+*&%$@!?\"' ")
31
+ img_h = st.sidebar.number_input("Recognizer input height", 64, 256, 128, 8)
32
+ img_w = st.sidebar.number_input("Recognizer input width", 128, 2048, 512, 16)
33
+ det_lang = st.sidebar.text_input("Tesseract lang(s) for detection only", value="eng")
34
+ show_boxes = st.sidebar.checkbox("Show word boxes", value=False)
35
+ device = "cuda" if torch.cuda.is_available() else "cpu"
36
+
37
+ # ========================= UTILITIES =========================
38
+ def load_pages(file_bytes: bytes, name: str) -> List[Image.Image]:
39
+ name = (name or "").lower()
40
+ if name.endswith(".pdf"):
41
+ return convert_from_bytes(file_bytes, dpi=300)
42
+ return [Image.open(io.BytesIO(file_bytes)).convert("RGB")]
43
 
44
+ def preprocess_for_detection(img: Image.Image) -> Image.Image:
 
 
45
  g = ImageOps.grayscale(img)
46
  g = ImageOps.autocontrast(g)
 
47
  g = g.filter(ImageFilter.UnsharpMask(radius=1, percent=150, threshold=3))
48
+ return g
49
+
50
+ @st.cache_resource
51
+ def load_pdrt(arch_name: str, ckpt: str, num_classes: int):
52
+ if arch_name == "Swin_CTC":
53
+ model = pdrt_models.Swin_CTC(num_classes=num_classes)
54
+ elif arch_name == "VED":
55
+ model = pdrt_models.VED(num_classes=num_classes)
 
 
 
 
 
 
 
56
  else:
57
+ raise ValueError("Unknown model")
58
+ state = torch.load(ckpt, map_location="cpu")
59
+ model.load_state_dict(state, strict=False)
60
+ model.eval().to(device)
61
+ return model
62
+
63
+ def build_transform(img_h: int, img_w: int):
64
+ return T.Compose([
65
+ T.Grayscale(num_output_channels=3), # keep 3ch if encoder expects RGB
66
+ T.Resize((img_h, img_w)),
67
+ T.ToTensor(),
68
+ T.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]),
69
+ ])
70
+
71
+ def greedy_ctc_decode(logits: torch.Tensor, alphabet: str) -> str:
72
+ """
73
+ logits: (B, T, C) or (T, B, C). We map argmax to chars, collapse repeats, remove blank.
74
+ We assume blank_id = len(alphabet).
75
+ """
76
+ if logits.dim() == 3 and logits.shape[0] != 1 and logits.shape[1] == 1:
77
+ # rare shape, just permute if needed
78
+ pass
79
+ if logits.shape[0] == 1:
80
+ logits = logits.squeeze(0) # (T, C)
81
+ elif logits.shape[1] == 1:
82
+ logits = logits[:,0,:] # (T, C)
83
+ probs = logits.softmax(-1)
84
+ ids = probs.argmax(-1).tolist()
85
+ blank_id = len(alphabet)
86
+ out = []
87
+ prev = None
88
+ for i in ids:
89
+ if i != prev and i != blank_id:
90
+ out.append(alphabet[i] if i < len(alphabet) else "")
91
+ prev = i
92
+ return "".join(out)
93
+
94
+ def recognize_word_crops(model, crops: List[Image.Image], tfm, arch_name: str, alphabet: str) -> List[str]:
95
+ texts = []
96
+ with torch.no_grad():
97
+ for im in crops:
98
+ x = tfm(im).unsqueeze(0).to(device)
99
+ y = model(x)
100
+ if arch_name == "Swin_CTC":
101
+ # expect CTC logits [B, T, C] or [T, B, C]
102
+ if y.dim() == 3 and y.shape[0] == 1: # [1, T, C]
103
+ logits = y[0] # [T, C]
104
+ elif y.dim() == 3 and y.shape[1] == 1: # [T, 1, C]
105
+ logits = y[:,0,:]
106
+ else:
107
+ logits = y
108
+ txt = greedy_ctc_decode(logits, alphabet)
109
+ else:
110
+ # VED: if returns token ids/logits, plug your repo's decoding here.
111
+ # Fallback: argmax over last dim per step and map ids to alphabet (no blank).
112
+ if y.dim() == 3 and y.shape[0] == 1:
113
+ y = y[0]
114
+ ids = y.argmax(-1).tolist()
115
+ txt = "".join(alphabet[i] if i < len(alphabet) else "" for i in ids).strip()
116
+ texts.append(txt)
117
+ return texts
118
+
119
+ def detect_words(img: Image.Image, lang="eng") -> pd.DataFrame:
120
+ df = pytesseract.image_to_data(img, lang=lang, output_type=Output.DATAFRAME)
121
+ df = df.dropna(subset=["text"]).reset_index(drop=True)
122
+ df["x2"] = df["left"] + df["width"]
123
+ df["y2"] = df["top"] + df["height"]
124
+ return df[df["conf"] > -1]
125
+
126
+ def crop_words(img: Image.Image, df: pd.DataFrame) -> List[Tuple[Image.Image, Dict]]:
127
+ crops, metas = [], []
128
+ for _, r in df.iterrows():
129
+ if str(r["text"]).strip() == "":
130
+ continue
131
+ box = (int(r["left"]), int(r["top"]), int(r["x2"]), int(r["y2"]))
132
+ c = img.crop(box)
133
+ crops.append(c)
134
+ metas.append({"box": box})
135
+ return crops, metas
136
+
137
+ # ---------------- key fields & table (same logic as earlier Tesseract app) ----------------
138
  CURRENCY = r"(?P<curr>USD|CAD|EUR|GBP|\$|C\$|€|£)?"
139
  MONEY = rf"{CURRENCY}\s?(?P<amt>\d{{1,3}}(?:[,]\d{{3}})*(?:[.]\d{{2}})?)"
 
140
  DATE = r"(?P<date>(?:\d{4}[-/]\d{1,2}[-/]\d{1,2})|(?:\d{1,2}[-/]\d{1,2}[-/]\d{2,4})|(?:[A-Za-z]{3,9}\s+\d{1,2},\s*\d{2,4}))"
141
+ INV_PAT = r"(?:invoice\s*(?:no\.?|#|number)?\s*[:\-]?\s*(?P<inv>[A-Z0-9\-_/]{4,}))"
142
  PO_PAT = r"(?:po\s*(?:no\.?|#|number)?\s*[:\-]?\s*(?P<po>[A-Z0-9\-_/]{3,}))"
143
  TOTAL_PAT = rf"(?:\b(total(?:\s*amount)?|amount\s*due|grand\s*total)\b.*?{MONEY})"
144
  SUBTOTAL_PAT = rf"(?:\bsub\s*total\b.*?{MONEY})"
145
  TAX_PAT = rf"(?:\b(tax|gst|vat|hst)\b.*?{MONEY})"
146
 
 
 
 
 
147
  def parse_fields(fulltext: str):
 
148
  t = re.sub(r"[ \t]+", " ", fulltext)
149
  t = re.sub(r"\n{2,}", "\n", t)
150
+ out = {"invoice_number":None,"invoice_date":None,"po_number":None,"subtotal":None,"tax":None,"total":None,"currency":None}
151
+ m = re.search(INV_PAT, t, re.I); out["invoice_number"] = m.group("inv") if m else None
152
+ m = re.search(PO_PAT, t, re.I); out["po_number"] = m.group("po") if m else None
153
+ m = re.search(rf"(invoice\s*date[:\-\s]*){DATE}", t, re.I)
154
+ out["invoice_date"] = (m.group("date") if m else (re.search(DATE, t, re.I).group("date") if re.search(DATE, t, re.I) else None))
155
+ m = re.search(SUBTOTAL_PAT, t, re.I|re.S);
156
+ if m: out["subtotal"], out["currency"] = m.group("amt").replace(",",""), m.group("curr") or out["currency"]
157
+ m = re.search(TAX_PAT, t, re.I|re.S);
158
+ if m: out["tax"], out["currency"] = m.group("amt").replace(",",""), m.group("curr") or out["currency"]
159
+ m = re.search(TOTAL_PAT, t, re.I|re.S);
160
+ if m: out["total"], out["currency"] = m.group("amt").replace(",",""), m.group("curr") or out["currency"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  if out["currency"] in ["$", "C$", "€", "£"]:
162
+ out["currency"] = {"$":"USD", "C$":"CAD", "€":"EUR", "£":"GBP"}[out["currency"]]
 
 
163
  return out
164
 
165
+ HEAD_CANDIDATES = ["description","item","qty","quantity","price","unit","rate","amount","total"]
166
+ def items_from_wordgrid(df: pd.DataFrame) -> pd.DataFrame:
167
+ # Group into lines
168
+ df = df.copy()
169
+ df["cx"] = df["left"] + 0.5*df["width"]
170
+ df["cy"] = df["top"] + 0.5*df["height"]
 
 
171
  lines = []
172
+ for (b,p,l), g in df.groupby(["block_num","par_num","line_num"]):
173
+ text = " ".join([t for t in g["text"].astype(str) if t.strip()])
174
  if text.strip():
175
+ lines.append({
176
+ "block_num":b,"par_num":p,"line_num":l,
 
 
177
  "text": text.lower(),
178
+ "top": g["top"].min(), "bottom": (g["top"]+g["height"]).max(),
179
+ "left": g["left"].min(), "right": (g["left"]+g["width"]).max(),
180
+ "words": g.sort_values("cx")[["cx","left","top","width","height"]].values.tolist()
181
+ })
 
 
182
  L = pd.DataFrame(lines)
183
+ if L.empty: return pd.DataFrame()
184
+ L["score"] = L["text"].apply(lambda s: sum(1 for h in HEAD_CANDIDATES if h in s))
185
+ headers = L[L["score"]>=2].sort_values(["score","top"], ascending=[False,True])
186
+ if headers.empty: return pd.DataFrame()
187
+ H = headers.iloc[0]
188
+ header_y = H["bottom"] + 4
189
+
190
+ # choose column centers from header words positions
191
+ # we reuse df within header band
192
+ header_band = df[(df["top"]>=H["top"]-5) & ((df["top"]+df["height"])<=H["bottom"]+5)]
193
+ header_band = header_band.sort_values("left")
194
+ col_x = header_band["left"].tolist()
195
+ if len(col_x)<2: return pd.DataFrame()
196
+ # region below header until totals
197
+ below = df[df["top"]>header_y].copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  totals_mask = below["text"].str.lower().str.contains(r"(sub\s*total|amount\s*due|total|grand\s*total|balance)", regex=True, na=False)
199
  if totals_mask.any():
200
+ stop_y = below.loc[totals_mask,"top"].min()
201
+ below = below[below["top"]<stop_y-4]
202
+ rows = []
 
 
 
 
 
203
  for (b,p,l), g in below.groupby(["block_num","par_num","line_num"]):
204
+ if g["text"].astype(str).str.strip().eq("").all(): continue
205
+ g = g.sort_values("left")
206
+ # assign to nearest header word x
207
+ xs = np.array(col_x)
208
+ buckets = {i:[] for i in range(len(xs))}
209
+ for _,w in g.iterrows():
210
+ idx = int(np.abs(xs - w["left"]).argmin())
211
+ buckets[idx].append(str(w["text"]))
212
+ vals = [" ".join(buckets.get(i,[])).strip() for i in range(len(xs))]
213
+ rows.append(vals)
214
+ if not rows: return pd.DataFrame()
215
+ df_rows = pd.DataFrame(rows).fillna("")
216
+ # try to name columns
217
+ names = []
218
+ for i, w in enumerate(header_band["text"].tolist()[:df_rows.shape[1]]):
219
+ wl = w.lower()
220
+ if "desc" in wl or wl in ["item","description"]:
221
+ names.append("description")
222
+ elif wl in ["qty","quantity"]:
223
+ names.append("quantity")
224
+ elif "unit" in wl or "rate" in wl or "price" in wl:
225
+ names.append("unit_price")
226
+ elif "amount" in wl or "total" in wl:
227
+ names.append("line_total")
228
+ else:
229
+ names.append(f"col_{i}")
230
+ df_rows.columns = names
231
+ # drop empty lines
232
+ df_rows = df_rows[~(df_rows.fillna("").apply(lambda r: "".join(r.values), axis=1).str.strip()=="")]
233
+ return df_rows.reset_index(drop=True)
234
+
235
+ # ========================= APP =========================
236
+ st.title("Invoice Extraction ViT recognizer (dparres) + Tesseract detector")
237
+
238
+ up = st.file_uploader("Upload an invoice (PDF/JPG/PNG)", type=["pdf","png","jpg","jpeg"])
 
 
 
 
 
 
 
 
 
 
 
 
239
  if not up:
240
+ st.info("Upload a scanned invoice to begin.")
241
  st.stop()
242
 
243
  pages = load_pages(up.read(), up.name)
 
 
244
 
245
+ # load model once
246
+ num_classes = len(alphabet) + (1 if arch=="Swin_CTC" else 0) # add CTC blank for Swin_CTC
247
+ assert os.path.exists(ckpt_path), f"Checkpoint not found: {ckpt_path}"
248
+ model = load_pdrt(arch, ckpt_path, num_classes)
249
+ tfm = build_transform(img_h, img_w)
250
+
251
+ page_idx = 0
252
  if len(pages) > 1:
253
+ page_idx = st.number_input("Page", 1, len(pages), 1) - 1
254
+ img = pages[page_idx]
 
 
255
 
256
+ col1, col2 = st.columns([1.1,1.3], gap="large")
257
 
258
+ with col1:
259
  st.subheader("Preview")
260
+ st.image(img, use_column_width=True)
261
+ det_img = preprocess_for_detection(img)
262
+ with st.expander("Detection view"):
263
+ st.image(det_img, use_column_width=True)
264
+
265
+ with col2:
266
+ st.subheader("OCR & Extraction")
267
+ # 1) detect words (boxes only)
268
+ det_df = detect_words(det_img, lang=det_lang)
269
+
270
+ # 2) crop & recognize each word via ViT recognizer
271
+ crops, metas = crop_words(det_img, det_df)
272
+ texts = recognize_word_crops(model, crops, tfm, arch, alphabet)
273
+
274
+ # 3) stitch line-by-line using tesseract line indices
275
+ det_df = det_df.reset_index(drop=True)
276
+ det_df["pred"] = texts
277
+ grouped = det_df.groupby(["block_num","par_num","line_num"])
278
+ lines = []
279
+ for _, g in grouped:
280
+ g = g.sort_values("left")
281
+ line = " ".join([t for t in g["pred"].tolist() if t])
282
+ lines.append(line)
283
+ full_text = "\n".join([ln for ln in lines if ln.strip()])
284
+
285
+ if show_boxes:
286
+ st.caption("First 15 predicted words")
287
+ st.write(det_df[["left","top","width","height","text","pred"]].head(15))
288
+
289
+ # 4) key fields
290
+ key_fields = parse_fields(full_text)
291
+ k1,k2,k3 = st.columns(3)
292
  with k1:
293
  st.write(f"**Invoice #:** {key_fields.get('invoice_number') or '—'}")
294
  st.write(f"**Invoice Date:** {key_fields.get('invoice_date') or '—'}")
 
301
  cur = key_fields.get('currency') or ''
302
  st.write(f"**Total:** {tot} {cur}".strip())
303
 
304
+ # 5) line items (geometry heuristic)
305
+ items = items_from_wordgrid(det_df.assign(text=det_df["pred"]))
306
+ st.markdown("**Line Items**")
307
  if items.empty:
308
+ st.caption("No line items confidently detected.")
309
  else:
310
  st.dataframe(items, use_container_width=True)
311
 
312
+ # 6) downloads
313
  result = {
314
+ "file": up.name, "page": page_idx+1,
315
  "key_fields": key_fields,
316
  "items": items.to_dict(orient="records") if not items.empty else [],
317
+ "full_text": full_text
318
  }
319
+ st.download_button("Download JSON", data=json.dumps(result, indent=2), file_name="invoice_extraction.json", mime="application/json")
 
320
  if not items.empty:
321
+ st.download_button("Download Items CSV", data=items.to_csv(index=False), file_name="invoice_items.csv", mime="text/csv")