NotoriousH2's picture
Update README with detailed data pipeline and reproduction steps
bab0696 verified
---
language: ko
license: apache-2.0
base_model: google/gemma-3-1b-it
tags:
- math
- korean
- grpo
- rl
- gemma
datasets:
- NotoriousH2/HRM8K
---
# Gemma-3-1B-IT Math GRPO
SFT โ†’ RS-SFT โ†’ GRPO 3๋‹จ๊ณ„ ํŒŒ์ดํ”„๋ผ์ธ์œผ๋กœ ํ•™์Šตํ•œ ํ•œ๊ตญ์–ด ์ˆ˜ํ•™ ๋ชจ๋ธ.
## ์„ฑ๋Šฅ
| Benchmark | Score |
|-----------|-------|
| HRM8K eval GSM8K (264๋ฌธ์ œ, Korean) | **~46.2%** |
| HRM8K eval MATH (577๋ฌธ์ œ, Korean) | ~16.5% |
> โš ๏ธ GRPO๋Š” base (RS-SFT, ~46.6%)๋Œ€๋น„ ์œ ์˜๋ฏธํ•œ ๊ฐœ์„  ์—†์Œ. ์ด๋ฏธ SFT+RS-SFT๋กœ ์ตœ์ ํ™”๋œ 1B ๋ชจ๋ธ์—์„œ RL ์ถ”๊ฐ€ ์—ฌ์ง€๊ฐ€ ๊ทนํžˆ ์ œํ•œ์ .
## ๋ฐ์ดํ„ฐ & ํ•™์Šต ํŒŒ์ดํ”„๋ผ์ธ
### Stage 1-2: SFT โ†’ RS-SFT
์œ„ RS-SFT ๋ชจ๋ธ๊ณผ ๋™์ผ. (SFT ๊ต์‚ฌ ์ฆ๋ฅ˜ โ†’ RS ์ƒ˜ํ”Œ๋ง โ†’ RS-SFT with 5x replay)
### Stage 3: GRPO
#### Reward ํ•จ์ˆ˜
#### ํ•™์Šต ๋ฐ์ดํ„ฐ
- **ํ”„๋กฌํ”„ํŠธ๋งŒ ํ•„์š”** (GRPO๋Š” ํ•™์Šต ์ค‘ ์ž์ฒด ์ƒ์„ฑ)
- GSM8K train 6,871๊ฐœ unique ํ•œ๊ตญ์–ด ๋ฌธ์ œ
- ๊ฐ ๋ฌธ์ œ์— ground truth ๋‹ต ํฌํ•จ (reward ๊ณ„์‚ฐ์šฉ)
#### ํ•™์Šต ์„ค์ •
## DPO vs GRPO ๋น„๊ต (์‹คํ—˜ ๊ฒฐ๊ณผ)
### DPO ์‹คํŒจ ๋ถ„์„ (10ํšŒ)
| ๋ฐ์ดํ„ฐ ์ „๋žต | GSM8K | ๋ถ„์„ |
|------------|-------|------|
| ๊ธฐ์กด ๋ฐฉ์‹ (shortest correct + longest incorrect) | 48.1% | ๊ธธ์ด ํŽธํ–ฅ๋งŒ ํ•™์Šต |
| Length-matched (55์ž ์ฐจ์ด) | 46.2% | ์‹ ํ˜ธ ์—†์Œ (DPO accโ‰ˆ0.50) |
| Teacher-chosen (30B ๊ต์‚ฌ ํ’€์ด=chosen) | 47.3% | OOD ๋ฌธ์ œ |
| Multi-pair (์งˆ๋ฌธ๋‹น 3์Œ, ๋‚œ์ด๋„ ๊ฐ€์ค‘) | 46.6% | ์–‘ ์ฆ๊ฐ€๋„ ๋ฌดํšจ |
| base ๋Œ€๋น„ | ยฑ0-2%p | ๋ชจ๋‘ variance ๋ฒ”์œ„ |
**DPO ๊ทผ๋ณธ ๋ฌธ์ œ**: 1B ๋ชจ๋ธ์ด ์ •๋‹ต/์˜ค๋‹ต ํ’€์ด์˜ ๋ฏธ๋ฌ˜ํ•œ ์ฐจ์ด๋ฅผ ๋‚ด๋ถ€์ ์œผ๋กœ ๊ตฌ๋ถ„ํ•  capacity ๋ถ€์กฑ.
### GRPO ๊ฒฐ๊ณผ (2ํšŒ)
| beta | steps | GSM8K | ๋น„๊ณ  |
|------|-------|-------|------|
| 0.001 | 200 | 43.9% | format ํ‡ด๋ณด (boxedโ†“) |
| 0.04 | 500 | 46.2% | base์™€ ๋™์ผ ์ˆ˜์ค€ |
## ํ™˜๊ฒฝ
- GPU: H100 NVL 95GB
- Framework: trl 0.29.0, transformers 4.57.3, vllm 0.11.0
- GRPO ํ•™์Šต: ~55๋ถ„ (vLLM colocate ์‚ฌ์šฉ)
## ์žฌํ˜„ ๋ฐฉ๋ฒ•
INFO 03-19 14:53:37 [__init__.py:216] Automatically detected platform cuda.
(APIServer pid=3429210) INFO 03-19 14:53:43 [api_server.py:1839] vLLM API server version 0.11.0
(APIServer pid=3429210) INFO 03-19 14:53:43 [utils.py:233] non-default args: {'model_tag': './grpo_model', 'model': './grpo_model', 'dtype': 'bfloat16', 'max_model_len': 4096, 'gpu_memory_utilization': 0.85}
## ํŒŒ์ผ
- : Stage 1 SFT
- : RS ์ƒ˜ํ”Œ๋ง
- : Stage 2 RS-SFT
- : Stage 3 GRPO
- : HRM8K ํ‰๊ฐ€