language: ko
license: apache-2.0
base_model: google/gemma-3-1b-it
tags:
- math
- korean
- rejection-sampling
- sft
- gemma
datasets:
- NotoriousH2/HRM8K
Gemma-3-1B-IT Math RS-SFT (Best Model)
SFT โ Rejection Sampling โ SFT 2๋จ๊ณ ํ์ดํ๋ผ์ธ์ผ๋ก ํ์ตํ ํ๊ตญ์ด ์ํ ๋ชจ๋ธ. ์ต๊ณ ์ฑ๋ฅ.
์ฑ๋ฅ
| Benchmark | Score |
|---|---|
| HRM8K eval GSM8K (264๋ฌธ์ , Korean) | ~46.6% avg, 48.9% best run |
| HRM8K eval MATH (577๋ฌธ์ , Korean) | ~17% |
โ ๏ธ temperature=0์์๋ vLLM inference variance ยฑ2-4%p ์กด์ฌ. ์ ์์น๋ 3ํ ํ๊ฐ ํ๊ท .
๋ฐ์ดํฐ ์์ฑ ํ์ดํ๋ผ์ธ
Stage 1: SFT ๋ฐ์ดํฐ (๊ต์ฌ ์ฆ๋ฅ)
์ SFT ๋ชจ๋ธ๊ณผ ๋์ผ. GSM8K 7,473๋ฌธ์ โ Qwen3-30B๋ก ํ๊ตญ์ด ํ์ด 26,254๊ฐ ์์ฑ.
Stage 2: RS ๋ฐ์ดํฐ (On-policy ์ํ๋ง)
RS ์ํ๋ง
RS ๋ฐ์ดํฐ ํํฐ๋ง
RS-SFT ํ์ต ๋ฐ์ดํฐ ๊ตฌ์ฑ (ํต์ฌ!)
Replay๊ฐ ํต์ฌ: RS ๋ฐ์ดํฐ๋ง ์ฌ์ฉํ๋ฉด ๊ต์ฌ ํ์ด ํจํด์ ์์ด ์ฑ๋ฅ ํ๋ฝ (catastrophic forgetting).
| Replay ๋น์จ | GSM8K | ๋น๊ณ |
|---|---|---|
| 0x (RS only) | 46.2% | forgetting |
| 2x | 46.6% | ๋ถ์กฑ |
| 3x | 48.5% | ์ํธ |
| 5x | 48.9% | ์ต์ |
| max (์ ๋ถ) | 47.3% | RS ํฌ์ |
RS-SFT ํ์ต ๋ฐ์ดํฐ ํ์
SFT์ ๋์ผํ question/answer JSON. ์ฐจ์ด์ ์ answer๊ฐ ํ์ ๋ชจ๋ธ(SFT)์ด ์ค์ค๋ก ์์ฑํ ์ ๋ต ํ์ด๋ผ๋ ๊ฒ.
ํ์ต ์ค์
Stage 1: SFT
Stage 2: RS-SFT
์ฌํ ๋ฐฉ๋ฒ
INFO 03-19 14:53:13 [init.py:216] Automatically detected platform cuda. [1;36m(APIServer pid=3428638)[0;0m INFO 03-19 14:53:19 [api_server.py:1839] vLLM API server version 0.11.0 [1;36m(APIServer pid=3428638)[0;0m INFO 03-19 14:53:19 [utils.py:233] non-default args: {'model_tag': './sft_model', 'model': './sft_model', 'dtype': 'bfloat16', 'max_model_len': 4096, 'gpu_memory_utilization': 0.85} INFO 03-19 14:53:25 [init.py:216] Automatically detected platform cuda. [1;36m(APIServer pid=3428911)[0;0m INFO 03-19 14:53:31 [api_server.py:1839] vLLM API server version 0.11.0 [1;36m(APIServer pid=3428911)[0;0m INFO 03-19 14:53:31 [utils.py:233] non-default args: {'model_tag': './rs_sft_model', 'model': './rs_sft_model', 'dtype': 'bfloat16', 'max_model_len': 4096, 'gpu_memory_utilization': 0.85}
์คํจํ ์ ๊ทผ๋ค (์ฐธ๊ณ )
- Iterative RS (RS ๋ชจ๋ธ ์์ ๋ค์ RS): ํญ์ ํด๋ณด
- DPO (10๊ฐ์ง ์๋): ๋ชจ๋ ๋ฌดํจ (1B ๋ชจ๋ธ capacity ๋ถ์กฑ)
- GRPO (2๊ฐ์ง ์๋): base variance ๋ฒ์ ๋ด
- ๋ค๋ฅธ ๊ต์ฌ ๋ชจ๋ธ: ์คํ์ผ ๋ถ์ผ์น๋ก ๋ํญ ํ๋ฝ
ํ์ผ
- : Stage 1 SFT ํ์ต
- : RS ์ํ๋ง ์คํฌ๋ฆฝํธ (vLLM ์๋น ํ์)
- : Stage 2 RS-SFT ํ์ต (replay ํฌํจ)
- : HRM8K ํ๊ฐ