| | import json |
| | import os |
| | import random |
| | import threading |
| | from pathlib import Path |
| |
|
| | from config import TASK_NAME_LIST |
| | from state_manager import clear_task_start_time, get_task_start_time |
| |
|
| |
|
| | METADATA_FILE_GLOB = "record_dataset_*_metadata.json" |
| |
|
| |
|
| | class UserManager: |
| | def __init__(self): |
| | self.base_dir = Path(__file__).resolve().parent |
| | self.lock = threading.Lock() |
| |
|
| | self.env_to_episodes = self._load_env_episode_pool() |
| | self.env_choices = self._build_env_choices() |
| |
|
| | |
| | self.session_progress = {} |
| |
|
| | def _resolve_metadata_root(self) -> Path: |
| | env_root = os.environ.get("ROBOMME_METADATA_ROOT") |
| | if env_root: |
| | return Path(env_root) |
| | return self.base_dir.parent / "src" / "robomme" / "env_metadata" / "train" |
| |
|
| | def _load_env_episode_pool(self): |
| | env_to_episode_set = {} |
| | metadata_root = self._resolve_metadata_root() |
| | if not metadata_root.exists(): |
| | print(f"Warning: metadata root not found: {metadata_root}") |
| | return {} |
| |
|
| | for metadata_path in sorted(metadata_root.glob(METADATA_FILE_GLOB)): |
| | try: |
| | payload = json.loads(metadata_path.read_text(encoding="utf-8")) |
| | except Exception as exc: |
| | print(f"Warning: failed to read metadata file {metadata_path}: {exc}") |
| | continue |
| |
|
| | fallback_env = str(payload.get("env_id") or "").strip() |
| | for record in payload.get("records", []): |
| | env_id = str(record.get("task") or fallback_env or "").strip() |
| | episode = record.get("episode") |
| | if not env_id or episode is None: |
| | continue |
| | try: |
| | episode_idx = int(episode) |
| | except (TypeError, ValueError): |
| | continue |
| | env_to_episode_set.setdefault(env_id, set()).add(episode_idx) |
| |
|
| | env_to_episodes = { |
| | env_id: sorted(episodes) |
| | for env_id, episodes in env_to_episode_set.items() |
| | if episodes |
| | } |
| | print(f"Loaded random env pool: {len(env_to_episodes)} envs from metadata root {metadata_root}") |
| | return env_to_episodes |
| |
|
| | def _build_env_choices(self): |
| | available_envs = set(self.env_to_episodes.keys()) |
| | ordered_choices = [env_id for env_id in TASK_NAME_LIST if env_id in available_envs] |
| | remaining_choices = sorted(available_envs - set(ordered_choices)) |
| | return ordered_choices + remaining_choices |
| |
|
| | def _ensure_session_entry(self, uid): |
| | if uid not in self.session_progress: |
| | self.session_progress[uid] = { |
| | "completed_count": 0, |
| | "current_env_id": None, |
| | "current_episode_idx": None, |
| | } |
| |
|
| | def _set_current_random_task(self, uid, preferred_env=None): |
| | if not self.env_choices: |
| | return False |
| | self._ensure_session_entry(uid) |
| |
|
| | env_id = preferred_env if preferred_env in self.env_to_episodes else random.choice(self.env_choices) |
| | episodes = self.env_to_episodes.get(env_id, []) |
| | if not episodes: |
| | return False |
| |
|
| | episode_idx = int(random.choice(episodes)) |
| | self.session_progress[uid]["current_env_id"] = env_id |
| | self.session_progress[uid]["current_episode_idx"] = episode_idx |
| | return True |
| |
|
| | def init_session(self, uid): |
| | if not uid: |
| | return False, "Session uid cannot be empty", None |
| | if not self.env_choices: |
| | return False, "No available environments found in metadata.", None |
| |
|
| | with self.lock: |
| | self._ensure_session_entry(uid) |
| | progress = self.session_progress[uid] |
| | if progress.get("current_env_id") is None or progress.get("current_episode_idx") is None: |
| | if not self._set_current_random_task(uid): |
| | return False, "Failed to assign random task from metadata.", None |
| |
|
| | return True, "Session initialized", self.get_session_status(uid) |
| |
|
| | def get_session_status(self, uid): |
| | if not uid: |
| | return None |
| |
|
| | with self.lock: |
| | self._ensure_session_entry(uid) |
| | progress = self.session_progress[uid] |
| | if ( |
| | (progress.get("current_env_id") is None or progress.get("current_episode_idx") is None) |
| | and self.env_choices |
| | ): |
| | self._set_current_random_task(uid) |
| | progress = self.session_progress[uid] |
| |
|
| | current_task = None |
| | if progress.get("current_env_id") is not None and progress.get("current_episode_idx") is not None: |
| | current_task = { |
| | "env_id": progress["current_env_id"], |
| | "episode_idx": int(progress["current_episode_idx"]), |
| | } |
| |
|
| | completed_count = int(progress.get("completed_count", 0)) |
| |
|
| | return { |
| | "uid": uid, |
| | "total_tasks": len(self.env_choices), |
| | "current_index": completed_count, |
| | "completed_count": completed_count, |
| | "current_task": current_task, |
| | "is_done_all": False, |
| | "tasks": [], |
| | "env_choices": list(self.env_choices), |
| | } |
| |
|
| | def complete_current_task(self, uid, env_id=None, episode_idx=None, **_kwargs): |
| | if not uid: |
| | return None |
| |
|
| | with self.lock: |
| | self._ensure_session_entry(uid) |
| | self.session_progress[uid]["completed_count"] = int(self.session_progress[uid]["completed_count"]) + 1 |
| |
|
| | if env_id is not None and episode_idx is not None: |
| | _ = get_task_start_time(uid, env_id, episode_idx) |
| | clear_task_start_time(uid, env_id, episode_idx) |
| |
|
| | return self.get_session_status(uid) |
| |
|
| | def switch_env_and_random_episode(self, uid, env_id): |
| | if not uid or env_id not in self.env_to_episodes: |
| | return None |
| |
|
| | with self.lock: |
| | self._ensure_session_entry(uid) |
| | if not self._set_current_random_task(uid, preferred_env=env_id): |
| | return None |
| |
|
| | return self.get_session_status(uid) |
| |
|
| | def next_episode_same_env(self, uid): |
| | if not uid: |
| | return None |
| |
|
| | with self.lock: |
| | self._ensure_session_entry(uid) |
| | current_env = self.session_progress[uid].get("current_env_id") |
| | if current_env not in self.env_to_episodes: |
| | if not self._set_current_random_task(uid): |
| | return None |
| | else: |
| | if not self._set_current_random_task(uid, preferred_env=current_env): |
| | return None |
| |
|
| | return self.get_session_status(uid) |
| |
|
| | def cleanup_session(self, uid): |
| | if not uid: |
| | return |
| |
|
| | with self.lock: |
| | self.session_progress.pop(uid, None) |
| |
|
| |
|
| | user_manager = UserManager() |
| |
|