Upload README.md with huggingface_hub
Browse files
README.md
CHANGED
|
@@ -22,30 +22,36 @@ A reasoning-enhanced version of [Qwen3.5-0.8B](https://huggingface.co/Qwen/Qwen3
|
|
| 22 |
|
| 23 |
## Results
|
| 24 |
|
| 25 |
-
|
|
| 26 |
-
|---
|
| 27 |
-
|
|
| 28 |
-
|
|
| 29 |
-
| **
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
|
| 32 |
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
## Training Pipeline
|
| 36 |
|
| 37 |
### Phase 1: SFT Warmup
|
| 38 |
- **Data:** [3,558 reasoning examples](https://huggingface.co/datasets/celestialcreator/Qwen3.5-0.8B-GRPO-Math-Dataset) from 3 sources, standardized to `<think>` tags
|
| 39 |
-
- 1K Claude Sonnet math chains (originally `<thought>` β converted to `<think>`)
|
| 40 |
-
- ~250 from TeichAI Opus reasoning (already `<think>`)
|
| 41 |
-
- ~2.3K from Opus 4.6 Reasoning (thinking/solution β combined into `<think>`)
|
| 42 |
- **Purpose:** Solve the cold-start problem β teach the 0.8B model `<think>` tag format before RL exploration
|
| 43 |
- **Stats:** 1 epoch, loss 0.932, 78% token accuracy
|
| 44 |
|
| 45 |
### Phase 2: GRPO Training
|
| 46 |
- **Data:** GSM8K train split (7,473 math word problems)
|
| 47 |
- **Rewards:** Math correctness (1.0/0.0) + format reward (0.3 for `<think>` tags, 0.2 for `####` answer)
|
| 48 |
-
- **Config:** 8 generations/prompt, batch size 1
|
| 49 |
- **Hardware:** Single NVIDIA RTX 5090 (32GB VRAM)
|
| 50 |
- **Duration:** ~77 hours, 15,900 steps (epoch 2.13, rewards had plateaued)
|
| 51 |
|
|
@@ -62,19 +68,21 @@ Only 2 models in memory (policy + reference) instead of 4 β feasible on consum
|
|
| 62 |
## Lessons Learned
|
| 63 |
|
| 64 |
### What worked
|
| 65 |
-
- **GRPO improved zero-shot reasoning** β +5.9pp
|
|
|
|
| 66 |
- **Format + correctness rewards together** β `<think>` tag bonus helped learn structured reasoning alongside accuracy
|
| 67 |
-
- **Single consumer GPU is viable** β full pipeline on one RTX 5090
|
| 68 |
|
| 69 |
### What we'd do differently
|
| 70 |
-
- **Eval after SFT** β we skipped this, so we can't isolate SFT's contribution
|
| 71 |
- **Try GRPO without SFT** β ablation would show if SFT warmup is necessary or trades few-shot ability for format
|
| 72 |
-
- **Larger model** β 0.8B is near capacity ceiling. Successful open GRPO reproductions start at 1.5B+
|
| 73 |
|
| 74 |
### Technical findings
|
| 75 |
- **Qwen3.5 DeltaNet needs FLA** β install `flash-linear-attention` + `causal-conv1d`, otherwise torch fallback is ~10x slower
|
| 76 |
-
- **SDPA > FLA for inference** β 3.6x faster first call. Use `attn_implementation="sdpa"`
|
| 77 |
-
- **Rewards plateau ~epoch 1.2** β
|
|
|
|
| 78 |
|
| 79 |
## Usage
|
| 80 |
|
|
@@ -90,6 +98,7 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 90 |
trust_remote_code=True,
|
| 91 |
)
|
| 92 |
|
|
|
|
| 93 |
messages = [
|
| 94 |
{"role": "system", "content": "You are a helpful assistant that thinks step by step. Show your reasoning inside <think> tags before giving your final answer. End math answers with: #### <number>"},
|
| 95 |
{"role": "user", "content": "If a train travels at 60 mph for 2.5 hours, how far does it go?"},
|
|
@@ -101,6 +110,8 @@ outputs = model.generate(**inputs, max_new_tokens=512, do_sample=False)
|
|
| 101 |
print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True))
|
| 102 |
```
|
| 103 |
|
|
|
|
|
|
|
| 104 |
## Training Code
|
| 105 |
|
| 106 |
Full pipeline (Dockerfile, k8s configs, scripts): [github.com/CelestialCreator/gpu-lab/tree/main/projects/05-grpo-reasoning](https://github.com/CelestialCreator/gpu-lab/tree/main/projects/05-grpo-reasoning)
|
|
|
|
| 22 |
|
| 23 |
## Results
|
| 24 |
|
| 25 |
+
| Eval Setting | GSM8K Accuracy | Notes |
|
| 26 |
+
|---|:-:|---|
|
| 27 |
+
| Baseline 8-shot CoT | 53.5% | Pre-trained, no fine-tuning |
|
| 28 |
+
| Baseline zero-shot | 52.1% | Pre-trained, no fine-tuning |
|
| 29 |
+
| **GRPO zero-shot** | **58.0% (+5.9pp)** | Best result β model reasons autonomously |
|
| 30 |
+
| GRPO 8-shot (plain format) | 50.4% (-3.1pp) | Few-shot examples conflict with learned policy |
|
| 31 |
+
| GRPO 8-shot (`<think>` aligned) | 34.1% (-19.4pp) | Format-aligned examples hurt even more |
|
| 32 |
|
| 33 |
+
### Key Finding: Demonstration to Policy Shift
|
| 34 |
|
| 35 |
+
GRPO training shifted the model from **demonstration-based reasoning** to **policy-based reasoning**.
|
| 36 |
+
|
| 37 |
+
After training, the model:
|
| 38 |
+
- **Performs best in zero-shot** β it reasons autonomously using `<think>` tags
|
| 39 |
+
- **Is hurt by few-shot examples** β any demonstrations conflict with its learned internal reasoning policy
|
| 40 |
+
- **Is hurt even more by format-aligned few-shot** β `<think>` tags in examples caused the model to confuse context with its own generation, dropping to 34.1%
|
| 41 |
+
|
| 42 |
+
This is a behavioral shift, not a regression. The model no longer needs (or wants) demonstrations. This mirrors what DeepSeek-R1 demonstrated at 670B scale.
|
| 43 |
|
| 44 |
## Training Pipeline
|
| 45 |
|
| 46 |
### Phase 1: SFT Warmup
|
| 47 |
- **Data:** [3,558 reasoning examples](https://huggingface.co/datasets/celestialcreator/Qwen3.5-0.8B-GRPO-Math-Dataset) from 3 sources, standardized to `<think>` tags
|
|
|
|
|
|
|
|
|
|
| 48 |
- **Purpose:** Solve the cold-start problem β teach the 0.8B model `<think>` tag format before RL exploration
|
| 49 |
- **Stats:** 1 epoch, loss 0.932, 78% token accuracy
|
| 50 |
|
| 51 |
### Phase 2: GRPO Training
|
| 52 |
- **Data:** GSM8K train split (7,473 math word problems)
|
| 53 |
- **Rewards:** Math correctness (1.0/0.0) + format reward (0.3 for `<think>` tags, 0.2 for `####` answer)
|
| 54 |
+
- **Config:** 8 generations/prompt, batch size 1 x 8 grad accum, lr 1e-6, beta=0.04
|
| 55 |
- **Hardware:** Single NVIDIA RTX 5090 (32GB VRAM)
|
| 56 |
- **Duration:** ~77 hours, 15,900 steps (epoch 2.13, rewards had plateaued)
|
| 57 |
|
|
|
|
| 68 |
## Lessons Learned
|
| 69 |
|
| 70 |
### What worked
|
| 71 |
+
- **GRPO improved zero-shot reasoning** β +5.9pp, model internalized step-by-step thinking
|
| 72 |
+
- **Demonstration to policy shift** β the model developed its own reasoning strategy instead of relying on examples
|
| 73 |
- **Format + correctness rewards together** β `<think>` tag bonus helped learn structured reasoning alongside accuracy
|
| 74 |
+
- **Single consumer GPU is viable** β full pipeline on one RTX 5090
|
| 75 |
|
| 76 |
### What we'd do differently
|
| 77 |
+
- **Eval after SFT** β we skipped this, so we can't isolate SFT's contribution
|
| 78 |
- **Try GRPO without SFT** β ablation would show if SFT warmup is necessary or trades few-shot ability for format
|
| 79 |
+
- **Larger model** β 0.8B is near capacity ceiling. Successful open GRPO reproductions start at 1.5B+
|
| 80 |
|
| 81 |
### Technical findings
|
| 82 |
- **Qwen3.5 DeltaNet needs FLA** β install `flash-linear-attention` + `causal-conv1d`, otherwise torch fallback is ~10x slower
|
| 83 |
+
- **SDPA > FLA for inference** β 3.6x faster first call. Use `attn_implementation="sdpa"`
|
| 84 |
+
- **Rewards plateau ~epoch 1.2** β diminishing returns beyond 2 epochs at this scale
|
| 85 |
+
- **RL-trained models are few-shot sensitive** β even format-aligned examples hurt (34.1%), suggesting the model confuses example `<think>` tags with its own generation context
|
| 86 |
|
| 87 |
## Usage
|
| 88 |
|
|
|
|
| 98 |
trust_remote_code=True,
|
| 99 |
)
|
| 100 |
|
| 101 |
+
# Best used in zero-shot β the model has its own reasoning policy
|
| 102 |
messages = [
|
| 103 |
{"role": "system", "content": "You are a helpful assistant that thinks step by step. Show your reasoning inside <think> tags before giving your final answer. End math answers with: #### <number>"},
|
| 104 |
{"role": "user", "content": "If a train travels at 60 mph for 2.5 hours, how far does it go?"},
|
|
|
|
| 110 |
print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True))
|
| 111 |
```
|
| 112 |
|
| 113 |
+
> **Note:** This model performs best in **zero-shot** mode. Do not use few-shot examples β they conflict with the model's learned reasoning policy and reduce accuracy.
|
| 114 |
+
|
| 115 |
## Training Code
|
| 116 |
|
| 117 |
Full pipeline (Dockerfile, k8s configs, scripts): [github.com/CelestialCreator/gpu-lab/tree/main/projects/05-grpo-reasoning](https://github.com/CelestialCreator/gpu-lab/tree/main/projects/05-grpo-reasoning)
|