Spaces:
Sleeping
Sleeping
| import spaces | |
| import streamlit as st | |
| from PIL import Image | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from transformers import ( | |
| AutoProcessor, | |
| AutoModelForImageTextToText, | |
| AutoModelForCausalLM, | |
| ) | |
| import json | |
| import cv2 | |
| import numpy as np | |
| import pandas as pd | |
| from io import BytesIO, StringIO | |
| import datetime | |
| from enum import Enum | |
| from typing import Optional, Tuple | |
| # ---------------------------- | |
| # Optional dependency (recommended for dots.ocr) | |
| # ---------------------------- | |
| try: | |
| from qwen_vl_utils import process_vision_info # type: ignore | |
| except Exception: | |
| process_vision_info = None | |
| # ======================================== | |
| # MODEL TYPES | |
| # ======================================== | |
| class OCRModel(str, Enum): | |
| GLM_OCR = "glm_ocr" | |
| DOTS_OCR = "dots_ocr" | |
| MODEL_UI = { | |
| OCRModel.GLM_OCR: {"name": "GLM-OCR", "icon": "🟦", "hf_id": "zai-org/GLM-OCR"}, | |
| OCRModel.DOTS_OCR: {"name": "dots.ocr", "icon": "🟩", "hf_id": "rednote-hilab/dots.ocr"}, | |
| } | |
| # ======================================== | |
| # DOCUMENT TYPES & TEMPLATES | |
| # ======================================== | |
| class DocumentType(str, Enum): | |
| """Supported document types""" | |
| GENERAL = "general" | |
| FULL_JSON_SCHEMA = "full_json_schema" | |
| SIMPLE_TITLE_JSON = "simple_title_json" | |
| LOCALIZED_TITLE_JSON = "localized_title_json" | |
| GROUNDED_TITLE_JSON = "grounded_title_json" | |
| HANDWRITTEN = "handwritten" | |
| DOCUMENT_TEMPLATES = { | |
| DocumentType.GENERAL: { | |
| "name": "General Text", | |
| "description": "Extract all text from any document", | |
| "prompt": "Extract all text from this image. Preserve the layout and structure. Output plain text.", | |
| "icon": "📄" | |
| }, | |
| DocumentType.FULL_JSON_SCHEMA: { | |
| "name": "Full Json Schema", | |
| "description": "Extract structured data from this cover page", | |
| "prompt": """Analyze this thesis/dissertation cover image and extract ONLY visible information. | |
| CRITICAL: Only extract information that is CLEARLY VISIBLE on the page. | |
| DO NOT invent, guess, or hallucinate any data. If a field is not visible, use null. | |
| Return ONLY valid JSON with this exact structure: | |
| { | |
| "title": "Main title of the thesis or dissertation as it appears on the title page", | |
| "subtitle": "Subtitle or remainder of the title, usually following a colon; null if not present", | |
| "author": "Full name of the author (student) who wrote the thesis or dissertation", | |
| "degree_type": "Academic degree sought by the author (e.g. PhD, Doctorate, Master's degree, Master's thesis)", | |
| "discipline": "Academic field or discipline of the thesis if explicitly stated; null if not present. Possible values: Mathématiques|Physics|Biology|others", | |
| "granting_institution": "Institution where the thesis was submitted and the degree is granted (degree-granting institution)", | |
| "doctoral_school": "Doctoral school or graduate program, if explicitly mentioned; null if not present", | |
| "co_tutelle_institutions": "List of institutions involved in a joint supervision or co-tutelle agreement; empty list if none", | |
| "partner_institutions": "List of partner institutions associated with the thesis but not granting the degree; empty list if none", | |
| "defense_year": "Year the thesis or dissertation was defended, in YYYY format; null if not visible", | |
| "defense_place": "City or place where the defense took place, if stated; null if not present", | |
| "thesis_advisor": "Main thesis advisor or supervisor (director of thesis); full name; null if not present", | |
| "co_advisors": "List of co-advisors or co-supervisors if explicitly mentioned; full names; empty list if none", | |
| "jury_president": "President or chair of the thesis examination committee, if specified; null if not present", | |
| "reviewers": "List of reviewers or rapporteurs of the thesis, if specified; full names; empty list if none", | |
| "committee_members": "List of other thesis committee or jury members, excluding advisor and reviewers; full names; empty list if none", | |
| "language": "Language in which the thesis is written, if explicitly stated; null if not present", | |
| "confidence": "Confidence score between 0.0 and 1.0 indicating reliability of the extracted metadata" | |
| } | |
| IMPORTANT: Return null for any field where information is NOT clearly visible. | |
| Return ONLY the JSON, no explanation.""", | |
| "icon": "🆔" | |
| }, | |
| DocumentType.SIMPLE_TITLE_JSON: { | |
| "name": "Simple Title Json", | |
| "description": "Extract title from this cover page", | |
| "prompt": """Extract the document title from this cover page. | |
| Output ONLY valid JSON: | |
| { | |
| "title": "" | |
| }""", | |
| "icon": "🧾" | |
| }, | |
| DocumentType.LOCALIZED_TITLE_JSON: { | |
| "name": "Localized Title Json", | |
| "description": "Extract localized title from this cover page", | |
| "prompt": """Extract the document title from the middle central block of this cover page. | |
| Output ONLY valid JSON: | |
| { | |
| "title": "" | |
| }""", | |
| "icon": "🧾" | |
| }, | |
| DocumentType.GROUNDED_TITLE_JSON: { | |
| "name": "Grounded Title Json", | |
| "description": "Extract localized title from this cover page", | |
| "prompt": """Extract the document title usually located around (0.5015,0.442) from this cover page. | |
| Output ONLY valid JSON: | |
| { | |
| "title": "" | |
| }""", | |
| "icon": "📋" | |
| }, | |
| DocumentType.HANDWRITTEN: { | |
| "name": "Handwritten Note", | |
| "description": "Extract text from handwritten documents", | |
| "prompt": "Extract all handwritten text from this image. Output plain text, preserving line breaks.", | |
| "icon": "✍️" | |
| } | |
| } | |
| # ======================================== | |
| # MODEL LOADING | |
| # ======================================== | |
| def load_glm_ocr(): | |
| """Load GLM-OCR model (cached)""" | |
| model_name = MODEL_UI[OCRModel.GLM_OCR]["hf_id"] | |
| with st.spinner("🔄 Loading GLM-OCR... (first time may take 1–3 minutes)"): | |
| processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| model_name, | |
| torch_dtype=dtype, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ) | |
| if not torch.cuda.is_available(): | |
| model = model.to(device) | |
| model.eval() | |
| return processor, model, device | |
| def load_dots_ocr(): | |
| """Load dots.ocr model (cached)""" | |
| repo_id = MODEL_UI[OCRModel.DOTS_OCR]["hf_id"] | |
| # dots.ocr recommends avoiding '.' in local directory names (workaround mentioned in their docs) | |
| model_path = "./models/DotsOCR" | |
| snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=model_path, | |
| local_dir_use_symlinks=False, | |
| ) | |
| with st.spinner("🔄 Loading dots.ocr... (first time may take 1–3 minutes)"): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| #dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| dtype = torch.float32 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| # attn_implementation="flash_attention_2", # optional if available | |
| torch_dtype=dtype, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| ) | |
| processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) | |
| if not torch.cuda.is_available(): | |
| model = model.to(device) | |
| model.eval() | |
| return processor, model, device | |
| def get_loaded_model(selected: OCRModel): | |
| """Return (processor, model, device) for selected model, cached by Streamlit.""" | |
| if selected == OCRModel.GLM_OCR: | |
| return load_glm_ocr() | |
| return load_dots_ocr() | |
| # ======================================== | |
| # IMAGE PREPROCESSING | |
| # ======================================== | |
| def preprocess_image( | |
| image: Image.Image, | |
| enhance_contrast: bool = False, | |
| denoise: bool = False, | |
| sharpen: bool = False, | |
| auto_rotate: bool = False, | |
| prevent_cropping: bool = False | |
| ) -> Image.Image: | |
| if prevent_cropping and not auto_rotate: | |
| raise Exception("Auto-Rotate must be enabled when Prevent-Cropping is active") | |
| img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY) | |
| if denoise: | |
| gray = cv2.fastNlMeansDenoising(gray, None, 10, 7, 21) | |
| if enhance_contrast: | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
| gray = clahe.apply(gray) | |
| if sharpen: | |
| kernel = np.array([[-1, -1, -1], | |
| [-1, 9, -1], | |
| [-1, -1, -1]]) | |
| gray = cv2.filter2D(gray, -1, kernel) | |
| if auto_rotate: | |
| blurred = cv2.GaussianBlur(gray, (5, 5), 0) | |
| edges = cv2.Canny(blurred, 50, 150) | |
| lines = cv2.HoughLinesP(edges, 1, np.pi / 180, 100, minLineLength=80, maxLineGap=10) | |
| if lines is not None and len(lines) > 0: | |
| angles = [] | |
| for line in lines: | |
| x1, y1, x2, y2 = line[0] | |
| angle = np.degrees(np.arctan2(y2 - y1, x2 - x1)) | |
| if -45 < angle < 45: | |
| angles.append(angle) | |
| if angles: | |
| median_angle = float(np.median(angles)) | |
| if abs(median_angle) > 45: | |
| median_angle -= 90 * np.sign(median_angle) | |
| (h0, w0) = gray.shape[:2] | |
| center = (w0 // 2, h0 // 2) | |
| M = cv2.getRotationMatrix2D(center, median_angle, 1.0) | |
| out_w, out_h = w0, h0 | |
| if prevent_cropping: | |
| cos = np.abs(M[0, 0]) | |
| sin = np.abs(M[0, 1]) | |
| out_w = int((h0 * sin) + (w0 * cos)) | |
| out_h = int((h0 * cos) + (w0 * sin)) | |
| M[0, 2] += (out_w / 2) - center[0] | |
| M[1, 2] += (out_h / 2) - center[1] | |
| gray = cv2.warpAffine( | |
| gray, | |
| M, | |
| (out_w, out_h), | |
| flags=cv2.INTER_CUBIC, | |
| borderMode=cv2.BORDER_REPLICATE | |
| ) | |
| return Image.fromarray(cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)) | |
| # ======================================== | |
| # OCR EXTRACTION | |
| # ======================================== | |
| def _now_ms() -> int: | |
| return int(datetime.datetime.now().timestamp() * 1000) | |
| def extract_text_glm( | |
| image: Image.Image, | |
| prompt: str, | |
| max_tokens: int, | |
| processor, | |
| model, | |
| device: str | |
| ) -> Tuple[str, int]: | |
| start_ms = _now_ms() | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt} | |
| ] | |
| } | |
| ] | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items() if k != "token_type_ids"} | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=False, | |
| temperature=0.0 | |
| ) | |
| output_text = processor.decode( | |
| generated_ids[0][inputs["input_ids"].shape[1]:], | |
| skip_special_tokens=True | |
| ) | |
| del inputs, generated_ids | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return output_text, _now_ms() - start_ms | |
| def extract_text_dots( | |
| image: Image.Image, | |
| prompt: str, | |
| max_tokens: int, | |
| processor, | |
| model, | |
| device: str | |
| ) -> Tuple[str, int]: | |
| """ | |
| dots.ocr transformers inference (matches their model-card approach): | |
| - apply_chat_template(tokenize=False) | |
| - process_vision_info(messages) to get image_inputs/video_inputs | |
| - processor(text=[...], images=..., videos=..., return_tensors="pt") | |
| - generate, then trim input tokens, then decode | |
| """ | |
| start_ms = _now_ms() | |
| # dots.ocr examples use {"type":"image","image": <path>} but PIL works in practice with most processors | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt} | |
| ] | |
| } | |
| ] | |
| text = processor.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| if process_vision_info is not None: | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| else: | |
| # Fallback: no video, single image | |
| image_inputs, video_inputs = [image], None | |
| inputs = processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| # Some processors return a BatchEncoding with .to(...) | |
| inputs = inputs.to(device) | |
| # some processors add keys that this model doesn't use | |
| unused_keys = ["mm_token_type_ids", "token_type_ids"] | |
| for k in unused_keys: | |
| if k in inputs: | |
| inputs.pop(k) | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=False, | |
| temperature=0.0 | |
| ) | |
| # Trim prompt tokens | |
| in_ids = inputs["input_ids"] | |
| trimmed = [] | |
| for i in range(generated_ids.shape[0]): | |
| trimmed.append(generated_ids[i][in_ids.shape[1]:]) | |
| output_text = processor.batch_decode( | |
| trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False | |
| )[0] | |
| del inputs, generated_ids, trimmed | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return output_text, _now_ms() - start_ms | |
| def extract_text( | |
| selected_model: OCRModel, | |
| image: Image.Image, | |
| prompt: str, | |
| max_tokens: int, | |
| ) -> Tuple[str, int]: | |
| processor, model, device = get_loaded_model(selected_model) | |
| if selected_model == OCRModel.GLM_OCR: | |
| return extract_text_glm(image, prompt, max_tokens, processor, model, device) | |
| return extract_text_dots(image, prompt, max_tokens, processor, model, device) | |
| # ======================================== | |
| # STREAMLIT UI | |
| # ======================================== | |
| st.set_page_config( | |
| page_title="Universal OCR Scanner", | |
| page_icon="🔍", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Initialize session state | |
| for k, v in { | |
| "should_process": False, | |
| "has_results": False, | |
| "output_text": "", | |
| "processing_time": 0, | |
| "doc_type": DocumentType.GENERAL, | |
| "selected_model": OCRModel.GLM_OCR, | |
| "current_file": None, | |
| }.items(): | |
| if k not in st.session_state: | |
| st.session_state[k] = v | |
| st.title("🔍 Universal OCR Scanner") | |
| st.markdown("Extract text and structured data from **any document** - receipts, IDs, invoices, forms, and more!") | |
| with st.sidebar: | |
| st.header("🧠 Model") | |
| selected_model = st.radio( | |
| "Select OCR model:", | |
| options=list(OCRModel), | |
| format_func=lambda x: f"{MODEL_UI[x]['icon']} {MODEL_UI[x]['name']}", | |
| index=list(OCRModel).index(st.session_state.selected_model), | |
| ) | |
| st.session_state.selected_model = selected_model | |
| st.header("📋 Document Type") | |
| doc_type = st.radio( | |
| "Select document type:", | |
| options=list(DocumentType), | |
| format_func=lambda x: f"{DOCUMENT_TEMPLATES[x]['icon']} {DOCUMENT_TEMPLATES[x]['name']}", | |
| index=list(DocumentType).index(st.session_state.doc_type), | |
| ) | |
| st.session_state.doc_type = doc_type | |
| st.info(DOCUMENT_TEMPLATES[doc_type]['description']) | |
| st.markdown("---") | |
| st.header("⚙️ Image Enhancement") | |
| with st.expander("🎨 Preprocessing Options", expanded=False): | |
| enhance_contrast = st.checkbox("Enhance Contrast", value=False, help="Improve visibility of faded text") | |
| denoise = st.checkbox("Reduce Noise", value=False, help="Remove image noise and artifacts") | |
| sharpen = st.checkbox("Sharpen Text", value=False, help="Make text edges crisper") | |
| auto_rotate = st.checkbox("Auto-Rotate", value=False, help="Automatically straighten tilted documents") | |
| prevent_cropping = st.checkbox("Prevent-Cropping", value=False, help="Prevent cropping when rotate") | |
| st.markdown("---") | |
| with st.expander("🔧 Advanced Options", expanded=False): | |
| show_preprocessed = st.checkbox("Show Preprocessed Image", value=False) | |
| max_tokens = st.slider("Max Output Tokens", 512, 4096, 2048, 256, help="Increase for longer documents") | |
| custom_prompt = st.checkbox("Use Custom Prompt", value=False) | |
| st.markdown("---") | |
| st.caption("💡 **Tips:**") | |
| st.caption("• Use good lighting") | |
| st.caption("• Avoid shadows") | |
| st.caption("• Keep text horizontal") | |
| st.caption("• Use high resolution images") | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.subheader("📤 Upload Document") | |
| upload_tab, camera_tab = st.tabs(["📁 Upload File", "📸 Take Photo"]) | |
| image: Optional[Image.Image] = None | |
| with upload_tab: | |
| uploaded_file = st.file_uploader( | |
| "Choose an image...", | |
| type=["jpg", "jpeg", "png", "webp"], | |
| help="Supported formats: JPG, PNG, WEBP" | |
| ) | |
| if uploaded_file is not None: | |
| image = Image.open(uploaded_file).convert("RGB") | |
| if st.session_state.current_file != uploaded_file.name: | |
| st.session_state.current_file = uploaded_file.name | |
| st.session_state.has_results = False | |
| with camera_tab: | |
| camera_picture = st.camera_input("Take a photo") | |
| if camera_picture is not None: | |
| image = Image.open(BytesIO(camera_picture.getvalue())).convert("RGB") | |
| st.session_state.has_results = False | |
| if image is not None: | |
| st.image(image, caption="Original Image", width="content") | |
| with col2: | |
| st.subheader("📋 Extraction Settings") | |
| if custom_prompt: | |
| prompt = st.text_area( | |
| "Custom Extraction Prompt:", | |
| value=DOCUMENT_TEMPLATES[doc_type]['prompt'], | |
| height=200, | |
| help="Customize how the OCR extracts data", | |
| key="custom_prompt_text" | |
| ) | |
| else: | |
| prompt = DOCUMENT_TEMPLATES[doc_type]['prompt'] | |
| st.code(prompt, language="text") | |
| if image is not None: | |
| if st.button("🚀 Extract Text", type="primary", width="content", key="extract_button"): | |
| st.session_state.should_process = True | |
| else: | |
| st.info("👆 Upload or capture an image to begin") | |
| # Processing | |
| if image is not None and st.session_state.get('should_process', False): | |
| st.session_state.should_process = False | |
| with st.spinner("🔄 Processing document..."): | |
| try: | |
| if enhance_contrast or denoise or sharpen or auto_rotate or prevent_cropping: | |
| preprocessed_image = preprocess_image( | |
| image, | |
| enhance_contrast=enhance_contrast, | |
| denoise=denoise, | |
| sharpen=sharpen, | |
| auto_rotate=auto_rotate, | |
| prevent_cropping=prevent_cropping | |
| ) | |
| else: | |
| preprocessed_image = image | |
| if show_preprocessed and preprocessed_image != image: | |
| st.subheader("🔧 Preprocessed Image") | |
| col_a, col_b = st.columns(2) | |
| with col_a: | |
| st.image(image, caption="Original", width="content") | |
| with col_b: | |
| st.image(preprocessed_image, caption="Enhanced", width="content") | |
| output_text, processing_time = extract_text( | |
| selected_model=st.session_state.selected_model, | |
| image=preprocessed_image, | |
| prompt=prompt, | |
| max_tokens=max_tokens | |
| ) | |
| st.session_state.output_text = output_text | |
| st.session_state.processing_time = processing_time | |
| st.session_state.preprocessed_image = preprocessed_image | |
| st.session_state.has_results = True | |
| except Exception as e: | |
| st.error(f"❌ Error during extraction: {str(e)}") | |
| import traceback | |
| with st.expander("Show Error Details"): | |
| st.code(traceback.format_exc()) | |
| st.session_state.has_results = False | |
| # Results | |
| if st.session_state.get('has_results', False): | |
| output_text = st.session_state.output_text | |
| processing_time = st.session_state.processing_time | |
| preprocessed_image = st.session_state.get('preprocessed_image', image) | |
| st.success(f"✅ Extraction complete! ({processing_time}ms)") | |
| is_json = False | |
| parsed_data = None | |
| if doc_type in [ | |
| DocumentType.FULL_JSON_SCHEMA, | |
| DocumentType.SIMPLE_TITLE_JSON, | |
| DocumentType.LOCALIZED_TITLE_JSON, | |
| DocumentType.GROUNDED_TITLE_JSON | |
| ]: | |
| try: | |
| clean_text = output_text | |
| if "```json" in clean_text: | |
| clean_text = clean_text.split("```json")[1].split("```")[0].strip() | |
| elif "```" in clean_text: | |
| clean_text = clean_text.split("```")[1].split("```")[0].strip() | |
| if len(clean_text) > 50000: | |
| st.warning("⚠️ Detected unusually large JSON output. Truncating...") | |
| clean_text = clean_text[:50000] | |
| parsed_data = json.loads(clean_text) | |
| def flatten_dict(d, max_depth=2, current_depth=0): | |
| if current_depth >= max_depth: | |
| return {} | |
| if not isinstance(d, dict): | |
| return d | |
| flattened = {} | |
| for key, value in d.items(): | |
| if isinstance(value, dict): | |
| if current_depth < max_depth - 1: | |
| flattened[key] = flatten_dict(value, max_depth, current_depth + 1) | |
| elif isinstance(value, list): | |
| flattened[key] = value | |
| else: | |
| flattened[key] = value | |
| return flattened | |
| parsed_data = flatten_dict(parsed_data, max_depth=2) | |
| is_json = True | |
| except json.JSONDecodeError: | |
| is_json = False | |
| except Exception as e: | |
| st.warning(f"⚠️ JSON parsing issue: {str(e)}") | |
| is_json = False | |
| st.markdown("---") | |
| st.subheader("📄 Extracted Data") | |
| if is_json and parsed_data: | |
| col_display, col_download = st.columns([2, 1]) | |
| with col_display: | |
| for key, value in parsed_data.items(): | |
| if isinstance(value, dict): | |
| st.markdown(f"**{key.replace('_', ' ').title()}:**") | |
| for k, v in value.items(): | |
| st.text(f" {k}: {v}") | |
| elif isinstance(value, list): | |
| st.markdown(f"**{key.replace('_', ' ').title()}:**") | |
| if value and isinstance(value[0], dict): | |
| df = pd.DataFrame(value) | |
| st.dataframe(df, width="content", hide_index=True) | |
| else: | |
| for item in value: | |
| st.text(f" • {item}") | |
| else: | |
| st.markdown(f"**{key.replace('_', ' ').title()}:** {value}") | |
| with col_download: | |
| st.subheader("💾 Downloads") | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| json_str = json.dumps(parsed_data, ensure_ascii=False, indent=2) | |
| st.download_button( | |
| label="📄 JSON", | |
| data=json_str, | |
| file_name=f"{doc_type.value}_{timestamp}.json", | |
| mime="application/json", | |
| width="content" | |
| ) | |
| try: | |
| flat_data = {} | |
| for k, v in parsed_data.items(): | |
| if isinstance(v, (dict, list)): | |
| flat_data[k] = json.dumps(v, ensure_ascii=False) | |
| else: | |
| flat_data[k] = v | |
| df = pd.DataFrame([flat_data]) | |
| csv_buffer = StringIO() | |
| df.to_csv(csv_buffer, index=False, encoding='utf-8') | |
| st.download_button( | |
| label="📊 CSV", | |
| data=csv_buffer.getvalue(), | |
| file_name=f"{doc_type.value}_{timestamp}.csv", | |
| mime="text/csv", | |
| width="content" | |
| ) | |
| except Exception: | |
| pass | |
| st.download_button( | |
| label="📝 TXT", | |
| data=output_text, | |
| file_name=f"{doc_type.value}_{timestamp}.txt", | |
| mime="text/plain", | |
| width="content" | |
| ) | |
| with st.expander("🔍 View Raw JSON"): | |
| st.json(parsed_data) | |
| else: | |
| st.text_area( | |
| "Extracted Text:", | |
| value=output_text, | |
| height=400, | |
| label_visibility="collapsed" | |
| ) | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| st.download_button( | |
| label="💾 Download as TXT", | |
| data=output_text, | |
| file_name=f"extracted_text_{timestamp}.txt", | |
| mime="text/plain" | |
| ) | |
| # Footer | |
| st.markdown("---") | |
| col_footer1, col_footer2, col_footer3 = st.columns(3) | |
| with col_footer1: | |
| m = st.session_state.selected_model | |
| st.caption(f"⚡ Powered by {MODEL_UI[m]['name']}") | |
| with col_footer2: | |
| # device depends on selected model (they’ll both use same device typically) | |
| _, _, device = get_loaded_model(st.session_state.selected_model) | |
| st.caption(f"🖥️ Device: {device.upper()}") | |
| with col_footer3: | |
| st.caption("🌟 Universal Document Scanner") |