Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import gradio as gr | |
| from PIL import Image | |
| import tempfile | |
| import shutil | |
| from pathlib import Path | |
| from kraken.lib import vgsl | |
| from kraken.lib import models | |
| from kraken import serialization | |
| import logging | |
| import numpy as np | |
| import cv2 | |
| from kraken import blla, rpred | |
| from kraken.containers import BaselineLine | |
| import json | |
| from jinja2 import Environment, FileSystemLoader | |
| import base64 | |
| import io | |
| from jinja2 import Template | |
| import re | |
| import time | |
| # Configure logging | |
| logging.basicConfig(level=logging.WARNING) | |
| logging.getLogger('kraken').setLevel(logging.WARNING) | |
| logging.getLogger('kraken.serialization').setLevel(logging.WARNING) | |
| logging.getLogger('kraken.blla').setLevel(logging.WARNING) | |
| logging.getLogger('kraken.lib.models').setLevel(logging.WARNING) | |
| logger = logging.getLogger(__name__) | |
| # Constants - Use relative paths for Hugging Face | |
| MODELS_DIR = Path("models") | |
| SEG_MODELS_DIR = MODELS_DIR / "seg" | |
| REC_MODELS_DIR = MODELS_DIR / "rec" | |
| # Create Jinja environment | |
| TEMPLATE_DIR = Path("templates") | |
| TEMPLATE_DIR.mkdir(exist_ok=True) | |
| _ENV = Environment(loader=FileSystemLoader(str(TEMPLATE_DIR))) | |
| # Create template files | |
| def create_templates(): | |
| """Create Jinja templates for visualization.""" | |
| # Image template with SVG for visualization | |
| image_template = """ | |
| <div class="visualization-container"> | |
| <div class="image-container"> | |
| <svg width="{{ width }}" height="{{ height }}" viewBox="0 0 {{ width }} {{ height }}"> | |
| <image href="data:image/png;base64,{{ image_base64 }}" width="{{ width }}" height="{{ height }}"/> | |
| {% for line in lines %} | |
| <a class="textline line{{loop.index}}" onmouseover="document.querySelectorAll('.line{{loop.index}}').forEach(element => {element.classList.add('highlighted')});" onmouseout="document.querySelectorAll('*').forEach(element => {element.classList.remove('highlighted')});"> | |
| <path class="line-boundary" d="M {{ line.boundary|join(' L ') }} Z" fill="rgba(0, 128, 255, 0.2)" stroke="none"/> | |
| <path class="line-baseline" d="M {{ line.baseline|join(' L ') }}" stroke="red" stroke-width="1" fill="none"/> | |
| </a> | |
| {% endfor %} | |
| </svg> | |
| </div> | |
| <div class="transcription-container"> | |
| {% for line in lines %} | |
| <span class="textline line{{loop.index}}" onmouseover="document.querySelectorAll('.line{{loop.index}}').forEach(element => {element.classList.add('highlighted')});" onmouseout="document.querySelectorAll('*').forEach(element => {element.classList.remove('highlighted')});"> | |
| <span class="line-number">{{ loop.index }}:</span> | |
| <span class="line-text">{{ line.text }}</span> | |
| {% if line.confidence %} | |
| <span class="line-confidence">({{ "%.2f"|format(line.confidence) }})</span> | |
| {% endif %} | |
| </span> | |
| <br> | |
| {% endfor %} | |
| </div> | |
| </div> | |
| <style> | |
| .visualization-container { | |
| display: flex; | |
| gap: 20px; | |
| max-height: 1000px; | |
| } | |
| .image-container { | |
| flex: 2; | |
| overflow: auto; | |
| border: 1px solid #ddd; | |
| border-radius: 4px; | |
| } | |
| .image-container svg { | |
| display: block; | |
| width: 100%; | |
| height: auto; | |
| max-width: 100%; | |
| } | |
| .transcription-container { | |
| flex: 1; | |
| overflow-y: auto; | |
| padding: 10px; | |
| border: 1px solid #ddd; | |
| border-radius: 4px; | |
| } | |
| /* Synchronize scrolling between containers */ | |
| .image-container, .transcription-container { | |
| scroll-behavior: smooth; | |
| } | |
| .image-container::-webkit-scrollbar, .transcription-container::-webkit-scrollbar { | |
| width: 8px; | |
| } | |
| .image-container::-webkit-scrollbar-track, .transcription-container::-webkit-scrollbar-track { | |
| background: #f1f1f1; | |
| } | |
| .image-container::-webkit-scrollbar-thumb, .transcription-container::-webkit-scrollbar-thumb { | |
| background: #888; | |
| border-radius: 4px; | |
| } | |
| .image-container::-webkit-scrollbar-thumb:hover, .transcription-container::-webkit-scrollbar-thumb:hover { | |
| background: #555; | |
| } | |
| .textline { | |
| padding: 5px; | |
| cursor: pointer; | |
| display: inline-block; | |
| unicode-bidi: bidi-override; | |
| } | |
| .textline:hover, | |
| .textline.highlighted { | |
| background-color: rgba(0, 128, 255, 0.1); | |
| } | |
| .textline:hover .line-boundary, | |
| .textline.highlighted .line-boundary { | |
| fill: rgba(0, 255, 255, 0.3); | |
| } | |
| .textline:hover .line-baseline, | |
| .textline.highlighted .line-baseline { | |
| stroke: yellow; | |
| } | |
| .line-number { | |
| color: #666; | |
| margin-right: 5px; | |
| } | |
| .line-confidence { | |
| color: #888; | |
| font-size: 0.9em; | |
| margin-left: 5px; | |
| } | |
| /* RTL text support */ | |
| .textline[dir="rtl"] { | |
| text-align: right; | |
| } | |
| .textline[dir="ltr"] { | |
| text-align: left; | |
| } | |
| </style> | |
| <script> | |
| // Synchronize scrolling between containers | |
| const imageContainer = document.querySelector('.image-container'); | |
| const textContainer = document.querySelector('.transcription-container'); | |
| function syncScroll(source, target) { | |
| const ratio = target.scrollHeight / source.scrollHeight; | |
| target.scrollTop = source.scrollTop * ratio; | |
| } | |
| imageContainer.addEventListener('scroll', () => syncScroll(imageContainer, textContainer)); | |
| textContainer.addEventListener('scroll', () => syncScroll(textContainer, imageContainer)); | |
| // Function to detect text direction | |
| function detectTextDirection(text) { | |
| const rtlChars = /[\u0591-\u07FF\u200F\u202B\u202E\uFB1D-\uFDFD\uFE70-\uFEFC]/; | |
| return rtlChars.test(text) ? 'rtl' : 'ltr'; | |
| } | |
| // Add direction attribute to text lines | |
| function updateTextDirections() { | |
| document.querySelectorAll('.textline').forEach(line => { | |
| const text = line.textContent; | |
| line.setAttribute('dir', detectTextDirection(text)); | |
| }); | |
| } | |
| // Update text directions when visualization changes | |
| const observer = new MutationObserver(updateTextDirections); | |
| observer.observe(document.body, { childList: true, subtree: true }); | |
| </script> | |
| """ | |
| # Transcription template | |
| transcription_template = """ | |
| <div class="transcription-container" style="max-height: 600px; overflow-y: auto;"> | |
| {% for line in lines %} | |
| <span class="textline line{{loop.index}}" onmouseover="document.querySelectorAll('.line{{loop.index}}').forEach(element => {element.classList.add('highlighted')});" onmouseout="document.querySelectorAll('*').forEach(element => {element.classList.remove('highlighted')});"> | |
| <span class="line-number">{{ loop.index }}:</span> | |
| <span class="line-text">{{ line.text }}</span> | |
| {% if line.confidence %} | |
| <span class="line-confidence">({{ "%.2f"|format(line.confidence) }})</span> | |
| {% endif %} | |
| </span> | |
| <br> | |
| {% endfor %} | |
| </div> | |
| <style> | |
| .textline { | |
| padding: 5px; | |
| cursor: pointer; | |
| display: inline-block; | |
| } | |
| .textline:hover, | |
| .textline.highlighted { | |
| background-color: rgba(0, 128, 255, 0.1); | |
| } | |
| .line-number { | |
| color: #666; | |
| margin-right: 5px; | |
| } | |
| .line-confidence { | |
| color: #888; | |
| font-size: 0.9em; | |
| margin-left: 5px; | |
| } | |
| </style> | |
| """ | |
| # Write templates | |
| with open(TEMPLATE_DIR / "image.html", "w") as f: | |
| f.write(image_template) | |
| with open(TEMPLATE_DIR / "transcription.html", "w") as f: | |
| f.write(transcription_template) | |
| # Embedded template | |
| PAGEXML_TEMPLATE = '''{%+ macro render_line(line) +%} | |
| <TextLine id="{{ line.id }}" {% if line.tags and "type" in line.tags %}custom="structure {type:{{ line.tags["type"] }};}"{% endif %}> | |
| {% if line.boundary %} | |
| <Coords points="{% for point in line.boundary %}{{ point|join(',') }}{% if not loop.last %} {% endif %}{% endfor %}"/> | |
| {% endif %} | |
| {% if line.baseline %} | |
| <Baseline points="{% for point in line.baseline %}{{ point|join(',') }}{% if not loop.last %} {% endif %}{% endfor %}"/> | |
| {% endif %} | |
| {% if line.text is string %} | |
| <TextEquiv{% if line.confidences|length %} conf="{{ (line.confidences|sum / line.confidences|length)|round(4) }}"{% endif %}><Unicode>{{ line.text|e }}</Unicode></TextEquiv> | |
| {% else %} | |
| {% for segment in line.recognition %} | |
| <Word id="segment_{{ segment.index }}"> | |
| {% if segment.boundary %} | |
| <Coords points="{% for point in segment.boundary %}{{ point|join(',') }}{% if not loop.last %} {% endif %}{% endfor %}"/> | |
| {% else %} | |
| <Coords points="{{ segment.bbox[0] }},{{ segment.bbox[1] }} {{ segment.bbox[0] }},{{ segment.bbox[3] }} {{ segment.bbox[2] }},{{ segment.bbox[3] }} {{ segment.bbox[2] }},{{ segment.bbox[1] }}"/> | |
| {% endif %} | |
| {% for char in segment.recognition %} | |
| <Glyph id="char_{{ char.index }}"> | |
| <Coords points="{% for point in char.boundary %}{{ point|join(',') }}{% if not loop.last %} {% endif %}{% endfor %}"/> | |
| <TextEquiv conf="{{ char.confidence|round(4) }}"><Unicode>{{ char.text|e }}</Unicode></TextEquiv> | |
| </Glyph> | |
| {% endfor %} | |
| <TextEquiv conf="{{ (segment.confidences|sum / segment.confidences|length)|round(4) }}"><Unicode>{{ segment.text|e }}</Unicode></TextEquiv> | |
| </Word> | |
| {% endfor %} | |
| {%+ if line.confidences|length %}<TextEquiv conf="{{ (line.confidences|sum / line.confidences|length)|round(4) }}"><Unicode>{% for segment in line.recognition %}{{ segment.text|e }}{% endfor %}</Unicode></TextEquiv>{% endif +%} | |
| {% endif %} | |
| </TextLine> | |
| {%+ endmacro %} | |
| <?xml version="1.0" encoding="UTF-8"?> | |
| <PcGts xmlns="http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15 http://schema.primaresearch.org/PAGE/gts/pagecontent/2019-07-15/pagecontent.xsd"> | |
| <Metadata> | |
| <Creator>kraken {{ metadata.version }}</Creator> | |
| <Created>{{ page.date }}</Created> | |
| <LastChange>{{ page.date }}</LastChange> | |
| </Metadata> | |
| <Page imageFilename="{{ page.name }}" imageWidth="{{ page.size[0] }}" imageHeight="{{ page.size[1] }}" {% if page.base_dir %}readingDirection="{{ page.base_dir }}"{% endif %}> | |
| {% for entity in page.entities %} | |
| {% if entity.type == "region" %} | |
| {% if loop.previtem and loop.previtem.type == 'line' %} | |
| </TextRegion> | |
| {% endif %} | |
| <TextRegion id="{{ entity.id }}" {% if entity.tags and "type" in entity.tags %}custom="structure {type:{{ entity.tags["type"] }};}"{% endif %}> | |
| {% if entity.boundary %}<Coords points="{% for point in entity.boundary %}{{ point|join(',') }}{% if not loop.last %} {% endif %}{% endfor %}"/>{% endif %} | |
| {%- for line in entity.lines -%} | |
| {{ render_line(line) }} | |
| {%- endfor %} | |
| </TextRegion> | |
| {% else %} | |
| {% if not loop.previtem or loop.previtem.type != 'line' %} | |
| <TextRegion id="textblock_{{ loop.index }}"> | |
| <Coords points="0,0 0,{{ page.size[1] }} {{ page.size[0] }},{{ page.size[1] }} {{ page.size[0] }},0"/> | |
| {% endif %} | |
| {{ render_line(entity) }} | |
| {% if loop.last %} | |
| </TextRegion> | |
| {% endif %} | |
| {% endif %} | |
| {% endfor %} | |
| </Page> | |
| </PcGts>''' | |
| def seg_rec_image(image_path, seg_model, rec_model, output_dir=None): | |
| try: | |
| im = Image.open(image_path) | |
| baseline_seg = blla.segment(im, model=seg_model) | |
| # Run recognition and collect full BaselineOCRRecord objects | |
| pred_it = rpred.rpred(network=rec_model, im=im, bounds=baseline_seg, pad=16) | |
| records = [record for record in pred_it] | |
| # Attach recognition results to segmentation lines | |
| for line, rec_line in zip(baseline_seg.lines, records): | |
| # Debug logging for recognition results | |
| logger.debug(f'Recognition result - Prediction: {rec_line.prediction}') | |
| logger.debug(f'Recognition result - Confidences: {rec_line.confidences}') | |
| # Ensure the line has both prediction and confidence values | |
| line.prediction = rec_line.prediction | |
| line.text = rec_line.prediction # Set text field for serialization | |
| # Store per-character confidences | |
| line.confidences = rec_line.confidences # Keep the list of confidences | |
| # Debug logging for line object | |
| logger.debug(f'Line {line.id} - Prediction: {line.prediction}') | |
| logger.debug(f'Line {line.id} - Confidences: {line.confidences}') | |
| # Construct PAGE-XML segmentation only data | |
| pagexml_seg_only = serialization.serialize(baseline_seg, image_size=im.size, template='pagexml', sub_line_segmentation=False) | |
| # Serialize with recognition results | |
| pagexml = serialization.serialize(baseline_seg, | |
| image_size=im.size, | |
| template='custom_pagexml', | |
| template_source='custom', | |
| sub_line_segmentation=False) | |
| base_name = os.path.splitext(os.path.basename(image_path))[0] | |
| if output_dir: | |
| os.makedirs(output_dir, exist_ok=True) | |
| output_path = os.path.join(output_dir, base_name + '.xml') | |
| else: | |
| output_path = os.path.splitext(image_path)[0] + '.xml' | |
| with open(output_path, 'w') as fp: | |
| fp.write(pagexml) | |
| print(f"β Segmented/recognized: {os.path.basename(image_path)} β {os.path.basename(output_path)}") | |
| except Exception as e: | |
| print(f"β Failed to process {image_path}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| def ensure_template_exists(): | |
| """Create the template file if it doesn't exist.""" | |
| template_path = os.path.join(os.path.dirname(__file__), 'custom_pagexml') | |
| if not os.path.exists(template_path): | |
| with open(template_path, 'w', encoding='utf-8') as f: | |
| f.write(PAGEXML_TEMPLATE) | |
| def get_model_files(directory): | |
| """Get list of .mlmodel files from directory.""" | |
| return [f for f in os.listdir(directory) if f.endswith('.mlmodel')] | |
| def load_models(): | |
| """Load all available models.""" | |
| seg_models = {} | |
| rec_models = {} | |
| # Load segmentation models | |
| for model_file in get_model_files(SEG_MODELS_DIR): | |
| try: | |
| model_path = os.path.join(SEG_MODELS_DIR, model_file) | |
| seg_models[model_file] = vgsl.TorchVGSLModel.load_model(model_path) | |
| except Exception as e: | |
| print(f"Error loading segmentation model {model_file}: {str(e)}") | |
| # Load recognition models | |
| for model_file in get_model_files(REC_MODELS_DIR): | |
| try: | |
| model_path = os.path.join(REC_MODELS_DIR, model_file) | |
| rec_models[model_file] = models.load_any(model_path) | |
| except Exception as e: | |
| print(f"Error loading recognition model {model_file}: {str(e)}") | |
| return seg_models, rec_models | |
| def process_image(image, seg_model, rec_model): | |
| """Process image and return segmentation and recognition results.""" | |
| # Run segmentation | |
| baseline_seg = blla.segment(image, model=seg_model) | |
| # Run recognition | |
| pred_it = rpred.rpred(network=rec_model, im=image, bounds=baseline_seg, pad=16, bidi_reordering=True) | |
| records = [record for record in pred_it] | |
| # Attach recognition results to segmentation lines | |
| for line, rec_line in zip(baseline_seg.lines, records): | |
| line.prediction = rec_line.prediction | |
| line.text = rec_line.prediction | |
| line.confidences = rec_line.confidences | |
| return baseline_seg | |
| def render_image(image, baseline_seg): | |
| """Render image with SVG overlay.""" | |
| # Convert image to base64 | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| image_base64 = base64.b64encode(buffered.getvalue()).decode() | |
| # Get image dimensions | |
| width, height = image.size | |
| # Prepare lines data | |
| lines = [] | |
| for line in baseline_seg.lines: | |
| # Convert boundary points to SVG path | |
| boundary_points = [] | |
| for point in line.boundary: | |
| boundary_points.append(f"{point[0]},{point[1]}") | |
| # Convert baseline points to SVG path | |
| baseline_points = [] | |
| for point in line.baseline: | |
| baseline_points.append(f"{point[0]},{point[1]}") | |
| # Get text and determine direction | |
| text = line.text if hasattr(line, 'text') else '' | |
| # Check if text contains RTL characters (Hebrew, Arabic, etc.) | |
| rtl_chars = re.compile(r'[\u0591-\u07FF\u200F\u202B\u202E\uFB1D-\uFDFD\uFE70-\uFEFC\u0600-\u06FF\u0750-\u077F\u08A0-\u08FF\uFB50-\uFDFF\uFE70-\uFEFF]') | |
| is_rtl = bool(rtl_chars.search(text)) | |
| lines.append({ | |
| 'boundary': boundary_points, | |
| 'baseline': baseline_points, | |
| 'text': text, | |
| 'confidence': line.confidence if hasattr(line, 'confidence') else None, | |
| 'is_rtl': is_rtl | |
| }) | |
| # Render template | |
| template = """ | |
| <div class="visualization-container"> | |
| <div class="image-container"> | |
| <svg width="{{ width }}" height="{{ height }}" viewBox="0 0 {{ width }} {{ height }}"> | |
| <image href="data:image/png;base64,{{ image_base64 }}" width="{{ width }}" height="{{ height }}"/> | |
| {% for line in lines %} | |
| <a class="textline line{{loop.index}}" onmouseover="document.querySelectorAll('.line{{loop.index}}').forEach(element => {element.classList.add('highlighted')});" onmouseout="document.querySelectorAll('*').forEach(element => {element.classList.remove('highlighted')});"> | |
| <path class="line-boundary" d="M {{ line.boundary|join(' L ') }} Z" fill="rgba(0, 128, 255, 0.2)" stroke="none"/> | |
| <path class="line-baseline" d="M {{ line.baseline|join(' L ') }}" stroke="red" stroke-width="1" fill="none"/> | |
| </a> | |
| {% endfor %} | |
| </svg> | |
| </div> | |
| <div class="transcription-container"> | |
| {% for line in lines %} | |
| <div class="textline-container {% if line.is_rtl %}rtl{% else %}ltr{% endif %}"> | |
| <span class="textline line{{loop.index}}" onmouseover="document.querySelectorAll('.line{{loop.index}}').forEach(element => {element.classList.add('highlighted')});" onmouseout="document.querySelectorAll('*').forEach(element => {element.classList.remove('highlighted')});"> | |
| <span class="line-number">{{ loop.index }}:</span> | |
| <span class="line-text">{{ line.text }}</span> | |
| {% if line.confidence %} | |
| <span class="line-confidence">({{ "%.2f"|format(line.confidence) }})</span> | |
| {% endif %} | |
| </span> | |
| </div> | |
| {% endfor %} | |
| </div> | |
| </div> | |
| <style> | |
| .visualization-container { | |
| display: flex; | |
| gap: 20px; | |
| max-height: 1000px; | |
| } | |
| .image-container { | |
| flex: 2; | |
| overflow: auto; | |
| border: 1px solid #ddd; | |
| border-radius: 4px; | |
| } | |
| .image-container svg { | |
| display: block; | |
| width: 100%; | |
| height: auto; | |
| max-width: 100%; | |
| } | |
| .transcription-container { | |
| flex: 1; | |
| overflow-y: auto; | |
| padding: 10px; | |
| border: 1px solid #ddd; | |
| border-radius: 4px; | |
| } | |
| /* Synchronize scrolling between containers */ | |
| .image-container, .transcription-container { | |
| scroll-behavior: smooth; | |
| } | |
| .image-container::-webkit-scrollbar, .transcription-container::-webkit-scrollbar { | |
| width: 8px; | |
| } | |
| .image-container::-webkit-scrollbar-track, .transcription-container::-webkit-scrollbar-track { | |
| background: #f1f1f1; | |
| } | |
| .image-container::-webkit-scrollbar-thumb, .transcription-container::-webkit-scrollbar-thumb { | |
| background: #888; | |
| border-radius: 4px; | |
| } | |
| .image-container::-webkit-scrollbar-thumb:hover, .transcription-container::-webkit-scrollbar-thumb:hover { | |
| background: #555; | |
| } | |
| .textline-container { | |
| padding: 5px; | |
| margin: 2px 0; | |
| border-radius: 4px; | |
| } | |
| .textline-container.rtl { | |
| direction: rtl; | |
| text-align: right; | |
| } | |
| .textline-container.ltr { | |
| direction: ltr; | |
| text-align: left; | |
| } | |
| .textline { | |
| cursor: pointer; | |
| display: inline-block; | |
| width: 100%; | |
| } | |
| .textline:hover, | |
| .textline.highlighted { | |
| background-color: rgba(0, 128, 255, 0.1); | |
| } | |
| .textline:hover .line-boundary, | |
| .textline.highlighted .line-boundary { | |
| fill: rgba(0, 255, 255, 0.3); | |
| } | |
| .textline:hover .line-baseline, | |
| .textline.highlighted .line-baseline { | |
| stroke: yellow; | |
| } | |
| .line-number { | |
| color: #666; | |
| margin-right: 5px; | |
| } | |
| .line-text { | |
| unicode-bidi: bidi-override; | |
| } | |
| .line-confidence { | |
| color: #888; | |
| font-size: 0.9em; | |
| margin-left: 5px; | |
| } | |
| </style> | |
| <script> | |
| // Synchronize scrolling between containers | |
| const imageContainer = document.querySelector('.image-container'); | |
| const textContainer = document.querySelector('.transcription-container'); | |
| function syncScroll(source, target) { | |
| const ratio = target.scrollHeight / source.scrollHeight; | |
| target.scrollTop = source.scrollTop * ratio; | |
| } | |
| imageContainer.addEventListener('scroll', () => syncScroll(imageContainer, textContainer)); | |
| textContainer.addEventListener('scroll', () => syncScroll(textContainer, imageContainer)); | |
| </script> | |
| """ | |
| return Template(template).render( | |
| width=width, | |
| height=height, | |
| image_base64=image_base64, | |
| lines=lines | |
| ) | |
| def get_example_images(): | |
| """Get list of example images from the examples directory.""" | |
| examples_dir = Path(__file__).parent / "examples" | |
| if not examples_dir.exists(): | |
| return [] | |
| # Combine both glob patterns into a single list | |
| return [str(f) for f in list(examples_dir.glob("*.jpg")) + list(examples_dir.glob("*.png"))] | |
| def process_and_visualize(image, seg_model_name, rec_model_name, progress=gr.Progress()): | |
| try: | |
| if image is None: | |
| yield "β Please upload an image first.", None, None, None, None, None | |
| return | |
| yield "π Starting processing...", None, None, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False) | |
| progress(0.1, desc="Loading models...") | |
| yield "π¦ Loading models...", None, None, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False) | |
| seg_models, rec_models = load_models() | |
| seg_model = seg_models[seg_model_name] | |
| rec_model = rec_models[rec_model_name] | |
| progress(0.3, desc="Running Segmentation...") | |
| yield "βοΈ Running segmentation...", None, None, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False) | |
| baseline_seg = blla.segment(image, model=seg_model) | |
| progress(0.6, desc="Running Recognition...") | |
| yield "π Running text recognition...", None, None, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False) | |
| pred_it = rpred.rpred(network=rec_model, im=image, bounds=baseline_seg, pad=16) | |
| records = [record for record in pred_it] | |
| for line, rec_line in zip(baseline_seg.lines, records): | |
| line.prediction = rec_line.prediction | |
| line.text = rec_line.prediction | |
| line.confidences = rec_line.confidences | |
| progress(0.85, desc="Generating PageXML...") | |
| yield "π Generating PageXML output...", None, None, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False) | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| input_path = os.path.join(temp_dir, "temp.png") | |
| image.save(input_path) | |
| seg_rec_image(input_path, seg_model, rec_model, temp_dir) | |
| output_xml = os.path.join(temp_dir, "temp.xml") | |
| xml_content = open(output_xml, 'r', encoding='utf-8').read() if os.path.exists(output_xml) else "β οΈ Error generating XML output." | |
| progress(1.0, desc="Rendering results...") | |
| yield "β Done! Switch to visualization!", render_image(image, baseline_seg), xml_content, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True) | |
| except Exception as e: | |
| yield f"β Error: {str(e)}", None, None, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True) | |
| def main(): | |
| # Create necessary directories and templates | |
| SEG_MODELS_DIR.mkdir(parents=True, exist_ok=True) | |
| REC_MODELS_DIR.mkdir(parents=True, exist_ok=True) | |
| ensure_template_exists() | |
| create_templates() | |
| # Load available models | |
| seg_models, rec_models = load_models() | |
| if not seg_models: | |
| print("No segmentation models found in app/models/seg. Please add .mlmodel files.") | |
| return | |
| if not rec_models: | |
| print("No recognition models found in app/models/rec. Please add .mlmodel files.") | |
| return | |
| # Create Gradio interface | |
| with gr.Blocks(title="Kraken OCR on Samaritan manuscripts") as demo: | |
| gr.Markdown("# Kraken OCR on Samaritan manuscripts") | |
| gr.Markdown("Upload an image and select models to process it.") | |
| with gr.Tabs() as tabs: | |
| with gr.Tab("Upload Image") as upload_tab: | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| image_input = gr.Image(type="pil", label="Input Image", height=400) | |
| with gr.Row(): | |
| seg_model = gr.Dropdown(choices=list(seg_models.keys()), label="Segmentation Model", value=list(seg_models.keys())[0]) | |
| rec_model = gr.Dropdown(choices=list(rec_models.keys()), label="Recognition Model", value=list(rec_models.keys())[0]) | |
| process_btn = gr.Button("Process Image") | |
| status_box = gr.Markdown("", visible=True) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Example Images") | |
| examples = gr.Gallery( | |
| get_example_images(), | |
| show_label=False, | |
| interactive=True, | |
| allow_preview=False, | |
| object_fit="cover", | |
| columns=2, | |
| height=400, | |
| elem_classes="example-gallery" | |
| ) | |
| with gr.Tab("Visualization", interactive=False) as vis_tab: | |
| visualization_output = gr.HTML(label="Visualization") | |
| with gr.Tab("PageXML", interactive=False) as xml_tab: | |
| xml_output = gr.Textbox(label="PageXML", lines=20, max_lines=50, show_copy_button=True) | |
| # Add custom CSS for the gallery | |
| gr.HTML(""" | |
| <style> | |
| .example-gallery { | |
| overflow-y: auto !important; | |
| max-height: 400px !important; | |
| } | |
| .example-gallery img { | |
| width: 100% !important; | |
| height: 150px !important; | |
| object-fit: cover !important; | |
| border-radius: 4px !important; | |
| cursor: pointer !important; | |
| transition: transform 0.2s !important; | |
| } | |
| .example-gallery img:hover { | |
| transform: scale(1.05) !important; | |
| } | |
| </style> | |
| """) | |
| process_btn.click( | |
| process_and_visualize, | |
| inputs=[image_input, seg_model, rec_model], | |
| outputs=[status_box, visualization_output, xml_output, vis_tab, xml_tab, upload_tab], | |
| show_progress=True | |
| ).then( | |
| lambda: gr.Tabs(selected="Visualization"), | |
| outputs=tabs | |
| ) | |
| # Example image selection handler | |
| def select_example(evt: gr.SelectData): | |
| if not examples.value: | |
| return None | |
| selected = examples.value[evt.index] | |
| return selected["image"]["path"] | |
| examples.select( | |
| select_example, | |
| None, | |
| image_input | |
| ) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |
| if __name__ == "__main__": | |
| main() |