cmpatino HF Staff lewtun HF Staff commited on
Commit
0321690
·
unverified ·
1 Parent(s): d7637ba

Implement `/resume` command for CLI (#233)

Browse files

* Add /resume command to CLI

* Refine /resume feature

* Refine /resume implementation

* Address review comments

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

agent/core/agent_loop.py CHANGED
@@ -7,6 +7,7 @@ import json
7
  import logging
8
  import time
9
  from dataclasses import dataclass, field
 
10
  from typing import Any
11
 
12
  from litellm import (
@@ -29,7 +30,7 @@ from agent.core.doom_loop import check_for_doom_loop
29
  from agent.core.hub_artifacts import start_session_artifact_collection_task
30
  from agent.core.llm_params import _resolve_llm_params
31
  from agent.core.prompt_caching import with_prompt_caching
32
- from agent.core.session import Event, OpType, Session
33
  from agent.core.tools import ToolRouter
34
  from agent.tools.jobs_tool import CPU_FLAVORS
35
  from agent.tools.sandbox_tool import DEFAULT_CPU_SANDBOX_HARDWARE
@@ -1667,6 +1668,20 @@ class Handlers:
1667
  logger.warning("Undo: no user message found to remove")
1668
  await session.send_event(Event(event_type="undo_complete"))
1669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1670
  @staticmethod
1671
  async def exec_approval(session: Session, approvals: list[dict]) -> None:
1672
  """Handle batch job execution approval"""
@@ -1953,6 +1968,16 @@ async def process_submission(session: Session, submission) -> bool:
1953
  await Handlers.undo(session)
1954
  return True
1955
 
 
 
 
 
 
 
 
 
 
 
1956
  if op.op_type == OpType.EXEC_APPROVAL:
1957
  approvals = op.data.get("approvals", []) if op.data else []
1958
  await Handlers.exec_approval(session, approvals)
@@ -2007,7 +2032,7 @@ async def submission_loop(
2007
  # to publish to the user's HF dataset gets a fresh attempt on next run.
2008
  if config and config.save_sessions:
2009
  Session.retry_failed_uploads_detached(
2010
- directory="session_logs",
2011
  repo_id=config.session_dataset_repo,
2012
  personal_repo_id=session._personal_trace_repo_id(),
2013
  )
 
7
  import logging
8
  import time
9
  from dataclasses import dataclass, field
10
+ from pathlib import Path
11
  from typing import Any
12
 
13
  from litellm import (
 
30
  from agent.core.hub_artifacts import start_session_artifact_collection_task
31
  from agent.core.llm_params import _resolve_llm_params
32
  from agent.core.prompt_caching import with_prompt_caching
33
+ from agent.core.session import DEFAULT_SESSION_LOG_DIR, Event, OpType, Session
34
  from agent.core.tools import ToolRouter
35
  from agent.tools.jobs_tool import CPU_FLAVORS
36
  from agent.tools.sandbox_tool import DEFAULT_CPU_SANDBOX_HARDWARE
 
1668
  logger.warning("Undo: no user message found to remove")
1669
  await session.send_event(Event(event_type="undo_complete"))
1670
 
1671
+ @staticmethod
1672
+ async def resume(session: Session, path: str) -> None:
1673
+ """Reload context from a saved session log into the active session."""
1674
+ from agent.core.session_resume import restore_session_from_log
1675
+
1676
+ try:
1677
+ result = restore_session_from_log(session, Path(path))
1678
+ except Exception as e:
1679
+ await session.send_event(
1680
+ Event(event_type="error", data={"error": f"Resume failed: {e}"})
1681
+ )
1682
+ return
1683
+ await session.send_event(Event(event_type="resume_complete", data=result))
1684
+
1685
  @staticmethod
1686
  async def exec_approval(session: Session, approvals: list[dict]) -> None:
1687
  """Handle batch job execution approval"""
 
1968
  await Handlers.undo(session)
1969
  return True
1970
 
1971
+ if op.op_type == OpType.RESUME:
1972
+ path = op.data.get("path") if op.data else None
1973
+ if path:
1974
+ await Handlers.resume(session, path)
1975
+ else:
1976
+ await session.send_event(
1977
+ Event(event_type="error", data={"error": "Resume requires a path"})
1978
+ )
1979
+ return True
1980
+
1981
  if op.op_type == OpType.EXEC_APPROVAL:
1982
  approvals = op.data.get("approvals", []) if op.data else []
1983
  await Handlers.exec_approval(session, approvals)
 
2032
  # to publish to the user's HF dataset gets a fresh attempt on next run.
2033
  if config and config.save_sessions:
2034
  Session.retry_failed_uploads_detached(
2035
+ directory=str(DEFAULT_SESSION_LOG_DIR),
2036
  repo_id=config.session_dataset_repo,
2037
  personal_repo_id=session._personal_trace_repo_id(),
2038
  )
agent/core/session.py CHANGED
@@ -21,6 +21,8 @@ logger = logging.getLogger(__name__)
21
  _DEFAULT_MAX_TOKENS = 200_000
22
  _TURN_COMPLETE_NOTIFICATION_CHARS = 39000
23
 
 
 
24
 
25
  def _get_max_tokens_safe(model_name: str) -> int:
26
  """Return the max input-context tokens for a model.
@@ -60,6 +62,7 @@ class OpType(Enum):
60
  INTERRUPT = "interrupt"
61
  UNDO = "undo"
62
  COMPACT = "compact"
 
63
  SHUTDOWN = "shutdown"
64
 
65
 
@@ -418,7 +421,7 @@ class Session:
418
 
419
  def save_trajectory_local(
420
  self,
421
- directory: str = "session_logs",
422
  upload_status: str = "pending",
423
  dataset_url: Optional[str] = None,
424
  ) -> Optional[str]:
@@ -613,7 +616,7 @@ class Session:
613
 
614
  @staticmethod
615
  def retry_failed_uploads_detached(
616
- directory: str = "session_logs",
617
  repo_id: Optional[str] = None,
618
  *,
619
  personal_repo_id: Optional[str] = None,
 
21
  _DEFAULT_MAX_TOKENS = 200_000
22
  _TURN_COMPLETE_NOTIFICATION_CHARS = 39000
23
 
24
+ DEFAULT_SESSION_LOG_DIR = Path("session_logs")
25
+
26
 
27
  def _get_max_tokens_safe(model_name: str) -> int:
28
  """Return the max input-context tokens for a model.
 
62
  INTERRUPT = "interrupt"
63
  UNDO = "undo"
64
  COMPACT = "compact"
65
+ RESUME = "resume"
66
  SHUTDOWN = "shutdown"
67
 
68
 
 
421
 
422
  def save_trajectory_local(
423
  self,
424
+ directory: str = str(DEFAULT_SESSION_LOG_DIR),
425
  upload_status: str = "pending",
426
  dataset_url: Optional[str] = None,
427
  ) -> Optional[str]:
 
616
 
617
  @staticmethod
618
  def retry_failed_uploads_detached(
619
+ directory: str = str(DEFAULT_SESSION_LOG_DIR),
620
  repo_id: Optional[str] = None,
621
  *,
622
  personal_repo_id: Optional[str] = None,
agent/core/session_resume.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reload a previously saved session log into the active CLI session."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import logging
7
+ import re
8
+ from dataclasses import dataclass
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+ from typing import Any
12
+
13
+ from litellm import Message
14
+
15
+ from agent.core.model_switcher import is_valid_model_id
16
+ from agent.core.session import DEFAULT_SESSION_LOG_DIR
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ _REDACTED_MARKER = re.compile(r"\[REDACTED_[A-Z_]+\]")
21
+
22
+
23
+ @dataclass
24
+ class SessionLogEntry:
25
+ """Metadata for a locally saved session log."""
26
+
27
+ path: Path
28
+ session_id: str
29
+ session_start_time: str | None
30
+ session_end_time: str | None
31
+ model_name: str | None
32
+ message_count: int
33
+ preview: str
34
+ mtime: float
35
+
36
+
37
+ def _message_preview(content: Any, max_chars: int = 72) -> str:
38
+ """Return a one-line preview for string or OpenAI-style block content."""
39
+ if isinstance(content, str):
40
+ text = content
41
+ elif isinstance(content, list):
42
+ parts: list[str] = []
43
+ for block in content:
44
+ if isinstance(block, dict):
45
+ value = block.get("text") or block.get("content")
46
+ if isinstance(value, str):
47
+ parts.append(value)
48
+ elif isinstance(block, str):
49
+ parts.append(block)
50
+ text = " ".join(parts)
51
+ else:
52
+ text = ""
53
+ text = " ".join(text.split())
54
+ if len(text) > max_chars:
55
+ return text[: max_chars - 1].rstrip() + "…"
56
+ return text
57
+
58
+
59
+ def _first_user_preview(messages: list[Any]) -> str:
60
+ for raw in messages:
61
+ if isinstance(raw, dict) and raw.get("role") == "user":
62
+ preview = _message_preview(raw.get("content"))
63
+ if preview:
64
+ return preview
65
+ return "(no user prompt preview)"
66
+
67
+
68
+ def list_session_logs(
69
+ directory: Path = DEFAULT_SESSION_LOG_DIR,
70
+ ) -> list[SessionLogEntry]:
71
+ """Return readable session logs under ``directory``, newest first."""
72
+ if not directory.exists():
73
+ return []
74
+
75
+ entries: list[SessionLogEntry] = []
76
+ for path in directory.glob("*.json"):
77
+ try:
78
+ with open(path) as f:
79
+ data = json.load(f)
80
+ except Exception:
81
+ continue
82
+
83
+ messages = data.get("messages") or []
84
+ if not isinstance(messages, list):
85
+ continue
86
+
87
+ session_id = data.get("session_id")
88
+ if not isinstance(session_id, str) or not session_id:
89
+ session_id = path.stem
90
+
91
+ stat = path.stat()
92
+ entries.append(
93
+ SessionLogEntry(
94
+ path=path,
95
+ session_id=session_id,
96
+ session_start_time=data.get("session_start_time"),
97
+ session_end_time=data.get("session_end_time"),
98
+ model_name=data.get("model_name"),
99
+ message_count=len(messages),
100
+ preview=_first_user_preview(messages),
101
+ mtime=stat.st_mtime,
102
+ )
103
+ )
104
+
105
+ entries.sort(key=lambda item: item.mtime, reverse=True)
106
+ return entries
107
+
108
+
109
+ def format_session_log_entry(index: int, entry: SessionLogEntry) -> str:
110
+ timestamp = entry.session_end_time or entry.session_start_time
111
+ label = "unknown time"
112
+ if isinstance(timestamp, str) and timestamp:
113
+ try:
114
+ label = datetime.fromisoformat(timestamp).strftime("%Y-%m-%d %H:%M")
115
+ except ValueError:
116
+ label = timestamp[:16]
117
+ short_id = entry.session_id[:8]
118
+ model = entry.model_name or "unknown model"
119
+ return (
120
+ f"{index:>2}. {label} {short_id} "
121
+ f"{entry.message_count} msgs {model}\n"
122
+ f" {entry.preview}"
123
+ )
124
+
125
+
126
+ def resolve_session_log_arg(
127
+ arg: str,
128
+ entries: list[SessionLogEntry],
129
+ directory: Path = DEFAULT_SESSION_LOG_DIR,
130
+ ) -> Path | None:
131
+ """Resolve ``/resume <arg>`` as index, path, filename, or session id prefix."""
132
+ value = arg.strip()
133
+ if not value:
134
+ return None
135
+
136
+ if value.isdigit():
137
+ idx = int(value)
138
+ if 1 <= idx <= len(entries):
139
+ return entries[idx - 1].path
140
+
141
+ candidate = Path(value).expanduser()
142
+ candidates = [candidate]
143
+ if not candidate.is_absolute():
144
+ candidates.append(directory / candidate)
145
+ if candidate.suffix != ".json":
146
+ candidates.append(directory / f"{value}.json")
147
+
148
+ for path in candidates:
149
+ if path.exists() and path.is_file():
150
+ return path
151
+
152
+ matches = [
153
+ entry.path
154
+ for entry in entries
155
+ if entry.session_id.startswith(value) or entry.path.name.startswith(value)
156
+ ]
157
+ if len(matches) == 1:
158
+ return matches[0]
159
+ return None
160
+
161
+
162
+ def _turn_count_from_messages(messages: list[Any]) -> int:
163
+ return sum(
164
+ 1 for raw in messages if isinstance(raw, dict) and raw.get("role") == "user"
165
+ )
166
+
167
+
168
+ def _has_redacted_content(messages: list[Any]) -> bool:
169
+ """Whether any message body contains a ``[REDACTED_*]`` marker."""
170
+ for raw in messages:
171
+ if not isinstance(raw, dict):
172
+ continue
173
+ content = raw.get("content")
174
+ if isinstance(content, str) and _REDACTED_MARKER.search(content):
175
+ return True
176
+ if isinstance(content, list):
177
+ for block in content:
178
+ if isinstance(block, dict):
179
+ text = block.get("text") or block.get("content")
180
+ if isinstance(text, str) and _REDACTED_MARKER.search(text):
181
+ return True
182
+ return False
183
+
184
+
185
+ def restore_session_from_log(session: Any, path: Path) -> dict[str, Any]:
186
+ """Replace the active session context with messages from ``path``.
187
+
188
+ Continues the saved session (reusing its id and on-disk save path) when
189
+ the log's ``user_id`` matches the current session, and forks otherwise:
190
+ the caller's session id stays put and future heartbeat saves go to a
191
+ fresh file rather than overwriting the source log.
192
+
193
+ Returns metadata for the ``resume_complete`` event.
194
+ """
195
+ with open(path) as f:
196
+ data = json.load(f)
197
+
198
+ raw_messages = data.get("messages")
199
+ if not isinstance(raw_messages, list):
200
+ raise ValueError("Selected log does not contain a messages array")
201
+
202
+ restored_messages: list[Message] = []
203
+ dropped_count = 0
204
+ for raw in raw_messages:
205
+ if not isinstance(raw, dict) or raw.get("role") == "system":
206
+ continue
207
+ try:
208
+ restored_messages.append(Message.model_validate(raw))
209
+ except Exception as e:
210
+ dropped_count += 1
211
+ logger.warning("Dropping malformed message from %s: %s", path, e)
212
+
213
+ if not restored_messages:
214
+ raise ValueError("Selected log has no restorable non-system messages")
215
+
216
+ cm = session.context_manager
217
+ system_msg = cm.items[0] if cm.items and cm.items[0].role == "system" else None
218
+ cm.items = ([system_msg] if system_msg else []) + restored_messages
219
+
220
+ # Validate the saved model id before switching. ``update_model`` doesn't
221
+ # check availability; an unrecognised id silently sticks and the next LLM
222
+ # call fails with a cryptic routing error. Logs from a different
223
+ # deployment, an older catalog, or a removed model land here.
224
+ saved_model = data.get("model_name")
225
+ invalid_saved_model: str | None = None
226
+ if isinstance(saved_model, str) and saved_model:
227
+ if is_valid_model_id(saved_model):
228
+ session.update_model(saved_model)
229
+ else:
230
+ invalid_saved_model = saved_model
231
+ logger.warning(
232
+ "Saved log model %r failed format validation; keeping %r",
233
+ saved_model,
234
+ session.config.model_name,
235
+ )
236
+
237
+ cm._recompute_usage(session.config.model_name)
238
+
239
+ saved_session_id = data.get("session_id")
240
+ saved_user_id = data.get("user_id")
241
+ is_continuation = saved_user_id == session.user_id
242
+
243
+ if is_continuation:
244
+ if isinstance(saved_session_id, str) and saved_session_id:
245
+ session.session_id = saved_session_id
246
+ session.session_start_time = (
247
+ data.get("session_start_time") or session.session_start_time
248
+ )
249
+
250
+ # Always fork the on-disk save path. The source log is treated as an
251
+ # immutable snapshot: ``logged_events`` is reset to a single
252
+ # ``resumed_from`` marker below for cost accounting, so reusing the
253
+ # source path would let the next heartbeat save destroy the original
254
+ # ``llm_call``/event history on disk. The next save will pick a fresh
255
+ # filename instead.
256
+ session._local_save_path = None
257
+
258
+ saved_event_count = (
259
+ len(data.get("events", [])) if isinstance(data.get("events"), list) else 0
260
+ )
261
+ session.logged_events = [
262
+ {
263
+ "timestamp": datetime.now().isoformat(),
264
+ "event_type": "resumed_from",
265
+ "data": {
266
+ "path": str(path),
267
+ "original_session_id": (
268
+ saved_session_id if isinstance(saved_session_id, str) else None
269
+ ),
270
+ "original_event_count": saved_event_count,
271
+ "forked": not is_continuation,
272
+ },
273
+ }
274
+ ]
275
+ session.turn_count = _turn_count_from_messages(raw_messages)
276
+ session.last_auto_save_turn = session.turn_count
277
+ session.pending_approval = None
278
+
279
+ return {
280
+ "path": str(path),
281
+ "restored_count": len(restored_messages),
282
+ "dropped_count": dropped_count,
283
+ "model_name": session.config.model_name,
284
+ "invalid_saved_model": invalid_saved_model,
285
+ "forked": not is_continuation,
286
+ "had_redacted_content": _has_redacted_content(raw_messages),
287
+ }
agent/main.py CHANGED
@@ -9,6 +9,7 @@ Supports two modes:
9
  import argparse
10
  import asyncio
11
  import json
 
12
  import os
13
  import signal
14
  import sys
@@ -55,6 +56,7 @@ litellm.drop_params = True
55
  litellm.suppress_debug_info = True
56
 
57
  CLI_CONFIG_PATH = Path(__file__).parent.parent / "configs" / "cli_agent_config.json"
 
58
 
59
 
60
  def _is_scheduled_hf_job_tool(tool_info: dict[str, Any]) -> bool:
@@ -368,6 +370,46 @@ async def event_listener(
368
  elif event.event_type == "undo_complete":
369
  console.print("[dim]Undone.[/dim]")
370
  turn_complete_event.set()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  elif event.event_type == "tool_log":
372
  tool = event.data.get("tool", "") if event.data else ""
373
  log = event.data.get("log", "") if event.data else ""
@@ -739,12 +781,69 @@ async def get_user_input(prompt_session: PromptSession) -> str:
739
  # Slash commands are defined in terminal_display
740
 
741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
742
  async def _handle_slash_command(
743
  cmd: str,
744
  config,
745
  session_holder: list,
746
  submission_queue: asyncio.Queue,
747
  submission_id: list[int],
 
748
  ) -> Submission | None:
749
  """
750
  Handle a slash command. Returns a Submission to enqueue, or None if
@@ -775,6 +874,24 @@ async def _handle_slash_command(
775
  operation=Operation(op_type=OpType.COMPACT),
776
  )
777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
778
  if command == "/model":
779
  console = get_console()
780
  if not arg:
@@ -1136,6 +1253,7 @@ async def main(model: str | None = None):
1136
  session_holder,
1137
  submission_queue,
1138
  submission_id,
 
1139
  )
1140
  if sub is None:
1141
  # Command handled locally, loop back for input
 
9
  import argparse
10
  import asyncio
11
  import json
12
+ import logging
13
  import os
14
  import signal
15
  import sys
 
56
  litellm.suppress_debug_info = True
57
 
58
  CLI_CONFIG_PATH = Path(__file__).parent.parent / "configs" / "cli_agent_config.json"
59
+ logger = logging.getLogger(__name__)
60
 
61
 
62
  def _is_scheduled_hf_job_tool(tool_info: dict[str, Any]) -> bool:
 
370
  elif event.event_type == "undo_complete":
371
  console.print("[dim]Undone.[/dim]")
372
  turn_complete_event.set()
373
+ elif event.event_type == "resume_complete":
374
+ data = event.data or {}
375
+ path = data.get("path", "?")
376
+ count = data.get("restored_count", 0)
377
+ dropped = int(data.get("dropped_count", 0) or 0)
378
+ model = data.get("model_name", "?")
379
+ invalid_model = data.get("invalid_saved_model")
380
+ forked = bool(data.get("forked", False))
381
+ redacted = bool(data.get("had_redacted_content", False))
382
+ verb = "Forked from" if forked else "Resumed"
383
+ console.print(
384
+ f"[green]{verb}[/green] {path} "
385
+ f"([cyan]{count}[/cyan] messages, "
386
+ f"model [cyan]{model}[/cyan])."
387
+ )
388
+ if dropped:
389
+ console.print(
390
+ f"[yellow]Warning:[/yellow] dropped {dropped} "
391
+ "malformed message(s) while restoring — surrounding "
392
+ "tool-call alignment may be off."
393
+ )
394
+ if invalid_model:
395
+ console.print(
396
+ f"[yellow]Warning:[/yellow] saved model id "
397
+ f"[cyan]{invalid_model}[/cyan] failed validation; "
398
+ f"kept current model [cyan]{model}[/cyan]."
399
+ )
400
+ if forked:
401
+ console.print(
402
+ "[dim]Saved log belongs to a different user — kept "
403
+ "current session id; future saves go to a fresh file.[/dim]"
404
+ )
405
+ if redacted:
406
+ console.print(
407
+ "[yellow]Note:[/yellow] tokens/secrets in restored "
408
+ "messages were scrubbed at save time. Your live tokens "
409
+ "are used for this session; [REDACTED_*] markers in "
410
+ "past messages are not re-injected."
411
+ )
412
+ turn_complete_event.set()
413
  elif event.event_type == "tool_log":
414
  tool = event.data.get("tool", "") if event.data else ""
415
  log = event.data.get("log", "") if event.data else ""
 
781
  # Slash commands are defined in terminal_display
782
 
783
 
784
+ async def _resume_picker(
785
+ arg: str,
786
+ prompt_session: PromptSession | None,
787
+ ) -> Path | None:
788
+ """Resolve a session log path via ``arg`` or interactive selection.
789
+
790
+ Returns ``None`` if the user cancels, no logs exist, or the argument
791
+ matches nothing — already prints the explanation in those cases.
792
+ """
793
+ from agent.core.session_resume import (
794
+ format_session_log_entry,
795
+ list_session_logs,
796
+ resolve_session_log_arg,
797
+ )
798
+ from agent.core.session import DEFAULT_SESSION_LOG_DIR
799
+
800
+ console = get_console()
801
+ directory = DEFAULT_SESSION_LOG_DIR
802
+ entries = list_session_logs(directory)
803
+ if not entries:
804
+ console.print(f"[yellow]No session logs found in ./{directory}.[/yellow]")
805
+ return None
806
+
807
+ if arg:
808
+ selected = resolve_session_log_arg(arg, entries, directory)
809
+ if selected is None:
810
+ console.print(f"[bold red]No matching session log:[/bold red] {arg}")
811
+ return selected
812
+
813
+ console.print()
814
+ console.print("[bold]Saved sessions[/bold]")
815
+ for index, entry in enumerate(entries, start=1):
816
+ console.print(format_session_log_entry(index, entry))
817
+ console.print()
818
+
819
+ if prompt_session is None:
820
+ console.print("[yellow]Cannot prompt for a selection here.[/yellow]")
821
+ return None
822
+
823
+ try:
824
+ choice = await prompt_session.prompt_async(
825
+ "Select session number (blank to cancel): "
826
+ )
827
+ except (EOFError, KeyboardInterrupt):
828
+ console.print("[dim]Resume cancelled.[/dim]")
829
+ return None
830
+ choice = choice.strip()
831
+ if not choice:
832
+ console.print("[dim]Resume cancelled.[/dim]")
833
+ return None
834
+ selected = resolve_session_log_arg(choice, entries, directory)
835
+ if selected is None:
836
+ console.print(f"[bold red]Invalid selection:[/bold red] {choice}")
837
+ return selected
838
+
839
+
840
  async def _handle_slash_command(
841
  cmd: str,
842
  config,
843
  session_holder: list,
844
  submission_queue: asyncio.Queue,
845
  submission_id: list[int],
846
+ prompt_session: PromptSession | None = None,
847
  ) -> Submission | None:
848
  """
849
  Handle a slash command. Returns a Submission to enqueue, or None if
 
874
  operation=Operation(op_type=OpType.COMPACT),
875
  )
876
 
877
+ if command == "/resume":
878
+ session = session_holder[0] if session_holder else None
879
+ if session is None:
880
+ get_console().print(
881
+ "[bold red]No active session to restore into.[/bold red]"
882
+ )
883
+ return None
884
+ selected_path = await _resume_picker(arg, prompt_session)
885
+ if selected_path is None:
886
+ return None
887
+ submission_id[0] += 1
888
+ return Submission(
889
+ id=f"sub_{submission_id[0]}",
890
+ operation=Operation(
891
+ op_type=OpType.RESUME, data={"path": str(selected_path)}
892
+ ),
893
+ )
894
+
895
  if command == "/model":
896
  console = get_console()
897
  if not arg:
 
1253
  session_holder,
1254
  submission_queue,
1255
  submission_id,
1256
+ prompt_session,
1257
  )
1258
  if sub is None:
1259
  # Command handled locally, loop back for input
agent/utils/terminal_display.py CHANGED
@@ -451,6 +451,7 @@ HELP_TEXT = f"""\
451
  {_I} [cyan]/help[/cyan] Show this help
452
  {_I} [cyan]/undo[/cyan] Undo last turn
453
  {_I} [cyan]/compact[/cyan] Compact context window
 
454
  {_I} [cyan]/model[/cyan] [id] Show available models or switch
455
  {_I} [cyan]/effort[/cyan] [level] Reasoning effort (minimal|low|medium|high|xhigh|max|off)
456
  {_I} [cyan]/yolo[/cyan] Toggle auto-approve mode
 
451
  {_I} [cyan]/help[/cyan] Show this help
452
  {_I} [cyan]/undo[/cyan] Undo last turn
453
  {_I} [cyan]/compact[/cyan] Compact context window
454
+ {_I} [cyan]/resume[/cyan] [index|id|path] Pick up from a log in ./session_logs
455
  {_I} [cyan]/model[/cyan] [id] Show available models or switch
456
  {_I} [cyan]/effort[/cyan] [level] Reasoning effort (minimal|low|medium|high|xhigh|max|off)
457
  {_I} [cyan]/yolo[/cyan] Toggle auto-approve mode
tests/unit/test_session_resume.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for ``agent.core.session_resume``."""
2
+
3
+ import json
4
+ import os
5
+ import time
6
+ from pathlib import Path
7
+ from types import SimpleNamespace
8
+
9
+ from litellm import Message
10
+
11
+ from agent.core import session_resume
12
+
13
+
14
+ def _write_session_log(
15
+ directory: Path,
16
+ name: str,
17
+ *,
18
+ session_id: str,
19
+ content: str,
20
+ mtime: float,
21
+ user_id: str | None = "user-a",
22
+ extra_messages: list[dict] | None = None,
23
+ events: list[dict] | None = None,
24
+ ) -> Path:
25
+ directory.mkdir(exist_ok=True)
26
+ path = directory / name
27
+ payload = {
28
+ "session_id": session_id,
29
+ "user_id": user_id,
30
+ "session_start_time": "2026-01-01T00:00:00",
31
+ "session_end_time": "2026-01-01T00:05:00",
32
+ "model_name": "openai/gpt-5.5",
33
+ "messages": [
34
+ {"role": "system", "content": "old system"},
35
+ {"role": "user", "content": content},
36
+ *(extra_messages or []),
37
+ ],
38
+ "events": events
39
+ if events is not None
40
+ else [{"event_type": "turn_complete", "data": {}}],
41
+ }
42
+ path.write_text(json.dumps(payload))
43
+ os.utime(path, (mtime, mtime))
44
+ return path
45
+
46
+
47
+ class _FakeContext:
48
+ def __init__(self) -> None:
49
+ self.items = [Message(role="system", content="current system")]
50
+ self.running_context_usage = 0
51
+ self.recompute_calls: list[str] = []
52
+
53
+ def _recompute_usage(self, model_name: str) -> None:
54
+ self.recompute_calls.append(model_name)
55
+ self.running_context_usage = 123
56
+
57
+
58
+ class _FakeSession:
59
+ def __init__(self, *, user_id: str | None = "user-a") -> None:
60
+ self.context_manager = _FakeContext()
61
+ self.config = SimpleNamespace(model_name="moonshotai/Kimi-K2.6")
62
+ self.session_id = "current-session"
63
+ self.session_start_time = "2026-01-02T00:00:00"
64
+ self.user_id = user_id
65
+ self.logged_events: list[dict] = []
66
+ self._local_save_path: str | None = None
67
+ self.turn_count = 0
68
+ self.last_auto_save_turn = 0
69
+ self.pending_approval: dict | None = {"tool_calls": ["pending"]}
70
+
71
+ def update_model(self, model_name: str) -> None:
72
+ self.config.model_name = model_name
73
+
74
+
75
+ def test_session_log_listing_newest_first(tmp_path):
76
+ log_dir = tmp_path / "session_logs"
77
+ older = _write_session_log(
78
+ log_dir,
79
+ "older.json",
80
+ session_id="older-session",
81
+ content="older prompt",
82
+ mtime=time.time() - 10,
83
+ )
84
+ newer = _write_session_log(
85
+ log_dir,
86
+ "newer.json",
87
+ session_id="newer-session",
88
+ content="newer prompt",
89
+ mtime=time.time(),
90
+ )
91
+
92
+ entries = session_resume.list_session_logs(log_dir)
93
+
94
+ assert [entry.path for entry in entries] == [newer, older]
95
+ assert entries[0].session_id == "newer-session"
96
+ assert entries[0].preview == "newer prompt"
97
+
98
+
99
+ def test_restore_continues_when_user_id_matches(tmp_path):
100
+ log_dir = tmp_path / "session_logs"
101
+ path = _write_session_log(
102
+ log_dir,
103
+ "session.json",
104
+ session_id="saved-session",
105
+ content="continue this work",
106
+ mtime=time.time(),
107
+ user_id="user-a",
108
+ )
109
+
110
+ session = _FakeSession(user_id="user-a")
111
+
112
+ result = session_resume.restore_session_from_log(session, path)
113
+
114
+ assert result["restored_count"] == 1
115
+ assert result["dropped_count"] == 0
116
+ assert result["forked"] is False
117
+ assert result["model_name"] == "openai/gpt-5.5"
118
+ assert result["had_redacted_content"] is False
119
+ assert result["invalid_saved_model"] is None
120
+ assert session.config.model_name == "openai/gpt-5.5"
121
+ assert session.session_id == "saved-session"
122
+ # Source log path is never reused: future heartbeat saves write to a
123
+ # fresh file so the snapshot stays intact (regression: see source-log
124
+ # round-trip test below).
125
+ assert session._local_save_path is None
126
+ assert session.turn_count == 1
127
+ assert session.last_auto_save_turn == 1
128
+ assert session.pending_approval is None
129
+ assert [msg.role for msg in session.context_manager.items] == ["system", "user"]
130
+ assert session.context_manager.items[0].content == "current system"
131
+ assert session.context_manager.items[1].content == "continue this work"
132
+ assert session.context_manager.running_context_usage == 123
133
+ assert session.context_manager.recompute_calls == ["openai/gpt-5.5"]
134
+ assert len(session.logged_events) == 1
135
+ marker = session.logged_events[0]
136
+ assert marker["event_type"] == "resumed_from"
137
+ assert marker["data"]["forked"] is False
138
+ assert marker["data"]["original_session_id"] == "saved-session"
139
+ assert marker["data"]["original_event_count"] == 1
140
+
141
+
142
+ def test_restore_forks_when_user_id_differs(tmp_path):
143
+ log_dir = tmp_path / "session_logs"
144
+ path = _write_session_log(
145
+ log_dir,
146
+ "session.json",
147
+ session_id="saved-session",
148
+ content="someone else's chat",
149
+ mtime=time.time(),
150
+ user_id="user-a",
151
+ )
152
+
153
+ session = _FakeSession(user_id="user-b")
154
+ original_session_id = session.session_id
155
+ original_start_time = session.session_start_time
156
+
157
+ result = session_resume.restore_session_from_log(session, path)
158
+
159
+ assert result["forked"] is True
160
+ assert session.session_id == original_session_id
161
+ assert session.session_start_time == original_start_time
162
+ assert session._local_save_path is None
163
+ marker = session.logged_events[0]
164
+ assert marker["event_type"] == "resumed_from"
165
+ assert marker["data"]["forked"] is True
166
+ assert marker["data"]["original_session_id"] == "saved-session"
167
+
168
+
169
+ def test_restore_forks_when_one_side_is_anonymous(tmp_path):
170
+ log_dir = tmp_path / "session_logs"
171
+ path = _write_session_log(
172
+ log_dir,
173
+ "session.json",
174
+ session_id="saved-session",
175
+ content="anonymous save",
176
+ mtime=time.time(),
177
+ user_id=None,
178
+ )
179
+
180
+ session = _FakeSession(user_id="user-a")
181
+
182
+ result = session_resume.restore_session_from_log(session, path)
183
+
184
+ assert result["forked"] is True
185
+ assert session._local_save_path is None
186
+
187
+
188
+ def test_restore_continues_when_both_sides_anonymous(tmp_path):
189
+ log_dir = tmp_path / "session_logs"
190
+ path = _write_session_log(
191
+ log_dir,
192
+ "session.json",
193
+ session_id="saved-session",
194
+ content="local-only chat",
195
+ mtime=time.time(),
196
+ user_id=None,
197
+ )
198
+
199
+ session = _FakeSession(user_id=None)
200
+
201
+ result = session_resume.restore_session_from_log(session, path)
202
+
203
+ assert result["forked"] is False
204
+ assert session.session_id == "saved-session"
205
+ assert session._local_save_path is None
206
+
207
+
208
+ def test_restore_rejects_invalid_saved_model(tmp_path):
209
+ log_dir = tmp_path / "session_logs"
210
+ path = log_dir / "session.json"
211
+ log_dir.mkdir()
212
+ path.write_text(
213
+ json.dumps(
214
+ {
215
+ "session_id": "saved",
216
+ "user_id": "user-a",
217
+ "model_name": "not a real id with spaces",
218
+ "messages": [{"role": "user", "content": "hello"}],
219
+ "events": [],
220
+ }
221
+ )
222
+ )
223
+
224
+ session = _FakeSession(user_id="user-a")
225
+ original_model = session.config.model_name
226
+
227
+ result = session_resume.restore_session_from_log(session, path)
228
+
229
+ assert result["invalid_saved_model"] == "not a real id with spaces"
230
+ assert result["model_name"] == original_model
231
+ assert session.config.model_name == original_model
232
+
233
+
234
+ def test_restore_counts_dropped_messages(tmp_path):
235
+ log_dir = tmp_path / "session_logs"
236
+ path = log_dir / "session.json"
237
+ log_dir.mkdir()
238
+ path.write_text(
239
+ json.dumps(
240
+ {
241
+ "session_id": "saved",
242
+ "user_id": "user-a",
243
+ "model_name": "openai/gpt-5.5",
244
+ "messages": [
245
+ {"role": "user", "content": "hi"},
246
+ {"role": "user", "content": 12345}, # invalid content type
247
+ ],
248
+ "events": [],
249
+ }
250
+ )
251
+ )
252
+
253
+ session = _FakeSession(user_id="user-a")
254
+
255
+ result = session_resume.restore_session_from_log(session, path)
256
+
257
+ assert result["restored_count"] == 1
258
+ assert result["dropped_count"] == 1
259
+
260
+
261
+ def test_restore_does_not_overwrite_source_log_on_save(tmp_path, monkeypatch):
262
+ """Regression: resuming + saving must not destroy the source log on disk.
263
+
264
+ Without the always-fork ``_local_save_path`` reset, the next heartbeat
265
+ save would rewrite the source file with ``events=[resumed_from]`` and
266
+ ``total_cost_usd=0``, wiping the original audit trail. This builds a
267
+ real ``Session`` and exercises the round-trip.
268
+ """
269
+ monkeypatch.chdir(tmp_path)
270
+
271
+ from agent.context_manager.manager import ContextManager
272
+ from agent.core.session import Session
273
+
274
+ log_dir = tmp_path / "session_logs"
275
+ log_dir.mkdir()
276
+ src_path = log_dir / "src.json"
277
+ src_payload = {
278
+ "session_id": "saved-session",
279
+ "user_id": "user-a",
280
+ "session_start_time": "2026-01-01T00:00:00",
281
+ "session_end_time": "2026-01-01T00:05:00",
282
+ "model_name": "openai/gpt-5.5",
283
+ "messages": [
284
+ {"role": "system", "content": "old system"},
285
+ {"role": "user", "content": "earlier work"},
286
+ ],
287
+ "events": [
288
+ {"event_type": "llm_call", "data": {"cost_usd": 0.42}},
289
+ {"event_type": "turn_complete", "data": {}},
290
+ ],
291
+ }
292
+ src_path.write_text(json.dumps(src_payload, indent=2))
293
+ src_bytes_before = src_path.read_bytes()
294
+
295
+ class _Cfg:
296
+ model_name = "openai/gpt-5.5"
297
+ save_sessions = True
298
+ session_dataset_repo = None
299
+ auto_save_interval = 1
300
+ heartbeat_interval_s = 60
301
+ max_iterations = 10
302
+ yolo_mode = False
303
+ confirm_cpu_jobs = False
304
+ auto_file_upload = False
305
+ reasoning_effort = None
306
+ share_traces = False
307
+ personal_trace_repo_template = None
308
+ mcpServers: dict = {}
309
+
310
+ cm = ContextManager.__new__(ContextManager)
311
+ cm.items = [Message(role="system", content="current system")]
312
+ cm.tool_specs = []
313
+ cm.model_max_tokens = 200_000
314
+ cm.running_context_usage = 0
315
+ cm.compact_size = 0.1
316
+ cm.untouched_messages = 5
317
+ cm.hf_token = None
318
+ cm.local_mode = True
319
+ cm.system_prompt = "current system"
320
+ cm.on_message_added = None
321
+
322
+ import asyncio as _asyncio
323
+
324
+ session = Session(
325
+ event_queue=_asyncio.Queue(),
326
+ config=_Cfg(),
327
+ tool_router=None,
328
+ context_manager=cm,
329
+ hf_token=None,
330
+ user_id="user-a",
331
+ local_mode=True,
332
+ )
333
+
334
+ session_resume.restore_session_from_log(session, src_path)
335
+ assert session._local_save_path is None
336
+
337
+ saved_path = session.save_trajectory_local(directory=str(log_dir))
338
+
339
+ assert saved_path is not None
340
+ assert Path(saved_path) != src_path
341
+ assert src_path.read_bytes() == src_bytes_before
342
+
343
+
344
+ def test_restore_flags_redacted_messages(tmp_path):
345
+ log_dir = tmp_path / "session_logs"
346
+ path = _write_session_log(
347
+ log_dir,
348
+ "session.json",
349
+ session_id="saved-session",
350
+ content="my token is [REDACTED_HF_TOKEN]",
351
+ mtime=time.time(),
352
+ user_id="user-a",
353
+ )
354
+
355
+ session = _FakeSession(user_id="user-a")
356
+
357
+ result = session_resume.restore_session_from_log(session, path)
358
+
359
+ assert result["had_redacted_content"] is True
360
+
361
+
362
+ def test_resolve_session_log_arg_accepts_index_and_id_prefix(tmp_path):
363
+ log_dir = tmp_path / "session_logs"
364
+ older = _write_session_log(
365
+ log_dir,
366
+ "older.json",
367
+ session_id="abcdef-older",
368
+ content="x",
369
+ mtime=time.time() - 10,
370
+ )
371
+ newer = _write_session_log(
372
+ log_dir,
373
+ "newer.json",
374
+ session_id="123456-newer",
375
+ content="y",
376
+ mtime=time.time(),
377
+ )
378
+ entries = session_resume.list_session_logs(log_dir)
379
+
380
+ assert session_resume.resolve_session_log_arg("1", entries, log_dir) == newer
381
+ assert session_resume.resolve_session_log_arg("abc", entries, log_dir) == older
382
+ assert session_resume.resolve_session_log_arg("nope", entries, log_dir) is None