avlukas commited on
Commit
5bc5483
·
1 Parent(s): d40203a

Add EpisodeTrace import and enhance travel issue checks

Browse files

- Imported EpisodeTrace in lifeops_env.py.
- Added last_task_id_progressed to LifeOpsState for tracking task progress.
- Updated travel_issues function in reward.py to accept start_location for improved travel feasibility checks.
- Modified action selection logic to consider overlaps and travel issues more effectively, including home location in calculations.

ARCHITECTURE_REVIEW.md ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LifeOps Architecture & Logic Review
2
+
3
+ ## Executive Summary
4
+
5
+ The LifeOps environment is well-structured and mostly correct. Several bugs, edge cases, and design inconsistencies were identified. The most critical issues are: (1) baseline agent creates double-booked focus blocks, (2) dead reward code for conflict resolution, (3) no travel feasibility check for the first event of the day.
6
+
7
+ ---
8
+
9
+ ## 1. Environment Logic
10
+
11
+ ### 1.1 `reset()`
12
+
13
+ **Status:** Generally correct.
14
+
15
+ - Loads scenario, persona, calendar, tasks, pending requests, travel times.
16
+ - `max_steps = max(5, len(pending) + 5)` is reasonable.
17
+ - **Edge case:** `reset("invalid_id")` raises `KeyError` — consider catching and raising a clearer error.
18
+
19
+ ### 1.2 `step()`
20
+
21
+ **Status:** Correct flow.
22
+
23
+ - Validates action against `valid_actions()` via `_action_key`.
24
+ - Applies request or focus action, increments step count, computes reward and done.
25
+
26
+ **Potential issue:** `_action_key` does not include all fields that distinguish actions. For `block_focus_time`, two actions with same `(new_start_min, duration_min)` but different `new_end_min` (one None, one computed) could theoretically collide — in practice `new_end_min` is None for focus blocks, so this is fine.
27
+
28
+ ### 1.3 State Transitions
29
+
30
+ **Status:** Correct.
31
+
32
+ - Request actions: accept/reschedule add to calendar; reject/propose do not; all pop the current request.
33
+ - Focus actions: add focus event, progress highest-priority unfinished task.
34
+
35
+ ### 1.4 Termination Conditions (`_is_done`)
36
+
37
+ **Status:** Correct.
38
+
39
+ - Done when: `step_count >= max_steps` OR (no pending requests AND all tasks complete).
40
+ - **Edge case:** If `valid_actions()` returns empty (e.g., hypothetical scenario with no request and no unfinished tasks but `_is_done` False), the demo runner would crash on `valid[0]`. Current scenarios do not hit this.
41
+
42
+ ---
43
+
44
+ ## 2. Reward Calculation
45
+
46
+ ### 2.1 Correctness
47
+
48
+ **Overlap penalty:** Correct. `-5.0 * len(next_overlaps)`.
49
+
50
+ **Travel penalty:** Correct. `-4.0 * len(issues) * travel_aversion_weight`.
51
+
52
+ **Rejected important penalty:** Correct. `-4.0` when rejecting importance ≥ 3.
53
+
54
+ **Preference penalty:** Correct. Applied to accept/reschedule/propose for meeting-like events.
55
+
56
+ **Focus reward:** Correct. `(1.0 + 0.02 * progress) * focus_time_weight`.
57
+
58
+ **Wasted focus penalty:** `-0.5` when `block_focus_time` with `progress == 0`. In practice this is rare because focus blocks are only generated when `has_unfinished` is true, and progress is always made when there is an unfinished task. Defensive.
59
+
60
+ ### 2.2 Dead Code: `conflict_resolved_bonus`
61
+
62
+ **Bug:** The reward includes:
63
+
64
+ ```python
65
+ if prev_overlaps and len(next_overlaps) < len(prev_overlaps):
66
+ reward += 3.0
67
+ breakdown["conflict_resolved_bonus"] = 3.0
68
+ ```
69
+
70
+ The calendar is **append-only** — events are never removed. Therefore `next_overlaps` can never have fewer pairs than `prev_overlaps`. This branch is **never executed**.
71
+
72
+ **Fix:** Remove this block, or redesign if you later add event-removal/cancellation.
73
+
74
+ ### 2.3 Missing Penalties / Rewards
75
+
76
+ - **No reward for accepting important requests** — only penalty for rejecting. Consider a small positive reward for accepting high-importance requests.
77
+ - **No explicit penalty for `propose_new_time` that suggests an infeasible time** — the preference penalty applies, but overlap/travel of the proposed time are not penalized (since it is not added to the calendar). This may be intentional (proposal quality is soft).
78
+
79
+ ### 2.4 Unintended Reward Loops
80
+
81
+ - None identified. The reward structure is straightforward.
82
+
83
+ ---
84
+
85
+ ## 3. State Consistency
86
+
87
+ ### 3.1 Calendar Updates
88
+
89
+ **Status:** Correct. Events are appended; no removal or modification.
90
+
91
+ ### 3.2 Task Tracking
92
+
93
+ **Status:** Correct. `remaining_minutes` is decremented in-place for the highest-priority unfinished task during focus blocks.
94
+
95
+ ### 3.3 Message/Request Handling
96
+
97
+ **Status:** Correct. FIFO via `pending_requests.pop(0)`. `current_request` is always the first pending.
98
+
99
+ ---
100
+
101
+ ## 4. Travel Feasibility
102
+
103
+ ### 4.1 Detection of Impossible Travel
104
+
105
+ **Status:** Correct for consecutive events. `travel_issues()` sorts by start time and checks each pair.
106
+
107
+ ### 4.2 Missing: Travel to First Event
108
+
109
+ **Bug:** `travel_issues()` only checks `prev → next` for consecutive events. It never checks whether the user can reach the **first** event of the day. The model assumes the user is already at the location of the first event at its start time.
110
+
111
+ **Example:** First event at 8:00 at Office, persona at Home, travel 25 min. User would need to leave by 7:35. This is not validated.
112
+
113
+ **Fix:** Add an optional `start_location` (e.g., Home) to the persona/state and check travel from that location to the first event.
114
+
115
+ ### 4.3 Overlap Logic
116
+
117
+ **Status:** Correct. `_overlap(a_start, a_end, b_start, b_end)` uses `a_start < b_end and b_start < a_end`. Touching events (a_end == b_start) do not overlap.
118
+
119
+ ### 4.4 Rescheduling Edge Cases
120
+
121
+ - Reschedule/propose options use fixed deltas (-30, 30, 60). At day boundaries, `new_start` can clamp to the same value for different deltas, producing duplicate actions. This is harmless (same key).
122
+ - No validation that rescheduled time is free — overlaps are penalized by reward. Acceptable for RL.
123
+
124
+ ---
125
+
126
+ ## 5. Action Handling
127
+
128
+ ### 5.1 All Actions Update State Correctly
129
+
130
+ | Action | Calendar | Pending Requests | Tasks |
131
+ |-------------------|-----------------|------------------|------------|
132
+ | accept_event | +1 event | pop | — |
133
+ | reject_event | — | pop | — |
134
+ | reschedule_event | +1 event | pop | — |
135
+ | propose_new_time | — | pop | — |
136
+ | block_focus_time | +1 focus event | — | progress |
137
+
138
+ **Status:** Correct.
139
+
140
+ ### 5.2 Invalid Action Handling
141
+
142
+ **Status:** Correct. `step()` raises `ValueError` if action key is not in `valid_keys`.
143
+
144
+ ### 5.3 Action Constraints
145
+
146
+ **Issue:** `generate_valid_actions()` does **not** filter out:
147
+
148
+ - Focus blocks that overlap with existing calendar events.
149
+ - Reschedule/propose times that would overlap or cause travel issues.
150
+
151
+ This is acceptable for RL (agent learns from penalties) but means the baseline can choose “valid” actions that create overlaps.
152
+
153
+ ---
154
+
155
+ ## 6. Demo Runner Correctness
156
+
157
+ **Note:** There is no separate `play_episode.py`; the demo lives in `env/lifeops_env.py` under `if __name__ == "__main__"`.
158
+
159
+ ### 6.1 Reflects Real Environment Behavior
160
+
161
+ **Status:** Yes. Uses `env.reset()`, `env.observation()`, `env.valid_actions()`, `env.step()`.
162
+
163
+ ### 6.2 Trajectories Exercise Key Logic
164
+
165
+ **Status:** Partially. Tests cover accept, reject, propose, focus, overlap penalty, travel penalty. However:
166
+
167
+ **Bug:** The baseline agent **can create double-booked focus blocks**. Observed in a run:
168
+
169
+ - After handling the request, it scheduled focus blocks at 9:00, 11:00, 14:00, then **again at 9:00** and **again at 11:00**, causing overlaps.
170
+
171
+ **Cause:** `_choose_simple_action` scores focus blocks by simulating each option against the **current** calendar. Once a slot is used (e.g., 9:00), the next time it considers 9:00 vs 11:00 vs 14:00 vs 16:00, they may all overlap with existing focus blocks. When scores tie, it picks the first (9:00). So it reuses occupied slots.
172
+
173
+ **Fix:** Filter focus blocks to exclude slots that overlap with the current calendar, or improve the baseline to prefer slots with zero overlaps (and handle ties by picking a free slot).
174
+
175
+ ---
176
+
177
+ ## 7. Baseline Agent Logic
178
+
179
+ ### 7.1 Avoids Double Booking?
180
+
181
+ **No.** As above, the baseline can schedule overlapping focus blocks. For **request** actions it minimizes overlaps when choosing accept/reschedule/propose, so it tends to avoid double-booking requests. But for focus blocks it does not.
182
+
183
+ ### 7.2 Respects Travel Constraints?
184
+
185
+ **Yes.** For request actions, it scores by `(overlaps, travel_issues)` and picks the action with the fewest. For focus blocks, it also minimizes travel issues. So it prefers feasible travel.
186
+
187
+ ### 7.3 Prioritizes High-Priority Obligations?
188
+
189
+ **Partially.** It strongly prefers scheduling over rejecting (reject scores (999, 999)), so it rarely rejects important requests. But it does not explicitly prioritize by `importance`. It only minimizes overlaps and travel. For optional low-importance meetings it may still accept if that minimizes violations, instead of rejecting to free time for high-priority tasks.
190
+
191
+ ---
192
+
193
+ ## 8. Summary of Issues
194
+
195
+ | Severity | Issue | Location | Fix | Status |
196
+ |----------|-------|----------|-----|--------|
197
+ | High | Baseline creates overlapping focus blocks | `lifeops_env.py` `_choose_simple_action` | Filter or re-score focus blocks to avoid already-used slots | **Fixed** – prefer non-overlapping slots; fall back to least-bad when all overlap |
198
+ | Medium | `conflict_resolved_bonus` never triggers | `reward.py` | Remove dead code or add event removal to enable it | **Fixed** – removed dead code |
199
+ | Medium | No travel check to first event of day | `reward.py` `travel_issues` | Add optional check from `start_location` to first event | **Fixed** – added `start_location` param, uses `home_location` |
200
+ | Low | `reset("bad_id")` raises raw `KeyError` | `lifeops_env.py` | Catch and re-raise with clearer message | Not applied (minor) |
201
+ | Low | Duplicate reschedule actions at boundaries | `actions.py` | Optional: deduplicate by `(new_start, new_end)` | Not applied (harmless) |
202
+ | Low | Baseline never rejects (scores reject as 999,999) | `lifeops_env.py` | Consider allowing reject when all scheduling options are bad | **Fixed** – reject now scores (0, 0) so it wins when scheduling causes issues |
203
+
204
+ ---
205
+
206
+ ## 9. Suggested Fixes (Minimal Changes)
207
+
208
+ ### Fix 1: Baseline focus block selection
209
+
210
+ In `_choose_simple_action`, when scoring focus blocks, prefer actions that result in **zero** overlaps. If all have overlaps, pick the one with the smallest overlap count, and among those prefer the one that overlaps with the fewest events (e.g., break ties by total overlap duration or event count).
211
+
212
+ A simpler approach: **filter** `focus_actions` to exclude those whose `(new_start_min, duration_min)` would overlap with any existing calendar event. Use `detect_overlaps` with a simulated calendar including the candidate focus block.
213
+
214
+ ### Fix 2: Remove dead `conflict_resolved_bonus`
215
+
216
+ Delete or comment out lines 99–101 in `reward.py` until the environment supports event removal.
217
+
218
+ ### Fix 3: Travel to first event (optional)
219
+
220
+ Add a parameter `start_location` (default `None`) to the scenario or persona. If set, prepend a synthetic “start” event at `start_location` with `end_min=0` before the first real event, so `travel_issues` checks the first leg.
221
+
222
+ ---
223
+
224
+ ## 10. Edge Cases Not Handled
225
+
226
+ 1. **Empty valid_actions:** If both `current_request` is None and `has_unfinished` is False, `valid_actions` is empty. The demo would crash on `valid[0]`. Current scenarios avoid this.
227
+ 2. **Event at midnight (0) or end of day (1440):** Logic uses `<= 1440`; should be verified for boundary events.
228
+ 3. **Zero-duration events:** `_overlap` would treat (100, 100) and (100, 100) as overlapping (`100 < 100` is false, so no overlap). Zero-duration events are not generated.
229
+ 4. **Multiple events at same start/end:** Sorting by `(start_min, end_min)` is deterministic; `travel_issues` order is stable.
230
+ 5. **Unknown locations in travel_times:** Default 30 minutes is conservative; no explicit handling for missing keys.
env/episode_trace.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Episode tracing and structured logging for LifeOps.
3
+
4
+ Provides human-readable step-by-step logs and a timeline view for hackathon demos.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import copy
10
+ from dataclasses import dataclass, field
11
+ from typing import Any, Dict, List, Optional
12
+
13
+
14
+ def _min_to_time(m: int) -> str:
15
+ """Convert minutes since midnight to HH:MM."""
16
+ h, mm = divmod(m, 60)
17
+ return f"{h:02d}:{mm:02d}"
18
+
19
+
20
+ @dataclass
21
+ class StepRecord:
22
+ """Record of a single environment step."""
23
+
24
+ step: int
25
+ action: Dict[str, Any]
26
+ prev_calendar_count: int
27
+ next_calendar_count: int
28
+ prev_pending_count: int
29
+ next_pending_count: int
30
+ reward: float
31
+ breakdown: Dict[str, Any]
32
+ overlaps: List[tuple]
33
+ travel_issues: List[tuple]
34
+ done: bool
35
+
36
+ # State changes (compact summaries)
37
+ added_event: Optional[Dict[str, Any]] = None
38
+ handled_request: Optional[Dict[str, Any]] = None
39
+ task_progress: Optional[Dict[str, str]] = None # task_id -> "X min progress"
40
+
41
+
42
+ @dataclass
43
+ class EpisodeTrace:
44
+ """Trace of an entire episode for logging and timeline display."""
45
+
46
+ scenario_id: str
47
+ persona_name: str
48
+ initial_calendar: List[Dict[str, Any]] = field(default_factory=list)
49
+ initial_tasks: List[Dict[str, Any]] = field(default_factory=list)
50
+ initial_pending_count: int = 0
51
+ steps: List[StepRecord] = field(default_factory=list)
52
+ total_reward: float = 0.0
53
+
54
+ def log_step(
55
+ self,
56
+ step: int,
57
+ action: Dict[str, Any],
58
+ prev_obs: Dict[str, Any],
59
+ next_obs: Dict[str, Any],
60
+ reward: float,
61
+ breakdown: Dict[str, Any],
62
+ info: Dict[str, Any],
63
+ done: bool,
64
+ last_added_event: Optional[Dict[str, Any]] = None,
65
+ last_handled_request: Optional[Dict[str, Any]] = None,
66
+ last_task_progress_minutes: int = 0,
67
+ task_id_progressed: Optional[str] = None,
68
+ ) -> None:
69
+ """Record one step."""
70
+ task_progress = None
71
+ if last_task_progress_minutes and task_id_progressed:
72
+ task_progress = {task_id_progressed: f"{last_task_progress_minutes} min"}
73
+
74
+ self.steps.append(
75
+ StepRecord(
76
+ step=step,
77
+ action=copy.deepcopy(action),
78
+ prev_calendar_count=len(prev_obs.get("calendar", [])),
79
+ next_calendar_count=len(next_obs.get("calendar", [])),
80
+ prev_pending_count=prev_obs.get("pending_request_count", 0),
81
+ next_pending_count=next_obs.get("pending_request_count", 0),
82
+ reward=reward,
83
+ breakdown=copy.deepcopy(breakdown),
84
+ overlaps=info.get("overlaps", []),
85
+ travel_issues=info.get("travel_issues", []),
86
+ done=done,
87
+ added_event=copy.deepcopy(last_added_event) if last_added_event else None,
88
+ handled_request=copy.deepcopy(last_handled_request) if last_handled_request else None,
89
+ task_progress=task_progress,
90
+ )
91
+ )
92
+
93
+ def _format_action(self, a: Dict[str, Any]) -> str:
94
+ at = a.get("action_type", "?")
95
+ if at == "block_focus_time":
96
+ start = a.get("new_start_min")
97
+ dur = a.get("duration_min")
98
+ return f"block_focus_time @ {_min_to_time(start or 0)} for {dur} min"
99
+ if at == "accept_event":
100
+ return f"accept_event (request_id={a.get('request_id', '?')})"
101
+ if at == "reject_event":
102
+ return f"reject_event (request_id={a.get('request_id', '?')})"
103
+ if at == "reschedule_event":
104
+ ns, ne = a.get("new_start_min"), a.get("new_end_min")
105
+ return f"reschedule_event → {_min_to_time(ns or 0)}–{_min_to_time(ne or 0)}"
106
+ if at == "propose_new_time":
107
+ ns, ne = a.get("new_start_min"), a.get("new_end_min")
108
+ return f"propose_new_time → {_min_to_time(ns or 0)}–{_min_to_time(ne or 0)}"
109
+ return str(a)
110
+
111
+ def _format_breakdown(self, b: Dict[str, Any]) -> str:
112
+ parts = []
113
+ for k, v in b.items():
114
+ if k == "total":
115
+ continue
116
+ if isinstance(v, (int, float)) and v != 0:
117
+ parts.append(f"{k}={v:+.1f}")
118
+ return ", ".join(parts) if parts else "(none)"
119
+
120
+ def print_step_log(self, step_record: StepRecord) -> None:
121
+ """Print a single step in human-readable form."""
122
+ s = step_record
123
+ print(f"\n Step {s.step}")
124
+ print(f" Action: {self._format_action(s.action)}")
125
+ print(f" Reward: {s.reward:+.2f} ({self._format_breakdown(s.breakdown)})")
126
+ if s.added_event:
127
+ e = s.added_event
128
+ print(f" + Added: {e.get('title', '?')} @ {_min_to_time(e.get('start_min', 0))}–{_min_to_time(e.get('end_min', 0))} ({e.get('location', '?')})")
129
+ if s.handled_request and s.action.get("action_type") != "block_focus_time":
130
+ r = s.handled_request
131
+ at = s.action.get("action_type", "")
132
+ if at == "reject_event":
133
+ outcome = "rejected"
134
+ elif at == "propose_new_time":
135
+ outcome = "proposed new time (not scheduled)"
136
+ else:
137
+ outcome = "accepted/scheduled"
138
+ print(f" Request {outcome}: {r.get('title', '?')}")
139
+ if s.task_progress:
140
+ for tid, prog in s.task_progress.items():
141
+ print(f" Task progress: {tid} ({prog})")
142
+ if s.overlaps:
143
+ print(f" ⚠ Overlaps: {s.overlaps}")
144
+ if s.travel_issues:
145
+ print(f" ⚠ Travel issues: {[(t[0], t[1], f'need {t[2]}min') for t in s.travel_issues]}")
146
+
147
+ def print_timeline(self, final_calendar: Optional[List[Dict[str, Any]]] = None) -> None:
148
+ """Print a readable timeline of the final calendar."""
149
+ if final_calendar is not None:
150
+ events = list(final_calendar)
151
+ else:
152
+ # Fallback: merge initial + all added events from steps
153
+ events = list(self.initial_calendar)
154
+ for s in self.steps:
155
+ if s.added_event:
156
+ events.append(s.added_event)
157
+
158
+ if not events:
159
+ print("\n (No events on calendar)")
160
+ return
161
+
162
+ ordered = sorted(events, key=lambda e: (int(e["start_min"]), int(e["end_min"])))
163
+ print("\n Timeline (final calendar):")
164
+ print(" " + "-" * 60)
165
+ for e in ordered:
166
+ start = int(e["start_min"])
167
+ end = int(e["end_min"])
168
+ title = e.get("title", e.get("event_id", "?"))
169
+ loc = e.get("location", "?")
170
+ kind = e.get("kind", "meeting")
171
+ print(f" {_min_to_time(start)} – {_min_to_time(end)} {title} @ {loc} [{kind}]")
172
+ print(" " + "-" * 60)
173
+
174
+ def print_full(self, final_calendar: Optional[List[Dict[str, Any]]] = None) -> None:
175
+ """Print the complete episode trace (header, steps, timeline, summary)."""
176
+ print("\n" + "=" * 60)
177
+ print("EPISODE TRACE")
178
+ print("=" * 60)
179
+ print(f"Scenario: {self.scenario_id}")
180
+ print(f"Persona: {self.persona_name}")
181
+ print(f"Initial: {len(self.initial_calendar)} events, {len(self.initial_tasks)} tasks, {self.initial_pending_count} pending requests")
182
+ print("-" * 60)
183
+
184
+ for s in self.steps:
185
+ self.print_step_log(s)
186
+
187
+ self.print_timeline(final_calendar)
188
+
189
+ print("\n" + "-" * 60)
190
+ print(f"Total reward: {self.total_reward:+.2f}")
191
+ print("=" * 60)
env/lifeops_env.py CHANGED
@@ -22,6 +22,7 @@ from typing import Any, Dict, List, Optional, Tuple
22
  try:
23
  # Normal usage (tests / `python -m ...`) expects repo root on sys.path.
24
  from env.actions import Action, ActionType, generate_valid_actions
 
25
  from env.personas import Persona, get_personas
26
  from env.reward import compute_reward, detect_overlaps, travel_issues
27
  from env.scenario_generator import Scenario, get_scenario, list_scenario_ids, sample_scenarios
@@ -30,6 +31,7 @@ except ModuleNotFoundError:
30
  repo_root = Path(__file__).resolve().parent.parent
31
  sys.path.insert(0, str(repo_root))
32
  from env.actions import Action, ActionType, generate_valid_actions
 
33
  from env.personas import Persona, get_personas
34
  from env.reward import compute_reward, detect_overlaps, travel_issues
35
  from env.scenario_generator import Scenario, get_scenario, list_scenario_ids, sample_scenarios
@@ -75,6 +77,7 @@ class LifeOpsState:
75
  last_added_event: Optional[Dict[str, Any]] = None
76
  last_handled_request: Optional[Dict[str, Any]] = None
77
  last_task_progress_minutes: int = 0
 
78
 
79
  def current_request(self) -> Optional[Dict[str, Any]]:
80
  return self.pending_requests[0] if self.pending_requests else None
@@ -168,6 +171,7 @@ class LifeOpsEnv:
168
  self._state.last_added_event = None
169
  self._state.last_handled_request = None
170
  self._state.last_task_progress_minutes = 0
 
171
 
172
  at = str(action_dict.get("action_type"))
173
  if at in {ActionType.accept_event.value, ActionType.reject_event.value, ActionType.reschedule_event.value, ActionType.propose_new_time.value}:
@@ -195,7 +199,15 @@ class LifeOpsEnv:
195
  info: Dict[str, Any] = {
196
  "reward_breakdown": breakdown,
197
  "overlaps": detect_overlaps(next_obs.get("calendar", [])),
198
- "travel_issues": travel_issues(next_obs.get("calendar", []), next_obs.get("travel_times", {})),
 
 
 
 
 
 
 
 
199
  }
200
  return next_obs, float(reward), bool(done), info
201
 
@@ -260,8 +272,10 @@ class LifeOpsEnv:
260
  progress = min(duration, int(t["remaining_minutes"]))
261
  t["remaining_minutes"] = int(t["remaining_minutes"]) - progress
262
  self._state.last_task_progress_minutes = int(progress)
 
263
  else:
264
  self._state.last_task_progress_minutes = 0
 
265
 
266
  def _is_done(self) -> bool:
267
  if self._state.step_count >= self._state.max_steps:
@@ -273,26 +287,44 @@ class LifeOpsEnv:
273
  return True
274
 
275
 
 
 
 
 
 
 
 
 
 
 
276
  def _choose_simple_action(env: LifeOpsEnv) -> Action:
277
  """
278
  Tiny heuristic policy for manual running:
279
- - If accept would cause overlap/travel issues, try reschedule actions next.
280
  - Otherwise accept the request.
281
- - If no request, block focus time.
282
  """
283
 
284
  valid = env.valid_actions()
285
  obs = env.observation()
286
  req = obs.get("current_request")
 
 
 
 
287
  if req is None:
288
  focus_actions = [a for a in valid if a.action_type == ActionType.block_focus_time]
289
  if not focus_actions:
290
  return valid[0]
291
 
 
 
 
 
292
  def focus_score(a: Action) -> Tuple[int, int]:
293
  start = int(a.new_start_min or 0)
294
  dur = int(a.duration_min or 0)
295
- sim = list(obs.get("calendar", [])) + [
296
  {
297
  "event_id": "focus_sim",
298
  "start_min": start,
@@ -300,43 +332,58 @@ def _choose_simple_action(env: LifeOpsEnv) -> Action:
300
  "location": obs["persona"].get("primary_work_location", "Home"),
301
  }
302
  ]
303
- return (len(detect_overlaps(sim)), len(travel_issues(sim, obs.get("travel_times", {}))))
304
 
305
- focus_actions.sort(key=focus_score)
306
- return focus_actions[0]
307
 
308
  # Pick the request-handling action that minimizes feasibility violations.
 
309
  def score_action(a: Action) -> Tuple[int, int]:
310
  # (overlap_count, travel_issue_count) — smaller is better
311
  if a.action_type == ActionType.reject_event:
312
- return (999, 999) # prefer scheduling over rejecting (manual runner)
313
  if a.action_type in {ActionType.accept_event, ActionType.reschedule_event, ActionType.propose_new_time}:
314
  added = dict(req)
315
  if a.action_type in {ActionType.reschedule_event, ActionType.propose_new_time}:
316
  added["start_min"] = int(a.new_start_min or added["start_min"])
317
  added["end_min"] = int(a.new_end_min or added["end_min"])
318
- # NOTE: propose_new_time does not actually schedule; don't add it to sim calendar.
319
- sim_events = list(obs.get("calendar", [])) + ([] if a.action_type == ActionType.propose_new_time else [added])
320
- return (len(detect_overlaps(sim_events)), len(travel_issues(sim_events, obs.get("travel_times", {}))))
 
321
  return (500, 500)
322
 
323
- candidates = [a for a in valid if a.action_type in {ActionType.accept_event, ActionType.reschedule_event, ActionType.propose_new_time, ActionType.reject_event}]
 
 
 
 
 
324
  candidates.sort(key=score_action)
325
  return candidates[0] if candidates else valid[0]
326
 
327
 
328
  if __name__ == "__main__":
329
- # Simple manual episode runner: `python env/lifeops_env.py`
330
  env = LifeOpsEnv(seed=7)
331
  obs = env.reset()
332
- print("Scenario:", obs["scenario_id"])
333
- print("Persona:", obs["persona"]["name"])
 
 
 
 
 
 
334
 
335
  done = False
336
  total_reward = 0.0
 
 
337
  while not done:
338
- obs = env.observation()
339
- req = obs.get("current_request")
340
  if req is not None:
341
  print(f"\nCurrent request: {req['title']} ({req['start_min']}..{req['end_min']}) @ {req['location']}")
342
  else:
@@ -344,13 +391,30 @@ if __name__ == "__main__":
344
 
345
  action = _choose_simple_action(env)
346
  next_obs, reward, done, info = env.step(action)
 
347
  total_reward += reward
348
- print("Action:", action.to_dict())
349
- print("Reward:", reward)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  if info.get("overlaps"):
351
- print("Overlaps:", info["overlaps"])
352
  if info.get("travel_issues"):
353
- print("Travel issues:", info["travel_issues"])
354
 
355
- print("\nEpisode done. Total reward:", total_reward)
 
356
 
 
22
  try:
23
  # Normal usage (tests / `python -m ...`) expects repo root on sys.path.
24
  from env.actions import Action, ActionType, generate_valid_actions
25
+ from env.episode_trace import EpisodeTrace
26
  from env.personas import Persona, get_personas
27
  from env.reward import compute_reward, detect_overlaps, travel_issues
28
  from env.scenario_generator import Scenario, get_scenario, list_scenario_ids, sample_scenarios
 
31
  repo_root = Path(__file__).resolve().parent.parent
32
  sys.path.insert(0, str(repo_root))
33
  from env.actions import Action, ActionType, generate_valid_actions
34
+ from env.episode_trace import EpisodeTrace
35
  from env.personas import Persona, get_personas
36
  from env.reward import compute_reward, detect_overlaps, travel_issues
37
  from env.scenario_generator import Scenario, get_scenario, list_scenario_ids, sample_scenarios
 
77
  last_added_event: Optional[Dict[str, Any]] = None
78
  last_handled_request: Optional[Dict[str, Any]] = None
79
  last_task_progress_minutes: int = 0
80
+ last_task_id_progressed: Optional[str] = None
81
 
82
  def current_request(self) -> Optional[Dict[str, Any]]:
83
  return self.pending_requests[0] if self.pending_requests else None
 
171
  self._state.last_added_event = None
172
  self._state.last_handled_request = None
173
  self._state.last_task_progress_minutes = 0
174
+ self._state.last_task_id_progressed = None
175
 
176
  at = str(action_dict.get("action_type"))
177
  if at in {ActionType.accept_event.value, ActionType.reject_event.value, ActionType.reschedule_event.value, ActionType.propose_new_time.value}:
 
199
  info: Dict[str, Any] = {
200
  "reward_breakdown": breakdown,
201
  "overlaps": detect_overlaps(next_obs.get("calendar", [])),
202
+ "travel_issues": travel_issues(
203
+ next_obs.get("calendar", []),
204
+ next_obs.get("travel_times", {}),
205
+ start_location=next_obs.get("persona", {}).get("home_location"),
206
+ ),
207
+ "last_added_event": copy.deepcopy(self._state.last_added_event),
208
+ "last_handled_request": copy.deepcopy(self._state.last_handled_request),
209
+ "last_task_progress_minutes": int(self._state.last_task_progress_minutes),
210
+ "last_task_id_progressed": self._state.last_task_id_progressed,
211
  }
212
  return next_obs, float(reward), bool(done), info
213
 
 
272
  progress = min(duration, int(t["remaining_minutes"]))
273
  t["remaining_minutes"] = int(t["remaining_minutes"]) - progress
274
  self._state.last_task_progress_minutes = int(progress)
275
+ self._state.last_task_id_progressed = str(t.get("task_id", "?"))
276
  else:
277
  self._state.last_task_progress_minutes = 0
278
+ self._state.last_task_id_progressed = None
279
 
280
  def _is_done(self) -> bool:
281
  if self._state.step_count >= self._state.max_steps:
 
287
  return True
288
 
289
 
290
+ def _focus_overlaps_calendar(a: Action, calendar: List[Dict[str, Any]]) -> bool:
291
+ """True if adding this focus block would overlap with existing calendar events."""
292
+ start = int(a.new_start_min or 0)
293
+ dur = int(a.duration_min or 0)
294
+ sim = list(calendar) + [
295
+ {"event_id": "_", "start_min": start, "end_min": start + dur, "location": "x"},
296
+ ]
297
+ return len(detect_overlaps(sim)) > 0
298
+
299
+
300
  def _choose_simple_action(env: LifeOpsEnv) -> Action:
301
  """
302
  Tiny heuristic policy for manual running:
303
+ - If accept would cause overlap/travel issues, try reschedule/propose, or reject.
304
  - Otherwise accept the request.
305
+ - If no request, block focus time (prefer non-overlapping slots).
306
  """
307
 
308
  valid = env.valid_actions()
309
  obs = env.observation()
310
  req = obs.get("current_request")
311
+ calendar = obs.get("calendar", [])
312
+ travel_times = obs.get("travel_times", {})
313
+ home = obs.get("persona", {}).get("home_location")
314
+
315
  if req is None:
316
  focus_actions = [a for a in valid if a.action_type == ActionType.block_focus_time]
317
  if not focus_actions:
318
  return valid[0]
319
 
320
+ # Prefer focus blocks that don't overlap with existing calendar.
321
+ non_overlapping = [a for a in focus_actions if not _focus_overlaps_calendar(a, calendar)]
322
+ candidates = non_overlapping if non_overlapping else focus_actions
323
+
324
  def focus_score(a: Action) -> Tuple[int, int]:
325
  start = int(a.new_start_min or 0)
326
  dur = int(a.duration_min or 0)
327
+ sim = list(calendar) + [
328
  {
329
  "event_id": "focus_sim",
330
  "start_min": start,
 
332
  "location": obs["persona"].get("primary_work_location", "Home"),
333
  }
334
  ]
335
+ return (len(detect_overlaps(sim)), len(travel_issues(sim, travel_times, home)))
336
 
337
+ candidates.sort(key=focus_score)
338
+ return candidates[0]
339
 
340
  # Pick the request-handling action that minimizes feasibility violations.
341
+ # Reject scores (0, 0) so we prefer it when all scheduling options cause issues.
342
  def score_action(a: Action) -> Tuple[int, int]:
343
  # (overlap_count, travel_issue_count) — smaller is better
344
  if a.action_type == ActionType.reject_event:
345
+ return (0, 0) # no new overlaps/travel; prefer when scheduling options are bad
346
  if a.action_type in {ActionType.accept_event, ActionType.reschedule_event, ActionType.propose_new_time}:
347
  added = dict(req)
348
  if a.action_type in {ActionType.reschedule_event, ActionType.propose_new_time}:
349
  added["start_min"] = int(a.new_start_min or added["start_min"])
350
  added["end_min"] = int(a.new_end_min or added["end_min"])
351
+ sim_events = list(calendar) + (
352
+ [] if a.action_type == ActionType.propose_new_time else [added]
353
+ )
354
+ return (len(detect_overlaps(sim_events)), len(travel_issues(sim_events, travel_times, home)))
355
  return (500, 500)
356
 
357
+ candidates = [
358
+ a
359
+ for a in valid
360
+ if a.action_type
361
+ in {ActionType.accept_event, ActionType.reschedule_event, ActionType.propose_new_time, ActionType.reject_event}
362
+ ]
363
  candidates.sort(key=score_action)
364
  return candidates[0] if candidates else valid[0]
365
 
366
 
367
  if __name__ == "__main__":
368
+ # Simple manual episode runner with tracing: `python env/lifeops_env.py`
369
  env = LifeOpsEnv(seed=7)
370
  obs = env.reset()
371
+
372
+ trace = EpisodeTrace(
373
+ scenario_id=obs["scenario_id"],
374
+ persona_name=obs["persona"]["name"],
375
+ initial_calendar=copy.deepcopy(obs.get("calendar", [])),
376
+ initial_tasks=copy.deepcopy(obs.get("tasks", [])),
377
+ initial_pending_count=obs.get("pending_request_count", 0),
378
+ )
379
 
380
  done = False
381
  total_reward = 0.0
382
+ step_num = 0
383
+
384
  while not done:
385
+ prev_obs = env.observation()
386
+ req = prev_obs.get("current_request")
387
  if req is not None:
388
  print(f"\nCurrent request: {req['title']} ({req['start_min']}..{req['end_min']}) @ {req['location']}")
389
  else:
 
391
 
392
  action = _choose_simple_action(env)
393
  next_obs, reward, done, info = env.step(action)
394
+ step_num += 1
395
  total_reward += reward
396
+
397
+ trace.log_step(
398
+ step=step_num,
399
+ action=action.to_dict(),
400
+ prev_obs=prev_obs,
401
+ next_obs=next_obs,
402
+ reward=reward,
403
+ breakdown=info.get("reward_breakdown", {}),
404
+ info=info,
405
+ done=done,
406
+ last_added_event=info.get("last_added_event"),
407
+ last_handled_request=info.get("last_handled_request"),
408
+ last_task_progress_minutes=info.get("last_task_progress_minutes", 0),
409
+ task_id_progressed=info.get("last_task_id_progressed"),
410
+ )
411
+
412
+ print(f" → Action: {trace._format_action(action.to_dict())} | Reward: {reward:+.2f}")
413
  if info.get("overlaps"):
414
+ print(f"Overlaps: {info['overlaps']}")
415
  if info.get("travel_issues"):
416
+ print(f"Travel issues: {info['travel_issues']}")
417
 
418
+ trace.total_reward = total_reward
419
+ trace.print_full(final_calendar=next_obs.get("calendar", []))
420
 
env/reward.py CHANGED
@@ -10,7 +10,7 @@ The reward is intentionally small and readable. It's "shaped" to encourage:
10
 
11
  from __future__ import annotations
12
 
13
- from typing import Any, Dict, List, Tuple
14
 
15
  def _overlap(a_start: int, a_end: int, b_start: int, b_end: int) -> bool:
16
  return a_start < b_end and b_start < a_end
@@ -46,10 +46,14 @@ def _travel_time_minutes(travel_times: Dict[str, Dict[str, int]], a_loc: str, b_
46
  def travel_issues(
47
  events: List[Dict[str, Any]],
48
  travel_times: Dict[str, Dict[str, int]],
 
49
  ) -> List[Tuple[str, str, int, int]]:
50
  """
51
  Returns travel feasibility issues between consecutive events.
52
 
 
 
 
53
  Output tuple: (from_event_id, to_event_id, needed_minutes, available_minutes)
54
  """
55
 
@@ -58,6 +62,15 @@ def travel_issues(
58
 
59
  ordered = sorted(events, key=lambda e: (int(e["start_min"]), int(e["end_min"])))
60
  issues: List[Tuple[str, str, int, int]] = []
 
 
 
 
 
 
 
 
 
61
  for prev, nxt in zip(ordered, ordered[1:]):
62
  prev_end = int(prev["end_min"])
63
  nxt_start = int(nxt["start_min"])
@@ -96,11 +109,12 @@ def compute_reward(
96
  if next_overlaps:
97
  reward -= 5.0 * len(next_overlaps)
98
  breakdown["overlap_penalty"] = -5.0 * len(next_overlaps)
99
- if prev_overlaps and len(next_overlaps) < len(prev_overlaps):
100
- reward += 3.0
101
- breakdown["conflict_resolved_bonus"] = 3.0
102
 
103
- issues = travel_issues(next_events, travel_times)
 
 
 
 
104
  if issues:
105
  # Penalize per infeasible leg.
106
  travel_pen = -4.0 * len(issues) * float(persona.get("travel_aversion_weight", 1.0))
 
10
 
11
  from __future__ import annotations
12
 
13
+ from typing import Any, Dict, List, Optional, Tuple
14
 
15
  def _overlap(a_start: int, a_end: int, b_start: int, b_end: int) -> bool:
16
  return a_start < b_end and b_start < a_end
 
46
  def travel_issues(
47
  events: List[Dict[str, Any]],
48
  travel_times: Dict[str, Dict[str, int]],
49
+ start_location: Optional[str] = None,
50
  ) -> List[Tuple[str, str, int, int]]:
51
  """
52
  Returns travel feasibility issues between consecutive events.
53
 
54
+ If start_location is provided (e.g. persona home), also checks whether the
55
+ user can reach the first event of the day in time.
56
+
57
  Output tuple: (from_event_id, to_event_id, needed_minutes, available_minutes)
58
  """
59
 
 
62
 
63
  ordered = sorted(events, key=lambda e: (int(e["start_min"]), int(e["end_min"])))
64
  issues: List[Tuple[str, str, int, int]] = []
65
+
66
+ # Check travel from start_location to first event (if provided).
67
+ if start_location is not None:
68
+ first = ordered[0]
69
+ available = int(first["start_min"])
70
+ needed = _travel_time_minutes(travel_times, start_location, str(first["location"]))
71
+ if needed > available:
72
+ issues.append(("__start__", str(first["event_id"]), needed, available))
73
+
74
  for prev, nxt in zip(ordered, ordered[1:]):
75
  prev_end = int(prev["end_min"])
76
  nxt_start = int(nxt["start_min"])
 
109
  if next_overlaps:
110
  reward -= 5.0 * len(next_overlaps)
111
  breakdown["overlap_penalty"] = -5.0 * len(next_overlaps)
 
 
 
112
 
113
+ issues = travel_issues(
114
+ next_events,
115
+ travel_times,
116
+ start_location=persona.get("home_location"),
117
+ )
118
  if issues:
119
  # Penalize per infeasible leg.
120
  travel_pen = -4.0 * len(issues) * float(persona.get("travel_aversion_weight", 1.0))
training/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Training utilities for LifeOps RL."""
training/train_rl.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Minimal RL training loop for LifeOps.
4
+
5
+ Runs episodes in the environment, collects trajectories, and prints results.
6
+ Uses a simple policy (random or heuristic). No external RL frameworks required.
7
+
8
+ For learned policies, consider adding HuggingFace TRL or a small PyTorch policy.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import random
14
+ import sys
15
+ from pathlib import Path
16
+ from typing import Any, Dict, List, Optional, Tuple
17
+
18
+ # Add repo root for imports
19
+ repo_root = Path(__file__).resolve().parent.parent
20
+ if str(repo_root) not in sys.path:
21
+ sys.path.insert(0, str(repo_root))
22
+
23
+ from env.actions import Action
24
+ from env.lifeops_env import LifeOpsEnv, _choose_simple_action
25
+
26
+
27
+ def random_policy(env: LifeOpsEnv) -> Action:
28
+ """Pick uniformly from valid actions."""
29
+ valid = env.valid_actions()
30
+ if not valid:
31
+ raise RuntimeError("No valid actions")
32
+ return random.choice(valid)
33
+
34
+
35
+ def collect_trajectory(
36
+ env: LifeOpsEnv,
37
+ policy: str = "random",
38
+ scenario_id: Optional[str] = None,
39
+ ) -> Tuple[List[Dict[str, Any]], float, int, str]:
40
+ """
41
+ Run one episode and collect trajectory.
42
+
43
+ Returns:
44
+ trajectory: list of (obs, action_dict, reward, done) per step
45
+ total_reward: sum of rewards
46
+ episode_length: number of steps
47
+ scenario_id: scenario used
48
+ """
49
+ obs = env.reset(scenario_id=scenario_id)
50
+ scenario_id = obs["scenario_id"]
51
+ trajectory: List[Dict[str, Any]] = []
52
+ total_reward = 0.0
53
+ step_count = 0
54
+
55
+ policy_fn = _choose_simple_action if policy == "heuristic" else random_policy
56
+
57
+ done = False
58
+ while not done:
59
+ action = policy_fn(env)
60
+ action_dict = action.to_dict()
61
+
62
+ next_obs, reward, done, info = env.step(action)
63
+ step_count += 1
64
+ total_reward += reward
65
+
66
+ trajectory.append({
67
+ "obs": obs,
68
+ "action": action_dict,
69
+ "reward": reward,
70
+ "next_obs": next_obs,
71
+ "done": done,
72
+ "info": info,
73
+ })
74
+ obs = next_obs
75
+
76
+ return trajectory, total_reward, step_count, scenario_id
77
+
78
+
79
+ def _format_action_short(a: Dict[str, Any]) -> str:
80
+ """Format action for key decisions summary."""
81
+ at = a.get("action_type", "?")
82
+ if at == "block_focus_time":
83
+ start = a.get("new_start_min", 0)
84
+ dur = a.get("duration_min", 0)
85
+ h, m = divmod(start or 0, 60)
86
+ return f"block_focus @ {h:02d}:{m:02d} ({dur}min)"
87
+ if at == "accept_event":
88
+ return f"accept request {a.get('request_id', '?')}"
89
+ if at == "reject_event":
90
+ return f"reject request {a.get('request_id', '?')}"
91
+ if at == "reschedule_event":
92
+ ns = a.get("new_start_min", 0)
93
+ h, m = divmod(ns or 0, 60)
94
+ return f"reschedule → {h:02d}:{m:02d}"
95
+ if at == "propose_new_time":
96
+ ns = a.get("new_start_min", 0)
97
+ h, m = divmod(ns or 0, 60)
98
+ return f"propose → {h:02d}:{m:02d}"
99
+ return at
100
+
101
+
102
+ def print_episode_results(
103
+ episode: int,
104
+ total_reward: float,
105
+ episode_length: int,
106
+ scenario_id: str,
107
+ trajectory: List[Dict[str, Any]],
108
+ verbose: bool = False,
109
+ ) -> None:
110
+ """Print human-readable episode results."""
111
+ print(f"\n--- Episode {episode} ---")
112
+ print(f" Scenario: {scenario_id}")
113
+ print(f" Steps: {episode_length}")
114
+ print(f" Total reward: {total_reward:+.2f}")
115
+
116
+ # Key decisions taken
117
+ if trajectory:
118
+ decisions = [_format_action_short(t["action"]) for t in trajectory]
119
+ print(f" Key decisions: {', '.join(decisions)}")
120
+
121
+ if verbose:
122
+ for i, t in enumerate(trajectory):
123
+ a = t["action"]
124
+ at = a.get("action_type", "?")
125
+ r = t["reward"]
126
+ print(f" Step {i + 1}: {at} reward={r:+.2f}")
127
+
128
+
129
+ def train(
130
+ num_episodes: int = 20,
131
+ seed: Optional[int] = 42,
132
+ policy: str = "random",
133
+ scenario_id: Optional[str] = None,
134
+ verbose: bool = False,
135
+ ) -> Dict[str, Any]:
136
+ """
137
+ Run RL training loop: collect trajectories and print results.
138
+
139
+ Args:
140
+ num_episodes: number of episodes to run
141
+ seed: random seed for env (None = random)
142
+ policy: "random" or "heuristic"
143
+ scenario_id: fix scenario (None = random each episode)
144
+ verbose: print per-step details
145
+
146
+ Returns:
147
+ Summary dict with episode rewards and stats
148
+ """
149
+ env = LifeOpsEnv(seed=seed)
150
+
151
+ all_rewards: List[float] = []
152
+ all_lengths: List[int] = []
153
+ all_scenarios: List[str] = []
154
+
155
+ print("=" * 50)
156
+ print("LifeOps RL Training")
157
+ print("=" * 50)
158
+ print(f"Episodes: {num_episodes} | Policy: {policy} | Seed: {seed}")
159
+
160
+ for ep in range(1, num_episodes + 1):
161
+ trajectory, total_reward, ep_len, scenario_id_used = collect_trajectory(
162
+ env, policy=policy, scenario_id=scenario_id
163
+ )
164
+
165
+ all_rewards.append(total_reward)
166
+ all_lengths.append(ep_len)
167
+ all_scenarios.append(scenario_id_used)
168
+
169
+ print_episode_results(
170
+ episode=ep,
171
+ total_reward=total_reward,
172
+ episode_length=ep_len,
173
+ scenario_id=scenario_id_used,
174
+ trajectory=trajectory,
175
+ verbose=verbose,
176
+ )
177
+
178
+ # Summary
179
+ avg_reward = sum(all_rewards) / len(all_rewards)
180
+ avg_len = sum(all_lengths) / len(all_lengths)
181
+ print("\n" + "=" * 50)
182
+ print("Training Summary")
183
+ print("=" * 50)
184
+ print(f" Episodes: {num_episodes}")
185
+ print(f" Avg reward: {avg_reward:+.2f}")
186
+ print(f" Avg length: {avg_len:.1f} steps")
187
+ print(f" Best reward: {max(all_rewards):+.2f}")
188
+ print(f" Worst reward: {min(all_rewards):+.2f}")
189
+ print("=" * 50)
190
+
191
+ return {
192
+ "rewards": all_rewards,
193
+ "lengths": all_lengths,
194
+ "scenarios": all_scenarios,
195
+ "avg_reward": avg_reward,
196
+ "avg_length": avg_len,
197
+ }
198
+
199
+
200
+ if __name__ == "__main__":
201
+ import argparse
202
+
203
+ parser = argparse.ArgumentParser(description="Train RL agent on LifeOps")
204
+ parser.add_argument("-n", "--episodes", type=int, default=10, help="Number of episodes")
205
+ parser.add_argument("-p", "--policy", choices=["random", "heuristic"], default="random")
206
+ parser.add_argument("-s", "--seed", type=int, default=42)
207
+ parser.add_argument("--scenario", type=str, default=None, help="Fix scenario (e.g. s1_basic_conflict)")
208
+ parser.add_argument("-v", "--verbose", action="store_true")
209
+ args = parser.parse_args()
210
+
211
+ train(
212
+ num_episodes=args.episodes,
213
+ seed=args.seed,
214
+ policy=args.policy,
215
+ scenario_id=args.scenario,
216
+ verbose=args.verbose,
217
+ )