hugruby commited on
Commit
9856910
·
verified ·
1 Parent(s): 64967c4

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +137 -0
README.md ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ base_model: mistralai/Mathstral-7B-v0.1
4
+ library_name: peft
5
+ pipeline_tag: text-generation
6
+ tags:
7
+ - lora
8
+ - peft
9
+ - grpo
10
+ - dr-grpo
11
+ - mathematical-reasoning
12
+ - math
13
+ datasets:
14
+ - hugruby/mismatched-wrong-drafts
15
+ language:
16
+ - en
17
+ ---
18
+
19
+ # Mathstral-7B · No draft (GRPO baseline)
20
+
21
+ On-policy Dr. GRPO baseline — no draft injected.
22
+
23
+ LoRA adapter for **`mistralai/Mathstral-7B-v0.1`**, trained with **Dr. GRPO** as the **No draft (GRPO baseline)** condition in *"Weak-to-Strong Elicitation via Mismatched Wrong Drafts"* (Wei Deng, [arXiv:2605.17314](https://arxiv.org/abs/2605.17314)).
24
+
25
+ - **Base model:** `mistralai/Mathstral-7B-v0.1` (Apache-2.0)
26
+ - **Adapter:** LoRA r=16, α=32 (167 MB), released at **global step 2000**
27
+ - **Training data:** config `no_draft` of [`hugruby/mismatched-wrong-drafts`](https://huggingface.co/datasets/hugruby/mismatched-wrong-drafts) — 8,888 Level 3–5 MATH problems (**MATH-500 held out**)
28
+ - **License:** Apache-2.0
29
+
30
+ ## How to use
31
+
32
+ This is a **LoRA adapter** — load it on top of the base model.
33
+
34
+ ```python
35
+ import torch
36
+ from transformers import AutoModelForCausalLM, AutoTokenizer
37
+ from peft import PeftModel
38
+
39
+ BASE = "mistralai/Mathstral-7B-v0.1"
40
+ ADAPTER = "hugruby/mathstral-7b-grpo-no-draft"
41
+
42
+ tok = AutoTokenizer.from_pretrained(ADAPTER)
43
+ model = PeftModel.from_pretrained(
44
+ AutoModelForCausalLM.from_pretrained(BASE, torch_dtype=torch.bfloat16, device_map="auto"),
45
+ ADAPTER,
46
+ ).eval()
47
+
48
+ problem = "If $x+y=6$ and $xy=5$, find $x^2+y^2$."
49
+ gen = dict(max_new_tokens=4096, do_sample=False)
50
+
51
+ # CANONICAL — the plain draft-free prompt the model was trained and evaluated on (no [INST]):
52
+ PROMPT = (
53
+ "Problem: " + problem + "\n\n"
54
+ "Thinking: N/A\n\n"
55
+ "The thinking section may contain errors. Solve the math problem step by step. "
56
+ "Write your own correct solution. Put your final answer within \\boxed{}.\n\n"
57
+ "Correct Solution:"
58
+ )
59
+ ids = tok(PROMPT, return_tensors="pt").to(model.device)
60
+ print(tok.decode(model.generate(**ids, **gen)[0][ids.input_ids.shape[1]:], skip_special_tokens=True))
61
+ ```
62
+
63
+ ### Optional: the `[INST]` chat format (out-of-distribution)
64
+
65
+ The shipped `chat_template.jinja` is Mathstral's original `[INST]` chat template. This adapter was **not** trained in that format, so `apply_chat_template(...)` is **out-of-distribution** and generally underperforms the plain prompt above — it is included only so you can A/B both:
66
+
67
+ ```python
68
+ ids = tok.apply_chat_template(
69
+ [{"role": "user",
70
+ "content": problem + "\n\nPlease reason step by step, and put your final answer within \\boxed{}."}],
71
+ add_generation_prompt=True, return_tensors="pt").to(model.device)
72
+ print(tok.decode(model.generate(ids, **gen)[0][ids.shape[1]:], skip_special_tokens=True))
73
+ ```
74
+
75
+ ## How it was trained
76
+
77
+ Trained with **Dr. GRPO** (`loss_type=dr_grpo`, `scale_rewards=False`) using TRL `GRPOTrainer` on top of Unsloth `FastLanguageModel`, on the `no_draft` data config. The reward is binary `mathematically_quasi_correct`. The correction-bonus, copy-penalty, and corrupt-penalty terms are all **0**, and the reward is pure binary.
78
+
79
+ Training command:
80
+
81
+ ```bash
82
+ python scripts/train.py \
83
+ --model mistralai/Mathstral-7B-v0.1 \
84
+ --dataset-path data/no_draft \
85
+ --output-dir outputs/no_draft \
86
+ --max-steps 2222 \
87
+ --gradient-accumulation-steps 4 \
88
+ --max-completion-length 4096 \
89
+ --max-seq-length 7168 \
90
+ --learning-rate 5e-6 --lr-scheduler-type constant \
91
+ --beta 0 \
92
+ --correction-bonus 0.0 --copy-penalty 0.0 --corrupt-penalty 0.0 \
93
+ --adam-beta2 0.99 \
94
+ --save-steps 50 --gpu-mem-util 0.5
95
+ ```
96
+
97
+ | Hyperparameter | Value |
98
+ |---|---|
99
+ | Base model | `mistralai/Mathstral-7B-v0.1` |
100
+ | Method | Dr. GRPO (`loss_type=dr_grpo`, `scale_rewards=False`) |
101
+ | LoRA rank / alpha | **r = 16, α = 32** → scaling **γ = α/r = 2** |
102
+ | LoRA targets / dropout | `q,k,v,o,gate,up,down` (7 projections) / 0.0 |
103
+ | KL coefficient β | 0 |
104
+ | Reward bonuses | correction 0, copy-penalty 0, corrupt-penalty 0 |
105
+ | Generations per prompt | 16 |
106
+ | Per-device batch | 1 |
107
+ | Gradient accumulation | 4 → 4 problems × 16 = 64 completions/step |
108
+ | Learning rate | 5e-6, **constant** schedule |
109
+ | Adam β₂ | 0.99 |
110
+ | Max completion length | 4096 |
111
+ | Max sequence length | 7168 |
112
+ | Max prompt tokens | Disabled and no truncation. Since no drafts, prompts are short — longest 1,899 tok, well under the implicit 3,072 = 7,168 − 4,096 budget. As a result, disabling is equivalent to setting it to 3,072. |
113
+ | Max steps | 2222 |
114
+ | **Released checkpoint** | **global step 2000** (epoch 0.900) |
115
+ | Random seed | 42 |
116
+
117
+ ## Files
118
+
119
+ - `adapter_model.safetensors`, `adapter_config.json` — the LoRA adapter (load with PEFT on the base model)
120
+ - `tokenizer.json`, `tokenizer.model`, `tokenizer_config.json`, `special_tokens_map.json` — tokenizer
121
+ - `chat_template.jinja` — Mathstral's `[INST]` template (see the out-of-distribution note above)
122
+
123
+ ## Citation
124
+
125
+ ```bibtex
126
+ @article{deng2026mismatched,
127
+ title = {Weak-to-Strong Elicitation via Mismatched Wrong Drafts},
128
+ author = {Deng, Wei},
129
+ journal = {arXiv preprint arXiv:2605.17314},
130
+ year = {2026},
131
+ url = {https://arxiv.org/abs/2605.17314}
132
+ }
133
+ ```
134
+
135
+ ## License
136
+
137
+ Apache-2.0. The base model (`Mathstral-7B-v0.1`) and the draft model (`Qwen2.5-Math-1.5B`) are both Apache-2.0.