File size: 7,028 Bytes
06c11b0
 
 
 
 
 
d10d370
06c11b0
 
 
 
 
 
 
 
 
 
 
 
d10d370
06c11b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d10d370
 
 
 
 
 
06c11b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a365309
 
 
 
 
 
 
06c11b0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import json
import os
import random
import threading
from pathlib import Path

from config import TASK_NAME_LIST
from state_manager import clear_task_start_time, get_task_start_time


METADATA_FILE_GLOB = "record_dataset_*_metadata.json"


class UserManager:
    def __init__(self):
        self.base_dir = Path(__file__).resolve().parent
        self.lock = threading.Lock()

        self.env_to_episodes = self._load_env_episode_pool()
        self.env_choices = self._build_env_choices()

        # Session-local progress only (no disk persistence)
        self.session_progress = {}

    def _resolve_metadata_root(self) -> Path:
        env_root = os.environ.get("ROBOMME_METADATA_ROOT")
        if env_root:
            return Path(env_root)
        return self.base_dir.parent / "src" / "robomme" / "env_metadata" / "train"

    def _load_env_episode_pool(self):
        env_to_episode_set = {}
        metadata_root = self._resolve_metadata_root()
        if not metadata_root.exists():
            print(f"Warning: metadata root not found: {metadata_root}")
            return {}

        for metadata_path in sorted(metadata_root.glob(METADATA_FILE_GLOB)):
            try:
                payload = json.loads(metadata_path.read_text(encoding="utf-8"))
            except Exception as exc:
                print(f"Warning: failed to read metadata file {metadata_path}: {exc}")
                continue

            fallback_env = str(payload.get("env_id") or "").strip()
            for record in payload.get("records", []):
                env_id = str(record.get("task") or fallback_env or "").strip()
                episode = record.get("episode")
                if not env_id or episode is None:
                    continue
                try:
                    episode_idx = int(episode)
                except (TypeError, ValueError):
                    continue
                env_to_episode_set.setdefault(env_id, set()).add(episode_idx)

        env_to_episodes = {
            env_id: sorted(episodes)
            for env_id, episodes in env_to_episode_set.items()
            if episodes
        }
        print(f"Loaded random env pool: {len(env_to_episodes)} envs from metadata root {metadata_root}")
        return env_to_episodes

    def _build_env_choices(self):
        available_envs = set(self.env_to_episodes.keys())
        ordered_choices = [env_id for env_id in TASK_NAME_LIST if env_id in available_envs]
        remaining_choices = sorted(available_envs - set(ordered_choices))
        return ordered_choices + remaining_choices

    def _ensure_session_entry(self, uid):
        if uid not in self.session_progress:
            self.session_progress[uid] = {
                "completed_count": 0,
                "current_env_id": None,
                "current_episode_idx": None,
            }

    def _set_current_random_task(self, uid, preferred_env=None):
        if not self.env_choices:
            return False
        self._ensure_session_entry(uid)

        env_id = preferred_env if preferred_env in self.env_to_episodes else random.choice(self.env_choices)
        episodes = self.env_to_episodes.get(env_id, [])
        if not episodes:
            return False

        episode_idx = int(random.choice(episodes))
        self.session_progress[uid]["current_env_id"] = env_id
        self.session_progress[uid]["current_episode_idx"] = episode_idx
        return True

    def init_session(self, uid):
        if not uid:
            return False, "Session uid cannot be empty", None
        if not self.env_choices:
            return False, "No available environments found in metadata.", None

        with self.lock:
            self._ensure_session_entry(uid)
            progress = self.session_progress[uid]
            if progress.get("current_env_id") is None or progress.get("current_episode_idx") is None:
                if not self._set_current_random_task(uid):
                    return False, "Failed to assign random task from metadata.", None

        return True, "Session initialized", self.get_session_status(uid)

    def get_session_status(self, uid):
        if not uid:
            return None

        with self.lock:
            self._ensure_session_entry(uid)
            progress = self.session_progress[uid]
            if (
                (progress.get("current_env_id") is None or progress.get("current_episode_idx") is None)
                and self.env_choices
            ):
                self._set_current_random_task(uid)
                progress = self.session_progress[uid]

            current_task = None
            if progress.get("current_env_id") is not None and progress.get("current_episode_idx") is not None:
                current_task = {
                    "env_id": progress["current_env_id"],
                    "episode_idx": int(progress["current_episode_idx"]),
                }

            completed_count = int(progress.get("completed_count", 0))

        return {
            "uid": uid,
            "total_tasks": len(self.env_choices),  # compatibility only
            "current_index": completed_count,  # compatibility only
            "completed_count": completed_count,
            "current_task": current_task,
            "is_done_all": False,
            "tasks": [],  # compatibility only
            "env_choices": list(self.env_choices),
        }

    def complete_current_task(self, uid, env_id=None, episode_idx=None, **_kwargs):
        if not uid:
            return None

        with self.lock:
            self._ensure_session_entry(uid)
            self.session_progress[uid]["completed_count"] = int(self.session_progress[uid]["completed_count"]) + 1

        if env_id is not None and episode_idx is not None:
            _ = get_task_start_time(uid, env_id, episode_idx)
            clear_task_start_time(uid, env_id, episode_idx)

        return self.get_session_status(uid)

    def switch_env_and_random_episode(self, uid, env_id):
        if not uid or env_id not in self.env_to_episodes:
            return None

        with self.lock:
            self._ensure_session_entry(uid)
            if not self._set_current_random_task(uid, preferred_env=env_id):
                return None

        return self.get_session_status(uid)

    def next_episode_same_env(self, uid):
        if not uid:
            return None

        with self.lock:
            self._ensure_session_entry(uid)
            current_env = self.session_progress[uid].get("current_env_id")
            if current_env not in self.env_to_episodes:
                if not self._set_current_random_task(uid):
                    return None
            else:
                if not self._set_current_random_task(uid, preferred_env=current_env):
                    return None

        return self.get_session_status(uid)

    def cleanup_session(self, uid):
        if not uid:
            return

        with self.lock:
            self.session_progress.pop(uid, None)


user_manager = UserManager()