dipta007 commited on
Commit
72c58cb
·
verified ·
1 Parent(s): 707e99c

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +132 -0
README.md ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ license: gemma
4
+ license_link: https://ai.google.dev/gemma/terms
5
+ pipeline_tag: text-generation
6
+ tags:
7
+ - math
8
+ - reasoning
9
+ - computational-graph
10
+ - bangla
11
+ - low-resource
12
+ - distractor-aware
13
+ - grpo
14
+ - reinforcement-learning
15
+ base_model:
16
+ - google/gemma-3-12b-it
17
+ language:
18
+ - bn
19
+ - en
20
+ datasets:
21
+ - dipta007/dagger
22
+ - dipta007/DistractMath-Bn
23
+ ---
24
+
25
+ # DAGGER-12B-GRPO
26
+
27
+ <a href="https://arxiv.org/abs/XXXX.XXXXX" target="_blank">
28
+ <img alt="arXiv" src="https://img.shields.io/badge/arXiv-XXXX.XXXXX-b31b1b" style="display: inline-block; vertical-align: middle;"/>
29
+ </a>
30
+ <a href="https://github.com/your-username/dagger" target="_blank">
31
+ <img alt="GitHub" src="https://img.shields.io/badge/GitHub-Code-black" style="display: inline-block; vertical-align: middle;"/>
32
+ </a>
33
+
34
+ ## Model Description
35
+
36
+ **DAGGER-12B-GRPO** is trained with Group Relative Policy Optimization (GRPO) directly from the base Gemma-3-12B model, **without SFT initialization**. This model demonstrates that GRPO alone can learn computational graph generation, though SFT initialization provides better distractor robustness.
37
+
38
+ ## Highlights
39
+
40
+ - **Base → GRPO training** (no SFT phase)
41
+ - **Executable reward signal**: Learns from format, execution, and correctness rewards
42
+ - **Ablation model**: Demonstrates contribution of SFT initialization
43
+
44
+ ## Model Overview
45
+
46
+ | Attribute | Value |
47
+ |-----------|-------|
48
+ | Base Model | Gemma-3-12B-Instruct |
49
+ | Training | GRPO (from base) |
50
+ | Parameters | 12B |
51
+ | LoRA Rank | 64 |
52
+
53
+ ## Performance
54
+
55
+ | Dataset | Original | +Distractor | Drop |
56
+ |---------|----------|-------------|------|
57
+ | MGSM | 67.6 | 48.4 | 19.2 |
58
+ | MSVAMP | 75.0 | 59.6 | 15.4 |
59
+
60
+ ### Ablation: Effect of SFT Initialization
61
+
62
+ | Initialization | MGSM (+D) | MSVAMP (+D) |
63
+ |---------------|-----------|-------------|
64
+ | Base → GRPO | 48.4 | 59.6 |
65
+ | **SFT → GRPO** | **64.0** (+15.6) | **66.8** (+7.2) |
66
+
67
+ **Key Finding**: SFT initialization provides crucial scaffolding that stabilizes GRPO learning and improves distractor robustness by +7-16 points.
68
+
69
+ ## Quickstart
70
+
71
+ ```python
72
+ from transformers import AutoModelForCausalLM, AutoTokenizer
73
+
74
+ model_name = "dipta007/dagger-12B_GRPO"
75
+
76
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
77
+ model = AutoModelForCausalLM.from_pretrained(
78
+ model_name,
79
+ torch_dtype="auto",
80
+ device_map="auto"
81
+ )
82
+
83
+ question = "মিনার কাছে ১০০টি কলম আছে। প্রতিটি কলমের দাম ৫ টাকা।"
84
+
85
+ messages = [
86
+ {"role": "system", "content": "You are an expert Bangla Math Reasoner. Solve by constructing a Computational Graph."},
87
+ {"role": "user", "content": question}
88
+ ]
89
+
90
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
91
+ inputs = tokenizer([text], return_tensors="pt").to(model.device)
92
+
93
+ outputs = model.generate(**inputs, max_new_tokens=1024)
94
+ response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
95
+ print(response)
96
+ ```
97
+
98
+ ## Training Configuration
99
+
100
+ | Parameter | Value |
101
+ |-----------|-------|
102
+ | Base Model | Gemma-3-12B-Instruct (no SFT) |
103
+ | LoRA Rank / Alpha | 64 / 128 |
104
+ | Global Batch Size | 32 |
105
+ | Generations per Prompt | 8 |
106
+ | Loss Type | BNPO |
107
+ | β / ε / ε_high | 0.0 / 0.2 / 0.28 |
108
+
109
+ **Reward Function:**
110
+ - Valid JSON: +0.5
111
+ - Successful execution: +0.5
112
+ - Correct answer: +1.0
113
+
114
+ ## When to Use This Model
115
+
116
+ - **Ablation studies**: Understanding contribution of SFT vs. GRPO
117
+ - **GRPO-only scenarios**: When SFT data is unavailable
118
+ - **Research**: Studying policy optimization for structured generation
119
+
120
+ ## Related Models
121
+
122
+ | Model | Training | MGSM (+D) | MSVAMP (+D) |
123
+ |-------|----------|-----------|-------------|
124
+ | **dagger-12B_GRPO** | Base → GRPO | 48.4 | 59.6 |
125
+ | [dagger-12B_SFT_GRPO](https://huggingface.co/dipta007/dagger-12B_SFT_GRPO) | SFT → GRPO | **64.0** | **66.8** |
126
+ | [dagger-12B_SFT](https://huggingface.co/dipta007/dagger-12B_SFT) | SFT only | 56.8 | 65.4 |
127
+
128
+ ## Citation
129
+
130
+ ```bibtex
131
+ will be updated
132
+ ```