Deploy 2026-05-08
Browse filesCo-authored-by: OpenAI Codex <codex@openai.com>
- agent/core/agent_loop.py +27 -4
- agent/core/hub_artifacts.py +40 -72
- agent/core/session.py +5 -2
- agent/core/session_resume.py +287 -0
- agent/main.py +118 -0
- agent/tools/sandbox_client.py +15 -71
- agent/utils/terminal_display.py +1 -0
- backend/session_manager.py +0 -29
- tests/unit/test_hub_artifacts.py +74 -60
- tests/unit/test_sandbox_private_spaces.py +51 -59
- tests/unit/test_session_manager_persistence.py +5 -57
- tests/unit/test_session_resume.py +382 -0
agent/core/agent_loop.py
CHANGED
|
@@ -7,6 +7,7 @@ import json
|
|
| 7 |
import logging
|
| 8 |
import time
|
| 9 |
from dataclasses import dataclass, field
|
|
|
|
| 10 |
from typing import Any
|
| 11 |
|
| 12 |
from litellm import (
|
|
@@ -26,10 +27,9 @@ from agent.core.cost_estimation import CostEstimate, estimate_tool_cost
|
|
| 26 |
from agent.messaging.gateway import NotificationGateway
|
| 27 |
from agent.core import telemetry
|
| 28 |
from agent.core.doom_loop import check_for_doom_loop
|
| 29 |
-
from agent.core.hub_artifacts import start_session_artifact_collection_task
|
| 30 |
from agent.core.llm_params import _resolve_llm_params
|
| 31 |
from agent.core.prompt_caching import with_prompt_caching
|
| 32 |
-
from agent.core.session import Event, OpType, Session
|
| 33 |
from agent.core.tools import ToolRouter
|
| 34 |
from agent.tools.jobs_tool import CPU_FLAVORS
|
| 35 |
from agent.tools.sandbox_tool import DEFAULT_CPU_SANDBOX_HARDWARE
|
|
@@ -1667,6 +1667,20 @@ class Handlers:
|
|
| 1667 |
logger.warning("Undo: no user message found to remove")
|
| 1668 |
await session.send_event(Event(event_type="undo_complete"))
|
| 1669 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1670 |
@staticmethod
|
| 1671 |
async def exec_approval(session: Session, approvals: list[dict]) -> None:
|
| 1672 |
"""Handle batch job execution approval"""
|
|
@@ -1953,6 +1967,16 @@ async def process_submission(session: Session, submission) -> bool:
|
|
| 1953 |
await Handlers.undo(session)
|
| 1954 |
return True
|
| 1955 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1956 |
if op.op_type == OpType.EXEC_APPROVAL:
|
| 1957 |
approvals = op.data.get("approvals", []) if op.data else []
|
| 1958 |
await Handlers.exec_approval(session, approvals)
|
|
@@ -1999,7 +2023,6 @@ async def submission_loop(
|
|
| 1999 |
)
|
| 2000 |
if session_holder is not None:
|
| 2001 |
session_holder[0] = session
|
| 2002 |
-
start_session_artifact_collection_task(session, token=hf_token)
|
| 2003 |
logger.info("Agent loop started")
|
| 2004 |
|
| 2005 |
# Retry any failed uploads from previous sessions (fire-and-forget).
|
|
@@ -2007,7 +2030,7 @@ async def submission_loop(
|
|
| 2007 |
# to publish to the user's HF dataset gets a fresh attempt on next run.
|
| 2008 |
if config and config.save_sessions:
|
| 2009 |
Session.retry_failed_uploads_detached(
|
| 2010 |
-
directory=
|
| 2011 |
repo_id=config.session_dataset_repo,
|
| 2012 |
personal_repo_id=session._personal_trace_repo_id(),
|
| 2013 |
)
|
|
|
|
| 7 |
import logging
|
| 8 |
import time
|
| 9 |
from dataclasses import dataclass, field
|
| 10 |
+
from pathlib import Path
|
| 11 |
from typing import Any
|
| 12 |
|
| 13 |
from litellm import (
|
|
|
|
| 27 |
from agent.messaging.gateway import NotificationGateway
|
| 28 |
from agent.core import telemetry
|
| 29 |
from agent.core.doom_loop import check_for_doom_loop
|
|
|
|
| 30 |
from agent.core.llm_params import _resolve_llm_params
|
| 31 |
from agent.core.prompt_caching import with_prompt_caching
|
| 32 |
+
from agent.core.session import DEFAULT_SESSION_LOG_DIR, Event, OpType, Session
|
| 33 |
from agent.core.tools import ToolRouter
|
| 34 |
from agent.tools.jobs_tool import CPU_FLAVORS
|
| 35 |
from agent.tools.sandbox_tool import DEFAULT_CPU_SANDBOX_HARDWARE
|
|
|
|
| 1667 |
logger.warning("Undo: no user message found to remove")
|
| 1668 |
await session.send_event(Event(event_type="undo_complete"))
|
| 1669 |
|
| 1670 |
+
@staticmethod
|
| 1671 |
+
async def resume(session: Session, path: str) -> None:
|
| 1672 |
+
"""Reload context from a saved session log into the active session."""
|
| 1673 |
+
from agent.core.session_resume import restore_session_from_log
|
| 1674 |
+
|
| 1675 |
+
try:
|
| 1676 |
+
result = restore_session_from_log(session, Path(path))
|
| 1677 |
+
except Exception as e:
|
| 1678 |
+
await session.send_event(
|
| 1679 |
+
Event(event_type="error", data={"error": f"Resume failed: {e}"})
|
| 1680 |
+
)
|
| 1681 |
+
return
|
| 1682 |
+
await session.send_event(Event(event_type="resume_complete", data=result))
|
| 1683 |
+
|
| 1684 |
@staticmethod
|
| 1685 |
async def exec_approval(session: Session, approvals: list[dict]) -> None:
|
| 1686 |
"""Handle batch job execution approval"""
|
|
|
|
| 1967 |
await Handlers.undo(session)
|
| 1968 |
return True
|
| 1969 |
|
| 1970 |
+
if op.op_type == OpType.RESUME:
|
| 1971 |
+
path = op.data.get("path") if op.data else None
|
| 1972 |
+
if path:
|
| 1973 |
+
await Handlers.resume(session, path)
|
| 1974 |
+
else:
|
| 1975 |
+
await session.send_event(
|
| 1976 |
+
Event(event_type="error", data={"error": "Resume requires a path"})
|
| 1977 |
+
)
|
| 1978 |
+
return True
|
| 1979 |
+
|
| 1980 |
if op.op_type == OpType.EXEC_APPROVAL:
|
| 1981 |
approvals = op.data.get("approvals", []) if op.data else []
|
| 1982 |
await Handlers.exec_approval(session, approvals)
|
|
|
|
| 2023 |
)
|
| 2024 |
if session_holder is not None:
|
| 2025 |
session_holder[0] = session
|
|
|
|
| 2026 |
logger.info("Agent loop started")
|
| 2027 |
|
| 2028 |
# Retry any failed uploads from previous sessions (fire-and-forget).
|
|
|
|
| 2030 |
# to publish to the user's HF dataset gets a fresh attempt on next run.
|
| 2031 |
if config and config.save_sessions:
|
| 2032 |
Session.retry_failed_uploads_detached(
|
| 2033 |
+
directory=str(DEFAULT_SESSION_LOG_DIR),
|
| 2034 |
repo_id=config.session_dataset_repo,
|
| 2035 |
personal_repo_id=session._personal_trace_repo_id(),
|
| 2036 |
)
|
agent/core/hub_artifacts.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
"""Best-effort Hub metadata for artifacts generated by ML Intern sessions."""
|
| 2 |
|
| 3 |
-
import asyncio
|
| 4 |
import base64
|
| 5 |
import logging
|
| 6 |
import re
|
|
@@ -11,7 +10,7 @@ from datetime import datetime
|
|
| 11 |
from pathlib import Path
|
| 12 |
from typing import Any
|
| 13 |
|
| 14 |
-
from huggingface_hub import
|
| 15 |
from huggingface_hub.repocard import metadata_load, metadata_save
|
| 16 |
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
|
| 17 |
|
|
@@ -29,7 +28,6 @@ _UUID_SESSION_ID_RE = re.compile(
|
|
| 29 |
_KNOWN_ARTIFACTS_ATTR = "_ml_intern_known_hub_artifacts"
|
| 30 |
_REGISTERED_ARTIFACTS_ATTR = "_ml_intern_registered_hub_artifacts"
|
| 31 |
_COLLECTION_SLUG_ATTR = "_ml_intern_artifact_collection_slug"
|
| 32 |
-
_COLLECTION_TASK_ATTR = "_ml_intern_artifact_collection_task"
|
| 33 |
_SESSION_ARTIFACT_SET_FALLBACK: dict[tuple[int, str], set[str]] = {}
|
| 34 |
_USAGE_HEADING_RE = re.compile(
|
| 35 |
r"^#{2,6}\s+(usage|how to use|using this (model|dataset)|use this (model|dataset))\b",
|
|
@@ -307,70 +305,6 @@ def _ensure_collection_slug(
|
|
| 307 |
return slug
|
| 308 |
|
| 309 |
|
| 310 |
-
async def ensure_session_artifact_collection(
|
| 311 |
-
session: Any,
|
| 312 |
-
*,
|
| 313 |
-
token: str | bool | None = None,
|
| 314 |
-
) -> str | None:
|
| 315 |
-
"""Create/cache the per-session artifact collection without raising."""
|
| 316 |
-
if session is None or not getattr(session, "session_id", None):
|
| 317 |
-
return None
|
| 318 |
-
token_value = token if token is not None else getattr(session, "hf_token", None)
|
| 319 |
-
if not token_value:
|
| 320 |
-
return None
|
| 321 |
-
|
| 322 |
-
try:
|
| 323 |
-
api = HfApi(token=token_value)
|
| 324 |
-
return await asyncio.to_thread(
|
| 325 |
-
_ensure_collection_slug,
|
| 326 |
-
api,
|
| 327 |
-
session,
|
| 328 |
-
token=token_value,
|
| 329 |
-
)
|
| 330 |
-
except Exception as e:
|
| 331 |
-
logger.warning(
|
| 332 |
-
"ML Intern session collection creation failed for %s: %s",
|
| 333 |
-
_safe_session_id(session),
|
| 334 |
-
e,
|
| 335 |
-
)
|
| 336 |
-
return None
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
def start_session_artifact_collection_task(
|
| 340 |
-
session: Any,
|
| 341 |
-
*,
|
| 342 |
-
token: str | bool | None = None,
|
| 343 |
-
) -> asyncio.Task | None:
|
| 344 |
-
"""Schedule best-effort collection creation for a newly started session."""
|
| 345 |
-
if session is None or not getattr(session, "session_id", None):
|
| 346 |
-
return None
|
| 347 |
-
if getattr(session, _COLLECTION_SLUG_ATTR, None):
|
| 348 |
-
return None
|
| 349 |
-
|
| 350 |
-
token_value = token if token is not None else getattr(session, "hf_token", None)
|
| 351 |
-
if not token_value:
|
| 352 |
-
return None
|
| 353 |
-
|
| 354 |
-
existing = getattr(session, _COLLECTION_TASK_ATTR, None)
|
| 355 |
-
if isinstance(existing, asyncio.Task) and not existing.done():
|
| 356 |
-
return existing
|
| 357 |
-
|
| 358 |
-
try:
|
| 359 |
-
loop = asyncio.get_running_loop()
|
| 360 |
-
except RuntimeError:
|
| 361 |
-
return None
|
| 362 |
-
|
| 363 |
-
async def _run() -> None:
|
| 364 |
-
await ensure_session_artifact_collection(session, token=token_value)
|
| 365 |
-
|
| 366 |
-
task = loop.create_task(_run())
|
| 367 |
-
try:
|
| 368 |
-
setattr(session, _COLLECTION_TASK_ATTR, task)
|
| 369 |
-
except Exception:
|
| 370 |
-
logger.debug("Could not attach ML Intern collection task to session")
|
| 371 |
-
return task
|
| 372 |
-
|
| 373 |
-
|
| 374 |
def _add_to_collection(
|
| 375 |
api: Any,
|
| 376 |
session: Any,
|
|
@@ -378,10 +312,10 @@ def _add_to_collection(
|
|
| 378 |
repo_type: str,
|
| 379 |
*,
|
| 380 |
token: str | bool | None = None,
|
| 381 |
-
) ->
|
| 382 |
slug = _ensure_collection_slug(api, session, token=token)
|
| 383 |
if not slug:
|
| 384 |
-
return
|
| 385 |
api.add_collection_item(
|
| 386 |
collection_slug=slug,
|
| 387 |
item_id=repo_id,
|
|
@@ -393,6 +327,7 @@ def _add_to_collection(
|
|
| 393 |
exists_ok=True,
|
| 394 |
token=token,
|
| 395 |
)
|
|
|
|
| 396 |
|
| 397 |
|
| 398 |
def register_hub_artifact(
|
|
@@ -436,8 +371,13 @@ def register_hub_artifact(
|
|
| 436 |
logger.debug("ML Intern repo-card update failed for %s: %s", repo_id, e)
|
| 437 |
|
| 438 |
try:
|
| 439 |
-
|
| 440 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 441 |
except Exception as e:
|
| 442 |
logger.debug("ML Intern collection update failed for %s: %s", repo_id, e)
|
| 443 |
|
|
@@ -490,6 +430,13 @@ def build_hub_artifact_sitecustomize(session: Any) -> str:
|
|
| 490 |
re.IGNORECASE | re.MULTILINE,
|
| 491 |
)
|
| 492 |
front_matter_re = re.compile(r"\\A---\\s*\\n.*?\\n---\\s*\\n?", re.DOTALL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
|
| 494 |
def _token(value=None, api=None):
|
| 495 |
if isinstance(value, str) and value:
|
|
@@ -602,6 +549,15 @@ def build_hub_artifact_sitecustomize(session: Any) -> str:
|
|
| 602 |
nonlocal collection_slug
|
| 603 |
if collection_slug:
|
| 604 |
return collection_slug
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
collection = api.create_collection(
|
| 606 |
title=collection_title,
|
| 607 |
description=(
|
|
@@ -613,6 +569,13 @@ def build_hub_artifact_sitecustomize(session: Any) -> str:
|
|
| 613 |
token=token_value,
|
| 614 |
)
|
| 615 |
collection_slug = getattr(collection, "slug", None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 616 |
return collection_slug
|
| 617 |
|
| 618 |
def _register(
|
|
@@ -637,6 +600,7 @@ def build_hub_artifact_sitecustomize(session: Any) -> str:
|
|
| 637 |
try:
|
| 638 |
token_value = _token(token_value)
|
| 639 |
api = HfApi(token=token_value)
|
|
|
|
| 640 |
try:
|
| 641 |
current = _readme(api, repo_id, repo_type, token_value)
|
| 642 |
updated = _augment(
|
|
@@ -652,8 +616,10 @@ def build_hub_artifact_sitecustomize(session: Any) -> str:
|
|
| 652 |
token=token_value,
|
| 653 |
commit_message="Update ML Intern artifact metadata",
|
| 654 |
)
|
|
|
|
| 655 |
except Exception:
|
| 656 |
pass
|
|
|
|
| 657 |
try:
|
| 658 |
slug = _ensure_collection(api, token_value)
|
| 659 |
if slug:
|
|
@@ -668,9 +634,11 @@ def build_hub_artifact_sitecustomize(session: Any) -> str:
|
|
| 668 |
exists_ok=True,
|
| 669 |
token=token_value,
|
| 670 |
)
|
|
|
|
| 671 |
except Exception:
|
| 672 |
pass
|
| 673 |
-
|
|
|
|
| 674 |
finally:
|
| 675 |
registering = False
|
| 676 |
|
|
|
|
| 1 |
"""Best-effort Hub metadata for artifacts generated by ML Intern sessions."""
|
| 2 |
|
|
|
|
| 3 |
import base64
|
| 4 |
import logging
|
| 5 |
import re
|
|
|
|
| 10 |
from pathlib import Path
|
| 11 |
from typing import Any
|
| 12 |
|
| 13 |
+
from huggingface_hub import hf_hub_download
|
| 14 |
from huggingface_hub.repocard import metadata_load, metadata_save
|
| 15 |
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
|
| 16 |
|
|
|
|
| 28 |
_KNOWN_ARTIFACTS_ATTR = "_ml_intern_known_hub_artifacts"
|
| 29 |
_REGISTERED_ARTIFACTS_ATTR = "_ml_intern_registered_hub_artifacts"
|
| 30 |
_COLLECTION_SLUG_ATTR = "_ml_intern_artifact_collection_slug"
|
|
|
|
| 31 |
_SESSION_ARTIFACT_SET_FALLBACK: dict[tuple[int, str], set[str]] = {}
|
| 32 |
_USAGE_HEADING_RE = re.compile(
|
| 33 |
r"^#{2,6}\s+(usage|how to use|using this (model|dataset)|use this (model|dataset))\b",
|
|
|
|
| 305 |
return slug
|
| 306 |
|
| 307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
def _add_to_collection(
|
| 309 |
api: Any,
|
| 310 |
session: Any,
|
|
|
|
| 312 |
repo_type: str,
|
| 313 |
*,
|
| 314 |
token: str | bool | None = None,
|
| 315 |
+
) -> bool:
|
| 316 |
slug = _ensure_collection_slug(api, session, token=token)
|
| 317 |
if not slug:
|
| 318 |
+
return False
|
| 319 |
api.add_collection_item(
|
| 320 |
collection_slug=slug,
|
| 321 |
item_id=repo_id,
|
|
|
|
| 327 |
exists_ok=True,
|
| 328 |
token=token,
|
| 329 |
)
|
| 330 |
+
return True
|
| 331 |
|
| 332 |
|
| 333 |
def register_hub_artifact(
|
|
|
|
| 371 |
logger.debug("ML Intern repo-card update failed for %s: %s", repo_id, e)
|
| 372 |
|
| 373 |
try:
|
| 374 |
+
collection_updated = _add_to_collection(
|
| 375 |
+
api,
|
| 376 |
+
session,
|
| 377 |
+
repo_id,
|
| 378 |
+
repo_type,
|
| 379 |
+
token=token_value,
|
| 380 |
+
)
|
| 381 |
except Exception as e:
|
| 382 |
logger.debug("ML Intern collection update failed for %s: %s", repo_id, e)
|
| 383 |
|
|
|
|
| 430 |
re.IGNORECASE | re.MULTILINE,
|
| 431 |
)
|
| 432 |
front_matter_re = re.compile(r"\\A---\\s*\\n.*?\\n---\\s*\\n?", re.DOTALL)
|
| 433 |
+
collection_cache_path = (
|
| 434 |
+
os.environ.get("ML_INTERN_ARTIFACT_COLLECTION_CACHE")
|
| 435 |
+
or str(
|
| 436 |
+
Path(tempfile.gettempdir())
|
| 437 |
+
/ f"ml-intern-artifacts-{{session_id}}.collection"
|
| 438 |
+
)
|
| 439 |
+
)
|
| 440 |
|
| 441 |
def _token(value=None, api=None):
|
| 442 |
if isinstance(value, str) and value:
|
|
|
|
| 549 |
nonlocal collection_slug
|
| 550 |
if collection_slug:
|
| 551 |
return collection_slug
|
| 552 |
+
try:
|
| 553 |
+
cached_slug = Path(collection_cache_path).read_text(
|
| 554 |
+
encoding="utf-8"
|
| 555 |
+
).strip()
|
| 556 |
+
if cached_slug:
|
| 557 |
+
collection_slug = cached_slug
|
| 558 |
+
return collection_slug
|
| 559 |
+
except Exception:
|
| 560 |
+
pass
|
| 561 |
collection = api.create_collection(
|
| 562 |
title=collection_title,
|
| 563 |
description=(
|
|
|
|
| 569 |
token=token_value,
|
| 570 |
)
|
| 571 |
collection_slug = getattr(collection, "slug", None)
|
| 572 |
+
if collection_slug:
|
| 573 |
+
try:
|
| 574 |
+
cache_path = Path(collection_cache_path)
|
| 575 |
+
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
| 576 |
+
cache_path.write_text(collection_slug, encoding="utf-8")
|
| 577 |
+
except Exception:
|
| 578 |
+
pass
|
| 579 |
return collection_slug
|
| 580 |
|
| 581 |
def _register(
|
|
|
|
| 600 |
try:
|
| 601 |
token_value = _token(token_value)
|
| 602 |
api = HfApi(token=token_value)
|
| 603 |
+
card_updated = False
|
| 604 |
try:
|
| 605 |
current = _readme(api, repo_id, repo_type, token_value)
|
| 606 |
updated = _augment(
|
|
|
|
| 616 |
token=token_value,
|
| 617 |
commit_message="Update ML Intern artifact metadata",
|
| 618 |
)
|
| 619 |
+
card_updated = True
|
| 620 |
except Exception:
|
| 621 |
pass
|
| 622 |
+
collection_updated = False
|
| 623 |
try:
|
| 624 |
slug = _ensure_collection(api, token_value)
|
| 625 |
if slug:
|
|
|
|
| 634 |
exists_ok=True,
|
| 635 |
token=token_value,
|
| 636 |
)
|
| 637 |
+
collection_updated = True
|
| 638 |
except Exception:
|
| 639 |
pass
|
| 640 |
+
if card_updated and collection_updated:
|
| 641 |
+
registered.add(key)
|
| 642 |
finally:
|
| 643 |
registering = False
|
| 644 |
|
agent/core/session.py
CHANGED
|
@@ -21,6 +21,8 @@ logger = logging.getLogger(__name__)
|
|
| 21 |
_DEFAULT_MAX_TOKENS = 200_000
|
| 22 |
_TURN_COMPLETE_NOTIFICATION_CHARS = 39000
|
| 23 |
|
|
|
|
|
|
|
| 24 |
|
| 25 |
def _get_max_tokens_safe(model_name: str) -> int:
|
| 26 |
"""Return the max input-context tokens for a model.
|
|
@@ -60,6 +62,7 @@ class OpType(Enum):
|
|
| 60 |
INTERRUPT = "interrupt"
|
| 61 |
UNDO = "undo"
|
| 62 |
COMPACT = "compact"
|
|
|
|
| 63 |
SHUTDOWN = "shutdown"
|
| 64 |
|
| 65 |
|
|
@@ -418,7 +421,7 @@ class Session:
|
|
| 418 |
|
| 419 |
def save_trajectory_local(
|
| 420 |
self,
|
| 421 |
-
directory: str =
|
| 422 |
upload_status: str = "pending",
|
| 423 |
dataset_url: Optional[str] = None,
|
| 424 |
) -> Optional[str]:
|
|
@@ -613,7 +616,7 @@ class Session:
|
|
| 613 |
|
| 614 |
@staticmethod
|
| 615 |
def retry_failed_uploads_detached(
|
| 616 |
-
directory: str =
|
| 617 |
repo_id: Optional[str] = None,
|
| 618 |
*,
|
| 619 |
personal_repo_id: Optional[str] = None,
|
|
|
|
| 21 |
_DEFAULT_MAX_TOKENS = 200_000
|
| 22 |
_TURN_COMPLETE_NOTIFICATION_CHARS = 39000
|
| 23 |
|
| 24 |
+
DEFAULT_SESSION_LOG_DIR = Path("session_logs")
|
| 25 |
+
|
| 26 |
|
| 27 |
def _get_max_tokens_safe(model_name: str) -> int:
|
| 28 |
"""Return the max input-context tokens for a model.
|
|
|
|
| 62 |
INTERRUPT = "interrupt"
|
| 63 |
UNDO = "undo"
|
| 64 |
COMPACT = "compact"
|
| 65 |
+
RESUME = "resume"
|
| 66 |
SHUTDOWN = "shutdown"
|
| 67 |
|
| 68 |
|
|
|
|
| 421 |
|
| 422 |
def save_trajectory_local(
|
| 423 |
self,
|
| 424 |
+
directory: str = str(DEFAULT_SESSION_LOG_DIR),
|
| 425 |
upload_status: str = "pending",
|
| 426 |
dataset_url: Optional[str] = None,
|
| 427 |
) -> Optional[str]:
|
|
|
|
| 616 |
|
| 617 |
@staticmethod
|
| 618 |
def retry_failed_uploads_detached(
|
| 619 |
+
directory: str = str(DEFAULT_SESSION_LOG_DIR),
|
| 620 |
repo_id: Optional[str] = None,
|
| 621 |
*,
|
| 622 |
personal_repo_id: Optional[str] = None,
|
agent/core/session_resume.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reload a previously saved session log into the active CLI session."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import re
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
from litellm import Message
|
| 14 |
+
|
| 15 |
+
from agent.core.model_switcher import is_valid_model_id
|
| 16 |
+
from agent.core.session import DEFAULT_SESSION_LOG_DIR
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
_REDACTED_MARKER = re.compile(r"\[REDACTED_[A-Z_]+\]")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class SessionLogEntry:
|
| 25 |
+
"""Metadata for a locally saved session log."""
|
| 26 |
+
|
| 27 |
+
path: Path
|
| 28 |
+
session_id: str
|
| 29 |
+
session_start_time: str | None
|
| 30 |
+
session_end_time: str | None
|
| 31 |
+
model_name: str | None
|
| 32 |
+
message_count: int
|
| 33 |
+
preview: str
|
| 34 |
+
mtime: float
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _message_preview(content: Any, max_chars: int = 72) -> str:
|
| 38 |
+
"""Return a one-line preview for string or OpenAI-style block content."""
|
| 39 |
+
if isinstance(content, str):
|
| 40 |
+
text = content
|
| 41 |
+
elif isinstance(content, list):
|
| 42 |
+
parts: list[str] = []
|
| 43 |
+
for block in content:
|
| 44 |
+
if isinstance(block, dict):
|
| 45 |
+
value = block.get("text") or block.get("content")
|
| 46 |
+
if isinstance(value, str):
|
| 47 |
+
parts.append(value)
|
| 48 |
+
elif isinstance(block, str):
|
| 49 |
+
parts.append(block)
|
| 50 |
+
text = " ".join(parts)
|
| 51 |
+
else:
|
| 52 |
+
text = ""
|
| 53 |
+
text = " ".join(text.split())
|
| 54 |
+
if len(text) > max_chars:
|
| 55 |
+
return text[: max_chars - 1].rstrip() + "…"
|
| 56 |
+
return text
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _first_user_preview(messages: list[Any]) -> str:
|
| 60 |
+
for raw in messages:
|
| 61 |
+
if isinstance(raw, dict) and raw.get("role") == "user":
|
| 62 |
+
preview = _message_preview(raw.get("content"))
|
| 63 |
+
if preview:
|
| 64 |
+
return preview
|
| 65 |
+
return "(no user prompt preview)"
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def list_session_logs(
|
| 69 |
+
directory: Path = DEFAULT_SESSION_LOG_DIR,
|
| 70 |
+
) -> list[SessionLogEntry]:
|
| 71 |
+
"""Return readable session logs under ``directory``, newest first."""
|
| 72 |
+
if not directory.exists():
|
| 73 |
+
return []
|
| 74 |
+
|
| 75 |
+
entries: list[SessionLogEntry] = []
|
| 76 |
+
for path in directory.glob("*.json"):
|
| 77 |
+
try:
|
| 78 |
+
with open(path) as f:
|
| 79 |
+
data = json.load(f)
|
| 80 |
+
except Exception:
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
messages = data.get("messages") or []
|
| 84 |
+
if not isinstance(messages, list):
|
| 85 |
+
continue
|
| 86 |
+
|
| 87 |
+
session_id = data.get("session_id")
|
| 88 |
+
if not isinstance(session_id, str) or not session_id:
|
| 89 |
+
session_id = path.stem
|
| 90 |
+
|
| 91 |
+
stat = path.stat()
|
| 92 |
+
entries.append(
|
| 93 |
+
SessionLogEntry(
|
| 94 |
+
path=path,
|
| 95 |
+
session_id=session_id,
|
| 96 |
+
session_start_time=data.get("session_start_time"),
|
| 97 |
+
session_end_time=data.get("session_end_time"),
|
| 98 |
+
model_name=data.get("model_name"),
|
| 99 |
+
message_count=len(messages),
|
| 100 |
+
preview=_first_user_preview(messages),
|
| 101 |
+
mtime=stat.st_mtime,
|
| 102 |
+
)
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
entries.sort(key=lambda item: item.mtime, reverse=True)
|
| 106 |
+
return entries
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def format_session_log_entry(index: int, entry: SessionLogEntry) -> str:
|
| 110 |
+
timestamp = entry.session_end_time or entry.session_start_time
|
| 111 |
+
label = "unknown time"
|
| 112 |
+
if isinstance(timestamp, str) and timestamp:
|
| 113 |
+
try:
|
| 114 |
+
label = datetime.fromisoformat(timestamp).strftime("%Y-%m-%d %H:%M")
|
| 115 |
+
except ValueError:
|
| 116 |
+
label = timestamp[:16]
|
| 117 |
+
short_id = entry.session_id[:8]
|
| 118 |
+
model = entry.model_name or "unknown model"
|
| 119 |
+
return (
|
| 120 |
+
f"{index:>2}. {label} {short_id} "
|
| 121 |
+
f"{entry.message_count} msgs {model}\n"
|
| 122 |
+
f" {entry.preview}"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def resolve_session_log_arg(
|
| 127 |
+
arg: str,
|
| 128 |
+
entries: list[SessionLogEntry],
|
| 129 |
+
directory: Path = DEFAULT_SESSION_LOG_DIR,
|
| 130 |
+
) -> Path | None:
|
| 131 |
+
"""Resolve ``/resume <arg>`` as index, path, filename, or session id prefix."""
|
| 132 |
+
value = arg.strip()
|
| 133 |
+
if not value:
|
| 134 |
+
return None
|
| 135 |
+
|
| 136 |
+
if value.isdigit():
|
| 137 |
+
idx = int(value)
|
| 138 |
+
if 1 <= idx <= len(entries):
|
| 139 |
+
return entries[idx - 1].path
|
| 140 |
+
|
| 141 |
+
candidate = Path(value).expanduser()
|
| 142 |
+
candidates = [candidate]
|
| 143 |
+
if not candidate.is_absolute():
|
| 144 |
+
candidates.append(directory / candidate)
|
| 145 |
+
if candidate.suffix != ".json":
|
| 146 |
+
candidates.append(directory / f"{value}.json")
|
| 147 |
+
|
| 148 |
+
for path in candidates:
|
| 149 |
+
if path.exists() and path.is_file():
|
| 150 |
+
return path
|
| 151 |
+
|
| 152 |
+
matches = [
|
| 153 |
+
entry.path
|
| 154 |
+
for entry in entries
|
| 155 |
+
if entry.session_id.startswith(value) or entry.path.name.startswith(value)
|
| 156 |
+
]
|
| 157 |
+
if len(matches) == 1:
|
| 158 |
+
return matches[0]
|
| 159 |
+
return None
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _turn_count_from_messages(messages: list[Any]) -> int:
|
| 163 |
+
return sum(
|
| 164 |
+
1 for raw in messages if isinstance(raw, dict) and raw.get("role") == "user"
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _has_redacted_content(messages: list[Any]) -> bool:
|
| 169 |
+
"""Whether any message body contains a ``[REDACTED_*]`` marker."""
|
| 170 |
+
for raw in messages:
|
| 171 |
+
if not isinstance(raw, dict):
|
| 172 |
+
continue
|
| 173 |
+
content = raw.get("content")
|
| 174 |
+
if isinstance(content, str) and _REDACTED_MARKER.search(content):
|
| 175 |
+
return True
|
| 176 |
+
if isinstance(content, list):
|
| 177 |
+
for block in content:
|
| 178 |
+
if isinstance(block, dict):
|
| 179 |
+
text = block.get("text") or block.get("content")
|
| 180 |
+
if isinstance(text, str) and _REDACTED_MARKER.search(text):
|
| 181 |
+
return True
|
| 182 |
+
return False
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def restore_session_from_log(session: Any, path: Path) -> dict[str, Any]:
|
| 186 |
+
"""Replace the active session context with messages from ``path``.
|
| 187 |
+
|
| 188 |
+
Continues the saved session (reusing its id and on-disk save path) when
|
| 189 |
+
the log's ``user_id`` matches the current session, and forks otherwise:
|
| 190 |
+
the caller's session id stays put and future heartbeat saves go to a
|
| 191 |
+
fresh file rather than overwriting the source log.
|
| 192 |
+
|
| 193 |
+
Returns metadata for the ``resume_complete`` event.
|
| 194 |
+
"""
|
| 195 |
+
with open(path) as f:
|
| 196 |
+
data = json.load(f)
|
| 197 |
+
|
| 198 |
+
raw_messages = data.get("messages")
|
| 199 |
+
if not isinstance(raw_messages, list):
|
| 200 |
+
raise ValueError("Selected log does not contain a messages array")
|
| 201 |
+
|
| 202 |
+
restored_messages: list[Message] = []
|
| 203 |
+
dropped_count = 0
|
| 204 |
+
for raw in raw_messages:
|
| 205 |
+
if not isinstance(raw, dict) or raw.get("role") == "system":
|
| 206 |
+
continue
|
| 207 |
+
try:
|
| 208 |
+
restored_messages.append(Message.model_validate(raw))
|
| 209 |
+
except Exception as e:
|
| 210 |
+
dropped_count += 1
|
| 211 |
+
logger.warning("Dropping malformed message from %s: %s", path, e)
|
| 212 |
+
|
| 213 |
+
if not restored_messages:
|
| 214 |
+
raise ValueError("Selected log has no restorable non-system messages")
|
| 215 |
+
|
| 216 |
+
cm = session.context_manager
|
| 217 |
+
system_msg = cm.items[0] if cm.items and cm.items[0].role == "system" else None
|
| 218 |
+
cm.items = ([system_msg] if system_msg else []) + restored_messages
|
| 219 |
+
|
| 220 |
+
# Validate the saved model id before switching. ``update_model`` doesn't
|
| 221 |
+
# check availability; an unrecognised id silently sticks and the next LLM
|
| 222 |
+
# call fails with a cryptic routing error. Logs from a different
|
| 223 |
+
# deployment, an older catalog, or a removed model land here.
|
| 224 |
+
saved_model = data.get("model_name")
|
| 225 |
+
invalid_saved_model: str | None = None
|
| 226 |
+
if isinstance(saved_model, str) and saved_model:
|
| 227 |
+
if is_valid_model_id(saved_model):
|
| 228 |
+
session.update_model(saved_model)
|
| 229 |
+
else:
|
| 230 |
+
invalid_saved_model = saved_model
|
| 231 |
+
logger.warning(
|
| 232 |
+
"Saved log model %r failed format validation; keeping %r",
|
| 233 |
+
saved_model,
|
| 234 |
+
session.config.model_name,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
cm._recompute_usage(session.config.model_name)
|
| 238 |
+
|
| 239 |
+
saved_session_id = data.get("session_id")
|
| 240 |
+
saved_user_id = data.get("user_id")
|
| 241 |
+
is_continuation = saved_user_id == session.user_id
|
| 242 |
+
|
| 243 |
+
if is_continuation:
|
| 244 |
+
if isinstance(saved_session_id, str) and saved_session_id:
|
| 245 |
+
session.session_id = saved_session_id
|
| 246 |
+
session.session_start_time = (
|
| 247 |
+
data.get("session_start_time") or session.session_start_time
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# Always fork the on-disk save path. The source log is treated as an
|
| 251 |
+
# immutable snapshot: ``logged_events`` is reset to a single
|
| 252 |
+
# ``resumed_from`` marker below for cost accounting, so reusing the
|
| 253 |
+
# source path would let the next heartbeat save destroy the original
|
| 254 |
+
# ``llm_call``/event history on disk. The next save will pick a fresh
|
| 255 |
+
# filename instead.
|
| 256 |
+
session._local_save_path = None
|
| 257 |
+
|
| 258 |
+
saved_event_count = (
|
| 259 |
+
len(data.get("events", [])) if isinstance(data.get("events"), list) else 0
|
| 260 |
+
)
|
| 261 |
+
session.logged_events = [
|
| 262 |
+
{
|
| 263 |
+
"timestamp": datetime.now().isoformat(),
|
| 264 |
+
"event_type": "resumed_from",
|
| 265 |
+
"data": {
|
| 266 |
+
"path": str(path),
|
| 267 |
+
"original_session_id": (
|
| 268 |
+
saved_session_id if isinstance(saved_session_id, str) else None
|
| 269 |
+
),
|
| 270 |
+
"original_event_count": saved_event_count,
|
| 271 |
+
"forked": not is_continuation,
|
| 272 |
+
},
|
| 273 |
+
}
|
| 274 |
+
]
|
| 275 |
+
session.turn_count = _turn_count_from_messages(raw_messages)
|
| 276 |
+
session.last_auto_save_turn = session.turn_count
|
| 277 |
+
session.pending_approval = None
|
| 278 |
+
|
| 279 |
+
return {
|
| 280 |
+
"path": str(path),
|
| 281 |
+
"restored_count": len(restored_messages),
|
| 282 |
+
"dropped_count": dropped_count,
|
| 283 |
+
"model_name": session.config.model_name,
|
| 284 |
+
"invalid_saved_model": invalid_saved_model,
|
| 285 |
+
"forked": not is_continuation,
|
| 286 |
+
"had_redacted_content": _has_redacted_content(raw_messages),
|
| 287 |
+
}
|
agent/main.py
CHANGED
|
@@ -9,6 +9,7 @@ Supports two modes:
|
|
| 9 |
import argparse
|
| 10 |
import asyncio
|
| 11 |
import json
|
|
|
|
| 12 |
import os
|
| 13 |
import signal
|
| 14 |
import sys
|
|
@@ -55,6 +56,7 @@ litellm.drop_params = True
|
|
| 55 |
litellm.suppress_debug_info = True
|
| 56 |
|
| 57 |
CLI_CONFIG_PATH = Path(__file__).parent.parent / "configs" / "cli_agent_config.json"
|
|
|
|
| 58 |
|
| 59 |
|
| 60 |
def _is_scheduled_hf_job_tool(tool_info: dict[str, Any]) -> bool:
|
|
@@ -368,6 +370,46 @@ async def event_listener(
|
|
| 368 |
elif event.event_type == "undo_complete":
|
| 369 |
console.print("[dim]Undone.[/dim]")
|
| 370 |
turn_complete_event.set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
elif event.event_type == "tool_log":
|
| 372 |
tool = event.data.get("tool", "") if event.data else ""
|
| 373 |
log = event.data.get("log", "") if event.data else ""
|
|
@@ -739,12 +781,69 @@ async def get_user_input(prompt_session: PromptSession) -> str:
|
|
| 739 |
# Slash commands are defined in terminal_display
|
| 740 |
|
| 741 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 742 |
async def _handle_slash_command(
|
| 743 |
cmd: str,
|
| 744 |
config,
|
| 745 |
session_holder: list,
|
| 746 |
submission_queue: asyncio.Queue,
|
| 747 |
submission_id: list[int],
|
|
|
|
| 748 |
) -> Submission | None:
|
| 749 |
"""
|
| 750 |
Handle a slash command. Returns a Submission to enqueue, or None if
|
|
@@ -775,6 +874,24 @@ async def _handle_slash_command(
|
|
| 775 |
operation=Operation(op_type=OpType.COMPACT),
|
| 776 |
)
|
| 777 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 778 |
if command == "/model":
|
| 779 |
console = get_console()
|
| 780 |
if not arg:
|
|
@@ -1136,6 +1253,7 @@ async def main(model: str | None = None):
|
|
| 1136 |
session_holder,
|
| 1137 |
submission_queue,
|
| 1138 |
submission_id,
|
|
|
|
| 1139 |
)
|
| 1140 |
if sub is None:
|
| 1141 |
# Command handled locally, loop back for input
|
|
|
|
| 9 |
import argparse
|
| 10 |
import asyncio
|
| 11 |
import json
|
| 12 |
+
import logging
|
| 13 |
import os
|
| 14 |
import signal
|
| 15 |
import sys
|
|
|
|
| 56 |
litellm.suppress_debug_info = True
|
| 57 |
|
| 58 |
CLI_CONFIG_PATH = Path(__file__).parent.parent / "configs" / "cli_agent_config.json"
|
| 59 |
+
logger = logging.getLogger(__name__)
|
| 60 |
|
| 61 |
|
| 62 |
def _is_scheduled_hf_job_tool(tool_info: dict[str, Any]) -> bool:
|
|
|
|
| 370 |
elif event.event_type == "undo_complete":
|
| 371 |
console.print("[dim]Undone.[/dim]")
|
| 372 |
turn_complete_event.set()
|
| 373 |
+
elif event.event_type == "resume_complete":
|
| 374 |
+
data = event.data or {}
|
| 375 |
+
path = data.get("path", "?")
|
| 376 |
+
count = data.get("restored_count", 0)
|
| 377 |
+
dropped = int(data.get("dropped_count", 0) or 0)
|
| 378 |
+
model = data.get("model_name", "?")
|
| 379 |
+
invalid_model = data.get("invalid_saved_model")
|
| 380 |
+
forked = bool(data.get("forked", False))
|
| 381 |
+
redacted = bool(data.get("had_redacted_content", False))
|
| 382 |
+
verb = "Forked from" if forked else "Resumed"
|
| 383 |
+
console.print(
|
| 384 |
+
f"[green]{verb}[/green] {path} "
|
| 385 |
+
f"([cyan]{count}[/cyan] messages, "
|
| 386 |
+
f"model [cyan]{model}[/cyan])."
|
| 387 |
+
)
|
| 388 |
+
if dropped:
|
| 389 |
+
console.print(
|
| 390 |
+
f"[yellow]Warning:[/yellow] dropped {dropped} "
|
| 391 |
+
"malformed message(s) while restoring — surrounding "
|
| 392 |
+
"tool-call alignment may be off."
|
| 393 |
+
)
|
| 394 |
+
if invalid_model:
|
| 395 |
+
console.print(
|
| 396 |
+
f"[yellow]Warning:[/yellow] saved model id "
|
| 397 |
+
f"[cyan]{invalid_model}[/cyan] failed validation; "
|
| 398 |
+
f"kept current model [cyan]{model}[/cyan]."
|
| 399 |
+
)
|
| 400 |
+
if forked:
|
| 401 |
+
console.print(
|
| 402 |
+
"[dim]Saved log belongs to a different user — kept "
|
| 403 |
+
"current session id; future saves go to a fresh file.[/dim]"
|
| 404 |
+
)
|
| 405 |
+
if redacted:
|
| 406 |
+
console.print(
|
| 407 |
+
"[yellow]Note:[/yellow] tokens/secrets in restored "
|
| 408 |
+
"messages were scrubbed at save time. Your live tokens "
|
| 409 |
+
"are used for this session; [REDACTED_*] markers in "
|
| 410 |
+
"past messages are not re-injected."
|
| 411 |
+
)
|
| 412 |
+
turn_complete_event.set()
|
| 413 |
elif event.event_type == "tool_log":
|
| 414 |
tool = event.data.get("tool", "") if event.data else ""
|
| 415 |
log = event.data.get("log", "") if event.data else ""
|
|
|
|
| 781 |
# Slash commands are defined in terminal_display
|
| 782 |
|
| 783 |
|
| 784 |
+
async def _resume_picker(
|
| 785 |
+
arg: str,
|
| 786 |
+
prompt_session: PromptSession | None,
|
| 787 |
+
) -> Path | None:
|
| 788 |
+
"""Resolve a session log path via ``arg`` or interactive selection.
|
| 789 |
+
|
| 790 |
+
Returns ``None`` if the user cancels, no logs exist, or the argument
|
| 791 |
+
matches nothing — already prints the explanation in those cases.
|
| 792 |
+
"""
|
| 793 |
+
from agent.core.session_resume import (
|
| 794 |
+
format_session_log_entry,
|
| 795 |
+
list_session_logs,
|
| 796 |
+
resolve_session_log_arg,
|
| 797 |
+
)
|
| 798 |
+
from agent.core.session import DEFAULT_SESSION_LOG_DIR
|
| 799 |
+
|
| 800 |
+
console = get_console()
|
| 801 |
+
directory = DEFAULT_SESSION_LOG_DIR
|
| 802 |
+
entries = list_session_logs(directory)
|
| 803 |
+
if not entries:
|
| 804 |
+
console.print(f"[yellow]No session logs found in ./{directory}.[/yellow]")
|
| 805 |
+
return None
|
| 806 |
+
|
| 807 |
+
if arg:
|
| 808 |
+
selected = resolve_session_log_arg(arg, entries, directory)
|
| 809 |
+
if selected is None:
|
| 810 |
+
console.print(f"[bold red]No matching session log:[/bold red] {arg}")
|
| 811 |
+
return selected
|
| 812 |
+
|
| 813 |
+
console.print()
|
| 814 |
+
console.print("[bold]Saved sessions[/bold]")
|
| 815 |
+
for index, entry in enumerate(entries, start=1):
|
| 816 |
+
console.print(format_session_log_entry(index, entry))
|
| 817 |
+
console.print()
|
| 818 |
+
|
| 819 |
+
if prompt_session is None:
|
| 820 |
+
console.print("[yellow]Cannot prompt for a selection here.[/yellow]")
|
| 821 |
+
return None
|
| 822 |
+
|
| 823 |
+
try:
|
| 824 |
+
choice = await prompt_session.prompt_async(
|
| 825 |
+
"Select session number (blank to cancel): "
|
| 826 |
+
)
|
| 827 |
+
except (EOFError, KeyboardInterrupt):
|
| 828 |
+
console.print("[dim]Resume cancelled.[/dim]")
|
| 829 |
+
return None
|
| 830 |
+
choice = choice.strip()
|
| 831 |
+
if not choice:
|
| 832 |
+
console.print("[dim]Resume cancelled.[/dim]")
|
| 833 |
+
return None
|
| 834 |
+
selected = resolve_session_log_arg(choice, entries, directory)
|
| 835 |
+
if selected is None:
|
| 836 |
+
console.print(f"[bold red]Invalid selection:[/bold red] {choice}")
|
| 837 |
+
return selected
|
| 838 |
+
|
| 839 |
+
|
| 840 |
async def _handle_slash_command(
|
| 841 |
cmd: str,
|
| 842 |
config,
|
| 843 |
session_holder: list,
|
| 844 |
submission_queue: asyncio.Queue,
|
| 845 |
submission_id: list[int],
|
| 846 |
+
prompt_session: PromptSession | None = None,
|
| 847 |
) -> Submission | None:
|
| 848 |
"""
|
| 849 |
Handle a slash command. Returns a Submission to enqueue, or None if
|
|
|
|
| 874 |
operation=Operation(op_type=OpType.COMPACT),
|
| 875 |
)
|
| 876 |
|
| 877 |
+
if command == "/resume":
|
| 878 |
+
session = session_holder[0] if session_holder else None
|
| 879 |
+
if session is None:
|
| 880 |
+
get_console().print(
|
| 881 |
+
"[bold red]No active session to restore into.[/bold red]"
|
| 882 |
+
)
|
| 883 |
+
return None
|
| 884 |
+
selected_path = await _resume_picker(arg, prompt_session)
|
| 885 |
+
if selected_path is None:
|
| 886 |
+
return None
|
| 887 |
+
submission_id[0] += 1
|
| 888 |
+
return Submission(
|
| 889 |
+
id=f"sub_{submission_id[0]}",
|
| 890 |
+
operation=Operation(
|
| 891 |
+
op_type=OpType.RESUME, data={"path": str(selected_path)}
|
| 892 |
+
),
|
| 893 |
+
)
|
| 894 |
+
|
| 895 |
if command == "/model":
|
| 896 |
console = get_console()
|
| 897 |
if not arg:
|
|
|
|
| 1253 |
session_holder,
|
| 1254 |
submission_queue,
|
| 1255 |
submission_id,
|
| 1256 |
+
prompt_session,
|
| 1257 |
)
|
| 1258 |
if sub is None:
|
| 1259 |
# Command handled locally, loop back for input
|
agent/tools/sandbox_client.py
CHANGED
|
@@ -65,7 +65,6 @@ MAX_TIMEOUT = 1200
|
|
| 65 |
WAIT_TIMEOUT = 600
|
| 66 |
WAIT_INTERVAL = 5
|
| 67 |
API_WAIT_TIMEOUT = 180
|
| 68 |
-
HARDWARE_REQUEST_TIMEOUT = 60
|
| 69 |
CPU_BASIC_HARDWARE = "cpu-basic"
|
| 70 |
|
| 71 |
|
|
@@ -78,58 +77,6 @@ def _is_transient_space_visibility_error(error: Exception) -> bool:
|
|
| 78 |
return "Repository Not Found" in message or "404 Client Error" in message
|
| 79 |
|
| 80 |
|
| 81 |
-
def _is_transient_space_management_error(error: Exception) -> bool:
|
| 82 |
-
"""Return True when a just-created private Space is not manageable yet."""
|
| 83 |
-
response = getattr(error, "response", None)
|
| 84 |
-
if getattr(response, "status_code", None) in {401, 404}:
|
| 85 |
-
return True
|
| 86 |
-
message = str(error)
|
| 87 |
-
return (
|
| 88 |
-
"Repository Not Found" in message
|
| 89 |
-
or "401 Client Error" in message
|
| 90 |
-
or "404 Client Error" in message
|
| 91 |
-
)
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
def _request_space_hardware_with_retry(
|
| 95 |
-
api: HfApi,
|
| 96 |
-
space_id: str,
|
| 97 |
-
*,
|
| 98 |
-
hardware: str,
|
| 99 |
-
sleep_time: int | None,
|
| 100 |
-
log: Callable[[str], object],
|
| 101 |
-
check_cancel: Callable[[], object],
|
| 102 |
-
) -> None:
|
| 103 |
-
"""Request hardware, retrying while Hub permissions propagate for a new Space."""
|
| 104 |
-
deadline = time.time() + HARDWARE_REQUEST_TIMEOUT
|
| 105 |
-
attempt = 0
|
| 106 |
-
while True:
|
| 107 |
-
check_cancel()
|
| 108 |
-
try:
|
| 109 |
-
api.request_space_hardware(
|
| 110 |
-
space_id,
|
| 111 |
-
hardware=hardware,
|
| 112 |
-
sleep_time=sleep_time,
|
| 113 |
-
)
|
| 114 |
-
return
|
| 115 |
-
except Exception as e:
|
| 116 |
-
if not _is_transient_space_management_error(e):
|
| 117 |
-
raise
|
| 118 |
-
|
| 119 |
-
remaining = deadline - time.time()
|
| 120 |
-
if remaining <= 0:
|
| 121 |
-
raise
|
| 122 |
-
|
| 123 |
-
attempt += 1
|
| 124 |
-
status_code = getattr(getattr(e, "response", None), "status_code", None)
|
| 125 |
-
status = f"HTTP {status_code}" if status_code else type(e).__name__
|
| 126 |
-
log(
|
| 127 |
-
f" Hardware request not accepted yet ({status}); "
|
| 128 |
-
f"retrying ({attempt})..."
|
| 129 |
-
)
|
| 130 |
-
time.sleep(min(WAIT_INTERVAL, remaining))
|
| 131 |
-
|
| 132 |
-
|
| 133 |
_DOCKERFILE = """\
|
| 134 |
FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
|
| 135 |
|
|
@@ -679,24 +626,21 @@ class Sandbox:
|
|
| 679 |
|
| 680 |
_check_cancel()
|
| 681 |
|
| 682 |
-
# ``duplicate_space``
|
| 683 |
-
#
|
| 684 |
-
# 401 on that endpoint for a
|
| 685 |
-
#
|
| 686 |
-
#
|
| 687 |
-
#
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
check_cancel=_check_cancel,
|
| 698 |
-
)
|
| 699 |
-
_log(f"Requested hardware: {hardware}")
|
| 700 |
|
| 701 |
# Inject secrets BEFORE uploading server files (which triggers rebuild).
|
| 702 |
# Secrets added after a Space is running aren't available until restart,
|
|
|
|
| 65 |
WAIT_TIMEOUT = 600
|
| 66 |
WAIT_INTERVAL = 5
|
| 67 |
API_WAIT_TIMEOUT = 180
|
|
|
|
| 68 |
CPU_BASIC_HARDWARE = "cpu-basic"
|
| 69 |
|
| 70 |
|
|
|
|
| 77 |
return "Repository Not Found" in message or "404 Client Error" in message
|
| 78 |
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
_DOCKERFILE = """\
|
| 81 |
FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
|
| 82 |
|
|
|
|
| 626 |
|
| 627 |
_check_cancel()
|
| 628 |
|
| 629 |
+
# ``duplicate_space`` sends hardware and sleepTimeSeconds in the
|
| 630 |
+
# initial create request. Avoid a second /hardware call: deployed HF
|
| 631 |
+
# OAuth tokens can 401 on that endpoint for a just-created private
|
| 632 |
+
# Space even though duplication itself succeeded. We rely on the
|
| 633 |
+
# duplicate endpoint to honor sleepTimeSeconds for upgraded hardware;
|
| 634 |
+
# cpu-basic auto-sleep is fixed by the Hub.
|
| 635 |
+
_log(f"Using duplicated Space hardware: {hardware}")
|
| 636 |
+
if sleep_time is not None:
|
| 637 |
+
if hardware == CPU_BASIC_HARDWARE:
|
| 638 |
+
_log(
|
| 639 |
+
f"Requested duplicated Space sleep time: {sleep_time}s "
|
| 640 |
+
"(cpu-basic auto-sleep is fixed by the Hub)"
|
| 641 |
+
)
|
| 642 |
+
else:
|
| 643 |
+
_log(f"Using duplicated Space sleep time: {sleep_time}s")
|
|
|
|
|
|
|
|
|
|
| 644 |
|
| 645 |
# Inject secrets BEFORE uploading server files (which triggers rebuild).
|
| 646 |
# Secrets added after a Space is running aren't available until restart,
|
agent/utils/terminal_display.py
CHANGED
|
@@ -451,6 +451,7 @@ HELP_TEXT = f"""\
|
|
| 451 |
{_I} [cyan]/help[/cyan] Show this help
|
| 452 |
{_I} [cyan]/undo[/cyan] Undo last turn
|
| 453 |
{_I} [cyan]/compact[/cyan] Compact context window
|
|
|
|
| 454 |
{_I} [cyan]/model[/cyan] [id] Show available models or switch
|
| 455 |
{_I} [cyan]/effort[/cyan] [level] Reasoning effort (minimal|low|medium|high|xhigh|max|off)
|
| 456 |
{_I} [cyan]/yolo[/cyan] Toggle auto-approve mode
|
|
|
|
| 451 |
{_I} [cyan]/help[/cyan] Show this help
|
| 452 |
{_I} [cyan]/undo[/cyan] Undo last turn
|
| 453 |
{_I} [cyan]/compact[/cyan] Compact context window
|
| 454 |
+
{_I} [cyan]/resume[/cyan] [index|id|path] Pick up from a log in ./session_logs
|
| 455 |
{_I} [cyan]/model[/cyan] [id] Show available models or switch
|
| 456 |
{_I} [cyan]/effort[/cyan] [level] Reasoning effort (minimal|low|medium|high|xhigh|max|off)
|
| 457 |
{_I} [cyan]/yolo[/cyan] Toggle auto-approve mode
|
backend/session_manager.py
CHANGED
|
@@ -12,7 +12,6 @@ from typing import Any, Optional
|
|
| 12 |
|
| 13 |
from agent.config import load_config
|
| 14 |
from agent.core.agent_loop import process_submission
|
| 15 |
-
from agent.core.hub_artifacts import start_session_artifact_collection_task
|
| 16 |
from agent.core.session import Event, OpType, Session
|
| 17 |
from agent.core.session_persistence import get_session_store
|
| 18 |
from agent.core.tools import ToolRouter
|
|
@@ -136,7 +135,6 @@ class SessionManager:
|
|
| 136 |
self.sessions: dict[str, AgentSession] = {}
|
| 137 |
self._lock = asyncio.Lock()
|
| 138 |
self.persistence_store = None
|
| 139 |
-
self.enable_hub_artifact_collections = True
|
| 140 |
|
| 141 |
async def start(self) -> None:
|
| 142 |
"""Start shared background resources."""
|
|
@@ -413,28 +411,6 @@ class SessionManager:
|
|
| 413 |
session.sandbox_preload_cancel_event = None
|
| 414 |
self._start_cpu_sandbox_preload(agent_session)
|
| 415 |
|
| 416 |
-
def _start_hub_artifact_collection(self, agent_session: AgentSession) -> None:
|
| 417 |
-
"""Kick off best-effort Hub collection creation for the session."""
|
| 418 |
-
if not getattr(self, "enable_hub_artifact_collections", False):
|
| 419 |
-
return
|
| 420 |
-
session = agent_session.session
|
| 421 |
-
if not getattr(session, "session_id", None):
|
| 422 |
-
try:
|
| 423 |
-
session.session_id = agent_session.session_id
|
| 424 |
-
except Exception:
|
| 425 |
-
logger.debug("Could not attach session id for Hub artifact collection")
|
| 426 |
-
token = agent_session.hf_token or getattr(session, "hf_token", None)
|
| 427 |
-
if not token:
|
| 428 |
-
return
|
| 429 |
-
try:
|
| 430 |
-
start_session_artifact_collection_task(session, token=token)
|
| 431 |
-
except Exception as e:
|
| 432 |
-
logger.debug(
|
| 433 |
-
"Failed to schedule Hub artifact collection for %s: %s",
|
| 434 |
-
agent_session.session_id,
|
| 435 |
-
e,
|
| 436 |
-
)
|
| 437 |
-
|
| 438 |
async def _clear_persisted_sandbox_metadata(self, session_id: str) -> None:
|
| 439 |
try:
|
| 440 |
await self._store().update_session_fields(
|
|
@@ -591,7 +567,6 @@ class SessionManager:
|
|
| 591 |
existing,
|
| 592 |
preload_sandbox=preload_sandbox,
|
| 593 |
)
|
| 594 |
-
self._start_hub_artifact_collection(existing)
|
| 595 |
return existing
|
| 596 |
return None
|
| 597 |
|
|
@@ -613,7 +588,6 @@ class SessionManager:
|
|
| 613 |
existing,
|
| 614 |
preload_sandbox=preload_sandbox,
|
| 615 |
)
|
| 616 |
-
self._start_hub_artifact_collection(existing)
|
| 617 |
return existing
|
| 618 |
return None
|
| 619 |
|
|
@@ -700,9 +674,7 @@ class SessionManager:
|
|
| 700 |
hf_token=hf_token,
|
| 701 |
hf_username=hf_username,
|
| 702 |
)
|
| 703 |
-
self._start_hub_artifact_collection(started)
|
| 704 |
return started
|
| 705 |
-
self._start_hub_artifact_collection(agent_session)
|
| 706 |
if preload_sandbox:
|
| 707 |
self._start_cpu_sandbox_preload(agent_session)
|
| 708 |
logger.info("Restored session %s for user %s", session_id, owner or user_id)
|
|
@@ -785,7 +757,6 @@ class SessionManager:
|
|
| 785 |
event_queue=event_queue,
|
| 786 |
tool_router=tool_router,
|
| 787 |
)
|
| 788 |
-
self._start_hub_artifact_collection(agent_session)
|
| 789 |
await self.persist_session_snapshot(agent_session, runtime_state="idle")
|
| 790 |
self._start_cpu_sandbox_preload(agent_session)
|
| 791 |
|
|
|
|
| 12 |
|
| 13 |
from agent.config import load_config
|
| 14 |
from agent.core.agent_loop import process_submission
|
|
|
|
| 15 |
from agent.core.session import Event, OpType, Session
|
| 16 |
from agent.core.session_persistence import get_session_store
|
| 17 |
from agent.core.tools import ToolRouter
|
|
|
|
| 135 |
self.sessions: dict[str, AgentSession] = {}
|
| 136 |
self._lock = asyncio.Lock()
|
| 137 |
self.persistence_store = None
|
|
|
|
| 138 |
|
| 139 |
async def start(self) -> None:
|
| 140 |
"""Start shared background resources."""
|
|
|
|
| 411 |
session.sandbox_preload_cancel_event = None
|
| 412 |
self._start_cpu_sandbox_preload(agent_session)
|
| 413 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
async def _clear_persisted_sandbox_metadata(self, session_id: str) -> None:
|
| 415 |
try:
|
| 416 |
await self._store().update_session_fields(
|
|
|
|
| 567 |
existing,
|
| 568 |
preload_sandbox=preload_sandbox,
|
| 569 |
)
|
|
|
|
| 570 |
return existing
|
| 571 |
return None
|
| 572 |
|
|
|
|
| 588 |
existing,
|
| 589 |
preload_sandbox=preload_sandbox,
|
| 590 |
)
|
|
|
|
| 591 |
return existing
|
| 592 |
return None
|
| 593 |
|
|
|
|
| 674 |
hf_token=hf_token,
|
| 675 |
hf_username=hf_username,
|
| 676 |
)
|
|
|
|
| 677 |
return started
|
|
|
|
| 678 |
if preload_sandbox:
|
| 679 |
self._start_cpu_sandbox_preload(agent_session)
|
| 680 |
logger.info("Restored session %s for user %s", session_id, owner or user_id)
|
|
|
|
| 757 |
event_queue=event_queue,
|
| 758 |
tool_router=tool_router,
|
| 759 |
)
|
|
|
|
| 760 |
await self.persist_session_snapshot(agent_session, runtime_state="idle")
|
| 761 |
self._start_cpu_sandbox_preload(agent_session)
|
| 762 |
|
tests/unit/test_hub_artifacts.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
import asyncio
|
| 2 |
import logging
|
| 3 |
from types import SimpleNamespace
|
| 4 |
|
|
@@ -11,12 +10,10 @@ from agent.core.hub_artifacts import (
|
|
| 11 |
artifact_collection_title,
|
| 12 |
augment_repo_card_content,
|
| 13 |
build_hub_artifact_sitecustomize,
|
| 14 |
-
ensure_session_artifact_collection,
|
| 15 |
is_known_hub_artifact,
|
| 16 |
is_sandbox_hub_repo,
|
| 17 |
register_hub_artifact,
|
| 18 |
remember_hub_artifact,
|
| 19 |
-
start_session_artifact_collection_task,
|
| 20 |
wrap_shell_command_with_hub_artifact_bootstrap,
|
| 21 |
)
|
| 22 |
from agent.tools import local_tools, sandbox_tool
|
|
@@ -207,6 +204,7 @@ def test_register_hub_artifact_retries_after_partial_failure(monkeypatch):
|
|
| 207 |
def add_to_collection(*args, **kwargs):
|
| 208 |
nonlocal collection_attempts
|
| 209 |
collection_attempts += 1
|
|
|
|
| 210 |
|
| 211 |
monkeypatch.setattr(
|
| 212 |
hub_artifacts,
|
|
@@ -238,6 +236,7 @@ def test_register_hub_artifact_retries_after_collection_failure(monkeypatch):
|
|
| 238 |
collection_attempts += 1
|
| 239 |
if collection_attempts == 1:
|
| 240 |
raise RuntimeError("temporary collection failure")
|
|
|
|
| 241 |
|
| 242 |
monkeypatch.setattr(hub_artifacts, "_update_repo_card", update_repo_card)
|
| 243 |
monkeypatch.setattr(
|
|
@@ -271,63 +270,6 @@ def test_session_artifact_set_falls_back_when_session_rejects_attrs(caplog):
|
|
| 271 |
assert "using process-local fallback state" in caplog.text
|
| 272 |
|
| 273 |
|
| 274 |
-
@pytest.mark.asyncio
|
| 275 |
-
async def test_ensure_session_artifact_collection_uses_user_token(monkeypatch):
|
| 276 |
-
session = _session()
|
| 277 |
-
calls = []
|
| 278 |
-
|
| 279 |
-
class FakeApi:
|
| 280 |
-
def __init__(self, token):
|
| 281 |
-
self.token = token
|
| 282 |
-
|
| 283 |
-
def fake_ensure_collection_slug(api, seen_session, **kwargs):
|
| 284 |
-
calls.append((api.token, seen_session, kwargs))
|
| 285 |
-
return "alice/ml-intern-artifacts"
|
| 286 |
-
|
| 287 |
-
monkeypatch.setattr(hub_artifacts, "HfApi", FakeApi)
|
| 288 |
-
monkeypatch.setattr(
|
| 289 |
-
hub_artifacts,
|
| 290 |
-
"_ensure_collection_slug",
|
| 291 |
-
fake_ensure_collection_slug,
|
| 292 |
-
)
|
| 293 |
-
|
| 294 |
-
slug = await ensure_session_artifact_collection(session, token="hf-token")
|
| 295 |
-
|
| 296 |
-
assert slug == "alice/ml-intern-artifacts"
|
| 297 |
-
assert calls == [
|
| 298 |
-
("hf-token", session, {"token": "hf-token"}),
|
| 299 |
-
]
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
@pytest.mark.asyncio
|
| 303 |
-
async def test_start_session_artifact_collection_task_dedupes(monkeypatch):
|
| 304 |
-
session = _session()
|
| 305 |
-
calls = []
|
| 306 |
-
|
| 307 |
-
async def fake_ensure_session_artifact_collection(seen_session, **kwargs):
|
| 308 |
-
calls.append((seen_session, kwargs))
|
| 309 |
-
await asyncio.sleep(0)
|
| 310 |
-
return "alice/ml-intern-artifacts"
|
| 311 |
-
|
| 312 |
-
monkeypatch.setattr(
|
| 313 |
-
hub_artifacts,
|
| 314 |
-
"ensure_session_artifact_collection",
|
| 315 |
-
fake_ensure_session_artifact_collection,
|
| 316 |
-
)
|
| 317 |
-
|
| 318 |
-
task = start_session_artifact_collection_task(session, token="hf-token")
|
| 319 |
-
second = start_session_artifact_collection_task(session, token="hf-token")
|
| 320 |
-
|
| 321 |
-
assert task is not None
|
| 322 |
-
assert second is task
|
| 323 |
-
await task
|
| 324 |
-
assert calls == [(session, {"token": "hf-token"})]
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
def test_start_session_artifact_collection_task_skips_without_token():
|
| 328 |
-
assert start_session_artifact_collection_task(_session()) is None
|
| 329 |
-
|
| 330 |
-
|
| 331 |
@pytest.mark.asyncio
|
| 332 |
async def test_hf_repo_git_create_repo_registers_artifact(monkeypatch):
|
| 333 |
session = _session()
|
|
@@ -535,6 +477,78 @@ def test_sitecustomize_bootstrap_reuses_existing_collection_slug():
|
|
| 535 |
)
|
| 536 |
|
| 537 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
def test_sitecustomize_skips_sandbox_space_registration(monkeypatch):
|
| 539 |
import huggingface_hub as hub
|
| 540 |
from huggingface_hub import HfApi
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
from types import SimpleNamespace
|
| 3 |
|
|
|
|
| 10 |
artifact_collection_title,
|
| 11 |
augment_repo_card_content,
|
| 12 |
build_hub_artifact_sitecustomize,
|
|
|
|
| 13 |
is_known_hub_artifact,
|
| 14 |
is_sandbox_hub_repo,
|
| 15 |
register_hub_artifact,
|
| 16 |
remember_hub_artifact,
|
|
|
|
| 17 |
wrap_shell_command_with_hub_artifact_bootstrap,
|
| 18 |
)
|
| 19 |
from agent.tools import local_tools, sandbox_tool
|
|
|
|
| 204 |
def add_to_collection(*args, **kwargs):
|
| 205 |
nonlocal collection_attempts
|
| 206 |
collection_attempts += 1
|
| 207 |
+
return True
|
| 208 |
|
| 209 |
monkeypatch.setattr(
|
| 210 |
hub_artifacts,
|
|
|
|
| 236 |
collection_attempts += 1
|
| 237 |
if collection_attempts == 1:
|
| 238 |
raise RuntimeError("temporary collection failure")
|
| 239 |
+
return True
|
| 240 |
|
| 241 |
monkeypatch.setattr(hub_artifacts, "_update_repo_card", update_repo_card)
|
| 242 |
monkeypatch.setattr(
|
|
|
|
| 270 |
assert "using process-local fallback state" in caplog.text
|
| 271 |
|
| 272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
@pytest.mark.asyncio
|
| 274 |
async def test_hf_repo_git_create_repo_registers_artifact(monkeypatch):
|
| 275 |
session = _session()
|
|
|
|
| 477 |
)
|
| 478 |
|
| 479 |
|
| 480 |
+
def test_sitecustomize_caches_lazy_collection_slug_across_bootstraps(
|
| 481 |
+
monkeypatch,
|
| 482 |
+
tmp_path,
|
| 483 |
+
):
|
| 484 |
+
import huggingface_hub as hub
|
| 485 |
+
from huggingface_hub import HfApi
|
| 486 |
+
|
| 487 |
+
readme_path = tmp_path / "README.md"
|
| 488 |
+
readme_path.write_text("# Existing Model\n", encoding="utf-8")
|
| 489 |
+
cache_path = tmp_path / "collection-slug.txt"
|
| 490 |
+
collection_slug = "alice/ml-intern-artifacts-2026-05-05-session-123"
|
| 491 |
+
uploads = []
|
| 492 |
+
downloads = []
|
| 493 |
+
collection_creates = []
|
| 494 |
+
collection_items = []
|
| 495 |
+
|
| 496 |
+
def fake_upload_file(self, **kwargs):
|
| 497 |
+
uploads.append(kwargs)
|
| 498 |
+
return SimpleNamespace()
|
| 499 |
+
|
| 500 |
+
def fake_hf_hub_download(*args, **kwargs):
|
| 501 |
+
downloads.append((args, kwargs))
|
| 502 |
+
return str(readme_path)
|
| 503 |
+
|
| 504 |
+
def fake_create_collection(self, **kwargs):
|
| 505 |
+
collection_creates.append(kwargs)
|
| 506 |
+
return SimpleNamespace(slug=collection_slug)
|
| 507 |
+
|
| 508 |
+
def fake_add_collection_item(self, **kwargs):
|
| 509 |
+
collection_items.append(kwargs)
|
| 510 |
+
|
| 511 |
+
monkeypatch.setenv("ML_INTERN_ARTIFACT_COLLECTION_CACHE", str(cache_path))
|
| 512 |
+
code = build_hub_artifact_sitecustomize(_session())
|
| 513 |
+
|
| 514 |
+
def install_fresh_bootstrap():
|
| 515 |
+
monkeypatch.setattr(HfApi, "upload_file", fake_upload_file)
|
| 516 |
+
monkeypatch.setattr(HfApi, "create_collection", fake_create_collection)
|
| 517 |
+
monkeypatch.setattr(HfApi, "add_collection_item", fake_add_collection_item)
|
| 518 |
+
monkeypatch.setattr(hub, "hf_hub_download", fake_hf_hub_download)
|
| 519 |
+
exec(code, {})
|
| 520 |
+
assert HfApi.upload_file is not fake_upload_file
|
| 521 |
+
|
| 522 |
+
install_fresh_bootstrap()
|
| 523 |
+
HfApi(token="hf-token").upload_file(
|
| 524 |
+
path_or_fileobj=b"weights",
|
| 525 |
+
path_in_repo="model.safetensors",
|
| 526 |
+
repo_id="alice/model-a",
|
| 527 |
+
repo_type="model",
|
| 528 |
+
token="hf-token",
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
install_fresh_bootstrap()
|
| 532 |
+
HfApi(token="hf-token").upload_file(
|
| 533 |
+
path_or_fileobj=b"weights",
|
| 534 |
+
path_in_repo="model.safetensors",
|
| 535 |
+
repo_id="alice/model-b",
|
| 536 |
+
repo_type="model",
|
| 537 |
+
token="hf-token",
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
assert cache_path.read_text(encoding="utf-8") == collection_slug
|
| 541 |
+
assert len(collection_creates) == 1
|
| 542 |
+
assert [item["item_id"] for item in collection_items] == [
|
| 543 |
+
"alice/model-a",
|
| 544 |
+
"alice/model-b",
|
| 545 |
+
]
|
| 546 |
+
assert [download[1]["repo_id"] for download in downloads] == [
|
| 547 |
+
"alice/model-a",
|
| 548 |
+
"alice/model-b",
|
| 549 |
+
]
|
| 550 |
+
|
| 551 |
+
|
| 552 |
def test_sitecustomize_skips_sandbox_space_registration(monkeypatch):
|
| 553 |
import huggingface_hub as hub
|
| 554 |
from huggingface_hub import HfApi
|
tests/unit/test_sandbox_private_spaces.py
CHANGED
|
@@ -3,8 +3,6 @@ import threading
|
|
| 3 |
import time
|
| 4 |
from types import SimpleNamespace
|
| 5 |
|
| 6 |
-
import pytest
|
| 7 |
-
|
| 8 |
from agent.core import telemetry
|
| 9 |
from agent.tools import sandbox_client, sandbox_tool
|
| 10 |
from agent.tools.sandbox_client import Sandbox
|
|
@@ -17,6 +15,7 @@ def _fail_metadata_update(*args, **kwargs):
|
|
| 17 |
|
| 18 |
def test_sandbox_client_defaults_to_private_spaces(monkeypatch):
|
| 19 |
duplicate_kwargs = {}
|
|
|
|
| 20 |
requested_hardware = []
|
| 21 |
|
| 22 |
class FakeApi:
|
|
@@ -44,11 +43,12 @@ def test_sandbox_client_defaults_to_private_spaces(monkeypatch):
|
|
| 44 |
)
|
| 45 |
monkeypatch.setattr(Sandbox, "_wait_for_api", lambda self, *args, **kwargs: None)
|
| 46 |
|
| 47 |
-
Sandbox.create(owner="alice", token="hf-token", log=
|
| 48 |
|
| 49 |
assert duplicate_kwargs["private"] is True
|
| 50 |
assert duplicate_kwargs["hardware"] == "cpu-basic"
|
| 51 |
assert requested_hardware == []
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
def test_sandbox_client_retries_transient_runtime_404(monkeypatch):
|
|
@@ -98,32 +98,20 @@ def test_sandbox_client_retries_transient_runtime_404(monkeypatch):
|
|
| 98 |
assert runtime_calls == 2
|
| 99 |
|
| 100 |
|
| 101 |
-
def
|
| 102 |
-
|
| 103 |
logs: list[str] = []
|
| 104 |
-
|
| 105 |
-
class FakeResponse:
|
| 106 |
-
status_code = 401
|
| 107 |
-
|
| 108 |
-
class FakeHardware401(Exception):
|
| 109 |
-
response = FakeResponse()
|
| 110 |
-
|
| 111 |
-
def __str__(self):
|
| 112 |
-
return "401 Client Error: Repository Not Found"
|
| 113 |
|
| 114 |
class FakeApi:
|
| 115 |
def __init__(self, token=None):
|
| 116 |
self.token = token
|
| 117 |
|
| 118 |
def duplicate_space(self, **kwargs):
|
| 119 |
-
|
| 120 |
|
| 121 |
def request_space_hardware(self, space_id, hardware, sleep_time=None):
|
| 122 |
-
|
| 123 |
-
hardware_calls += 1
|
| 124 |
-
if hardware_calls == 1:
|
| 125 |
-
raise FakeHardware401()
|
| 126 |
-
return SimpleNamespace(stage="BUILDING", hardware=None)
|
| 127 |
|
| 128 |
def add_space_secret(self, *args, **kwargs):
|
| 129 |
pass
|
|
@@ -144,58 +132,62 @@ def test_sandbox_client_retries_transient_hardware_401(monkeypatch):
|
|
| 144 |
owner="alice",
|
| 145 |
token="hf-token",
|
| 146 |
hardware="t4-small",
|
|
|
|
| 147 |
log=logs.append,
|
| 148 |
)
|
| 149 |
|
| 150 |
assert sandbox.space_id.startswith("alice/sandbox-")
|
| 151 |
-
assert
|
| 152 |
-
assert
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
|
| 155 |
-
def
|
| 156 |
-
|
| 157 |
logs: list[str] = []
|
| 158 |
-
|
| 159 |
|
| 160 |
-
class
|
| 161 |
-
|
|
|
|
| 162 |
|
| 163 |
-
|
| 164 |
-
|
| 165 |
|
| 166 |
-
def
|
| 167 |
-
|
| 168 |
|
| 169 |
-
|
| 170 |
-
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
hardware="cpu-basic",
|
| 190 |
-
sleep_time=None,
|
| 191 |
-
log=logs.append,
|
| 192 |
-
check_cancel=lambda: None,
|
| 193 |
-
)
|
| 194 |
|
| 195 |
-
assert
|
| 196 |
-
assert
|
| 197 |
-
assert
|
| 198 |
-
assert
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
|
| 201 |
def test_sandbox_tool_forces_private_spaces(monkeypatch):
|
|
|
|
| 3 |
import time
|
| 4 |
from types import SimpleNamespace
|
| 5 |
|
|
|
|
|
|
|
| 6 |
from agent.core import telemetry
|
| 7 |
from agent.tools import sandbox_client, sandbox_tool
|
| 8 |
from agent.tools.sandbox_client import Sandbox
|
|
|
|
| 15 |
|
| 16 |
def test_sandbox_client_defaults_to_private_spaces(monkeypatch):
|
| 17 |
duplicate_kwargs = {}
|
| 18 |
+
logs: list[str] = []
|
| 19 |
requested_hardware = []
|
| 20 |
|
| 21 |
class FakeApi:
|
|
|
|
| 43 |
)
|
| 44 |
monkeypatch.setattr(Sandbox, "_wait_for_api", lambda self, *args, **kwargs: None)
|
| 45 |
|
| 46 |
+
Sandbox.create(owner="alice", token="hf-token", log=logs.append)
|
| 47 |
|
| 48 |
assert duplicate_kwargs["private"] is True
|
| 49 |
assert duplicate_kwargs["hardware"] == "cpu-basic"
|
| 50 |
assert requested_hardware == []
|
| 51 |
+
assert not any("sleep time" in log for log in logs)
|
| 52 |
|
| 53 |
|
| 54 |
def test_sandbox_client_retries_transient_runtime_404(monkeypatch):
|
|
|
|
| 98 |
assert runtime_calls == 2
|
| 99 |
|
| 100 |
|
| 101 |
+
def test_sandbox_client_configures_gpu_at_duplication(monkeypatch):
|
| 102 |
+
duplicate_kwargs = {}
|
| 103 |
logs: list[str] = []
|
| 104 |
+
requested_hardware = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
class FakeApi:
|
| 107 |
def __init__(self, token=None):
|
| 108 |
self.token = token
|
| 109 |
|
| 110 |
def duplicate_space(self, **kwargs):
|
| 111 |
+
duplicate_kwargs.update(kwargs)
|
| 112 |
|
| 113 |
def request_space_hardware(self, space_id, hardware, sleep_time=None):
|
| 114 |
+
requested_hardware.append((space_id, hardware, sleep_time))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
def add_space_secret(self, *args, **kwargs):
|
| 117 |
pass
|
|
|
|
| 132 |
owner="alice",
|
| 133 |
token="hf-token",
|
| 134 |
hardware="t4-small",
|
| 135 |
+
sleep_time=2700,
|
| 136 |
log=logs.append,
|
| 137 |
)
|
| 138 |
|
| 139 |
assert sandbox.space_id.startswith("alice/sandbox-")
|
| 140 |
+
assert duplicate_kwargs["hardware"] == "t4-small"
|
| 141 |
+
assert duplicate_kwargs["sleep_time"] == 2700
|
| 142 |
+
assert requested_hardware == []
|
| 143 |
+
assert "Using duplicated Space hardware: t4-small" in logs
|
| 144 |
+
assert "Using duplicated Space sleep time: 2700s" in logs
|
| 145 |
|
| 146 |
|
| 147 |
+
def test_sandbox_client_logs_cpu_sleep_time_as_hub_fixed(monkeypatch):
|
| 148 |
+
duplicate_kwargs = {}
|
| 149 |
logs: list[str] = []
|
| 150 |
+
requested_hardware = []
|
| 151 |
|
| 152 |
+
class FakeApi:
|
| 153 |
+
def __init__(self, token=None):
|
| 154 |
+
self.token = token
|
| 155 |
|
| 156 |
+
def duplicate_space(self, **kwargs):
|
| 157 |
+
duplicate_kwargs.update(kwargs)
|
| 158 |
|
| 159 |
+
def request_space_hardware(self, space_id, hardware, sleep_time=None):
|
| 160 |
+
requested_hardware.append((space_id, hardware, sleep_time))
|
| 161 |
|
| 162 |
+
def add_space_secret(self, *args, **kwargs):
|
| 163 |
+
pass
|
| 164 |
|
| 165 |
+
def get_space_runtime(self, space_id):
|
| 166 |
+
return SimpleNamespace(stage="RUNNING", hardware="cpu-basic")
|
| 167 |
+
|
| 168 |
+
monkeypatch.setattr(sandbox_client, "HfApi", FakeApi)
|
| 169 |
+
monkeypatch.setattr(
|
| 170 |
+
Sandbox,
|
| 171 |
+
"_setup_server",
|
| 172 |
+
staticmethod(lambda *args, **kwargs: None),
|
| 173 |
+
)
|
| 174 |
+
monkeypatch.setattr(Sandbox, "_wait_for_api", lambda self, *args, **kwargs: None)
|
| 175 |
+
|
| 176 |
+
Sandbox.create(
|
| 177 |
+
owner="alice",
|
| 178 |
+
token="hf-token",
|
| 179 |
+
sleep_time=2700,
|
| 180 |
+
log=logs.append,
|
| 181 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
+
assert duplicate_kwargs["hardware"] == "cpu-basic"
|
| 184 |
+
assert duplicate_kwargs["sleep_time"] == 2700
|
| 185 |
+
assert requested_hardware == []
|
| 186 |
+
assert "Using duplicated Space hardware: cpu-basic" in logs
|
| 187 |
+
assert (
|
| 188 |
+
"Requested duplicated Space sleep time: 2700s "
|
| 189 |
+
"(cpu-basic auto-sleep is fixed by the Hub)"
|
| 190 |
+
) in logs
|
| 191 |
|
| 192 |
|
| 193 |
def test_sandbox_tool_forces_private_spaces(monkeypatch):
|
tests/unit/test_session_manager_persistence.py
CHANGED
|
@@ -425,32 +425,9 @@ async def test_create_session_schedules_cpu_sandbox_preload():
|
|
| 425 |
|
| 426 |
assert scheduled == [session_id]
|
| 427 |
assert session_id in manager.sessions
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
@pytest.mark.asyncio
|
| 434 |
-
async def test_create_session_starts_hub_artifact_collection(monkeypatch):
|
| 435 |
-
manager = _manager_with_store(NoopSessionStore())
|
| 436 |
-
manager.enable_hub_artifact_collections = True
|
| 437 |
-
stop = _install_fake_runtime(manager)
|
| 438 |
-
started: list[tuple[str, str]] = []
|
| 439 |
-
|
| 440 |
-
def fake_start_session_artifact_collection_task(session, **kwargs):
|
| 441 |
-
started.append((session.session_id, kwargs["token"]))
|
| 442 |
-
return None
|
| 443 |
-
|
| 444 |
-
monkeypatch.setattr(
|
| 445 |
-
"session_manager.start_session_artifact_collection_task",
|
| 446 |
-
fake_start_session_artifact_collection_task,
|
| 447 |
-
)
|
| 448 |
-
manager._start_cpu_sandbox_preload = lambda _: None # type: ignore[method-assign]
|
| 449 |
-
|
| 450 |
-
try:
|
| 451 |
-
session_id = await manager.create_session(user_id="owner", hf_token="token")
|
| 452 |
-
|
| 453 |
-
assert started == [(session_id, "token")]
|
| 454 |
finally:
|
| 455 |
stop.set()
|
| 456 |
await _cancel_runtime_tasks(manager)
|
|
@@ -475,37 +452,8 @@ async def test_lazy_restore_schedules_cpu_sandbox_preload():
|
|
| 475 |
assert restored is not None
|
| 476 |
assert scheduled == ["persisted-session"]
|
| 477 |
assert "persisted-session" in manager.sessions
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
await _cancel_runtime_tasks(manager)
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
@pytest.mark.asyncio
|
| 484 |
-
async def test_lazy_restore_starts_hub_artifact_collection(monkeypatch):
|
| 485 |
-
manager = _manager_with_store(RestoreStore())
|
| 486 |
-
manager.enable_hub_artifact_collections = True
|
| 487 |
-
stop = _install_fake_runtime(manager)
|
| 488 |
-
started: list[tuple[str, str]] = []
|
| 489 |
-
|
| 490 |
-
def fake_start_session_artifact_collection_task(session, **kwargs):
|
| 491 |
-
started.append((session.session_id, kwargs["token"]))
|
| 492 |
-
return None
|
| 493 |
-
|
| 494 |
-
monkeypatch.setattr(
|
| 495 |
-
"session_manager.start_session_artifact_collection_task",
|
| 496 |
-
fake_start_session_artifact_collection_task,
|
| 497 |
-
)
|
| 498 |
-
manager._start_cpu_sandbox_preload = lambda _: None # type: ignore[method-assign]
|
| 499 |
-
|
| 500 |
-
try:
|
| 501 |
-
restored = await manager.ensure_session_loaded(
|
| 502 |
-
"persisted-session",
|
| 503 |
-
user_id="owner",
|
| 504 |
-
hf_token="token",
|
| 505 |
-
)
|
| 506 |
-
|
| 507 |
-
assert restored is not None
|
| 508 |
-
assert started == [("persisted-session", "token")]
|
| 509 |
finally:
|
| 510 |
stop.set()
|
| 511 |
await _cancel_runtime_tasks(manager)
|
|
|
|
| 425 |
|
| 426 |
assert scheduled == [session_id]
|
| 427 |
assert session_id in manager.sessions
|
| 428 |
+
runtime_session = manager.sessions[session_id].session
|
| 429 |
+
assert not hasattr(runtime_session, "_ml_intern_artifact_collection_task")
|
| 430 |
+
assert not hasattr(runtime_session, "_ml_intern_artifact_collection_slug")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
finally:
|
| 432 |
stop.set()
|
| 433 |
await _cancel_runtime_tasks(manager)
|
|
|
|
| 452 |
assert restored is not None
|
| 453 |
assert scheduled == ["persisted-session"]
|
| 454 |
assert "persisted-session" in manager.sessions
|
| 455 |
+
assert not hasattr(restored.session, "_ml_intern_artifact_collection_task")
|
| 456 |
+
assert not hasattr(restored.session, "_ml_intern_artifact_collection_slug")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
finally:
|
| 458 |
stop.set()
|
| 459 |
await _cancel_runtime_tasks(manager)
|
tests/unit/test_session_resume.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for ``agent.core.session_resume``."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from types import SimpleNamespace
|
| 8 |
+
|
| 9 |
+
from litellm import Message
|
| 10 |
+
|
| 11 |
+
from agent.core import session_resume
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _write_session_log(
|
| 15 |
+
directory: Path,
|
| 16 |
+
name: str,
|
| 17 |
+
*,
|
| 18 |
+
session_id: str,
|
| 19 |
+
content: str,
|
| 20 |
+
mtime: float,
|
| 21 |
+
user_id: str | None = "user-a",
|
| 22 |
+
extra_messages: list[dict] | None = None,
|
| 23 |
+
events: list[dict] | None = None,
|
| 24 |
+
) -> Path:
|
| 25 |
+
directory.mkdir(exist_ok=True)
|
| 26 |
+
path = directory / name
|
| 27 |
+
payload = {
|
| 28 |
+
"session_id": session_id,
|
| 29 |
+
"user_id": user_id,
|
| 30 |
+
"session_start_time": "2026-01-01T00:00:00",
|
| 31 |
+
"session_end_time": "2026-01-01T00:05:00",
|
| 32 |
+
"model_name": "openai/gpt-5.5",
|
| 33 |
+
"messages": [
|
| 34 |
+
{"role": "system", "content": "old system"},
|
| 35 |
+
{"role": "user", "content": content},
|
| 36 |
+
*(extra_messages or []),
|
| 37 |
+
],
|
| 38 |
+
"events": events
|
| 39 |
+
if events is not None
|
| 40 |
+
else [{"event_type": "turn_complete", "data": {}}],
|
| 41 |
+
}
|
| 42 |
+
path.write_text(json.dumps(payload))
|
| 43 |
+
os.utime(path, (mtime, mtime))
|
| 44 |
+
return path
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class _FakeContext:
|
| 48 |
+
def __init__(self) -> None:
|
| 49 |
+
self.items = [Message(role="system", content="current system")]
|
| 50 |
+
self.running_context_usage = 0
|
| 51 |
+
self.recompute_calls: list[str] = []
|
| 52 |
+
|
| 53 |
+
def _recompute_usage(self, model_name: str) -> None:
|
| 54 |
+
self.recompute_calls.append(model_name)
|
| 55 |
+
self.running_context_usage = 123
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class _FakeSession:
|
| 59 |
+
def __init__(self, *, user_id: str | None = "user-a") -> None:
|
| 60 |
+
self.context_manager = _FakeContext()
|
| 61 |
+
self.config = SimpleNamespace(model_name="moonshotai/Kimi-K2.6")
|
| 62 |
+
self.session_id = "current-session"
|
| 63 |
+
self.session_start_time = "2026-01-02T00:00:00"
|
| 64 |
+
self.user_id = user_id
|
| 65 |
+
self.logged_events: list[dict] = []
|
| 66 |
+
self._local_save_path: str | None = None
|
| 67 |
+
self.turn_count = 0
|
| 68 |
+
self.last_auto_save_turn = 0
|
| 69 |
+
self.pending_approval: dict | None = {"tool_calls": ["pending"]}
|
| 70 |
+
|
| 71 |
+
def update_model(self, model_name: str) -> None:
|
| 72 |
+
self.config.model_name = model_name
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def test_session_log_listing_newest_first(tmp_path):
|
| 76 |
+
log_dir = tmp_path / "session_logs"
|
| 77 |
+
older = _write_session_log(
|
| 78 |
+
log_dir,
|
| 79 |
+
"older.json",
|
| 80 |
+
session_id="older-session",
|
| 81 |
+
content="older prompt",
|
| 82 |
+
mtime=time.time() - 10,
|
| 83 |
+
)
|
| 84 |
+
newer = _write_session_log(
|
| 85 |
+
log_dir,
|
| 86 |
+
"newer.json",
|
| 87 |
+
session_id="newer-session",
|
| 88 |
+
content="newer prompt",
|
| 89 |
+
mtime=time.time(),
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
entries = session_resume.list_session_logs(log_dir)
|
| 93 |
+
|
| 94 |
+
assert [entry.path for entry in entries] == [newer, older]
|
| 95 |
+
assert entries[0].session_id == "newer-session"
|
| 96 |
+
assert entries[0].preview == "newer prompt"
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def test_restore_continues_when_user_id_matches(tmp_path):
|
| 100 |
+
log_dir = tmp_path / "session_logs"
|
| 101 |
+
path = _write_session_log(
|
| 102 |
+
log_dir,
|
| 103 |
+
"session.json",
|
| 104 |
+
session_id="saved-session",
|
| 105 |
+
content="continue this work",
|
| 106 |
+
mtime=time.time(),
|
| 107 |
+
user_id="user-a",
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
session = _FakeSession(user_id="user-a")
|
| 111 |
+
|
| 112 |
+
result = session_resume.restore_session_from_log(session, path)
|
| 113 |
+
|
| 114 |
+
assert result["restored_count"] == 1
|
| 115 |
+
assert result["dropped_count"] == 0
|
| 116 |
+
assert result["forked"] is False
|
| 117 |
+
assert result["model_name"] == "openai/gpt-5.5"
|
| 118 |
+
assert result["had_redacted_content"] is False
|
| 119 |
+
assert result["invalid_saved_model"] is None
|
| 120 |
+
assert session.config.model_name == "openai/gpt-5.5"
|
| 121 |
+
assert session.session_id == "saved-session"
|
| 122 |
+
# Source log path is never reused: future heartbeat saves write to a
|
| 123 |
+
# fresh file so the snapshot stays intact (regression: see source-log
|
| 124 |
+
# round-trip test below).
|
| 125 |
+
assert session._local_save_path is None
|
| 126 |
+
assert session.turn_count == 1
|
| 127 |
+
assert session.last_auto_save_turn == 1
|
| 128 |
+
assert session.pending_approval is None
|
| 129 |
+
assert [msg.role for msg in session.context_manager.items] == ["system", "user"]
|
| 130 |
+
assert session.context_manager.items[0].content == "current system"
|
| 131 |
+
assert session.context_manager.items[1].content == "continue this work"
|
| 132 |
+
assert session.context_manager.running_context_usage == 123
|
| 133 |
+
assert session.context_manager.recompute_calls == ["openai/gpt-5.5"]
|
| 134 |
+
assert len(session.logged_events) == 1
|
| 135 |
+
marker = session.logged_events[0]
|
| 136 |
+
assert marker["event_type"] == "resumed_from"
|
| 137 |
+
assert marker["data"]["forked"] is False
|
| 138 |
+
assert marker["data"]["original_session_id"] == "saved-session"
|
| 139 |
+
assert marker["data"]["original_event_count"] == 1
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def test_restore_forks_when_user_id_differs(tmp_path):
|
| 143 |
+
log_dir = tmp_path / "session_logs"
|
| 144 |
+
path = _write_session_log(
|
| 145 |
+
log_dir,
|
| 146 |
+
"session.json",
|
| 147 |
+
session_id="saved-session",
|
| 148 |
+
content="someone else's chat",
|
| 149 |
+
mtime=time.time(),
|
| 150 |
+
user_id="user-a",
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
session = _FakeSession(user_id="user-b")
|
| 154 |
+
original_session_id = session.session_id
|
| 155 |
+
original_start_time = session.session_start_time
|
| 156 |
+
|
| 157 |
+
result = session_resume.restore_session_from_log(session, path)
|
| 158 |
+
|
| 159 |
+
assert result["forked"] is True
|
| 160 |
+
assert session.session_id == original_session_id
|
| 161 |
+
assert session.session_start_time == original_start_time
|
| 162 |
+
assert session._local_save_path is None
|
| 163 |
+
marker = session.logged_events[0]
|
| 164 |
+
assert marker["event_type"] == "resumed_from"
|
| 165 |
+
assert marker["data"]["forked"] is True
|
| 166 |
+
assert marker["data"]["original_session_id"] == "saved-session"
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def test_restore_forks_when_one_side_is_anonymous(tmp_path):
|
| 170 |
+
log_dir = tmp_path / "session_logs"
|
| 171 |
+
path = _write_session_log(
|
| 172 |
+
log_dir,
|
| 173 |
+
"session.json",
|
| 174 |
+
session_id="saved-session",
|
| 175 |
+
content="anonymous save",
|
| 176 |
+
mtime=time.time(),
|
| 177 |
+
user_id=None,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
session = _FakeSession(user_id="user-a")
|
| 181 |
+
|
| 182 |
+
result = session_resume.restore_session_from_log(session, path)
|
| 183 |
+
|
| 184 |
+
assert result["forked"] is True
|
| 185 |
+
assert session._local_save_path is None
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def test_restore_continues_when_both_sides_anonymous(tmp_path):
|
| 189 |
+
log_dir = tmp_path / "session_logs"
|
| 190 |
+
path = _write_session_log(
|
| 191 |
+
log_dir,
|
| 192 |
+
"session.json",
|
| 193 |
+
session_id="saved-session",
|
| 194 |
+
content="local-only chat",
|
| 195 |
+
mtime=time.time(),
|
| 196 |
+
user_id=None,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
session = _FakeSession(user_id=None)
|
| 200 |
+
|
| 201 |
+
result = session_resume.restore_session_from_log(session, path)
|
| 202 |
+
|
| 203 |
+
assert result["forked"] is False
|
| 204 |
+
assert session.session_id == "saved-session"
|
| 205 |
+
assert session._local_save_path is None
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def test_restore_rejects_invalid_saved_model(tmp_path):
|
| 209 |
+
log_dir = tmp_path / "session_logs"
|
| 210 |
+
path = log_dir / "session.json"
|
| 211 |
+
log_dir.mkdir()
|
| 212 |
+
path.write_text(
|
| 213 |
+
json.dumps(
|
| 214 |
+
{
|
| 215 |
+
"session_id": "saved",
|
| 216 |
+
"user_id": "user-a",
|
| 217 |
+
"model_name": "not a real id with spaces",
|
| 218 |
+
"messages": [{"role": "user", "content": "hello"}],
|
| 219 |
+
"events": [],
|
| 220 |
+
}
|
| 221 |
+
)
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
session = _FakeSession(user_id="user-a")
|
| 225 |
+
original_model = session.config.model_name
|
| 226 |
+
|
| 227 |
+
result = session_resume.restore_session_from_log(session, path)
|
| 228 |
+
|
| 229 |
+
assert result["invalid_saved_model"] == "not a real id with spaces"
|
| 230 |
+
assert result["model_name"] == original_model
|
| 231 |
+
assert session.config.model_name == original_model
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def test_restore_counts_dropped_messages(tmp_path):
|
| 235 |
+
log_dir = tmp_path / "session_logs"
|
| 236 |
+
path = log_dir / "session.json"
|
| 237 |
+
log_dir.mkdir()
|
| 238 |
+
path.write_text(
|
| 239 |
+
json.dumps(
|
| 240 |
+
{
|
| 241 |
+
"session_id": "saved",
|
| 242 |
+
"user_id": "user-a",
|
| 243 |
+
"model_name": "openai/gpt-5.5",
|
| 244 |
+
"messages": [
|
| 245 |
+
{"role": "user", "content": "hi"},
|
| 246 |
+
{"role": "user", "content": 12345}, # invalid content type
|
| 247 |
+
],
|
| 248 |
+
"events": [],
|
| 249 |
+
}
|
| 250 |
+
)
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
session = _FakeSession(user_id="user-a")
|
| 254 |
+
|
| 255 |
+
result = session_resume.restore_session_from_log(session, path)
|
| 256 |
+
|
| 257 |
+
assert result["restored_count"] == 1
|
| 258 |
+
assert result["dropped_count"] == 1
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def test_restore_does_not_overwrite_source_log_on_save(tmp_path, monkeypatch):
|
| 262 |
+
"""Regression: resuming + saving must not destroy the source log on disk.
|
| 263 |
+
|
| 264 |
+
Without the always-fork ``_local_save_path`` reset, the next heartbeat
|
| 265 |
+
save would rewrite the source file with ``events=[resumed_from]`` and
|
| 266 |
+
``total_cost_usd=0``, wiping the original audit trail. This builds a
|
| 267 |
+
real ``Session`` and exercises the round-trip.
|
| 268 |
+
"""
|
| 269 |
+
monkeypatch.chdir(tmp_path)
|
| 270 |
+
|
| 271 |
+
from agent.context_manager.manager import ContextManager
|
| 272 |
+
from agent.core.session import Session
|
| 273 |
+
|
| 274 |
+
log_dir = tmp_path / "session_logs"
|
| 275 |
+
log_dir.mkdir()
|
| 276 |
+
src_path = log_dir / "src.json"
|
| 277 |
+
src_payload = {
|
| 278 |
+
"session_id": "saved-session",
|
| 279 |
+
"user_id": "user-a",
|
| 280 |
+
"session_start_time": "2026-01-01T00:00:00",
|
| 281 |
+
"session_end_time": "2026-01-01T00:05:00",
|
| 282 |
+
"model_name": "openai/gpt-5.5",
|
| 283 |
+
"messages": [
|
| 284 |
+
{"role": "system", "content": "old system"},
|
| 285 |
+
{"role": "user", "content": "earlier work"},
|
| 286 |
+
],
|
| 287 |
+
"events": [
|
| 288 |
+
{"event_type": "llm_call", "data": {"cost_usd": 0.42}},
|
| 289 |
+
{"event_type": "turn_complete", "data": {}},
|
| 290 |
+
],
|
| 291 |
+
}
|
| 292 |
+
src_path.write_text(json.dumps(src_payload, indent=2))
|
| 293 |
+
src_bytes_before = src_path.read_bytes()
|
| 294 |
+
|
| 295 |
+
class _Cfg:
|
| 296 |
+
model_name = "openai/gpt-5.5"
|
| 297 |
+
save_sessions = True
|
| 298 |
+
session_dataset_repo = None
|
| 299 |
+
auto_save_interval = 1
|
| 300 |
+
heartbeat_interval_s = 60
|
| 301 |
+
max_iterations = 10
|
| 302 |
+
yolo_mode = False
|
| 303 |
+
confirm_cpu_jobs = False
|
| 304 |
+
auto_file_upload = False
|
| 305 |
+
reasoning_effort = None
|
| 306 |
+
share_traces = False
|
| 307 |
+
personal_trace_repo_template = None
|
| 308 |
+
mcpServers: dict = {}
|
| 309 |
+
|
| 310 |
+
cm = ContextManager.__new__(ContextManager)
|
| 311 |
+
cm.items = [Message(role="system", content="current system")]
|
| 312 |
+
cm.tool_specs = []
|
| 313 |
+
cm.model_max_tokens = 200_000
|
| 314 |
+
cm.running_context_usage = 0
|
| 315 |
+
cm.compact_size = 0.1
|
| 316 |
+
cm.untouched_messages = 5
|
| 317 |
+
cm.hf_token = None
|
| 318 |
+
cm.local_mode = True
|
| 319 |
+
cm.system_prompt = "current system"
|
| 320 |
+
cm.on_message_added = None
|
| 321 |
+
|
| 322 |
+
import asyncio as _asyncio
|
| 323 |
+
|
| 324 |
+
session = Session(
|
| 325 |
+
event_queue=_asyncio.Queue(),
|
| 326 |
+
config=_Cfg(),
|
| 327 |
+
tool_router=None,
|
| 328 |
+
context_manager=cm,
|
| 329 |
+
hf_token=None,
|
| 330 |
+
user_id="user-a",
|
| 331 |
+
local_mode=True,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
session_resume.restore_session_from_log(session, src_path)
|
| 335 |
+
assert session._local_save_path is None
|
| 336 |
+
|
| 337 |
+
saved_path = session.save_trajectory_local(directory=str(log_dir))
|
| 338 |
+
|
| 339 |
+
assert saved_path is not None
|
| 340 |
+
assert Path(saved_path) != src_path
|
| 341 |
+
assert src_path.read_bytes() == src_bytes_before
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def test_restore_flags_redacted_messages(tmp_path):
|
| 345 |
+
log_dir = tmp_path / "session_logs"
|
| 346 |
+
path = _write_session_log(
|
| 347 |
+
log_dir,
|
| 348 |
+
"session.json",
|
| 349 |
+
session_id="saved-session",
|
| 350 |
+
content="my token is [REDACTED_HF_TOKEN]",
|
| 351 |
+
mtime=time.time(),
|
| 352 |
+
user_id="user-a",
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
session = _FakeSession(user_id="user-a")
|
| 356 |
+
|
| 357 |
+
result = session_resume.restore_session_from_log(session, path)
|
| 358 |
+
|
| 359 |
+
assert result["had_redacted_content"] is True
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def test_resolve_session_log_arg_accepts_index_and_id_prefix(tmp_path):
|
| 363 |
+
log_dir = tmp_path / "session_logs"
|
| 364 |
+
older = _write_session_log(
|
| 365 |
+
log_dir,
|
| 366 |
+
"older.json",
|
| 367 |
+
session_id="abcdef-older",
|
| 368 |
+
content="x",
|
| 369 |
+
mtime=time.time() - 10,
|
| 370 |
+
)
|
| 371 |
+
newer = _write_session_log(
|
| 372 |
+
log_dir,
|
| 373 |
+
"newer.json",
|
| 374 |
+
session_id="123456-newer",
|
| 375 |
+
content="y",
|
| 376 |
+
mtime=time.time(),
|
| 377 |
+
)
|
| 378 |
+
entries = session_resume.list_session_logs(log_dir)
|
| 379 |
+
|
| 380 |
+
assert session_resume.resolve_session_log_arg("1", entries, log_dir) == newer
|
| 381 |
+
assert session_resume.resolve_session_log_arg("abc", entries, log_dir) == older
|
| 382 |
+
assert session_resume.resolve_session_log_arg("nope", entries, log_dir) is None
|