PaperShow / app.py
JaceWei's picture
update: push latest content
0d563bd
raw
history blame
5.83 kB
import gradio as gr
import subprocess, shutil, os, zipfile, datetime
from pathlib import Path
ROOT = Path(__file__).resolve().parent
OUTPUT_DIR = ROOT / "output"
INPUT_DIR = ROOT / "input"
LOGO_DIR = INPUT_DIR / "logo"
POSTER_LATEX_DIR = ROOT / "posterbuilder" / "latex_proj"
ZIP_PATH = ROOT / "output.zip"
LOG_PATH = ROOT / "last_run.log"
def run_pipeline(arxiv_url, pdf_file, openai_key, logo_files):
start_time = datetime.datetime.now()
logs = [f"🚀 Starting pipeline at {start_time.strftime('%Y-%m-%d %H:%M:%S')}\n"]
# ====== 目录准备 ======
for d in [OUTPUT_DIR, LOGO_DIR, POSTER_LATEX_DIR, INPUT_DIR]:
d.mkdir(parents=True, exist_ok=True)
# 清理旧输出
for item in OUTPUT_DIR.iterdir():
if item.is_dir():
shutil.rmtree(item)
else:
item.unlink()
if ZIP_PATH.exists():
ZIP_PATH.unlink()
logs.append("🧹 Cleaned previous output.\n")
# ====== 校验:必须上传 LOGO ======
# Gradio 可能返回单个文件对象或列表,这里统一成列表处理
if logo_files is None:
logo_files = []
if not isinstance(logo_files, (list, tuple)):
logo_files = [logo_files]
logo_files = [f for f in logo_files if f] # 过滤掉 None
if len(logo_files) == 0:
msg = "❌ 必须上传作者所属机构 Logo(可多张)。"
logs.append(msg)
_write_logs(logs)
return "\n".join(logs), None
# 清空 input/logo 后再保存
for item in LOGO_DIR.iterdir():
if item.is_file():
item.unlink()
saved_logo_paths = []
for lf in logo_files:
p = LOGO_DIR / Path(lf.name).name
shutil.copy(lf.name, p)
saved_logo_paths.append(p)
logs.append(f"🏷️ Saved {len(saved_logo_paths)} logo file(s) to: {LOGO_DIR}\n")
# ====== 处理上传 PDF(可选) ======
pdf_path = None
if pdf_file:
pdf_dir = INPUT_DIR / "pdf"
pdf_dir.mkdir(parents=True, exist_ok=True)
pdf_path = pdf_dir / Path(pdf_file.name).name
shutil.copy(pdf_file.name, pdf_path)
logs.append(f"📄 Uploaded PDF saved to: {pdf_path}\n")
# 为 pipeline 的 Step 1.5 兼容:额外复制到 input/paper.pdf
canonical_pdf = INPUT_DIR / "paper.pdf"
shutil.copy(pdf_file.name, canonical_pdf)
logs.append(f"🔁 Also copied PDF to: {canonical_pdf}\n")
# ====== 校验输入来源 ======
if not arxiv_url and not pdf_file:
msg = "❌ 请提供 arXiv 链接或上传 PDF 文件(二选一)。"
logs.append(msg)
_write_logs(logs)
return "\n".join(logs), None
# ====== 构造命令 ======
cmd = [
"python", "pipeline.py",
"--model_name_t", "gpt-5",
"--model_name_v", "gpt-5",
"--result_dir", "output",
"--paper_latex_root", "input/latex_proj",
"--openai_key", openai_key,
"--gemini_key", "AIzaSyA1wVVdlYAVs3FULSmCVD1Noulwrq7zqeo",
"--logo_dir", str(LOGO_DIR) # 👈 新增:把 logo 目录传入
]
if arxiv_url:
cmd += ["--arxiv_url", arxiv_url]
if pdf_path:
cmd += ["--pdf_path", str(pdf_path)]
logs.append(f"🧠 Running command:\n{' '.join(cmd)}\n")
try:
result = subprocess.run(
cmd, capture_output=True, text=True, timeout=1800
)
logs.append("\n======= STDOUT =======\n")
logs.append(result.stdout)
logs.append("\n======= STDERR =======\n")
logs.append(result.stderr)
except subprocess.TimeoutExpired:
msg = "❌ Pipeline timed out (30 min limit)."
logs.append(msg)
_write_logs(logs)
return "\n".join(logs), None
except Exception as e:
msg = f"❌ Pipeline error: {e}"
logs.append(msg)
_write_logs(logs)
return "\n".join(logs), None
# ====== 检查输出 & 打包 ======
if not any(OUTPUT_DIR.iterdir()):
msg = "❌ No output generated. Please check logs below."
logs.append(msg)
_write_logs(logs)
return "\n".join(logs), None
with zipfile.ZipFile(ZIP_PATH, 'w', zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(OUTPUT_DIR):
for file in files:
file_path = Path(root) / file
arcname = file_path.relative_to(OUTPUT_DIR)
zipf.write(file_path, arcname=arcname)
logs.append(f"✅ Zipped output folder to {ZIP_PATH}\n")
end_time = datetime.datetime.now()
logs.append(f"🏁 Completed at {end_time.strftime('%Y-%m-%d %H:%M:%S')} (Duration: {(end_time - start_time).seconds}s)\n")
_write_logs(logs)
return "\n".join(logs), ZIP_PATH
def _write_logs(logs):
with open(LOG_PATH, "w", encoding="utf-8") as f:
f.write("\n".join(logs))
# ===================== Gradio UI =====================
iface = gr.Interface(
fn=run_pipeline,
inputs=[
gr.Textbox(label="📘 ArXiv URL(二选一)", placeholder="https://arxiv.org/abs/2505.xxxxx"),
gr.File(label="📄 上传 PDF(二选一)"),
gr.Textbox(label="🔑 OpenAI API Key", placeholder="sk-...", type="password"),
gr.File(label="🏷️ 上传作者所属机构 Logo(必选,可多文件)", file_count="multiple", file_types=["image"]),
],
outputs=[
gr.Textbox(label="🧾 Logs", lines=30, max_lines=50),
gr.File(label="📦 下载生成结果 (.zip)")
],
title="📄 PaperShow Pipeline",
description=(
"必须上传机构 Logo(可多张)。\n"
"可输入 arXiv 链接或上传 PDF(二选一),系统将生成 Poster 并打包下载。"
),
allow_flagging="never",
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)