Spaces:
Running
Running
| from pathlib import Path | |
| import time | |
| import json | |
| import gradio as gr | |
| import pandas as pd | |
| from gocr import ScreenAIClient | |
| from gocr.models import int_to_hex_rgb | |
| # ========================================================= | |
| # Lazy Singleton OCR Client | |
| # ========================================================= | |
| _client = None | |
| def get_client(): | |
| global _client | |
| if _client is None: | |
| from huggingface_hub import snapshot_download | |
| from pathlib import Path | |
| assets_dir = Path(snapshot_download(repo_id="shethjenil/Google-OCR")) | |
| lib_path = assets_dir / "libchromescreenai.so" | |
| _client = ScreenAIClient(library_path=lib_path, assets_dir=assets_dir) | |
| _client.init_ocr() | |
| return _client | |
| # ========================================================= | |
| # HTML Templates (loaded once from files) | |
| # ========================================================= | |
| _TEMPLATES_DIR = Path(__file__).resolve().parent / "templates" | |
| STATS_GRID_TPL = (_TEMPLATES_DIR / "stats_grid.html").read_text() | |
| TREE_BLOCK_TPL = (_TEMPLATES_DIR / "tree_block.html").read_text() | |
| TREE_PARA_TPL = (_TEMPLATES_DIR / "tree_paragraph.html").read_text() | |
| TREE_LINE_TPL = (_TEMPLATES_DIR / "tree_line.html").read_text() | |
| TREE_WORD_BADGE_TPL = (_TEMPLATES_DIR / "tree_word_badge.html").read_text() | |
| EMPTY_STATE_TPL = (_TEMPLATES_DIR / "empty_state.html").read_text() | |
| ERROR_MSG_TPL = (_TEMPLATES_DIR / "error_msg.html").read_text() | |
| HEADER_HTML = (_TEMPLATES_DIR / "header.html").read_text() | |
| # ========================================================= | |
| # Auxiliary Helper Functions | |
| # ========================================================= | |
| def make_stats_html(metrics_dict, latency="N/A") -> str: | |
| if not metrics_dict: | |
| return EMPTY_STATE_TPL.format(message="No stats. Run OCR.") | |
| return STATS_GRID_TPL.format( | |
| avg_conf=f"{metrics_dict.get('average_confidence', 0.0):.1%}", | |
| latency=latency, | |
| blocks=metrics_dict.get("total_blocks", 0), | |
| paragraphs=metrics_dict.get("total_paragraphs", 0), | |
| lines=metrics_dict.get("total_lines", 0), | |
| words=metrics_dict.get("total_words", 0), | |
| symbols=metrics_dict.get("total_symbols", 0), | |
| lang=metrics_dict.get("primary_language", "und"), | |
| ) | |
| def make_tree_html(parsed_data) -> str: | |
| if not parsed_data: | |
| return EMPTY_STATE_TPL.format(message="No layout structure. Run OCR first!") | |
| parts = ['<div class="layout-tree">'] | |
| for block in parsed_data.get("blocks", []): | |
| b_header = f"Block #{block['block_id']} - Conf: {block['confidence']:.1%} | BBox: [{block['bounding_box']['x']}, {block['bounding_box']['y']}, {block['bounding_box']['width']}×{block['bounding_box']['height']}]" | |
| paras = [] | |
| for para in block.get("paragraphs", []): | |
| pbox = para["bounding_box"] | |
| p_header = f"Paragraph #{para['paragraph_id']} - Conf: {para['confidence']:.1%} | BBox: [{pbox['x']}, {pbox['y']}, {pbox['width']}×{pbox['height']}]" | |
| lines = [] | |
| for line in para.get("lines", []): | |
| l_header = f"Line: \"{line['utf8_string']}\" - Conf: {line['confidence']:.1%} | Lang: {line['language']} | Dir: {line['direction']}" | |
| words_html = [] | |
| for word in line.get("words", []): | |
| fg_hex = int_to_hex_rgb(word["foreground_rgb_value"]) | |
| bg_hex = int_to_hex_rgb(word["background_rgb_value"]) | |
| tooltip = f"Confidence: {word['confidence']:.1%}\\nLanguage: {word['language']}\\nDirection: {word['direction']}\\nType: {word['content_type']}\\nColors - FG: {fg_hex} | BG: {bg_hex}\\nSpace After: {word['has_space_after']}" | |
| symbols_badge = "" | |
| if word.get("symbols"): | |
| sym_list = [ | |
| f"'{sym['utf8_string']}' ({sym['confidence']:.1%})" | |
| for sym in word["symbols"] | |
| ] | |
| symbols_badge = f' <span title="{"\\n".join(sym_list)}" class="ml-1 border-l border-emerald-500/30 pl-1.5 text-xs opacity-60">{len(word["symbols"])} sym</span>' | |
| words_html.append( | |
| TREE_WORD_BADGE_TPL.format( | |
| tooltip=tooltip, | |
| text=word["utf8_string"], | |
| symbols_badge=symbols_badge, | |
| ) | |
| ) | |
| lines.append( | |
| TREE_LINE_TPL.format(header=l_header, content="\n".join(words_html)) | |
| ) | |
| paras.append( | |
| TREE_PARA_TPL.format(header=p_header, content="\n".join(lines)) | |
| ) | |
| parts.append( | |
| TREE_BLOCK_TPL.format(header=b_header, content="\n".join(paras)) | |
| ) | |
| parts.append("</div>") | |
| return "\n".join(parts) | |
| def update_explorer_table(parsed_data, level) -> pd.DataFrame: | |
| if not parsed_data or not level: | |
| return pd.DataFrame() | |
| rows = [] | |
| if level == "Blocks": | |
| for b in parsed_data.get("blocks", []): | |
| rows.append( | |
| { | |
| "Block ID": b["block_id"], | |
| "Text": b["utf8_string"][:100] + "..." | |
| if len(b["utf8_string"]) > 100 | |
| else b["utf8_string"], | |
| "Confidence": f"{b['confidence']:.2%}", | |
| "Language": b["language"], | |
| "Paragraphs": b["paragraph_count"], | |
| "Lines": b["line_count"], | |
| } | |
| ) | |
| elif level == "Paragraphs": | |
| for p in parsed_data.get("paragraphs", []): | |
| rows.append( | |
| { | |
| "Paragraph ID": p["paragraph_id"], | |
| "Block ID": p["block_id"], | |
| "Text": p["utf8_string"][:100] + "..." | |
| if len(p["utf8_string"]) > 100 | |
| else p["utf8_string"], | |
| "Confidence": f"{p['confidence']:.2%}", | |
| "Language": p["language"], | |
| "Lines": p["line_count"], | |
| } | |
| ) | |
| elif level == "Lines": | |
| for ln in parsed_data.get("lines", []): | |
| rows.append( | |
| { | |
| "Line Index": ln["index"], | |
| "Block ID": ln["block_id"], | |
| "Paragraph ID": ln["paragraph_id"], | |
| "Text": ln["utf8_string"], | |
| "Confidence": f"{ln['confidence']:.2%}", | |
| "Language": ln["language"], | |
| "Direction": ln["direction"], | |
| "Content Type": ln["content_type"], | |
| } | |
| ) | |
| elif level == "Words": | |
| for ln in parsed_data.get("lines", []): | |
| for w in ln["words"]: | |
| fg_color = int_to_hex_rgb(w["foreground_rgb_value"]) | |
| bg_color = int_to_hex_rgb(w["background_rgb_value"]) | |
| rows.append( | |
| { | |
| "Word Index": w["index"], | |
| "Line Index": ln["index"], | |
| "Word Text": w["utf8_string"], | |
| "Confidence": f"{w['confidence']:.2%}", | |
| "Language": w["language"], | |
| "FG Color": fg_color, | |
| "BG Color": bg_color, | |
| "Has Space After": "Yes" if w["has_space_after"] else "No", | |
| } | |
| ) | |
| elif level == "Symbols": | |
| for ln in parsed_data.get("lines", []): | |
| for w in ln["words"]: | |
| for s in w["symbols"]: | |
| rows.append( | |
| { | |
| "Symbol Index": s["index"], | |
| "Word Index": w["index"], | |
| "Line Index": ln["index"], | |
| "Char": s["utf8_string"], | |
| "Confidence": f"{s['confidence']:.2%}", | |
| } | |
| ) | |
| return pd.DataFrame(rows) | |
| # ========================================================= | |
| # OCR Core Pipeline Handler | |
| # ========================================================= | |
| def process_image(image, explorer_level): | |
| if image is None: | |
| return ( | |
| "", | |
| "", | |
| EMPTY_STATE_TPL.format(message="No image."), | |
| EMPTY_STATE_TPL.format(message="N/A"), | |
| pd.DataFrame(), | |
| {}, | |
| {}, | |
| ) | |
| start = time.time() | |
| try: | |
| client = get_client() | |
| annotation = client.perform_ocr(image) | |
| if not annotation: | |
| return ( | |
| "", | |
| "No text detected.", | |
| EMPTY_STATE_TPL.format(message="No text found."), | |
| EMPTY_STATE_TPL.format(message="N/A"), | |
| pd.DataFrame(), | |
| {}, | |
| {}, | |
| ) | |
| latency = f"{time.time() - start:.2f}s" | |
| parsed_data = annotation.model_dump() | |
| plain_text = annotation.text | |
| tree_html = make_tree_html(parsed_data) | |
| stats_html = make_stats_html(parsed_data["metrics"], latency) | |
| df_explorer = update_explorer_table(parsed_data, explorer_level) | |
| # Dump JSON payload for browser interactive canvas (image is loaded directly from Gradio via client-side JS) | |
| payload = json.dumps({"annotation": parsed_data}) | |
| return ( | |
| payload, | |
| plain_text, | |
| tree_html, | |
| stats_html, | |
| df_explorer, | |
| parsed_data, | |
| parsed_data, | |
| ) | |
| except Exception as e: | |
| err_msg = f"ERROR occurred during processing:\n{e}" | |
| return ( | |
| "", | |
| err_msg, | |
| ERROR_MSG_TPL.format(message=err_msg), | |
| "Failed", | |
| pd.DataFrame(), | |
| {}, | |
| {}, | |
| ) | |
| # ========================================================= | |
| # Gradio Blocks UI Definition | |
| # ========================================================= | |
| with gr.Blocks( | |
| title="Google Chrome Screen AI OCR Dashboard", | |
| ) as demo: | |
| parsed_dict_state = gr.State(None) | |
| gr.HTML(HEADER_HTML) | |
| with gr.Row(): | |
| with gr.Column(scale=4, elem_classes=["sidebar-panel"]): | |
| input_image = gr.Image( | |
| type="pil", | |
| label="Upload Image", | |
| sources=["upload", "clipboard", "webcam"], | |
| elem_id="input-image", | |
| ) | |
| run_btn = gr.Button( | |
| "Execute OCR", variant="primary", elem_id="run-btn" | |
| ) | |
| with gr.Column(scale=7): | |
| stats_output = gr.HTML( | |
| value=EMPTY_STATE_TPL.format(message="Run OCR to see statistics.") | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Interactive Canvas"): | |
| gr.HTML(value=Path("canvas.html").read_text(), js_on_load=Path("canvas.js").read_text()) | |
| # Hidden text transfer element populated by python to update canvas | |
| data_transfer_el = gr.Textbox( | |
| visible=False, elem_id="ocr-data-transfer" | |
| ) | |
| with gr.Tab("Layout Tree"): | |
| tree_output = gr.HTML(label="Layout Tree") | |
| with gr.Tab("Tabular Explorer"): | |
| explorer_level = gr.Radio( | |
| choices=["Blocks", "Paragraphs", "Lines", "Words", "Symbols"], | |
| value="Paragraphs", | |
| label="View Level", | |
| ) | |
| table_output = gr.DataFrame(interactive=False, wrap=True) | |
| with gr.Tab("Plain Text"): | |
| text_output = gr.Textbox( | |
| lines=18, label="Extracted Text" | |
| ) | |
| with gr.Tab("Developer JSON"): | |
| json_output = gr.JSON( | |
| label="Raw OCR Data" | |
| ) | |
| explorer_level.change( | |
| fn=update_explorer_table, | |
| inputs=[parsed_dict_state, explorer_level], | |
| outputs=table_output, | |
| ) | |
| run_btn.click( | |
| fn=process_image, | |
| inputs=[input_image, explorer_level], | |
| outputs=[ | |
| data_transfer_el, | |
| text_output, | |
| tree_output, | |
| stats_output, | |
| table_output, | |
| json_output, | |
| parsed_dict_state, | |
| ], | |
| ).then( | |
| fn=None, | |
| inputs=[data_transfer_el], | |
| outputs=[], | |
| js=""" | |
| (payload_str) => { | |
| if (window.loadStreamedData) { | |
| window.loadStreamedData(payload_str); | |
| } | |
| } | |
| """, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| css=Path("custom.css").read_text(), | |
| head='<script>tailwind.config={theme:{extend:{colors:{border:"hsl(var(--bd))",input:"hsl(var(--bd))",ring:"hsl(var(--ri))",background:"hsl(var(--bg))",foreground:"hsl(var(--fg))",primary:{DEFAULT:"hsl(var(--pr))",foreground:"hsl(var(--fg))"},secondary:{DEFAULT:"hsl(var(--bd))",foreground:"hsl(var(--fg))"},destructive:{DEFAULT:"hsl(0 62.8% 30.6%)",foreground:"hsl(var(--fg))"},muted:{DEFAULT:"hsl(var(--mu))",foreground:"hsl(var(--mu-fg))"},accent:{DEFAULT:"hsl(var(--ac))",foreground:"hsl(var(--ac-fg))"},card:{DEFAULT:"hsl(var(--ca))",foreground:"hsl(var(--ca-fg))"},popover:{DEFAULT:"hsl(var(--ca))",foreground:"hsl(var(--ca-fg))"}},borderRadius:{lg:"var(--rd)",md:"calc(var(--rd)-2px)",sm:"calc(var(--rd)-4px)"}}}}</script><script src="https://cdn.tailwindcss.com"></script>', | |
| ) | |