| | import json |
| | import random |
| |
|
| | |
| | |
| | |
| |
|
| | ENVS = [ |
| | |
| | "BinFill", |
| | "PickXtimes", |
| | "SwingXtimes", |
| | "StopCube", |
| | |
| | |
| | "VideoUnmask", |
| | "ButtonUnmask", |
| | "VideoUnmaskSwap", |
| | "ButtonUnmaskSwap", |
| | |
| | |
| | "PickHighlight", |
| | "VideoRepick", |
| | "VideoPlaceButton", |
| | "VideoPlaceOrder", |
| | |
| | |
| | "MoveCube", |
| | "InsertPeg", |
| | "PatternLock", |
| | "RouteStick", |
| | ] |
| |
|
| | REAL_USERS = [ |
| | "Hongyu_Zhou", |
| | "Wanling_Cai", |
| | "Xinyi_Wang", |
| | "Yinpei_Dai", |
| | "Hongze_Fu", |
| | "Run_Peng", |
| | "Haoran_Zhang", |
| | "Yunqi_Zhao", |
| | "Yue_Hu", |
| | "Yiwei_Lyu", |
| | "Josue_Torres-Fonseca", |
| | "Jung-Chun_Liu", |
| | "Jacob_Sansom", |
| | "Long-Jing_Hsu" |
| |
|
| | ] |
| |
|
| | NUM_USERS = 20 |
| | EPISODES_PER_ENV = 50 |
| | TEST_EPISODE_IDX = 98 |
| |
|
| |
|
| | def generate_json(seed: int = 0): |
| | rng = random.Random(seed) |
| |
|
| | |
| | env_tasks = {} |
| | for env in ENVS: |
| | env_tasks[env] = [ |
| | {"env_id": env, "episode_idx": ep} |
| | for ep in range(EPISODES_PER_ENV) |
| | ] |
| |
|
| | |
| | user_keys = [] |
| | for i in range(NUM_USERS): |
| | if i < len(REAL_USERS): |
| | user_keys.append(REAL_USERS[i]) |
| | else: |
| | user_keys.append(f"user{i+1}") |
| |
|
| | |
| | users = {key: [] for key in user_keys} |
| | |
| | |
| | |
| | used_tasks = {env: set() for env in ENVS} |
| | |
| | for user_key in user_keys: |
| | for env in ENVS: |
| | |
| | available = [ |
| | task for task in env_tasks[env] |
| | if task["episode_idx"] not in used_tasks[env] |
| | ] |
| | if available: |
| | selected_task = rng.choice(available) |
| | users[user_key].append(selected_task) |
| | used_tasks[env].add(selected_task["episode_idx"]) |
| |
|
| | |
| | |
| | remaining_tasks = [] |
| | for env in ENVS: |
| | for task in env_tasks[env]: |
| | if task["episode_idx"] not in used_tasks[env]: |
| | remaining_tasks.append(task) |
| | |
| | |
| | rng.shuffle(remaining_tasks) |
| | |
| | |
| | |
| | remaining_per_user = len(remaining_tasks) // NUM_USERS |
| | |
| | for i in range(NUM_USERS): |
| | start = i * remaining_per_user |
| | end = (i + 1) * remaining_per_user |
| | users[user_keys[i]].extend(remaining_tasks[start:end]) |
| | |
| | |
| | remainder = len(remaining_tasks) % NUM_USERS |
| | if remainder > 0: |
| | start_idx = remaining_per_user * NUM_USERS |
| | for i in range(remainder): |
| | users[user_keys[i]].append(remaining_tasks[start_idx + i]) |
| |
|
| | |
| | test_template = [ |
| | {"env_id": env, "episode_idx": TEST_EPISODE_IDX} |
| | |
| | for env in ENVS |
| | ] |
| |
|
| | output = {} |
| | for user_key in user_keys: |
| | |
| | output[user_key] = test_template + users[user_key] |
| | |
| |
|
| | return output |
| |
|
| |
|
| | if __name__ == "__main__": |
| | data = generate_json(seed=42) |
| |
|
| | with open("user_tasks.json", "w", encoding="utf-8") as f: |
| | json.dump(data, f, indent=2, ensure_ascii=False) |
| |
|
| | counts = {k: len(v) for k, v in data.items() if not k.endswith("_test")} |
| | print("Train counts:", counts) |
| | print("Min/Max:", min(counts.values()), max(counts.values())) |
| | print("✅ 已生成并保存到 user_tasks.json") |
| |
|