Bhuvi13 commited on
Commit
098f047
·
verified ·
1 Parent(s): 55e21c4

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +780 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,782 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import os
2
+ # --- Fix: ensure HOME is writable before Streamlit initializes ---
3
+ from pathlib import Path
4
+
5
+ _home = os.environ.get("HOME", "")
6
+ if _home in ("", "/", None):
7
+ # Prefer the repo working directory if writable, otherwise use /tmp
8
+ repo_dir = os.getcwd()
9
+ safe_home = repo_dir if os.access(repo_dir, os.W_OK) else "/tmp"
10
+ os.environ["HOME"] = safe_home
11
+ print(f"[startup] HOME not set or unwritable — setting HOME={safe_home}")
12
+
13
+ # Ensure the .streamlit folder exists under HOME so Streamlit won't try to write to '/'
14
+ streamlit_dir = Path(os.environ["HOME"]) / ".streamlit"
15
+ try:
16
+ streamlit_dir.mkdir(parents=True, exist_ok=True)
17
+ print(f"[startup] ensured {streamlit_dir}")
18
+ except Exception as e:
19
+ print(f"[startup] WARNING: could not create {streamlit_dir}: {e}")
20
+
21
+ import json
22
+ from io import BytesIO
23
+ from datetime import datetime
24
+ from pathlib import Path
25
+ import hashlib
26
+
27
  import streamlit as st
28
+ import pandas as pd
29
+ from PIL import Image
30
+ from huggingface_hub import login
31
+
32
+ # ---------------------------
33
+ # UI: main
34
+ # ---------------------------
35
+ st.set_page_config(page_title="Invoice Extractor (Donut)", layout="wide")
36
+ st.title("Invoice Extraction")
37
+ # Reduce top margin and tighten layout
38
+ st.markdown(
39
+ """
40
+ <style>
41
+ /* Reduce top padding of main block */
42
+ .stApp {
43
+ background-color: #E8E8E8 !important;
44
+ }
45
+ div.block-container {
46
+ padding-top: 2rem;
47
+ padding-bottom: 1rem;
48
+ }
49
+ /* Tighten title spacing */
50
+ h1 {
51
+ margin-top: 0.4rem !important;
52
+ margin-bottom: 0.4rem !important;
53
+ background-color: #E8E8E8 !important;
54
+ }
55
+ /* Reduce gap between columns */
56
+ [data-testid="column"] {
57
+ padding-top: 0.5rem;
58
+ background-color: #E8E8E8 !important;
59
+ }
60
+ </style>
61
+ """,
62
+ unsafe_allow_html=True
63
+ )
64
+
65
+ # --- Secure token handling: prefer env var or Streamlit secrets; never hardcode or save the token ---
66
+ # Safe token retrieval: prefer env var, then Streamlit secrets if available, else None
67
+ # --- Secure token handling: prefer session-state -> env var -> Streamlit secrets; never hardcode or commit token ---
68
+ from pathlib import Path
69
+
70
+ # --- Robust token retrieval (session -> env -> secrets-if-file-exists) ---
71
+ def _get_hf_token():
72
+ # 0) In-memory token from an earlier interactive login (preferred during dev)
73
+ if st.session_state.get("_hf_token"):
74
+ return st.session_state.get("_hf_token"), "session"
75
+
76
+ # 1) Environment variable (preferred for deployments)
77
+ env_tok = os.getenv("HF_TOKEN")
78
+ if env_tok:
79
+ return env_tok, "env"
80
+
81
+ # 2) Only try Streamlit secrets if a secrets.toml actually exists (avoids noisy message)
82
+ try:
83
+ project_secrets = Path(".streamlit/secrets.toml")
84
+ user_secrets = Path.home() / ".streamlit" / "secrets.toml"
85
+ if project_secrets.exists() or user_secrets.exists():
86
+ sec = st.secrets.get("HF_TOKEN")
87
+ if sec:
88
+ return sec, "secrets"
89
+ except Exception:
90
+ pass
91
+
92
+ # nothing found
93
+ return None, None
94
+
95
+ # get token and its source
96
+ hf_token, hf_token_source = _get_hf_token()
97
+
98
+
99
+ # --- Interactive login fallback (development) ---
100
+ if hf_token is None:
101
+ st.subheader("Login Token🔑")
102
+ token_input = st.text_input("Enter your Login token (starts with 'hf_'):", type="password")
103
+ if token_input:
104
+ if not token_input.startswith("hf_"):
105
+ st.error("Invalid token format. Token must start with 'hf_'.")
106
+ st.stop()
107
+ try:
108
+ login(token_input)
109
+ # store only in-memory for this session (not on disk)
110
+ st.session_state["_hf_token"] = token_input
111
+ st.session_state.logged_in = True
112
+ st.success("Logged in successfully. Loading model...")
113
+ st.rerun()
114
+ except Exception as e:
115
+ st.error(f"Failed to log in: {e}")
116
+ st.stop()
117
+ else:
118
+ st.warning("Provide a token via the UI or set HF_TOKEN as an environment variable.")
119
+ st.stop()
120
+ else:
121
+ # ensure HF client is logged-in for env/secrets/session tokens
122
+ try:
123
+ login(hf_token)
124
+ st.session_state.logged_in = True
125
+ # OPTIONAL debug: show token source (no token value)
126
+ _ = st.query_params # touch the query params (no-op) to keep UI in sync without using deprecated API
127
+ # noop, but keeps UI in sync
128
+ # st.info(f"Token source: {hf_token_source}") # un-comment for debugging
129
+ except Exception as e:
130
+ st.error(f"Failed to log in with {hf_token_source or 'unknown'} token: {e}")
131
+ st.stop()
132
+
133
+
134
+
135
+
136
+ # ---------------------------
137
+ # Configuration (edit these)
138
+ # ---------------------------
139
+ HF_MODEL_ID = "Bhuvi13/model-V7" # your HF model id
140
+ TASK_PROMPT = "<s_cord-v2>" # your decoder prompt used during training
141
+
142
+ # ---------------------------
143
+ # Helper: load model & processor (cached)
144
+ # ---------------------------
145
+ @st.cache_resource(show_spinner=False)
146
+ def load_model_and_processor(hf_model_id: str, task_prompt: str):
147
+ """
148
+ Lazily import torch/transformers/donut and load model + processor.
149
+ This prevents Streamlit's watcher from touching torch internals during import-time.
150
+ """
151
+ try:
152
+ # lazy imports
153
+ import torch
154
+ from transformers import VisionEncoderDecoderModel, DonutProcessor
155
+ except Exception as e:
156
+ raise RuntimeError(f"Failed to import ML libraries: {e}")
157
+
158
+ try:
159
+ processor = DonutProcessor.from_pretrained(hf_model_id)
160
+ model = VisionEncoderDecoderModel.from_pretrained(hf_model_id)
161
+ except Exception as e:
162
+ raise RuntimeError(
163
+ f"Failed to load model/processor from Hugging Face ({hf_model_id}). "
164
+ "Make sure your HF token is available and model id is correct.\n"
165
+ f"Original error: {e}"
166
+ )
167
+
168
+ model.eval()
169
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
170
+ model.to(device)
171
+
172
+ with torch.no_grad():
173
+ decoder_input_ids = processor.tokenizer(
174
+ task_prompt,
175
+ add_special_tokens=False,
176
+ return_tensors="pt"
177
+ ).input_ids.to(device)
178
+
179
+ return processor, model, device, decoder_input_ids
180
+
181
+
182
+ def run_inference_on_image(image: Image.Image, processor, model, device, decoder_input_ids):
183
+ """
184
+ Lazily uses torch to run inference on a single PIL.Image.
185
+ Meant to be called after the model/processor are loaded.
186
+ """
187
+ import torch # lazy import ensures torch isn't touched at module-import time
188
+
189
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)
190
+
191
+ gen_kwargs = dict(
192
+ pixel_values=pixel_values,
193
+ decoder_input_ids=decoder_input_ids,
194
+ max_length=1536,
195
+ num_beams=1,
196
+ early_stopping=False,
197
+ )
198
+
199
+ with torch.no_grad():
200
+ generated_ids = model.generate(**gen_kwargs)
201
+
202
+ raw_pred = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
203
+ cleaned = (raw_pred
204
+ .replace(processor.tokenizer.eos_token or "", "")
205
+ .replace(processor.tokenizer.pad_token or "", "")
206
+ .strip())
207
+ token2json_out = processor.token2json(cleaned)
208
+
209
+ if isinstance(token2json_out, str):
210
+ try:
211
+ pred_dict = json.loads(token2json_out)
212
+ except Exception:
213
+ pred_dict = token2json_out
214
+ else:
215
+ pred_dict = token2json_out
216
+
217
+ return pred_dict
218
+
219
+
220
+ # ---------------------------
221
+ # Helper: map donut output to our UI schema
222
+ # (kept unchanged from your original)
223
+ # ---------------------------
224
+ def map_prediction_to_ui(pred):
225
+ import json, re
226
+ from datetime import datetime
227
+
228
+ def safe_json_load(s):
229
+ if s is None:
230
+ return None
231
+ if isinstance(s, (dict, list)):
232
+ return s
233
+ if isinstance(s, str):
234
+ try:
235
+ return json.loads(s)
236
+ except Exception:
237
+ try:
238
+ t = s.strip()
239
+ t = t.replace("\\'", "'").replace('\"{', '{').replace('}\"', '}')
240
+ return json.loads(t)
241
+ except Exception:
242
+ return None
243
+ return None
244
+
245
+ def clean_number(x):
246
+ if x is None:
247
+ return 0.0
248
+ if isinstance(x, (int, float)):
249
+ return float(x)
250
+ s = str(x).strip()
251
+ if s == "":
252
+ return 0.0
253
+ s = re.sub(r"[,\s]", "", s)
254
+ s = re.sub(r"[^\d\.\-]", "", s)
255
+ if s in ("", ".", "-", "-."):
256
+ return 0.0
257
+ try:
258
+ return float(s)
259
+ except Exception:
260
+ return 0.0
261
+
262
+ def parse_date(s):
263
+ if not s:
264
+ return ""
265
+ s = str(s).strip()
266
+ for fmt in ("%Y-%m-%d", "%d-%m-%Y", "%d/%m/%Y", "%m/%d/%Y", "%d.%m.%Y"):
267
+ try:
268
+ return datetime.strptime(s, fmt).strftime("%Y-%m-%d")
269
+ except Exception:
270
+ pass
271
+ m = re.match(r"^(\d{1,2})/(\d{1,2})/(\d{4})$", s)
272
+ if m:
273
+ a, b, y = int(m.group(1)), int(m.group(2)), int(m.group(3))
274
+ if a > 12:
275
+ d, mo = a, b
276
+ else:
277
+ mo, d = a, b
278
+ try:
279
+ return datetime(year=y, month=mo, day=d).strftime("%Y-%m-%d")
280
+ except Exception:
281
+ return s
282
+ return s
283
+
284
+ ui = {
285
+ "Invoice Number": "",
286
+ "Invoice Date": "",
287
+ "Due Date": "",
288
+ "Currency": "",
289
+ "Subtotal": 0.0,
290
+ "Tax Percentage": 0.0,
291
+ "Total Tax": 0.0,
292
+ "Total Amount": 0.0,
293
+ "Sender": {"Name": "", "Address": ""},
294
+ "Recipient": {"Name": "", "Address": ""},
295
+ "Sender Name": "",
296
+ "Sender Address": "",
297
+ "Recipient Name": "",
298
+ "Recipient Address": "",
299
+ "Bank Details": {},
300
+ "Itemized Data": []
301
+ }
302
+
303
+ if pred is None:
304
+ return ui
305
+
306
+ if isinstance(pred, str):
307
+ parsed = safe_json_load(pred)
308
+ if parsed is not None:
309
+ pred = parsed
310
+
311
+ gt = None
312
+ if isinstance(pred, dict):
313
+ if "gt_parse" in pred:
314
+ gp = pred["gt_parse"]
315
+ gp_parsed = safe_json_load(gp)
316
+ gt = gp_parsed if gp_parsed is not None else (gp if isinstance(gp, dict) else {})
317
+ else:
318
+ gt = pred
319
+ else:
320
+ return ui
321
+
322
+ header = gt.get("header") or {}
323
+ items = gt.get("items") or []
324
+ summary = gt.get("summary") or {}
325
+
326
+ ui["Invoice Number"] = header.get("invoice_no") or header.get("invoice_number") or ui["Invoice Number"]
327
+ ui["Invoice Date"] = str(header.get("invoice_date") or header.get("inv_date") or "")
328
+ ui["Due Date"] = str(header.get("due_date") or header.get("due") or "")
329
+
330
+ ui["Sender Name"] = header.get("sender_name") or header.get("seller_name") or header.get("from_name") or ui["Sender Name"]
331
+ ui["Sender Address"] = header.get("sender_addr") or header.get("sender_address") or header.get("seller_addr") or ui["Sender Address"]
332
+ ui["Recipient Name"] = header.get("rcpt_name") or header.get("recipient_name") or header.get("to_name") or ui["Recipient Name"]
333
+ ui["Recipient Address"] = header.get("rcpt_addr") or header.get("rcpt_address") or header.get("recipient_address") or ui["Recipient Address"]
334
+
335
+ ui["Sender"] = {"Name": ui["Sender Name"], "Address": ui["Sender Address"]}
336
+ ui["Recipient"] = {"Name": ui["Recipient Name"], "Address": ui["Recipient Address"]}
337
+
338
+ bank = {}
339
+ if header.get("bank_name"):
340
+ bank["bank_name"] = str(header.get("bank_name")).strip()
341
+ if header.get("bank_acc_no"):
342
+ bank["bank_account_number"] = str(header.get("bank_acc_no")).strip()
343
+ if header.get("bank_account_number"):
344
+ bank["bank_account_number"] = bank.get("bank_account_number") or str(header.get("bank_account_number")).strip()
345
+ if header.get("bank_iban"):
346
+ bank["bank_iban"] = str(header.get("bank_iban")).strip()
347
+ if header.get("bank_routing"):
348
+ bank["bank_routing"] = str(header.get("bank_routing")).strip()
349
+ if header.get("bank_swift"):
350
+ bank["bank_swift"] = str(header.get("bank_swift")).strip()
351
+ if header.get("bank_branch"):
352
+ bank["bank_branch"] = str(header.get("bank_branch")).strip()
353
+ if header.get("bank_acc_name"):
354
+ bank["bank_acc_name"] = str(header.get("bank_acc_name")).strip()
355
+ hb = header.get("bank")
356
+ if isinstance(hb, dict):
357
+ for k, v in hb.items():
358
+ if not v:
359
+ continue
360
+ lk = k.lower()
361
+ if "iban" in lk:
362
+ bank["bank_iban"] = bank.get("bank_iban") or str(v).strip()
363
+ elif "swift" in lk:
364
+ bank["bank_swift"] = bank.get("bank_swift") or str(v).strip()
365
+ elif "acc" in lk or "account" in lk:
366
+ bank["bank_account_number"] = bank.get("bank_account_number") or str(v).strip()
367
+ elif "name" in lk and "bank" in lk:
368
+ bank["bank_name"] = bank.get("bank_name") or str(v).strip()
369
+ elif "branch" in lk:
370
+ bank["bank_branch"] = bank.get("bank_branch") or str(v).strip()
371
+ elif "acc_name" in lk or "account_name" in lk:
372
+ bank["bank_acc_name"] = bank.get("bank_acc_name") or str(v).strip()
373
+
374
+ ui["Bank Details"] = bank
375
+
376
+ ui["Subtotal"] = clean_number(summary.get("subtotal") or summary.get("sub_total") or summary.get("subTotal"))
377
+ ui["Tax Percentage"] = clean_number(summary.get("tax_rate") or summary.get("taxRate") or summary.get("tax_percentage"))
378
+ ui["Total Tax"] = clean_number(summary.get("tax_amount") or summary.get("tax") or summary.get("taxAmount"))
379
+ ui["Total Amount"] = clean_number(summary.get("total_amount") or summary.get("grand_total") or summary.get("total") or summary.get("amount_total"))
380
+ ui["Currency"] = summary.get("currency") or header.get("currency") or ui["Currency"] or ""
381
+
382
+ normalized_items = []
383
+
384
+ if isinstance(items, str):
385
+ parsed_items = safe_json_load(items)
386
+ if parsed_items is not None:
387
+ items = parsed_items
388
+
389
+ if isinstance(items, dict):
390
+ if any(isinstance(v, list) for v in items.values()):
391
+ list_cols = {k: v for k, v in items.items() if isinstance(v, list)}
392
+ max_len = max((len(v) for v in list_cols.values()), default=0)
393
+ for i in range(max_len):
394
+ row = {}
395
+ for k, v in items.items():
396
+ if isinstance(v, list):
397
+ row[k] = v[i] if i < len(v) else ""
398
+ else:
399
+ row[k] = v
400
+ normalized_items.append(row)
401
+ else:
402
+ normalized_items.append(items)
403
+ elif isinstance(items, list):
404
+ normalized_items = items
405
+ else:
406
+ normalized_items = []
407
+
408
+ item_rows = []
409
+ for it in normalized_items:
410
+ if not isinstance(it, dict):
411
+ item_rows.append({"Description": str(it), "Quantity": 1, "Unit Price": 0.0, "Amount": 0.0})
412
+ continue
413
+ desc = it.get("descriptions") or it.get("description") or it.get("desc") or it.get("item") or it.get("name") or ""
414
+ qty = it.get("quantity") or it.get("qty") or it.get("Quantity") or ""
415
+ unit = it.get("unit_price") or it.get("unitPrice") or it.get("price") or ""
416
+ amt = it.get("amount") or it.get("Line_total") or it.get("line_total") or it.get("total") or ""
417
+
418
+ item_rows.append({
419
+ "Description": str(desc).strip(),
420
+ "Quantity": float(clean_number(qty)),
421
+ "Unit Price": float(clean_number(unit)),
422
+ "Amount": float(clean_number(amt))
423
+ })
424
+
425
+ ui["Itemized Data"] = item_rows
426
+
427
+ return ui
428
+
429
+
430
+ # show model load status and try to load model lazily
431
+ try:
432
+ with st.spinner("Loading model & processor (cached) ..."):
433
+ processor, model, device, decoder_input_ids = load_model_and_processor(HF_MODEL_ID, TASK_PROMPT)
434
+ #st.success("Model loaded (cached).")
435
+ except Exception as e:
436
+ st.error("Could not load model automatically. See details below.")
437
+ st.exception(e)
438
+ st.stop()
439
+
440
+ # initialize session state variables
441
+ if "extracted_data" not in st.session_state:
442
+ st.session_state.extracted_data = None
443
+ st.session_state.raw_prediction = None
444
+ if "uploaded_file_hash" not in st.session_state:
445
+ st.session_state.uploaded_file_hash = None
446
+ if "show_results" not in st.session_state:
447
+ st.session_state.show_results = False
448
+ if "last_image" not in st.session_state:
449
+ st.session_state.last_image = None
450
+ if "is_running_inference" not in st.session_state:
451
+ st.session_state.is_running_inference = False
452
+
453
+ # ---------------------------
454
+ # SHOW UPLOAD UI ONLY IF NOT RUNNING INFERENCE AND NOT IN RESULTS
455
+ # ---------------------------
456
+ if (not st.session_state.show_results and
457
+ not st.session_state.is_running_inference and
458
+ st.session_state.uploaded_file_hash is None):
459
+ st.markdown(
460
+ """
461
+ Upload an invoice image (png/jpg/jpeg). The app will run your Donut model and map detected fields into
462
+ an editable UI. After editing you can download the extracted JSONL / CSV.
463
+ """
464
+ )
465
+
466
+ st.header("📤 Upload Invoice")
467
+
468
+ uploaded_file = st.file_uploader("Upload invoice image (png/jpg/jpeg/pdf)", type=["png", "jpg", "jpeg", "pdf"], accept_multiple_files=False)
469
+
470
+ # allow user to optionally paste a local path or sample file (for debug)
471
+ col_top_1, col_top_2 = st.columns([1, 3])
472
+ #with col_top_1:
473
+ #if st.button("Use example image (if available)"):
474
+ #st.info("No example included. Please upload an image.")
475
+ with col_top_2:
476
+ st.write(" ")
477
+
478
+ if uploaded_file is not None:
479
+ # Read bytes and compute hash
480
+ uploaded_bytes = uploaded_file.read()
481
+ file_hash = hashlib.sha256(uploaded_bytes).hexdigest()
482
+
483
+ # Render image or first PDF page
484
+ image = None
485
+ is_pdf = uploaded_file.name.lower().endswith('.pdf') or (hasattr(uploaded_file, 'type') and uploaded_file.type == 'application/pdf')
486
+ if is_pdf:
487
+ try:
488
+ from pdf2image import convert_from_bytes
489
+ pages = convert_from_bytes(uploaded_bytes, dpi=200)
490
+ if len(pages) > 0:
491
+ image = pages[0].convert("RGB")
492
+ st.session_state.last_image = image
493
+ else:
494
+ st.error("PDF has no pages.")
495
+ image = None
496
+ except Exception as e:
497
+ st.error("Could not render PDF. Ensure 'pdf2image' and poppler are installed.")
498
+ image = None
499
+ else:
500
+ try:
501
+ image = Image.open(BytesIO(uploaded_bytes)).convert("RGB")
502
+ st.session_state.last_image = image
503
+ except Exception as e:
504
+ st.error("Failed to open uploaded image.")
505
+ image = None
506
+
507
+ if image is not None:
508
+ # ✅ SET FLAG TO HIDE UPLOAD UI
509
+ st.session_state.is_running_inference = True
510
+
511
+ # ✅ RENDER THE SAME LAYOUT AS RESULTS PAGE — RIGHT COLUMN = LOADING TABS
512
+ left_col, right_col = st.columns([1, 1])
513
+
514
+ with left_col:
515
+ st.image(image, caption="Uploaded Invoice", use_container_width=True)
516
+ st.write(f"**File Hash:** {file_hash[:8]}...")
517
+
518
+ with right_col:
519
+ #st.subheader("📄 Extracted Invoice Details")
520
+ #st.caption(f"File Hash: {file_hash[:8]}... | Model: {HF_MODEL_ID}")
521
+
522
+ # Show identical tab structure during loading
523
+ placeholder_tabs = st.tabs([
524
+ "Invoice Details",
525
+ "Sender/Recipient info",
526
+ "Bank Details",
527
+ "Line Items"
528
+ ])
529
+
530
+ #for tab in placeholder_tabs:
531
+ #with tab:
532
+ #st.info("⏳ Extracting Invoice Details... Please wait.")
533
+ # Optional: show spinner inside each tab
534
+ #st.spinner("Processing...")
535
+
536
+ # ACTUALLY RUN INFERENCE
537
+ with st.spinner("⏳ Extracting Invoice Details... Please wait."):
538
+ try:
539
+ pred = run_inference_on_image(image, processor, model, device, decoder_input_ids)
540
+ except Exception as e:
541
+ st.session_state.inference_error = str(e)
542
+ pred = None
543
+
544
+ # Store results
545
+ st.session_state.uploaded_file_hash = file_hash
546
+ st.session_state.raw_prediction = pred
547
+
548
+ try:
549
+ mapped = map_prediction_to_ui(pred)
550
+ except Exception as e:
551
+ st.session_state.mapping_error = str(e)
552
+ mapped = {}
553
+
554
+ st.session_state.extracted_data = mapped
555
+ st.session_state.show_results = True
556
+ st.session_state.is_running_inference = False # 👈 RESET FLAG
557
+
558
+ st.success("✅ Extraction complete!")
559
+
560
+ # Rerun to show real editable form
561
+ st.rerun()
562
+ else:
563
+ st.error("Could not process uploaded file into an image.")
564
+
565
+ # If inference is running (e.g., after rerun or error), show only the layout
566
+ # ---------------------------
567
+ # INFERENCE IN PROGRESS — Show only left/right columns with loading UI
568
+ # ---------------------------
569
+ if not st.session_state.show_results and st.session_state.is_running_inference:
570
+ if st.session_state.last_image is not None:
571
+ left_col, right_col = st.columns([1, 1])
572
+ with left_col:
573
+ st.image(st.session_state.last_image, caption="Uploaded Invoice", use_container_width=True)
574
+ if st.session_state.uploaded_file_hash:
575
+ st.write(f"**File Hash:** {st.session_state.uploaded_file_hash[:8]}...")
576
+ with right_col:
577
+ #st.subheader("📄 Extracted Invoice Details")
578
+ #st.caption(f"File Hash: {st.session_state.uploaded_file_hash[:8]}... | Model: {HF_MODEL_ID}")
579
+ placeholder_tabs = st.tabs([
580
+ "Invoice Details",
581
+ "Sender/Recipient info",
582
+ "Bank Details",
583
+ "Line Items"
584
+ ])
585
+ for tab in placeholder_tabs:
586
+ with tab:
587
+ st.info("⏳ Still processing... Please wait.")
588
+ else:
589
+ st.warning("Inference in progress, but no image available. Please re-upload.")
590
+
591
+
592
+ # ---------------------------
593
+ # RESULTS READY — Show editable form + back button
594
+ # ---------------------------
595
+ elif st.session_state.show_results:
596
+ # ✅ Back Button — ONLY shown when results are ready
597
+ if st.button("⬅️ Back to Upload"):
598
+ st.session_state.show_results = False
599
+ st.session_state.extracted_data = None
600
+ st.session_state.raw_prediction = None
601
+ st.session_state.uploaded_file_hash = None
602
+ st.session_state.last_image = None
603
+ st.session_state.is_running_inference = False # 👈 Also reset this
604
+ st.rerun()
605
+
606
+ # Layout: two columns, image on left, form on right
607
+ left_col, right_col = st.columns([1, 1])
608
+
609
+ # LEFT: Show image
610
+ with left_col:
611
+ if st.session_state.last_image is not None:
612
+ st.image(st.session_state.last_image, caption="Uploaded Invoice", use_container_width=True)
613
+ st.write(f"**File Hash:** {st.session_state.uploaded_file_hash[:8]}...")
614
+
615
+ # 👇 RAW MODEL OUTPUT NOW APPEARS HERE, BELOW IMAGE
616
+ if st.session_state.get('raw_prediction') is not None:
617
+ with st.expander("🔍 Show raw model output"):
618
+ st.json(st.session_state.raw_prediction)
619
+
620
+ else:
621
+ st.warning("Image preview not available. Please re-upload.")
622
+
623
+ # provide a way to force re-run if needed
624
+ #if st.button("Re-run extraction for this file"):
625
+ #st.session_state.raw_prediction = None
626
+ #st.session_state.extracted_data = None
627
+ #st.rerun()
628
+
629
+ # RIGHT: Editable form
630
+ with right_col:
631
+ data = st.session_state.extracted_data
632
+
633
+ if data is None:
634
+ st.error("No data extracted. Something went wrong.")
635
+ else:
636
+ st.subheader("Editable Invoice Form")
637
+ tabs = st.tabs(["Invoice Details", "Sender/Recipient info", "Bank Details", "Line Items"])
638
+
639
+ st.markdown(
640
+ """
641
+ <style>
642
+ div[data-testid="stTabs"] > div > div {
643
+ padding-bottom: 5px !important;
644
+ margin-top: -5px !important;
645
+ background-color: #E8E8E8 !important;
646
+ }
647
+ .stTextInput, .stNumberInput, .stSelectbox, .stTextArea, .stDateInput {
648
+ margin-bottom: -10px !important;
649
+ padding-bottom: 5px !important;
650
+ }
651
+ div[data-testid="stTabs"] {
652
+ background-color: #E8E8E8 !important;
653
+ }
654
+ h3:first-of-type {
655
+ margin-top: -50px !important;
656
+ }
657
+ </style>
658
+ """,
659
+ unsafe_allow_html=True,
660
+ )
661
+
662
+ # ---------- Invoice Details ----------
663
+ with tabs[0]:
664
+ with st.container():
665
+ data['Invoice Number'] = st.text_input("Invoice Number", value=data.get('Invoice Number', ''), key="invoice_number")
666
+
667
+ # Invoice Date with calendar and callback
668
+ data['Invoice Date'] = st.text_input(
669
+ "Invoice Date",
670
+ value=str(data.get('Invoice Date', '')).strip(),
671
+ key="invoice_date_text"
672
+ )
673
+
674
+ # Due Date — preserve original format
675
+ data['Due Date'] = st.text_input(
676
+ "Due Date",
677
+ value=str(data.get('Due Date', '')).strip(),
678
+ key="due_date_text"
679
+ )
680
+
681
+ curr_options = ['USD', 'EUR', 'GBP', 'INR', 'Other']
682
+ curr_value = data.get('Currency', 'USD')
683
+ curr_index = curr_options.index(curr_value) if curr_value in curr_options else (len(curr_options) - 1)
684
+ new_curr = st.selectbox("Currency", options=curr_options, index=curr_index, key="currency_select")
685
+ if new_curr == 'Other':
686
+ new_curr = st.text_input("Specify Currency", value=data.get('Currency', ''), key="custom_currency")
687
+ data['Currency'] = new_curr
688
+
689
+ # numeric fields - safe conversion
690
+ def safe_number_input(label, value, key):
691
+ try:
692
+ v = float(value)
693
+ except Exception:
694
+ v = 0.0
695
+ return st.number_input(label, value=v, key=key)
696
+
697
+ data['Subtotal'] = safe_number_input("Subtotal", data.get('Subtotal', 0.0), "subtotal")
698
+ data['Tax Percentage'] = safe_number_input("Tax Percentage", data.get('Tax Percentage', 0.0), "tax_pct")
699
+ data['Total Tax'] = safe_number_input("Total Tax", data.get('Total Tax', 0.0), "total_tax")
700
+ data['Total Amount'] = safe_number_input("Total Amount", data.get('Total Amount', 0.0), "total_amount")
701
+
702
+ # ---------- Sender / Recipient ----------
703
+ with tabs[1]:
704
+ if 'Sender' not in data:
705
+ data['Sender'] = {'Name': '', 'Address': ''}
706
+ if 'Recipient' not in data:
707
+ data['Recipient'] = {'Name': '', 'Address': ''}
708
+ sender_info = data['Sender']
709
+ recipient_info = data['Recipient']
710
+
711
+ with st.container():
712
+ sender_info['Name'] = st.text_input("Sender Name*", value=sender_info.get('Name', ''), key="sender_name")
713
+ sender_info['Address'] = st.text_area("Sender Address*", value=sender_info.get('Address', ''), key="sender_address")
714
+
715
+ recipient_info['Name'] = st.text_input("Recipient Name*", value=recipient_info.get('Name', ''), key="recipient_name")
716
+ recipient_info['Address'] = st.text_area("Recipient Address*", value=recipient_info.get('Address', ''), key="recipient_address")
717
+
718
+ if st.button("⇄ Swap", help="Swap sender and recipient information"):
719
+ data['Sender'], data['Recipient'] = data['Recipient'], data['Sender']
720
+ st.session_state.extracted_data['Sender'] = data['Sender']
721
+ st.session_state.extracted_data['Recipient'] = data['Recipient']
722
+ st.rerun()
723
+
724
+ # ---------- Bank Details ----------
725
+ with tabs[2]:
726
+ bank_info = data.get('Bank Details', {}) or {}
727
+ with st.container():
728
+ bank_info['bank_name'] = st.text_input("Bank Name", value=bank_info.get('bank_name', ''), key="bank_name")
729
+ bank_info['bank_account_number'] = st.text_input("Account Number", value=bank_info.get('bank_account_number', '') or bank_info.get('bank_acc_no',''), key="bank_account")
730
+ bank_info['bank_acc_name'] = st.text_input("Bank Account Name", value=bank_info.get('bank_acc_name', '') or bank_info.get('bank_acc_name', ''), key="bank_acc_name")
731
+ bank_info['bank_iban'] = st.text_input("IBAN", value=bank_info.get('bank_iban', ''), key="iban")
732
+ bank_info['bank_swift'] = st.text_input("SWIFT Code", value=bank_info.get('bank_swift', ''), key="swift_code")
733
+ bank_info['bank_routing'] = st.text_input("Routing Number", value=bank_info.get('bank_routing', ''), key="routing")
734
+ bank_info['bank_branch'] = st.text_input("Branch", value=bank_info.get('bank_branch', ''), key="branch")
735
+ data['Bank Details'] = bank_info
736
+
737
+ # ---------- Line Items ----------
738
+ with tabs[3]:
739
+ file_hash = st.session_state.get("uploaded_file_hash", "")
740
+ editor_key = f"item_editor_{file_hash}"
741
+
742
+ if "extracted_data" in st.session_state and "Itemized Data" in st.session_state.extracted_data:
743
+ item_rows = st.session_state.extracted_data["Itemized Data"]
744
+ else:
745
+ item_rows = data.get('Itemized Data') or []
746
+
747
+ df = pd.DataFrame(item_rows)
748
+ for col in ["Description", "Quantity", "Unit Price", "Amount"]:
749
+ if col not in df.columns:
750
+ df[col] = ""
751
+
752
+ edited_df = st.data_editor(df, num_rows="dynamic", key=editor_key, use_container_width=True)
753
+
754
+ if len(edited_df) == 0:
755
+ st.info("No line items found in the invoice.")
756
+
757
+ # ---------- Save / Export ----------
758
+ st.markdown("---")
759
+ col_a, col_b, col_c = st.columns([1, 1, 1])
760
+ with col_a:
761
+ if st.button("Save to session"):
762
+ st.session_state.extracted_data = data
763
+ st.success("Saved to session_state.extracted_data")
764
+
765
+
766
+ # ---------------------------
767
+ # DEFAULT STATE — Show upload UI
768
+ # ---------------------------
769
+ else:
770
+ # This is the initial state: nothing running, no results
771
+ ## Upload UI is already rendered above — nothing more needed here
772
+ pass
773
+ #with col_b:
774
+ #jsonl_str = json.dumps(data, ensure_ascii=False)
775
+ #st.download_button("Download JSONL", jsonl_str.encode("utf-8"), file_name="extracted_invoice.jsonl", mime="application/json")
776
+ #with col_c:
777
+ #items_df = pd.DataFrame(data.get("Itemized Data", []))
778
+ #csv_bytes = items_df.to_csv(index=False).encode("utf-8")
779
+ #st.download_button("Download line-items CSV", csv_bytes, file_name="invoice_items.csv", mime="text/csv")
780
 
781
+ #with st.expander("Preview mapped data (for quick check)"):
782
+ #st.json(data)