Emilyxml commited on
Commit
ba063cc
·
verified ·
1 Parent(s): 5ad29be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -122
app.py CHANGED
@@ -14,11 +14,10 @@ LOG_FOLDER = Path("logs")
14
  LOG_FOLDER.mkdir(parents=True, exist_ok=True)
15
  TOKEN = os.environ.get("HF_TOKEN")
16
 
17
- # --- 1.5 自动下载数据 (关键修复) ---
18
- # 只有当本地没有数据时才尝试下载
19
  if not os.path.exists(DATA_FOLDER) or not os.listdir(DATA_FOLDER):
20
  try:
21
- print("正在从 Dataset 下载图片数据...")
22
  snapshot_download(
23
  repo_id=DATASET_REPO_ID,
24
  repo_type="dataset",
@@ -26,25 +25,24 @@ if not os.path.exists(DATA_FOLDER) or not os.listdir(DATA_FOLDER):
26
  token=TOKEN,
27
  allow_patterns=["*.jpg", "*.png", "*.jpeg", "*.webp", "*.txt"]
28
  )
29
- print("下载完成!")
30
  except Exception as e:
31
- print(f"下载数据失败 (可能是Token问题或Repo不存在): {e}")
32
 
33
- # --- 2. 启动同步 ---
34
  scheduler = CommitScheduler(
35
  repo_id=DATASET_REPO_ID,
36
  repo_type="dataset",
37
  folder_path=LOG_FOLDER,
38
- path_in_repo="logs", # 建议把日志分文件夹存
39
  every=1,
40
  token=TOKEN
41
  )
42
 
43
- # --- 3. 数据加载 ---
44
  def load_data():
45
  groups = {}
46
  if not os.path.exists(DATA_FOLDER):
47
- os.makedirs(DATA_FOLDER, exist_ok=True)
48
  return {}, []
49
 
50
  for filename in os.listdir(DATA_FOLDER):
@@ -80,145 +78,180 @@ def load_data():
80
 
81
  ALL_GROUPS, ALL_GROUP_IDS = load_data()
82
 
83
- # --- 4. 保存逻辑 ---
84
- def save_user_vote(user_id, group_id, choice_labels, method_names):
85
- user_filename = f"user_{user_id}.csv"
86
- user_file_path = LOG_FOLDER / user_filename
87
- row = [user_id, datetime.now().strftime("%Y-%m-%d %H:%M:%S"), group_id, choice_labels, method_names]
88
-
89
- with scheduler.lock:
90
- file_exists = user_file_path.exists()
91
- with user_file_path.open("a", newline="", encoding="utf-8") as f:
92
- writer = csv.writer(f)
93
- if not file_exists:
94
- writer.writerow(["user_id", "timestamp", "group_id", "selected_labels", "selected_methods"])
95
- writer.writerow(row)
96
- print(f"Saved: {user_id} -> {choice_labels}")
97
-
98
  # --- 5. 核心逻辑 ---
99
 
100
- def init_state():
101
- return {"user_id": str(uuid.uuid4())[:8], "index": 0, "is_finished": False}
102
-
103
- def next_question_data(user_state):
104
  idx = user_state["index"]
 
 
105
  if idx >= len(ALL_GROUP_IDS):
106
- user_state["is_finished"] = True
107
- return user_state, None, [], []
108
-
 
 
 
 
 
 
 
 
 
 
109
  group_id = ALL_GROUP_IDS[idx]
110
  group_data = ALL_GROUPS[group_id]
111
 
 
112
  origin_path = group_data["origin"]
 
 
113
  candidates = group_data["candidates"].copy()
114
  random.shuffle(candidates)
115
 
116
- candidate_info = []
 
 
 
 
117
  for i, path in enumerate(candidates):
118
- label = f"Option {chr(65+i)}"
119
- candidate_info.append({"path": path, "label": label})
120
-
121
- return user_state, origin_path, candidate_info, group_data["instruction"]
122
-
123
- def submit_logic(user_state, current_candidates, selected_indices, is_none=False):
124
- if user_state["is_finished"]:
125
- return user_state, []
 
 
 
 
 
 
 
 
 
 
126
 
 
 
127
  current_idx = user_state["index"]
128
  group_id = ALL_GROUP_IDS[current_idx]
129
-
 
130
  if is_none:
131
- save_user_vote(user_state["user_id"], group_id, "Rejected All", "None_Satisfied")
 
132
  else:
133
- if not selected_indices:
134
- raise gr.Error("请至少选择一张图片,或点击“都不满意”")
135
- labels = []
136
- methods = []
137
- for idx in selected_indices:
138
- info = current_candidates[idx]
139
- labels.append(info["label"])
140
- filename = os.path.basename(info["path"])
141
- name_no_ext = os.path.splitext(filename)[0]
142
- parts = name_no_ext.split('_', 1)
143
- method = parts[1] if len(parts) > 1 else name_no_ext
144
- methods.append(method)
145
 
146
- save_user_vote(user_state["user_id"], group_id, "; ".join(labels), "; ".join(methods))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  user_state["index"] += 1
149
- return user_state, []
 
 
150
 
151
- # --- 6. 界面构建 (Gradio 5.0+) ---
152
  with gr.Blocks(title="User Study") as demo:
153
 
154
- state_main = gr.State(init_state) # 修正:使用 init_state 函数引用
155
- state_origin = gr.State()
156
- state_candidates = gr.State([])
157
- state_instruction = gr.State("")
158
- state_selection = gr.State([])
159
-
160
- # @gr.render 需要 Gradio 5.0+
161
- @gr.render(inputs=[state_main, state_origin, state_candidates, state_selection, state_instruction])
162
- def render_content(main_st, origin, candidates, selection, instruction):
163
- if main_st["is_finished"]:
164
- gr.Markdown("## 🎉 测试结束!\n感谢您的参与,所有结果已保存。")
165
- return
166
-
167
- idx = main_st["index"]
168
- total = len(ALL_GROUP_IDS)
169
- gr.Markdown(f"### 任务 ({idx + 1} / {total if total > 0 else 0})\n\n{instruction}")
170
 
171
- if total == 0:
172
- gr.Markdown("⚠️ **错误**:未找到任何数据。请检查 Dataset 设置。")
173
- return
174
-
175
- if origin:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  with gr.Row():
177
- with gr.Column(scale=1):
178
- gr.Image(origin, label="Reference (参考原图)", interactive=False, height=300)
179
- with gr.Column(scale=2):
180
- gr.Markdown("👈 **请参考左侧原图**,并在下方选择您认为质量最好的图片(可多选)。")
181
-
182
- with gr.Row(wrap=True):
183
- for i, item in enumerate(candidates):
184
- is_selected = i in selection
185
- with gr.Column(min_width=200):
186
- gr.Image(item["path"], show_label=False, interactive=False)
187
- btn_text = f"✅ {item['label']} (已选)" if is_selected else f"⬜️ {item['label']} (点击选择)"
188
- btn_variant = "primary" if is_selected else "secondary"
189
-
190
- btn = gr.Button(btn_text, variant=btn_variant)
191
-
192
- def toggle(idx, current_sel):
193
- if idx in current_sel: current_sel.remove(idx)
194
- else: current_sel.append(idx)
195
- current_sel.sort()
196
- return current_sel
197
-
198
- btn.click(fn=toggle, inputs=[gr.Number(i, visible=False), state_selection], outputs=[state_selection])
199
-
200
- with gr.Row():
201
- btn_submit = gr.Button("🚀 提交 (Submit)", variant="primary", scale=2)
202
- btn_none = gr.Button("🚫 都不满意 (None)", variant="stop", scale=1)
203
-
204
- def load_first(main_st):
205
- return next_question_data(main_st)
206
-
207
- demo.load(load_first, inputs=[state_main], outputs=[state_main, state_origin, state_candidates, state_instruction])
208
-
209
- def on_submit(main_st, cands, sel):
210
- new_main, new_sel = submit_logic(main_st, cands, sel, is_none=False)
211
- updated_main, origin, new_cands, instr = next_question_data(new_main)
212
- return updated_main, new_sel, origin, new_cands, instr
213
-
214
- btn_submit.click(on_submit, inputs=[state_main, state_candidates, state_selection], outputs=[state_main, state_selection, state_origin, state_candidates, state_instruction])
215
 
216
- def on_none(main_st, cands, sel):
217
- new_main, new_sel = submit_logic(main_st, cands, sel, is_none=True)
218
- updated_main, origin, new_cands, instr = next_question_data(new_main)
219
- return updated_main, new_sel, origin, new_cands, instr
220
-
221
- btn_none.click(on_none, inputs=[state_main, state_candidates, state_selection], outputs=[state_main, state_selection, state_origin, state_candidates, state_instruction])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
  if __name__ == "__main__":
224
  demo.launch()
 
14
  LOG_FOLDER.mkdir(parents=True, exist_ok=True)
15
  TOKEN = os.environ.get("HF_TOKEN")
16
 
17
+ # --- 2. 自动下载数据 (保证不为空) ---
 
18
  if not os.path.exists(DATA_FOLDER) or not os.listdir(DATA_FOLDER):
19
  try:
20
+ print("🚀 正在从 Dataset 下载数据...")
21
  snapshot_download(
22
  repo_id=DATASET_REPO_ID,
23
  repo_type="dataset",
 
25
  token=TOKEN,
26
  allow_patterns=["*.jpg", "*.png", "*.jpeg", "*.webp", "*.txt"]
27
  )
28
+ print("✅ 数据下载完成!")
29
  except Exception as e:
30
+ print(f"⚠️ 下载失败 (请检查 Token 或网络): {e}")
31
 
32
+ # --- 3. 启动同步调度器 ---
33
  scheduler = CommitScheduler(
34
  repo_id=DATASET_REPO_ID,
35
  repo_type="dataset",
36
  folder_path=LOG_FOLDER,
37
+ path_in_repo="logs",
38
  every=1,
39
  token=TOKEN
40
  )
41
 
42
+ # --- 4. 数据加载 ---
43
  def load_data():
44
  groups = {}
45
  if not os.path.exists(DATA_FOLDER):
 
46
  return {}, []
47
 
48
  for filename in os.listdir(DATA_FOLDER):
 
78
 
79
  ALL_GROUPS, ALL_GROUP_IDS = load_data()
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # --- 5. 核心逻辑 ---
82
 
83
+ def get_next_question(user_state):
84
+ """准备下一题的数据"""
 
 
85
  idx = user_state["index"]
86
+
87
+ # 1. 结束判断
88
  if idx >= len(ALL_GROUP_IDS):
89
+ return (
90
+ gr.update(visible=False), # Origin
91
+ gr.update(visible=False), # Gallery
92
+ gr.update(visible=False), # Checkbox
93
+ gr.update(visible=False), # Instruction
94
+ gr.update(visible=False), # Submit Btn
95
+ gr.update(visible=False), # None Btn
96
+ gr.update(value="## 🎉 测试结束!感谢您的参与。", visible=True),
97
+ user_state,
98
+ [] # current_candidates_info
99
+ )
100
+
101
+ # 2. 获取数据
102
  group_id = ALL_GROUP_IDS[idx]
103
  group_data = ALL_GROUPS[group_id]
104
 
105
+ # 3. 准备原图
106
  origin_path = group_data["origin"]
107
+
108
+ # 4. 准备候选图(打乱)
109
  candidates = group_data["candidates"].copy()
110
  random.shuffle(candidates)
111
 
112
+ # 构造 Gallery 数据 [(path, label), ...] 和 选项列表 ["Option A", ...]
113
+ gallery_items = []
114
+ choices = []
115
+ candidates_info = [] # 用于存储真实路径,方便后续查找
116
+
117
  for i, path in enumerate(candidates):
118
+ label = f"Option {chr(65+i)}" # Option A, Option B...
119
+ gallery_items.append((path, label))
120
+ choices.append(label)
121
+ candidates_info.append({"label": label, "path": path})
122
+
123
+ instruction = f"### 任务 ({idx + 1} / {len(ALL_GROUP_IDS)})\n\n{group_data['instruction']}"
124
+
125
+ return (
126
+ gr.update(value=origin_path, visible=True if origin_path else False),
127
+ gr.update(value=gallery_items, visible=True),
128
+ gr.update(choices=choices, value=[], visible=True), # 重置多选框
129
+ gr.update(value=instruction, visible=True),
130
+ gr.update(visible=True),
131
+ gr.update(visible=True),
132
+ gr.update(visible=False),
133
+ user_state,
134
+ candidates_info
135
+ )
136
 
137
+ def save_and_next(user_state, candidates_info, selected_options, is_none=False):
138
+ """保存并进入下一题"""
139
  current_idx = user_state["index"]
140
  group_id = ALL_GROUP_IDS[current_idx]
141
+
142
+ # --- 保存逻辑 ---
143
  if is_none:
144
+ choice_str = "Rejected All"
145
+ method_str = "None_Satisfied"
146
  else:
147
+ # 检查是否未选
148
+ if not selected_options:
149
+ raise gr.Error("请至少勾选一个选项,或点击“都不满意”")
 
 
 
 
 
 
 
 
 
150
 
151
+ choice_str = "; ".join(selected_options)
152
+
153
+ # 查找对应的方法名
154
+ selected_methods = []
155
+ for opt in selected_options: # opt is "Option A"
156
+ # 找到对应的文件路径
157
+ for info in candidates_info:
158
+ if info["label"] == opt:
159
+ path = info["path"]
160
+ filename = os.path.basename(path)
161
+ name = os.path.splitext(filename)[0]
162
+ # 简单清洗文件名拿到方法名
163
+ parts = name.split('_', 1)
164
+ method = parts[1] if len(parts) > 1 else name
165
+ selected_methods.append(method)
166
+ break
167
+ method_str = "; ".join(selected_methods)
168
 
169
+ # 写入 CSV
170
+ user_file = LOG_FOLDER / f"user_{user_state['user_id']}.csv"
171
+ with scheduler.lock:
172
+ exists = user_file.exists()
173
+ with open(user_file, "a", newline="", encoding="utf-8") as f:
174
+ writer = csv.writer(f)
175
+ if not exists:
176
+ writer.writerow(["user_id", "timestamp", "group_id", "choices", "methods"])
177
+ writer.writerow([
178
+ user_state["user_id"],
179
+ datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
180
+ group_id,
181
+ choice_str,
182
+ method_str
183
+ ])
184
+
185
+ # --- 状态更新 ---
186
  user_state["index"] += 1
187
+
188
+ # --- 加载下一题 ---
189
+ return get_next_question(user_state)
190
 
191
+ # --- 6. 界面构建 ---
192
  with gr.Blocks(title="User Study") as demo:
193
 
194
+ # 状态变量
195
+ state_user = gr.State({"user_id": str(uuid.uuid4())[:8], "index": 0})
196
+ state_candidates_info = gr.State([]) # 存当前页面的候选图信息,用于把 Option A 映射回文件名
197
+
198
+ # 布局
199
+ with gr.Row():
200
+ md_instruction = gr.Markdown("Loading...")
201
+
202
+ with gr.Row():
203
+ # 左侧:原图
204
+ with gr.Column(scale=1):
205
+ img_origin = gr.Image(label="Reference (参考原图)", interactive=False, height=400)
 
 
 
 
206
 
207
+ # 右侧:候选图 + 选择区
208
+ with gr.Column(scale=2):
209
+ # 1. Gallery 显示所有选项
210
+ gallery_candidates = gr.Gallery(
211
+ label="Candidates (候选结果)",
212
+ columns=[2],
213
+ height="auto",
214
+ object_fit="contain",
215
+ interactive=False # 禁止 Gallery 自身的点击选择,用下面的 Checkbox 替代
216
+ )
217
+
218
+ gr.Markdown("👇 **请在下方勾选您认为最好的结果(可多选):**")
219
+
220
+ # 2. 多选框 (核心交互)
221
+ checkbox_options = gr.CheckboxGroup(
222
+ choices=[],
223
+ label="您的选择",
224
+ info="对应上方图片的标签 (Option A, B...)"
225
+ )
226
+
227
  with gr.Row():
228
+ btn_submit = gr.Button("🚀 提交 (Submit)", variant="primary")
229
+ btn_none = gr.Button("🚫 都不满意 (None)", variant="stop")
230
+
231
+ md_end = gr.Markdown(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
+ # --- 事件绑定 ---
234
+
235
+ # 1. 页面加载时,加载第一题
236
+ demo.load(
237
+ fn=get_next_question,
238
+ inputs=[state_user],
239
+ outputs=[img_origin, gallery_candidates, checkbox_options, md_instruction, btn_submit, btn_none, md_end, state_user, state_candidates_info]
240
+ )
241
+
242
+ # 2. 提交按钮
243
+ btn_submit.click(
244
+ fn=lambda s, c, o: save_and_next(s, c, o, is_none=False),
245
+ inputs=[state_user, state_candidates_info, checkbox_options],
246
+ outputs=[img_origin, gallery_candidates, checkbox_options, md_instruction, btn_submit, btn_none, md_end, state_user, state_candidates_info]
247
+ )
248
+
249
+ # 3. 都不满意按钮
250
+ btn_none.click(
251
+ fn=lambda s, c, o: save_and_next(s, c, o, is_none=True),
252
+ inputs=[state_user, state_candidates_info, checkbox_options],
253
+ outputs=[img_origin, gallery_candidates, checkbox_options, md_instruction, btn_submit, btn_none, md_end, state_user, state_candidates_info]
254
+ )
255
 
256
  if __name__ == "__main__":
257
  demo.launch()