RoboMME / gradio-web /user_manager.py
HongzeFu's picture
session manager v1
a365309
import json
import os
import random
import threading
from pathlib import Path
from config import TASK_NAME_LIST
from state_manager import clear_task_start_time, get_task_start_time
METADATA_FILE_GLOB = "record_dataset_*_metadata.json"
class UserManager:
def __init__(self):
self.base_dir = Path(__file__).resolve().parent
self.lock = threading.Lock()
self.env_to_episodes = self._load_env_episode_pool()
self.env_choices = self._build_env_choices()
# Session-local progress only (no disk persistence)
self.session_progress = {}
def _resolve_metadata_root(self) -> Path:
env_root = os.environ.get("ROBOMME_METADATA_ROOT")
if env_root:
return Path(env_root)
return self.base_dir.parent / "src" / "robomme" / "env_metadata" / "train"
def _load_env_episode_pool(self):
env_to_episode_set = {}
metadata_root = self._resolve_metadata_root()
if not metadata_root.exists():
print(f"Warning: metadata root not found: {metadata_root}")
return {}
for metadata_path in sorted(metadata_root.glob(METADATA_FILE_GLOB)):
try:
payload = json.loads(metadata_path.read_text(encoding="utf-8"))
except Exception as exc:
print(f"Warning: failed to read metadata file {metadata_path}: {exc}")
continue
fallback_env = str(payload.get("env_id") or "").strip()
for record in payload.get("records", []):
env_id = str(record.get("task") or fallback_env or "").strip()
episode = record.get("episode")
if not env_id or episode is None:
continue
try:
episode_idx = int(episode)
except (TypeError, ValueError):
continue
env_to_episode_set.setdefault(env_id, set()).add(episode_idx)
env_to_episodes = {
env_id: sorted(episodes)
for env_id, episodes in env_to_episode_set.items()
if episodes
}
print(f"Loaded random env pool: {len(env_to_episodes)} envs from metadata root {metadata_root}")
return env_to_episodes
def _build_env_choices(self):
available_envs = set(self.env_to_episodes.keys())
ordered_choices = [env_id for env_id in TASK_NAME_LIST if env_id in available_envs]
remaining_choices = sorted(available_envs - set(ordered_choices))
return ordered_choices + remaining_choices
def _ensure_session_entry(self, uid):
if uid not in self.session_progress:
self.session_progress[uid] = {
"completed_count": 0,
"current_env_id": None,
"current_episode_idx": None,
}
def _set_current_random_task(self, uid, preferred_env=None):
if not self.env_choices:
return False
self._ensure_session_entry(uid)
env_id = preferred_env if preferred_env in self.env_to_episodes else random.choice(self.env_choices)
episodes = self.env_to_episodes.get(env_id, [])
if not episodes:
return False
episode_idx = int(random.choice(episodes))
self.session_progress[uid]["current_env_id"] = env_id
self.session_progress[uid]["current_episode_idx"] = episode_idx
return True
def init_session(self, uid):
if not uid:
return False, "Session uid cannot be empty", None
if not self.env_choices:
return False, "No available environments found in metadata.", None
with self.lock:
self._ensure_session_entry(uid)
progress = self.session_progress[uid]
if progress.get("current_env_id") is None or progress.get("current_episode_idx") is None:
if not self._set_current_random_task(uid):
return False, "Failed to assign random task from metadata.", None
return True, "Session initialized", self.get_session_status(uid)
def get_session_status(self, uid):
if not uid:
return None
with self.lock:
self._ensure_session_entry(uid)
progress = self.session_progress[uid]
if (
(progress.get("current_env_id") is None or progress.get("current_episode_idx") is None)
and self.env_choices
):
self._set_current_random_task(uid)
progress = self.session_progress[uid]
current_task = None
if progress.get("current_env_id") is not None and progress.get("current_episode_idx") is not None:
current_task = {
"env_id": progress["current_env_id"],
"episode_idx": int(progress["current_episode_idx"]),
}
completed_count = int(progress.get("completed_count", 0))
return {
"uid": uid,
"total_tasks": len(self.env_choices), # compatibility only
"current_index": completed_count, # compatibility only
"completed_count": completed_count,
"current_task": current_task,
"is_done_all": False,
"tasks": [], # compatibility only
"env_choices": list(self.env_choices),
}
def complete_current_task(self, uid, env_id=None, episode_idx=None, **_kwargs):
if not uid:
return None
with self.lock:
self._ensure_session_entry(uid)
self.session_progress[uid]["completed_count"] = int(self.session_progress[uid]["completed_count"]) + 1
if env_id is not None and episode_idx is not None:
_ = get_task_start_time(uid, env_id, episode_idx)
clear_task_start_time(uid, env_id, episode_idx)
return self.get_session_status(uid)
def switch_env_and_random_episode(self, uid, env_id):
if not uid or env_id not in self.env_to_episodes:
return None
with self.lock:
self._ensure_session_entry(uid)
if not self._set_current_random_task(uid, preferred_env=env_id):
return None
return self.get_session_status(uid)
def next_episode_same_env(self, uid):
if not uid:
return None
with self.lock:
self._ensure_session_entry(uid)
current_env = self.session_progress[uid].get("current_env_id")
if current_env not in self.env_to_episodes:
if not self._set_current_random_task(uid):
return None
else:
if not self._set_current_random_task(uid, preferred_env=current_env):
return None
return self.get_session_status(uid)
def cleanup_session(self, uid):
if not uid:
return
with self.lock:
self.session_progress.pop(uid, None)
user_manager = UserManager()