invoice-annotator / src /streamlit_app.py
Ankushbl6's picture
Update src/streamlit_app.py
a3bb747 verified
raw
history blame
22.5 kB
import os
from io import BytesIO
import json
import streamlit as st
from PIL import Image, ImageEnhance
from streamlit_drawable_canvas import st_canvas
import pytesseract
# Tesseract is installed via packages.txt on HuggingFace Spaces (Linux)
# No need to set path - it's in system PATH
st.set_page_config(
page_title="Remittance GT Annotator - Interactive OCR",
layout="wide"
)
st.title("Remittance GT Annotator - Interactive OCR")
# ---- Define fields ----
SINGLE_FIELDS = [
"Remittance Advice Number",
"Remittance Advice Date",
"Payment Method",
"FCY",
"Total Payment Amount in FCY",
"Payment Date",
"Payment Reference Number/Check Number",
"Customer Name",
"Customer Address",
"Customer Contact Information",
"Supplier Name",
"Supplier Address",
"Supplier Contact Information",
"Bank Name",
"Bank Account Number",
"Bank Routing Number",
"SWIFT/BIC Code",
]
LINE_ITEM_FIELDS = [
"PO number",
"Invoice number",
"Other document reference number",
"Invoice Date",
"Invoice Amount in FCY",
"Amount Paid for Each Invoice in FCY",
"Outstanding Balance in FCY",
"Discounts Taken in FCY",
"Adjustments(Withholding Tax) in FCY",
"Description",
]
COLOR_PALETTE = [
"#e6194b", "#3cb44b", "#ffe119", "#4363d8", "#f58231",
"#911eb4", "#46f0f0", "#f032e6", "#bcf60c", "#fabebe",
"#008080", "#e6beff", "#9a6324", "#fffac8", "#800000",
"#aaffc3", "#808000", "#ffd8b1", "#000075", "#808080",
"#ffe4e1", "#40e0d0", "#ff1493", "#7fffd4", "#b0e0e6",
"#ffb6c1", "#add8e6",
]
ALL_BASE_FIELDS = SINGLE_FIELDS + LINE_ITEM_FIELDS
FIELD_COLORS = {field: COLOR_PALETTE[i % len(COLOR_PALETTE)] for i, field in enumerate(ALL_BASE_FIELDS)}
# ----- JSONL schema helper mappings -----
HEADER_GROUPS = {
"remittance_advice_details": {
"Remittance Advice Number": "remittance_advice_number",
"Remittance Advice Date": "remittance_advice_date",
"Payment Method": "payment_method",
"FCY": "fcy",
"Total Payment Amount in FCY": "total_payment_amount_in_fcy",
"Payment Date": "payment_date",
"Payment Reference Number/Check Number": "payment_reference_number_check_number",
},
"customer_supplier_details": {
"Customer Name": "customer_name",
"Customer Address": "customer_address",
"Customer Contact Information": "customer_contact_information",
"Supplier Name": "supplier_name",
"Supplier Address": "supplier_address",
"Supplier Contact Information": "supplier_contact_information",
},
"bank_details": {
"Bank Name": "bank_name",
"Bank Account Number": "bank_account_number",
"Bank Routing Number": "bank_routing_number",
"SWIFT/BIC Code": "swift_bic_code",
},
}
LINE_ITEM_FIELD_KEY_MAP = {
"PO number": "po_number",
"Invoice number": "invoice_number",
"Other document reference number": "other_document_reference_number",
"Invoice Date": "invoice_date",
"Invoice Amount in FCY": "invoice_amount_in_fcy",
"Amount Paid for Each Invoice in FCY": "amount_paid_for_each_invoice_in_fcy",
"Outstanding Balance in FCY": "outstanding_balance_in_fcy",
"Discounts Taken in FCY": "discounts_taken_in_fcy",
"Adjustments(Withholding Tax) in FCY": "adjustments_withholding_tax_in_fcy",
"Description": "description",
}
# Fixed zoom options
ZOOM_OPTIONS = [25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100, 110, 120, 130, 140, 150]
# ---- Session state ----
if "field_values" not in st.session_state:
st.session_state.field_values = {}
if "field_rects_orig" not in st.session_state:
st.session_state.field_rects_orig = {}
if "num_line_items" not in st.session_state:
st.session_state.num_line_items = {}
if "selected_image" not in st.session_state:
st.session_state.selected_image = None
if "zoom_values" not in st.session_state:
st.session_state.zoom_values = {}
if "rect_version" not in st.session_state:
st.session_state.rect_version = {}
if "image_data" not in st.session_state:
st.session_state.image_data = {}
# Pending delete - process at start before UI
if "pending_delete" not in st.session_state:
st.session_state.pending_delete = None
if st.session_state.pending_delete is not None:
img_name, field_key = st.session_state.pending_delete
if img_name in st.session_state.field_rects_orig:
st.session_state.field_rects_orig[img_name].pop(field_key, None)
if img_name in st.session_state.rect_version:
st.session_state.rect_version[img_name] += 1
st.session_state.pending_delete = None
# --- Helper functions ---
@st.cache_data
def load_image(file_content):
return Image.open(BytesIO(file_content)).convert("RGB")
def get_display_image_from_bytes(image_bytes, width, height):
"""Create fresh display image from bytes - no caching to avoid stale PIL references"""
pil_image = Image.open(BytesIO(image_bytes)).convert("RGB")
resized = pil_image.resize((width, height), Image.LANCZOS)
resized = ImageEnhance.Sharpness(resized).enhance(1.2)
resized = ImageEnhance.Contrast(resized).enhance(1.1)
return resized
def get_default_zoom(pil_image):
"""Calculate best fit zoom"""
MAX_WIDTH = 850
MAX_HEIGHT = 900
default_scale = min(MAX_WIDTH / pil_image.width, MAX_HEIGHT / pil_image.height, 1.0)
default_zoom = int(default_scale * 100)
# Find closest zoom option
closest = min(ZOOM_OPTIONS, key=lambda x: abs(x - default_zoom))
return closest
def build_gt_record_for_file(file_name: str) -> dict:
"""
JSONL record for one remittance image:
{
"file_name": "<image>",
"gt_parse": {
"remittance_advice_details": {...},
"customer_supplier_details": {...},
"bank_details": {...},
"line_items": [...]
}
}
"""
values = st.session_state.field_values.get(file_name, {})
num_items = st.session_state.num_line_items.get(file_name, 1)
def v(label: str) -> str:
return str(values.get(label, "")).strip()
gt_parse: dict = {}
# Header sections
for section_name, mapping in HEADER_GROUPS.items():
section_dict = {}
for ui_label, json_key in mapping.items():
section_dict[json_key] = v(ui_label)
gt_parse[section_name] = section_dict
# Line items
line_items = []
for idx in range(1, num_items + 1):
row = {}
any_non_empty = False
for ui_label, json_key in LINE_ITEM_FIELD_KEY_MAP.items():
key = f"Line {idx}: {ui_label}"
val = str(values.get(key, "")).strip()
row[json_key] = val
if val:
any_non_empty = True
if any_non_empty:
line_items.append(row)
gt_parse["line_items"] = line_items
return {
"file_name": file_name,
"gt_parse": gt_parse,
}
def has_any_label(fname: str) -> bool:
"""Check if file has any labeled values"""
vals = st.session_state.field_values.get(fname, {})
return any(str(v).strip() for v in vals.values())
# --- Upload ---
uploaded_files = st.file_uploader(
"Upload remittance images",
type=["png", "jpg", "jpeg"],
accept_multiple_files=True,
)
if not uploaded_files:
st.info("Upload at least one image to begin.")
st.stop()
images = []
for f in uploaded_files:
f.seek(0)
content = f.read()
# Store image bytes in session state for stability across reruns
if f.name not in st.session_state.image_data:
st.session_state.image_data[f.name] = content
img = load_image(st.session_state.image_data[f.name])
images.append({"name": f.name, "image": img, "bytes": st.session_state.image_data[f.name]})
file_names = [img["name"] for img in images]
selected_name = st.selectbox("Select image", file_names)
st.session_state.selected_image = selected_name
selected_img_data = next(img for img in images if img["name"] == selected_name)
pil_image = selected_img_data["image"]
image_bytes = selected_img_data["bytes"]
# Init for this image
if selected_name not in st.session_state.field_values:
st.session_state.field_values[selected_name] = {}
if selected_name not in st.session_state.field_rects_orig:
st.session_state.field_rects_orig[selected_name] = {}
if selected_name not in st.session_state.num_line_items:
st.session_state.num_line_items[selected_name] = 1
if selected_name not in st.session_state.rect_version:
st.session_state.rect_version[selected_name] = 0
if selected_name not in st.session_state.zoom_values:
st.session_state.zoom_values[selected_name] = get_default_zoom(pil_image)
# ========== FIELD SELECTION ==========
st.markdown("---")
def add_line_item():
img = st.session_state.selected_image
if img:
st.session_state.num_line_items[img] += 1
def remove_line_item():
img = st.session_state.selected_image
if img and st.session_state.num_line_items[img] > 1:
last_num = st.session_state.num_line_items[img]
for lif in LINE_ITEM_FIELDS:
key = f"Line {last_num}: {lif}"
st.session_state.field_values[img].pop(key, None)
st.session_state.field_rects_orig[img].pop(key, None)
st.session_state.num_line_items[img] -= 1
st.session_state.rect_version[img] += 1
# Initialize field variables with defaults
display_field_name = SINGLE_FIELDS[0]
storage_field_name = SINGLE_FIELDS[0]
base_field_for_color = SINGLE_FIELDS[0]
sel_col1, sel_col2, sel_col3, sel_col4 = st.columns([1.5, 1.5, 2, 2])
with sel_col1:
field_type = st.radio("Type", ["Single", "Line Item"], horizontal=True, label_visibility="collapsed")
with sel_col2:
if field_type == "Single":
field_name = st.selectbox("Field", SINGLE_FIELDS, label_visibility="collapsed")
display_field_name = field_name
storage_field_name = field_name
base_field_for_color = field_name
else:
num_items = st.session_state.num_line_items[selected_name]
line_item_options = [f"Line {i+1}" for i in range(num_items)]
selected_line_item = st.selectbox("Line", line_item_options, label_visibility="collapsed")
line_item_num = int(selected_line_item.split()[1])
with sel_col3:
if field_type == "Line Item":
base_field = st.selectbox("Field", LINE_ITEM_FIELDS, label_visibility="collapsed")
display_field_name = f"{selected_line_item}: {base_field}"
storage_field_name = f"Line {line_item_num}: {base_field}"
base_field_for_color = base_field
with sel_col4:
if field_type == "Line Item":
# Line items +/- buttons next to line item dropdown
add_col, rem_col, info_col = st.columns([1, 1, 2])
with add_col:
st.button("โž•", key=f"addli_{selected_name}", on_click=add_line_item, help="Add line item")
with rem_col:
if st.session_state.num_line_items[selected_name] > 1:
st.button("โž–", key=f"remli_{selected_name}", on_click=remove_line_item, help="Remove line item")
with info_col:
st.write(f"Lines: **{st.session_state.num_line_items[selected_name]}**")
# Guard in case something weird happens
if not storage_field_name:
storage_field_name = display_field_name
field_color = FIELD_COLORS.get(base_field_for_color or display_field_name, "#FF0000")
st.markdown(f"**Current:** <span style='color:{field_color}'>โ—</span> {display_field_name}", unsafe_allow_html=True)
# ========== MAIN COLUMNS ==========
col1, col2 = st.columns([3, 2])
with col1:
# Zoom controls - selectbox + buttons
current_zoom = st.session_state.zoom_values[selected_name]
zoom_index = ZOOM_OPTIONS.index(current_zoom) if current_zoom in ZOOM_OPTIONS else 0
# Zoom callbacks
def do_zoom_out():
img = st.session_state.selected_image
curr = st.session_state.zoom_values[img]
idx = ZOOM_OPTIONS.index(curr) if curr in ZOOM_OPTIONS else 0
if idx > 0:
st.session_state.zoom_values[img] = ZOOM_OPTIONS[idx - 1]
def do_zoom_in():
img = st.session_state.selected_image
curr = st.session_state.zoom_values[img]
idx = ZOOM_OPTIONS.index(curr) if curr in ZOOM_OPTIONS else 0
if idx < len(ZOOM_OPTIONS) - 1:
st.session_state.zoom_values[img] = ZOOM_OPTIONS[idx + 1]
def do_zoom_fit():
img = st.session_state.selected_image
img_bytes = st.session_state.image_data.get(img)
if img_bytes:
pil_img = load_image(img_bytes)
st.session_state.zoom_values[img] = get_default_zoom(pil_img)
zoom_row1, zoom_row2, zoom_row3, zoom_row4 = st.columns([2, 1, 1, 1])
with zoom_row1:
zoom = st.selectbox(
"๐Ÿ” Zoom",
options=ZOOM_OPTIONS,
index=zoom_index,
format_func=lambda x: f"{x}%",
key=f"zoom_select_{selected_name}",
label_visibility="collapsed"
)
st.session_state.zoom_values[selected_name] = zoom
with zoom_row2:
st.button("โž–", key="zoom_out", help="Zoom out", on_click=do_zoom_out)
with zoom_row3:
st.button("โž•", key="zoom_in", help="Zoom in", on_click=do_zoom_in)
with zoom_row4:
st.button("Fit", key="zoom_fit", help="Fit to screen", on_click=do_zoom_fit)
# Get current zoom value
zoom = st.session_state.zoom_values[selected_name]
scale = zoom / 100.0
disp_w = int(pil_image.width * scale)
disp_h = int(pil_image.height * scale)
# Get display image - fresh PIL object each time from stable bytes
display_image = get_display_image_from_bytes(image_bytes, disp_w, disp_h)
st.caption(f"Original: {pil_image.width}ร—{pil_image.height} | Display: {disp_w}ร—{disp_h}")
has_rect = storage_field_name in st.session_state.field_rects_orig[selected_name]
if has_rect:
st.success(f"โœ… Has rectangle. Draw again to replace.")
else:
st.warning(f"โฌœ Draw rectangle for this field")
def orig_to_display(rect_orig, s):
return {
"type": "rect",
"left": rect_orig["left"] * s,
"top": rect_orig["top"] * s,
"width": rect_orig["width"] * s,
"height": rect_orig["height"] * s,
"fill": "rgba(0,0,0,0)",
"stroke": rect_orig.get("stroke", "#FF0000"),
"strokeWidth": rect_orig.get("strokeWidth", 2),
"scaleX": 1,
"scaleY": 1,
}
def display_to_orig(rect_display, s):
w = rect_display.get("width", 0) * rect_display.get("scaleX", 1)
h = rect_display.get("height", 0) * rect_display.get("scaleY", 1)
return {
"left": rect_display.get("left", 0) / s,
"top": rect_display.get("top", 0) / s,
"width": w / s,
"height": h / s,
"stroke": rect_display.get("stroke", "#FF0000"),
"strokeWidth": rect_display.get("strokeWidth", 2),
}
# Build display objects from stored rectangles
all_display_objects = []
for fld, rect_orig in st.session_state.field_rects_orig[selected_name].items():
disp_rect = orig_to_display(rect_orig, scale)
base = fld.split(": ", 1)[1] if ": " in fld else fld
disp_rect["stroke"] = FIELD_COLORS.get(base, "#FF0000")
disp_rect["strokeWidth"] = 3 if fld == storage_field_name else 2
all_display_objects.append(disp_rect)
initial_drawing = {"version": "4.4.0", "objects": all_display_objects}
expected_count = len(all_display_objects)
# Canvas key: includes rect count to force refresh when rectangles change
rect_ver = st.session_state.rect_version[selected_name]
num_rects = len(st.session_state.field_rects_orig[selected_name])
canvas_key = f"canvas_{selected_name}_z{zoom}_rv{rect_ver}_n{num_rects}"
# Render canvas
canvas_result = st_canvas(
background_image=display_image,
height=disp_h,
width=disp_w,
drawing_mode="rect",
stroke_width=3,
stroke_color=field_color,
fill_color="rgba(255,0,0,0.1)",
update_streamlit=True,
initial_drawing=initial_drawing,
key=canvas_key,
)
# Detect new rectangle
if canvas_result.json_data is not None:
objs = canvas_result.json_data.get("objects", [])
if len(objs) > expected_count:
new_rect_display = objs[-1]
new_rect_orig = display_to_orig(new_rect_display, scale)
new_rect_orig["stroke"] = field_color
st.session_state.field_rects_orig[selected_name][storage_field_name] = new_rect_orig
# Auto-run OCR
x1 = max(0, int(new_rect_orig["left"]))
y1 = max(0, int(new_rect_orig["top"]))
x2 = min(pil_image.width, int(new_rect_orig["left"] + new_rect_orig["width"]))
y2 = min(pil_image.height, int(new_rect_orig["top"] + new_rect_orig["height"]))
if x2 > x1 and y2 > y1:
crop = pil_image.crop((x1, y1, x2, y2))
try:
text = pytesseract.image_to_string(crop, config="--psm 6").strip()
if text:
st.session_state.field_values[selected_name][storage_field_name] = text
st.toast(f"โœ… OCR: {text[:50]}{'...' if len(text) > 50 else ''}")
else:
st.toast(f"โœ… Rectangle saved (no text detected)")
except Exception:
st.toast(f"โœ… Rectangle saved")
with col2:
# ========== ALL VALUES SECTION (MOVED UP) ==========
st.markdown("---")
single_rects = sum(1 for f in st.session_state.field_rects_orig[selected_name] if not f.startswith("Line "))
num_items = st.session_state.num_line_items[selected_name]
line_rects = sum(1 for f in st.session_state.field_rects_orig[selected_name] if f.startswith("Line "))
st.write(f"**Single:** {single_rects}/{len(SINGLE_FIELDS)} | **Lines ({num_items}):** {line_rects}/{num_items * len(LINE_ITEM_FIELDS)}")
with st.expander("๐Ÿ“‹ All Values"):
for f in SINGLE_FIELDS:
v = st.session_state.field_values[selected_name].get(f, "")
if v.strip():
st.write(f"**{f}:** {v}")
for i in range(1, num_items + 1):
vals = [(lif, st.session_state.field_values[selected_name].get(f"Line {i}: {lif}", ""))
for lif in LINE_ITEM_FIELDS]
vals = [(lif, v) for lif, v in vals if v.strip()]
if vals:
st.write(f"**Line {i}:**")
for lif, v in vals:
st.write(f" {lif}: {v}")
# ========== OCR & VALUE SECTION (MOVED DOWN) ==========
st.markdown("---")
st.subheader("OCR & Value")
current_rect_orig = st.session_state.field_rects_orig[selected_name].get(storage_field_name)
current_val = st.session_state.field_values[selected_name].get(storage_field_name, "")
if current_rect_orig:
st.caption(f"๐Ÿ“ ({current_rect_orig['left']:.0f}, {current_rect_orig['top']:.0f}) - {current_rect_orig['width']:.0f}ร—{current_rect_orig['height']:.0f}")
x1 = max(0, int(current_rect_orig["left"]))
y1 = max(0, int(current_rect_orig["top"]))
x2 = min(pil_image.width, int(current_rect_orig["left"] + current_rect_orig["width"]))
y2 = min(pil_image.height, int(current_rect_orig["top"] + current_rect_orig["height"]))
if x2 > x1 and y2 > y1:
crop = pil_image.crop((x1, y1, x2, y2))
st.image(crop, caption="Selected Region", width=200)
new_val = st.text_area("Value (auto-filled by OCR)", value=current_val, height=80)
col_btn1, col_btn2, col_btn3 = st.columns(3)
with col_btn1:
if st.button("๐Ÿ’พ Save"):
st.session_state.field_values[selected_name][storage_field_name] = new_val
st.success("Saved!")
with col_btn2:
if current_rect_orig and st.button("๐Ÿ”„ Re-OCR"):
x1 = max(0, int(current_rect_orig["left"]))
y1 = max(0, int(current_rect_orig["top"]))
x2 = min(pil_image.width, int(current_rect_orig["left"] + current_rect_orig["width"]))
y2 = min(pil_image.height, int(current_rect_orig["top"] + current_rect_orig["height"]))
if x2 > x1 and y2 > y1:
crop = pil_image.crop((x1, y1, x2, y2))
try:
text = pytesseract.image_to_string(crop, config="--psm 6").strip()
if text:
st.session_state.field_values[selected_name][storage_field_name] = text
st.success(f"OCR: {text}")
else:
st.warning("Empty result")
except Exception as e:
st.error(f"OCR failed: {e}")
with col_btn3:
def delete_rect():
st.session_state.pending_delete = (selected_name, storage_field_name)
if current_rect_orig:
st.button("๐Ÿ—‘๏ธ Delete", on_click=delete_rect)
# ========== EXPORT SECTION ==========
st.markdown("---")
st.subheader("๐Ÿ“ค JSONL Export")
# Export ALL labeled remittances
records_all = [
build_gt_record_for_file(img["name"])
for img in images
if has_any_label(img["name"])
]
if records_all:
all_jsonl_str = "\n".join(
json.dumps(rec, ensure_ascii=False) for rec in records_all
)
st.download_button(
"โฌ‡๏ธ Export ALL labeled remittances (JSONL)",
data=all_jsonl_str.encode("utf-8"),
file_name="remittances_ground_truth.jsonl",
mime="application/json",
)
else:
st.caption("No labeled remittances yet to export in bulk.")
# Export CURRENT remittance
current_record = build_gt_record_for_file(selected_name)
with st.expander("Preview CURRENT remittance JSON"):
st.json(current_record)
current_jsonl_str = json.dumps(current_record, ensure_ascii=False) + "\n"
st.download_button(
"โฌ‡๏ธ Export CURRENT remittance (JSONL)",
data=current_jsonl_str.encode("utf-8"),
file_name=f"{os.path.splitext(selected_name)[0]}_remittance.jsonl",
mime="application/json",
)