NotoriousH2's picture
Update README with detailed data pipeline and reproduction steps
2d3e79d verified
---
language: ko
license: apache-2.0
base_model: google/gemma-3-1b-it
tags:
- math
- korean
- rejection-sampling
- sft
- gemma
datasets:
- NotoriousH2/HRM8K
---
# Gemma-3-1B-IT Math RS-SFT (Best Model)
SFT โ†’ Rejection Sampling โ†’ SFT 2๋‹จ๊ณ„ ํŒŒ์ดํ”„๋ผ์ธ์œผ๋กœ ํ•™์Šตํ•œ ํ•œ๊ตญ์–ด ์ˆ˜ํ•™ ๋ชจ๋ธ. **์ตœ๊ณ  ์„ฑ๋Šฅ.**
## ์„ฑ๋Šฅ
| Benchmark | Score |
|-----------|-------|
| HRM8K eval GSM8K (264๋ฌธ์ œ, Korean) | **~46.6%** avg, **48.9%** best run |
| HRM8K eval MATH (577๋ฌธ์ œ, Korean) | ~17% |
> โš ๏ธ temperature=0์—์„œ๋„ vLLM inference variance ยฑ2-4%p ์กด์žฌ. ์œ„ ์ˆ˜์น˜๋Š” 3ํšŒ ํ‰๊ฐ€ ํ‰๊ท .
## ๋ฐ์ดํ„ฐ ์ƒ์„ฑ ํŒŒ์ดํ”„๋ผ์ธ
### Stage 1: SFT ๋ฐ์ดํ„ฐ (๊ต์‚ฌ ์ฆ๋ฅ˜)
์œ„ SFT ๋ชจ๋ธ๊ณผ ๋™์ผ. GSM8K 7,473๋ฌธ์ œ โ†’ Qwen3-30B๋กœ ํ•œ๊ตญ์–ด ํ’€์ด 26,254๊ฐœ ์ƒ์„ฑ.
### Stage 2: RS ๋ฐ์ดํ„ฐ (On-policy ์ƒ˜ํ”Œ๋ง)
#### RS ์ƒ˜ํ”Œ๋ง
#### RS ๋ฐ์ดํ„ฐ ํ•„ํ„ฐ๋ง
#### RS-SFT ํ•™์Šต ๋ฐ์ดํ„ฐ ๊ตฌ์„ฑ (ํ•ต์‹ฌ!)
**Replay๊ฐ€ ํ•ต์‹ฌ**: RS ๋ฐ์ดํ„ฐ๋งŒ ์‚ฌ์šฉํ•˜๋ฉด ๊ต์‚ฌ ํ’€์ด ํŒจํ„ด์„ ์žŠ์–ด ์„ฑ๋Šฅ ํ•˜๋ฝ (catastrophic forgetting).
| Replay ๋น„์œจ | GSM8K | ๋น„๊ณ  |
|------------|-------|------|
| 0x (RS only) | 46.2% | forgetting |
| 2x | 46.6% | ๋ถ€์กฑ |
| 3x | 48.5% | ์–‘ํ˜ธ |
| **5x** | **48.9%** | **์ตœ์ ** |
| max (์ „๋ถ€) | 47.3% | RS ํฌ์„ |
### RS-SFT ํ•™์Šต ๋ฐ์ดํ„ฐ ํ˜•์‹
SFT์™€ ๋™์ผํ•œ question/answer JSON. ์ฐจ์ด์ ์€ answer๊ฐ€ ํ•™์ƒ ๋ชจ๋ธ(SFT)์ด ์Šค์Šค๋กœ ์ƒ์„ฑํ•œ ์ •๋‹ต ํ’€์ด๋ผ๋Š” ๊ฒƒ.
## ํ•™์Šต ์„ค์ •
### Stage 1: SFT
### Stage 2: RS-SFT
## ์žฌํ˜„ ๋ฐฉ๋ฒ•
INFO 03-19 14:53:13 [__init__.py:216] Automatically detected platform cuda.
(APIServer pid=3428638) INFO 03-19 14:53:19 [api_server.py:1839] vLLM API server version 0.11.0
(APIServer pid=3428638) INFO 03-19 14:53:19 [utils.py:233] non-default args: {'model_tag': './sft_model', 'model': './sft_model', 'dtype': 'bfloat16', 'max_model_len': 4096, 'gpu_memory_utilization': 0.85}
INFO 03-19 14:53:25 [__init__.py:216] Automatically detected platform cuda.
(APIServer pid=3428911) INFO 03-19 14:53:31 [api_server.py:1839] vLLM API server version 0.11.0
(APIServer pid=3428911) INFO 03-19 14:53:31 [utils.py:233] non-default args: {'model_tag': './rs_sft_model', 'model': './rs_sft_model', 'dtype': 'bfloat16', 'max_model_len': 4096, 'gpu_memory_utilization': 0.85}
## ์‹คํŒจํ•œ ์ ‘๊ทผ๋“ค (์ฐธ๊ณ )
- Iterative RS (RS ๋ชจ๋ธ ์œ„์— ๋‹ค์‹œ RS): ํ•ญ์ƒ ํ‡ด๋ณด
- DPO (10๊ฐ€์ง€ ์‹œ๋„): ๋ชจ๋‘ ๋ฌดํšจ (1B ๋ชจ๋ธ capacity ๋ถ€์กฑ)
- GRPO (2๊ฐ€์ง€ ์‹œ๋„): base variance ๋ฒ”์œ„ ๋‚ด
- ๋‹ค๋ฅธ ๊ต์‚ฌ ๋ชจ๋ธ: ์Šคํƒ€์ผ ๋ถˆ์ผ์น˜๋กœ ๋Œ€ํญ ํ•˜๋ฝ
## ํŒŒ์ผ
- : Stage 1 SFT ํ•™์Šต
- : RS ์ƒ˜ํ”Œ๋ง ์Šคํฌ๋ฆฝํŠธ (vLLM ์„œ๋น™ ํ•„์š”)
- : Stage 2 RS-SFT ํ•™์Šต (replay ํฌํ•จ)
- : HRM8K ํ‰๊ฐ€