File size: 13,012 Bytes
60fc766
 
 
ad01980
 
 
60fc766
 
 
 
 
 
 
 
 
 
 
 
ad01980
 
 
 
60fc766
 
 
 
 
 
 
 
ad01980
 
 
 
60fc766
 
ad01980
60fc766
ad01980
 
60fc766
ad01980
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60fc766
 
ad01980
60fc766
 
ad01980
 
 
 
60fc766
ad01980
 
60fc766
 
ad01980
 
 
 
 
60fc766
 
 
 
 
 
 
 
 
ad01980
60fc766
 
 
 
ad01980
60fc766
 
ad01980
60fc766
 
 
 
 
ad01980
 
 
60fc766
ad01980
60fc766
 
ad01980
 
 
 
 
 
60fc766
ad01980
 
 
 
 
 
 
60fc766
 
ad01980
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60fc766
 
ad01980
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60fc766
ad01980
 
 
60fc766
 
 
 
ad01980
60fc766
 
ad01980
 
 
60fc766
ad01980
60fc766
 
 
ad01980
60fc766
 
ad01980
 
 
60fc766
 
ad01980
 
 
 
 
 
 
 
 
60fc766
 
ad01980
60fc766
 
 
 
ad01980
 
60fc766
ad01980
 
 
60fc766
 
ad01980
 
 
 
60fc766
 
ad01980
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
import requests
import json
import re
import os
import time
from typing import List

# IMPORTANT: You need `trl`, `transformers`, and `datasets` to run this locally.
# pip install trl transformers datasets torch
try:
    from trl import GRPOTrainer, GRPOConfig
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from datasets import Dataset
except ImportError:
    print("Dependencies missing! Ensure `trl` and `transformers` are installed.")

CLM_SERVER = "http://localhost:7860"

# ==========================================
# PROMPT CONSTRUCTION
# ==========================================

def format_tasks(tasks: list) -> str:
    lines = []
    for t in tasks:
        diff = t.get("difficulty", "medium")
        p = t.get("progress", 0.0)
        pri = t.get("priority", "normal")
        dead = t.get("deadline", "None")
        deps = t.get("depends_on", "None")
        lines.append(
            f"- [{t['id']}] {t['task_type']} | Pri: {pri} | Dead: {dead} "
            f"| Prog: {p:.2f} | Dep: {deps}"
        )
    return "\n".join(lines)


def manager_agent(state: dict) -> str:
    """Multi-Agent Oracle Manager: inspects worker states and issues guidance."""
    workers = state.get("workers", [])
    advice = []

    for w in workers:
        wid = w.get("id", "?")
        if w.get("fatigue_level") == "high":
            advice.append(
                f"Worker {wid} is burning out! MANDATORY: assign a 'break' to recover energy."
            )
        if w.get("stress_level") == "critical":
            advice.append(
                f"Worker {wid} stress is CRITICAL — delay non-critical tasks or use focus mode fast."
            )

    if state.get("upcoming_deadlines"):
        advice.append(
            f"Deadlines imminent: {state['upcoming_deadlines']} — prioritise these NOW."
        )
    if state.get("blocked_tasks"):
        advice.append(
            f"Blocked tasks (skip these): {state['blocked_tasks']}."
        )

    return " ".join(advice) if advice else "State is stable. Maintain a steady work pace."


def build_prompt(observation: dict) -> str:
    """Convert a CLM observation dict into an LLM prompt for the Worker Agent."""
    tasks = observation.get("tasks", [])
    state = observation.get("visible_state", {})
    workers = state.get("workers", [])

    # Pick first worker's summary for the prompt headline
    first_w = workers[0] if workers else {}
    manager_advice = manager_agent(state)

    return f"""You are a productivity AI acting as a worker managed by an Oracle Manager.

Current State:
- Energy Level: {first_w.get('fatigue_level', 'unknown')}
- Stress Level: {first_w.get('stress_level', 'unknown')}
- Focus Mode: {state.get('focus_mode', False)}
- Blocked Tasks: {state.get('blocked_tasks', [])}
- Time Step: {observation.get('time_step', 0)}

MANAGER DIRECTIVE: {manager_advice}

Tasks:
{format_tasks(tasks)}

Choose ONE action.
Available actions:
- work <task_id>: Normal work on task
- focus <task_id>: Deep work (2x progress, 2x energy cost)
- break: Rest to recover energy
- switch <task_id>: Switch focus to another task
- delay: Wait one step

Respond strictly with JSON only: {{"type": "work", "task_id": "m1"}}
"""


def parse_action(response: str) -> dict:
    default_act = {"type": "delay"}
    try:
        match = re.search(r"\{[^{}]*\}", response)
        if match:
            parsed = json.loads(match.group(0))
            if "type" in parsed:
                return parsed
        return default_act
    except Exception:
        return default_act


# ==========================================
# REAL REWARD FUNCTION
# ==========================================

def clm_reward_function(completions: List[str], **kwargs) -> List[float]:
    """
    REAL reward function — actually plays episodes in the CLM environment.

    Each completion is an action string the LLM chose.  We reset the env,
    step it with that action, and return the real reward the environment gives.

    This is what makes training meaningful: the LLM learns to pick actions
    that score well in the real cognitive-load simulation.
    """
    rewards = []

    for completion in completions:
        try:
            # Start a fresh episode with a medium-difficulty task set
            reset_resp = requests.post(
                f"{CLM_SERVER}/reset",
                json={"task_id": "medium"},
                timeout=10,
            ).json()

            # Extract observation for context (not used here but good for logging)
            obs = reset_resp.get("observation", reset_resp)

            # Parse the LLM's action from its text output
            action = parse_action(completion)

            # Ensure work/focus actions have a task_id — default to first available
            if action.get("type") in ("work", "focus") and not action.get("task_id"):
                tasks = obs.get("tasks", [])
                if tasks:
                    action["task_id"] = tasks[0]["id"]

            # Step the environment with the parsed action
            step_resp = requests.post(
                f"{CLM_SERVER}/step",
                json={"action": action},
                timeout=10,
            ).json()

            # The real reward from the environment physics
            real_reward = float(step_resp.get("reward", 0.0))
            rewards.append(real_reward)

        except requests.exceptions.ConnectionError:
            # Server not running — apply a strong penalty so training fails loudly
            print(
                f"[CLM] ERROR: Cannot reach {CLM_SERVER}. "
                "Start the server with: uvicorn server.app:app --port 7860 --reload"
            )
            rewards.append(-1.0)
        except Exception as e:
            print(f"[CLM] Env error during reward: {e}")
            rewards.append(-0.1)

    return rewards


# ==========================================
# DATASET COLLECTION
# ==========================================

def collect_prompts(n: int = 50, difficulty: str = "medium") -> List[dict]:
    """
    Collect real environment observations as training prompts.

    Each prompt is a fresh episode state. Running n resets gives the LLM
    diverse starting conditions (random seeds) to learn from.
    """
    prompts = []
    print(f"[CLM] Collecting {n} prompts from environment (difficulty={difficulty})...")

    for i in range(n):
        try:
            resp = requests.post(
                f"{CLM_SERVER}/reset",
                json={"task_id": difficulty},
                timeout=10,
            ).json()
            obs = resp.get("observation", resp)
            prompt = build_prompt(obs)
            prompts.append({"prompt": prompt})
        except requests.exceptions.ConnectionError:
            print(
                f"[CLM] Server offline at {CLM_SERVER} — "
                "using fallback prompts. Real training requires the server."
            )
            # Provide a minimal fallback so the training loop doesn't crash
            fallback_obs = {
                "tasks": [
                    {"id": "m1", "task_type": "email", "priority": "critical",
                     "progress": 0.0, "deadline": 14, "depends_on": None},
                    {"id": "m2", "task_type": "code_review", "priority": "high",
                     "progress": 0.0, "deadline": 20, "depends_on": None},
                ],
                "visible_state": {
                    "workers": [{"id": "w1", "fatigue_level": "low",
                                 "stress_level": "calm", "expertise": "analytical"}],
                    "focus_mode": False,
                    "upcoming_deadlines": [],
                    "blocked_tasks": [],
                },
                "time_step": 0,
            }
            prompts.append({"prompt": build_prompt(fallback_obs)})
        except Exception as e:
            print(f"[CLM] Prompt collection error at step {i}: {e}")
            continue

    print(f"[CLM] Collected {len(prompts)} prompts.")
    return prompts if prompts else [{"prompt": build_prompt({})}]


# ==========================================
# REWARD CURVE LOGGING
# ==========================================

_reward_log: list[dict] = []


def log_reward(step: int, rewards: list[float]) -> None:
    """Record per-step reward stats so we can plot a learning curve later."""
    entry = {
        "step": step,
        "mean": sum(rewards) / len(rewards) if rewards else 0.0,
        "max": max(rewards) if rewards else 0.0,
        "min": min(rewards) if rewards else 0.0,
    }
    _reward_log.append(entry)
    print(
        f"[CLM] Step {step:>4} | "
        f"mean_reward={entry['mean']:+.4f} | "
        f"max={entry['max']:+.4f} | "
        f"min={entry['min']:+.4f}"
    )


def save_reward_curve(path: str = "reward_curve.json") -> None:
    with open(path, "w") as f:
        json.dump(_reward_log, f, indent=2)
    print(f"[CLM] Reward curve saved to {path}")


def plot_reward_curve(path: str = "reward_curve.json") -> None:
    """Print an ASCII reward curve from the saved log. Requires no extra libraries."""
    try:
        with open(path) as f:
            data = json.load(f)
    except FileNotFoundError:
        print("[CLM] No reward curve file found. Run training first.")
        return

    if not data:
        print("[CLM] Reward log is empty.")
        return

    means = [d["mean"] for d in data]
    lo, hi = min(means), max(means)
    span = hi - lo if hi != lo else 1.0
    width = 40

    print("\n[CLM] Reward Learning Curve (ASCII)")
    print(f"  min={lo:+.3f}  max={hi:+.3f}  steps={len(means)}")
    print("  " + "-" * (width + 4))
    for d in data:
        bar_len = int((d["mean"] - lo) / span * width)
        bar = "#" * bar_len
        print(f"  {d['step']:>4} | {bar:<{width}} | {d['mean']:+.4f}")
    print("  " + "-" * (width + 4))


# ==========================================
# TRAINING LOOP
# ==========================================

def run_training_loop():
    model_name = "Qwen/Qwen2.5-1.5B-Instruct"
    print(f"[CLM] Loading model: {model_name}")

    try:
        model = AutoModelForCausalLM.from_pretrained(model_name)
        tokenizer = AutoTokenizer.from_pretrained(model_name)
    except Exception as e:
        print(f"[CLM] Could not load HuggingFace model. Error: {e}")
        return

    # Collect real prompts from the live environment
    prompts_data = collect_prompts(n=50, difficulty="medium")
    dataset = Dataset.from_list(prompts_data)

    print("[CLM] Configuring GRPO Trainer...")
    config = GRPOConfig(
        output_dir="grpo_clm_model",
        learning_rate=1e-5,
        num_train_epochs=3,
        per_device_train_batch_size=2,
        max_prompt_length=1024,
        max_completion_length=128,
        logging_steps=1,
        save_steps=50,
    )

    # Wrap reward function to also log reward curves
    step_counter = [0]

    def tracked_reward(completions: List[str], **kwargs) -> List[float]:
        rewards = clm_reward_function(completions, **kwargs)
        log_reward(step_counter[0], rewards)
        step_counter[0] += 1
        return rewards

    trainer = GRPOTrainer(
        model=model,
        reward_funcs=[tracked_reward],
        args=config,
        train_dataset=dataset,
    )

    print("[CLM] Starting training...")
    start = time.time()
    trainer.train()
    elapsed = time.time() - start

    print(f"[CLM] Training complete in {elapsed:.1f}s. Saving model.")
    trainer.save_model("grpo_clm_model_final")

    save_reward_curve("reward_curve.json")
    plot_reward_curve("reward_curve.json")


if __name__ == "__main__":
    print("--- Cognitive Load Manager: GRPO Training Script ---")
    print("Theme #1 (Multi-Agent): Oracle Manager oversees 3 Worker Agents.")
    print("Theme #2 (OpenEnv):     Real env steps drive the reward signal.")
    print()
    print("Make sure the CLM server is running first:")
    print("  uvicorn server.app:app --port 7860 --reload")
    print()

    import sys
    if "--train" in sys.argv:
        run_training_loop()
    elif "--plot" in sys.argv:
        plot_reward_curve("reward_curve.json")
    elif "--test-reward" in sys.argv:
        # Quick sanity-check: fire one real reward call against the live server
        print("[CLM] Testing reward function against live server...")
        test_completions = [
            '{"type": "work", "task_id": "m1"}',
            '{"type": "break"}',
            '{"type": "focus", "task_id": "m1"}',
            'invalid json garbage',
        ]
        rewards = clm_reward_function(test_completions)
        for c, r in zip(test_completions, rewards):
            print(f"  action={c!r:50s}  reward={r:+.4f}")
    else:
        print("Usage:")
        print("  python training_loop.py --test-reward   # verify env connection")
        print("  python training_loop.py --train         # run full GRPO training")
        print("  python training_loop.py --plot          # show reward curve")