DataFlow-VQA / app.py
aaron1141's picture
feat: embed extracted images as base64 data URIs in preview cards
09fc1b5
import os
import sys
import re
import json
import shutil
import tempfile
import traceback
import zipfile
import base64
import gradio as gr
_MIME = {"png": "png", "jpg": "jpeg", "jpeg": "jpeg", "gif": "gif", "webp": "webp"}
try:
import markdown as _md
def _md2html(text: str, img_dir: str = None) -> str:
"""Markdown → HTML.
If img_dir is given, <img src="vqa_images/x.png"> tags are replaced with
base64 data URIs so the browser renders the actual image.
Otherwise a grey badge placeholder is shown.
"""
html = _md.markdown(text, extensions=["nl2br", "tables"])
def _img_handler(m):
tag = m.group(0)
src_m = re.search(r'src="([^"]*)"', tag)
alt_m = re.search(r'alt="([^"]*)"', tag)
src = src_m.group(1) if src_m else ""
alt = alt_m.group(1) if alt_m else "image"
if img_dir and src:
img_name = os.path.basename(src)
img_path = os.path.join(img_dir, img_name)
if os.path.exists(img_path):
ext = img_name.rsplit(".", 1)[-1].lower()
mime = _MIME.get(ext, "png")
with open(img_path, "rb") as f:
b64 = base64.b64encode(f.read()).decode()
return (
f'<img src="data:image/{mime};base64,{b64}" alt="{alt}" '
f'style="max-width:100%;border-radius:6px;margin:6px 0;display:block;">'
)
# Fallback badge when image file is not found
return (
'<span style="display:inline-flex;align-items:center;gap:4px;'
'background:#f3f4f6;border:1px solid #d1d5db;border-radius:4px;'
f'padding:1px 7px;font-size:12px;color:#6b7280">📷 {alt}</span>'
)
return re.sub(r'<img\b[^>]*/?>', _img_handler, html)
except ImportError:
def _md2html(text: str, img_dir: str = None) -> str:
return text.replace("\n", "<br>")
_REPO_ROOT = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, _REPO_ROOT)
print(f"[startup] Gradio: {gr.__version__}", flush=True)
# ── i18n ──────────────────────────────────────────────────────────────────────
T = {
"zh": dict(
lang_btn="English",
subtitle="多模态知识提取 Pipeline Demo",
desc=(
"上传教材或试卷 PDF,用 [MinerU](https://mineru.net) 解析版面、再用 LLM 提取结构化 QA 对,"
"输出 `raw_vqa.jsonl`。\n\n"
"**流程:** PDF 上传 → MinerU 解析 → LLM 提取 QA → 下载结果\n\n"
"> 所有 API 调用均通过您提供的密钥完成,本 Space 不存储任何数据或密钥。"
),
sec_upload="📄 上传 PDF",
upload_label="PDF 文件(单文件:题答混排;双文件:第1个题目,第2个答案)",
task_label="任务名称",
sec_examples="📋 内置示例 PDF(点击加载)",
ex1_label="示例 1:单文件题答混排",
ex2_label="示例 2:双文件(题目 + 答案)",
sec_llm="⚙️ LLM 配置",
api_url_label="API Base URL",
llm_key_label="LLM API Key(DF_API_KEY)",
llm_key_ph="sk-... / AIzaSy...",
model_label="模型名称",
model_ph="gemini-2.5-pro / gpt-4o / deepseek-r1",
sec_mineru="🏗️ MinerU 配置",
mineru_key_label="MinerU API Key(MINERU_API_KEY)",
mineru_key_info="⚠️ 独立于 LLM 的第二个 Key,去 https://mineru.net/apiManage/token 免费申请",
workers_label="并发 Worker 数",
run_btn="▶ 开始提取",
stop_btn="⏹ 中止运行",
sec_output="📤 输出",
status_label="运行状态",
status_ph="点击「开始提取」后进度显示在这里(运行需数分钟,请耐心等待)…",
output_label="下载结果(vqa_output.zip,含 JSONL + 图片)",
preview_label="结果预览",
),
"en": dict(
lang_btn="中文",
subtitle="Multimodal Knowledge Extraction Pipeline Demo",
desc=(
"Upload textbook or exam PDFs. [MinerU](https://mineru.net) parses the layout and an LLM "
"extracts structured QA pairs, outputting `raw_vqa.jsonl`.\n\n"
"**Pipeline:** PDF Upload → MinerU Parsing → LLM QA Extraction → Download Results\n\n"
"> All API calls use your own keys. This Space does not store any data or keys."
),
sec_upload="📄 Upload PDF",
upload_label="PDF File(s) — single: Q&A interleaved; two files: 1st questions, 2nd answers",
task_label="Task Name",
sec_examples="📋 Example PDFs (click to load)",
ex1_label="Example 1: Single file (Q&A mixed)",
ex2_label="Example 2: Two files (questions + answers)",
sec_llm="⚙️ LLM Configuration",
api_url_label="API Base URL",
llm_key_label="LLM API Key (DF_API_KEY)",
llm_key_ph="sk-... / AIzaSy...",
model_label="Model Name",
model_ph="gemini-2.5-pro / gpt-4o / deepseek-r1",
sec_mineru="🏗️ MinerU Configuration",
mineru_key_label="MinerU API Key (MINERU_API_KEY)",
mineru_key_info="⚠️ Independent from LLM key. Get yours at https://mineru.net/apiManage/token",
workers_label="Max Workers",
run_btn="▶ Start Extraction",
stop_btn="⏹ Stop",
sec_output="📤 Output",
status_label="Status",
status_ph="Click 'Start Extraction' to begin (may take several minutes)…",
output_label="Download Result (vqa_output.zip — JSONL + images)",
preview_label="Result Preview",
),
}
_DEFAULT_LANG = "en"
EXAMPLES = [
("examples/VQA/questionextract_test.pdf",),
("examples/VQA/math_question.pdf", "examples/VQA/math_answer.pdf"),
]
# ── Helpers ───────────────────────────────────────────────────────────────────
def _render_preview(jsonl_path: str, lang: str = "en", output_dir: str = None) -> str:
"""Render up to 3 QA items as styled HTML cards with real image rendering."""
if not jsonl_path or not os.path.exists(jsonl_path):
return ""
items = []
with open(jsonl_path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
items.append(json.loads(line))
except Exception:
pass
if len(items) >= 3:
break
if not items:
label_empty = "(无 QA 数据)" if lang == "zh" else "(No QA data)"
return f'<div style="padding:16px;color:#666">{label_empty}</div>'
label_q = "题目" if lang == "zh" else "Question"
label_a = "答案" if lang == "zh" else "Answer"
label_s = "解题过程" if lang == "zh" else "Solution"
cards = []
for i, item in enumerate(items):
name = item.get("name", "")
# Images live at output_dir/{name}/vqa_images/
img_dir = os.path.join(output_dir, name, "vqa_images") if output_dir else None
q_html = _md2html(str(item.get("question", "")), img_dir)
a_html = _md2html(str(item.get("answer", "")), img_dir)
sol_raw = str(item.get("solution", ""))
# Truncate solution before converting (avoids cutting mid-tag)
sol_short = (sol_raw[:400] + "\n\n…") if len(sol_raw) > 400 else sol_raw
sol_html = _md2html(sol_short, img_dir)
sol_block = (
f'<div style="margin-top:12px;padding-top:10px;border-top:1px solid #e5e7eb">'
f'<span style="font-weight:600;color:#374151">{label_s}:</span>'
f'<div class="md-body" style="margin-top:6px;font-size:13px;color:#4b5563">{sol_html}</div>'
f'</div>'
) if sol_raw and sol_raw != item.get("answer", "") else ""
cards.append(f"""
<div style="border:1px solid #e5e7eb;border-radius:12px;padding:18px;margin-bottom:12px;
background:#ffffff;box-shadow:0 1px 4px rgba(0,0,0,.06);">
<div style="font-size:11px;color:#9ca3af;margin-bottom:10px">#{i+1} &nbsp;·&nbsp; {name}</div>
<div style="margin-bottom:12px">
<span style="font-weight:600;color:#111827">{label_q}:</span>
<div class="md-body" style="margin-top:6px;font-size:14px">{q_html}</div>
</div>
<div style="background:#f0fdf4;border-radius:8px;padding:12px">
<span style="font-weight:600;color:#15803d">{label_a}:</span>
<div class="md-body" style="margin-top:6px;font-size:14px;color:#166534">{a_html}</div>
</div>
{sol_block}
</div>""")
total_hint = ""
try:
with open(jsonl_path, encoding="utf-8") as f:
total = sum(1 for l in f if l.strip())
if total > 3:
more = f"(共 {total} 条,仅展示前 3 条)" if lang == "zh" else f"{total} items total — showing first 3"
total_hint = f'<div style="font-size:12px;color:#6b7280;margin-bottom:10px">{more}</div>'
except Exception:
pass
inner = total_hint + "".join(cards)
# Wrap in a container that loads MathJax for $…$ / $$…$$ rendering
return (
'<div id="vqa-preview" style="background:#f9fafb;border:1px solid #e5e7eb;'
'border-radius:12px;padding:16px;max-height:580px;overflow-y:auto;">'
+ inner
+ "</div>"
+ """
<script>
(function(){
if(window.MathJax){window.MathJax.typesetPromise&&MathJax.typesetPromise([document.getElementById('vqa-preview')]);return;}
var s=document.createElement('script');
s.src='https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-chtml.js';
s.async=true;
window.MathJax={tex:{inlineMath:[['$','$'],['\\\\(','\\\\)']],displayMath:[['$$','$$'],['\\\\[','\\\\]']]},startup:{ready:function(){MathJax.startup.defaultReady();MathJax.typesetPromise([document.getElementById('vqa-preview')]);}},options:{skipHtmlTags:['script','noscript','style','textarea','pre']}};
document.head.appendChild(s);
})();
</script>"""
)
# ── Backend (generator → stop button works) ───────────────────────────────────
def run_vqa_extraction(
pdf_files, task_name, api_url, llm_api_key, mineru_api_key, model_name, max_workers, lang,
):
if pdf_files is None or (isinstance(pdf_files, list) and len(pdf_files) == 0):
yield None, "❌ 请先上传 PDF 文件。" if lang == "zh" else "❌ Please upload a PDF file first.", ""
return
if not str(llm_api_key).strip():
msg = "❌ 请填写 LLM API Key。" if lang == "zh" else "❌ Please enter your LLM API Key."
yield None, msg, ""; return
if not str(mineru_api_key).strip():
msg = (
"❌ 请填写 MinerU API Key(独立于 LLM Key,去 https://mineru.net/apiManage/token 申请)。"
if lang == "zh" else
"❌ Please enter your MinerU API Key (get it at https://mineru.net/apiManage/token)."
)
yield None, msg, ""; return
task_name = str(task_name).strip() or "task1"
os.environ["DF_API_KEY"] = str(llm_api_key).strip()
os.environ["MINERU_API_KEY"] = str(mineru_api_key).strip()
workspace = tempfile.mkdtemp(prefix="dataflow_vqa_")
cache_dir = os.path.join(workspace, "cache")
os.makedirs(cache_dir, exist_ok=True)
original_cwd = os.getcwd()
try:
os.chdir(workspace)
yield None, "⏳ [1/4] Preparing PDF files…" if lang == "en" else "⏳ [1/4] 整理 PDF 文件…", ""
if not isinstance(pdf_files, list):
pdf_files = [pdf_files]
pdf_paths = []
for i, f in enumerate(pdf_files):
src = f if isinstance(f, str) else (f.name if hasattr(f, "name") else str(f))
dst = os.path.join(workspace, f"input_{i}.pdf")
shutil.copy(src, dst)
pdf_paths.append(dst)
input_jsonl = os.path.join(workspace, "input.jsonl")
with open(input_jsonl, "w") as fout:
entry = {
"input_pdf_paths": pdf_paths if len(pdf_paths) > 1 else pdf_paths[0],
"name": task_name,
}
fout.write(json.dumps(entry, ensure_ascii=False) + "\n")
yield None, "⏳ [2/4] Loading pipeline module…" if lang == "en" else "⏳ [2/4] 加载 Pipeline 模块…", ""
try:
from pipelines.vqa_extract_optimized_pipeline import PDF_VQA_extract_optimized_pipeline
except Exception:
err = f"❌ Failed to import pipeline:\n{traceback.format_exc()}"
yield None, err, ""; return
try:
pipeline = PDF_VQA_extract_optimized_pipeline(
input_file = input_jsonl,
api_url = str(api_url).rstrip("/"),
model_name = str(model_name),
max_workers = int(max_workers),
)
pipeline.compile()
except ValueError as e:
msg = str(e)
if "DF_API_KEY" in msg:
yield None, "❌ LLM API Key 读取失败。" if lang == "zh" else "❌ Failed to read LLM API Key.", ""
elif "MINERU_API_KEY" in msg:
yield None, "❌ MinerU API Key 读取失败。" if lang == "zh" else "❌ Failed to read MinerU API Key.", ""
else:
yield None, f"❌ {msg}", ""
return
yield None, (
"⏳ [3/4] MinerU parsing + LLM QA extraction (may take several minutes)…"
if lang == "en" else
"⏳ [3/4] MinerU 解析 PDF + LLM 提取 QA(可能需要数分钟)…"
), ""
try:
pipeline.forward()
except RuntimeError as e:
msg = str(e)
if "no api found" in msg.lower() or "Apply upload urls failed" in msg:
err = (
"❌ MinerU API Key invalid or expired. Get a new one at https://mineru.net/apiManage/token\n\n" + msg
if lang == "en" else
"❌ MinerU API Key 无效或已过期。请到 https://mineru.net/apiManage/token 重新申请。\n\n" + msg
)
elif "Cannot connect to LLM server" in msg:
err = ("❌ Cannot connect to LLM API. Check Base URL.\n\n" if lang == "en" else "❌ 无法连接 LLM API,请检查 Base URL。\n\n") + msg
else:
err = f"❌ {msg}"
yield None, err, ""; return
yield None, "⏳ [4/4] Collecting output…" if lang == "en" else "⏳ [4/4] 整理输出结果…", ""
step_files = [f for f in os.listdir(cache_dir) if re.match(r"vqa_step\d+\.jsonl", f)]
if not step_files:
msg = "❌ Pipeline finished but no output file found." if lang == "en" else "❌ Pipeline 完成但未找到输出文件。"
yield None, msg, ""; return
max_step = max(int(re.findall(r"vqa_step(\d+)\.jsonl", f)[0]) for f in step_files)
max_step_file = os.path.join(cache_dir, f"vqa_step{max_step}.jsonl")
# ── Collect QA pairs & copy per-task image directories ────────────────
output_dir = os.path.join(workspace, "output")
os.makedirs(output_dir, exist_ok=True)
jsonl_path = os.path.join(output_dir, "raw_vqa.jsonl")
count = 0
image_dirs_found = 0
with open(max_step_file) as f_in, open(jsonl_path, "w") as f_out:
for line in f_in:
data = json.loads(line)
qa_item = data.get("vqa_pair")
if not qa_item:
continue
name = data.get("name", task_name)
out = {"name": name, **qa_item, "image_basedir": "."}
if not out.get("solution"):
out["solution"] = out.get("answer", "")
f_out.write(json.dumps(out, ensure_ascii=False) + "\n")
count += 1
# Copy cache/{name}/ → output/{name}/ (contains vqa_images/)
src_task_dir = os.path.join(cache_dir, name)
dst_task_dir = os.path.join(output_dir, name)
if os.path.isdir(src_task_dir) and not os.path.exists(dst_task_dir):
shutil.copytree(src_task_dir, dst_task_dir)
image_dirs_found += 1
# ── Pack into a zip so images + JSONL download together ──────────────
zip_path = os.path.join(workspace, "vqa_output.zip")
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
for root, dirs, files in os.walk(output_dir):
for fname in files:
full = os.path.join(root, fname)
arcname = os.path.relpath(full, output_dir)
zf.write(full, arcname)
img_note = (
f" ({image_dirs_found} image folder(s) bundled)"
if lang == "en" else
f"(含 {image_dirs_found} 个图片文件夹)"
)
done = (
f"✅ Done! Extracted {count} QA pairs{img_note}. Download the zip to get images + JSONL."
if lang == "en" else
f"✅ 完成!共提取 {count} 条 QA 对{img_note}。下载 zip 可获得 JSONL 和图片。"
)
yield zip_path, done, _render_preview(jsonl_path, lang, output_dir)
except Exception:
yield None, f"❌ Unexpected error:\n{traceback.format_exc()}", ""
finally:
os.chdir(original_cwd)
# ── UI ────────────────────────────────────────────────────────────────────────
CSS = """
#title-row { align-items: center; }
#lang-btn { min-width: 90px; }
.example-btn { margin: 4px 0 !important; }
.md-body p { margin: 0 0 6px; }
.md-body ul, .md-body ol { margin: 4px 0 4px 18px; padding: 0; }
.md-body li { margin-bottom: 2px; }
.md-body code { background:#f3f4f6; border-radius:3px; padding:1px 4px; font-size:12px; }
.md-body pre { background:#f3f4f6; border-radius:6px; padding:8px; overflow-x:auto; font-size:12px; }
.md-body table { border-collapse:collapse; width:100%; font-size:13px; }
.md-body th, .md-body td { border:1px solid #e5e7eb; padding:4px 8px; }
.md-body th { background:#f9fafb; }
"""
_L = _DEFAULT_LANG # shorthand for initial render
with gr.Blocks(
title="FlipVQA-Miner",
theme=gr.themes.Soft(primary_hue="blue", neutral_hue="slate"),
css=CSS,
) as demo:
lang_state = gr.State(_DEFAULT_LANG)
# ── Header ────────────────────────────────────────────────────────────────
with gr.Row(elem_id="title-row"):
with gr.Column(scale=5):
gr.Markdown("# FlipVQA-Miner: Multimodal Knowledge Extraction")
with gr.Column(scale=0, min_width=110):
lang_btn = gr.Button(T[_L]["lang_btn"], elem_id="lang-btn", size="sm")
subtitle_md = gr.Markdown(T[_L]["subtitle"])
desc_md = gr.Markdown(T[_L]["desc"])
# ── Main layout ───────────────────────────────────────────────────────────
with gr.Row():
# ── Left column: inputs ───────────────────────────────────────────────
with gr.Column(scale=1):
# 1. Upload PDF
sec_upload_md = gr.Markdown(f"### {T[_L]['sec_upload']}")
pdf_files = gr.File(
label=T[_L]["upload_label"],
file_types=[".pdf"],
file_count="multiple",
)
task_name = gr.Textbox(label=T[_L]["task_label"], value="task1")
# 2. Example PDFs (between upload and LLM config)
sec_examples_md = gr.Markdown(f"### {T[_L]['sec_examples']}")
with gr.Row():
ex1_btn = gr.Button(T[_L]["ex1_label"], elem_classes="example-btn", scale=1)
ex2_btn = gr.Button(T[_L]["ex2_label"], elem_classes="example-btn", scale=1)
# 3. LLM config
sec_llm_md = gr.Markdown(f"### {T[_L]['sec_llm']}")
api_url = gr.Textbox(
label=T[_L]["api_url_label"],
placeholder="https://api.openai.com/v1",
)
llm_api_key = gr.Textbox(
label=T[_L]["llm_key_label"],
type="password",
placeholder=T[_L]["llm_key_ph"],
)
model_name = gr.Textbox(
label=T[_L]["model_label"],
value="gemini-2.5-pro",
placeholder=T[_L]["model_ph"],
)
# 4. MinerU config
sec_mineru_md = gr.Markdown(f"### {T[_L]['sec_mineru']}")
mineru_api_key = gr.Textbox(
label=T[_L]["mineru_key_label"],
type="password",
placeholder="sk2-...",
info=T[_L]["mineru_key_info"],
)
max_workers = gr.Slider(label=T[_L]["workers_label"], minimum=1, maximum=30, value=5, step=1)
with gr.Row():
run_btn = gr.Button(T[_L]["run_btn"], variant="primary", scale=4)
stop_btn = gr.Button(T[_L]["stop_btn"], variant="stop", scale=1)
# ── Right column: outputs ─────────────────────────────────────────────
with gr.Column(scale=1):
sec_output_md = gr.Markdown(f"### {T[_L]['sec_output']}")
status_box = gr.Textbox(
label=T[_L]["status_label"],
interactive=False,
lines=6,
placeholder=T[_L]["status_ph"],
)
output_file = gr.File(label=T[_L]["output_label"], interactive=False)
preview_md = gr.Markdown(f"#### {T[_L]['preview_label']}")
preview_box = gr.HTML(value="")
# ── Event: Run ────────────────────────────────────────────────────────────
run_event = run_btn.click(
fn=run_vqa_extraction,
inputs=[pdf_files, task_name, api_url, llm_api_key, mineru_api_key,
model_name, max_workers, lang_state],
outputs=[output_file, status_box, preview_box],
api_name="run_vqa_extraction",
)
stop_btn.click(fn=None, cancels=[run_event])
# ── Event: Example buttons ────────────────────────────────────────────────
def load_example(ex_index):
return [os.path.join(_REPO_ROOT, p) for p in EXAMPLES[ex_index]]
ex1_btn.click(fn=lambda: load_example(0), inputs=[], outputs=[pdf_files])
ex2_btn.click(fn=lambda: load_example(1), inputs=[], outputs=[pdf_files])
# ── Event: Language toggle ────────────────────────────────────────────────
def toggle_lang(current_lang):
new = "en" if current_lang == "zh" else "zh"
t = T[new]
return (
new,
gr.update(value=t["lang_btn"]),
gr.update(value=t["subtitle"]),
gr.update(value=t["desc"]),
gr.update(value=f"### {t['sec_upload']}"),
gr.update(label=t["upload_label"]),
gr.update(label=t["task_label"]),
gr.update(value=f"### {t['sec_examples']}"),
gr.update(value=t["ex1_label"]),
gr.update(value=t["ex2_label"]),
gr.update(value=f"### {t['sec_llm']}"),
gr.update(label=t["api_url_label"]),
gr.update(label=t["llm_key_label"], placeholder=t["llm_key_ph"]),
gr.update(label=t["model_label"], placeholder=t["model_ph"]),
gr.update(value=f"### {t['sec_mineru']}"),
gr.update(label=t["mineru_key_label"], info=t["mineru_key_info"]),
gr.update(label=t["workers_label"]),
gr.update(value=t["run_btn"]),
gr.update(value=t["stop_btn"]),
gr.update(value=f"### {t['sec_output']}"),
gr.update(label=t["status_label"], placeholder=t["status_ph"]),
gr.update(label=t["output_label"]),
gr.update(value=f"#### {t['preview_label']}"),
)
lang_btn.click(
fn=toggle_lang,
inputs=[lang_state],
outputs=[
lang_state, lang_btn,
subtitle_md, desc_md,
sec_upload_md, pdf_files, task_name,
sec_examples_md, ex1_btn, ex2_btn,
sec_llm_md, api_url, llm_api_key, model_name,
sec_mineru_md, mineru_api_key, max_workers,
run_btn, stop_btn,
sec_output_md, status_box, output_file,
preview_md,
],
)
if __name__ == "__main__":
demo.launch(allowed_paths=[os.path.join(_REPO_ROOT, "examples")])