Bhuvi13 commited on
Commit
5b3d6b6
Β·
verified Β·
1 Parent(s): e77e890

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +118 -55
src/streamlit_app.py CHANGED
@@ -162,6 +162,7 @@ def load_model_and_processor(hf_model_id: str, task_prompt: str):
162
 
163
  return processor, model, device, decoder_input_ids
164
 
 
165
  def run_inference_on_image(image: Image.Image, processor, model, device, decoder_input_ids):
166
  import torch
167
 
@@ -385,18 +386,26 @@ def map_prediction_to_ui(pred):
385
  item_rows = []
386
  for it in normalized_items:
387
  if not isinstance(it, dict):
388
- item_rows.append({"Description": str(it), "Quantity": 1, "Unit Price": 0.0, "Amount": 0.0})
389
  continue
390
  desc = it.get("descriptions") or it.get("description") or it.get("desc") or it.get("item") or it.get("name") or ""
391
  qty = it.get("quantity") or it.get("qty") or it.get("Quantity") or ""
392
  unit = it.get("unit_price") or it.get("unitPrice") or it.get("price") or ""
393
  amt = it.get("amount") or it.get("Line_total") or it.get("line_total") or it.get("total") or ""
394
 
 
 
 
 
 
 
395
  item_rows.append({
396
  "Description": str(desc).strip(),
397
  "Quantity": float(clean_number(qty)),
398
  "Unit Price": float(clean_number(unit)),
399
- "Amount": float(clean_number(amt))
 
 
400
  })
401
 
402
  ui["Itemized Data"] = item_rows
@@ -413,6 +422,7 @@ def flatten_invoice_to_rows(invoice_data) -> list:
413
  """
414
  rows = []
415
  line_items = invoice_data.get("Itemized Data", [])
 
416
  if not line_items:
417
  # If no line items, create one row with invoice info only
418
  row = {
@@ -429,10 +439,13 @@ def flatten_invoice_to_rows(invoice_data) -> list:
429
  "Recipient Name": invoice_data.get("Recipient", {}).get("Name", ""),
430
  "Recipient Address": invoice_data.get("Recipient", {}).get("Address", ""),
431
  }
 
432
  # Flatten bank details
433
  bank = invoice_data.get("Bank Details", {})
434
  for k, v in bank.items():
435
- row[f"bank_{k}"] = v
 
 
436
 
437
  # Add empty line item fields
438
  row.update({
@@ -440,6 +453,8 @@ def flatten_invoice_to_rows(invoice_data) -> list:
440
  "Item Quantity": 0,
441
  "Item Unit Price": 0.0,
442
  "Item Amount": 0.0,
 
 
443
  })
444
  rows.append(row)
445
  return rows
@@ -464,7 +479,9 @@ def flatten_invoice_to_rows(invoice_data) -> list:
464
  # Flatten bank details
465
  bank = invoice_data.get("Bank Details", {})
466
  for k, v in bank.items():
467
- row[f"bank_{k}"] = v
 
 
468
 
469
  # Add line item fields
470
  row.update({
@@ -472,12 +489,15 @@ def flatten_invoice_to_rows(invoice_data) -> list:
472
  "Item Quantity": item.get("Quantity", 0),
473
  "Item Unit Price": item.get("Unit Price", 0.0),
474
  "Item Amount": item.get("Amount", 0.0),
 
 
475
  })
476
 
477
  rows.append(row)
478
 
479
  return rows
480
 
 
481
  # Load model once
482
  try:
483
  with st.spinner("Loading model & processor (cached) ..."):
@@ -501,7 +521,7 @@ if "is_processing_batch" not in st.session_state:
501
  if not st.session_state.is_processing_batch and len(st.session_state.batch_results) == 0:
502
  st.markdown("Upload one or more invoice images (png/jpg/jpeg/pdf). The app will process them one by one.")
503
 
504
- st.header("πŸ“€ Upload Invoices (Batch)")
505
 
506
  uploaded_files = st.file_uploader(
507
  "Upload invoice images (png/jpg/jpeg/pdf)",
@@ -624,6 +644,26 @@ elif len(st.session_state.batch_results) > 0:
624
  # RIGHT: Editable Form
625
  with right_col:
626
  st.subheader(f"Editable Invoice: {current['file_name']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
627
  tabs = st.tabs(["Invoice Details", "Sender/Recipient info", "Bank Details", "Line Items"])
628
 
629
  st.markdown(
@@ -728,7 +768,7 @@ elif len(st.session_state.batch_results) > 0:
728
 
729
  item_rows = data.get('Itemized Data', [])
730
  df = pd.DataFrame(item_rows)
731
- for col in ["Description", "Quantity", "Unit Price", "Amount"]:
732
  if col not in df.columns:
733
  df[col] = ""
734
 
@@ -759,15 +799,15 @@ elif len(st.session_state.batch_results) > 0:
759
  # Download buttons (per file)
760
  st.markdown("---")
761
  col_a, col_b, col_c = st.columns([1, 1, 1])
762
- with col_a:
763
- jsonl_str = json.dumps(data, ensure_ascii=False, indent=2)
764
- st.download_button(
765
- "πŸ“₯ Download JSON",
766
- jsonl_str.encode("utf-8"),
767
- file_name=f"{Path(current['file_name']).stem}_extracted.json",
768
- mime="application/json",
769
- key=f"dl_json_{selected_hash}"
770
- )
771
  with col_b:
772
  # βœ… Flatten entire invoice into rows (one per line item)
773
  rows = flatten_invoice_to_rows(data)
@@ -779,7 +819,7 @@ elif len(st.session_state.batch_results) > 0:
779
  "Sender Name", "Sender Address", "Recipient Name", "Recipient Address",
780
  "Subtotal", "Tax Percentage", "Total Tax", "Total Amount",
781
  "bank_name", "bank_account_number", "bank_iban", "bank_swift", "bank_routing", "bank_branch", "bank_acc_name",
782
- "Item Description", "Item Quantity", "Item Unit Price", "Item Amount"
783
  ]
784
  # Keep only columns that exist
785
  existing_cols = [col for col in desired_col_order if col in full_df.columns]
@@ -797,46 +837,69 @@ elif len(st.session_state.batch_results) > 0:
797
  mime="text/csv",
798
  key=f"dl_csv_{selected_hash}"
799
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
800
 
801
- # Global Download All
802
- if st.button("πŸ“¦ Download All Results (ZIP)", key="download_all"):
803
- zip_buffer = BytesIO()
804
- with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
805
- for file_hash, result in st.session_state.batch_results.items():
806
- # Save JSON
807
- json_data = json.dumps(result["edited_data"], ensure_ascii=False, indent=2)
808
- json_name = f"{Path(result['file_name']).stem}_extracted.json"
809
- zf.writestr(json_name, json_data)
810
-
811
- # Save FULL CSV (all data)
812
- rows = flatten_invoice_to_rows(result["edited_data"])
813
- full_df = pd.DataFrame(rows)
814
-
815
- # Optional: reorder columns (same as above)
816
- desired_col_order = [
817
- "Invoice Number", "Invoice Date", "Due Date", "Currency",
818
- "Sender Name", "Sender Address", "Recipient Name", "Recipient Address",
819
- "Subtotal", "Tax Percentage", "Total Tax", "Total Amount",
820
- "bank_name", "bank_account_number", "bank_iban", "bank_swift", "bank_routing", "bank_branch", "bank_acc_name",
821
- "Item Description", "Item Quantity", "Item Unit Price", "Item Amount"
822
- ]
823
- existing_cols = [col for col in desired_col_order if col in full_df.columns]
824
- remaining_cols = [col for col in full_df.columns if col not in existing_cols]
825
- final_col_order = existing_cols + remaining_cols
826
- full_df = full_df[final_col_order]
827
-
828
- csv_data = full_df.to_csv(index=False)
829
- csv_name = f"{Path(result['file_name']).stem}_full.csv"
830
- zf.writestr(csv_name, csv_data)
831
-
832
- zip_buffer.seek(0)
833
- st.download_button(
834
- label="⬇️ Download ZIP",
835
- data=zip_buffer,
836
- file_name="all_extracted_invoices.zip",
837
- mime="application/zip",
838
- key="final_download_button"
839
- )
840
 
841
  # ---------------------------
842
  # PROCESSING STATE β€” Show progress
 
162
 
163
  return processor, model, device, decoder_input_ids
164
 
165
+
166
  def run_inference_on_image(image: Image.Image, processor, model, device, decoder_input_ids):
167
  import torch
168
 
 
386
  item_rows = []
387
  for it in normalized_items:
388
  if not isinstance(it, dict):
389
+ item_rows.append({"Description": str(it), "Quantity": 1, "Unit Price": 0.0, "Amount": 0.0, "Tax": 0.0, "Line Total": 0.0})
390
  continue
391
  desc = it.get("descriptions") or it.get("description") or it.get("desc") or it.get("item") or it.get("name") or ""
392
  qty = it.get("quantity") or it.get("qty") or it.get("Quantity") or ""
393
  unit = it.get("unit_price") or it.get("unitPrice") or it.get("price") or ""
394
  amt = it.get("amount") or it.get("Line_total") or it.get("line_total") or it.get("total") or ""
395
 
396
+ # Extract item-level tax if available under common keys
397
+ tax_val = it.get("tax") or it.get("tax_amount") or it.get("line_tax") or it.get("item_tax") or it.get("taxAmount") or ""
398
+
399
+ # Extract explicit line total if present; otherwise fall back to amount
400
+ line_total_val = it.get("Line_total") or it.get("line_total") or it.get("lineTotal") or amt
401
+
402
  item_rows.append({
403
  "Description": str(desc).strip(),
404
  "Quantity": float(clean_number(qty)),
405
  "Unit Price": float(clean_number(unit)),
406
+ "Amount": float(clean_number(amt)),
407
+ "Tax": float(clean_number(tax_val)),
408
+ "Line Total": float(clean_number(line_total_val))
409
  })
410
 
411
  ui["Itemized Data"] = item_rows
 
422
  """
423
  rows = []
424
  line_items = invoice_data.get("Itemized Data", [])
425
+
426
  if not line_items:
427
  # If no line items, create one row with invoice info only
428
  row = {
 
439
  "Recipient Name": invoice_data.get("Recipient", {}).get("Name", ""),
440
  "Recipient Address": invoice_data.get("Recipient", {}).get("Address", ""),
441
  }
442
+
443
  # Flatten bank details
444
  bank = invoice_data.get("Bank Details", {})
445
  for k, v in bank.items():
446
+ # Avoid double-prefixing if key already contains 'bank_'
447
+ key_name = k if str(k).startswith("bank_") else f"bank_{k}"
448
+ row[key_name] = v
449
 
450
  # Add empty line item fields
451
  row.update({
 
453
  "Item Quantity": 0,
454
  "Item Unit Price": 0.0,
455
  "Item Amount": 0.0,
456
+ "Item Tax": 0.0,
457
+ "Item Line Total": 0.0,
458
  })
459
  rows.append(row)
460
  return rows
 
479
  # Flatten bank details
480
  bank = invoice_data.get("Bank Details", {})
481
  for k, v in bank.items():
482
+ # Avoid double-prefixing if key already contains 'bank_'
483
+ key_name = k if str(k).startswith("bank_") else f"bank_{k}"
484
+ row[key_name] = v
485
 
486
  # Add line item fields
487
  row.update({
 
489
  "Item Quantity": item.get("Quantity", 0),
490
  "Item Unit Price": item.get("Unit Price", 0.0),
491
  "Item Amount": item.get("Amount", 0.0),
492
+ "Item Tax": item.get("Tax", 0.0),
493
+ "Item Line Total": item.get("Line Total", item.get("Amount", 0.0)),
494
  })
495
 
496
  rows.append(row)
497
 
498
  return rows
499
 
500
+
501
  # Load model once
502
  try:
503
  with st.spinner("Loading model & processor (cached) ..."):
 
521
  if not st.session_state.is_processing_batch and len(st.session_state.batch_results) == 0:
522
  st.markdown("Upload one or more invoice images (png/jpg/jpeg/pdf). The app will process them one by one.")
523
 
524
+ st.header("πŸ“€ Upload Invoices")
525
 
526
  uploaded_files = st.file_uploader(
527
  "Upload invoice images (png/jpg/jpeg/pdf)",
 
644
  # RIGHT: Editable Form
645
  with right_col:
646
  st.subheader(f"Editable Invoice: {current['file_name']}")
647
+
648
+ # ---------- Re-run (per-file) ----------
649
+ if st.button("πŸ” Re-Run", key=f"rerun_{selected_hash}"):
650
+ # Re-run inference only for the selected file's image, update stored predictions and editable copy
651
+ with st.spinner("Re-running inference for selected file..."):
652
+ try:
653
+ pred = run_inference_on_image(image, processor, model, device, decoder_input_ids)
654
+ mapped = map_prediction_to_ui(pred)
655
+ safe_mapped = mapped if isinstance(mapped, dict) else {}
656
+
657
+ # Save updated results for this single file
658
+ st.session_state.batch_results[selected_hash]["raw_pred"] = pred
659
+ st.session_state.batch_results[selected_hash]["mapped_data"] = mapped
660
+ st.session_state.batch_results[selected_hash]["edited_data"] = safe_mapped.copy()
661
+
662
+ st.success("βœ… Re-run complete β€” predictions updated for this file.")
663
+ # Refresh the UI so the new values appear in the form
664
+ st.rerun()
665
+ except Exception as e:
666
+ st.error(f"Re-run failed: {e}")
667
  tabs = st.tabs(["Invoice Details", "Sender/Recipient info", "Bank Details", "Line Items"])
668
 
669
  st.markdown(
 
768
 
769
  item_rows = data.get('Itemized Data', [])
770
  df = pd.DataFrame(item_rows)
771
+ for col in ["Description", "Quantity", "Unit Price", "Amount", "Tax", "Line Total"]:
772
  if col not in df.columns:
773
  df[col] = ""
774
 
 
799
  # Download buttons (per file)
800
  st.markdown("---")
801
  col_a, col_b, col_c = st.columns([1, 1, 1])
802
+ #with col_a:
803
+ #jsonl_str = json.dumps(data, ensure_ascii=False, indent=2)
804
+ #st.download_button(
805
+ # "πŸ“₯ Download JSON",
806
+ #jsonl_str.encode("utf-8"),
807
+ #file_name=f"{Path(current['file_name']).stem}_extracted.json",
808
+ #mime="application/json",
809
+ #key=f"dl_json_{selected_hash}"
810
+ #)
811
  with col_b:
812
  # βœ… Flatten entire invoice into rows (one per line item)
813
  rows = flatten_invoice_to_rows(data)
 
819
  "Sender Name", "Sender Address", "Recipient Name", "Recipient Address",
820
  "Subtotal", "Tax Percentage", "Total Tax", "Total Amount",
821
  "bank_name", "bank_account_number", "bank_iban", "bank_swift", "bank_routing", "bank_branch", "bank_acc_name",
822
+ "Item Description", "Item Quantity", "Item Unit Price", "Item Amount", "Item Tax", "Item Line Total"
823
  ]
824
  # Keep only columns that exist
825
  existing_cols = [col for col in desired_col_order if col in full_df.columns]
 
837
  mime="text/csv",
838
  key=f"dl_csv_{selected_hash}"
839
  )
840
+ # Global Download All β€” produce a single Excel file (concatenated rows) and trigger direct download
841
+ if st.button("πŸ“¦ Download All Results (Excel)", key="download_all"):
842
+ # Collect rows from all invoices and concatenate into one DataFrame
843
+ all_rows = []
844
+ for file_hash, result in st.session_state.batch_results.items():
845
+ rows = flatten_invoice_to_rows(result["edited_data"])
846
+ # Annotate rows with source file name so user can identify which invoice each row came from
847
+ for r in rows:
848
+ r["Source File"] = result.get("file_name", file_hash)
849
+ all_rows.extend(rows)
850
+
851
+ if len(all_rows) == 0:
852
+ st.warning("No invoice data available to download.")
853
+ else:
854
+ full_df = pd.DataFrame(all_rows)
855
+
856
+ # Reorder columns to put Source File first
857
+ cols = list(full_df.columns)
858
+ if "Source File" in cols:
859
+ cols = ["Source File"] + [c for c in cols if c != "Source File"]
860
+ full_df = full_df[cols]
861
+
862
+ # Try to write XLSX (preferred). If engine not available, fall back to CSV.
863
+ buffer = BytesIO()
864
+ dl_filename = "all_extracted_invoices.xlsx"
865
+ tried_xlsx = False
866
+ try:
867
+ with pd.ExcelWriter(buffer, engine="openpyxl") as writer:
868
+ full_df.to_excel(writer, index=False, sheet_name="Invoices")
869
+ tried_xlsx = True
870
+ buffer.seek(0)
871
+ file_bytes = buffer.read()
872
+ mime = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
873
+ except Exception:
874
+ # Fallback to CSV
875
+ buffer = BytesIO()
876
+ csv_data = full_df.to_csv(index=False).encode("utf-8")
877
+ buffer.write(csv_data)
878
+ buffer.seek(0)
879
+ file_bytes = buffer.read()
880
+ dl_filename = "all_extracted_invoices.csv"
881
+ mime = "text/csv"
882
+
883
+ # Trigger immediate download via a data URI and small HTML snippet
884
+ import base64
885
+ import streamlit.components.v1 as components
886
+ b64 = base64.b64encode(file_bytes).decode()
887
+ data_uri = f"data:{mime};base64,{b64}"
888
+
889
+ auto_dl_html = f'''<html>
890
+ <body>
891
+ <a id="dlLink" href="{data_uri}" download="{dl_filename}"></a>
892
+ <script>
893
+ const a = document.getElementById('dlLink');
894
+ a.click();
895
+ </script>
896
+ </body>
897
+ </html>'''
898
+
899
+ components.html(auto_dl_html, height=0)
900
 
901
+ # ---------------------------
902
+ # PROCESSING STATE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
903
 
904
  # ---------------------------
905
  # PROCESSING STATE β€” Show progress