Emilyxml commited on
Commit
74af1a5
·
verified ·
1 Parent(s): 746eee8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -99
app.py CHANGED
@@ -8,213 +8,276 @@ from pathlib import Path
8
  from huggingface_hub import CommitScheduler
9
 
10
  # --- 1. 配置区域 ---
11
-
12
- # 你的数据集地址 (已修改)
13
- DATASET_REPO_ID = "Emilyxml/moveit"
14
-
15
- # 数据源文件夹 (读取你的图片和txt)
16
- DATA_FOLDER = "data"
17
-
18
- # 临时日志文件夹 (用于存放用户生成的CSV,Scheduler 会监控这里)
19
- LOG_FOLDER = Path("logs")
20
  LOG_FOLDER.mkdir(parents=True, exist_ok=True)
21
-
22
- # 获取 Token (需要在 Space 设置里配置 HF_TOKEN)
23
  TOKEN = os.environ.get("HF_TOKEN")
24
 
25
  # --- 2. 启动同步调度器 ---
26
- # 只要 logs 文件夹里有 CSV 变化,就自动上传到 Dataset 的 data 文件夹下
27
  scheduler = CommitScheduler(
28
  repo_id=DATASET_REPO_ID,
29
  repo_type="dataset",
30
- folder_path=LOG_FOLDER, # 监控本地的 logs 文件夹
31
- path_in_repo="data", # 上传到 Dataset 的 data 目录中
32
- every=1, # 每分钟同步一次
33
  token=TOKEN
34
  )
35
 
36
- # --- 3. 数据加载逻辑 ---
37
  def load_data():
38
  groups = {}
39
-
40
- # 检查 data 文件夹是否存在
41
  if not os.path.exists(DATA_FOLDER):
42
- # 如果不存在,尝试创建(防止报错),但实际应该由你上传文件
43
  os.makedirs(DATA_FOLDER, exist_ok=True)
44
- print(f"Warning: {DATA_FOLDER} not found. Please upload your images.")
45
  return {}, []
46
 
47
- # 遍历文件
48
  for filename in os.listdir(DATA_FOLDER):
49
- if filename.startswith('.'): continue # 跳过隐藏文件
50
-
51
  file_path = os.path.join(DATA_FOLDER, filename)
52
- prefix = filename[:5] # 以前5个字符作为组ID
53
 
54
  if prefix not in groups:
55
  groups[prefix] = {"images": [], "instruction": "暂无说明"}
56
 
57
- # 识别图片
58
- if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.webp', '.bmp')):
59
  groups[prefix]["images"].append(file_path)
60
- # 识别文本
61
  elif filename.lower().endswith('.txt'):
62
  try:
63
  with open(file_path, "r", encoding="utf-8") as f:
64
  groups[prefix]["instruction"] = f.read()
65
  except:
66
- # 兼容 gbk 编码
67
  with open(file_path, "r", encoding="gbk") as f:
68
  groups[prefix]["instruction"] = f.read()
69
 
70
- # 过滤掉没有图片的组
71
  valid_groups = {k: v for k, v in groups.items() if len(v["images"]) > 0}
72
-
73
- # 生成题目列表,并随机打乱
74
  group_ids = list(valid_groups.keys())
75
  random.shuffle(group_ids)
76
-
77
- print(f"Loaded {len(group_ids)} groups of images.")
78
  return valid_groups, group_ids
79
 
80
- # 全局加载数据
81
  ALL_GROUPS, ALL_GROUP_IDS = load_data()
82
 
83
- # --- 4. 保存逻辑 (每个用户一个独立CSV) ---
84
- def save_user_vote(user_id, group_id, choice_label, method_name):
85
  """
86
- 保存单次投票到 logs/user_{user_id}.csv
 
 
87
  """
88
  user_filename = f"user_{user_id}.csv"
89
  user_file_path = LOG_FOLDER / user_filename
90
 
91
- # 数据行
92
  row = [
93
  user_id,
94
  datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
95
  group_id,
96
- choice_label, # 用户选了 Option A 还是 B
97
- method_name # 真实的方法名
98
  ]
99
 
100
- # 线程安全写入
101
  with scheduler.lock:
102
  file_exists = user_file_path.exists()
103
  with user_file_path.open("a", newline="", encoding="utf-8") as f:
104
  writer = csv.writer(f)
105
- # 如果是新文件,先写表头
106
  if not file_exists:
107
- writer.writerow(["user_id", "timestamp", "group_id", "selected_label", "selected_method"])
108
  writer.writerow(row)
109
 
110
- print(f"Saved vote for {user_id}: {method_name}")
 
 
111
 
112
- # --- 5. 交互逻辑 ---
113
- def get_next_question(user_state):
114
  current_idx = user_state["index"]
115
 
116
- # 1. 检查是否做完
117
  if current_idx >= len(ALL_GROUP_IDS):
118
  return (
 
 
 
119
  gr.update(visible=False),
120
- gr.update(visible=False),
121
- gr.update(visible=False),
122
- gr.update(value="## 🎉 测试结束!\n感谢您的参与,您的选择已保存。", visible=True),
123
  user_state,
124
- []
 
125
  )
126
 
127
- # 2. 获取当前组
128
  group_id = ALL_GROUP_IDS[current_idx]
129
  group_data = ALL_GROUPS[group_id]
130
 
131
- # 3. 准备 Prompt
132
- instruction_text = f"## 任务 ({current_idx + 1} / {len(ALL_GROUP_IDS)})\n\n{group_data['instruction']}"
133
 
134
- # 4. 准备图片 (打乱顺序实现盲测)
135
  original_images = group_data["images"]
136
  shuffled_images = original_images.copy()
137
  random.shuffle(shuffled_images)
138
 
139
- # 构造 Gradio 显示对象
140
  display_list = []
141
  for i, img_path in enumerate(shuffled_images):
142
  label = f"Option {chr(65+i)}" # Option A, Option B...
143
  display_list.append((img_path, label))
144
-
 
 
 
 
145
  return (
146
  gr.update(value=instruction_text, visible=True),
147
- gr.update(value=display_list, visible=True),
148
- gr.update(visible=True),
149
- gr.update(visible=False),
 
150
  user_state,
151
- shuffled_images # 将乱序后的真实路径列表传给 State
 
152
  )
153
 
154
- def on_vote(user_state, current_file_paths, select_data: gr.SelectData = None, is_none=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  user_id = user_state["user_id"]
156
  current_idx = user_state["index"]
157
 
158
- # 防止溢出
159
  if current_idx >= len(ALL_GROUP_IDS):
160
- return get_next_question(user_state)
161
-
162
  group_id = ALL_GROUP_IDS[current_idx]
163
- selected_method = "Unknown"
164
- selected_label = "None"
165
 
166
- # --- 解析选择 ---
167
  if is_none:
168
- selected_method = "None_Satisfied"
169
- selected_label = "Rejected All"
170
- elif select_data is not None:
171
- idx = select_data.index
172
- # 获取真实路径
173
- real_image_path = current_file_paths[idx]
174
- selected_label = select_data.value["caption"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
- # 提取方法名 (例如 6180a_omnigen.png -> omnigen)
177
- filename = os.path.basename(real_image_path)
 
178
  name_no_ext = os.path.splitext(filename)[0]
179
  parts = name_no_ext.split('_', 1)
180
- if len(parts) > 1:
181
- selected_method = parts[1]
182
- else:
183
- selected_method = name_no_ext
184
-
185
- # --- 保存 ---
186
- save_user_vote(user_id, group_id, selected_label, selected_method)
187
-
188
- # --- 下一题 ---
 
189
  user_state["index"] += 1
190
- return get_next_question(user_state)
191
 
192
  # --- 6. 界面构建 ---
193
- with gr.Blocks(title="User Study - MoveIt") as demo:
194
- # State 初始化:每次刷新网页生成新的 user_id
 
195
  state_user = gr.State(lambda: {"user_id": str(uuid.uuid4())[:8], "index": 0})
196
- state_files = gr.State([])
 
197
 
198
- with gr.Column(elem_id="main"):
199
  instruction_md = gr.Markdown("Loading...")
200
 
 
201
  gallery = gr.Gallery(
202
- label="请点击选择最佳图片",
203
- columns=[2],
 
204
  height="auto",
205
  interactive=True
206
  )
207
 
208
- btn_none = gr.Button("🚫 没有任何一张图片符合要求", variant="stop")
 
 
 
 
 
 
209
 
210
  end_msg = gr.Markdown(visible=False)
211
 
212
- # 事件绑定
213
- demo.load(fn=get_next_question, inputs=[state_user], outputs=[instruction_md, gallery, btn_none, end_msg, state_user, state_files])
 
 
 
 
 
 
214
 
215
- gallery.select(fn=lambda s, f, evt: on_vote(s, f, evt, is_none=False), inputs=[state_user, state_files], outputs=[instruction_md, gallery, btn_none, end_msg, state_user, state_files])
 
 
 
 
 
216
 
217
- btn_none.click(fn=lambda s, f: on_vote(s, f, None, is_none=True), inputs=[state_user, state_files], outputs=[instruction_md, gallery, btn_none, end_msg, state_user, state_files])
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
  if __name__ == "__main__":
220
- demo.launch()
 
8
  from huggingface_hub import CommitScheduler
9
 
10
  # --- 1. 配置区域 ---
11
+ DATASET_REPO_ID = "Emilyxml/moveit" # 你的数据集
12
+ DATA_FOLDER = "data" # 数据文件夹
13
+ LOG_FOLDER = Path("logs") # 本地日志
 
 
 
 
 
 
14
  LOG_FOLDER.mkdir(parents=True, exist_ok=True)
 
 
15
  TOKEN = os.environ.get("HF_TOKEN")
16
 
17
  # --- 2. 启动同步调度器 ---
 
18
  scheduler = CommitScheduler(
19
  repo_id=DATASET_REPO_ID,
20
  repo_type="dataset",
21
+ folder_path=LOG_FOLDER,
22
+ path_in_repo="data",
23
+ every=1,
24
  token=TOKEN
25
  )
26
 
27
+ # --- 3. 数据加载逻辑 (保持不变) ---
28
  def load_data():
29
  groups = {}
 
 
30
  if not os.path.exists(DATA_FOLDER):
 
31
  os.makedirs(DATA_FOLDER, exist_ok=True)
 
32
  return {}, []
33
 
 
34
  for filename in os.listdir(DATA_FOLDER):
35
+ if filename.startswith('.'): continue
 
36
  file_path = os.path.join(DATA_FOLDER, filename)
37
+ prefix = filename[:5]
38
 
39
  if prefix not in groups:
40
  groups[prefix] = {"images": [], "instruction": "暂无说明"}
41
 
42
+ if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
 
43
  groups[prefix]["images"].append(file_path)
 
44
  elif filename.lower().endswith('.txt'):
45
  try:
46
  with open(file_path, "r", encoding="utf-8") as f:
47
  groups[prefix]["instruction"] = f.read()
48
  except:
 
49
  with open(file_path, "r", encoding="gbk") as f:
50
  groups[prefix]["instruction"] = f.read()
51
 
 
52
  valid_groups = {k: v for k, v in groups.items() if len(v["images"]) > 0}
 
 
53
  group_ids = list(valid_groups.keys())
54
  random.shuffle(group_ids)
55
+ print(f"Loaded {len(group_ids)} groups.")
 
56
  return valid_groups, group_ids
57
 
 
58
  ALL_GROUPS, ALL_GROUP_IDS = load_data()
59
 
60
+ # --- 4. 保存逻辑 (支持多选保存) ---
61
+ def save_user_vote(user_id, group_id, choice_labels, method_names):
62
  """
63
+ 保存投票。
64
+ choice_labels: 字符串,例如 "Option A; Option B"
65
+ method_names: 字符串,例如 "omnigen; sdxl"
66
  """
67
  user_filename = f"user_{user_id}.csv"
68
  user_file_path = LOG_FOLDER / user_filename
69
 
 
70
  row = [
71
  user_id,
72
  datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
73
  group_id,
74
+ choice_labels,
75
+ method_names
76
  ]
77
 
 
78
  with scheduler.lock:
79
  file_exists = user_file_path.exists()
80
  with user_file_path.open("a", newline="", encoding="utf-8") as f:
81
  writer = csv.writer(f)
 
82
  if not file_exists:
83
+ writer.writerow(["user_id", "timestamp", "group_id", "selected_labels", "selected_methods"])
84
  writer.writerow(row)
85
 
86
+ print(f"Saved: User {user_id} selected {method_names}")
87
+
88
+ # --- 5. 交互逻辑 (多选核心) ---
89
 
90
+ def get_current_question_ui(user_state):
91
+ """根据当前索引刷新界面"""
92
  current_idx = user_state["index"]
93
 
94
+ # 1. 检查是否结束
95
  if current_idx >= len(ALL_GROUP_IDS):
96
  return (
97
+ gr.update(visible=False),
98
+ gr.update(visible=False),
99
+ gr.update(visible=False),
100
  gr.update(visible=False),
101
+ gr.update(value="## 🎉 测试结束!\n感谢您的参与,所有结果已保存。", visible=True),
 
 
102
  user_state,
103
+ [], # 清空文件路径
104
+ [] # 清空当前选中的索引
105
  )
106
 
107
+ # 2. 获取数据
108
  group_id = ALL_GROUP_IDS[current_idx]
109
  group_data = ALL_GROUPS[group_id]
110
 
111
+ # 3. 准备文本
112
+ instruction_text = f"### 任务 ({current_idx + 1} / {len(ALL_GROUP_IDS)})\n\n{group_data['instruction']}"
113
 
114
+ # 4. 准备图片 (盲测 + 打乱)
115
  original_images = group_data["images"]
116
  shuffled_images = original_images.copy()
117
  random.shuffle(shuffled_images)
118
 
119
+ # 构造显示列表
120
  display_list = []
121
  for i, img_path in enumerate(shuffled_images):
122
  label = f"Option {chr(65+i)}" # Option A, Option B...
123
  display_list.append((img_path, label))
124
+
125
+ # 动态列数
126
+ num_imgs = len(shuffled_images)
127
+ cols = 2 if num_imgs == 4 else min(num_imgs, 3)
128
+
129
  return (
130
  gr.update(value=instruction_text, visible=True),
131
+ gr.update(value=display_list, columns=cols, visible=True), # Gallery
132
+ gr.update(value="当前未选择任何图片", visible=True), # 状态栏重置
133
+ gr.update(visible=True), # 按钮区可见
134
+ gr.update(visible=False), # 结束语
135
  user_state,
136
+ shuffled_images,
137
+ [] # 重置选中的索引列表
138
  )
139
 
140
+ def toggle_selection(evt: gr.SelectData, current_indices):
141
+ """
142
+ 处理图片点击:
143
+ 点击一次 -> 选中
144
+ 再点一次 -> 取消选中
145
+ """
146
+ clicked_idx = evt.index
147
+
148
+ # 切换状态
149
+ if clicked_idx in current_indices:
150
+ current_indices.remove(clicked_idx)
151
+ else:
152
+ current_indices.append(clicked_idx)
153
+
154
+ # 排序一下,让显示更好看 (Option A, Option B)
155
+ current_indices.sort()
156
+
157
+ # 更新状态文本
158
+ if not current_indices:
159
+ status_text = "当前未选择任何图片"
160
+ else:
161
+ labels = [f"Option {chr(65+i)}" for i in current_indices]
162
+ status_text = "已选中: " + ", ".join(labels)
163
+
164
+ return current_indices, status_text
165
+
166
+ def submit_vote(user_state, current_file_paths, current_indices, is_none=False):
167
+ """
168
+ 提交投票(可能是多选,可能是None)
169
+ """
170
  user_id = user_state["user_id"]
171
  current_idx = user_state["index"]
172
 
 
173
  if current_idx >= len(ALL_GROUP_IDS):
174
+ return get_current_question_ui(user_state)
175
+
176
  group_id = ALL_GROUP_IDS[current_idx]
 
 
177
 
178
+ # --- 场景1: 都没有 ---
179
  if is_none:
180
+ save_user_vote(user_id, group_id, "Rejected All", "None_Satisfied")
181
+ user_state["index"] += 1
182
+ return get_current_question_ui(user_state)
183
+
184
+ # --- 场景2: 提交多选 ---
185
+ if not current_indices:
186
+ # 如果用户没选图片就点了提交,弹窗提示或者不做反应
187
+ # 这里为了简单,返回原样,并提示
188
+ return (
189
+ gr.update(), gr.update(),
190
+ gr.update(value="❌ 请至少选择一张图片,或者点击“都不满意”"),
191
+ gr.update(), gr.update(),
192
+ user_state, current_file_paths, current_indices
193
+ )
194
+
195
+ # 解析所有选中的图片
196
+ selected_labels = []
197
+ selected_methods = []
198
+
199
+ for idx in current_indices:
200
+ # 1. 记录 Option X
201
+ label = f"Option {chr(65+idx)}"
202
+ selected_labels.append(label)
203
 
204
+ # 2. 提取方法名
205
+ real_path = current_file_paths[idx]
206
+ filename = os.path.basename(real_path)
207
  name_no_ext = os.path.splitext(filename)[0]
208
  parts = name_no_ext.split('_', 1)
209
+ method = parts[1] if len(parts) > 1 else name_no_ext
210
+ selected_methods.append(method)
211
+
212
+ # 用分号连接 (CSV友好)
213
+ str_labels = "; ".join(selected_labels)
214
+ str_methods = "; ".join(selected_methods)
215
+
216
+ save_user_vote(user_id, group_id, str_labels, str_methods)
217
+
218
+ # 下一题
219
  user_state["index"] += 1
220
+ return get_current_question_ui(user_state)
221
 
222
  # --- 6. 界面构建 ---
223
+ with gr.Blocks(title="Multi-Select User Study", theme=gr.themes.Soft()) as demo:
224
+
225
+ # 状态变量
226
  state_user = gr.State(lambda: {"user_id": str(uuid.uuid4())[:8], "index": 0})
227
+ state_files = gr.State([]) # 存当前图片的真实路径
228
+ state_indices = gr.State([]) # 存当前选中的图片索引 [0, 2]
229
 
230
+ with gr.Column():
231
  instruction_md = gr.Markdown("Loading...")
232
 
233
+ # 图片区
234
  gallery = gr.Gallery(
235
+ label="请点击选择图片(可多选)",
236
+ allow_preview=True,
237
+ object_fit="contain",
238
  height="auto",
239
  interactive=True
240
  )
241
 
242
+ # 状态显示区(告诉用户选了啥)
243
+ status_box = gr.Textbox(value="当前未选择任何图片", label="当前选中状态", interactive=False)
244
+
245
+ # 按钮区
246
+ with gr.Row():
247
+ btn_submit = gr.Button("✅ 提交选择 (Confirm Selection)", variant="primary", scale=2)
248
+ btn_none = gr.Button("🚫 都不满意 (None of them)", variant="stop", scale=1)
249
 
250
  end_msg = gr.Markdown(visible=False)
251
 
252
+ # --- 事件流 ---
253
+
254
+ # 1. 启动加载
255
+ demo.load(
256
+ fn=get_current_question_ui,
257
+ inputs=[state_user],
258
+ outputs=[instruction_md, gallery, status_box, btn_submit, end_msg, state_user, state_files, state_indices]
259
+ )
260
 
261
+ # 2. 点击图片 -> 切换选中状态 (不翻页)
262
+ gallery.select(
263
+ fn=toggle_selection,
264
+ inputs=[state_indices],
265
+ outputs=[state_indices, status_box]
266
+ )
267
 
268
+ # 3. 点击提交 -> 保存并下一页
269
+ btn_submit.click(
270
+ fn=lambda s, f, i: submit_vote(s, f, i, is_none=False),
271
+ inputs=[state_user, state_files, state_indices],
272
+ outputs=[instruction_md, gallery, status_box, btn_submit, end_msg, state_user, state_files, state_indices]
273
+ )
274
+
275
+ # 4. 点击都不满意 -> 保存并下一页
276
+ btn_none.click(
277
+ fn=lambda s, f, i: submit_vote(s, f, i, is_none=True),
278
+ inputs=[state_user, state_files, state_indices],
279
+ outputs=[instruction_md, gallery, status_box, btn_submit, end_msg, state_user, state_files, state_indices]
280
+ )
281
 
282
  if __name__ == "__main__":
283
+ demo.launch()