RoboMME / gradio-web /AAAuser_generator.py
HongzeFu's picture
HF Space: code-only (no binary assets)
06c11b0
import json
import random
# Deprecated runtime path:
# This script is only for offline generation experiments and is not used by
# the current Gradio runtime task assignment flow.
ENVS = [
# Counting
"BinFill",
"PickXtimes",
"SwingXtimes",
"StopCube",
# Persistence
"VideoUnmask",
"ButtonUnmask",
"VideoUnmaskSwap",
"ButtonUnmaskSwap",
# Reference
"PickHighlight",
"VideoRepick",
"VideoPlaceButton",
"VideoPlaceOrder",
# Behavior
"MoveCube",
"InsertPeg",
"PatternLock",
"RouteStick",
]
REAL_USERS = [
"Hongyu_Zhou",
"Wanling_Cai",
"Xinyi_Wang",
"Yinpei_Dai",
"Hongze_Fu",
"Run_Peng",
"Haoran_Zhang",
"Yunqi_Zhao",
"Yue_Hu",
"Yiwei_Lyu",
"Josue_Torres-Fonseca",
"Jung-Chun_Liu",
"Jacob_Sansom",
"Long-Jing_Hsu"
]
NUM_USERS = 20
EPISODES_PER_ENV = 50
TEST_EPISODE_IDX = 98
def generate_json(seed: int = 0):
rng = random.Random(seed)
# 1️⃣ 为每个环境生成所有任务
env_tasks = {}
for env in ENVS:
env_tasks[env] = [
{"env_id": env, "episode_idx": ep}
for ep in range(EPISODES_PER_ENV)
]
# Generate user keys
user_keys = []
for i in range(NUM_USERS):
if i < len(REAL_USERS):
user_keys.append(REAL_USERS[i])
else:
user_keys.append(f"user{i+1}")
# 2️⃣ 初始化用户任务列表
users = {key: [] for key in user_keys}
# 3️⃣ 阶段1:保证每个用户都有全部环境至少一次
# 为每个用户从每个环境随机选择1个任务
used_tasks = {env: set() for env in ENVS} # 记录已使用的episode_idx
for user_key in user_keys:
for env in ENVS:
# 从该环境的可用任务中随机选择一个
available = [
task for task in env_tasks[env]
if task["episode_idx"] not in used_tasks[env]
]
if available:
selected_task = rng.choice(available)
users[user_key].append(selected_task)
used_tasks[env].add(selected_task["episode_idx"])
# 4️⃣ 阶段2:均匀分配剩余任务
# 收集剩余任务(未被使用的任务)
remaining_tasks = []
for env in ENVS:
for task in env_tasks[env]:
if task["episode_idx"] not in used_tasks[env]:
remaining_tasks.append(task)
# 打乱剩余任务
rng.shuffle(remaining_tasks)
# 均匀分配给用户,保持每个环境在每个用户中的平衡
# 每个用户再分到剩余任务数/用户数的任务
remaining_per_user = len(remaining_tasks) // NUM_USERS
for i in range(NUM_USERS):
start = i * remaining_per_user
end = (i + 1) * remaining_per_user
users[user_keys[i]].extend(remaining_tasks[start:end])
# 如果有余数,分配给前几个用户(每个用户1个)
remainder = len(remaining_tasks) % NUM_USERS
if remainder > 0:
start_idx = remaining_per_user * NUM_USERS
for i in range(remainder):
users[user_keys[i]].append(remaining_tasks[start_idx + i])
# 5️⃣ test(保持你原格式)
test_template = [
{"env_id": env, "episode_idx": TEST_EPISODE_IDX}
#for env in ENVS if env == "ButtonUnmask" or env == "VideoUnmaskSwap"
for env in ENVS
]
output = {}
for user_key in user_keys:
# 把test任务放在训练任务前面
output[user_key] = test_template + users[user_key]
#output[f"user{i}_test"] = test_template 不输出test
return output
if __name__ == "__main__":
data = generate_json(seed=42)
with open("user_tasks.json", "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
counts = {k: len(v) for k, v in data.items() if not k.endswith("_test")}
print("Train counts:", counts)
print("Min/Max:", min(counts.values()), max(counts.values()))
print("✅ 已生成并保存到 user_tasks.json")