| --- |
| language: ko |
| license: apache-2.0 |
| base_model: google/gemma-3-1b-it |
| tags: |
| - math |
| - korean |
| - rejection-sampling |
| - sft |
| - gemma |
| datasets: |
| - NotoriousH2/HRM8K |
| --- |
| |
| # Gemma-3-1B-IT Math RS-SFT (Best Model) |
|
|
| SFT โ Rejection Sampling โ SFT 2๋จ๊ณ ํ์ดํ๋ผ์ธ์ผ๋ก ํ์ตํ ํ๊ตญ์ด ์ํ ๋ชจ๋ธ. **์ต๊ณ ์ฑ๋ฅ.** |
|
|
| ## ์ฑ๋ฅ |
|
|
| | Benchmark | Score | |
| |-----------|-------| |
| | HRM8K eval GSM8K (264๋ฌธ์ , Korean) | **~46.6%** avg, **48.9%** best run | |
| | HRM8K eval MATH (577๋ฌธ์ , Korean) | ~17% | |
|
|
| > โ ๏ธ temperature=0์์๋ vLLM inference variance ยฑ2-4%p ์กด์ฌ. ์ ์์น๋ 3ํ ํ๊ฐ ํ๊ท . |
|
|
| ## ๋ฐ์ดํฐ ์์ฑ ํ์ดํ๋ผ์ธ |
|
|
| ### Stage 1: SFT ๋ฐ์ดํฐ (๊ต์ฌ ์ฆ๋ฅ) |
| ์ SFT ๋ชจ๋ธ๊ณผ ๋์ผ. GSM8K 7,473๋ฌธ์ โ Qwen3-30B๋ก ํ๊ตญ์ด ํ์ด 26,254๊ฐ ์์ฑ. |
|
|
| ### Stage 2: RS ๋ฐ์ดํฐ (On-policy ์ํ๋ง) |
|
|
| #### RS ์ํ๋ง |
|
|
|
|
| #### RS ๋ฐ์ดํฐ ํํฐ๋ง |
|
|
|
|
| #### RS-SFT ํ์ต ๋ฐ์ดํฐ ๊ตฌ์ฑ (ํต์ฌ!) |
|
|
|
|
| **Replay๊ฐ ํต์ฌ**: RS ๋ฐ์ดํฐ๋ง ์ฌ์ฉํ๋ฉด ๊ต์ฌ ํ์ด ํจํด์ ์์ด ์ฑ๋ฅ ํ๋ฝ (catastrophic forgetting). |
|
|
| | Replay ๋น์จ | GSM8K | ๋น๊ณ | |
| |------------|-------|------| |
| | 0x (RS only) | 46.2% | forgetting | |
| | 2x | 46.6% | ๋ถ์กฑ | |
| | 3x | 48.5% | ์ํธ | |
| | **5x** | **48.9%** | **์ต์ ** | |
| | max (์ ๋ถ) | 47.3% | RS ํฌ์ | |
|
|
| ### RS-SFT ํ์ต ๋ฐ์ดํฐ ํ์ |
| SFT์ ๋์ผํ question/answer JSON. ์ฐจ์ด์ ์ answer๊ฐ ํ์ ๋ชจ๋ธ(SFT)์ด ์ค์ค๋ก ์์ฑํ ์ ๋ต ํ์ด๋ผ๋ ๊ฒ. |
|
|
| ## ํ์ต ์ค์ |
|
|
| ### Stage 1: SFT |
|
|
|
|
| ### Stage 2: RS-SFT |
|
|
|
|
| ## ์ฌํ ๋ฐฉ๋ฒ |
|
|
| INFO 03-19 14:53:13 [__init__.py:216] Automatically detected platform cuda. |
| [1;36m(APIServer pid=3428638)[0;0m INFO 03-19 14:53:19 [api_server.py:1839] vLLM API server version 0.11.0 |
| [1;36m(APIServer pid=3428638)[0;0m INFO 03-19 14:53:19 [utils.py:233] non-default args: {'model_tag': './sft_model', 'model': './sft_model', 'dtype': 'bfloat16', 'max_model_len': 4096, 'gpu_memory_utilization': 0.85} |
| INFO 03-19 14:53:25 [__init__.py:216] Automatically detected platform cuda. |
| [1;36m(APIServer pid=3428911)[0;0m INFO 03-19 14:53:31 [api_server.py:1839] vLLM API server version 0.11.0 |
| [1;36m(APIServer pid=3428911)[0;0m INFO 03-19 14:53:31 [utils.py:233] non-default args: {'model_tag': './rs_sft_model', 'model': './rs_sft_model', 'dtype': 'bfloat16', 'max_model_len': 4096, 'gpu_memory_utilization': 0.85} |
|
|
| ## ์คํจํ ์ ๊ทผ๋ค (์ฐธ๊ณ ) |
| - Iterative RS (RS ๋ชจ๋ธ ์์ ๋ค์ RS): ํญ์ ํด๋ณด |
| - DPO (10๊ฐ์ง ์๋): ๋ชจ๋ ๋ฌดํจ (1B ๋ชจ๋ธ capacity ๋ถ์กฑ) |
| - GRPO (2๊ฐ์ง ์๋): base variance ๋ฒ์ ๋ด |
| - ๋ค๋ฅธ ๊ต์ฌ ๋ชจ๋ธ: ์คํ์ผ ๋ถ์ผ์น๋ก ๋ํญ ํ๋ฝ |
|
|
| ## ํ์ผ |
| - : Stage 1 SFT ํ์ต |
| - : RS ์ํ๋ง ์คํฌ๋ฆฝํธ (vLLM ์๋น ํ์) |
| - : Stage 2 RS-SFT ํ์ต (replay ํฌํจ) |
| - : HRM8K ํ๊ฐ |
|
|