Bhuvi13 commited on
Commit
1357bbc
·
verified ·
1 Parent(s): 47c5244

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +95 -87
src/streamlit_app.py CHANGED
@@ -196,7 +196,6 @@ def run_inference_on_image(image: Image.Image, processor, model, device, decoder
196
 
197
  # ---------------------------
198
  # Helper: map donut output to our UI schema
199
- # (kept unchanged from your original)
200
  # ---------------------------
201
  def map_prediction_to_ui(pred):
202
  import json, re
@@ -218,77 +217,6 @@ def map_prediction_to_ui(pred):
218
  except Exception:
219
  return None
220
  return None
221
- def flatten_invoice_to_rows(invoice_data) -> list:
222
- """
223
- Converts nested invoice data into a flat list of rows (one per line item),
224
- with invoice-level and sender/recipient/bank fields repeated in each row.
225
- """
226
- rows = []
227
- line_items = invoice_data.get("Itemized Data", [])
228
- if not line_items:
229
- # If no line items, create one row with invoice info only
230
- row = {
231
- "Invoice Number": invoice_data.get("Invoice Number", ""),
232
- "Invoice Date": invoice_data.get("Invoice Date", ""),
233
- "Due Date": invoice_data.get("Due Date", ""),
234
- "Currency": invoice_data.get("Currency", ""),
235
- "Subtotal": invoice_data.get("Subtotal", 0.0),
236
- "Tax Percentage": invoice_data.get("Tax Percentage", 0.0),
237
- "Total Tax": invoice_data.get("Total Tax", 0.0),
238
- "Total Amount": invoice_data.get("Total Amount", 0.0),
239
- "Sender Name": invoice_data.get("Sender", {}).get("Name", ""),
240
- "Sender Address": invoice_data.get("Sender", {}).get("Address", ""),
241
- "Recipient Name": invoice_data.get("Recipient", {}).get("Name", ""),
242
- "Recipient Address": invoice_data.get("Recipient", {}).get("Address", ""),
243
- }
244
- # Flatten bank details
245
- bank = invoice_data.get("Bank Details", {})
246
- for k, v in bank.items():
247
- row[f"bank_{k}"] = v
248
-
249
- # Add empty line item fields
250
- row.update({
251
- "Item Description": "",
252
- "Item Quantity": 0,
253
- "Item Unit Price": 0.0,
254
- "Item Amount": 0.0,
255
- })
256
- rows.append(row)
257
- return rows
258
-
259
- # For each line item, create a row with all invoice context
260
- for item in line_items:
261
- row = {
262
- "Invoice Number": invoice_data.get("Invoice Number", ""),
263
- "Invoice Date": invoice_data.get("Invoice Date", ""),
264
- "Due Date": invoice_data.get("Due Date", ""),
265
- "Currency": invoice_data.get("Currency", ""),
266
- "Subtotal": invoice_data.get("Subtotal", 0.0),
267
- "Tax Percentage": invoice_data.get("Tax Percentage", 0.0),
268
- "Total Tax": invoice_data.get("Total Tax", 0.0),
269
- "Total Amount": invoice_data.get("Total Amount", 0.0),
270
- "Sender Name": invoice_data.get("Sender", {}).get("Name", ""),
271
- "Sender Address": invoice_data.get("Sender", {}).get("Address", ""),
272
- "Recipient Name": invoice_data.get("Recipient", {}).get("Name", ""),
273
- "Recipient Address": invoice_data.get("Recipient", {}).get("Address", ""),
274
- }
275
-
276
- # Flatten bank details
277
- bank = invoice_data.get("Bank Details", {})
278
- for k, v in bank.items():
279
- row[f"bank_{k}"] = v
280
-
281
- # Add line item fields
282
- row.update({
283
- "Item Description": item.get("Description", ""),
284
- "Item Quantity": item.get("Quantity", 0),
285
- "Item Unit Price": item.get("Unit Price", 0.0),
286
- "Item Amount": item.get("Amount", 0.0),
287
- })
288
-
289
- rows.append(row)
290
-
291
- return rows
292
 
293
  def clean_number(x):
294
  if x is None:
@@ -474,6 +402,81 @@ def flatten_invoice_to_rows(invoice_data) -> list:
474
 
475
  return ui
476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  # Load model once
478
  try:
479
  with st.spinner("Loading model & processor (cached) ..."):
@@ -553,7 +556,11 @@ if not st.session_state.is_processing_batch and len(st.session_state.batch_resul
553
  mapped = map_prediction_to_ui(pred)
554
  except Exception as e:
555
  st.warning(f"Error processing {uploaded_file.name}: {str(e)}")
556
- pred, mapped = None, {}
 
 
 
 
557
 
558
  # Save to session state
559
  st.session_state.batch_results[file_hash] = {
@@ -561,7 +568,7 @@ if not st.session_state.is_processing_batch and len(st.session_state.batch_resul
561
  "image": image,
562
  "raw_pred": pred,
563
  "mapped_data": mapped,
564
- "edited_data": mapped.copy() # editable copy
565
  }
566
 
567
  progress_bar.progress((idx + 1) / len(uploaded_files))
@@ -711,6 +718,9 @@ elif len(st.session_state.batch_results) > 0:
711
  data['Bank Details'] = bank_info
712
 
713
  # ---------- Line Items ----------
 
 
 
714
  with tabs[3]:
715
  editor_key = f"item_editor_{selected_hash}"
716
 
@@ -720,29 +730,27 @@ elif len(st.session_state.batch_results) > 0:
720
  if col not in df.columns:
721
  df[col] = ""
722
 
723
- # FIXED: Use on_change to force immediate rerun edited_df is always fresh
 
724
  edited_df = st.data_editor(
725
  df,
726
  num_rows="dynamic",
727
  key=editor_key,
728
  use_container_width=True,
729
- on_change=trigger_rerun # 👈 This is the magic fix
730
  )
731
 
732
- # ✅ Safe to use — edited_df is updated after forced rerun
733
- data['Itemized Data'] = edited_df.to_dict('records')
734
-
735
- # ❗ OPTIONAL: Auto-calculate Amount = Quantity × Unit Price
736
- # Uncomment below if you want auto-calculation:
737
- # if "Quantity" in edited_df.columns and "Unit Price" in edited_df.columns:
738
- # edited_df["Amount"] = (edited_df["Quantity"] * edited_df["Unit Price"]).round(2)
739
- # data['Itemized Data'] = edited_df.to_dict('records')
740
-
741
  if len(edited_df) == 0:
742
  st.info("No line items found in the invoice.")
743
 
744
- # Save button (per file) — OPTIONAL, since edits auto-save via reference
 
 
745
  if st.button("💾 Save Edits for This File", key=f"save_{selected_hash}"):
 
 
 
 
746
  st.session_state.batch_results[selected_hash]["edited_data"] = data
747
  st.success(f"✅ Edits saved for {current['file_name']}")
748
 
@@ -798,7 +806,7 @@ elif len(st.session_state.batch_results) > 0:
798
  json_name = f"{Path(result['file_name']).stem}_extracted.json"
799
  zf.writestr(json_name, json_data)
800
 
801
- # Save FULL CSV (all data)
802
  rows = flatten_invoice_to_rows(result["edited_data"])
803
  full_df = pd.DataFrame(rows)
804
 
 
196
 
197
  # ---------------------------
198
  # Helper: map donut output to our UI schema
 
199
  # ---------------------------
200
  def map_prediction_to_ui(pred):
201
  import json, re
 
217
  except Exception:
218
  return None
219
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  def clean_number(x):
222
  if x is None:
 
402
 
403
  return ui
404
 
405
+ # ---------------------------
406
+ # Helper: flatten invoice to CSV rows
407
+ # ---------------------------
408
+ def flatten_invoice_to_rows(invoice_data) -> list:
409
+ """
410
+ Converts nested invoice data into a flat list of rows (one per line item),
411
+ with invoice-level and sender/recipient/bank fields repeated in each row.
412
+ """
413
+ rows = []
414
+ line_items = invoice_data.get("Itemized Data", [])
415
+ if not line_items:
416
+ # If no line items, create one row with invoice info only
417
+ row = {
418
+ "Invoice Number": invoice_data.get("Invoice Number", ""),
419
+ "Invoice Date": invoice_data.get("Invoice Date", ""),
420
+ "Due Date": invoice_data.get("Due Date", ""),
421
+ "Currency": invoice_data.get("Currency", ""),
422
+ "Subtotal": invoice_data.get("Subtotal", 0.0),
423
+ "Tax Percentage": invoice_data.get("Tax Percentage", 0.0),
424
+ "Total Tax": invoice_data.get("Total Tax", 0.0),
425
+ "Total Amount": invoice_data.get("Total Amount", 0.0),
426
+ "Sender Name": invoice_data.get("Sender", {}).get("Name", ""),
427
+ "Sender Address": invoice_data.get("Sender", {}).get("Address", ""),
428
+ "Recipient Name": invoice_data.get("Recipient", {}).get("Name", ""),
429
+ "Recipient Address": invoice_data.get("Recipient", {}).get("Address", ""),
430
+ }
431
+ # Flatten bank details
432
+ bank = invoice_data.get("Bank Details", {})
433
+ for k, v in bank.items():
434
+ row[f"bank_{k}"] = v
435
+
436
+ # Add empty line item fields
437
+ row.update({
438
+ "Item Description": "",
439
+ "Item Quantity": 0,
440
+ "Item Unit Price": 0.0,
441
+ "Item Amount": 0.0,
442
+ })
443
+ rows.append(row)
444
+ return rows
445
+
446
+ # For each line item, create a row with all invoice context
447
+ for item in line_items:
448
+ row = {
449
+ "Invoice Number": invoice_data.get("Invoice Number", ""),
450
+ "Invoice Date": invoice_data.get("Invoice Date", ""),
451
+ "Due Date": invoice_data.get("Due Date", ""),
452
+ "Currency": invoice_data.get("Currency", ""),
453
+ "Subtotal": invoice_data.get("Subtotal", 0.0),
454
+ "Tax Percentage": invoice_data.get("Tax Percentage", 0.0),
455
+ "Total Tax": invoice_data.get("Total Tax", 0.0),
456
+ "Total Amount": invoice_data.get("Total Amount", 0.0),
457
+ "Sender Name": invoice_data.get("Sender", {}).get("Name", ""),
458
+ "Sender Address": invoice_data.get("Sender", {}).get("Address", ""),
459
+ "Recipient Name": invoice_data.get("Recipient", {}).get("Name", ""),
460
+ "Recipient Address": invoice_data.get("Recipient", {}).get("Address", ""),
461
+ }
462
+
463
+ # Flatten bank details
464
+ bank = invoice_data.get("Bank Details", {})
465
+ for k, v in bank.items():
466
+ row[f"bank_{k}"] = v
467
+
468
+ # Add line item fields
469
+ row.update({
470
+ "Item Description": item.get("Description", ""),
471
+ "Item Quantity": item.get("Quantity", 0),
472
+ "Item Unit Price": item.get("Unit Price", 0.0),
473
+ "Item Amount": item.get("Amount", 0.0),
474
+ })
475
+
476
+ rows.append(row)
477
+
478
+ return rows
479
+
480
  # Load model once
481
  try:
482
  with st.spinner("Loading model & processor (cached) ..."):
 
556
  mapped = map_prediction_to_ui(pred)
557
  except Exception as e:
558
  st.warning(f"Error processing {uploaded_file.name}: {str(e)}")
559
+ pred = None
560
+ mapped = {} # 👈 Ensure mapped is always a dict
561
+
562
+ # ✅ SAFETY: Ensure mapped is a dict before copying
563
+ safe_mapped = mapped if isinstance(mapped, dict) else {}
564
 
565
  # Save to session state
566
  st.session_state.batch_results[file_hash] = {
 
568
  "image": image,
569
  "raw_pred": pred,
570
  "mapped_data": mapped,
571
+ "edited_data": safe_mapped.copy() # editable copy — now safe
572
  }
573
 
574
  progress_bar.progress((idx + 1) / len(uploaded_files))
 
718
  data['Bank Details'] = bank_info
719
 
720
  # ---------- Line Items ----------
721
+ # ---------- Line Items ----------
722
+ # ---------- Line Items ----------
723
+ # ---------- Line Items ----------
724
  with tabs[3]:
725
  editor_key = f"item_editor_{selected_hash}"
726
 
 
730
  if col not in df.columns:
731
  df[col] = ""
732
 
733
+ st.write("✏️ Edit line items below. Press Enter or click outside a cell to confirm each edit.")
734
+ # Get the edited DataFrame directly from data_editor
735
  edited_df = st.data_editor(
736
  df,
737
  num_rows="dynamic",
738
  key=editor_key,
739
  use_container_width=True,
740
+ on_change=trigger_rerun
741
  )
742
 
 
 
 
 
 
 
 
 
 
743
  if len(edited_df) == 0:
744
  st.info("No line items found in the invoice.")
745
 
746
+ # ... (previous code) ...
747
+
748
+ # Save button (per file)
749
  if st.button("💾 Save Edits for This File", key=f"save_{selected_hash}"):
750
+ # Update line items from the current edited_df (which is a DataFrame)
751
+ data['Itemized Data'] = edited_df.to_dict('records')
752
+
753
+ # Save the entire data to session state
754
  st.session_state.batch_results[selected_hash]["edited_data"] = data
755
  st.success(f"✅ Edits saved for {current['file_name']}")
756
 
 
806
  json_name = f"{Path(result['file_name']).stem}_extracted.json"
807
  zf.writestr(json_name, json_data)
808
 
809
+ # Save FULL CSV (all data)
810
  rows = flatten_invoice_to_rows(result["edited_data"])
811
  full_df = pd.DataFrame(rows)
812