File size: 7,028 Bytes
06c11b0 d10d370 06c11b0 d10d370 06c11b0 d10d370 06c11b0 a365309 06c11b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 | import json
import os
import random
import threading
from pathlib import Path
from config import TASK_NAME_LIST
from state_manager import clear_task_start_time, get_task_start_time
METADATA_FILE_GLOB = "record_dataset_*_metadata.json"
class UserManager:
def __init__(self):
self.base_dir = Path(__file__).resolve().parent
self.lock = threading.Lock()
self.env_to_episodes = self._load_env_episode_pool()
self.env_choices = self._build_env_choices()
# Session-local progress only (no disk persistence)
self.session_progress = {}
def _resolve_metadata_root(self) -> Path:
env_root = os.environ.get("ROBOMME_METADATA_ROOT")
if env_root:
return Path(env_root)
return self.base_dir.parent / "src" / "robomme" / "env_metadata" / "train"
def _load_env_episode_pool(self):
env_to_episode_set = {}
metadata_root = self._resolve_metadata_root()
if not metadata_root.exists():
print(f"Warning: metadata root not found: {metadata_root}")
return {}
for metadata_path in sorted(metadata_root.glob(METADATA_FILE_GLOB)):
try:
payload = json.loads(metadata_path.read_text(encoding="utf-8"))
except Exception as exc:
print(f"Warning: failed to read metadata file {metadata_path}: {exc}")
continue
fallback_env = str(payload.get("env_id") or "").strip()
for record in payload.get("records", []):
env_id = str(record.get("task") or fallback_env or "").strip()
episode = record.get("episode")
if not env_id or episode is None:
continue
try:
episode_idx = int(episode)
except (TypeError, ValueError):
continue
env_to_episode_set.setdefault(env_id, set()).add(episode_idx)
env_to_episodes = {
env_id: sorted(episodes)
for env_id, episodes in env_to_episode_set.items()
if episodes
}
print(f"Loaded random env pool: {len(env_to_episodes)} envs from metadata root {metadata_root}")
return env_to_episodes
def _build_env_choices(self):
available_envs = set(self.env_to_episodes.keys())
ordered_choices = [env_id for env_id in TASK_NAME_LIST if env_id in available_envs]
remaining_choices = sorted(available_envs - set(ordered_choices))
return ordered_choices + remaining_choices
def _ensure_session_entry(self, uid):
if uid not in self.session_progress:
self.session_progress[uid] = {
"completed_count": 0,
"current_env_id": None,
"current_episode_idx": None,
}
def _set_current_random_task(self, uid, preferred_env=None):
if not self.env_choices:
return False
self._ensure_session_entry(uid)
env_id = preferred_env if preferred_env in self.env_to_episodes else random.choice(self.env_choices)
episodes = self.env_to_episodes.get(env_id, [])
if not episodes:
return False
episode_idx = int(random.choice(episodes))
self.session_progress[uid]["current_env_id"] = env_id
self.session_progress[uid]["current_episode_idx"] = episode_idx
return True
def init_session(self, uid):
if not uid:
return False, "Session uid cannot be empty", None
if not self.env_choices:
return False, "No available environments found in metadata.", None
with self.lock:
self._ensure_session_entry(uid)
progress = self.session_progress[uid]
if progress.get("current_env_id") is None or progress.get("current_episode_idx") is None:
if not self._set_current_random_task(uid):
return False, "Failed to assign random task from metadata.", None
return True, "Session initialized", self.get_session_status(uid)
def get_session_status(self, uid):
if not uid:
return None
with self.lock:
self._ensure_session_entry(uid)
progress = self.session_progress[uid]
if (
(progress.get("current_env_id") is None or progress.get("current_episode_idx") is None)
and self.env_choices
):
self._set_current_random_task(uid)
progress = self.session_progress[uid]
current_task = None
if progress.get("current_env_id") is not None and progress.get("current_episode_idx") is not None:
current_task = {
"env_id": progress["current_env_id"],
"episode_idx": int(progress["current_episode_idx"]),
}
completed_count = int(progress.get("completed_count", 0))
return {
"uid": uid,
"total_tasks": len(self.env_choices), # compatibility only
"current_index": completed_count, # compatibility only
"completed_count": completed_count,
"current_task": current_task,
"is_done_all": False,
"tasks": [], # compatibility only
"env_choices": list(self.env_choices),
}
def complete_current_task(self, uid, env_id=None, episode_idx=None, **_kwargs):
if not uid:
return None
with self.lock:
self._ensure_session_entry(uid)
self.session_progress[uid]["completed_count"] = int(self.session_progress[uid]["completed_count"]) + 1
if env_id is not None and episode_idx is not None:
_ = get_task_start_time(uid, env_id, episode_idx)
clear_task_start_time(uid, env_id, episode_idx)
return self.get_session_status(uid)
def switch_env_and_random_episode(self, uid, env_id):
if not uid or env_id not in self.env_to_episodes:
return None
with self.lock:
self._ensure_session_entry(uid)
if not self._set_current_random_task(uid, preferred_env=env_id):
return None
return self.get_session_status(uid)
def next_episode_same_env(self, uid):
if not uid:
return None
with self.lock:
self._ensure_session_entry(uid)
current_env = self.session_progress[uid].get("current_env_id")
if current_env not in self.env_to_episodes:
if not self._set_current_random_task(uid):
return None
else:
if not self._set_current_random_task(uid, preferred_env=current_env):
return None
return self.get_session_status(uid)
def cleanup_session(self, uid):
if not uid:
return
with self.lock:
self.session_progress.pop(uid, None)
user_manager = UserManager()
|