| --- |
| language: ko |
| license: apache-2.0 |
| base_model: google/gemma-3-1b-it |
| tags: |
| - math |
| - korean |
| - grpo |
| - rl |
| - gemma |
| datasets: |
| - NotoriousH2/HRM8K |
| --- |
| |
| # Gemma-3-1B-IT Math GRPO |
|
|
| SFT โ RS-SFT โ GRPO 3๋จ๊ณ ํ์ดํ๋ผ์ธ์ผ๋ก ํ์ตํ ํ๊ตญ์ด ์ํ ๋ชจ๋ธ. |
|
|
| ## ์ฑ๋ฅ |
|
|
| | Benchmark | Score | |
| |-----------|-------| |
| | HRM8K eval GSM8K (264๋ฌธ์ , Korean) | **~46.2%** | |
| | HRM8K eval MATH (577๋ฌธ์ , Korean) | ~16.5% | |
|
|
| > โ ๏ธ GRPO๋ base (RS-SFT, ~46.6%)๋๋น ์ ์๋ฏธํ ๊ฐ์ ์์. ์ด๋ฏธ SFT+RS-SFT๋ก ์ต์ ํ๋ 1B ๋ชจ๋ธ์์ RL ์ถ๊ฐ ์ฌ์ง๊ฐ ๊ทนํ ์ ํ์ . |
|
|
| ## ๋ฐ์ดํฐ & ํ์ต ํ์ดํ๋ผ์ธ |
|
|
| ### Stage 1-2: SFT โ RS-SFT |
| ์ RS-SFT ๋ชจ๋ธ๊ณผ ๋์ผ. (SFT ๊ต์ฌ ์ฆ๋ฅ โ RS ์ํ๋ง โ RS-SFT with 5x replay) |
|
|
| ### Stage 3: GRPO |
|
|
| #### Reward ํจ์ |
|
|
|
|
| #### ํ์ต ๋ฐ์ดํฐ |
| - **ํ๋กฌํํธ๋ง ํ์** (GRPO๋ ํ์ต ์ค ์์ฒด ์์ฑ) |
| - GSM8K train 6,871๊ฐ unique ํ๊ตญ์ด ๋ฌธ์ |
| - ๊ฐ ๋ฌธ์ ์ ground truth ๋ต ํฌํจ (reward ๊ณ์ฐ์ฉ) |
|
|
| #### ํ์ต ์ค์ |
|
|
|
|
| ## DPO vs GRPO ๋น๊ต (์คํ ๊ฒฐ๊ณผ) |
|
|
| ### DPO ์คํจ ๋ถ์ (10ํ) |
| | ๋ฐ์ดํฐ ์ ๋ต | GSM8K | ๋ถ์ | |
| |------------|-------|------| |
| | ๊ธฐ์กด ๋ฐฉ์ (shortest correct + longest incorrect) | 48.1% | ๊ธธ์ด ํธํฅ๋ง ํ์ต | |
| | Length-matched (55์ ์ฐจ์ด) | 46.2% | ์ ํธ ์์ (DPO accโ0.50) | |
| | Teacher-chosen (30B ๊ต์ฌ ํ์ด=chosen) | 47.3% | OOD ๋ฌธ์ | |
| | Multi-pair (์ง๋ฌธ๋น 3์, ๋์ด๋ ๊ฐ์ค) | 46.6% | ์ ์ฆ๊ฐ๋ ๋ฌดํจ | |
| | base ๋๋น | ยฑ0-2%p | ๋ชจ๋ variance ๋ฒ์ | |
|
|
| **DPO ๊ทผ๋ณธ ๋ฌธ์ **: 1B ๋ชจ๋ธ์ด ์ ๋ต/์ค๋ต ํ์ด์ ๋ฏธ๋ฌํ ์ฐจ์ด๋ฅผ ๋ด๋ถ์ ์ผ๋ก ๊ตฌ๋ถํ capacity ๋ถ์กฑ. |
|
|
| ### GRPO ๊ฒฐ๊ณผ (2ํ) |
| | beta | steps | GSM8K | ๋น๊ณ | |
| |------|-------|-------|------| |
| | 0.001 | 200 | 43.9% | format ํด๋ณด (boxedโ) | |
| | 0.04 | 500 | 46.2% | base์ ๋์ผ ์์ค | |
|
|
| ## ํ๊ฒฝ |
| - GPU: H100 NVL 95GB |
| - Framework: trl 0.29.0, transformers 4.57.3, vllm 0.11.0 |
| - GRPO ํ์ต: ~55๋ถ (vLLM colocate ์ฌ์ฉ) |
|
|
| ## ์ฌํ ๋ฐฉ๋ฒ |
|
|
| INFO 03-19 14:53:37 [__init__.py:216] Automatically detected platform cuda. |
| [1;36m(APIServer pid=3429210)[0;0m INFO 03-19 14:53:43 [api_server.py:1839] vLLM API server version 0.11.0 |
| [1;36m(APIServer pid=3429210)[0;0m INFO 03-19 14:53:43 [utils.py:233] non-default args: {'model_tag': './grpo_model', 'model': './grpo_model', 'dtype': 'bfloat16', 'max_model_len': 4096, 'gpu_memory_utilization': 0.85} |
|
|
| ## ํ์ผ |
| - : Stage 1 SFT |
| - : RS ์ํ๋ง |
| - : Stage 2 RS-SFT |
| - : Stage 3 GRPO |
| - : HRM8K ํ๊ฐ |
|
|