File size: 11,427 Bytes
77da5ce | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 | """
run_episode.py β LifeStack Full Episode Runner
Orchestrates a complete episode:
1. Generate a Task (with correct horizon from task.horizon) and a ConflictEvent
2. Initialize environment, agent, person, and memory
3. Loop up to task.horizon steps: agent decides β action applied β reward computed β memory updated
4. Print a rich episode summary at the end
"""
import sys, os; sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import random
from core.life_state import LifeMetrics, ResourceBudget
from core.lifestack_env import LifeStackEnv, LifeStackAction
from agent.agent import LifeStackAgent
from intake.simperson import SimPerson
from agent.conflict_generator import generate_conflict, escalate_conflict, adaptive_escalate, TaskGenerator
from core.action_space import apply_action, validate_action
from agent.memory import LifeStackMemory
from core.reward import compute_reward
import copy
_TASK_GENERATOR = TaskGenerator()
def run_episode(
difficulty: int = None,
verbose: bool = True,
memory: "LifeStackMemory" = None,
agent: "LifeStackAgent" = None,
agent_history: list = None,
model_path: str = None,
) -> dict:
"""
Runs one full LifeStack episode.
Args:
memory: Optional shared LifeStackMemory instance (avoids re-loading the
sentence-transformer model on every episode).
agent: Optional shared LifeStackAgent instance (avoids re-creating the
Groq client on every episode).
agent_history: Optional list of (conflict_id, reward) tuples from prior
episodes. Used by adaptive_escalate to decide difficulty.
Returns:
summary dict with total_reward, steps, final_metrics, conflicts_seen
"""
# --------------------------------------------------
# 1. SETUP
# --------------------------------------------------
if agent is None:
agent = LifeStackAgent(local_model_path=model_path)
if memory is None:
memory = LifeStackMemory()
if agent_history is None:
agent_history = []
# Pick a SimPerson from a diverse pool
person_pool = [
SimPerson(name="Alex (Executive)", openness=0.4, conscientiousness=0.9, extraversion=0.7, agreeableness=0.25, neuroticism=0.8),
SimPerson(name="Chloe (Creative)", openness=0.9, conscientiousness=0.2, extraversion=0.5, agreeableness=0.70, neuroticism=0.15),
SimPerson(name="Sam (Introvert)", openness=0.5, conscientiousness=0.6, extraversion=0.1, agreeableness=0.65, neuroticism=0.9),
SimPerson(name="Maya (Family)", openness=0.5, conscientiousness=0.7, extraversion=0.5, agreeableness=0.95, neuroticism=0.3),
SimPerson(name="Leo (Student)", openness=0.85,conscientiousness=0.8, extraversion=0.4, agreeableness=0.4, neuroticism=0.55),
]
person = random.choice(person_pool)
# --- FIX: Generate a Task object so task.horizon is respected ---
# Determine domain from difficulty: easy conflicts β flight_crisis, harder β code_merge_crisis
domain = "flight_crisis" if (difficulty or 2) <= 3 else "code_merge_crisis"
task = _TASK_GENERATOR.generate(domain=domain, difficulty=difficulty or random.randint(1, 3))
# Generate starting conflict (legacy ConflictEvent for disruption/budget)
conflict = generate_conflict(difficulty)
initial_conflict_id = conflict.id
# --- FIX: Create env with task so max_steps = task.horizon (NOT hardcoded 5) ---
env = LifeStackEnv(task=task)
# Apply initial disruption to env; pass task= so reset() uses task.horizon
obs = env.reset(task=task, conflict=conflict, budget=conflict.resource_budget,
person=person, agent_history=agent_history)
done = obs.done
# --------------------------------------------------
# 2. EPISODE LOOP
# --------------------------------------------------
total_reward = 0.0
step_log = []
conflicts_seen = [conflict.title]
route_taken = []
initial_metrics_flat = env.state.current_metrics.flatten()
if verbose:
print("\n" + "β" * 60)
print(f" LIFESTACK EPISODE β {conflict.title}")
print(f" Person : {person.name}")
print(f" Hint : {person.get_personality_hint()}")
print(f" Story : {conflict.story}")
print("β" * 60)
env.render()
while not done:
step = obs.step
# Inject few-shot context into agent memory
few_shot = memory.build_few_shot_prompt(conflict.title, env.state.current_metrics.flatten())
# Agent decision
metrics_before = copy.deepcopy(env.state.current_metrics)
budget_before = copy.deepcopy(env.state.budget)
action = agent.get_action(env.state.current_metrics, env.state.budget, conflict, person, few_shot_context=few_shot)
# Validate resource cost
is_valid, reason = validate_action(action, env.state.budget)
if not is_valid:
if verbose:
print(f"\n β οΈ Step {step+1}: Action unaffordable ({reason}). Forcing rest.")
action.primary.metric_changes = {"mental_wellbeing.stress_level": -3.0}
action.primary.resource_cost = {}
# Scale metric changes by personality uptake
current_stress = env.state.current_metrics.mental_wellbeing.stress_level
uptake_score = person.respond_to_action(
action.primary.action_type,
action.primary.resource_cost,
current_stress
)
scaled_changes = {}
# Make sure that path format is 'domain.submetric'
for path, delta in action.primary.metric_changes.items():
if '.' not in path: # Prepend target_domain if the LLM forgot it
path = f"{action.primary.target_domain}.{path}"
# ensure float conversion just in case LLM put strings
try:
scaled_changes[path] = float(delta) * uptake_score
except ValueError:
pass
# Apply action through environment
env_action = LifeStackAction.from_agent_action(action)
# Apply scaled changes
env_action.metric_changes = scaled_changes
obs = env.step(env_action)
step_reward = obs.reward or 0.0
done = obs.done
total_reward += step_reward
# Store in transient agent memory
agent.store_decision(action, step_reward)
route_taken.append(f"{action.primary.action_type}({action.primary.target_domain})")
# Log the step
step_log.append({
"step": step + 1,
"action": action.primary.action_type,
"domain": action.primary.target_domain,
"description": action.primary.description,
"reward": round(step_reward, 3),
"penalties": obs.metadata.get("breakdown", {}).get("penalties_fired", [])
})
if verbose:
print(f"\n{'β'*60}")
print(f" STEP {step+1} β {action.primary.action_type.upper()} on {action.primary.target_domain}")
print(f" \"{action.primary.description}\"")
if action.communication:
print(f" π¬ [{action.communication.recipient}] ({action.communication.tone}): {action.communication.content}")
print(f" Reward: {step_reward:.3f} | Penalties: {obs.metadata.get('breakdown', {}).get('penalties_fired') or 'none'}")
# Print Drift/Escalation info from metadata.info
for msg in obs.metadata.get("info", []):
if msg.startswith("DRIFT:"):
print(f"\n[DRIFT] {msg[6:]}")
if msg.startswith("ESCALATION:"):
parts = msg[11:].split(" -> ")
reason = parts[0]
new_title = parts[1]
conflicts_seen.append(new_title)
print(f"\nπ₯ ADAPTIVE ESCALATION: {reason}")
print(f" New conflict: {new_title}")
env.render()
# --------------------------------------------------
# 3. EPISODE SUMMARY
# --------------------------------------------------
final_flat = env.state.current_metrics.flatten()
# Calculate difference string
diffs = []
for k, v_end in final_flat.items():
v_start = initial_metrics_flat.get(k, 0.0)
delta = v_end - v_start
if abs(delta) >= 1.0:
name = k.split('.')[-1]
sign = "+" if delta > 0 else ""
diffs.append(f"{name}:{sign}{delta:.1f}")
metrics_diff_str = ", ".join(diffs) if diffs else "no_change"
# Store full trajectory in ChromaDB
memory.store_trajectory(
conflict_title=conflict.title,
route_taken=" -> ".join(route_taken),
total_reward=total_reward,
metrics_diff_str=metrics_diff_str,
reasoning=f"Resolved with {env.state.step_count} steps. End critical: {len([k for k, v in final_flat.items() if v < 20])}"
)
final_flat = env.state.current_metrics.flatten()
critical = [k for k, v in final_flat.items() if v < 20]
improved = [k for k, v in final_flat.items() if v > 70]
mem_stats = memory.get_stats()
if verbose:
print("\n" + "β" * 60)
print(" EPISODE COMPLETE β FINAL SUMMARY")
print("β" * 60)
print(f" Person : {person.name}")
print(f" Conflicts Seen : {' β '.join(conflicts_seen)}")
print(f" Steps Taken : {env.state.step_count}")
print(f" Total Reward : {total_reward:.4f}")
print(f" Critical (<20) : {critical or 'None'}")
print(f" Thriving (>70) : {len(improved)} metrics")
print(f"\n Step-by-Step Log:")
for s in step_log:
flag = " β οΈ " if s["penalties"] else " β
"
print(f" {flag} Step {s['step']}: [{s['action']}] on {s['domain']} β {s['reward']:.3f}")
print(f"\n Memory Bank : {mem_stats['total_memories']} decisions stored (avg reward: {mem_stats['average_reward']})")
print("β" * 60)
return {
"person": person.name,
"initial_conflict_id": initial_conflict_id,
"total_reward": round(total_reward, 4),
"steps": env.state.step_count,
"conflicts_seen": conflicts_seen,
"critical_metrics": critical,
"thriving_count": len(improved),
"step_log": step_log,
"memory_stats": mem_stats
}
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model", default=None, help="Path to trained GRPO model (default: auto-detect ./lifestack_model or LIFESTACK_MODEL_PATH)")
parser.add_argument("--difficulty", type=int, default=None, help="Fixed difficulty 1-5 (default: varies)")
args = parser.parse_args()
shared_agent = LifeStackAgent(local_model_path=args.model)
shared_memory = LifeStackMemory(silent=True)
difficulties = [args.difficulty] * 3 if args.difficulty else [2, 3, 5]
for d in difficulties:
print(f"\n{'β'*60}")
print(f" STARTING EPISODE AT DIFFICULTY {d}")
print(f"{'β'*60}")
summary = run_episode(difficulty=d, verbose=True, agent=shared_agent, memory=shared_memory)
print(f"\n β Total Reward: {summary['total_reward']}")
|