kaier111's picture
Return None for file outputs on failure to avoid /app path error
03546eb verified
#!/usr/bin/env python3
"""Minimal Hugging Face Spaces app for VideoMAE A/B + optional LLM judge."""
from __future__ import annotations
import json
import os
import subprocess
import sys
import tempfile
from typing import List, Optional
import gradio as gr
DEFAULT_JUDGE_MODEL = os.environ.get("JUDGE_MODEL", "Qwen/Qwen2.5-7B-Instruct")
VARIANT_KEYS: List[str] = [
"A_baseline_locked",
"B_replay_priority",
"C_replay_plus_cv_shake_static",
"D_replay_plus_compound_rescue",
]
def _build_summary_md(payload: dict) -> str:
summary = payload.get("summary", {})
lines = [
"### Benchmark Summary",
"",
"| Variant | Strict Hit Rate | Strict Hits | Scored Shots | LLM Mean | LLM Pass Rate | LLM Judged |",
"|---|---:|---:|---:|---:|---:|---:|",
]
for name in VARIANT_KEYS:
v = summary.get(name, {})
lines.append(
(
f"| `{name}` | {v.get('strict_hit_rate', 0)} | {v.get('strict_hits', 0)}"
f" | {v.get('scored_shots', 0)} | {v.get('llm_overall_mean', 0)}"
f" | {v.get('llm_pass_rate', 0)} | {v.get('llm_judged', 0)} |"
)
)
lines.append("")
lines.append("```json")
lines.append(json.dumps(summary, ensure_ascii=False, indent=2))
lines.append("```")
return "\n".join(lines)
def run_eval(
mode: str,
hf_token: str,
enable_llm_judge: bool,
judge_model: str,
judge_token: str,
builtin_cases: str,
max_shots: int,
video_path: Optional[str],
shots_jsonl_path: Optional[str],
gt_json_path: Optional[str],
sample_ids: str,
) -> tuple[str, Optional[str], Optional[str], str]:
# Gradio API can pass `/app` placeholder for empty File inputs.
if video_path and os.path.isdir(video_path):
video_path = None
if shots_jsonl_path and os.path.isdir(shots_jsonl_path):
shots_jsonl_path = None
if gt_json_path and os.path.isdir(gt_json_path):
gt_json_path = None
token = (hf_token or "").strip() or os.environ.get("HF_TOKEN", "").strip()
if not token:
return ("HF token 为空。请在输入框填 `hf_...` 或在 Space Secret 设置 `HF_TOKEN`。", None, None, "")
tmpdir = tempfile.mkdtemp(prefix="videomae_ab_")
output_json = os.path.join(tmpdir, "ab_report.json")
output_csv = os.path.join(tmpdir, "ab_report.csv")
cmd = [
sys.executable,
"run_videomae_ab_test.py",
"--hf-token",
token,
"--output-json",
output_json,
"--output-csv",
output_csv,
"--max-shots",
str(max(0, int(max_shots))),
]
if mode == "builtin":
if (builtin_cases or "").strip():
cmd.extend(["--cases", builtin_cases.strip()])
else:
if not video_path:
return ("自定义模式缺少视频文件。", None, None, "")
if not gt_json_path:
return ("自定义模式缺少 GT 文件(json)。", None, None, "")
cmd.extend(["--video", video_path, "--gt-json", gt_json_path, "--case-name", "space_custom"])
if shots_jsonl_path:
cmd.extend(["--shots-jsonl", shots_jsonl_path])
if (sample_ids or "").strip():
cmd.extend(["--sample-ids", sample_ids.strip()])
if enable_llm_judge:
cmd.append("--llm-judge")
cmd.extend(["--judge-model", (judge_model or DEFAULT_JUDGE_MODEL).strip()])
jt = (judge_token or "").strip() or os.environ.get("JUDGE_TOKEN", "").strip()
if jt:
cmd.extend(["--judge-token", jt])
proc = subprocess.run(cmd, capture_output=True, text=True)
logs = ((proc.stdout or "") + "\n" + (proc.stderr or "")).strip()
if proc.returncode != 0:
msg = "运行失败,请检查日志。"
return (msg, None, None, logs[-12000:])
with open(output_json, "r", encoding="utf-8") as f:
payload = json.load(f)
return (_build_summary_md(payload), output_json, output_csv, logs[-12000:])
def run_eval_api(
mode: str,
hf_token: str,
enable_llm_judge: bool,
judge_model: str,
judge_token: str,
builtin_cases: str,
max_shots: int,
video_remote_path: str,
shots_jsonl_remote_path: str,
gt_json_remote_path: str,
sample_ids: str,
) -> tuple[str, Optional[str], Optional[str], str]:
return run_eval(
mode=mode,
hf_token=hf_token,
enable_llm_judge=enable_llm_judge,
judge_model=judge_model,
judge_token=judge_token,
builtin_cases=builtin_cases,
max_shots=max_shots,
video_path=(video_remote_path or "").strip() or None,
shots_jsonl_path=(shots_jsonl_remote_path or "").strip() or None,
gt_json_path=(gt_json_remote_path or "").strip() or None,
sample_ids=sample_ids,
)
with gr.Blocks(title="VideoMAE Camera Motion A/B") as demo:
gr.Markdown(
"# VideoMAE 运镜 A/B + LLM 评委\n"
"默认是内置样本快速跑;切到自定义模式可上传你的视频/镜头边界/GT。"
)
with gr.Row():
mode = gr.Radio(
choices=[("内置样本", "builtin"), ("自定义上传", "custom")],
value="builtin",
label="运行模式",
)
builtin_cases = gr.Textbox(
value="baseus,runner,vertical",
label="内置 case 过滤",
info="留空=全部内置样本;逗号分隔",
)
max_shots = gr.Slider(0, 20, value=3, step=1, label="每个 case 最大镜头数")
with gr.Row():
hf_token = gr.Textbox(label="HF Token", type="password", placeholder="hf_xxx")
enable_llm_judge = gr.Checkbox(value=True, label="启用 LLM 评委")
judge_model = gr.Textbox(value=DEFAULT_JUDGE_MODEL, label="Judge Model")
judge_token = gr.Textbox(label="Judge Token(可选)", type="password")
with gr.Row():
video_path = gr.File(label="视频文件", type="filepath")
shots_jsonl_path = gr.File(label="镜头边界 JSONL(可选)", type="filepath")
gt_json_path = gr.File(label="GT JSON", type="filepath")
# API-only string paths to bypass File preprocessor in queued remote calls.
video_remote_path = gr.Textbox(visible=False)
shots_jsonl_remote_path = gr.Textbox(visible=False)
gt_json_remote_path = gr.Textbox(visible=False)
api_run_btn = gr.Button("API_RUN", visible=False)
sample_ids = gr.Textbox(label="sample ids(可选)", placeholder="如: 1,2,3")
run_btn = gr.Button("开始评测", variant="primary")
summary_md = gr.Markdown()
out_json = gr.File(label="输出 JSON")
out_csv = gr.File(label="输出 CSV")
logs = gr.Textbox(label="运行日志", lines=16)
run_btn.click(
fn=run_eval,
inputs=[
mode,
hf_token,
enable_llm_judge,
judge_model,
judge_token,
builtin_cases,
max_shots,
video_path,
shots_jsonl_path,
gt_json_path,
sample_ids,
],
outputs=[summary_md, out_json, out_csv, logs],
)
api_run_btn.click(
fn=run_eval_api,
inputs=[
mode,
hf_token,
enable_llm_judge,
judge_model,
judge_token,
builtin_cases,
max_shots,
video_remote_path,
shots_jsonl_remote_path,
gt_json_remote_path,
sample_ids,
],
outputs=[summary_md, out_json, out_csv, logs],
api_name="run_eval_api",
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", "7860")),
show_error=True,
)