File size: 13,025 Bytes
ddbc1ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
"""
train.py β€” LifeStack Training Loop

Runs a curriculum of episodes at increasing difficulty, logs rewards,
generates a learning curve plot, and compares agent performance
before and after memory accumulation.
"""

import sys, os; sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import json
import random
import shutil
import matplotlib
matplotlib.use("Agg")  # Non-interactive backend β€” safe for headless runs
import matplotlib.pyplot as plt

from scripts.run_episode import run_episode
from agent.memory import LifeStackMemory
from agent.agent import LifeStackAgent


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _difficulty_for_episode(episode: int) -> int:
    """Curriculum schedule: easy β†’ medium β†’ hard β†’ extreme."""
    if episode <= 25:
        return random.randint(1, 2)
    elif episode <= 50:
        return random.randint(2, 3)
    elif episode <= 75:
        return random.randint(3, 4)
    else:
        return random.randint(4, 5)


def _rolling_avg(values: list, window: int = 5) -> list:
    """Compute a simple rolling average with the given window."""
    out = []
    for i in range(len(values)):
        start = max(0, i - window + 1)
        out.append(sum(values[start : i + 1]) / (i - start + 1))
    return out


def _phase_avg(rewards: list, start: int, end: int) -> float:
    """Average reward for 1-indexed episodes [start, end]."""
    subset = rewards[start - 1 : end]
    return round(sum(subset) / len(subset), 3) if subset else 0.0


# ---------------------------------------------------------------------------
# Main training function
# ---------------------------------------------------------------------------

def run_training(n_episodes: int = 50, save_plot: bool = True) -> dict:
    """
    Runs the full LifeStack curriculum training loop.

    Returns:
        summary dict with per-episode logs and phase averages.
    """
    episode_log = []
    rewards = []
    agent_history = []

    print(f"\n{'═' * 50}")
    print(f"  LIFESTACK TRAINING β€” {n_episodes} EPISODES")
    print(f"{'═' * 50}\n")

    # Initialize shared instances once β€” avoids reloading model weights each episode
    print("  Initializing shared agent and memory (one-time load)...")
    shared_memory = LifeStackMemory(silent=True)  # suppress per-decision spam
    shared_agent  = LifeStackAgent()
    print("  βœ… Ready.\n")

    for ep in range(1, n_episodes + 1):
        difficulty = _difficulty_for_episode(ep)

        # Run episode with shared memory + agent + history tracking
        result = run_episode(difficulty=difficulty, verbose=False,
                             memory=shared_memory, agent=shared_agent,
                             agent_history=agent_history)

        total_reward = result["total_reward"]
        rewards.append(total_reward)
        agent_history.append((result["initial_conflict_id"], total_reward))

        record = {
            "episode": ep,
            "reward": total_reward,
            "difficulty": difficulty,
            "person": result["person"],
            "conflicts_seen": result["conflicts_seen"],
            "steps": result["steps"],
        }
        episode_log.append(record)

        # Progress: print every episode
        mem_count = result["memory_stats"]["total_memories"]
        print(
            f"  Episode {ep:>3}/{n_episodes} | "
            f"Reward: {total_reward:.3f} | "
            f"Difficulty: {difficulty} | "
            f"Memories: {mem_count}"
        )

    # ------------------------------------------------------------------
    # Phase averages
    # ------------------------------------------------------------------
    early_avg = _phase_avg(rewards, 1, 25)
    mid_avg   = _phase_avg(rewards, 26, 50)
    late_avg  = _phase_avg(rewards, 51, 75)
    final_avg = _phase_avg(rewards, 76, n_episodes)
    overall   = round(sum(rewards) / len(rewards), 3)

    print(f"\n{'═' * 42}")
    print(f"  TRAINING SUMMARY")
    print(f"{'═' * 42}")
    print(f"  {'Phase':<10} {'Episodes':<12} {'Avg Reward'}")
    print(f"  {'-'*38}")
    print(f"  {'Early':<10} {'1-25':<12} {early_avg:.3f}")
    print(f"  {'Mid':<10} {'26-50':<12} {mid_avg:.3f}")
    print(f"  {'Late':<10} {'51-75':<12} {late_avg:.3f}")
    print(f"  {'Final':<10} {'76-' + str(n_episodes):<12} {final_avg:.3f}")
    print(f"  {'Overall':<10} {'1-' + str(n_episodes):<12} {overall:.3f}")
    print(f"{'═' * 42}\n")

    # ------------------------------------------------------------------
    # Save training log
    # ------------------------------------------------------------------
    log_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data", "training_log.json")
    with open(log_path, "w") as f:
        json.dump(episode_log, f, indent=2)
    print(f"  πŸ“„ Training log saved β†’ {log_path}")

    # ------------------------------------------------------------------
    # Matplotlib learning curve
    # ------------------------------------------------------------------
    if save_plot:
        ep_nums = [r["episode"] for r in episode_log]
        raw     = [r["reward"]  for r in episode_log]
        rolling = _rolling_avg(raw, window=5)

        fig, ax = plt.subplots(figsize=(12, 5))
        ax.plot(ep_nums, raw,     color="steelblue", alpha=0.6, linewidth=1.2, label="Episode Reward")
        ax.plot(ep_nums, rolling, color="crimson",   linewidth=2.0, linestyle="--", label="5-Episode Rolling Avg")
        ax.axhline(y=0, color="gray", linewidth=0.8, linestyle="--", alpha=0.7)

        # Phase boundary shading
        ax.axvspan(1,  25, alpha=0.04, color="green",  label="Easy (diff 1-2)")
        ax.axvspan(26, 50, alpha=0.04, color="orange", label="Mid (diff 2-3)")
        ax.axvspan(51, 75, alpha=0.04, color="red",    label="Hard (diff 3-4)")
        ax.axvspan(76, n_episodes, alpha=0.04, color="purple", label="Extreme (diff 4-5)")

        ax.set_title("LifeStack Agent Learning Curve", fontsize=14, fontweight="bold")
        ax.set_xlabel("Episode", fontsize=11)
        ax.set_ylabel("Total Reward", fontsize=11)
        ax.legend(fontsize=9)
        ax.grid(True, alpha=0.3)
        fig.tight_layout()

        plot_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data", "reward_curve.png")
        fig.savefig(plot_path, dpi=150)
        plt.close(fig)
        print(f"  πŸ“Š Learning curve saved β†’ {plot_path}")

    # ------------------------------------------------------------------
    # BEHAVIORAL COMPARISON β€” Friday 6PM (5 runs each)
    # ------------------------------------------------------------------
    N_COMPARE = 5
    print(f"\n{'═' * 58}")
    print(f"  BEHAVIORAL COMPARISON β€” Friday 6PM Crisis ({N_COMPARE} runs each)")
    print(f"{'═' * 58}")

    memory_dir    = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "lifestack_memory")
    memory_backup = memory_dir + "_backup"

    # --- WITHOUT memory: temporarily hide the ChromaDB folder ---
    had_memory = os.path.exists(memory_dir)
    if had_memory:
        shutil.move(memory_dir, memory_backup)

    no_mem_results = []
    try:
        for i in range(N_COMPARE):
            result = run_episode(difficulty=5, verbose=False)
            first_step = result["step_log"][0] if result["step_log"] else {}
            has_comm = any(
                s.get("action") == "communicate" for s in result["step_log"]
            )
            no_mem_results.append({
                "run": i + 1,
                "total_reward": result["total_reward"],
                "first_action": first_step.get("action", "unknown"),
                "first_domain": first_step.get("domain", "unknown"),
                "has_communication": has_comm,
                "steps": result["steps"],
            })
    finally:
        # Restore memory
        if had_memory and os.path.exists(memory_backup):
            if os.path.exists(memory_dir):
                shutil.rmtree(memory_dir)
            shutil.move(memory_backup, memory_dir)

    # --- WITH memory ---
    with_mem_results = []
    for i in range(N_COMPARE):
        result = run_episode(difficulty=5, verbose=False)
        first_step = result["step_log"][0] if result["step_log"] else {}
        has_comm = any(
            s.get("action") == "communicate" for s in result["step_log"]
        )
        with_mem_results.append({
            "run": i + 1,
            "total_reward": result["total_reward"],
            "first_action": first_step.get("action", "unknown"),
            "first_domain": first_step.get("domain", "unknown"),
            "has_communication": has_comm,
            "steps": result["steps"],
        })

    # --- Compute stats ---
    avg_no  = sum(r["total_reward"] for r in no_mem_results)   / N_COMPARE
    avg_yes = sum(r["total_reward"] for r in with_mem_results) / N_COMPARE
    improvement = avg_yes - avg_no
    pct = (improvement / abs(avg_no) * 100) if avg_no != 0 else 0

    # Most common first action
    from collections import Counter
    no_actions  = Counter(r["first_action"] for r in no_mem_results)
    yes_actions = Counter(r["first_action"] for r in with_mem_results)
    no_domains  = Counter(r["first_domain"] for r in no_mem_results)
    yes_domains = Counter(r["first_domain"] for r in with_mem_results)
    no_comm_pct  = sum(1 for r in no_mem_results  if r["has_communication"]) / N_COMPARE * 100
    yes_comm_pct = sum(1 for r in with_mem_results if r["has_communication"]) / N_COMPARE * 100

    # --- Print table ---
    print(f"\n  {'WITHOUT MEMORY':<28} {'WITH MEMORY':<28}")
    for i in range(N_COMPARE):
        nr = no_mem_results[i]
        wr = with_mem_results[i]
        print(f"  Run {nr['run']}: {nr['total_reward']:.3f} "
              f"({nr['first_action']:<14})"
              f"  Run {wr['run']}: {wr['total_reward']:.3f} "
              f"({wr['first_action']:<14})")
    print(f"  {'─' * 54}")
    print(f"  Avg:   {avg_no:.3f}                    Avg:   {avg_yes:.3f}")
    sign = "+" if improvement >= 0 else ""
    print(f"  Improvement: {sign}{improvement:.3f} ({sign}{pct:.1f}%)")

    print(f"\n  {'─' * 54}")
    print(f"  Most common 1st action WITHOUT memory: {no_actions.most_common(1)[0][0]}")
    print(f"  Most common 1st action WITH memory:    {yes_actions.most_common(1)[0][0]}")
    print(f"  Most common 1st domain WITHOUT memory: {no_domains.most_common(1)[0][0]}")
    print(f"  Most common 1st domain WITH memory:    {yes_domains.most_common(1)[0][0]}")
    print(f"  Communication used WITHOUT memory:     {no_comm_pct:.0f}% of runs")
    print(f"  Communication used WITH memory:        {yes_comm_pct:.0f}% of runs")

    # --- Behavioral insight ---
    if yes_actions.most_common(1)[0][0] != no_actions.most_common(1)[0][0]:
        print(f"\n  πŸ’‘ Memory changed the agent's primary strategy from "
              f"'{no_actions.most_common(1)[0][0]}' to '{yes_actions.most_common(1)[0][0]}'")
    if yes_comm_pct > no_comm_pct:
        print(f"  πŸ’‘ Memory taught the agent to include communication actions more often")
    print(f"{'═' * 58}\n")

    # --- Save comparison ---
    comparison = {
        "scenario": "Friday 6PM (difficulty 5)",
        "runs_per_condition": N_COMPARE,
        "without_memory": {
            "results": no_mem_results,
            "avg_reward": round(avg_no, 3),
            "most_common_first_action": no_actions.most_common(1)[0][0],
            "most_common_first_domain": no_domains.most_common(1)[0][0],
            "communication_rate": round(no_comm_pct, 1),
        },
        "with_memory": {
            "results": with_mem_results,
            "avg_reward": round(avg_yes, 3),
            "most_common_first_action": yes_actions.most_common(1)[0][0],
            "most_common_first_domain": yes_domains.most_common(1)[0][0],
            "communication_rate": round(yes_comm_pct, 1),
        },
        "improvement": {
            "absolute": round(improvement, 3),
            "percentage": round(pct, 1),
        },
    }
    comp_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data", "before_after_comparison.json")
    with open(comp_path, "w") as f:
        json.dump(comparison, f, indent=2)
    print(f"  πŸ“„ Behavioral comparison saved β†’ {comp_path}")

    return {
        "episode_log": episode_log,
        "phase_averages": {
            "early": early_avg,
            "mid": mid_avg,
            "late": late_avg,
            "final": final_avg,
            "overall": overall,
        },
        "comparison": comparison,
    }


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

def main():
    run_training(n_episodes=100)


if __name__ == "__main__":
    main()