Spaces:
Running
Running
| import gc | |
| import hashlib | |
| import json | |
| import math | |
| import os | |
| import re | |
| from io import BytesIO | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import fitz # PyMuPDF | |
| import gradio as gr | |
| import requests | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from PIL import Image, ImageDraw, ImageFont | |
| from qwen_vl_utils import process_vision_info | |
| from transformers import AutoModelForCausalLM, AutoProcessor | |
| from .utils.constants import IMAGE_FACTOR, MAX_PIXELS, MIN_PIXELS | |
| from .utils.prompts import dict_promptmode_to_prompt | |
| # ============================ | |
| # Constants and configuration | |
| # ============================ | |
| APP_TITLE = "PreviewSpace — VLM Playground" | |
| TMP_DIR = "/tmp/previewspace" | |
| MODELS_DIR = os.path.join(TMP_DIR, "models") | |
| DOTS_REPO_ID = "rednote-hilab/dots.ocr" | |
| DOTS_LOCAL_DIR = os.path.join(MODELS_DIR, "dots.ocr") | |
| DEFAULT_PROMPT = dict_promptmode_to_prompt.get( | |
| "prompt_layout_all_en", | |
| ( | |
| "Please output the layout information from the PDF page image. For each element, return: " | |
| 'bbox: [x1, y1, x2, y2], category from {"title","header","paragraph","table","figure","footnote"}, and text. ' | |
| 'Return JSON: {"elements": [{"bbox": [..], "category": "..", "text": ".."}], "page": <number>}' | |
| ), | |
| ) | |
| os.makedirs(TMP_DIR, exist_ok=True) | |
| os.makedirs(MODELS_DIR, exist_ok=True) | |
| # =========== | |
| # Utilities | |
| # =========== | |
| def round_by_factor(number: int, factor: int) -> int: | |
| return round(number / factor) * factor | |
| def smart_resize( | |
| height: int, | |
| width: int, | |
| factor: int = IMAGE_FACTOR, | |
| min_pixels: int = MIN_PIXELS, | |
| max_pixels: int = MAX_PIXELS, | |
| ) -> Tuple[int, int]: | |
| if max(height, width) / min(height, width) > 200: | |
| raise ValueError("absolute aspect ratio must be smaller than 200") | |
| h_bar = max(factor, round_by_factor(height, factor)) | |
| w_bar = max(factor, round_by_factor(width, factor)) | |
| if h_bar * w_bar > max_pixels: | |
| beta = math.sqrt((height * width) / max_pixels) | |
| h_bar = round_by_factor(height / beta, factor) | |
| w_bar = round_by_factor(width / beta, factor) | |
| elif h_bar * w_bar < min_pixels: | |
| beta = math.sqrt(min_pixels / (height * width)) | |
| h_bar = round_by_factor(height * beta, factor) | |
| w_bar = round_by_factor(width * beta, factor) | |
| return int(h_bar), int(w_bar) | |
| def fetch_image( | |
| image_input: Any, | |
| min_pixels: Optional[int] = None, | |
| max_pixels: Optional[int] = None, | |
| ) -> Image.Image: | |
| if isinstance(image_input, str): | |
| if image_input.startswith(("http://", "https://")): | |
| response = requests.get(image_input, timeout=60) | |
| image = Image.open(BytesIO(response.content)).convert("RGB") | |
| else: | |
| image = Image.open(image_input).convert("RGB") | |
| elif isinstance(image_input, Image.Image): | |
| image = image_input.convert("RGB") | |
| else: | |
| raise ValueError(f"Invalid image input type: {type(image_input)}") | |
| if min_pixels is not None or max_pixels is not None: | |
| min_pixels = min_pixels or MIN_PIXELS | |
| max_pixels = max_pixels or MAX_PIXELS | |
| new_h, new_w = smart_resize( | |
| image.height, | |
| image.width, | |
| factor=IMAGE_FACTOR, | |
| min_pixels=min_pixels, | |
| max_pixels=max_pixels, | |
| ) | |
| image = image.resize((new_w, new_h), Image.LANCZOS) | |
| return image | |
| def load_images_from_pdf(pdf_path: str) -> List[Image.Image]: | |
| images: List[Image.Image] = [] | |
| pdf_document = fitz.open(pdf_path) | |
| try: | |
| for page_idx in range(len(pdf_document)): | |
| page = pdf_document.load_page(page_idx) | |
| pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) | |
| img_data = pix.tobytes("ppm") | |
| image = Image.open(BytesIO(img_data)).convert("RGB") | |
| images.append(image) | |
| finally: | |
| pdf_document.close() | |
| return images | |
| def file_checksum(path: str, chunk_size: int = 1 << 20) -> str: | |
| hasher = hashlib.sha256() | |
| with open(path, "rb") as f: | |
| while True: | |
| chunk = f.read(chunk_size) | |
| if not chunk: | |
| break | |
| hasher.update(chunk) | |
| return hasher.hexdigest() | |
| def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image: | |
| img = image.copy() | |
| draw = ImageDraw.Draw(img) | |
| colors = { | |
| "Caption": "#FF6B6B", | |
| "Footnote": "#4ECDC4", | |
| "Formula": "#45B7D1", | |
| "List-item": "#96CEB4", | |
| "Page-footer": "#FFEAA7", | |
| "Page-header": "#DDA0DD", | |
| "Picture": "#FFD93D", | |
| "Section-header": "#6C5CE7", | |
| "Table": "#FD79A8", | |
| "Text": "#74B9FF", | |
| "Title": "#E17055", | |
| } | |
| try: | |
| try: | |
| font = ImageFont.truetype( | |
| "/System/Library/Fonts/Supplemental/Arial Bold.ttf", 12 | |
| ) | |
| except Exception: | |
| try: | |
| font = ImageFont.truetype( | |
| "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12 | |
| ) | |
| except Exception: | |
| font = ImageFont.load_default() | |
| for item in layout_data: | |
| bbox = item.get("bbox") | |
| category = item.get("category") | |
| if not bbox or not category: | |
| continue | |
| color = colors.get(category, "#000000") | |
| draw.rectangle(bbox, outline=color, width=2) | |
| label = str(category) | |
| label_bbox = draw.textbbox((0, 0), label, font=font) | |
| label_w = label_bbox[2] - label_bbox[0] | |
| label_h = label_bbox[3] - label_bbox[1] | |
| x1, y1 = int(bbox[0]), int(bbox[1]) | |
| lx = x1 | |
| ly = max(0, y1 - label_h - 2) | |
| draw.rectangle([lx, ly, lx + label_w + 4, ly + label_h + 2], fill=color) | |
| draw.text((lx + 2, ly + 1), label, fill="white", font=font) | |
| except Exception: | |
| pass | |
| return img | |
| def is_arabic_text(text: str) -> bool: | |
| if not text: | |
| return False | |
| header_pattern = r"^#{1,6}\s+(.+)$" | |
| paragraph_pattern = r"^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$" | |
| content_lines: List[str] = [] | |
| for line in text.split("\n"): | |
| s = line.strip() | |
| if not s: | |
| continue | |
| m = re.match(header_pattern, s) | |
| if m: | |
| content_lines.append(m.group(1)) | |
| continue | |
| if re.match(paragraph_pattern, s): | |
| content_lines.append(s) | |
| if not content_lines: | |
| return False | |
| combined = " ".join(content_lines) | |
| arabic = 0 | |
| total = 0 | |
| for ch in combined: | |
| if ch.isalpha(): | |
| total += 1 | |
| if ( | |
| ("\u0600" <= ch <= "\u06ff") | |
| or ("\u0750" <= ch <= "\u077f") | |
| or ("\u08a0" <= ch <= "\u08ff") | |
| ): | |
| arabic += 1 | |
| if total == 0: | |
| return False | |
| return (arabic / total) > 0.5 | |
| def extract_json(text: str) -> Optional[Dict[str, Any]]: | |
| if not text: | |
| return None | |
| try: | |
| return json.loads(text) | |
| except Exception: | |
| pass | |
| # Try to extract JSON block | |
| brace_start = text.find("{") | |
| brace_end = text.rfind("}") | |
| if 0 <= brace_start < brace_end: | |
| snippet = text[brace_start : brace_end + 1] | |
| try: | |
| return json.loads(snippet) | |
| except Exception: | |
| pass | |
| fenced = re.findall(r"```json\s*([\s\S]*?)\s*```", text) | |
| for block in fenced: | |
| try: | |
| return json.loads(block) | |
| except Exception: | |
| continue | |
| return None | |
| def layoutjson2md( | |
| image: Image.Image, layout_data: List[Dict], text_key: str = "text" | |
| ) -> str: | |
| lines: List[str] = [] | |
| try: | |
| items = sorted( | |
| layout_data, | |
| key=lambda x: ( | |
| x.get("bbox", [0, 0, 0, 0])[1], | |
| x.get("bbox", [0, 0, 0, 0])[0], | |
| ), | |
| ) | |
| for item in items: | |
| category = item.get("category", "") | |
| text = item.get(text_key, "") | |
| if category == "Title" and text: | |
| lines.append(f"# {text}\n") | |
| elif category == "Section-header" and text: | |
| lines.append(f"## {text}\n") | |
| elif category == "List-item" and text: | |
| lines.append(f"- {text}\n") | |
| elif category == "Table" and text: | |
| if text.strip().startswith("<"): | |
| lines.append(text + "\n") | |
| else: | |
| lines.append(f"**Table:** {text}\n") | |
| elif category == "Formula" and text: | |
| if text.strip().startswith("$") or "\\" in text: | |
| lines.append(f"$$\n{text}\n$$\n") | |
| else: | |
| lines.append(f"**Formula:** {text}\n") | |
| elif category == "Caption" and text: | |
| lines.append(f"*{text}*\n") | |
| elif category in ["Page-header", "Page-footer"]: | |
| continue | |
| elif category == "Picture": | |
| # Skip embedding image fragments in markdown for now | |
| continue | |
| elif text: | |
| lines.append(f"{text}\n") | |
| lines.append("") | |
| except Exception: | |
| return json.dumps(layout_data, ensure_ascii=False) | |
| return "\n".join(lines) | |
| # ===================== | |
| # Model initialization | |
| # ===================== | |
| model: Optional[AutoModelForCausalLM] = None | |
| processor: Optional[AutoProcessor] = None | |
| device = ( | |
| "cuda" | |
| if torch.cuda.is_available() | |
| else ("mps" if torch.backends.mps.is_available() else "cpu") | |
| ) | |
| def get_torch_dtype() -> torch.dtype: | |
| if device == "cuda": | |
| return torch.bfloat16 | |
| if device == "mps": | |
| return torch.float16 | |
| return torch.float32 | |
| def ensure_model_loaded() -> Tuple[AutoModelForCausalLM, AutoProcessor]: | |
| global model, processor | |
| if model is not None and processor is not None: | |
| return model, processor | |
| os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1") | |
| snapshot_download( | |
| repo_id=DOTS_REPO_ID, | |
| local_dir=DOTS_LOCAL_DIR, | |
| local_dir_use_symlinks=False, | |
| ) | |
| dtype = get_torch_dtype() | |
| model = AutoModelForCausalLM.from_pretrained( | |
| DOTS_LOCAL_DIR, | |
| torch_dtype=dtype, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| proc = AutoProcessor.from_pretrained(DOTS_LOCAL_DIR, trust_remote_code=True) | |
| processor = proc | |
| return model, processor | |
| def run_inference( | |
| image: Image.Image, prompt_text: str, max_new_tokens: int = 24000 | |
| ) -> str: | |
| mdl, proc = ensure_model_loaded() | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt_text}, | |
| ], | |
| } | |
| ] | |
| text = proc.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| inputs = proc( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| generated_ids = mdl.generate( | |
| **inputs, | |
| max_new_tokens=int(max_new_tokens), | |
| do_sample=False, | |
| temperature=0.1, | |
| ) | |
| trimmed = [ | |
| out_ids[len(in_ids) :] | |
| for in_ids, out_ids in zip(inputs["input_ids"], generated_ids) | |
| ] | |
| output_text = processor.batch_decode( | |
| trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| ) | |
| return output_text[0] if output_text else "" | |
| def process_single_image( | |
| image: Image.Image, | |
| prompt_text: str, | |
| min_pixels: Optional[int], | |
| max_pixels: Optional[int], | |
| max_new_tokens: int, | |
| ) -> Dict[str, Any]: | |
| img = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels) | |
| raw = run_inference(img, prompt_text, max_new_tokens=max_new_tokens) | |
| result: Dict[str, Any] = { | |
| "original_image": img, | |
| "processed_image": img, | |
| "raw_output": raw, | |
| "layout_result": None, | |
| "markdown": None, | |
| } | |
| data = extract_json(raw) | |
| if isinstance(data, dict): | |
| result["layout_result"] = data | |
| items = data.get("elements", data.get("elements_list", data.get("content", []))) | |
| if isinstance(items, list): | |
| result["processed_image"] = draw_layout_on_image(img, items) | |
| result["markdown"] = layoutjson2md(img, items) | |
| if result["markdown"] is None: | |
| result["markdown"] = raw | |
| return result | |
| # ================= | |
| # Gradio Interface | |
| # ================= | |
| def create_blocks_app(): | |
| css = """ | |
| .main-container { max-width: 1500px; margin: 0 auto; } | |
| .header-text { text-align: center; color: #1f2937; margin-bottom: 12px; } | |
| .page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: 600; } | |
| .process-button { border: none !important; color: white !important; font-weight: 700 !important; } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=css, title=APP_TITLE) as demo: | |
| # App state | |
| doc_state = gr.State( | |
| { | |
| "images": [], | |
| "current_page": 0, | |
| "total_pages": 0, | |
| "file_type": None, | |
| "checksum": None, | |
| "results": [], | |
| "parsed": False, | |
| } | |
| ) | |
| cache_state = gr.State({}) # (checksum, page, prompt_hash) -> result | |
| gr.HTML( | |
| """ | |
| <div class=\"header-text\"> | |
| <h2>VLM Playground — dots.ocr</h2> | |
| <p>Upload a PDF or image, preview pages, and parse with a layout-extraction prompt.</p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(elem_classes=["main-container"]): | |
| # Left: upload + controls | |
| with gr.Column(scale=4): | |
| file_input = gr.File( | |
| label="Upload PDF or Image", | |
| file_types=[ | |
| ".pdf", | |
| ".png", | |
| ".jpg", | |
| ".jpeg", | |
| ".bmp", | |
| ".tiff", | |
| ".webp", | |
| ], | |
| type="filepath", | |
| ) | |
| with gr.Group(): | |
| template = gr.Dropdown( | |
| label="Prompt Template", | |
| choices=["Layout Extraction"], | |
| value="Layout Extraction", | |
| ) | |
| prompt_text = gr.Textbox( | |
| label="Current Prompt", | |
| value=DEFAULT_PROMPT, | |
| lines=6, | |
| ) | |
| with gr.Row(): | |
| parse_button = gr.Button( | |
| "Parse", variant="primary", elem_classes=["process-button"] | |
| ) | |
| clear_button = gr.Button("Clear") | |
| with gr.Accordion("Advanced", open=False): | |
| max_new_tokens = gr.Slider( | |
| minimum=512, | |
| maximum=32000, | |
| value=24000, | |
| step=256, | |
| label="Max new tokens", | |
| ) | |
| min_pixels_in = gr.Number(value=MIN_PIXELS, label="Min pixels") | |
| max_pixels_in = gr.Number(value=MAX_PIXELS, label="Max pixels") | |
| page_range = gr.Textbox( | |
| label="Page selection", | |
| placeholder="e.g., 1-3,5 (blank = current page, 'all' = all pages)", | |
| ) | |
| # Center: page preview + nav | |
| with gr.Column(scale=5): | |
| preview_image = gr.Image(label="Page Preview", type="pil", height=520) | |
| with gr.Row(): | |
| prev_btn = gr.Button("◀ Prev") | |
| page_info = gr.HTML('<div class="page-info">No file</div>') | |
| next_btn = gr.Button("Next ▶") | |
| with gr.Row(): | |
| page_jump = gr.Number(value=1, label="Page #", precision=0) | |
| jump_btn = gr.Button("Go") | |
| # Right: results | |
| with gr.Column(scale=6): | |
| with gr.Tabs(): | |
| with gr.Tab("Markdown Render"): | |
| md_render = gr.Markdown( | |
| value="Upload and parse to view results", height=520 | |
| ) | |
| with gr.Tab("Raw Markdown"): | |
| md_raw = gr.Textbox(value="", lines=20) | |
| with gr.Tab("Current Page JSON"): | |
| json_view = gr.JSON(value=None) | |
| with gr.Tab("Processed Image"): | |
| processed_view = gr.Image(type="pil", height=520) | |
| with gr.Row(): | |
| download_jsonl = gr.DownloadButton(label="Download JSONL") | |
| download_markdown = gr.DownloadButton(label="Download Markdown") | |
| # ===== Handlers ===== | |
| def on_template_change(choice: str) -> str: | |
| return DEFAULT_PROMPT | |
| def on_file_change(path: Optional[str]): | |
| if not path or not os.path.exists(path): | |
| return ( | |
| { | |
| "images": [], | |
| "current_page": 0, | |
| "total_pages": 0, | |
| "file_type": None, | |
| "checksum": None, | |
| "results": [], | |
| "parsed": False, | |
| }, | |
| None, | |
| '<div class="page-info">No file</div>', | |
| ) | |
| checksum = file_checksum(path) | |
| ext = os.path.splitext(path)[1].lower() | |
| if ext == ".pdf": | |
| images = load_images_from_pdf(path) | |
| state = { | |
| "images": images, | |
| "current_page": 0, | |
| "total_pages": len(images), | |
| "file_type": "pdf", | |
| "checksum": checksum, | |
| "results": [None] * len(images), | |
| "parsed": False, | |
| } | |
| return ( | |
| state, | |
| images[0] if images else None, | |
| f'<div class="page-info">Page 1 / {len(images)}</div>', | |
| ) | |
| else: | |
| image = Image.open(path).convert("RGB") | |
| state = { | |
| "images": [image], | |
| "current_page": 0, | |
| "total_pages": 1, | |
| "file_type": "image", | |
| "checksum": checksum, | |
| "results": [None], | |
| "parsed": False, | |
| } | |
| return state, image, '<div class="page-info">Page 1 / 1</div>' | |
| def nav_page(state: Dict[str, Any], direction: str): | |
| if not state.get("images"): | |
| return ( | |
| state, | |
| None, | |
| '<div class="page-info">No file</div>', | |
| "No results", | |
| "", | |
| None, | |
| None, | |
| ) | |
| if direction == "prev": | |
| state["current_page"] = max(0, state["current_page"] - 1) | |
| elif direction == "next": | |
| state["current_page"] = min( | |
| state["total_pages"] - 1, state["current_page"] + 1 | |
| ) | |
| idx = state["current_page"] | |
| img = state["images"][idx] | |
| info = ( | |
| f'<div class="page-info">Page {idx + 1} / {state["total_pages"]}</div>' | |
| ) | |
| result = ( | |
| state["results"][idx] | |
| if state.get("parsed") and idx < len(state["results"]) | |
| else None | |
| ) | |
| md = result.get("markdown") if result else "Page not processed yet" | |
| md_out = gr.update(value=md, rtl=True) if is_arabic_text(md) else md | |
| md_raw_text = md | |
| proc_img = result.get("processed_image") if result else None | |
| js = result.get("layout_result") if result else None | |
| return state, img, info, md_out, md_raw_text, proc_img, js | |
| def jump_to_page(state: Dict[str, Any], page_num: Any): | |
| if not state.get("images"): | |
| return ( | |
| state, | |
| None, | |
| '<div class="page-info">No file</div>', | |
| "No results", | |
| "", | |
| None, | |
| None, | |
| ) | |
| try: | |
| n = int(page_num) | |
| except Exception: | |
| n = 1 | |
| n = max(1, min(state["total_pages"], n)) | |
| state["current_page"] = n - 1 | |
| return nav_page(state, direction="stay") | |
| def parse_pages( | |
| state: Dict[str, Any], | |
| prompt: str, | |
| max_tokens: int, | |
| min_pix: Optional[float], | |
| max_pix: Optional[float], | |
| selection: Optional[str], | |
| ): | |
| if not state.get("images"): | |
| return state, None, "No file", "No content", "", None, None | |
| # Determine pages to process | |
| indices: List[int] = [] | |
| if not selection or selection.strip() == "": | |
| indices = [state["current_page"]] | |
| elif selection.strip().lower() == "all": | |
| indices = list(range(state["total_pages"])) | |
| else: | |
| # parse like 1-3,5 | |
| parts = [p.strip() for p in selection.split(",") if p.strip()] | |
| for p in parts: | |
| if "-" in p: | |
| a, b = p.split("-", 1) | |
| try: | |
| a_i = max(1, int(a)) | |
| b_i = min(state["total_pages"], int(b)) | |
| for i in range(a_i - 1, b_i): | |
| indices.append(i) | |
| except Exception: | |
| continue | |
| else: | |
| try: | |
| i = max(1, min(state["total_pages"], int(p))) | |
| indices.append(i - 1) | |
| except Exception: | |
| continue | |
| indices = sorted( | |
| set([i for i in indices if 0 <= i < state["total_pages"]]) | |
| ) | |
| # Process sequentially for stability | |
| results = state.get("results") or [None] * state["total_pages"] | |
| for i in indices: | |
| img = state["images"][i] | |
| prompt_hash = hashlib.sha256(prompt.encode("utf-8")).hexdigest()[:16] | |
| cache_key = ( | |
| state["checksum"], | |
| i, | |
| prompt_hash, | |
| int(min_pix or 0), | |
| int(max_pix or 0), | |
| int(max_tokens), | |
| ) | |
| cached = cache_state.value.get(cache_key) | |
| if cached: | |
| results[i] = cached | |
| continue | |
| res = process_single_image( | |
| img, | |
| prompt_text=prompt, | |
| min_pixels=int(min_pix) if min_pix else None, | |
| max_pixels=int(max_pix) if max_pix else None, | |
| max_new_tokens=int(max_tokens), | |
| ) | |
| results[i] = res | |
| cache_state.value[cache_key] = res | |
| state["results"] = results | |
| state["parsed"] = True | |
| # Return current page outputs | |
| idx = state["current_page"] | |
| curr = results[idx] | |
| md = curr.get("markdown") if curr else "No content" | |
| md_out = gr.update(value=md, rtl=True) if is_arabic_text(md) else md | |
| md_raw_text = md | |
| proc_img = curr.get("processed_image") if curr else None | |
| js = curr.get("layout_result") if curr else None | |
| info = ( | |
| f'<div class="page-info">Page {idx + 1} / {state["total_pages"]}</div>' | |
| ) | |
| prev = state["images"][idx] | |
| return state, prev, info, md_out, md_raw_text, proc_img, js | |
| def clear_all(): | |
| gc.collect() | |
| return ( | |
| { | |
| "images": [], | |
| "current_page": 0, | |
| "total_pages": 0, | |
| "file_type": None, | |
| "checksum": None, | |
| "results": [], | |
| "parsed": False, | |
| }, | |
| None, | |
| '<div class="page-info">No file</div>', | |
| "Upload and parse to view results", | |
| "", | |
| None, | |
| None, | |
| ) | |
| def download_current_jsonl(state: Dict[str, Any]): | |
| if not state.get("parsed"): | |
| return gr.DownloadButton.update(value=b"") | |
| lines: List[str] = [] | |
| for i, res in enumerate(state.get("results", [])): | |
| if res and res.get("layout_result") is not None: | |
| obj = {"page": i + 1, "layout": res["layout_result"]} | |
| lines.append(json.dumps(obj, ensure_ascii=False)) | |
| content = "\n".join(lines) if lines else "" | |
| out_path = os.path.join(TMP_DIR, "results.jsonl") | |
| with open(out_path, "w", encoding="utf-8") as f: | |
| f.write(content) | |
| return gr.DownloadButton.update(value=out_path) | |
| def download_current_markdown(state: Dict[str, Any]): | |
| if not state.get("parsed"): | |
| return gr.DownloadButton.update(value=b"") | |
| chunks: List[str] = [] | |
| for i, res in enumerate(state.get("results", [])): | |
| if res and res.get("markdown"): | |
| chunks.append(f"## Page {i + 1}\n\n{res['markdown']}") | |
| content = "\n\n---\n\n".join(chunks) if chunks else "" | |
| out_path = os.path.join(TMP_DIR, "results.md") | |
| with open(out_path, "w", encoding="utf-8") as f: | |
| f.write(content) | |
| return gr.DownloadButton.update(value=out_path) | |
| # Wire events | |
| template.change(on_template_change, inputs=[template], outputs=[prompt_text]) | |
| file_input.change( | |
| on_file_change, | |
| inputs=[file_input], | |
| outputs=[doc_state, preview_image, page_info], | |
| ) | |
| prev_btn.click( | |
| lambda s: nav_page(s, "prev"), | |
| inputs=[doc_state], | |
| outputs=[ | |
| doc_state, | |
| preview_image, | |
| page_info, | |
| md_render, | |
| md_raw, | |
| processed_view, | |
| json_view, | |
| ], | |
| ) | |
| next_btn.click( | |
| lambda s: nav_page(s, "next"), | |
| inputs=[doc_state], | |
| outputs=[ | |
| doc_state, | |
| preview_image, | |
| page_info, | |
| md_render, | |
| md_raw, | |
| processed_view, | |
| json_view, | |
| ], | |
| ) | |
| jump_btn.click( | |
| jump_to_page, | |
| inputs=[doc_state, page_jump], | |
| outputs=[ | |
| doc_state, | |
| preview_image, | |
| page_info, | |
| md_render, | |
| md_raw, | |
| processed_view, | |
| json_view, | |
| ], | |
| ) | |
| parse_button.click( | |
| parse_pages, | |
| inputs=[ | |
| doc_state, | |
| prompt_text, | |
| max_new_tokens, | |
| min_pixels_in, | |
| max_pixels_in, | |
| page_range, | |
| ], | |
| outputs=[ | |
| doc_state, | |
| preview_image, | |
| page_info, | |
| md_render, | |
| md_raw, | |
| processed_view, | |
| json_view, | |
| ], | |
| ) | |
| clear_button.click( | |
| clear_all, | |
| outputs=[ | |
| doc_state, | |
| preview_image, | |
| page_info, | |
| md_render, | |
| md_raw, | |
| processed_view, | |
| json_view, | |
| ], | |
| ) | |
| download_jsonl.click( | |
| download_current_jsonl, inputs=[doc_state], outputs=[download_jsonl] | |
| ) | |
| download_markdown.click( | |
| download_current_markdown, inputs=[doc_state], outputs=[download_markdown] | |
| ) | |
| return demo | |