import gc import types import sys import hashlib import json import math import os import re from io import BytesIO from typing import Any, Dict, List, Optional, Tuple import fitz # PyMuPDF import gradio as gr import requests import torch from huggingface_hub import snapshot_download from PIL import Image, ImageDraw, ImageFont from qwen_vl_utils import process_vision_info from transformers import AutoModelForCausalLM, AutoProcessor from .utils.constants import IMAGE_FACTOR, MAX_PIXELS, MIN_PIXELS from .utils.prompts import dict_promptmode_to_prompt APP_TITLE = "PreviewSpace — VLM Playground (Local)" TMP_DIR = "/tmp/previewspace" MODELS_DIR = os.path.join(TMP_DIR, "models") DOTS_REPO_ID = "rednote-hilab/dots.ocr" DOTS_LOCAL_DIR = os.path.join(MODELS_DIR, "dots.ocr") LOCAL_DEFAULT_MAX_NEW_TOKENS = 2048 os.makedirs(TMP_DIR, exist_ok=True) os.makedirs(MODELS_DIR, exist_ok=True) def round_by_factor(number: int, factor: int) -> int: return round(number / factor) * factor def smart_resize( height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS, ) -> Tuple[int, int]: if max(height, width) / min(height, width) > 200: raise ValueError("absolute aspect ratio must be smaller than 200") h_bar = max(factor, round_by_factor(height, factor)) w_bar = max(factor, round_by_factor(width, factor)) if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) h_bar = round_by_factor(height / beta, factor) w_bar = round_by_factor(width / beta, factor) elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = round_by_factor(height * beta, factor) w_bar = round_by_factor(width * beta, factor) return int(h_bar), int(w_bar) def fetch_image(image_input: Any) -> Image.Image: if isinstance(image_input, str): if image_input.startswith(("http://", "https://")): response = requests.get(image_input, timeout=60) image = Image.open(BytesIO(response.content)).convert("RGB") else: image = Image.open(image_input).convert("RGB") elif isinstance(image_input, Image.Image): image = image_input.convert("RGB") else: raise ValueError(f"Invalid image input type: {type(image_input)}") return image def load_images_from_pdf(pdf_path: str) -> List[Image.Image]: images: List[Image.Image] = [] pdf_document = fitz.open(pdf_path) try: for page_idx in range(len(pdf_document)): page = pdf_document.load_page(page_idx) pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) img_data = pix.tobytes("ppm") image = Image.open(BytesIO(img_data)).convert("RGB") images.append(image) finally: pdf_document.close() return images def file_checksum(path: str, chunk_size: int = 1 << 20) -> str: hasher = hashlib.sha256() with open(path, "rb") as f: while True: chunk = f.read(chunk_size) if not chunk: break hasher.update(chunk) return hasher.hexdigest() def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image: img = image.copy() draw = ImageDraw.Draw(img) colors = { "Caption": "#FF6B6B", "Footnote": "#4ECDC4", "Formula": "#45B7D1", "List-item": "#96CEB4", "Page-footer": "#FFEAA7", "Page-header": "#DDA0DD", "Picture": "#FFD93D", "Section-header": "#6C5CE7", "Table": "#FD79A8", "Text": "#74B9FF", "Title": "#E17055", } try: try: font = ImageFont.truetype( "/System/Library/Fonts/Supplemental/Arial Bold.ttf", 12 ) except Exception: try: font = ImageFont.truetype( "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12 ) except Exception: font = ImageFont.load_default() for item in layout_data: bbox = item.get("bbox") category = item.get("category") if not bbox or not category: continue color = colors.get(category, "#000000") draw.rectangle(bbox, outline=color, width=2) label = str(category) label_bbox = draw.textbbox((0, 0), label, font=font) label_w = label_bbox[2] - label_bbox[0] label_h = label_bbox[3] - label_bbox[1] x1, y1 = int(bbox[0]), int(bbox[1]) lx = x1 ly = max(0, y1 - label_h - 2) draw.rectangle([lx, ly, lx + label_w + 4, ly + label_h + 2], fill=color) draw.text((lx + 2, ly + 1), label, fill="white", font=font) except Exception: pass return img def is_arabic_text(text: str) -> bool: if not text: return False header_pattern = r"^#{1,6}\s+(.+)$" paragraph_pattern = r"^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$" content_lines: List[str] = [] for line in text.split("\n"): s = line.strip() if not s: continue m = re.match(header_pattern, s) if m: content_lines.append(m.group(1)) continue if re.match(paragraph_pattern, s): content_lines.append(s) if not content_lines: return False combined = " ".join(content_lines) arabic = 0 total = 0 for ch in combined: if ch.isalpha(): total += 1 if ( ("\u0600" <= ch <= "\u06ff") or ("\u0750" <= ch <= "\u077f") or ("\u08a0" <= ch <= "\u08ff") ): arabic += 1 if total == 0: return False return (arabic / total) > 0.5 def extract_json(text: str) -> Optional[Dict[str, Any]]: if not text: return None try: return json.loads(text) except Exception: pass brace_start = text.find("{") brace_end = text.rfind("}") if 0 <= brace_start < brace_end: snippet = text[brace_start : brace_end + 1] try: return json.loads(snippet) except Exception: pass fenced = re.findall(r"```json\s*([\s\S]*?)\s*```", text) for block in fenced: try: return json.loads(block) except Exception: continue return None model: Optional[AutoModelForCausalLM] = None processor: Optional[AutoProcessor] = None def ensure_model_loaded() -> Tuple[AutoModelForCausalLM, AutoProcessor]: global model, processor if model is not None and processor is not None: return model, processor os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1") snapshot_download( repo_id=DOTS_REPO_ID, local_dir=DOTS_LOCAL_DIR, local_dir_use_symlinks=False, ) # Work around transformers dynamic module parent package issue with repo name containing a dot # Ensure 'transformers_modules' and 'transformers_modules.dots' exist as packages if "transformers_modules" not in sys.modules: pkg = types.ModuleType("transformers_modules") pkg.__path__ = [] # type: ignore[attr-defined] sys.modules["transformers_modules"] = pkg if "transformers_modules.dots" not in sys.modules: subpkg = types.ModuleType("transformers_modules.dots") subpkg.__path__ = [] # type: ignore[attr-defined] sys.modules["transformers_modules.dots"] = subpkg use_mps = torch.backends.mps.is_available() dtype = ( torch.float16 if use_mps else (torch.bfloat16 if torch.cuda.is_available() else torch.float32) ) model = AutoModelForCausalLM.from_pretrained( DOTS_LOCAL_DIR, torch_dtype=dtype, trust_remote_code=True, low_cpu_mem_usage=True, ) if use_mps: model.to("mps") proc = AutoProcessor.from_pretrained(DOTS_LOCAL_DIR, trust_remote_code=True) processor = proc return model, processor def run_inference( image: Image.Image, prompt_text: str, max_new_tokens: int = LOCAL_DEFAULT_MAX_NEW_TOKENS, ) -> str: mdl, proc = ensure_model_loaded() messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt_text}, ], } ] text = proc.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = proc( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) device = ( "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu") ) inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()} with torch.no_grad(): generated_ids = mdl.generate( **inputs, max_new_tokens=int(max_new_tokens), do_sample=False, temperature=0.1, ) trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids) ] output_text = processor.batch_decode( trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return output_text[0] if output_text else "" def process_single_image( image: Image.Image, prompt_text: str, max_new_tokens: int, ) -> Dict[str, Any]: img = fetch_image(image) raw = run_inference(img, prompt_text, max_new_tokens=max_new_tokens) result: Dict[str, Any] = { "original_image": img, "processed_image": img, "raw_output": raw, "layout_result": None, "markdown": None, } data = extract_json(raw) if isinstance(data, dict): result["layout_result"] = data items = data.get("elements", data.get("elements_list", data.get("content", []))) if isinstance(items, list): result["processed_image"] = draw_layout_on_image(img, items) result["markdown"] = layoutjson2md(img, items) if result["markdown"] is None: result["markdown"] = raw return result def layoutjson2md( image: Image.Image, layout_data: List[Dict], text_key: str = "text" ) -> str: lines: List[str] = [] try: items = sorted( layout_data, key=lambda x: ( x.get("bbox", [0, 0, 0, 0])[1], x.get("bbox", [0, 0, 0, 0])[0], ), ) for item in items: category = item.get("category", "") text = item.get(text_key, "") if category == "Title" and text: lines.append(f"# {text}\n") elif category == "Section-header" and text: lines.append(f"## {text}\n") elif category == "List-item" and text: lines.append(f"- {text}\n") elif category == "Table" and text: if text.strip().startswith("<"): lines.append(text + "\n") else: lines.append(f"**Table:** {text}\n") elif category == "Formula" and text: if text.strip().startswith("$") or "\\" in text: lines.append(f"$$\n{text}\n$$\n") else: lines.append(f"**Formula:** {text}\n") elif category == "Caption" and text: lines.append(f"*{text}*\n") elif category in ["Page-header", "Page-footer"]: continue elif category == "Picture": continue elif text: lines.append(f"{text}\n") lines.append("") except Exception: return json.dumps(layout_data, ensure_ascii=False) return "\n".join(lines) def create_blocks_app(): css = """ .main-container { max-width: 1500px; margin: 0 auto; } .header-text { text-align: center; color: #1f2937; margin-bottom: 12px; } .page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: 600; } .process-button { border: none !important; color: white !important; font-weight: 700 !important; } """ with gr.Blocks(theme=gr.themes.Soft(), css=css, title=APP_TITLE) as demo: doc_state = gr.State( { "images": [], "current_page": 0, "total_pages": 0, "file_type": None, "checksum": None, "results": [], "parsed": False, } ) cache_state = gr.State({}) gr.HTML( """
Optimized defaults for Apple Silicon / CPU dev.