Spaces:
Running
Running
Luis J Camargo
feat: Add a JSON output tab to display raw inference results for both Document and Visual Language tasks.
3141d0f | import os | |
| import io | |
| import json | |
| import base64 | |
| import re | |
| import logging | |
| import sys | |
| import yaml | |
| import traceback | |
| import subprocess | |
| from typing import Dict, List, Tuple, Any, Optional | |
| import time | |
| import gradio as gr | |
| from PIL import Image | |
| import requests | |
| from urllib.parse import urlparse | |
| from huggingface_hub import snapshot_download | |
| # --- Configuration --- | |
| LOGGING_FORMAT = '%(asctime)s [%(levelname)s] %(name)s: %(message)s' | |
| logging.basicConfig(level=logging.INFO, format=LOGGING_FORMAT, handlers=[logging.StreamHandler(sys.stdout)]) | |
| logger = logging.getLogger("TachiwinDocOCR") | |
| REPO_ID = "tachiwin/Tachiwin-OCR-1.5" | |
| # The YAML file provided by the user or generated | |
| CONFIG_FILE = "default.yaml" | |
| # Fallback generated if default.yaml doesn't exist | |
| GENERATED_CONFIG = "PaddleOCR-VL.yaml" | |
| OUTPUT_DIR = "output" | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # LATEX Configuration for Gradio | |
| LATEX_DELIMS = [ | |
| {"left": "$$", "right": "$$", "display": True}, | |
| {"left": "$", "right": "$", "display": False}, | |
| {"left": "\\(", "right": "\\)", "display": False}, | |
| {"left": "\\[", "right": "\\]", "display": True}, | |
| ] | |
| # --- Paddle imports and Diagnostic --- | |
| PADDLE_AVAILABLE = False | |
| try: | |
| import paddle | |
| import paddlex | |
| from paddlex import create_pipeline | |
| PADDLE_AVAILABLE = True | |
| logger.info(f"Paddle libraries loaded. PaddleX version: {getattr(paddlex, '__version__', 'Unknown')}") | |
| except ImportError as e: | |
| logger.error(f"Import Error: {e}") | |
| except Exception as e: | |
| logger.error(f"Unexpected error during import: {e}") | |
| # --- Model Initialization --- | |
| pipeline = None | |
| def setup_pipeline(): | |
| global pipeline | |
| if not PADDLE_AVAILABLE: | |
| logger.error("Skipping pipeline setup because Paddle is not available.") | |
| return | |
| try: | |
| logger.info("π Starting Tachiwin Doc OCR Pipeline Setup...") | |
| # 1. Download Model from Hugging Face Hub | |
| logger.info(f"π¦ Downloading custom model from HF: {REPO_ID}...") | |
| local_model_path = snapshot_download(repo_id=REPO_ID) | |
| logger.info(f"β Model downloaded to: {local_model_path}") | |
| target_config = None | |
| # Use existing default.yaml if present | |
| if os.path.exists(CONFIG_FILE): | |
| logger.info(f"β Found existing configuration: {CONFIG_FILE}") | |
| target_config = CONFIG_FILE | |
| else: | |
| logger.info(f"β οΈ {CONFIG_FILE} not found. Generating default configuration via paddlex CLI...") | |
| try: | |
| subprocess.run( | |
| ["paddlex", "--get_pipeline_config", "PaddleOCR-VL-1.5", "--save_path", "./"], | |
| capture_output=True, text=True, check=True | |
| ) | |
| if os.path.exists(GENERATED_CONFIG): | |
| target_config = GENERATED_CONFIG | |
| logger.info(f"β Generated {target_config}") | |
| else: | |
| logger.error(f"β CLI generation failed to produce {GENERATED_CONFIG}") | |
| logger.info(f"Directory contents: {os.listdir('.')}") | |
| return | |
| except Exception as e: | |
| logger.error(f"β Failed to run paddlex CLI: {e}") | |
| return | |
| # Load and verify/update config | |
| logger.info(f"π Loading YAML from {target_config}...") | |
| with open(target_config, 'r', encoding='utf-8') as f: | |
| config_data = yaml.safe_load(f) | |
| # Update model_dir to the LOCAL path | |
| updated = False | |
| def update_config(d): | |
| nonlocal updated | |
| for k, v in d.items(): | |
| if k == 'VLRecognition' and isinstance(v, dict): | |
| if v.get('model_dir') != local_model_path: | |
| logger.info(f"π§ Updating VLRecognition model_dir to local path: {local_model_path}") | |
| v['model_dir'] = local_model_path | |
| updated = True | |
| elif isinstance(v, dict): | |
| update_config(v) | |
| update_config(config_data) | |
| if updated: | |
| with open(target_config, 'w', encoding='utf-8') as f: | |
| yaml.dump(config_data, f, default_flow_style=False) | |
| logger.info(f"πΎ Updated configuration saved to {target_config}") | |
| # Log the config being used | |
| logger.info(f"--- [START] {target_config} CONTENT ---") | |
| print(yaml.dump(config_data, default_flow_style=False)) | |
| logger.info(f"--- [END] {target_config} CONTENT ---") | |
| # Initialize pipeline using the recommended PaddleX way | |
| logger.info(f"βοΈ Initializing pipeline with create_pipeline(pipeline={target_config})") | |
| pipeline = create_pipeline(pipeline=target_config) | |
| logger.info("β¨ Pipeline initialized successfully!") | |
| except Exception as e: | |
| logger.error(f"π₯ CRITICAL: Pipeline Setup Failed") | |
| logger.error(traceback.format_exc()) | |
| # Initial setup | |
| if PADDLE_AVAILABLE: | |
| setup_pipeline() | |
| # --- Helper Functions --- | |
| def image_to_base64_data_url(filepath: str) -> str: | |
| try: | |
| ext = os.path.splitext(filepath)[1].lower() | |
| mime_types = { | |
| ".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".png": "image/png", | |
| ".gif": "image/gif", ".webp": "image/webp", ".bmp": "image/bmp" | |
| } | |
| mime_type = mime_types.get(ext, "image/jpeg") | |
| with open(filepath, "rb") as image_file: | |
| encoded_string = base64.b64encode(image_file.read()).decode("utf-8") | |
| return f"data:{mime_type};base64,{encoded_string}" | |
| except Exception as e: | |
| logger.error(f"Error encoding image to Base64: {e}") | |
| return "" | |
| def _escape_inequalities_in_math(md: str) -> str: | |
| if not md: | |
| return "" | |
| # Safety: Only process if math delimiters are likely present | |
| if "$" not in md and "\\[" not in md and "\\(" not in md: | |
| return md | |
| _MATH_PATTERNS = [ | |
| re.compile(r"\$$([\s\S]+?)\$$"), | |
| re.compile(r"\$([^\$]+?)\$"), | |
| re.compile(r"\\\[([\s\S]+?)\\\]"), | |
| re.compile(r"\\\(([\s\S]+?)\\\)"), | |
| ] | |
| def fix(s: str) -> str: | |
| s = s.replace("<=", r" \le ").replace(">=", r" \ge ") | |
| s = s.replace("β€", r" \le ").replace("β₯", r" \ge ") | |
| s = s.replace("<", r" \lt ").replace(">", r" \gt ") | |
| return s | |
| for pat in _MATH_PATTERNS: | |
| md = pat.sub(lambda m: m.group(0).replace(m.group(1), fix(m.group(1))), md) | |
| return md | |
| # Removed update_preview_visibility as gr.Image handles previews natively. | |
| # --- Inference Logic --- | |
| def run_inference(img_path, task_type="ocr", progress=gr.Progress()): | |
| if not PADDLE_AVAILABLE: | |
| yield "β Paddle backend not installed.", "", "", "" | |
| return | |
| if pipeline is None: | |
| yield "β Pipeline is not initialized. Check server logs for error details.", "", "", "" | |
| return | |
| if not img_path: | |
| yield "β οΈ No image provided.", "", "", "" | |
| return | |
| try: | |
| logger.info(f"--- Inference Start: {task_type} ---") | |
| progress(0, desc="Initializing...") | |
| output = pipeline.predict(input=img_path) | |
| md_content = "" | |
| json_content = "" | |
| vis_html = "" | |
| run_id = f"run_{int(time.time())}" | |
| run_output_dir = os.path.join(OUTPUT_DIR, run_id) | |
| os.makedirs(run_output_dir, exist_ok=True) | |
| logger.info(f"will iterate") | |
| for i, res in enumerate(output): | |
| logger.info(f"Processing segment {i+1}...") | |
| progress(None, desc=f"Processing segment {i+1}...") | |
| # Save results | |
| res.save_to_json(save_path=run_output_dir) | |
| res.save_to_markdown(save_path=run_output_dir) | |
| res.print() | |
| # Read back generated files | |
| fnames = os.listdir(run_output_dir) | |
| for fname in fnames: | |
| logger.info(f"Processing file {fname}...") | |
| fpath = os.path.join(run_output_dir, fname) | |
| if fname.endswith(".md"): | |
| logger.info(f"Processing MD file {fname}...") | |
| with open(fpath, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| logger.info(f"MD content: {content}") | |
| if content not in md_content: | |
| md_content += content + "\n\n" | |
| elif fname.endswith(".json"): | |
| with open(fpath, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| json_content += content + "\n\n" | |
| elif fname.endswith((".png", ".jpg", ".jpeg")) and ("res" in fname or "vis" in fname): | |
| vis_src = image_to_base64_data_url(fpath) | |
| new_vis = f'<div style="margin-bottom:20px; border: 2px solid #10b981; border-radius: 12px; overflow: hidden; background:white;">' | |
| new_vis += f'<img src="{vis_src}" alt="Vis {i+1}" style="width:100%;"></div>' | |
| if new_vis not in vis_html: | |
| vis_html += new_vis | |
| logger.info(f"Finished processing segment {i+1}") | |
| md_preview = _escape_inequalities_in_math(md_content) | |
| yield md_preview, md_content, vis_html, json_content | |
| if not md_content: | |
| md_content = "β οΈ Finished but no content was recognized." | |
| yield md_content, md_content, "", "" | |
| logger.info("--- Inference Finished Successfully ---") | |
| progress(1.0, desc="β Complete") | |
| except Exception as e: | |
| logger.error(f"β Inference Error: {e}") | |
| logger.error(traceback.format_exc()) | |
| yield f"β Error: {str(e)}", "", "", "" | |
| return | |
| # --- UI Components --- | |
| custom_css = """ | |
| body, .gradio-container { font-family: 'Inter', system-ui, sans-serif; } | |
| .app-header { | |
| text-align: center; | |
| padding: 2.5rem; | |
| background: linear-gradient(135deg, #0284c7 0%, #10b981 100%); | |
| color: white; | |
| border-radius: 1.5rem; | |
| margin-bottom: 2rem; | |
| box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1); | |
| } | |
| .app-header h1 { color: white !important; font-weight: 800; font-size: 2.5rem; } | |
| .notice { background: #f0fdf4; border: 1px solid #bbf7d0; color: #166534; padding: 1rem; border-radius: 1rem; margin-bottom: 2rem; font-weight: 500;} | |
| .output-box { border: 1px solid #e2e8f0 !important; border-radius: 1rem !important; } | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.HTML( | |
| """ | |
| <div class="app-header"> | |
| <h1>π Tachiwin Document Parsing OCR π¦‘</h1> | |
| <p>Advancing linguistic rights with state-of-the-art document parsing</p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(elem_classes=["notice"]): | |
| status_text = "Initialized" if pipeline else "Initializing/Failed" | |
| gr.Markdown(f"**β‘ Status:** {status_text} | **Model:** `{REPO_ID}` | **Hardware:** CPU") | |
| with gr.Tabs(): | |
| # Document Parsing Tab | |
| with gr.Tab("π Full Document Parsing"): | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| file_doc = gr.Image(label="Upload Image", type="filepath") | |
| btn_parse = gr.Button("οΏ½ Start Parsing", variant="primary") | |
| with gr.Row(): | |
| chart_switch = gr.Checkbox(label="Chart OCR", value=True) | |
| unwarp_switch = gr.Checkbox(label="Unwarping", value=False) | |
| with gr.Column(scale=7): | |
| with gr.Tabs(): | |
| with gr.Tab("π Markdown View"): | |
| md_preview_doc = gr.Markdown(latex_delimiters=LATEX_DELIMS, elem_classes="output-box") | |
| with gr.Tab("πΌοΈ Visual Results"): | |
| vis_image_doc = gr.HTML('<div style="text-align:center; color:#94a3b8; padding: 50px;">Results will appear here.</div>') | |
| with gr.Tab("π Raw Source"): | |
| md_raw_doc = gr.Code(language="markdown") | |
| with gr.Tab("πΎ JSON Feed"): | |
| json_doc = gr.Code(language="json") | |
| def parse_doc_wrapper(fp, ch, uw, progress=gr.Progress()): | |
| if not fp: | |
| yield "β οΈ Please upload an image.", "", "", "" | |
| return | |
| # Initial yield to force loading indicators on all tabs | |
| yield "β Initializing...", gr.update(value="<p>β Processing...</p>"), "β Initializing...", "{}" | |
| for res_preview, res_raw, res_vis, res_json in run_inference(fp, task_type="Document", progress=progress): | |
| yield res_preview, res_vis, res_raw, res_json | |
| btn_parse.click( | |
| parse_doc_wrapper, | |
| [file_doc, chart_switch, unwarp_switch], | |
| [md_preview_doc, vis_image_doc, md_raw_doc, json_doc], | |
| show_progress="full" | |
| ) | |
| # Element Recognition Tab | |
| with gr.Tab("π§© Specific Recognition"): | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| file_vl = gr.Image(label="Upload Element", type="filepath") | |
| with gr.Row(): | |
| btn_ocr = gr.Button("Text", variant="secondary") | |
| btn_formula = gr.Button("Formula", variant="secondary") | |
| btn_table = gr.Button("Table", variant="secondary") | |
| with gr.Column(scale=7): | |
| with gr.Tabs(): | |
| with gr.Tab("π Result"): | |
| md_preview_vl = gr.Markdown(latex_delimiters=LATEX_DELIMS, elem_classes="output-box") | |
| with gr.Tab("π Source"): | |
| md_raw_vl = gr.Code(language="markdown") | |
| with gr.Tab("πΎ JSON Feed"): | |
| json_vl = gr.Code(language="json") | |
| def run_vl_wrapper(fp, prompt, progress=gr.Progress()): | |
| if not fp: | |
| yield "β οΈ Please upload an image.", "", "" | |
| return | |
| yield "β Initializing...", "β Initializing...", "{}" | |
| for res_preview, res_raw, _, res_json in run_inference(fp, task_type=prompt, progress=progress): | |
| yield res_preview, res_raw, res_json | |
| for btn, prompt in [(btn_ocr, "Text"), (btn_formula, "Formula"), (btn_table, "Table")]: | |
| btn.click( | |
| run_vl_wrapper, | |
| [file_vl, gr.State(prompt)], | |
| [md_preview_vl, md_raw_vl, json_vl], | |
| show_progress="full" | |
| ) | |
| # Spotting Tab | |
| with gr.Tab("π Feature Spotting"): | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| file_spot = gr.Image(label="Target Image", type="filepath") | |
| btn_run_spot = gr.Button("π― Run Spotting", variant="primary") | |
| with gr.Column(scale=7): | |
| with gr.Tabs(): | |
| with gr.Tab("πΌοΈ Detection"): | |
| vis_image_spot = gr.HTML('<div style="text-align:center; color:#94a3b8; padding: 50px;">Bboxes view.</div>') | |
| with gr.Tab("πΎ JSON Feed"): | |
| json_spot = gr.Code(label="JSON", language="json") | |
| def run_spotting_wrapper(fp, progress=gr.Progress()): | |
| if not fp: | |
| yield "", "" | |
| return | |
| for _, _, vis, js in run_inference(fp, task_type="Spotting", progress=progress): | |
| yield vis, js | |
| btn_run_spot.click( | |
| run_spotting_wrapper, | |
| file_spot, | |
| [vis_image_spot, json_spot], | |
| show_progress="full" | |
| ) | |
| gr.Markdown("--- \n *Tachiwin Project: Indigenous Languages of Mexico.*") | |
| if __name__ == "__main__": | |
| demo.queue().launch(theme=gr.themes.Ocean(), css=custom_css) | |