video_ranking / src /streamlit_app.py
Alexhe101's picture
Update src/streamlit_app.py
d43b86d verified
import streamlit as st
import json
import os
import random
import yaml
import uuid
from datetime import datetime
from filelock import FileLock
from collections import defaultdict
from huggingface_hub import HfApi, login
DATASET_REPO_ID = "Alexhe101/video_ranking_results" # 你的 Dataset 仓库名
HF_TOKEN = os.environ.get("HF_TOKEN") # 从 Secret 读取 Token
if HF_TOKEN:
try:
login(token=HF_TOKEN)
api = HfApi()
except Exception as e:
st.warning(f"HF Login Failed: {e}")
from huggingface_hub import snapshot_download
DATA_ROOT = "./web_data_new"
JSON_PATH = os.path.join(DATA_ROOT, "dataset.json")
LOG_FILE = "final_eval_log.txt"
LOCK_FILE = "final_eval_log.txt.lock"
# 采样配置
BATCH_SIZE_PER_SCENE = 2
MAX_SCENES = 5
DATA_ROOT = "./web_data_new" # 本地存储路径(保持不变)
DATA_SOURCE_REPO = "Alexhe101/video_eval_data" # 刚才创建的 Dataset 名字
if not os.path.exists(DATA_ROOT):
st.info(f"正在从 Dataset ({DATA_SOURCE_REPO}) 下载评测视频,请稍候...")
try:
snapshot_download(
repo_id=DATA_SOURCE_REPO,
repo_type="dataset",
local_dir=DATA_ROOT,
token=os.environ.get("HF_TOKEN") # 如果Dataset是Private的,需要Token
)
st.success("数据下载完成!")
st.rerun() # 刷新页面以加载数据
except Exception as e:
st.error(f"数据下载失败: {e}")
st.stop()
# ================= 配置区域 =================
st.set_page_config(layout="wide", page_title="Video Eval Platform")
# 检查目录是否存在,不存在则下载
# --- 评分标准 ---
PHYSICAL_RUBRIC = """
### ⚛️ 物理评分标准 (Physical Score)
- **5 (Perfect)**: 物理交互完美,重力、碰撞、接触点真实。
- **4 (Good)**: 物理规律基本正确,轻微瑕疵不影响理解。
- **3 (Fair)**: 有明显漂浮或穿模,但动作逻辑连贯。
- **2 (Poor)**: 严重物理错误(物体瞬移、穿透)。
- **1 (Fail)**: 完全崩坏,不符合物理规律。
"""
TASK_RUBRIC = """
### ✅ 子目标判定标准 (Subgoal Criteria)
勾选某个子目标 (Subgoal) 需同时满足:
1. **动作执行**: 视频中明确展示了该步骤。
2. **物理达标**: 该动作片段的物理质量 **≥ 4 (Good)**。
*(如果动作发生了但穿模严重,请勿勾选)*
"""
# ================= 工具函数 =================
@st.cache_data
def load_full_data():
if not os.path.exists(JSON_PATH):
return []
with open(JSON_PATH, 'r') as f:
return json.load(f)
def get_session_user():
if 'user_id' not in st.session_state:
st.session_state['user_id'] = f"u_{str(uuid.uuid4())[:8]}"
return st.session_state['user_id']
def parse_yaml_content(yaml_str):
try:
clean_str = yaml_str.replace("```yaml", "").replace("```", "").strip()
data = yaml.safe_load(clean_str)
# 兼容不同拼写 (intention vs intension)
intent = data.get('intention') or data.get('intension') or 'Unknown'
return intent, data.get('subgoals', [])
except:
return "Unknown", []
def save_log(record):
lock = FileLock(LOCK_FILE)
try:
# 1. 先保存到本地 (原逻辑)
with lock.acquire(timeout=5):
with open(LOG_FILE, "a", encoding='utf-8') as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
# 2. 新增:同步上传到 Hugging Face (静默上传,不打扰用户)
if HF_TOKEN:
api.upload_file(
path_or_fileobj=LOG_FILE,
path_in_repo="final_eval_log.txt", # 在 Dataset 里的文件名
repo_id=DATASET_REPO_ID,
repo_type="dataset",
commit_message=f"Sync data: {record.get('case_id', 'unknown')}"
)
print("Cloud sync success.")
except Exception as e:
st.error(f"Save/Sync failed: {e}")
def get_my_batch(all_data):
if 'my_batch' not in st.session_state:
# 分层采样逻辑
scene_map = defaultdict(list)
for item in all_data:
parts = item['case_id'].split('_')
scene_name = parts[0] if len(parts) > 1 else "misc"
scene_map[scene_name].append(item)
available = list(scene_map.keys())
random.shuffle(available)
selected = []
for s in available[:MAX_SCENES]:
items = scene_map[s]
cnt = min(len(items), BATCH_SIZE_PER_SCENE)
selected.extend(random.sample(items, cnt))
st.session_state['my_batch'] = selected
st.session_state['current_index'] = 0
return st.session_state['my_batch']
# ================= 主界面逻辑 =================
user_id = get_session_user()
full_data = load_full_data()
my_batch = get_my_batch(full_data)
curr_idx = st.session_state.get('current_index', 0)
# --- 侧边栏: 进度 & Rubric ---
with st.sidebar:
st.title("📹 视频评估系统")
st.write(f"User: `{user_id}`")
total = len(my_batch)
st.progress(curr_idx / total if total > 0 else 0)
st.write(f"当前进度: {curr_idx} / {total}")
st.divider()
st.markdown(PHYSICAL_RUBRIC)
st.divider()
st.markdown(TASK_RUBRIC)
# --- 完成判断 ---
if curr_idx >= len(my_batch):
st.balloons()
st.success("🎉 所有任务已完成!")
if st.button("开始新的一组 (New Batch)"):
del st.session_state['my_batch']
del st.session_state['current_index']
st.rerun()
st.stop()
# --- 当前任务 ---
current_case = my_batch[curr_idx]
c_id = current_case['case_id']
videos = current_case['videos']
yaml_text = current_case['yaml_text']
intention, subgoals = parse_yaml_content(yaml_text)
# 随机化顺序 (Blind Test)
if "curr_case_id_final" not in st.session_state or st.session_state["curr_case_id_final"] != c_id:
st.session_state["curr_case_id_final"] = c_id
methods = list(videos.keys())
random.shuffle(methods)
st.session_state["curr_methods_order"] = methods
methods_order = st.session_state["curr_methods_order"]
labels = ["A", "B", "C", "D"]
# --- 页面顶部信息 ---
st.subheader(f"📌 Case: {c_id}")
st.markdown(f"**🎯 Goal (Intention):** `{intention}`")
# --- 视频展示与打分 ---
col1, col2 = st.columns(2, gap="large")
# 辅助函数:渲染单个视频块
def render_video_block(col, idx):
method_name = methods_order[idx]
label = labels[idx]
video_path = os.path.join(DATA_ROOT, videos[method_name])
with col:
st.markdown(f"#### 📺 Video {label}")
if os.path.exists(video_path):
st.video(video_path, autoplay=True, loop=True, muted=True)
else:
st.warning("Video missing")
# 1. 物理评分 (1-5)
st.caption("1. Physical Score (1-5)")
st.radio(
f"phy_score_{label}",
[1, 2, 3, 4, 5],
index=None,
horizontal=True,
key=f"score_{c_id}_{method_name}",
label_visibility="collapsed"
)
# 2. Subgoals (直接展示列表)
st.caption("2. Subgoals & Completion")
if subgoals:
st.markdown(
"<small style='color: #FF4B4B;'>"
"⚠️ 若动作伴随严重缺陷(如严重穿模、物体幻觉等),请勿勾选。”"
"</small>",
unsafe_allow_html=True
)
# 使用 expander 稍微收纳一下,防止占用太多空间,默认展开
with st.expander("Subgoals Checklist", expanded=True):
for i, sg in enumerate(subgoals):
st.checkbox(sg, key=f"sub_{c_id}_{method_name}_{i}")
else:
st.caption("No subgoals defined in YAML.")
# 渲染上半部分 (A, B)
render_video_block(col1, 0)
render_video_block(col2, 1)
st.divider()
# 渲染下半部分 (C, D)
col3, col4 = st.columns(2, gap="large")
render_video_block(col3, 2)
render_video_block(col4, 3)
st.divider()
# --- 4. 整体对比 (Best & Worst) ---
st.markdown("### 🏆 Overall Comparison")
st.markdown("请基于整体质量(物理 + 意图完成度)选出最好和最差的视频。")
bw_col1, bw_col2 = st.columns(2)
with bw_col1:
best_choice = st.radio("🌟 Best Video", labels, horizontal=True, key=f"best_{c_id}")
with bw_col2:
worst_choice = st.radio("💩 Worst Video", labels, horizontal=True, key=f"worst_{c_id}")
st.write("")
st.divider()
# --- 5. 异常上报 ---
is_case_error = st.checkbox("🚫 无法标注 (Case Error): 首帧设置不合理或任务无法完成", key=f"error_{c_id}")
st.write("")
# --- 提交按钮 ---
if st.button("🚀 提交 (Submit & Next)", type="primary", use_container_width=True):
# 优先处理异常上报
if is_case_error:
final_record = {
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"user": user_id,
"case_id": c_id,
"is_error": True,
"error_reason": "User reported impossible setting",
"bws": {
"order": methods_order
}
}
save_log(final_record)
st.warning("已标记为异常 Case,正在切换下一个...")
st.session_state['current_index'] += 1
st.rerun()
else:
# 1. 验证数据完整性 (正常流程)
errors = []
if not best_choice or not worst_choice:
errors.append("请选择 Best 和 Worst 视频!")
elif best_choice == worst_choice:
errors.append("Best 和 Worst 不能是同一个视频!")
# 验证每个视频的评分
results = {}
for m in methods_order:
score = st.session_state.get(f"score_{c_id}_{m}")
# 移除 is_succ 的获取
# 收集选中的 Subgoals
completed_subs = []
if subgoals:
for i, sg in enumerate(subgoals):
if st.session_state.get(f"sub_{c_id}_{m}_{i}", False):
completed_subs.append(sg)
if score is None:
errors.append(f"请为 {m} (Video {labels[methods_order.index(m)]}) 打物理分!")
results[m] = {
"physical_score": score,
# "success": is_succ, # 已移除
"completed_subgoals": completed_subs, # 完成的具体子目标
"subgoal_rate": len(completed_subs)/len(subgoals) if len(subgoals)>0 else 0.0
}
if errors:
for e in errors:
st.error(e)
else:
# 2. 构造数据
real_best = methods_order[labels.index(best_choice)]
real_worst = methods_order[labels.index(worst_choice)]
final_record = {
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"user": user_id,
"case_id": c_id,
"is_error": False, # 正常Case
"details": results,
"bws": {
"best": real_best,
"worst": real_worst,
"order": methods_order
}
}
# 3. 保存
save_log(final_record)
st.success("保存成功!")
# 4. 切换下一个
st.session_state['current_index'] += 1
st.rerun()