intersteller2887 commited on
Commit
2d4a3ed
·
verified ·
1 Parent(s): 655f8d4

进一步修改了初始化试题时一系列文件锁相关逻辑,添加了后端相关函数的详细注释,调整了整体结构(还未调整提交试题时的文件锁相关逻辑)

Browse files
Files changed (1) hide show
  1. app.py +94 -59
app.py CHANGED
@@ -13,6 +13,7 @@ from huggingface_hub import HfApi, hf_hub_download
13
  from multiprocessing import TimeoutError
14
  from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
15
 
 
16
  dataset = load_dataset("intersteller2887/Turing-test-dataset", split="train")
17
  dataset = dataset.cast_column("audio", Audio(decode=False)) # Prevent calling 'torchcodec' from newer version of 'datasets'
18
 
@@ -41,7 +42,7 @@ sample1_audio_path = local_audio_paths[0]
41
  print(sample1_audio_path)
42
 
43
  # ==============================================================================
44
- # 数据定义 (Data Definition)
45
  # ==============================================================================
46
 
47
  DIMENSIONS_DATA = [
@@ -98,11 +99,9 @@ DIMENSION_TITLES = [d["title"] for d in DIMENSIONS_DATA]
98
  MAX_SUB_DIMS = max(len(d['sub_dims']) for d in DIMENSIONS_DATA)
99
 
100
  # ==============================================================================
101
- # Function Definitions
102
  # ==============================================================================
103
 
104
- # Function that load or initialize count.json
105
-
106
  # This version did not place file reading into filelock, concurrent read could happen
107
  """def load_or_initialize_count_json(audio_paths):
108
  try:
@@ -154,28 +153,39 @@ MAX_SUB_DIMS = max(len(d['sub_dims']) for d in DIMENSIONS_DATA)
154
 
155
  return count_data"""
156
 
 
 
 
 
 
 
 
157
  # This version also places file reading into filelock, and modified
158
  def load_or_initialize_count_json(audio_paths):
159
  # Add filelock to /workspace/count.json
160
  lock_path = COUNT_JSON_PATH + ".lock"
161
  with FileLock(lock_path, timeout=10):
162
- # Only try downloading if file doesn't exist yet
163
  if not os.path.exists(COUNT_JSON_PATH):
164
  try:
 
165
  downloaded_path = hf_hub_download(
166
  repo_id="intersteller2887/Turing-test-dataset",
167
  repo_type="dataset",
168
  filename=COUNT_JSON_REPO_PATH,
169
  token=os.getenv("HF_TOKEN")
170
  )
 
 
171
  except Exception:
172
  pass
173
 
174
- # If count.json exists: load into count_data
175
  if os.path.exists(COUNT_JSON_PATH):
176
  with open(COUNT_JSON_PATH, "r", encoding="utf-8") as f:
177
  count_data = json.load(f, object_pairs_hook=collections.OrderedDict)
178
  # Else initialize count_data with orderedDict
 
179
  else:
180
  count_data = collections.OrderedDict()
181
 
@@ -193,65 +203,86 @@ def load_or_initialize_count_json(audio_paths):
193
  count_data[filename] = 0
194
  updated = True
195
 
 
196
  if updated or not os.path.exists(COUNT_JSON_PATH):
197
  with open(COUNT_JSON_PATH, "w", encoding="utf-8") as f:
198
  json.dump(count_data, f, indent=4, ensure_ascii=False)
199
 
200
- return count_data
201
 
202
  # Shorten the time of playing previous audio when reached next question
203
  def append_cache_buster(audio_path):
204
  return f"{audio_path}?t={int(time.time() * 1000)}"
205
 
206
- # Shorten the time of playing previous audio when reached next question
207
- def append_cache_buster(audio_path):
208
- return f"{audio_path}?t={int(time.time() * 1000)}"
209
 
210
- """def sample_audio_paths(audio_paths, count_data, k=5, max_count=1):
 
211
  eligible_paths = [p for p in audio_paths if count_data.get(os.path.basename(p), 0) < max_count]
212
-
213
  if len(eligible_paths) < k:
214
  raise ValueError(f"可用音频数量不足(只剩 {len(eligible_paths)} 条 count<{max_count} 的音频),无法抽取 {k} 条")
215
 
216
- eligible_paths_copy = eligible_paths.copy()
217
-
218
- random.seed(int(time.time()))
219
-
220
- selected = random.sample(eligible_paths_copy, k)
221
 
 
222
  for path in selected:
223
  filename = os.path.basename(path)
224
  count_data[filename] = count_data.get(filename, 0) + 1
225
 
226
- with open(COUNT_JSON_PATH, "w", encoding="utf-8") as f:
227
- json.dump(count_data, f, indent=4, ensure_ascii=False)
 
 
 
228
 
229
  return selected, count_data"""
230
 
231
- def sample_audio_paths(audio_paths, count_data, k=5, max_count=1): # k for questions per test; max_count for question limit in total
232
- eligible_paths = [p for p in audio_paths if count_data.get(os.path.basename(p), 0) < max_count]
 
 
233
 
234
- if len(eligible_paths) < k:
235
- raise ValueError(f"可用音频数量不足(只剩 {len(eligible_paths)} 条 count<{max_count} 的音频),无法抽取 {k} 条")
 
 
 
 
 
 
 
236
 
237
- selected = random.sample(eligible_paths, k)
 
238
 
239
- for path in selected:
240
- filename = os.path.basename(path)
241
- count_data[filename] = count_data.get(filename, 0) + 1
242
 
243
- lock_path = COUNT_JSON_PATH + ".lock"
244
- with FileLock(lock_path, timeout=10):
 
 
 
 
245
  with open(COUNT_JSON_PATH, "w", encoding="utf-8") as f:
246
  json.dump(count_data, f, indent=4, ensure_ascii=False)
247
 
248
- return selected, count_data
 
 
 
 
 
 
249
 
250
  # Save question_set in each user_data_state, preventing global sharing
251
  def start_challenge(user_data_state):
252
 
253
- count_data = load_or_initialize_count_json(all_data_audio_paths)
254
- selected_audio_paths, updated_count_data = sample_audio_paths(all_data_audio_paths, count_data, k=5)
 
 
255
 
256
  question_set = [
257
  {"audio": path, "desc": f"这是音频文件 {os.path.basename(path)} 的描述"}
@@ -259,13 +290,18 @@ def start_challenge(user_data_state):
259
  ]
260
 
261
  user_data_state["question_set"] = question_set
262
- user_data_state["updated_count_data"] = updated_count_data
263
- return gr.update(visible=False), gr.update(visible=True), user_data_state
264
 
 
 
 
 
 
 
265
  def toggle_education_other(choice):
266
  is_other = (choice == "其他(请注明)")
267
  return gr.update(visible=is_other, interactive=is_other, value="")
268
 
 
269
  def check_info_complete(username, age, gender, education, education_other, ai_experience):
270
  if username.strip() and age and gender and education and ai_experience:
271
  if education == "其他(请注明)" and not education_other.strip():
@@ -273,6 +309,7 @@ def check_info_complete(username, age, gender, education, education_other, ai_ex
273
  return gr.update(interactive=True)
274
  return gr.update(interactive=False)
275
 
 
276
  def show_sample_page_and_init(username, age, gender, education, education_other, ai_experience, user_data):
277
  final_edu = education_other if education == "其他(请注明)" else education
278
  user_data.update({
@@ -341,8 +378,6 @@ def update_test_dimension_view(d_idx, selections):
341
 
342
  def init_test_question(user_data, q_idx):
343
  d_idx = 0
344
- # question = QUESTION_SET[q_idx]
345
- # progress_q = f"第 {q_idx + 1} / {len(QUESTION_SET)} 题"
346
  question = user_data["question_set"][q_idx]
347
  progress_q = f"第 {q_idx + 1} / {len(user_data['question_set'])} 题"
348
 
@@ -405,8 +440,30 @@ def navigate_dimensions(direction, q_idx, d_idx, selections, *slider_values):
405
  next_btn_update,
406
  ) + tuple(slider_updates)
407
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  # ==============================================================================
409
- # 重连函数定义 (Retry Function Definitions)
410
  # ==============================================================================
411
 
412
  # Function for handling connection error
@@ -732,28 +789,6 @@ def save_all_results_to_file(all_results, user_data, count_data=None):
732
  commit_message=f"Update count.json after submission by {username}"
733
  )
734
 
735
- def toggle_reference_view(current):
736
- if current == "参考":
737
- return gr.update(visible=False), gr.update(visible=True), gr.update(value="返回")
738
- else:
739
- return gr.update(visible=True), gr.update(visible=False), gr.update(value="参考")
740
-
741
- def back_to_welcome():
742
- return (
743
- gr.update(visible=True), # welcome_page
744
- gr.update(visible=False), # info_page
745
- gr.update(visible=False), # sample_page
746
- gr.update(visible=False), # pretest_page
747
- gr.update(visible=False), # test_page
748
- gr.update(visible=False), # final_judgment_page
749
- gr.update(visible=False), # result_page
750
- {}, # user_data_state
751
- 0, # current_question_index
752
- 0, # current_test_dimension_index
753
- {}, # current_question_selections
754
- [] # test_results
755
- )
756
-
757
  # ==============================================================================
758
  # Gradio 界面定义 (Gradio UI Definition)
759
  # ==============================================================================
 
13
  from multiprocessing import TimeoutError
14
  from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
15
 
16
+ # Load dataset from HuggingFace
17
  dataset = load_dataset("intersteller2887/Turing-test-dataset", split="train")
18
  dataset = dataset.cast_column("audio", Audio(decode=False)) # Prevent calling 'torchcodec' from newer version of 'datasets'
19
 
 
42
  print(sample1_audio_path)
43
 
44
  # ==============================================================================
45
+ # Data Definition
46
  # ==============================================================================
47
 
48
  DIMENSIONS_DATA = [
 
99
  MAX_SUB_DIMS = max(len(d['sub_dims']) for d in DIMENSIONS_DATA)
100
 
101
  # ==============================================================================
102
+ # Backend Function Definitions
103
  # ==============================================================================
104
 
 
 
105
  # This version did not place file reading into filelock, concurrent read could happen
106
  """def load_or_initialize_count_json(audio_paths):
107
  try:
 
153
 
154
  return count_data"""
155
 
156
+ # Function that load or initialize count.json
157
+ # Function is called when user start a challenge, and this will load or initialize count.json to working directory
158
+ # Initialize happens when count.json does not exist in the working directory as well as HuggingFace dataset
159
+ # Load happens when count.json exists in HuggingFace dataset, and it's not loaded to the working directory yet
160
+ # After load/initialize, all newly added audio files will be added to count.json with initial value of 0
161
+ # Load/Initialize will generate count.json in the working directory for all users under this space
162
+
163
  # This version also places file reading into filelock, and modified
164
  def load_or_initialize_count_json(audio_paths):
165
  # Add filelock to /workspace/count.json
166
  lock_path = COUNT_JSON_PATH + ".lock"
167
  with FileLock(lock_path, timeout=10):
168
+ # If count.json does not exist in the working directory, try to download it from HuggingFace dataset
169
  if not os.path.exists(COUNT_JSON_PATH):
170
  try:
171
+ # Save latest count.json to working directory
172
  downloaded_path = hf_hub_download(
173
  repo_id="intersteller2887/Turing-test-dataset",
174
  repo_type="dataset",
175
  filename=COUNT_JSON_REPO_PATH,
176
  token=os.getenv("HF_TOKEN")
177
  )
178
+ with open(downloaded_path, "rb") as src, open(COUNT_JSON_PATH, "wb") as dst:
179
+ dst.write(src.read())
180
  except Exception:
181
  pass
182
 
183
+ # If count.json exists in the working directory: load into count_data for potential update
184
  if os.path.exists(COUNT_JSON_PATH):
185
  with open(COUNT_JSON_PATH, "r", encoding="utf-8") as f:
186
  count_data = json.load(f, object_pairs_hook=collections.OrderedDict)
187
  # Else initialize count_data with orderedDict
188
+ # This happens when there is no count.json (both working directory and HuggingFace dataset)
189
  else:
190
  count_data = collections.OrderedDict()
191
 
 
203
  count_data[filename] = 0
204
  updated = True
205
 
206
+ # Write updated count_data to /home/user/app/count.json
207
  if updated or not os.path.exists(COUNT_JSON_PATH):
208
  with open(COUNT_JSON_PATH, "w", encoding="utf-8") as f:
209
  json.dump(count_data, f, indent=4, ensure_ascii=False)
210
 
211
+ return
212
 
213
  # Shorten the time of playing previous audio when reached next question
214
  def append_cache_buster(audio_path):
215
  return f"{audio_path}?t={int(time.time() * 1000)}"
216
 
217
+ # Function that samples questions from avaliable question set
 
 
218
 
219
+ # This version utilizes a given count_data to sample audio paths
220
+ """def sample_audio_paths(audio_paths, count_data, k=5, max_count=1): # k for questions per test; max_count for question limit in total
221
  eligible_paths = [p for p in audio_paths if count_data.get(os.path.basename(p), 0) < max_count]
222
+
223
  if len(eligible_paths) < k:
224
  raise ValueError(f"可用音频数量不足(只剩 {len(eligible_paths)} 条 count<{max_count} 的音频),无法抽取 {k} 条")
225
 
226
+ # Shuffule to avoid fixed selections resulted from directory structure
227
+ selected = random.sample(eligible_paths, k)
 
 
 
228
 
229
+ # Once sampled a test, update these questions immediately
230
  for path in selected:
231
  filename = os.path.basename(path)
232
  count_data[filename] = count_data.get(filename, 0) + 1
233
 
234
+ # Add filelock to /workspace/count.json
235
+ lock_path = COUNT_JSON_PATH + ".lock"
236
+ with FileLock(lock_path, timeout=10):
237
+ with open(COUNT_JSON_PATH, "w", encoding="utf-8") as f:
238
+ json.dump(count_data, f, indent=4, ensure_ascii=False)
239
 
240
  return selected, count_data"""
241
 
242
+ # This version places file reading into filelock to guarantee correct update of count.json
243
+ def sample_audio_paths(audio_paths, k=5, max_count=1):
244
+ # Add filelock to /workspace/count.json
245
+ lock_path = COUNT_JSON_PATH + ".lock"
246
 
247
+ # Load newest count.json
248
+ with FileLock(lock_path, timeout=10):
249
+ with open(COUNT_JSON_PATH, "r", encoding="utf-8") as f:
250
+ count_data = json.load(f)
251
+
252
+ eligible_paths = [
253
+ p for p in audio_paths
254
+ if count_data.get(os.path.basename(p), 0) < max_count
255
+ ]
256
 
257
+ if len(eligible_paths) < k:
258
+ raise ValueError(f"可用音频数量不足(只剩 {len(eligible_paths)} 条 count<{max_count} 的音频),无法抽取 {k} 条")
259
 
260
+ selected = random.sample(eligible_paths, k)
 
 
261
 
262
+ # Update count_data
263
+ for path in selected:
264
+ filename = os.path.basename(path)
265
+ count_data[filename] = count_data.get(filename, 0) + 1
266
+
267
+ # Update count.json
268
  with open(COUNT_JSON_PATH, "w", encoding="utf-8") as f:
269
  json.dump(count_data, f, indent=4, ensure_ascii=False)
270
 
271
+ # return selected, count_data
272
+ # Keep count_data atomic
273
+ return selected
274
+
275
+ # ==============================================================================
276
+ # Frontend Function Definitions
277
+ # ==============================================================================
278
 
279
  # Save question_set in each user_data_state, preventing global sharing
280
  def start_challenge(user_data_state):
281
 
282
+ load_or_initialize_count_json(all_data_audio_paths)
283
+ # selected_audio_paths, updated_count_data = sample_audio_paths(all_data_audio_paths, k=5)
284
+ # Keep count_data atomic
285
+ selected_audio_paths = sample_audio_paths(all_data_audio_paths, k=5)
286
 
287
  question_set = [
288
  {"audio": path, "desc": f"这是音频文件 {os.path.basename(path)} 的描述"}
 
290
  ]
291
 
292
  user_data_state["question_set"] = question_set
 
 
293
 
294
+ # count_data is not needed in the user data
295
+ # user_data_state["updated_count_data"] = updated_count_data
296
+
297
+ return gr.update(visible=False), gr.update(visible=True), user_data_state
298
+
299
+ # This function toggles the visibility of the "其他(请注明)" input field based on the selected education choice
300
  def toggle_education_other(choice):
301
  is_other = (choice == "其他(请注明)")
302
  return gr.update(visible=is_other, interactive=is_other, value="")
303
 
304
+ # This function checks if the user information is complete
305
  def check_info_complete(username, age, gender, education, education_other, ai_experience):
306
  if username.strip() and age and gender and education and ai_experience:
307
  if education == "其他(请注明)" and not education_other.strip():
 
309
  return gr.update(interactive=True)
310
  return gr.update(interactive=False)
311
 
312
+ # This function updates user_data and initializes the sample page (called when user submits their info)
313
  def show_sample_page_and_init(username, age, gender, education, education_other, ai_experience, user_data):
314
  final_edu = education_other if education == "其他(请注明)" else education
315
  user_data.update({
 
378
 
379
  def init_test_question(user_data, q_idx):
380
  d_idx = 0
 
 
381
  question = user_data["question_set"][q_idx]
382
  progress_q = f"第 {q_idx + 1} / {len(user_data['question_set'])} 题"
383
 
 
440
  next_btn_update,
441
  ) + tuple(slider_updates)
442
 
443
+ def toggle_reference_view(current):
444
+ if current == "参考":
445
+ return gr.update(visible=False), gr.update(visible=True), gr.update(value="返回")
446
+ else:
447
+ return gr.update(visible=True), gr.update(visible=False), gr.update(value="参考")
448
+
449
+ def back_to_welcome():
450
+ return (
451
+ gr.update(visible=True), # welcome_page
452
+ gr.update(visible=False), # info_page
453
+ gr.update(visible=False), # sample_page
454
+ gr.update(visible=False), # pretest_page
455
+ gr.update(visible=False), # test_page
456
+ gr.update(visible=False), # final_judgment_page
457
+ gr.update(visible=False), # result_page
458
+ {}, # user_data_state
459
+ 0, # current_question_index
460
+ 0, # current_test_dimension_index
461
+ {}, # current_question_selections
462
+ [] # test_results
463
+ )
464
+
465
  # ==============================================================================
466
+ # Retry Function Definitions
467
  # ==============================================================================
468
 
469
  # Function for handling connection error
 
789
  commit_message=f"Update count.json after submission by {username}"
790
  )
791
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
792
  # ==============================================================================
793
  # Gradio 界面定义 (Gradio UI Definition)
794
  # ==============================================================================