trevorpfiz
fix: unexpected keyword argument 'file_name'
4aa9a45
raw
history blame
28.9 kB
import gc
import hashlib
import json
import math
import os
import re
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple
import fitz # PyMuPDF
import gradio as gr
import requests
import torch
from huggingface_hub import snapshot_download
from PIL import Image, ImageDraw, ImageFont
from qwen_vl_utils import process_vision_info
from transformers import AutoModelForCausalLM, AutoProcessor
from .utils.constants import IMAGE_FACTOR, MAX_PIXELS, MIN_PIXELS
from .utils.prompts import dict_promptmode_to_prompt
# ============================
# Constants and configuration
# ============================
APP_TITLE = "PreviewSpace — VLM Playground"
TMP_DIR = "/tmp/previewspace"
MODELS_DIR = os.path.join(TMP_DIR, "models")
DOTS_REPO_ID = "rednote-hilab/dots.ocr"
DOTS_LOCAL_DIR = os.path.join(MODELS_DIR, "dots.ocr")
DEFAULT_PROMPT = dict_promptmode_to_prompt.get(
"prompt_layout_all_en",
(
"Please output the layout information from the PDF page image. For each element, return: "
'bbox: [x1, y1, x2, y2], category from {"title","header","paragraph","table","figure","footnote"}, and text. '
'Return JSON: {"elements": [{"bbox": [..], "category": "..", "text": ".."}], "page": <number>}'
),
)
os.makedirs(TMP_DIR, exist_ok=True)
os.makedirs(MODELS_DIR, exist_ok=True)
# ===========
# Utilities
# ===========
def round_by_factor(number: int, factor: int) -> int:
return round(number / factor) * factor
def smart_resize(
height: int,
width: int,
factor: int = IMAGE_FACTOR,
min_pixels: int = MIN_PIXELS,
max_pixels: int = MAX_PIXELS,
) -> Tuple[int, int]:
if max(height, width) / min(height, width) > 200:
raise ValueError("absolute aspect ratio must be smaller than 200")
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = round_by_factor(height / beta, factor)
w_bar = round_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = round_by_factor(height * beta, factor)
w_bar = round_by_factor(width * beta, factor)
return int(h_bar), int(w_bar)
def fetch_image(
image_input: Any,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
) -> Image.Image:
if isinstance(image_input, str):
if image_input.startswith(("http://", "https://")):
response = requests.get(image_input, timeout=60)
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
image = Image.open(image_input).convert("RGB")
elif isinstance(image_input, Image.Image):
image = image_input.convert("RGB")
else:
raise ValueError(f"Invalid image input type: {type(image_input)}")
if min_pixels is not None or max_pixels is not None:
min_pixels = min_pixels or MIN_PIXELS
max_pixels = max_pixels or MAX_PIXELS
new_h, new_w = smart_resize(
image.height,
image.width,
factor=IMAGE_FACTOR,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
image = image.resize((new_w, new_h), Image.LANCZOS)
return image
def load_images_from_pdf(pdf_path: str) -> List[Image.Image]:
images: List[Image.Image] = []
pdf_document = fitz.open(pdf_path)
try:
for page_idx in range(len(pdf_document)):
page = pdf_document.load_page(page_idx)
pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
img_data = pix.tobytes("ppm")
image = Image.open(BytesIO(img_data)).convert("RGB")
images.append(image)
finally:
pdf_document.close()
return images
def file_checksum(path: str, chunk_size: int = 1 << 20) -> str:
hasher = hashlib.sha256()
with open(path, "rb") as f:
while True:
chunk = f.read(chunk_size)
if not chunk:
break
hasher.update(chunk)
return hasher.hexdigest()
def draw_layout_on_image(image: Image.Image, layout_data: List[Dict]) -> Image.Image:
img = image.copy()
draw = ImageDraw.Draw(img)
colors = {
"Caption": "#FF6B6B",
"Footnote": "#4ECDC4",
"Formula": "#45B7D1",
"List-item": "#96CEB4",
"Page-footer": "#FFEAA7",
"Page-header": "#DDA0DD",
"Picture": "#FFD93D",
"Section-header": "#6C5CE7",
"Table": "#FD79A8",
"Text": "#74B9FF",
"Title": "#E17055",
}
try:
try:
font = ImageFont.truetype(
"/System/Library/Fonts/Supplemental/Arial Bold.ttf", 12
)
except Exception:
try:
font = ImageFont.truetype(
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 12
)
except Exception:
font = ImageFont.load_default()
for item in layout_data:
bbox = item.get("bbox")
category = item.get("category")
if not bbox or not category:
continue
color = colors.get(category, "#000000")
draw.rectangle(bbox, outline=color, width=2)
label = str(category)
label_bbox = draw.textbbox((0, 0), label, font=font)
label_w = label_bbox[2] - label_bbox[0]
label_h = label_bbox[3] - label_bbox[1]
x1, y1 = int(bbox[0]), int(bbox[1])
lx = x1
ly = max(0, y1 - label_h - 2)
draw.rectangle([lx, ly, lx + label_w + 4, ly + label_h + 2], fill=color)
draw.text((lx + 2, ly + 1), label, fill="white", font=font)
except Exception:
pass
return img
def is_arabic_text(text: str) -> bool:
if not text:
return False
header_pattern = r"^#{1,6}\s+(.+)$"
paragraph_pattern = r"^(?!#{1,6}\s|!\[|```|\||\s*[-*+]\s|\s*\d+\.\s)(.+)$"
content_lines: List[str] = []
for line in text.split("\n"):
s = line.strip()
if not s:
continue
m = re.match(header_pattern, s)
if m:
content_lines.append(m.group(1))
continue
if re.match(paragraph_pattern, s):
content_lines.append(s)
if not content_lines:
return False
combined = " ".join(content_lines)
arabic = 0
total = 0
for ch in combined:
if ch.isalpha():
total += 1
if (
("\u0600" <= ch <= "\u06ff")
or ("\u0750" <= ch <= "\u077f")
or ("\u08a0" <= ch <= "\u08ff")
):
arabic += 1
if total == 0:
return False
return (arabic / total) > 0.5
def extract_json(text: str) -> Optional[Dict[str, Any]]:
if not text:
return None
try:
return json.loads(text)
except Exception:
pass
# Try to extract JSON block
brace_start = text.find("{")
brace_end = text.rfind("}")
if 0 <= brace_start < brace_end:
snippet = text[brace_start : brace_end + 1]
try:
return json.loads(snippet)
except Exception:
pass
fenced = re.findall(r"```json\s*([\s\S]*?)\s*```", text)
for block in fenced:
try:
return json.loads(block)
except Exception:
continue
return None
def layoutjson2md(
image: Image.Image, layout_data: List[Dict], text_key: str = "text"
) -> str:
lines: List[str] = []
try:
items = sorted(
layout_data,
key=lambda x: (
x.get("bbox", [0, 0, 0, 0])[1],
x.get("bbox", [0, 0, 0, 0])[0],
),
)
for item in items:
category = item.get("category", "")
text = item.get(text_key, "")
if category == "Title" and text:
lines.append(f"# {text}\n")
elif category == "Section-header" and text:
lines.append(f"## {text}\n")
elif category == "List-item" and text:
lines.append(f"- {text}\n")
elif category == "Table" and text:
if text.strip().startswith("<"):
lines.append(text + "\n")
else:
lines.append(f"**Table:** {text}\n")
elif category == "Formula" and text:
if text.strip().startswith("$") or "\\" in text:
lines.append(f"$$\n{text}\n$$\n")
else:
lines.append(f"**Formula:** {text}\n")
elif category == "Caption" and text:
lines.append(f"*{text}*\n")
elif category in ["Page-header", "Page-footer"]:
continue
elif category == "Picture":
# Skip embedding image fragments in markdown for now
continue
elif text:
lines.append(f"{text}\n")
lines.append("")
except Exception:
return json.dumps(layout_data, ensure_ascii=False)
return "\n".join(lines)
# =====================
# Model initialization
# =====================
model: Optional[AutoModelForCausalLM] = None
processor: Optional[AutoProcessor] = None
device = (
"cuda"
if torch.cuda.is_available()
else ("mps" if torch.backends.mps.is_available() else "cpu")
)
def get_torch_dtype() -> torch.dtype:
if device == "cuda":
return torch.bfloat16
if device == "mps":
return torch.float16
return torch.float32
def ensure_model_loaded() -> Tuple[AutoModelForCausalLM, AutoProcessor]:
global model, processor
if model is not None and processor is not None:
return model, processor
os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
snapshot_download(
repo_id=DOTS_REPO_ID,
local_dir=DOTS_LOCAL_DIR,
local_dir_use_symlinks=False,
)
dtype = get_torch_dtype()
model = AutoModelForCausalLM.from_pretrained(
DOTS_LOCAL_DIR,
torch_dtype=dtype,
device_map="auto",
trust_remote_code=True,
)
proc = AutoProcessor.from_pretrained(DOTS_LOCAL_DIR, trust_remote_code=True)
processor = proc
return model, processor
def run_inference(
image: Image.Image, prompt_text: str, max_new_tokens: int = 24000
) -> str:
mdl, proc = ensure_model_loaded()
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt_text},
],
}
]
text = proc.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = proc(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
with torch.no_grad():
generated_ids = mdl.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
do_sample=False,
temperature=0.1,
)
trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
]
output_text = processor.batch_decode(
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0] if output_text else ""
def process_single_image(
image: Image.Image,
prompt_text: str,
min_pixels: Optional[int],
max_pixels: Optional[int],
max_new_tokens: int,
) -> Dict[str, Any]:
img = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
raw = run_inference(img, prompt_text, max_new_tokens=max_new_tokens)
result: Dict[str, Any] = {
"original_image": img,
"processed_image": img,
"raw_output": raw,
"layout_result": None,
"markdown": None,
}
data = extract_json(raw)
if isinstance(data, dict):
result["layout_result"] = data
items = data.get("elements", data.get("elements_list", data.get("content", [])))
if isinstance(items, list):
result["processed_image"] = draw_layout_on_image(img, items)
result["markdown"] = layoutjson2md(img, items)
if result["markdown"] is None:
result["markdown"] = raw
return result
# =================
# Gradio Interface
# =================
def create_blocks_app():
css = """
.main-container { max-width: 1500px; margin: 0 auto; }
.header-text { text-align: center; color: #1f2937; margin-bottom: 12px; }
.page-info { text-align: center; padding: 8px 16px; border-radius: 20px; font-weight: 600; }
.process-button { border: none !important; color: white !important; font-weight: 700 !important; }
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css, title=APP_TITLE) as demo:
# App state
doc_state = gr.State(
{
"images": [],
"current_page": 0,
"total_pages": 0,
"file_type": None,
"checksum": None,
"results": [],
"parsed": False,
}
)
cache_state = gr.State({}) # (checksum, page, prompt_hash) -> result
gr.HTML(
"""
<div class=\"header-text\">
<h2>VLM Playground — dots.ocr</h2>
<p>Upload a PDF or image, preview pages, and parse with a layout-extraction prompt.</p>
</div>
"""
)
with gr.Row(elem_classes=["main-container"]):
# Left: upload + controls
with gr.Column(scale=4):
file_input = gr.File(
label="Upload PDF or Image",
file_types=[
".pdf",
".png",
".jpg",
".jpeg",
".bmp",
".tiff",
".webp",
],
type="filepath",
)
with gr.Group():
template = gr.Dropdown(
label="Prompt Template",
choices=["Layout Extraction"],
value="Layout Extraction",
)
prompt_text = gr.Textbox(
label="Current Prompt",
value=DEFAULT_PROMPT,
lines=6,
)
with gr.Row():
parse_button = gr.Button(
"Parse", variant="primary", elem_classes=["process-button"]
)
clear_button = gr.Button("Clear")
with gr.Accordion("Advanced", open=False):
max_new_tokens = gr.Slider(
minimum=512,
maximum=32000,
value=24000,
step=256,
label="Max new tokens",
)
min_pixels_in = gr.Number(value=MIN_PIXELS, label="Min pixels")
max_pixels_in = gr.Number(value=MAX_PIXELS, label="Max pixels")
page_range = gr.Textbox(
label="Page selection",
placeholder="e.g., 1-3,5 (blank = current page, 'all' = all pages)",
)
# Center: page preview + nav
with gr.Column(scale=5):
preview_image = gr.Image(label="Page Preview", type="pil", height=520)
with gr.Row():
prev_btn = gr.Button("◀ Prev")
page_info = gr.HTML('<div class="page-info">No file</div>')
next_btn = gr.Button("Next ▶")
with gr.Row():
page_jump = gr.Number(value=1, label="Page #", precision=0)
jump_btn = gr.Button("Go")
# Right: results
with gr.Column(scale=6):
with gr.Tabs():
with gr.Tab("Markdown Render"):
md_render = gr.Markdown(
value="Upload and parse to view results", height=520
)
with gr.Tab("Raw Markdown"):
md_raw = gr.Textbox(value="", lines=20)
with gr.Tab("Current Page JSON"):
json_view = gr.JSON(value=None)
with gr.Tab("Processed Image"):
processed_view = gr.Image(type="pil", height=520)
with gr.Row():
download_jsonl = gr.DownloadButton(label="Download JSONL")
download_markdown = gr.DownloadButton(label="Download Markdown")
# ===== Handlers =====
def on_template_change(choice: str) -> str:
return DEFAULT_PROMPT
def on_file_change(path: Optional[str]):
if not path or not os.path.exists(path):
return (
{
"images": [],
"current_page": 0,
"total_pages": 0,
"file_type": None,
"checksum": None,
"results": [],
"parsed": False,
},
None,
'<div class="page-info">No file</div>',
)
checksum = file_checksum(path)
ext = os.path.splitext(path)[1].lower()
if ext == ".pdf":
images = load_images_from_pdf(path)
state = {
"images": images,
"current_page": 0,
"total_pages": len(images),
"file_type": "pdf",
"checksum": checksum,
"results": [None] * len(images),
"parsed": False,
}
return (
state,
images[0] if images else None,
f'<div class="page-info">Page 1 / {len(images)}</div>',
)
else:
image = Image.open(path).convert("RGB")
state = {
"images": [image],
"current_page": 0,
"total_pages": 1,
"file_type": "image",
"checksum": checksum,
"results": [None],
"parsed": False,
}
return state, image, '<div class="page-info">Page 1 / 1</div>'
def nav_page(state: Dict[str, Any], direction: str):
if not state.get("images"):
return (
state,
None,
'<div class="page-info">No file</div>',
"No results",
"",
None,
None,
)
if direction == "prev":
state["current_page"] = max(0, state["current_page"] - 1)
elif direction == "next":
state["current_page"] = min(
state["total_pages"] - 1, state["current_page"] + 1
)
idx = state["current_page"]
img = state["images"][idx]
info = (
f'<div class="page-info">Page {idx + 1} / {state["total_pages"]}</div>'
)
result = (
state["results"][idx]
if state.get("parsed") and idx < len(state["results"])
else None
)
md = result.get("markdown") if result else "Page not processed yet"
md_out = gr.update(value=md, rtl=True) if is_arabic_text(md) else md
md_raw_text = md
proc_img = result.get("processed_image") if result else None
js = result.get("layout_result") if result else None
return state, img, info, md_out, md_raw_text, proc_img, js
def jump_to_page(state: Dict[str, Any], page_num: Any):
if not state.get("images"):
return (
state,
None,
'<div class="page-info">No file</div>',
"No results",
"",
None,
None,
)
try:
n = int(page_num)
except Exception:
n = 1
n = max(1, min(state["total_pages"], n))
state["current_page"] = n - 1
return nav_page(state, direction="stay")
def parse_pages(
state: Dict[str, Any],
prompt: str,
max_tokens: int,
min_pix: Optional[float],
max_pix: Optional[float],
selection: Optional[str],
):
if not state.get("images"):
return state, None, "No file", "No content", "", None, None
# Determine pages to process
indices: List[int] = []
if not selection or selection.strip() == "":
indices = [state["current_page"]]
elif selection.strip().lower() == "all":
indices = list(range(state["total_pages"]))
else:
# parse like 1-3,5
parts = [p.strip() for p in selection.split(",") if p.strip()]
for p in parts:
if "-" in p:
a, b = p.split("-", 1)
try:
a_i = max(1, int(a))
b_i = min(state["total_pages"], int(b))
for i in range(a_i - 1, b_i):
indices.append(i)
except Exception:
continue
else:
try:
i = max(1, min(state["total_pages"], int(p)))
indices.append(i - 1)
except Exception:
continue
indices = sorted(
set([i for i in indices if 0 <= i < state["total_pages"]])
)
# Process sequentially for stability
results = state.get("results") or [None] * state["total_pages"]
for i in indices:
img = state["images"][i]
prompt_hash = hashlib.sha256(prompt.encode("utf-8")).hexdigest()[:16]
cache_key = (
state["checksum"],
i,
prompt_hash,
int(min_pix or 0),
int(max_pix or 0),
int(max_tokens),
)
cached = cache_state.value.get(cache_key)
if cached:
results[i] = cached
continue
res = process_single_image(
img,
prompt_text=prompt,
min_pixels=int(min_pix) if min_pix else None,
max_pixels=int(max_pix) if max_pix else None,
max_new_tokens=int(max_tokens),
)
results[i] = res
cache_state.value[cache_key] = res
state["results"] = results
state["parsed"] = True
# Return current page outputs
idx = state["current_page"]
curr = results[idx]
md = curr.get("markdown") if curr else "No content"
md_out = gr.update(value=md, rtl=True) if is_arabic_text(md) else md
md_raw_text = md
proc_img = curr.get("processed_image") if curr else None
js = curr.get("layout_result") if curr else None
info = (
f'<div class="page-info">Page {idx + 1} / {state["total_pages"]}</div>'
)
prev = state["images"][idx]
return state, prev, info, md_out, md_raw_text, proc_img, js
def clear_all():
gc.collect()
return (
{
"images": [],
"current_page": 0,
"total_pages": 0,
"file_type": None,
"checksum": None,
"results": [],
"parsed": False,
},
None,
'<div class="page-info">No file</div>',
"Upload and parse to view results",
"",
None,
None,
)
def download_current_jsonl(state: Dict[str, Any]):
if not state.get("parsed"):
return gr.DownloadButton.update(value=b"")
lines: List[str] = []
for i, res in enumerate(state.get("results", [])):
if res and res.get("layout_result") is not None:
obj = {"page": i + 1, "layout": res["layout_result"]}
lines.append(json.dumps(obj, ensure_ascii=False))
content = "\n".join(lines) if lines else ""
out_path = os.path.join(TMP_DIR, "results.jsonl")
with open(out_path, "w", encoding="utf-8") as f:
f.write(content)
return gr.DownloadButton.update(value=out_path)
def download_current_markdown(state: Dict[str, Any]):
if not state.get("parsed"):
return gr.DownloadButton.update(value=b"")
chunks: List[str] = []
for i, res in enumerate(state.get("results", [])):
if res and res.get("markdown"):
chunks.append(f"## Page {i + 1}\n\n{res['markdown']}")
content = "\n\n---\n\n".join(chunks) if chunks else ""
out_path = os.path.join(TMP_DIR, "results.md")
with open(out_path, "w", encoding="utf-8") as f:
f.write(content)
return gr.DownloadButton.update(value=out_path)
# Wire events
template.change(on_template_change, inputs=[template], outputs=[prompt_text])
file_input.change(
on_file_change,
inputs=[file_input],
outputs=[doc_state, preview_image, page_info],
)
prev_btn.click(
lambda s: nav_page(s, "prev"),
inputs=[doc_state],
outputs=[
doc_state,
preview_image,
page_info,
md_render,
md_raw,
processed_view,
json_view,
],
)
next_btn.click(
lambda s: nav_page(s, "next"),
inputs=[doc_state],
outputs=[
doc_state,
preview_image,
page_info,
md_render,
md_raw,
processed_view,
json_view,
],
)
jump_btn.click(
jump_to_page,
inputs=[doc_state, page_jump],
outputs=[
doc_state,
preview_image,
page_info,
md_render,
md_raw,
processed_view,
json_view,
],
)
parse_button.click(
parse_pages,
inputs=[
doc_state,
prompt_text,
max_new_tokens,
min_pixels_in,
max_pixels_in,
page_range,
],
outputs=[
doc_state,
preview_image,
page_info,
md_render,
md_raw,
processed_view,
json_view,
],
)
clear_button.click(
clear_all,
outputs=[
doc_state,
preview_image,
page_info,
md_render,
md_raw,
processed_view,
json_view,
],
)
download_jsonl.click(
download_current_jsonl, inputs=[doc_state], outputs=[download_jsonl]
)
download_markdown.click(
download_current_markdown, inputs=[doc_state], outputs=[download_markdown]
)
return demo