File size: 12,215 Bytes
72a3813
f6a504a
 
 
 
72a3813
f6a504a
 
 
 
 
 
 
 
72a3813
f6a504a
72a3813
f6a504a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72a3813
 
f6a504a
72a3813
0f07817
 
f6a504a
efdf1f5
f6a504a
 
 
 
 
72a3813
f6a504a
72a3813
f6a504a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f07817
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6a504a
 
0f07817
f6a504a
 
 
0f07817
f6a504a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f07817
 
 
 
 
 
 
 
 
f6a504a
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
---
library_name: transformers
license: gemma
license_link: https://ai.google.dev/gemma/terms
pipeline_tag: text-generation
tags:
- math
- reasoning
- computational-graph
- bangla
- low-resource
- distractor-aware
base_model:
- google/gemma-3-12b-it
language:
- bn
- en
datasets:
- dipta007/dagger
- dipta007/DistractMath-Bn
model-index:
- name: dagger-12B_SFT_GRPO
  results:
  - task:
      type: question-answering
      name: Math Word Problems
    dataset:
      name: MGSM-BN
      type: mgsm
    metrics:
    - type: accuracy
      value: 78.4
      name: Original Accuracy
    - type: accuracy
      value: 64.0
      name: Distractor Accuracy
  - task:
      type: question-answering
      name: Math Word Problems
    dataset:
      name: MSVAMP-BN
      type: msvamp
    metrics:
    - type: accuracy
      value: 78.8
      name: Original Accuracy
    - type: accuracy
      value: 66.8
      name: Distractor Accuracy
---

# DAGGER-12B-SFT-GRPO

<a href="https://arxiv.org/abs/2601.06853" target="_blank">
    <img alt="arXiv" src="https://img.shields.io/badge/arXiv-2601.06853-b31b1b" style="display: inline-block; vertical-align: middle;"/>
</a>
<a href="https://github.com/dipta007/dagger" target="_blank">
    <img alt="GitHub" src="https://img.shields.io/badge/GitHub-Code-black" style="display: inline-block; vertical-align: middle;"/>
</a>
<a href="https://huggingface.co/datasets/dipta007/DistractMath-Bn" target="_blank">
    <img alt="Dataset" src="https://img.shields.io/badge/Dataset-DistractMath--BN-green" style="display: inline-block; vertical-align: middle;"/>
</a>

## Highlights

**DAGGER-12B-SFT-GRPO** is our best-performing model for distractor-aware mathematical reasoning in Bangla. Key features:

- **89% fewer tokens** than reasoning models while achieving comparable accuracy
- **Robust to distractors**: Only 12-14 point accuracy drop under distractor augmentation (vs. 14-20 for reasoning models, 18-41 for standard CoT)
- **Executable outputs**: Generates computational graphs that can be deterministically executed
- **Explicit distractor modeling**: Identifies irrelevant information as distractor nodes

## Model Overview

| Attribute | Value |
|-----------|-------|
| Base Model | Gemma-3-12B-Instruct |
| Training | SFT → GRPO |
| Parameters | 12B |
| LoRA Rank | 64 |
| Max Sequence Length | 4096 |
| Output Format | JSON Computational Graph |

## Performance

### Accuracy Comparison

| Model | MGSM | MSVAMP | MGSM (+D) | MSVAMP (+D) | Weighted Avg | Tokens |
|-------|------|--------|-----------|-------------|--------------|--------|
| Qwen 3-8B (Reasoning) | 88.0 | 81.1 | 70.5 | 66.9 | 71.4 | 3,128 |
| **DAGGER-12B (Ours)** | **78.4** | **78.8** | **64.0** | **66.8** | **69.4** | **359** |
| Gemma 3-12B (CoT) | 76.8 | 72.3 | 54.3 | 48.7 | 55.7 | 599 |

(+D) = with distractors

### Robustness (Accuracy Drop)

| Distractor Type | Error Rate |
|-----------------|------------|
| Related Entity (RED) | 36% |
| Orthogonal Attribute (OAD) | 34% |
| Null-Effect Event (NEED) | 33% |

## Output Format

The model generates computational graphs in JSON format:

```json
{
  "nodes": [
    {"id": "n1", "op": "const", "val": 122195, "distractor": false, "label": "মিনার কলম"},
    {"id": "n2", "op": "const", "val": 25084, "distractor": true, "label": "রাজুর কলম"},
    {"id": "n3", "op": "const", "val": 45.6, "distractor": false, "label": "প্রতিটি কলমের দাম"},
    {"id": "total", "op": "mul", "args": ["n1", "n3"], "distractor": false, "label": "মোট টাকা"},
    {"id": "final_result", "op": "identity", "args": ["total"], "distractor": false}
  ]
}
```

**Supported Operations**: `const`, `add`, `sub`, `mul`, `div`, `sum`, `mean`, `min`, `max`, `floor`, `ceil`, `round`, `sqrt`, `pow`, `mod`, `gcd`, `lcm`, `identity`

## Quickstart

### Using Transformers

```python
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "dipta007/dagger-12B_SFT_GRPO"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)

USER_PROMPT_TEMPLATE = """You are an expert Bengali Math Reasoner. Your task is to solve mathematical problems by constructing a "Computational Graph".

### Graph Rules:
- `id`: Unique identifier (e.g., "n1", "n2").
- `val`: The raw number extracted from text (for input nodes).
- `op`: The operation (`add`, `sub`, `mul`, `div`, `round`, `sqrt`, `floor`, `sum`, `mean`, `ratio_split`). Use `const` for input numbers.
- `args`: List of input node IDs.
- `distractor`: Boolean (`true` / `false`). Set to `true` if the node is NOT used in the final calculation path.
- `label`: Label for the node.

### Available Operations:
- Input: `const` (Use this for all numbers found in text or constants).
- Arithmetic: `add`, `sub`, `mul`, `div`, `abs` (absolute difference).
- Logic/Stats: `sum`, `mean`, `min` (minimum), `max` (maximum).
- Rounding: `round` (nearest int), `floor` (round down), `ceil` (round up).
- Advanced: `sqrt`, `pow`, `mod` (remainder), `gcd`, `lcm`.
- Output: `identity` ("final_result" points to the answer node)

Only output a JSON graph representing the solution, nothing else. Nodes must be topologically sorted, and there must be exactly one "final_result" node that represents the final answer. One example is provided below.

### Example:
Question:
মিনার কাছে ১২২১৯৫ টা কলম আছে। রাজুর কাছে ২৫০৮৪ টা কলম আছে। মিনা রাজুর কাছে ১১২৬ টি কলম চাইল। রাজু ১০০০ টি কলম দিতে রাজি হল, কিন্তু পরে আর দিলেনা। প্রতিটি কলমের দাম ৪৫.৬ টাকা। মিনা যদি কলমগুলো বিক্রি করতে চায়, সে কত টাকা পাবে?

Output:
```json
{{
  "nodes": [
    {{"id": "n1", "op": "const", "val": 122195, "distractor": false, "label": "মিনার কলম"}},
    {{"id": "n2", "op": "const", "val": 25084, "distractor": true, "label": "রাজুর কলম"}},
    {{"id": "n3", "op": "const", "val": 1126, "distractor": true, "label": "মিনা রাজুর কাছে চাইল"}},
    {{"id": "n4", "op": "const", "val": 1000, "distractor": true, "label": "রাজু দিতে রাজি হল"}},
    {{"id": "n5", "op": "const", "val": 45.6, "distractor": false, "label": "প্রতিটি কলমের দাম"}},
    {{"id": "total_money", "op": "mul", "args": ["n1", "n5"], "distractor": false, "label": "মিনার মোট টাকা"}},
    {{"id": "final_result", "op": "identity", "args": ["total_money"], "distractor": false, "label": "চূড়ান্ত উত্তর"}}
  ]
}}```

### Your Task:

Question:
{question}

Output:
"""

question = "রজারের 5টি টেনিস বল আছে। সে আরও 2 ক্যান টেনিস বল কিনেছে। প্রতিটি ক্যানে 3টি করে টেনিস বল আছে। তার কাছে এখন কতগুলি টেনিস বল আছে?"
prompt = USER_PROMPT_TEMPLATE.format(question=question)

messages = [
  {"role": "user", "content": prompt}
]

text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to(model.device)

# Generate
outputs = model.generate(**inputs, max_new_tokens=1024, temperature=0.7, top_p=0.8)
response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)

print(response)
```

### Using vLLM

```bash
vllm serve dipta007/dagger-12B_SFT_GRPO --max-model-len 4096
```

```python
from openai import OpenAI

client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY")

response = client.chat.completions.create(
    model="dipta007/dagger-12B_SFT_GRPO",
    messages=[
        {"role": "system", "content": "You are an expert Bangla Math Reasoner..."},
        {"role": "user", "content": "মিনার কাছে ১০০টি কলম আছে..."}
    ],
    max_tokens=1024
)
```

### Graph Execution

```python
import json

def execute_graph(graph_json):
    """Execute a computational graph and return the final result."""
    nodes = {n["id"]: n for n in graph_json["nodes"]}
    cache = {}

    def compute(node_id):
        if node_id in cache:
            return cache[node_id]

        node = nodes[node_id]
        op = node["op"]

        if op == "const":
            result = node["val"]
        elif op == "add":
            result = sum(compute(arg) if isinstance(arg, str) else arg for arg in node["args"])
        elif op == "sub":
            args = [compute(arg) if isinstance(arg, str) else arg for arg in node["args"]]
            result = args[0] - args[1]
        elif op == "mul":
            result = 1
            for arg in node["args"]:
                result *= compute(arg) if isinstance(arg, str) else arg
        elif op == "div":
            args = [compute(arg) if isinstance(arg, str) else arg for arg in node["args"]]
            result = args[0] / args[1]
        elif op == "identity":
            result = compute(node["args"][0])
        # ... add other operations

        cache[node_id] = result
        return result

    return compute("final_result")

# Parse and execute
graph = json.loads(response)
answer = execute_graph(graph)
print(f"Answer: {answer}")
```

## Training Details

### Stage 1: Supervised Fine-Tuning (SFT)

| Parameter | Value |
|-----------|-------|
| Base Model | Gemma-3-12B-Instruct |
| LoRA Rank / Alpha | 64 / 128 |
| Global Batch Size | 256 |
| Epochs | 4 |
| Learning Rate | 1e-5 → 1e-6 (cosine) |
| Training Data | 3,000 examples |

### Stage 2: Group Relative Policy Optimization (GRPO)

| Parameter | Value |
|-----------|-------|
| Base Model | SFT Checkpoint |
| LoRA Rank / Alpha | 64 / 128 |
| Global Batch Size | 32 |
| Generations per Prompt | 8 |
| Epochs | 4 |
| Loss Type | BNPO |
| β / ε / ε_high | 0.0 / 0.2 / 0.28 |

**Reward Function:**
```
R(g, y) = 0.5 * I_fmt + 0.5 * I_exec + I_acc(exec(g), y)
```
- `I_fmt`: Valid JSON format (+0.5)
- `I_exec`: Successful execution (+0.5)
- `I_acc`: Correct answer (+1.0)

## Best Practices

1. **Temperature**: Use `temperature=0.7` with `top_p=0.8` for best results
2. **Max Tokens**: 1024 tokens is sufficient for most problems
3. **System Prompt**: Include the graph generation instructions in system message
4. **Post-processing**: Parse JSON and execute graph for final numeric answer

## Limitations

- Designed for arithmetic word problems; may not generalize to algebra, geometry, or calculus
- Primarily trained on Bangla; English performance not evaluated
- Requires JSON parsing and graph execution for final answers
- 4B variant shows lower performance, suggesting capacity requirements

## Related Models

| Model | Training | Weighted Avg |
|-------|----------|--------------|
| [dagger-12B_SFT_GRPO](https://huggingface.co/dipta007/dagger-12B_SFT_GRPO) | SFT → GRPO | **69.4** |
| [dagger-12B_SFT](https://huggingface.co/dipta007/dagger-12B_SFT) | SFT only | 66.7 |
| [dagger-12B_GRPO](https://huggingface.co/dipta007/dagger-12B_GRPO) | Base → GRPO | 69.4 |
| [dagger-4B_SFT_GRPO](https://huggingface.co/dipta007/dagger-4B_SFT_GRPO) | SFT → GRPO | 47.3 |

## Citation

```bibtex
@misc{nazi2026dagdaggerdistractorawaregraphgeneration,
      title={{\dag}DAGGER: Distractor-Aware Graph Generation for Executable Reasoning in Math Problems}, 
      author={Zabir Al Nazi and Shubhashis Roy Dipta and Sudipta Kar},
      year={2026},
      eprint={2601.06853},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2601.06853}, 
}
```

## Acknowledgments

- [Google Gemma](https://ai.google.dev/gemma) for the base model
- [Unsloth](https://github.com/unslothai/unsloth) for efficient fine-tuning
- [TRL](https://github.com/huggingface/trl) for GRPO implementation