NotoriousH2's picture
Update README with detailed data pipeline and reproduction steps
2d3e79d verified
metadata
language: ko
license: apache-2.0
base_model: google/gemma-3-1b-it
tags:
  - math
  - korean
  - rejection-sampling
  - sft
  - gemma
datasets:
  - NotoriousH2/HRM8K

Gemma-3-1B-IT Math RS-SFT (Best Model)

SFT โ†’ Rejection Sampling โ†’ SFT 2๋‹จ๊ณ„ ํŒŒ์ดํ”„๋ผ์ธ์œผ๋กœ ํ•™์Šตํ•œ ํ•œ๊ตญ์–ด ์ˆ˜ํ•™ ๋ชจ๋ธ. ์ตœ๊ณ  ์„ฑ๋Šฅ.

์„ฑ๋Šฅ

Benchmark Score
HRM8K eval GSM8K (264๋ฌธ์ œ, Korean) ~46.6% avg, 48.9% best run
HRM8K eval MATH (577๋ฌธ์ œ, Korean) ~17%

โš ๏ธ temperature=0์—์„œ๋„ vLLM inference variance ยฑ2-4%p ์กด์žฌ. ์œ„ ์ˆ˜์น˜๋Š” 3ํšŒ ํ‰๊ฐ€ ํ‰๊ท .

๋ฐ์ดํ„ฐ ์ƒ์„ฑ ํŒŒ์ดํ”„๋ผ์ธ

Stage 1: SFT ๋ฐ์ดํ„ฐ (๊ต์‚ฌ ์ฆ๋ฅ˜)

์œ„ SFT ๋ชจ๋ธ๊ณผ ๋™์ผ. GSM8K 7,473๋ฌธ์ œ โ†’ Qwen3-30B๋กœ ํ•œ๊ตญ์–ด ํ’€์ด 26,254๊ฐœ ์ƒ์„ฑ.

Stage 2: RS ๋ฐ์ดํ„ฐ (On-policy ์ƒ˜ํ”Œ๋ง)

RS ์ƒ˜ํ”Œ๋ง

RS ๋ฐ์ดํ„ฐ ํ•„ํ„ฐ๋ง

RS-SFT ํ•™์Šต ๋ฐ์ดํ„ฐ ๊ตฌ์„ฑ (ํ•ต์‹ฌ!)

Replay๊ฐ€ ํ•ต์‹ฌ: RS ๋ฐ์ดํ„ฐ๋งŒ ์‚ฌ์šฉํ•˜๋ฉด ๊ต์‚ฌ ํ’€์ด ํŒจํ„ด์„ ์žŠ์–ด ์„ฑ๋Šฅ ํ•˜๋ฝ (catastrophic forgetting).

Replay ๋น„์œจ GSM8K ๋น„๊ณ 
0x (RS only) 46.2% forgetting
2x 46.6% ๋ถ€์กฑ
3x 48.5% ์–‘ํ˜ธ
5x 48.9% ์ตœ์ 
max (์ „๋ถ€) 47.3% RS ํฌ์„

RS-SFT ํ•™์Šต ๋ฐ์ดํ„ฐ ํ˜•์‹

SFT์™€ ๋™์ผํ•œ question/answer JSON. ์ฐจ์ด์ ์€ answer๊ฐ€ ํ•™์ƒ ๋ชจ๋ธ(SFT)์ด ์Šค์Šค๋กœ ์ƒ์„ฑํ•œ ์ •๋‹ต ํ’€์ด๋ผ๋Š” ๊ฒƒ.

ํ•™์Šต ์„ค์ •

Stage 1: SFT

Stage 2: RS-SFT

์žฌํ˜„ ๋ฐฉ๋ฒ•

INFO 03-19 14:53:13 [init.py:216] Automatically detected platform cuda. (APIServer pid=3428638) INFO 03-19 14:53:19 [api_server.py:1839] vLLM API server version 0.11.0 (APIServer pid=3428638) INFO 03-19 14:53:19 [utils.py:233] non-default args: {'model_tag': './sft_model', 'model': './sft_model', 'dtype': 'bfloat16', 'max_model_len': 4096, 'gpu_memory_utilization': 0.85} INFO 03-19 14:53:25 [init.py:216] Automatically detected platform cuda. (APIServer pid=3428911) INFO 03-19 14:53:31 [api_server.py:1839] vLLM API server version 0.11.0 (APIServer pid=3428911) INFO 03-19 14:53:31 [utils.py:233] non-default args: {'model_tag': './rs_sft_model', 'model': './rs_sft_model', 'dtype': 'bfloat16', 'max_model_len': 4096, 'gpu_memory_utilization': 0.85}

์‹คํŒจํ•œ ์ ‘๊ทผ๋“ค (์ฐธ๊ณ )

  • Iterative RS (RS ๋ชจ๋ธ ์œ„์— ๋‹ค์‹œ RS): ํ•ญ์ƒ ํ‡ด๋ณด
  • DPO (10๊ฐ€์ง€ ์‹œ๋„): ๋ชจ๋‘ ๋ฌดํšจ (1B ๋ชจ๋ธ capacity ๋ถ€์กฑ)
  • GRPO (2๊ฐ€์ง€ ์‹œ๋„): base variance ๋ฒ”์œ„ ๋‚ด
  • ๋‹ค๋ฅธ ๊ต์‚ฌ ๋ชจ๋ธ: ์Šคํƒ€์ผ ๋ถˆ์ผ์น˜๋กœ ๋Œ€ํญ ํ•˜๋ฝ

ํŒŒ์ผ

  • : Stage 1 SFT ํ•™์Šต
  • : RS ์ƒ˜ํ”Œ๋ง ์Šคํฌ๋ฆฝํŠธ (vLLM ์„œ๋น™ ํ•„์š”)
  • : Stage 2 RS-SFT ํ•™์Šต (replay ํฌํ•จ)
  • : HRM8K ํ‰๊ฐ€