File size: 1,892 Bytes
2d83c9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""proteus.cli.commands.memory — the ``memory`` subcommand (generate + save checkpoint)."""

from __future__ import annotations

import argparse
import sys

from proteus.game.agents import VanillaAgent
from proteus.game.engine.difficulty import Difficulty
from proteus.game.scenarios.base import list_scenarios
from proteus.providers import make_provider
from proteus.game.runtime.memory import save_checkpoint
from proteus.game.runtime.memory_gen import generate_memory
from proteus.cli.commands.run import _resolve_persona


def _cmd_memory(args: argparse.Namespace) -> int:
    try:
        provider = make_provider(args.model)
    except ValueError as exc:
        print(str(exc), file=sys.stderr)
        return 2
    if args.scenario not in list_scenarios():
        print(
            f"Unknown scenario {args.scenario!r}. "
            f"Available scenarios: {', '.join(list_scenarios())}.",
            file=sys.stderr,
        )
        return 2
    persona, persona_err = _resolve_persona(args.persona)
    if persona_err is not None:
        print(persona_err, file=sys.stderr)
        return 2
    agent = VanillaAgent(provider)
    ckpt = generate_memory(
        args.scenario, agent,
        difficulty=Difficulty(args.difficulty), seed=args.seed,
        memory_turns=args.memory_turns, model_name=provider.model_name,
        persona=persona,
    )
    if args.out:
        from pathlib import Path

        path = Path(args.out)
        path.parent.mkdir(parents=True, exist_ok=True)
        path.write_text(ckpt.model_dump_json(), encoding="utf-8")
        written = path
    else:
        written = save_checkpoint(ckpt, root=args.memory_root)
    print(
        f"memory {ckpt.scenario} seed={ckpt.seed} {ckpt.difficulty} "
        f"model={ckpt.model} turns={len(ckpt.memory_turns)} -> {ckpt.outcome}"
    )
    print(f"checkpoint written to {written}")
    return 0