| import json |
| import os |
| import random |
| import threading |
| from pathlib import Path |
|
|
| from config import TASK_NAME_LIST |
| from state_manager import clear_task_start_time, get_task_start_time |
|
|
|
|
| METADATA_FILE_GLOB = "record_dataset_*_metadata.json" |
|
|
|
|
| class UserManager: |
| def __init__(self): |
| self.base_dir = Path(__file__).resolve().parent |
| self.lock = threading.Lock() |
|
|
| self.env_to_episodes = self._load_env_episode_pool() |
| self.env_choices = self._build_env_choices() |
|
|
| |
| self.session_progress = {} |
|
|
| def _resolve_metadata_root(self) -> Path: |
| env_root = os.environ.get("ROBOMME_METADATA_ROOT") |
| if env_root: |
| return Path(env_root) |
| return self.base_dir.parent / "src" / "robomme" / "env_metadata" / "train" |
|
|
| def _load_env_episode_pool(self): |
| env_to_episode_set = {} |
| metadata_root = self._resolve_metadata_root() |
| if not metadata_root.exists(): |
| print(f"Warning: metadata root not found: {metadata_root}") |
| return {} |
|
|
| for metadata_path in sorted(metadata_root.glob(METADATA_FILE_GLOB)): |
| try: |
| payload = json.loads(metadata_path.read_text(encoding="utf-8")) |
| except Exception as exc: |
| print(f"Warning: failed to read metadata file {metadata_path}: {exc}") |
| continue |
|
|
| fallback_env = str(payload.get("env_id") or "").strip() |
| for record in payload.get("records", []): |
| env_id = str(record.get("task") or fallback_env or "").strip() |
| episode = record.get("episode") |
| if not env_id or episode is None: |
| continue |
| try: |
| episode_idx = int(episode) |
| except (TypeError, ValueError): |
| continue |
| env_to_episode_set.setdefault(env_id, set()).add(episode_idx) |
|
|
| env_to_episodes = { |
| env_id: sorted(episodes) |
| for env_id, episodes in env_to_episode_set.items() |
| if episodes |
| } |
| print(f"Loaded random env pool: {len(env_to_episodes)} envs from metadata root {metadata_root}") |
| return env_to_episodes |
|
|
| def _build_env_choices(self): |
| available_envs = set(self.env_to_episodes.keys()) |
| ordered_choices = [env_id for env_id in TASK_NAME_LIST if env_id in available_envs] |
| remaining_choices = sorted(available_envs - set(ordered_choices)) |
| return ordered_choices + remaining_choices |
|
|
| def _ensure_session_entry(self, uid): |
| if uid not in self.session_progress: |
| self.session_progress[uid] = { |
| "completed_count": 0, |
| "current_env_id": None, |
| "current_episode_idx": None, |
| } |
|
|
| def _set_current_random_task(self, uid, preferred_env=None): |
| if not self.env_choices: |
| return False |
| self._ensure_session_entry(uid) |
|
|
| env_id = preferred_env if preferred_env in self.env_to_episodes else random.choice(self.env_choices) |
| episodes = self.env_to_episodes.get(env_id, []) |
| if not episodes: |
| return False |
|
|
| episode_idx = int(random.choice(episodes)) |
| self.session_progress[uid]["current_env_id"] = env_id |
| self.session_progress[uid]["current_episode_idx"] = episode_idx |
| return True |
|
|
| def init_session(self, uid): |
| if not uid: |
| return False, "Session uid cannot be empty", None |
| if not self.env_choices: |
| return False, "No available environments found in metadata.", None |
|
|
| with self.lock: |
| self._ensure_session_entry(uid) |
| progress = self.session_progress[uid] |
| if progress.get("current_env_id") is None or progress.get("current_episode_idx") is None: |
| if not self._set_current_random_task(uid): |
| return False, "Failed to assign random task from metadata.", None |
|
|
| return True, "Session initialized", self.get_session_status(uid) |
|
|
| def get_session_status(self, uid): |
| if not uid: |
| return None |
|
|
| with self.lock: |
| self._ensure_session_entry(uid) |
| progress = self.session_progress[uid] |
| if ( |
| (progress.get("current_env_id") is None or progress.get("current_episode_idx") is None) |
| and self.env_choices |
| ): |
| self._set_current_random_task(uid) |
| progress = self.session_progress[uid] |
|
|
| current_task = None |
| if progress.get("current_env_id") is not None and progress.get("current_episode_idx") is not None: |
| current_task = { |
| "env_id": progress["current_env_id"], |
| "episode_idx": int(progress["current_episode_idx"]), |
| } |
|
|
| completed_count = int(progress.get("completed_count", 0)) |
|
|
| return { |
| "uid": uid, |
| "total_tasks": len(self.env_choices), |
| "current_index": completed_count, |
| "completed_count": completed_count, |
| "current_task": current_task, |
| "is_done_all": False, |
| "tasks": [], |
| "env_choices": list(self.env_choices), |
| } |
|
|
| def complete_current_task(self, uid, env_id=None, episode_idx=None, **_kwargs): |
| if not uid: |
| return None |
|
|
| with self.lock: |
| self._ensure_session_entry(uid) |
| self.session_progress[uid]["completed_count"] = int(self.session_progress[uid]["completed_count"]) + 1 |
|
|
| if env_id is not None and episode_idx is not None: |
| _ = get_task_start_time(uid, env_id, episode_idx) |
| clear_task_start_time(uid, env_id, episode_idx) |
|
|
| return self.get_session_status(uid) |
|
|
| def switch_env_and_random_episode(self, uid, env_id): |
| if not uid or env_id not in self.env_to_episodes: |
| return None |
|
|
| with self.lock: |
| self._ensure_session_entry(uid) |
| if not self._set_current_random_task(uid, preferred_env=env_id): |
| return None |
|
|
| return self.get_session_status(uid) |
|
|
| def next_episode_same_env(self, uid): |
| if not uid: |
| return None |
|
|
| with self.lock: |
| self._ensure_session_entry(uid) |
| current_env = self.session_progress[uid].get("current_env_id") |
| if current_env not in self.env_to_episodes: |
| if not self._set_current_random_task(uid): |
| return None |
| else: |
| if not self._set_current_random_task(uid, preferred_env=current_env): |
| return None |
|
|
| return self.get_session_status(uid) |
|
|
| def cleanup_session(self, uid): |
| if not uid: |
| return |
|
|
| with self.lock: |
| self.session_progress.pop(uid, None) |
|
|
|
|
| user_manager = UserManager() |
|
|