File size: 2,951 Bytes
77da5ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# train_trl.py β€” GRPO Training Reference

`scripts/train_trl.py` β€” Curriculum GRPO training via HuggingFace TRL + Unsloth.

---

## Overview

Trains a small LLM (default: `Qwen2.5-1.5B-Instruct`) to resolve LifeStack life conflicts
using **Group Relative Policy Optimization (GRPO)**. Implements a success-based curriculum
that automatically increases difficulty when the agent's average reward exceeds 0.6.

Requires: `unsloth`, `trl`, `datasets`, `transformers`, `accelerate` (Colab / GPU).

---

## Usage

```bash
# Full curriculum training (5 stages Γ— 100 prompts)
python scripts/train_trl.py
```

No CLI args β€” edit constants at the top of the file to change stages/prompts/output dir.

---

## Architecture

### Reward Functions (multi-signal GRPO)

| Function | Signal |
|---|---|
| `reward_format_fn` | JSON format compliance |
| `reward_plausibility_fn` | Penalises zero-cost metric changes |
| `reward_task_success_fn` | Core env-step outcome reward |
| `reward_milestone_fn` | Milestone progress bonus |
| `reward_reasoning_fn` | Planning coherence score |
| `reward_human_feedback_fn` | Alignment with past real-world outcome feedback |

### `get_lifestack_evaluation(completion, prompt) -> dict`

The central reward computation function. Parses the LLM's JSON completion, reconstructs
the Task from the prompt's `<SYSTEM_METADATA>` block, steps the env, and returns:

```python
{
    "reward": float,
    "breakdown": dict,   # from obs.metadata["breakdown"]
    "action": LifeStackAction
}
```

Returns `{"reward": -0.5, "breakdown": {"error": ...}}` on any parse or env failure.

#### Task Construction Hardening (2026-04-23)

The `Task(...)` call inside `get_lifestack_evaluation` is wrapped in its own
`try/except`. On exception, logs `[reward] Task construction failed: <error>` and
returns the `-0.5` fallback immediately. A field-presence check on
`(id, goal, constraints, mutable_world, visible_world)` follows construction.

### Curriculum (`train_curriculum`)

```
Stage 1: difficulty=1 β†’ train β†’ eval β†’ if avg_reward > 0.6: difficulty++
Stage 2: difficulty=2 β†’ ...
...
Stage 5: difficulty=5 β†’ final save
```

### Dataset (`generate_dataset`)

Generates `N` prompts by:
1. Sampling a `TaskGenerator` task (flight_crisis or code_merge_crisis)
2. Merging a legacy `ConflictEvent` disruption for variety
3. Cascading the disruption through the `DependencyGraph`
4. Embedding task metadata in a `<SYSTEM_METADATA>` block for reward reconstruction

---

## Outputs

| Path | Contents |
|---|---|
| `./lifestack_model/` | Final saved model + tokenizer |
| `./lifestack_model/stage_N/` | Per-stage checkpoints |
| `training_logs/generations.jsonl` | Sampled generations (every 20 reward calls) |
| `grpo_reward_curve.png` | 50-episode eval reward curve |

---

## Change Log

| Date | Change |
|---|---|
| 2026-04-23 | `Task()` construction wrapped in try/except + field validation; returns -0.5 fallback on failure |