| import json |
| import random |
|
|
| |
| |
| |
|
|
| ENVS = [ |
| |
| "BinFill", |
| "PickXtimes", |
| "SwingXtimes", |
| "StopCube", |
| |
| |
| "VideoUnmask", |
| "ButtonUnmask", |
| "VideoUnmaskSwap", |
| "ButtonUnmaskSwap", |
| |
| |
| "PickHighlight", |
| "VideoRepick", |
| "VideoPlaceButton", |
| "VideoPlaceOrder", |
| |
| |
| "MoveCube", |
| "InsertPeg", |
| "PatternLock", |
| "RouteStick", |
| ] |
|
|
| REAL_USERS = [ |
| "Hongyu_Zhou", |
| "Wanling_Cai", |
| "Xinyi_Wang", |
| "Yinpei_Dai", |
| "Hongze_Fu", |
| "Run_Peng", |
| "Haoran_Zhang", |
| "Yunqi_Zhao", |
| "Yue_Hu", |
| "Yiwei_Lyu", |
| "Josue_Torres-Fonseca", |
| "Jung-Chun_Liu", |
| "Jacob_Sansom", |
| "Long-Jing_Hsu" |
|
|
| ] |
|
|
| NUM_USERS = 20 |
| EPISODES_PER_ENV = 50 |
| TEST_EPISODE_IDX = 98 |
|
|
|
|
| def generate_json(seed: int = 0): |
| rng = random.Random(seed) |
|
|
| |
| env_tasks = {} |
| for env in ENVS: |
| env_tasks[env] = [ |
| {"env_id": env, "episode_idx": ep} |
| for ep in range(EPISODES_PER_ENV) |
| ] |
|
|
| |
| user_keys = [] |
| for i in range(NUM_USERS): |
| if i < len(REAL_USERS): |
| user_keys.append(REAL_USERS[i]) |
| else: |
| user_keys.append(f"user{i+1}") |
|
|
| |
| users = {key: [] for key in user_keys} |
| |
| |
| |
| used_tasks = {env: set() for env in ENVS} |
| |
| for user_key in user_keys: |
| for env in ENVS: |
| |
| available = [ |
| task for task in env_tasks[env] |
| if task["episode_idx"] not in used_tasks[env] |
| ] |
| if available: |
| selected_task = rng.choice(available) |
| users[user_key].append(selected_task) |
| used_tasks[env].add(selected_task["episode_idx"]) |
|
|
| |
| |
| remaining_tasks = [] |
| for env in ENVS: |
| for task in env_tasks[env]: |
| if task["episode_idx"] not in used_tasks[env]: |
| remaining_tasks.append(task) |
| |
| |
| rng.shuffle(remaining_tasks) |
| |
| |
| |
| remaining_per_user = len(remaining_tasks) // NUM_USERS |
| |
| for i in range(NUM_USERS): |
| start = i * remaining_per_user |
| end = (i + 1) * remaining_per_user |
| users[user_keys[i]].extend(remaining_tasks[start:end]) |
| |
| |
| remainder = len(remaining_tasks) % NUM_USERS |
| if remainder > 0: |
| start_idx = remaining_per_user * NUM_USERS |
| for i in range(remainder): |
| users[user_keys[i]].append(remaining_tasks[start_idx + i]) |
|
|
| |
| test_template = [ |
| {"env_id": env, "episode_idx": TEST_EPISODE_IDX} |
| |
| for env in ENVS |
| ] |
|
|
| output = {} |
| for user_key in user_keys: |
| |
| output[user_key] = test_template + users[user_key] |
| |
|
|
| return output |
|
|
|
|
| if __name__ == "__main__": |
| data = generate_json(seed=42) |
|
|
| with open("user_tasks.json", "w", encoding="utf-8") as f: |
| json.dump(data, f, indent=2, ensure_ascii=False) |
|
|
| counts = {k: len(v) for k, v in data.items() if not k.endswith("_test")} |
| print("Train counts:", counts) |
| print("Min/Max:", min(counts.values()), max(counts.values())) |
| print("✅ 已生成并保存到 user_tasks.json") |
|
|