Alexhe101 commited on
Commit
d43b86d
·
verified ·
1 Parent(s): 5ac2070

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +335 -34
src/streamlit_app.py CHANGED
@@ -1,40 +1,341 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
 
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
 
 
 
 
 
14
  """
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import json
3
+ import os
4
+ import random
5
+ import yaml
6
+ import uuid
7
+ from datetime import datetime
8
+ from filelock import FileLock
9
+ from collections import defaultdict
10
+ from huggingface_hub import HfApi, login
11
+ DATASET_REPO_ID = "Alexhe101/video_ranking_results" # 你的 Dataset 仓库名
12
+ HF_TOKEN = os.environ.get("HF_TOKEN") # 从 Secret 读取 Token
13
+ if HF_TOKEN:
14
+ try:
15
+ login(token=HF_TOKEN)
16
+ api = HfApi()
17
+ except Exception as e:
18
+ st.warning(f"HF Login Failed: {e}")
19
+ from huggingface_hub import snapshot_download
20
+ DATA_ROOT = "./web_data_new"
21
+ JSON_PATH = os.path.join(DATA_ROOT, "dataset.json")
22
+ LOG_FILE = "final_eval_log.txt"
23
+ LOCK_FILE = "final_eval_log.txt.lock"
24
+
25
+ # 采样配置
26
+ BATCH_SIZE_PER_SCENE = 2
27
+ MAX_SCENES = 5
28
+ DATA_ROOT = "./web_data_new" # 本地存储路径(保持不变)
29
+
30
+ DATA_SOURCE_REPO = "Alexhe101/video_eval_data" # 刚才创建的 Dataset 名字
31
+ if not os.path.exists(DATA_ROOT):
32
+ st.info(f"正在从 Dataset ({DATA_SOURCE_REPO}) 下载评测视频,请稍候...")
33
+ try:
34
+ snapshot_download(
35
+ repo_id=DATA_SOURCE_REPO,
36
+ repo_type="dataset",
37
+ local_dir=DATA_ROOT,
38
+ token=os.environ.get("HF_TOKEN") # 如果Dataset是Private的,需要Token
39
+ )
40
+ st.success("数据下载完成!")
41
+ st.rerun() # 刷新页面以加载数据
42
+ except Exception as e:
43
+ st.error(f"数据下载失败: {e}")
44
+ st.stop()
45
+
46
+ # ================= 配置区域 =================
47
+ st.set_page_config(layout="wide", page_title="Video Eval Platform")
48
 
 
 
49
 
50
+ # 检查目录是否存在,不存在则下载
51
+ # --- 评分标准 ---
52
+ PHYSICAL_RUBRIC = """
53
+ ### ⚛️ 物理评分标准 (Physical Score)
54
+ - **5 (Perfect)**: 物理交互完美,重力、碰撞、接触点真实。
55
+ - **4 (Good)**: 物理规律基本正确,轻微瑕疵不影响理解。
56
+ - **3 (Fair)**: 有明显漂浮或穿模,但动作逻辑连贯。
57
+ - **2 (Poor)**: 严重物理错误(物体瞬移、穿透)。
58
+ - **1 (Fail)**: 完全崩坏,不符合物理规律。
59
+ """
60
 
61
+ TASK_RUBRIC = """
62
+ ### ✅ 子目标判定标准 (Subgoal Criteria)
63
+ 勾选某个子目标 (Subgoal) 需同时满足:
64
+ 1. **动作执行**: 视频中明确展示了该步骤。
65
+ 2. **物理达标**: 该动作片段的物理质量 **≥ 4 (Good)**。
66
+ *(如果动作发生了但穿模严重,请勿勾选)*
67
  """
68
 
69
+ # ================= 工具函数 =================
70
+
71
+ @st.cache_data
72
+ def load_full_data():
73
+ if not os.path.exists(JSON_PATH):
74
+ return []
75
+ with open(JSON_PATH, 'r') as f:
76
+ return json.load(f)
77
+
78
+ def get_session_user():
79
+ if 'user_id' not in st.session_state:
80
+ st.session_state['user_id'] = f"u_{str(uuid.uuid4())[:8]}"
81
+ return st.session_state['user_id']
82
+
83
+ def parse_yaml_content(yaml_str):
84
+ try:
85
+ clean_str = yaml_str.replace("```yaml", "").replace("```", "").strip()
86
+ data = yaml.safe_load(clean_str)
87
+ # 兼容不同拼写 (intention vs intension)
88
+ intent = data.get('intention') or data.get('intension') or 'Unknown'
89
+ return intent, data.get('subgoals', [])
90
+ except:
91
+ return "Unknown", []
92
+
93
+ def save_log(record):
94
+ lock = FileLock(LOCK_FILE)
95
+ try:
96
+ # 1. 先保存到本地 (原逻辑)
97
+ with lock.acquire(timeout=5):
98
+ with open(LOG_FILE, "a", encoding='utf-8') as f:
99
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
100
+
101
+ # 2. 新增:同步上传到 Hugging Face (静默上传,不打扰用户)
102
+ if HF_TOKEN:
103
+ api.upload_file(
104
+ path_or_fileobj=LOG_FILE,
105
+ path_in_repo="final_eval_log.txt", # 在 Dataset 里的文件名
106
+ repo_id=DATASET_REPO_ID,
107
+ repo_type="dataset",
108
+ commit_message=f"Sync data: {record.get('case_id', 'unknown')}"
109
+ )
110
+ print("Cloud sync success.")
111
+
112
+ except Exception as e:
113
+ st.error(f"Save/Sync failed: {e}")
114
+ def get_my_batch(all_data):
115
+ if 'my_batch' not in st.session_state:
116
+ # 分层采样逻辑
117
+ scene_map = defaultdict(list)
118
+ for item in all_data:
119
+ parts = item['case_id'].split('_')
120
+ scene_name = parts[0] if len(parts) > 1 else "misc"
121
+ scene_map[scene_name].append(item)
122
+
123
+ available = list(scene_map.keys())
124
+ random.shuffle(available)
125
+
126
+ selected = []
127
+ for s in available[:MAX_SCENES]:
128
+ items = scene_map[s]
129
+ cnt = min(len(items), BATCH_SIZE_PER_SCENE)
130
+ selected.extend(random.sample(items, cnt))
131
+
132
+ st.session_state['my_batch'] = selected
133
+ st.session_state['current_index'] = 0
134
+ return st.session_state['my_batch']
135
+
136
+ # ================= 主界面逻辑 =================
137
+
138
+ user_id = get_session_user()
139
+ full_data = load_full_data()
140
+ my_batch = get_my_batch(full_data)
141
+ curr_idx = st.session_state.get('current_index', 0)
142
+
143
+ # --- 侧边栏: 进度 & Rubric ---
144
+ with st.sidebar:
145
+ st.title("📹 视频评估系统")
146
+ st.write(f"User: `{user_id}`")
147
+
148
+ total = len(my_batch)
149
+ st.progress(curr_idx / total if total > 0 else 0)
150
+ st.write(f"当前进度: {curr_idx} / {total}")
151
+
152
+ st.divider()
153
+ st.markdown(PHYSICAL_RUBRIC)
154
+ st.divider()
155
+ st.markdown(TASK_RUBRIC)
156
+
157
+ # --- 完成判断 ---
158
+ if curr_idx >= len(my_batch):
159
+ st.balloons()
160
+ st.success("🎉 所有任务已完成!")
161
+ if st.button("开始新的一组 (New Batch)"):
162
+ del st.session_state['my_batch']
163
+ del st.session_state['current_index']
164
+ st.rerun()
165
+ st.stop()
166
+
167
+ # --- 当前任务 ---
168
+ current_case = my_batch[curr_idx]
169
+ c_id = current_case['case_id']
170
+ videos = current_case['videos']
171
+ yaml_text = current_case['yaml_text']
172
+ intention, subgoals = parse_yaml_content(yaml_text)
173
+
174
+ # 随机化顺序 (Blind Test)
175
+ if "curr_case_id_final" not in st.session_state or st.session_state["curr_case_id_final"] != c_id:
176
+ st.session_state["curr_case_id_final"] = c_id
177
+ methods = list(videos.keys())
178
+ random.shuffle(methods)
179
+ st.session_state["curr_methods_order"] = methods
180
+
181
+ methods_order = st.session_state["curr_methods_order"]
182
+ labels = ["A", "B", "C", "D"]
183
+
184
+ # --- 页面顶部信息 ---
185
+ st.subheader(f"📌 Case: {c_id}")
186
+ st.markdown(f"**🎯 Goal (Intention):** `{intention}`")
187
+
188
+ # --- 视频展示与打分 ---
189
+ col1, col2 = st.columns(2, gap="large")
190
+
191
+ # 辅助函数:渲染单个视频块
192
+ def render_video_block(col, idx):
193
+ method_name = methods_order[idx]
194
+ label = labels[idx]
195
+ video_path = os.path.join(DATA_ROOT, videos[method_name])
196
+
197
+ with col:
198
+ st.markdown(f"#### 📺 Video {label}")
199
+ if os.path.exists(video_path):
200
+ st.video(video_path, autoplay=True, loop=True, muted=True)
201
+ else:
202
+ st.warning("Video missing")
203
+
204
+ # 1. 物理评分 (1-5)
205
+ st.caption("1. Physical Score (1-5)")
206
+ st.radio(
207
+ f"phy_score_{label}",
208
+ [1, 2, 3, 4, 5],
209
+ index=None,
210
+ horizontal=True,
211
+ key=f"score_{c_id}_{method_name}",
212
+ label_visibility="collapsed"
213
+ )
214
+
215
+ # 2. Subgoals (直接展示列表)
216
+ st.caption("2. Subgoals & Completion")
217
+ if subgoals:
218
+ st.markdown(
219
+ "<small style='color: #FF4B4B;'>"
220
+ "⚠️ 若动作伴随严重缺陷(如严重穿模、物体幻觉等),请勿勾选。”"
221
+ "</small>",
222
+ unsafe_allow_html=True
223
+ )
224
+ # 使用 expander 稍微收纳一下,防止占用太多空间,默认展开
225
+ with st.expander("Subgoals Checklist", expanded=True):
226
+ for i, sg in enumerate(subgoals):
227
+ st.checkbox(sg, key=f"sub_{c_id}_{method_name}_{i}")
228
+ else:
229
+ st.caption("No subgoals defined in YAML.")
230
+
231
+ # 渲染上半部分 (A, B)
232
+ render_video_block(col1, 0)
233
+ render_video_block(col2, 1)
234
+
235
+ st.divider()
236
+
237
+ # 渲染下半部分 (C, D)
238
+ col3, col4 = st.columns(2, gap="large")
239
+ render_video_block(col3, 2)
240
+ render_video_block(col4, 3)
241
+
242
+ st.divider()
243
+
244
+ # --- 4. 整体对比 (Best & Worst) ---
245
+ st.markdown("### 🏆 Overall Comparison")
246
+ st.markdown("请基于整体质量(物理 + 意图完成度)选出最好和最差的视频。")
247
+
248
+ bw_col1, bw_col2 = st.columns(2)
249
+ with bw_col1:
250
+ best_choice = st.radio("🌟 Best Video", labels, horizontal=True, key=f"best_{c_id}")
251
+ with bw_col2:
252
+ worst_choice = st.radio("💩 Worst Video", labels, horizontal=True, key=f"worst_{c_id}")
253
+
254
+ st.write("")
255
+ st.divider()
256
+
257
+ # --- 5. 异常上报 ---
258
+ is_case_error = st.checkbox("🚫 无法标注 (Case Error): 首帧设置不合理或任务无法完成", key=f"error_{c_id}")
259
+
260
+ st.write("")
261
+
262
+ # --- 提交按钮 ---
263
+ if st.button("🚀 提交 (Submit & Next)", type="primary", use_container_width=True):
264
+
265
+ # 优先处理异常上报
266
+ if is_case_error:
267
+ final_record = {
268
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
269
+ "user": user_id,
270
+ "case_id": c_id,
271
+ "is_error": True,
272
+ "error_reason": "User reported impossible setting",
273
+ "bws": {
274
+ "order": methods_order
275
+ }
276
+ }
277
+ save_log(final_record)
278
+ st.warning("已标记为异常 Case,正在切换下一个...")
279
+ st.session_state['current_index'] += 1
280
+ st.rerun()
281
+
282
+ else:
283
+ # 1. 验证数据完整性 (正常流程)
284
+ errors = []
285
+
286
+ if not best_choice or not worst_choice:
287
+ errors.append("请选择 Best 和 Worst 视频!")
288
+ elif best_choice == worst_choice:
289
+ errors.append("Best 和 Worst 不能是同一个视频!")
290
+
291
+ # 验证每个视频的评分
292
+ results = {}
293
+ for m in methods_order:
294
+ score = st.session_state.get(f"score_{c_id}_{m}")
295
+ # 移除 is_succ 的获取
296
+
297
+ # 收集选中的 Subgoals
298
+ completed_subs = []
299
+ if subgoals:
300
+ for i, sg in enumerate(subgoals):
301
+ if st.session_state.get(f"sub_{c_id}_{m}_{i}", False):
302
+ completed_subs.append(sg)
303
+
304
+ if score is None:
305
+ errors.append(f"请为 {m} (Video {labels[methods_order.index(m)]}) 打物理分!")
306
+
307
+ results[m] = {
308
+ "physical_score": score,
309
+ # "success": is_succ, # 已移除
310
+ "completed_subgoals": completed_subs, # 完成的具体子目标
311
+ "subgoal_rate": len(completed_subs)/len(subgoals) if len(subgoals)>0 else 0.0
312
+ }
313
+
314
+ if errors:
315
+ for e in errors:
316
+ st.error(e)
317
+ else:
318
+ # 2. 构造数据
319
+ real_best = methods_order[labels.index(best_choice)]
320
+ real_worst = methods_order[labels.index(worst_choice)]
321
+
322
+ final_record = {
323
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
324
+ "user": user_id,
325
+ "case_id": c_id,
326
+ "is_error": False, # 正常Case
327
+ "details": results,
328
+ "bws": {
329
+ "best": real_best,
330
+ "worst": real_worst,
331
+ "order": methods_order
332
+ }
333
+ }
334
+
335
+ # 3. 保存
336
+ save_log(final_record)
337
+ st.success("保存成功!")
338
+
339
+ # 4. 切换下一个
340
+ st.session_state['current_index'] += 1
341
+ st.rerun()