File size: 2,472 Bytes
9adcdbb bab0696 9adcdbb bab0696 9adcdbb bab0696 9adcdbb bab0696 9adcdbb bab0696 9adcdbb bab0696 9adcdbb bab0696 9adcdbb bab0696 9adcdbb bab0696 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | ---
language: ko
license: apache-2.0
base_model: google/gemma-3-1b-it
tags:
- math
- korean
- grpo
- rl
- gemma
datasets:
- NotoriousH2/HRM8K
---
# Gemma-3-1B-IT Math GRPO
SFT โ RS-SFT โ GRPO 3๋จ๊ณ ํ์ดํ๋ผ์ธ์ผ๋ก ํ์ตํ ํ๊ตญ์ด ์ํ ๋ชจ๋ธ.
## ์ฑ๋ฅ
| Benchmark | Score |
|-----------|-------|
| HRM8K eval GSM8K (264๋ฌธ์ , Korean) | **~46.2%** |
| HRM8K eval MATH (577๋ฌธ์ , Korean) | ~16.5% |
> โ ๏ธ GRPO๋ base (RS-SFT, ~46.6%)๋๋น ์ ์๋ฏธํ ๊ฐ์ ์์. ์ด๋ฏธ SFT+RS-SFT๋ก ์ต์ ํ๋ 1B ๋ชจ๋ธ์์ RL ์ถ๊ฐ ์ฌ์ง๊ฐ ๊ทนํ ์ ํ์ .
## ๋ฐ์ดํฐ & ํ์ต ํ์ดํ๋ผ์ธ
### Stage 1-2: SFT โ RS-SFT
์ RS-SFT ๋ชจ๋ธ๊ณผ ๋์ผ. (SFT ๊ต์ฌ ์ฆ๋ฅ โ RS ์ํ๋ง โ RS-SFT with 5x replay)
### Stage 3: GRPO
#### Reward ํจ์
#### ํ์ต ๋ฐ์ดํฐ
- **ํ๋กฌํํธ๋ง ํ์** (GRPO๋ ํ์ต ์ค ์์ฒด ์์ฑ)
- GSM8K train 6,871๊ฐ unique ํ๊ตญ์ด ๋ฌธ์
- ๊ฐ ๋ฌธ์ ์ ground truth ๋ต ํฌํจ (reward ๊ณ์ฐ์ฉ)
#### ํ์ต ์ค์
## DPO vs GRPO ๋น๊ต (์คํ ๊ฒฐ๊ณผ)
### DPO ์คํจ ๋ถ์ (10ํ)
| ๋ฐ์ดํฐ ์ ๋ต | GSM8K | ๋ถ์ |
|------------|-------|------|
| ๊ธฐ์กด ๋ฐฉ์ (shortest correct + longest incorrect) | 48.1% | ๊ธธ์ด ํธํฅ๋ง ํ์ต |
| Length-matched (55์ ์ฐจ์ด) | 46.2% | ์ ํธ ์์ (DPO accโ0.50) |
| Teacher-chosen (30B ๊ต์ฌ ํ์ด=chosen) | 47.3% | OOD ๋ฌธ์ |
| Multi-pair (์ง๋ฌธ๋น 3์, ๋์ด๋ ๊ฐ์ค) | 46.6% | ์ ์ฆ๊ฐ๋ ๋ฌดํจ |
| base ๋๋น | ยฑ0-2%p | ๋ชจ๋ variance ๋ฒ์ |
**DPO ๊ทผ๋ณธ ๋ฌธ์ **: 1B ๋ชจ๋ธ์ด ์ ๋ต/์ค๋ต ํ์ด์ ๋ฏธ๋ฌํ ์ฐจ์ด๋ฅผ ๋ด๋ถ์ ์ผ๋ก ๊ตฌ๋ถํ capacity ๋ถ์กฑ.
### GRPO ๊ฒฐ๊ณผ (2ํ)
| beta | steps | GSM8K | ๋น๊ณ |
|------|-------|-------|------|
| 0.001 | 200 | 43.9% | format ํด๋ณด (boxedโ) |
| 0.04 | 500 | 46.2% | base์ ๋์ผ ์์ค |
## ํ๊ฒฝ
- GPU: H100 NVL 95GB
- Framework: trl 0.29.0, transformers 4.57.3, vllm 0.11.0
- GRPO ํ์ต: ~55๋ถ (vLLM colocate ์ฌ์ฉ)
## ์ฌํ ๋ฐฉ๋ฒ
INFO 03-19 14:53:37 [__init__.py:216] Automatically detected platform cuda.
[1;36m(APIServer pid=3429210)[0;0m INFO 03-19 14:53:43 [api_server.py:1839] vLLM API server version 0.11.0
[1;36m(APIServer pid=3429210)[0;0m INFO 03-19 14:53:43 [utils.py:233] non-default args: {'model_tag': './grpo_model', 'model': './grpo_model', 'dtype': 'bfloat16', 'max_model_len': 4096, 'gpu_memory_utilization': 0.85}
## ํ์ผ
- : Stage 1 SFT
- : RS ์ํ๋ง
- : Stage 2 RS-SFT
- : Stage 3 GRPO
- : HRM8K ํ๊ฐ
|