File size: 13,012 Bytes
60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 60fc766 ad01980 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 | import requests
import json
import re
import os
import time
from typing import List
# IMPORTANT: You need `trl`, `transformers`, and `datasets` to run this locally.
# pip install trl transformers datasets torch
try:
from trl import GRPOTrainer, GRPOConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset
except ImportError:
print("Dependencies missing! Ensure `trl` and `transformers` are installed.")
CLM_SERVER = "http://localhost:7860"
# ==========================================
# PROMPT CONSTRUCTION
# ==========================================
def format_tasks(tasks: list) -> str:
lines = []
for t in tasks:
diff = t.get("difficulty", "medium")
p = t.get("progress", 0.0)
pri = t.get("priority", "normal")
dead = t.get("deadline", "None")
deps = t.get("depends_on", "None")
lines.append(
f"- [{t['id']}] {t['task_type']} | Pri: {pri} | Dead: {dead} "
f"| Prog: {p:.2f} | Dep: {deps}"
)
return "\n".join(lines)
def manager_agent(state: dict) -> str:
"""Multi-Agent Oracle Manager: inspects worker states and issues guidance."""
workers = state.get("workers", [])
advice = []
for w in workers:
wid = w.get("id", "?")
if w.get("fatigue_level") == "high":
advice.append(
f"Worker {wid} is burning out! MANDATORY: assign a 'break' to recover energy."
)
if w.get("stress_level") == "critical":
advice.append(
f"Worker {wid} stress is CRITICAL — delay non-critical tasks or use focus mode fast."
)
if state.get("upcoming_deadlines"):
advice.append(
f"Deadlines imminent: {state['upcoming_deadlines']} — prioritise these NOW."
)
if state.get("blocked_tasks"):
advice.append(
f"Blocked tasks (skip these): {state['blocked_tasks']}."
)
return " ".join(advice) if advice else "State is stable. Maintain a steady work pace."
def build_prompt(observation: dict) -> str:
"""Convert a CLM observation dict into an LLM prompt for the Worker Agent."""
tasks = observation.get("tasks", [])
state = observation.get("visible_state", {})
workers = state.get("workers", [])
# Pick first worker's summary for the prompt headline
first_w = workers[0] if workers else {}
manager_advice = manager_agent(state)
return f"""You are a productivity AI acting as a worker managed by an Oracle Manager.
Current State:
- Energy Level: {first_w.get('fatigue_level', 'unknown')}
- Stress Level: {first_w.get('stress_level', 'unknown')}
- Focus Mode: {state.get('focus_mode', False)}
- Blocked Tasks: {state.get('blocked_tasks', [])}
- Time Step: {observation.get('time_step', 0)}
MANAGER DIRECTIVE: {manager_advice}
Tasks:
{format_tasks(tasks)}
Choose ONE action.
Available actions:
- work <task_id>: Normal work on task
- focus <task_id>: Deep work (2x progress, 2x energy cost)
- break: Rest to recover energy
- switch <task_id>: Switch focus to another task
- delay: Wait one step
Respond strictly with JSON only: {{"type": "work", "task_id": "m1"}}
"""
def parse_action(response: str) -> dict:
default_act = {"type": "delay"}
try:
match = re.search(r"\{[^{}]*\}", response)
if match:
parsed = json.loads(match.group(0))
if "type" in parsed:
return parsed
return default_act
except Exception:
return default_act
# ==========================================
# REAL REWARD FUNCTION
# ==========================================
def clm_reward_function(completions: List[str], **kwargs) -> List[float]:
"""
REAL reward function — actually plays episodes in the CLM environment.
Each completion is an action string the LLM chose. We reset the env,
step it with that action, and return the real reward the environment gives.
This is what makes training meaningful: the LLM learns to pick actions
that score well in the real cognitive-load simulation.
"""
rewards = []
for completion in completions:
try:
# Start a fresh episode with a medium-difficulty task set
reset_resp = requests.post(
f"{CLM_SERVER}/reset",
json={"task_id": "medium"},
timeout=10,
).json()
# Extract observation for context (not used here but good for logging)
obs = reset_resp.get("observation", reset_resp)
# Parse the LLM's action from its text output
action = parse_action(completion)
# Ensure work/focus actions have a task_id — default to first available
if action.get("type") in ("work", "focus") and not action.get("task_id"):
tasks = obs.get("tasks", [])
if tasks:
action["task_id"] = tasks[0]["id"]
# Step the environment with the parsed action
step_resp = requests.post(
f"{CLM_SERVER}/step",
json={"action": action},
timeout=10,
).json()
# The real reward from the environment physics
real_reward = float(step_resp.get("reward", 0.0))
rewards.append(real_reward)
except requests.exceptions.ConnectionError:
# Server not running — apply a strong penalty so training fails loudly
print(
f"[CLM] ERROR: Cannot reach {CLM_SERVER}. "
"Start the server with: uvicorn server.app:app --port 7860 --reload"
)
rewards.append(-1.0)
except Exception as e:
print(f"[CLM] Env error during reward: {e}")
rewards.append(-0.1)
return rewards
# ==========================================
# DATASET COLLECTION
# ==========================================
def collect_prompts(n: int = 50, difficulty: str = "medium") -> List[dict]:
"""
Collect real environment observations as training prompts.
Each prompt is a fresh episode state. Running n resets gives the LLM
diverse starting conditions (random seeds) to learn from.
"""
prompts = []
print(f"[CLM] Collecting {n} prompts from environment (difficulty={difficulty})...")
for i in range(n):
try:
resp = requests.post(
f"{CLM_SERVER}/reset",
json={"task_id": difficulty},
timeout=10,
).json()
obs = resp.get("observation", resp)
prompt = build_prompt(obs)
prompts.append({"prompt": prompt})
except requests.exceptions.ConnectionError:
print(
f"[CLM] Server offline at {CLM_SERVER} — "
"using fallback prompts. Real training requires the server."
)
# Provide a minimal fallback so the training loop doesn't crash
fallback_obs = {
"tasks": [
{"id": "m1", "task_type": "email", "priority": "critical",
"progress": 0.0, "deadline": 14, "depends_on": None},
{"id": "m2", "task_type": "code_review", "priority": "high",
"progress": 0.0, "deadline": 20, "depends_on": None},
],
"visible_state": {
"workers": [{"id": "w1", "fatigue_level": "low",
"stress_level": "calm", "expertise": "analytical"}],
"focus_mode": False,
"upcoming_deadlines": [],
"blocked_tasks": [],
},
"time_step": 0,
}
prompts.append({"prompt": build_prompt(fallback_obs)})
except Exception as e:
print(f"[CLM] Prompt collection error at step {i}: {e}")
continue
print(f"[CLM] Collected {len(prompts)} prompts.")
return prompts if prompts else [{"prompt": build_prompt({})}]
# ==========================================
# REWARD CURVE LOGGING
# ==========================================
_reward_log: list[dict] = []
def log_reward(step: int, rewards: list[float]) -> None:
"""Record per-step reward stats so we can plot a learning curve later."""
entry = {
"step": step,
"mean": sum(rewards) / len(rewards) if rewards else 0.0,
"max": max(rewards) if rewards else 0.0,
"min": min(rewards) if rewards else 0.0,
}
_reward_log.append(entry)
print(
f"[CLM] Step {step:>4} | "
f"mean_reward={entry['mean']:+.4f} | "
f"max={entry['max']:+.4f} | "
f"min={entry['min']:+.4f}"
)
def save_reward_curve(path: str = "reward_curve.json") -> None:
with open(path, "w") as f:
json.dump(_reward_log, f, indent=2)
print(f"[CLM] Reward curve saved to {path}")
def plot_reward_curve(path: str = "reward_curve.json") -> None:
"""Print an ASCII reward curve from the saved log. Requires no extra libraries."""
try:
with open(path) as f:
data = json.load(f)
except FileNotFoundError:
print("[CLM] No reward curve file found. Run training first.")
return
if not data:
print("[CLM] Reward log is empty.")
return
means = [d["mean"] for d in data]
lo, hi = min(means), max(means)
span = hi - lo if hi != lo else 1.0
width = 40
print("\n[CLM] Reward Learning Curve (ASCII)")
print(f" min={lo:+.3f} max={hi:+.3f} steps={len(means)}")
print(" " + "-" * (width + 4))
for d in data:
bar_len = int((d["mean"] - lo) / span * width)
bar = "#" * bar_len
print(f" {d['step']:>4} | {bar:<{width}} | {d['mean']:+.4f}")
print(" " + "-" * (width + 4))
# ==========================================
# TRAINING LOOP
# ==========================================
def run_training_loop():
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
print(f"[CLM] Loading model: {model_name}")
try:
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
except Exception as e:
print(f"[CLM] Could not load HuggingFace model. Error: {e}")
return
# Collect real prompts from the live environment
prompts_data = collect_prompts(n=50, difficulty="medium")
dataset = Dataset.from_list(prompts_data)
print("[CLM] Configuring GRPO Trainer...")
config = GRPOConfig(
output_dir="grpo_clm_model",
learning_rate=1e-5,
num_train_epochs=3,
per_device_train_batch_size=2,
max_prompt_length=1024,
max_completion_length=128,
logging_steps=1,
save_steps=50,
)
# Wrap reward function to also log reward curves
step_counter = [0]
def tracked_reward(completions: List[str], **kwargs) -> List[float]:
rewards = clm_reward_function(completions, **kwargs)
log_reward(step_counter[0], rewards)
step_counter[0] += 1
return rewards
trainer = GRPOTrainer(
model=model,
reward_funcs=[tracked_reward],
args=config,
train_dataset=dataset,
)
print("[CLM] Starting training...")
start = time.time()
trainer.train()
elapsed = time.time() - start
print(f"[CLM] Training complete in {elapsed:.1f}s. Saving model.")
trainer.save_model("grpo_clm_model_final")
save_reward_curve("reward_curve.json")
plot_reward_curve("reward_curve.json")
if __name__ == "__main__":
print("--- Cognitive Load Manager: GRPO Training Script ---")
print("Theme #1 (Multi-Agent): Oracle Manager oversees 3 Worker Agents.")
print("Theme #2 (OpenEnv): Real env steps drive the reward signal.")
print()
print("Make sure the CLM server is running first:")
print(" uvicorn server.app:app --port 7860 --reload")
print()
import sys
if "--train" in sys.argv:
run_training_loop()
elif "--plot" in sys.argv:
plot_reward_curve("reward_curve.json")
elif "--test-reward" in sys.argv:
# Quick sanity-check: fire one real reward call against the live server
print("[CLM] Testing reward function against live server...")
test_completions = [
'{"type": "work", "task_id": "m1"}',
'{"type": "break"}',
'{"type": "focus", "task_id": "m1"}',
'invalid json garbage',
]
rewards = clm_reward_function(test_completions)
for c, r in zip(test_completions, rewards):
print(f" action={c!r:50s} reward={r:+.4f}")
else:
print("Usage:")
print(" python training_loop.py --test-reward # verify env connection")
print(" python training_loop.py --train # run full GRPO training")
print(" python training_loop.py --plot # show reward curve")
|