Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +283 -192
src/streamlit_app.py
CHANGED
|
@@ -39,7 +39,8 @@ DOCUMENT_TEMPLATES = {
|
|
| 39 |
"name": "ID Card / Passport",
|
| 40 |
"description": "Extract structured data from identity documents",
|
| 41 |
"prompt": """Extract structured data from this identity document.
|
| 42 |
-
|
|
|
|
| 43 |
{
|
| 44 |
"document_type": "",
|
| 45 |
"full_name": "",
|
|
@@ -48,8 +49,11 @@ Output ONLY valid JSON:
|
|
| 48 |
"date_of_expiry": "",
|
| 49 |
"nationality": "",
|
| 50 |
"document_number": "",
|
| 51 |
-
"
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
| 53 |
"icon": "π"
|
| 54 |
},
|
| 55 |
|
|
@@ -57,6 +61,7 @@ Output ONLY valid JSON:
|
|
| 57 |
"name": "Receipt",
|
| 58 |
"description": "Extract items, prices, and totals from receipts",
|
| 59 |
"prompt": """Extract information from this receipt.
|
|
|
|
| 60 |
Output ONLY valid JSON:
|
| 61 |
{
|
| 62 |
"merchant_name": "",
|
|
@@ -77,6 +82,7 @@ Output ONLY valid JSON:
|
|
| 77 |
"name": "Invoice",
|
| 78 |
"description": "Extract invoice details and line items",
|
| 79 |
"prompt": """Extract information from this invoice.
|
|
|
|
| 80 |
Output ONLY valid JSON:
|
| 81 |
{
|
| 82 |
"invoice_number": "",
|
|
@@ -106,6 +112,7 @@ Output ONLY valid JSON:
|
|
| 106 |
"name": "Business Card",
|
| 107 |
"description": "Extract contact information",
|
| 108 |
"prompt": """Extract contact information from this business card.
|
|
|
|
| 109 |
Output ONLY valid JSON:
|
| 110 |
{
|
| 111 |
"name": "",
|
|
@@ -125,6 +132,7 @@ Output ONLY valid JSON:
|
|
| 125 |
"name": "Form",
|
| 126 |
"description": "Extract filled form data",
|
| 127 |
"prompt": """Extract all fields and values from this form.
|
|
|
|
| 128 |
Output ONLY valid JSON with field names as keys and filled values:
|
| 129 |
{
|
| 130 |
"field_name": "value"
|
|
@@ -133,7 +141,7 @@ Output ONLY valid JSON with field names as keys and filled values:
|
|
| 133 |
},
|
| 134 |
|
| 135 |
DocumentType.HANDWRITTEN: {
|
| 136 |
-
"name": "
|
| 137 |
"description": "Extract text from handwritten documents",
|
| 138 |
"prompt": "Extract all handwritten text from this image. Output plain text, preserving line breaks.",
|
| 139 |
"icon": "βοΈ"
|
|
@@ -192,19 +200,21 @@ def preprocess_image(
|
|
| 192 |
) -> Image.Image:
|
| 193 |
"""
|
| 194 |
Preprocess image with optional enhancements
|
|
|
|
| 195 |
Args:
|
| 196 |
image: PIL Image
|
| 197 |
enhance_contrast: Apply CLAHE contrast enhancement
|
| 198 |
denoise: Apply denoising
|
| 199 |
sharpen: Apply sharpening
|
| 200 |
auto_rotate: Attempt to auto-rotate text to horizontal
|
|
|
|
| 201 |
Returns:
|
| 202 |
Preprocessed PIL Image
|
| 203 |
"""
|
| 204 |
if prevent_cropping and not auto_rotate:
|
| 205 |
raise Exception(f"Auto-Rotate must be enabled when Prevent-Cropping is active")
|
| 206 |
|
| 207 |
-
|
| 208 |
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 209 |
gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
|
| 210 |
|
|
@@ -284,10 +294,12 @@ def extract_text(
|
|
| 284 |
) -> tuple[str, int]:
|
| 285 |
"""
|
| 286 |
Extract text from image using GLM-OCR
|
|
|
|
| 287 |
Args:
|
| 288 |
image: PIL Image
|
| 289 |
prompt: Extraction prompt
|
| 290 |
max_tokens: Maximum tokens to generate
|
|
|
|
| 291 |
Returns:
|
| 292 |
Tuple of (extracted_text, processing_time_ms)
|
| 293 |
"""
|
|
@@ -353,6 +365,20 @@ st.set_page_config(
|
|
| 353 |
initial_sidebar_state="expanded"
|
| 354 |
)
|
| 355 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
# Header
|
| 357 |
st.title("π Universal OCR Scanner")
|
| 358 |
st.markdown("Extract text and structured data from **any document** - receipts, IDs, invoices, forms, and more!")
|
|
@@ -387,7 +413,7 @@ with st.sidebar:
|
|
| 387 |
auto_rotate = st.checkbox("Auto-Rotate", value=False,
|
| 388 |
help="Automatically straighten tilted documents")
|
| 389 |
prevent_cropping = st.checkbox("Prevent-Cropping", value=False,
|
| 390 |
-
|
| 391 |
|
| 392 |
st.markdown("---")
|
| 393 |
|
|
@@ -426,11 +452,17 @@ with col1:
|
|
| 426 |
)
|
| 427 |
if uploaded_file is not None:
|
| 428 |
image = Image.open(uploaded_file).convert("RGB")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
|
| 430 |
with camera_tab:
|
| 431 |
camera_picture = st.camera_input("Take a photo")
|
| 432 |
if camera_picture is not None:
|
| 433 |
image = Image.open(BytesIO(camera_picture.getvalue())).convert("RGB")
|
|
|
|
|
|
|
| 434 |
|
| 435 |
# Show original image
|
| 436 |
if image is not None:
|
|
@@ -445,7 +477,8 @@ with col2:
|
|
| 445 |
"Custom Extraction Prompt:",
|
| 446 |
value=DOCUMENT_TEMPLATES[doc_type]['prompt'],
|
| 447 |
height=200,
|
| 448 |
-
help="Customize how the OCR extracts data"
|
|
|
|
| 449 |
)
|
| 450 |
else:
|
| 451 |
prompt = DOCUMENT_TEMPLATES[doc_type]['prompt']
|
|
@@ -453,18 +486,23 @@ with col2:
|
|
| 453 |
|
| 454 |
# Process button
|
| 455 |
if image is not None:
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
|
|
|
|
|
|
|
|
|
| 461 |
else:
|
| 462 |
st.info("π Upload or capture an image to begin")
|
| 463 |
-
process_button = False
|
| 464 |
|
| 465 |
-
# Processing
|
| 466 |
-
if image is not None and
|
| 467 |
-
|
|
|
|
|
|
|
|
|
|
| 468 |
try:
|
| 469 |
# Preprocess image
|
| 470 |
if enhance_contrast or denoise or sharpen or auto_rotate or prevent_cropping:
|
|
@@ -495,192 +533,245 @@ if image is not None and process_button:
|
|
| 495 |
max_tokens=max_tokens
|
| 496 |
)
|
| 497 |
|
| 498 |
-
#
|
| 499 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
except json.JSONDecodeError:
|
| 519 |
-
is_json = False
|
| 520 |
-
|
| 521 |
-
# Display based on type
|
| 522 |
-
st.markdown("---")
|
| 523 |
-
st.subheader("π Extracted Data")
|
| 524 |
-
|
| 525 |
-
if is_json and parsed_data:
|
| 526 |
-
# Structured data display
|
| 527 |
-
col_display, col_download = st.columns([2, 1])
|
| 528 |
-
|
| 529 |
-
with col_display:
|
| 530 |
-
# Format display based on document type
|
| 531 |
-
if doc_type == DocumentType.RECEIPT:
|
| 532 |
-
st.markdown("### π§Ύ Receipt Details")
|
| 533 |
-
|
| 534 |
-
# Merchant info
|
| 535 |
-
if "merchant_name" in parsed_data:
|
| 536 |
-
st.markdown(f"**Merchant:** {parsed_data['merchant_name']}")
|
| 537 |
-
if "date" in parsed_data:
|
| 538 |
-
st.markdown(f"**Date:** {parsed_data['date']}")
|
| 539 |
-
if "time" in parsed_data:
|
| 540 |
-
st.markdown(f"**Time:** {parsed_data['time']}")
|
| 541 |
-
|
| 542 |
-
# Items table
|
| 543 |
-
if "items" in parsed_data and parsed_data["items"]:
|
| 544 |
-
st.markdown("**Items:**")
|
| 545 |
-
items_df = pd.DataFrame(parsed_data["items"])
|
| 546 |
-
st.dataframe(items_df, width="content", hide_index=True)
|
| 547 |
-
|
| 548 |
-
# Totals
|
| 549 |
-
st.markdown("---")
|
| 550 |
-
if "subtotal" in parsed_data:
|
| 551 |
-
st.markdown(f"**Subtotal:** ${parsed_data['subtotal']:.2f}")
|
| 552 |
-
if "tax" in parsed_data:
|
| 553 |
-
st.markdown(f"**Tax:** ${parsed_data['tax']:.2f}")
|
| 554 |
-
if "total" in parsed_data:
|
| 555 |
-
st.markdown(f"**Total:** ${parsed_data['total']:.2f}")
|
| 556 |
-
|
| 557 |
-
elif doc_type == DocumentType.INVOICE:
|
| 558 |
-
st.markdown("### π Invoice Details")
|
| 559 |
-
|
| 560 |
-
col_inv1, col_inv2 = st.columns(2)
|
| 561 |
-
|
| 562 |
-
with col_inv1:
|
| 563 |
-
st.markdown("**Invoice Info:**")
|
| 564 |
-
if "invoice_number" in parsed_data:
|
| 565 |
-
st.text(f"Number: {parsed_data['invoice_number']}")
|
| 566 |
-
if "date" in parsed_data:
|
| 567 |
-
st.text(f"Date: {parsed_data['date']}")
|
| 568 |
-
if "due_date" in parsed_data:
|
| 569 |
-
st.text(f"Due: {parsed_data['due_date']}")
|
| 570 |
-
|
| 571 |
-
with col_inv2:
|
| 572 |
-
if "vendor" in parsed_data:
|
| 573 |
-
st.markdown("**Vendor:**")
|
| 574 |
-
vendor = parsed_data["vendor"]
|
| 575 |
-
if isinstance(vendor, dict):
|
| 576 |
-
for k, v in vendor.items():
|
| 577 |
-
if v:
|
| 578 |
-
st.text(f"{k.title()}: {v}")
|
| 579 |
-
|
| 580 |
-
# Line items
|
| 581 |
-
if "line_items" in parsed_data and parsed_data["line_items"]:
|
| 582 |
-
st.markdown("**Line Items:**")
|
| 583 |
-
items_df = pd.DataFrame(parsed_data["line_items"])
|
| 584 |
-
st.dataframe(items_df, width="content", hide_index=True)
|
| 585 |
-
|
| 586 |
-
# Total
|
| 587 |
-
if "total" in parsed_data:
|
| 588 |
-
st.markdown(f"### **Total: ${parsed_data['total']:.2f}**")
|
| 589 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 590 |
else:
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
)
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
st.
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 659 |
|
| 660 |
else:
|
| 661 |
-
#
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 668 |
|
| 669 |
-
# Download
|
| 670 |
-
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 671 |
st.download_button(
|
| 672 |
-
label="
|
| 673 |
-
data=
|
| 674 |
-
file_name=f"
|
| 675 |
-
mime="text/
|
|
|
|
| 676 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 677 |
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
|
| 682 |
-
|
| 683 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 684 |
|
| 685 |
# Footer
|
| 686 |
st.markdown("---")
|
|
|
|
| 39 |
"name": "ID Card / Passport",
|
| 40 |
"description": "Extract structured data from identity documents",
|
| 41 |
"prompt": """Extract structured data from this identity document.
|
| 42 |
+
|
| 43 |
+
Output ONLY valid JSON with these exact fields, no nested objects:
|
| 44 |
{
|
| 45 |
"document_type": "",
|
| 46 |
"full_name": "",
|
|
|
|
| 49 |
"date_of_expiry": "",
|
| 50 |
"nationality": "",
|
| 51 |
"document_number": "",
|
| 52 |
+
"place_of_birth": "",
|
| 53 |
+
"personal_number": ""
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
IMPORTANT: Do NOT create nested or recursive structures. Keep it flat and simple.""",
|
| 57 |
"icon": "π"
|
| 58 |
},
|
| 59 |
|
|
|
|
| 61 |
"name": "Receipt",
|
| 62 |
"description": "Extract items, prices, and totals from receipts",
|
| 63 |
"prompt": """Extract information from this receipt.
|
| 64 |
+
|
| 65 |
Output ONLY valid JSON:
|
| 66 |
{
|
| 67 |
"merchant_name": "",
|
|
|
|
| 82 |
"name": "Invoice",
|
| 83 |
"description": "Extract invoice details and line items",
|
| 84 |
"prompt": """Extract information from this invoice.
|
| 85 |
+
|
| 86 |
Output ONLY valid JSON:
|
| 87 |
{
|
| 88 |
"invoice_number": "",
|
|
|
|
| 112 |
"name": "Business Card",
|
| 113 |
"description": "Extract contact information",
|
| 114 |
"prompt": """Extract contact information from this business card.
|
| 115 |
+
|
| 116 |
Output ONLY valid JSON:
|
| 117 |
{
|
| 118 |
"name": "",
|
|
|
|
| 132 |
"name": "Form",
|
| 133 |
"description": "Extract filled form data",
|
| 134 |
"prompt": """Extract all fields and values from this form.
|
| 135 |
+
|
| 136 |
Output ONLY valid JSON with field names as keys and filled values:
|
| 137 |
{
|
| 138 |
"field_name": "value"
|
|
|
|
| 141 |
},
|
| 142 |
|
| 143 |
DocumentType.HANDWRITTEN: {
|
| 144 |
+
"name": "Handwritten Note",
|
| 145 |
"description": "Extract text from handwritten documents",
|
| 146 |
"prompt": "Extract all handwritten text from this image. Output plain text, preserving line breaks.",
|
| 147 |
"icon": "βοΈ"
|
|
|
|
| 200 |
) -> Image.Image:
|
| 201 |
"""
|
| 202 |
Preprocess image with optional enhancements
|
| 203 |
+
|
| 204 |
Args:
|
| 205 |
image: PIL Image
|
| 206 |
enhance_contrast: Apply CLAHE contrast enhancement
|
| 207 |
denoise: Apply denoising
|
| 208 |
sharpen: Apply sharpening
|
| 209 |
auto_rotate: Attempt to auto-rotate text to horizontal
|
| 210 |
+
|
| 211 |
Returns:
|
| 212 |
Preprocessed PIL Image
|
| 213 |
"""
|
| 214 |
if prevent_cropping and not auto_rotate:
|
| 215 |
raise Exception(f"Auto-Rotate must be enabled when Prevent-Cropping is active")
|
| 216 |
|
| 217 |
+
# Convert to OpenCV format
|
| 218 |
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 219 |
gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
|
| 220 |
|
|
|
|
| 294 |
) -> tuple[str, int]:
|
| 295 |
"""
|
| 296 |
Extract text from image using GLM-OCR
|
| 297 |
+
|
| 298 |
Args:
|
| 299 |
image: PIL Image
|
| 300 |
prompt: Extraction prompt
|
| 301 |
max_tokens: Maximum tokens to generate
|
| 302 |
+
|
| 303 |
Returns:
|
| 304 |
Tuple of (extracted_text, processing_time_ms)
|
| 305 |
"""
|
|
|
|
| 365 |
initial_sidebar_state="expanded"
|
| 366 |
)
|
| 367 |
|
| 368 |
+
# Initialize session state
|
| 369 |
+
if 'should_process' not in st.session_state:
|
| 370 |
+
st.session_state.should_process = False
|
| 371 |
+
if 'has_results' not in st.session_state:
|
| 372 |
+
st.session_state.has_results = False
|
| 373 |
+
if 'output_text' not in st.session_state:
|
| 374 |
+
st.session_state.output_text = ""
|
| 375 |
+
if 'processing_time' not in st.session_state:
|
| 376 |
+
st.session_state.processing_time = 0
|
| 377 |
+
if 'doc_type' not in st.session_state:
|
| 378 |
+
st.session_state.doc_type = DocumentType.GENERAL
|
| 379 |
+
if 'current_file' not in st.session_state:
|
| 380 |
+
st.session_state.current_file = None
|
| 381 |
+
|
| 382 |
# Header
|
| 383 |
st.title("π Universal OCR Scanner")
|
| 384 |
st.markdown("Extract text and structured data from **any document** - receipts, IDs, invoices, forms, and more!")
|
|
|
|
| 413 |
auto_rotate = st.checkbox("Auto-Rotate", value=False,
|
| 414 |
help="Automatically straighten tilted documents")
|
| 415 |
prevent_cropping = st.checkbox("Prevent-Cropping", value=False,
|
| 416 |
+
help="Prevent cropping when rotate")
|
| 417 |
|
| 418 |
st.markdown("---")
|
| 419 |
|
|
|
|
| 452 |
)
|
| 453 |
if uploaded_file is not None:
|
| 454 |
image = Image.open(uploaded_file).convert("RGB")
|
| 455 |
+
# Clear previous results when new image uploaded
|
| 456 |
+
if 'current_file' not in st.session_state or st.session_state.current_file != uploaded_file.name:
|
| 457 |
+
st.session_state.current_file = uploaded_file.name
|
| 458 |
+
st.session_state.has_results = False
|
| 459 |
|
| 460 |
with camera_tab:
|
| 461 |
camera_picture = st.camera_input("Take a photo")
|
| 462 |
if camera_picture is not None:
|
| 463 |
image = Image.open(BytesIO(camera_picture.getvalue())).convert("RGB")
|
| 464 |
+
# Clear previous results when new photo taken
|
| 465 |
+
st.session_state.has_results = False
|
| 466 |
|
| 467 |
# Show original image
|
| 468 |
if image is not None:
|
|
|
|
| 477 |
"Custom Extraction Prompt:",
|
| 478 |
value=DOCUMENT_TEMPLATES[doc_type]['prompt'],
|
| 479 |
height=200,
|
| 480 |
+
help="Customize how the OCR extracts data",
|
| 481 |
+
key="custom_prompt_text"
|
| 482 |
)
|
| 483 |
else:
|
| 484 |
prompt = DOCUMENT_TEMPLATES[doc_type]['prompt']
|
|
|
|
| 486 |
|
| 487 |
# Process button
|
| 488 |
if image is not None:
|
| 489 |
+
if st.button(
|
| 490 |
+
"π Extract Text",
|
| 491 |
+
type="primary",
|
| 492 |
+
width="content",
|
| 493 |
+
key="extract_button"
|
| 494 |
+
):
|
| 495 |
+
# Trigger processing by setting session state
|
| 496 |
+
st.session_state.should_process = True
|
| 497 |
else:
|
| 498 |
st.info("π Upload or capture an image to begin")
|
|
|
|
| 499 |
|
| 500 |
+
# Processing (only run when button is clicked)
|
| 501 |
+
if image is not None and st.session_state.get('should_process', False):
|
| 502 |
+
# Clear the flag immediately to prevent re-processing on next rerun
|
| 503 |
+
st.session_state.should_process = False
|
| 504 |
+
|
| 505 |
+
with st.spinner("π Processing document..."):
|
| 506 |
try:
|
| 507 |
# Preprocess image
|
| 508 |
if enhance_contrast or denoise or sharpen or auto_rotate or prevent_cropping:
|
|
|
|
| 533 |
max_tokens=max_tokens
|
| 534 |
)
|
| 535 |
|
| 536 |
+
# Store results in session state
|
| 537 |
+
st.session_state.output_text = output_text
|
| 538 |
+
st.session_state.processing_time = processing_time
|
| 539 |
+
st.session_state.doc_type = doc_type
|
| 540 |
+
st.session_state.preprocessed_image = preprocessed_image
|
| 541 |
+
st.session_state.has_results = True
|
| 542 |
|
| 543 |
+
except Exception as e:
|
| 544 |
+
st.error(f"β Error during extraction: {str(e)}")
|
| 545 |
+
import traceback
|
| 546 |
+
|
| 547 |
+
with st.expander("Show Error Details"):
|
| 548 |
+
st.code(traceback.format_exc())
|
| 549 |
+
st.session_state.has_results = False
|
| 550 |
+
|
| 551 |
+
# Display results (separate from processing)
|
| 552 |
+
if st.session_state.get('has_results', False):
|
| 553 |
+
output_text = st.session_state.output_text
|
| 554 |
+
processing_time = st.session_state.processing_time
|
| 555 |
+
doc_type = st.session_state.doc_type
|
| 556 |
+
preprocessed_image = st.session_state.get('preprocessed_image', image)
|
| 557 |
+
|
| 558 |
+
# Display success message
|
| 559 |
+
st.success(f"β
Extraction complete! ({processing_time}ms)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 560 |
|
| 561 |
+
# Try to parse as JSON for structured documents
|
| 562 |
+
is_json = False
|
| 563 |
+
parsed_data = None
|
| 564 |
+
|
| 565 |
+
if doc_type in [DocumentType.ID_CARD, DocumentType.RECEIPT,
|
| 566 |
+
DocumentType.INVOICE, DocumentType.BUSINESS_CARD,
|
| 567 |
+
DocumentType.FORM]:
|
| 568 |
+
try:
|
| 569 |
+
# Clean JSON from markdown
|
| 570 |
+
clean_text = output_text
|
| 571 |
+
if "```json" in clean_text:
|
| 572 |
+
clean_text = clean_text.split("```json")[1].split("```")[0].strip()
|
| 573 |
+
elif "```" in clean_text:
|
| 574 |
+
clean_text = clean_text.split("```")[1].split("```")[0].strip()
|
| 575 |
+
|
| 576 |
+
# Truncate if too long (likely recursive)
|
| 577 |
+
if len(clean_text) > 50000: # Reasonable JSON should be much smaller
|
| 578 |
+
st.warning("β οΈ Detected recursive JSON structure. Truncating...")
|
| 579 |
+
clean_text = clean_text[:50000]
|
| 580 |
+
|
| 581 |
+
parsed_data = json.loads(clean_text)
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
# Flatten recursive structures
|
| 585 |
+
def flatten_dict(d, max_depth=2, current_depth=0):
|
| 586 |
+
"""Remove recursive nested structures"""
|
| 587 |
+
if current_depth >= max_depth:
|
| 588 |
+
return {}
|
| 589 |
+
|
| 590 |
+
if not isinstance(d, dict):
|
| 591 |
+
return d
|
| 592 |
+
|
| 593 |
+
flattened = {}
|
| 594 |
+
for key, value in d.items():
|
| 595 |
+
if isinstance(value, dict):
|
| 596 |
+
# Only keep first level of nesting
|
| 597 |
+
if current_depth < max_depth - 1:
|
| 598 |
+
flattened[key] = flatten_dict(value, max_depth, current_depth + 1)
|
| 599 |
+
# Skip deeply nested structures
|
| 600 |
+
elif isinstance(value, list):
|
| 601 |
+
# Keep lists but limit depth
|
| 602 |
+
flattened[key] = value
|
| 603 |
else:
|
| 604 |
+
flattened[key] = value
|
| 605 |
+
|
| 606 |
+
return flattened
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
# Flatten the parsed data
|
| 610 |
+
parsed_data = flatten_dict(parsed_data, max_depth=2)
|
| 611 |
+
|
| 612 |
+
is_json = True
|
| 613 |
+
except json.JSONDecodeError:
|
| 614 |
+
is_json = False
|
| 615 |
+
except Exception as e:
|
| 616 |
+
st.warning(f"β οΈ JSON parsing issue: {str(e)}")
|
| 617 |
+
is_json = False
|
| 618 |
+
|
| 619 |
+
# Display based on type
|
| 620 |
+
st.markdown("---")
|
| 621 |
+
st.subheader("π Extracted Data")
|
| 622 |
+
|
| 623 |
+
if is_json and parsed_data:
|
| 624 |
+
# Structured data display
|
| 625 |
+
col_display, col_download = st.columns([2, 1])
|
| 626 |
+
|
| 627 |
+
with col_display:
|
| 628 |
+
# Format display based on document type
|
| 629 |
+
if doc_type == DocumentType.RECEIPT:
|
| 630 |
+
st.markdown("### π§Ύ Receipt Details")
|
| 631 |
+
|
| 632 |
+
# Merchant info
|
| 633 |
+
if "merchant_name" in parsed_data:
|
| 634 |
+
st.markdown(f"**Merchant:** {parsed_data['merchant_name']}")
|
| 635 |
+
if "date" in parsed_data:
|
| 636 |
+
st.markdown(f"**Date:** {parsed_data['date']}")
|
| 637 |
+
if "time" in parsed_data:
|
| 638 |
+
st.markdown(f"**Time:** {parsed_data['time']}")
|
| 639 |
+
|
| 640 |
+
# Items table
|
| 641 |
+
if "items" in parsed_data and parsed_data["items"]:
|
| 642 |
+
st.markdown("**Items:**")
|
| 643 |
+
items_df = pd.DataFrame(parsed_data["items"])
|
| 644 |
+
st.dataframe(items_df, width="content", hide_index=True)
|
| 645 |
+
|
| 646 |
+
# Totals
|
| 647 |
+
st.markdown("---")
|
| 648 |
+
if "subtotal" in parsed_data:
|
| 649 |
+
st.markdown(f"**Subtotal:** ${parsed_data['subtotal']:.2f}")
|
| 650 |
+
if "tax" in parsed_data:
|
| 651 |
+
st.markdown(f"**Tax:** ${parsed_data['tax']:.2f}")
|
| 652 |
+
if "total" in parsed_data:
|
| 653 |
+
st.markdown(f"**Total:** ${parsed_data['total']:.2f}")
|
| 654 |
+
|
| 655 |
+
elif doc_type == DocumentType.INVOICE:
|
| 656 |
+
st.markdown("### π Invoice Details")
|
| 657 |
+
|
| 658 |
+
col_inv1, col_inv2 = st.columns(2)
|
| 659 |
+
|
| 660 |
+
with col_inv1:
|
| 661 |
+
st.markdown("**Invoice Info:**")
|
| 662 |
+
if "invoice_number" in parsed_data:
|
| 663 |
+
st.text(f"Number: {parsed_data['invoice_number']}")
|
| 664 |
+
if "date" in parsed_data:
|
| 665 |
+
st.text(f"Date: {parsed_data['date']}")
|
| 666 |
+
if "due_date" in parsed_data:
|
| 667 |
+
st.text(f"Due: {parsed_data['due_date']}")
|
| 668 |
+
|
| 669 |
+
with col_inv2:
|
| 670 |
+
if "vendor" in parsed_data:
|
| 671 |
+
st.markdown("**Vendor:**")
|
| 672 |
+
vendor = parsed_data["vendor"]
|
| 673 |
+
if isinstance(vendor, dict):
|
| 674 |
+
for k, v in vendor.items():
|
| 675 |
+
if v:
|
| 676 |
+
st.text(f"{k.title()}: {v}")
|
| 677 |
+
|
| 678 |
+
# Line items
|
| 679 |
+
if "line_items" in parsed_data and parsed_data["line_items"]:
|
| 680 |
+
st.markdown("**Line Items:**")
|
| 681 |
+
items_df = pd.DataFrame(parsed_data["line_items"])
|
| 682 |
+
st.dataframe(items_df, width="content", hide_index=True)
|
| 683 |
+
|
| 684 |
+
# Total
|
| 685 |
+
if "total" in parsed_data:
|
| 686 |
+
st.markdown(f"### **Total: ${parsed_data['total']:.2f}**")
|
| 687 |
|
| 688 |
else:
|
| 689 |
+
# Generic structured data display
|
| 690 |
+
for key, value in parsed_data.items():
|
| 691 |
+
if isinstance(value, dict):
|
| 692 |
+
st.markdown(f"**{key.replace('_', ' ').title()}:**")
|
| 693 |
+
for k, v in value.items():
|
| 694 |
+
st.text(f" {k}: {v}")
|
| 695 |
+
elif isinstance(value, list):
|
| 696 |
+
st.markdown(f"**{key.replace('_', ' ').title()}:**")
|
| 697 |
+
if value and isinstance(value[0], dict):
|
| 698 |
+
df = pd.DataFrame(value)
|
| 699 |
+
st.dataframe(df, width="content", hide_index=True)
|
| 700 |
+
else:
|
| 701 |
+
for item in value:
|
| 702 |
+
st.text(f" β’ {item}")
|
| 703 |
+
else:
|
| 704 |
+
st.markdown(f"**{key.replace('_', ' ').title()}:** {value}")
|
| 705 |
+
|
| 706 |
+
with col_download:
|
| 707 |
+
st.subheader("πΎ Downloads")
|
| 708 |
+
|
| 709 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 710 |
+
|
| 711 |
+
# JSON download
|
| 712 |
+
json_str = json.dumps(parsed_data, ensure_ascii=False, indent=2)
|
| 713 |
+
st.download_button(
|
| 714 |
+
label="π JSON",
|
| 715 |
+
data=json_str,
|
| 716 |
+
file_name=f"{doc_type.value}_{timestamp}.json",
|
| 717 |
+
mime="application/json",
|
| 718 |
+
width="content"
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
# CSV download (flattened)
|
| 722 |
+
try:
|
| 723 |
+
# Flatten nested structures
|
| 724 |
+
flat_data = {}
|
| 725 |
+
for k, v in parsed_data.items():
|
| 726 |
+
if isinstance(v, (dict, list)):
|
| 727 |
+
flat_data[k] = json.dumps(v, ensure_ascii=False)
|
| 728 |
+
else:
|
| 729 |
+
flat_data[k] = v
|
| 730 |
+
|
| 731 |
+
df = pd.DataFrame([flat_data])
|
| 732 |
+
csv_buffer = StringIO()
|
| 733 |
+
df.to_csv(csv_buffer, index=False, encoding='utf-8')
|
| 734 |
|
|
|
|
|
|
|
| 735 |
st.download_button(
|
| 736 |
+
label="π CSV",
|
| 737 |
+
data=csv_buffer.getvalue(),
|
| 738 |
+
file_name=f"{doc_type.value}_{timestamp}.csv",
|
| 739 |
+
mime="text/csv",
|
| 740 |
+
width="content"
|
| 741 |
)
|
| 742 |
+
except:
|
| 743 |
+
pass
|
| 744 |
+
|
| 745 |
+
# Raw text download
|
| 746 |
+
st.download_button(
|
| 747 |
+
label="π TXT",
|
| 748 |
+
data=output_text,
|
| 749 |
+
file_name=f"{doc_type.value}_{timestamp}.txt",
|
| 750 |
+
mime="text/plain",
|
| 751 |
+
width="content"
|
| 752 |
+
)
|
| 753 |
|
| 754 |
+
# Show raw JSON in expander
|
| 755 |
+
with st.expander("π View Raw JSON"):
|
| 756 |
+
st.json(parsed_data)
|
| 757 |
|
| 758 |
+
else:
|
| 759 |
+
# Plain text display
|
| 760 |
+
st.text_area(
|
| 761 |
+
"Extracted Text:",
|
| 762 |
+
value=output_text,
|
| 763 |
+
height=400,
|
| 764 |
+
label_visibility="collapsed"
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
# Download
|
| 768 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 769 |
+
st.download_button(
|
| 770 |
+
label="πΎ Download as TXT",
|
| 771 |
+
data=output_text,
|
| 772 |
+
file_name=f"extracted_text_{timestamp}.txt",
|
| 773 |
+
mime="text/plain"
|
| 774 |
+
)
|
| 775 |
|
| 776 |
# Footer
|
| 777 |
st.markdown("---")
|