hugruby commited on
Commit
7defbb4
·
verified ·
1 Parent(s): 628715a

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +139 -0
README.md ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ base_model: mistralai/Mathstral-7B-v0.1
4
+ library_name: peft
5
+ pipeline_tag: text-generation
6
+ tags:
7
+ - lora
8
+ - peft
9
+ - grpo
10
+ - dr-grpo
11
+ - mathematical-reasoning
12
+ - math
13
+ datasets:
14
+ - hugruby/mismatched-wrong-drafts
15
+ language:
16
+ - en
17
+ ---
18
+
19
+ # Mathstral-7B · Matched × Wrong drafts
20
+
21
+ Ablation — a *wrong* draft **matched** to its own problem.
22
+
23
+ LoRA adapter for **`mistralai/Mathstral-7B-v0.1`**, trained with **Dr. GRPO** as the **Matched × Wrong drafts** condition in *"Weak-to-Strong Elicitation via Mismatched Wrong Drafts"* (Wei Deng, [arXiv:2605.17314](https://arxiv.org/abs/2605.17314)).
24
+
25
+ - **Base model:** `mistralai/Mathstral-7B-v0.1` (Apache-2.0)
26
+ - **Draft model:** `Qwen/Qwen2.5-Math-1.5B` (writes the training-time draft)
27
+ - **Adapter:** LoRA r=16, α=32 (167 MB), released at **global step 2000**
28
+ - **Training data:** config `matched_wrong` of [`hugruby/mismatched-wrong-drafts`](https://huggingface.co/datasets/hugruby/mismatched-wrong-drafts) — 8,888 Level 3–5 MATH problems (**MATH-500 held out**)
29
+ - **License:** Apache-2.0
30
+
31
+ ## How to use
32
+
33
+ This is a **LoRA adapter** — load it on top of the base model.
34
+
35
+ ```python
36
+ import torch
37
+ from transformers import AutoModelForCausalLM, AutoTokenizer
38
+ from peft import PeftModel
39
+
40
+ BASE = "mistralai/Mathstral-7B-v0.1"
41
+ ADAPTER = "hugruby/mathstral-7b-matched-wrong-drafts"
42
+
43
+ tok = AutoTokenizer.from_pretrained(ADAPTER)
44
+ model = PeftModel.from_pretrained(
45
+ AutoModelForCausalLM.from_pretrained(BASE, torch_dtype=torch.bfloat16, device_map="auto"),
46
+ ADAPTER,
47
+ ).eval()
48
+
49
+ problem = "If $x+y=6$ and $xy=5$, find $x^2+y^2$."
50
+ gen = dict(max_new_tokens=4096, do_sample=False)
51
+
52
+ # CANONICAL — the plain draft-free prompt the model was trained and evaluated on (no [INST]):
53
+ PROMPT = (
54
+ "Problem: " + problem + "\n\n"
55
+ "Thinking: N/A\n\n"
56
+ "The thinking section may contain errors. Solve the math problem step by step. "
57
+ "Write your own correct solution. Put your final answer within \\boxed{}.\n\n"
58
+ "Correct Solution:"
59
+ )
60
+ ids = tok(PROMPT, return_tensors="pt").to(model.device)
61
+ print(tok.decode(model.generate(**ids, **gen)[0][ids.input_ids.shape[1]:], skip_special_tokens=True))
62
+ ```
63
+
64
+ ### Optional: the `[INST]` chat format (out-of-distribution)
65
+
66
+ The shipped `chat_template.jinja` is Mathstral's original `[INST]` chat template. This adapter was **not** trained in that format, so `apply_chat_template(...)` is **out-of-distribution** and generally underperforms the plain prompt above — it is included only so you can A/B both:
67
+
68
+ ```python
69
+ ids = tok.apply_chat_template(
70
+ [{"role": "user",
71
+ "content": problem + "\n\nPlease reason step by step, and put your final answer within \\boxed{}."}],
72
+ add_generation_prompt=True, return_tensors="pt").to(model.device)
73
+ print(tok.decode(model.generate(ids, **gen)[0][ids.shape[1]:], skip_special_tokens=True))
74
+ ```
75
+
76
+ ## How it was trained
77
+
78
+ Trained with **Dr. GRPO** (`loss_type=dr_grpo`, `scale_rewards=False`) using TRL `GRPOTrainer` on top of Unsloth `FastLanguageModel`, on the `matched_wrong` data config. The reward is binary `mathematically_quasi_correct`. The correction-bonus, copy-penalty, and corrupt-penalty terms are all **0**, and the reward is pure binary.
79
+
80
+ Training command:
81
+
82
+ ```bash
83
+ python scripts/train.py \
84
+ --model mistralai/Mathstral-7B-v0.1 \
85
+ --dataset-path data/matched_wrong \
86
+ --output-dir outputs/matched_wrong \
87
+ --max-steps 2222 \
88
+ --gradient-accumulation-steps 4 \
89
+ --max-completion-length 4096 \
90
+ --max-seq-length 7168 \
91
+ --max-prompt-tokens 3072 \
92
+ --learning-rate 5e-6 --lr-scheduler-type constant \
93
+ --beta 0 \
94
+ --correction-bonus 0.0 --copy-penalty 0.0 --corrupt-penalty 0.0 \
95
+ --adam-beta2 0.99 \
96
+ --save-steps 50 --gpu-mem-util 0.5
97
+ ```
98
+
99
+ | Hyperparameter | Value |
100
+ |---|---|
101
+ | Base model | `mistralai/Mathstral-7B-v0.1` |
102
+ | Method | Dr. GRPO (`loss_type=dr_grpo`, `scale_rewards=False`) |
103
+ | LoRA rank / alpha | **r = 16, α = 32** → scaling **γ = α/r = 2** |
104
+ | LoRA targets / dropout | `q,k,v,o,gate,up,down` (7 projections) / 0.0 |
105
+ | KL coefficient β | 0 |
106
+ | Reward bonuses | correction 0, copy-penalty 0, corrupt-penalty 0 |
107
+ | Generations per prompt | 16 |
108
+ | Per-device batch | 1 |
109
+ | Gradient accumulation | 4 → 4 problems × 16 = 64 completions/step |
110
+ | Learning rate | 5e-6, **constant** schedule |
111
+ | Adam β₂ | 0.99 |
112
+ | Max completion length | 4096 |
113
+ | Max sequence length | 7168 |
114
+ | Max prompt tokens | 3072 |
115
+ | Max steps | 2222 |
116
+ | **Released checkpoint** | **global step 2000** (epoch 0.900) |
117
+ | Random seed | 42 |
118
+
119
+ ## Files
120
+
121
+ - `adapter_model.safetensors`, `adapter_config.json` — the LoRA adapter (load with PEFT on the base model)
122
+ - `tokenizer.json`, `tokenizer.model`, `tokenizer_config.json`, `special_tokens_map.json` — tokenizer
123
+ - `chat_template.jinja` — Mathstral's `[INST]` template (see the out-of-distribution note above)
124
+
125
+ ## Citation
126
+
127
+ ```bibtex
128
+ @article{deng2026mismatched,
129
+ title = {Weak-to-Strong Elicitation via Mismatched Wrong Drafts},
130
+ author = {Deng, Wei},
131
+ journal = {arXiv preprint arXiv:2605.17314},
132
+ year = {2026},
133
+ url = {https://arxiv.org/abs/2605.17314}
134
+ }
135
+ ```
136
+
137
+ ## License
138
+
139
+ Apache-2.0. The base model (`Mathstral-7B-v0.1`) and the draft model (`Qwen2.5-Math-1.5B`) are both Apache-2.0.