File size: 11,427 Bytes
77da5ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
"""
run_episode.py β€” LifeStack Full Episode Runner

Orchestrates a complete episode:
  1. Generate a Task (with correct horizon from task.horizon) and a ConflictEvent
  2. Initialize environment, agent, person, and memory
  3. Loop up to task.horizon steps: agent decides β†’ action applied β†’ reward computed β†’ memory updated
  4. Print a rich episode summary at the end
"""

import sys, os; sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import random
from core.life_state import LifeMetrics, ResourceBudget
from core.lifestack_env import LifeStackEnv, LifeStackAction
from agent.agent import LifeStackAgent
from intake.simperson import SimPerson
from agent.conflict_generator import generate_conflict, escalate_conflict, adaptive_escalate, TaskGenerator
from core.action_space import apply_action, validate_action
from agent.memory import LifeStackMemory
from core.reward import compute_reward
import copy

_TASK_GENERATOR = TaskGenerator()


def run_episode(
    difficulty: int = None,
    verbose: bool = True,
    memory: "LifeStackMemory" = None,
    agent: "LifeStackAgent" = None,
    agent_history: list = None,
    model_path: str = None,
) -> dict:
    """
    Runs one full LifeStack episode.

    Args:
        memory: Optional shared LifeStackMemory instance (avoids re-loading the
                sentence-transformer model on every episode).
        agent:  Optional shared LifeStackAgent instance (avoids re-creating the
                Groq client on every episode).
        agent_history: Optional list of (conflict_id, reward) tuples from prior
                       episodes. Used by adaptive_escalate to decide difficulty.

    Returns:
        summary dict with total_reward, steps, final_metrics, conflicts_seen
    """
    # --------------------------------------------------
    # 1. SETUP
    # --------------------------------------------------
    if agent is None:
        agent = LifeStackAgent(local_model_path=model_path)
    if memory is None:
        memory = LifeStackMemory()
    if agent_history is None:
        agent_history = []

    # Pick a SimPerson from a diverse pool
    person_pool = [
        SimPerson(name="Alex (Executive)",    openness=0.4, conscientiousness=0.9, extraversion=0.7,  agreeableness=0.25, neuroticism=0.8),
        SimPerson(name="Chloe (Creative)",    openness=0.9, conscientiousness=0.2, extraversion=0.5,  agreeableness=0.70, neuroticism=0.15),
        SimPerson(name="Sam (Introvert)",     openness=0.5, conscientiousness=0.6, extraversion=0.1,  agreeableness=0.65, neuroticism=0.9),
        SimPerson(name="Maya (Family)",       openness=0.5, conscientiousness=0.7, extraversion=0.5,  agreeableness=0.95, neuroticism=0.3),
        SimPerson(name="Leo (Student)",       openness=0.85,conscientiousness=0.8, extraversion=0.4,  agreeableness=0.4,  neuroticism=0.55),
    ]
    person = random.choice(person_pool)

    # --- FIX: Generate a Task object so task.horizon is respected ---
    # Determine domain from difficulty: easy conflicts β†’ flight_crisis, harder β†’ code_merge_crisis
    domain = "flight_crisis" if (difficulty or 2) <= 3 else "code_merge_crisis"
    task = _TASK_GENERATOR.generate(domain=domain, difficulty=difficulty or random.randint(1, 3))

    # Generate starting conflict (legacy ConflictEvent for disruption/budget)
    conflict = generate_conflict(difficulty)
    initial_conflict_id = conflict.id

    # --- FIX: Create env with task so max_steps = task.horizon (NOT hardcoded 5) ---
    env = LifeStackEnv(task=task)

    # Apply initial disruption to env; pass task= so reset() uses task.horizon
    obs = env.reset(task=task, conflict=conflict, budget=conflict.resource_budget,
                    person=person, agent_history=agent_history)
    done = obs.done

    # --------------------------------------------------
    # 2. EPISODE LOOP
    # --------------------------------------------------
    total_reward = 0.0
    step_log = []
    conflicts_seen = [conflict.title]
    route_taken = []
    initial_metrics_flat = env.state.current_metrics.flatten()

    if verbose:
        print("\n" + "β—†" * 60)
        print(f"  LIFESTACK EPISODE β€” {conflict.title}")
        print(f"  Person  : {person.name}")
        print(f"  Hint    : {person.get_personality_hint()}")
        print(f"  Story   : {conflict.story}")
        print("β—†" * 60)
        env.render()

    while not done:
        step = obs.step

        # Inject few-shot context into agent memory
        few_shot = memory.build_few_shot_prompt(conflict.title, env.state.current_metrics.flatten())
        
        # Agent decision
        metrics_before = copy.deepcopy(env.state.current_metrics)
        budget_before = copy.deepcopy(env.state.budget)
        
        action = agent.get_action(env.state.current_metrics, env.state.budget, conflict, person, few_shot_context=few_shot)

        # Validate resource cost
        is_valid, reason = validate_action(action, env.state.budget)
        if not is_valid:
            if verbose:
                print(f"\n  ⚠️  Step {step+1}: Action unaffordable ({reason}). Forcing rest.")
            action.primary.metric_changes = {"mental_wellbeing.stress_level": -3.0}
            action.primary.resource_cost = {}

        # Scale metric changes by personality uptake
        current_stress = env.state.current_metrics.mental_wellbeing.stress_level
        uptake_score = person.respond_to_action(
            action.primary.action_type, 
            action.primary.resource_cost, 
            current_stress
        )
        scaled_changes = {}
        # Make sure that path format is 'domain.submetric'
        for path, delta in action.primary.metric_changes.items():
            if '.' not in path: # Prepend target_domain if the LLM forgot it
                path = f"{action.primary.target_domain}.{path}"
            # ensure float conversion just in case LLM put strings
            try:
                scaled_changes[path] = float(delta) * uptake_score
            except ValueError:
                pass

        # Apply action through environment
        env_action = LifeStackAction.from_agent_action(action)
        # Apply scaled changes
        env_action.metric_changes = scaled_changes
        obs = env.step(env_action)
        step_reward = obs.reward or 0.0
        done = obs.done
        total_reward += step_reward

        # Store in transient agent memory
        agent.store_decision(action, step_reward)
        route_taken.append(f"{action.primary.action_type}({action.primary.target_domain})")

        # Log the step
        step_log.append({
            "step": step + 1,
            "action": action.primary.action_type,
            "domain": action.primary.target_domain,
            "description": action.primary.description,
            "reward": round(step_reward, 3),
            "penalties": obs.metadata.get("breakdown", {}).get("penalties_fired", [])
        })

        if verbose:
            print(f"\n{'─'*60}")
            print(f"  STEP {step+1} β†’ {action.primary.action_type.upper()} on {action.primary.target_domain}")
            print(f"  \"{action.primary.description}\"")
            if action.communication:
                print(f"  πŸ’¬ [{action.communication.recipient}] ({action.communication.tone}): {action.communication.content}")
            print(f"  Reward: {step_reward:.3f} | Penalties: {obs.metadata.get('breakdown', {}).get('penalties_fired') or 'none'}")
            
            # Print Drift/Escalation info from metadata.info
            for msg in obs.metadata.get("info", []):
                if msg.startswith("DRIFT:"):
                    print(f"\n[DRIFT] {msg[6:]}")
                if msg.startswith("ESCALATION:"):
                    parts = msg[11:].split(" -> ")
                    reason = parts[0]
                    new_title = parts[1]
                    conflicts_seen.append(new_title)
                    print(f"\nπŸ”₯ ADAPTIVE ESCALATION: {reason}")
                    print(f"   New conflict: {new_title}")
                    
            env.render()

    # --------------------------------------------------
    # 3. EPISODE SUMMARY
    # --------------------------------------------------
    final_flat = env.state.current_metrics.flatten()
    
    # Calculate difference string
    diffs = []
    for k, v_end in final_flat.items():
        v_start = initial_metrics_flat.get(k, 0.0)
        delta = v_end - v_start
        if abs(delta) >= 1.0:
            name = k.split('.')[-1]
            sign = "+" if delta > 0 else ""
            diffs.append(f"{name}:{sign}{delta:.1f}")
    metrics_diff_str = ", ".join(diffs) if diffs else "no_change"

    # Store full trajectory in ChromaDB
    memory.store_trajectory(
        conflict_title=conflict.title,
        route_taken=" -> ".join(route_taken),
        total_reward=total_reward,
        metrics_diff_str=metrics_diff_str,
        reasoning=f"Resolved with {env.state.step_count} steps. End critical: {len([k for k, v in final_flat.items() if v < 20])}"
    )
    final_flat = env.state.current_metrics.flatten()
    critical = [k for k, v in final_flat.items() if v < 20]
    improved = [k for k, v in final_flat.items() if v > 70]
    mem_stats = memory.get_stats()

    if verbose:
        print("\n" + "β–ˆ" * 60)
        print("  EPISODE COMPLETE β€” FINAL SUMMARY")
        print("β–ˆ" * 60)
        print(f"  Person         : {person.name}")
        print(f"  Conflicts Seen : {' β†’ '.join(conflicts_seen)}")
        print(f"  Steps Taken    : {env.state.step_count}")
        print(f"  Total Reward   : {total_reward:.4f}")
        print(f"  Critical (<20) : {critical or 'None'}")
        print(f"  Thriving (>70) : {len(improved)} metrics")
        print(f"\n  Step-by-Step Log:")
        for s in step_log:
            flag = " ⚠️ " if s["penalties"] else "  βœ…"
            print(f"  {flag} Step {s['step']}: [{s['action']}] on {s['domain']} β†’ {s['reward']:.3f}")
        print(f"\n  Memory Bank    : {mem_stats['total_memories']} decisions stored (avg reward: {mem_stats['average_reward']})")
        print("β–ˆ" * 60)

    return {
        "person": person.name,
        "initial_conflict_id": initial_conflict_id,
        "total_reward": round(total_reward, 4),
        "steps": env.state.step_count,
        "conflicts_seen": conflicts_seen,
        "critical_metrics": critical,
        "thriving_count": len(improved),
        "step_log": step_log,
        "memory_stats": mem_stats
    }


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", default=None, help="Path to trained GRPO model (default: auto-detect ./lifestack_model or LIFESTACK_MODEL_PATH)")
    parser.add_argument("--difficulty", type=int, default=None, help="Fixed difficulty 1-5 (default: varies)")
    args = parser.parse_args()

    shared_agent = LifeStackAgent(local_model_path=args.model)
    shared_memory = LifeStackMemory(silent=True)

    difficulties = [args.difficulty] * 3 if args.difficulty else [2, 3, 5]
    for d in difficulties:
        print(f"\n{'═'*60}")
        print(f"  STARTING EPISODE AT DIFFICULTY {d}")
        print(f"{'═'*60}")
        summary = run_episode(difficulty=d, verbose=True, agent=shared_agent, memory=shared_memory)
        print(f"\n  β†’ Total Reward: {summary['total_reward']}")