--- language: ko license: apache-2.0 base_model: google/gemma-3-1b-it tags: - math - korean - grpo - rl - gemma datasets: - NotoriousH2/HRM8K --- # Gemma-3-1B-IT Math GRPO SFT → RS-SFT → GRPO 3단계 파이프라인으로 학습한 한국어 수학 모델. ## 성능 | Benchmark | Score | |-----------|-------| | HRM8K eval GSM8K (264문제, Korean) | **~46.2%** | | HRM8K eval MATH (577문제, Korean) | ~16.5% | > ⚠️ GRPO는 base (RS-SFT, ~46.6%)대비 유의미한 개선 없음. 이미 SFT+RS-SFT로 최적화된 1B 모델에서 RL 추가 여지가 극히 제한적. ## 데이터 & 학습 파이프라인 ### Stage 1-2: SFT → RS-SFT 위 RS-SFT 모델과 동일. (SFT 교사 증류 → RS 샘플링 → RS-SFT with 5x replay) ### Stage 3: GRPO #### Reward 함수 #### 학습 데이터 - **프롬프트만 필요** (GRPO는 학습 중 자체 생성) - GSM8K train 6,871개 unique 한국어 문제 - 각 문제에 ground truth 답 포함 (reward 계산용) #### 학습 설정 ## DPO vs GRPO 비교 (실험 결과) ### DPO 실패 분석 (10회) | 데이터 전략 | GSM8K | 분석 | |------------|-------|------| | 기존 방식 (shortest correct + longest incorrect) | 48.1% | 길이 편향만 학습 | | Length-matched (55자 차이) | 46.2% | 신호 없음 (DPO acc≈0.50) | | Teacher-chosen (30B 교사 풀이=chosen) | 47.3% | OOD 문제 | | Multi-pair (질문당 3쌍, 난이도 가중) | 46.6% | 양 증가도 무효 | | base 대비 | ±0-2%p | 모두 variance 범위 | **DPO 근본 문제**: 1B 모델이 정답/오답 풀이의 미묘한 차이를 내부적으로 구분할 capacity 부족. ### GRPO 결과 (2회) | beta | steps | GSM8K | 비고 | |------|-------|-------|------| | 0.001 | 200 | 43.9% | format 퇴보 (boxed↓) | | 0.04 | 500 | 46.2% | base와 동일 수준 | ## 환경 - GPU: H100 NVL 95GB - Framework: trl 0.29.0, transformers 4.57.3, vllm 0.11.0 - GRPO 학습: ~55분 (vLLM colocate 사용) ## 재현 방법 INFO 03-19 14:53:37 [__init__.py:216] Automatically detected platform cuda. (APIServer pid=3429210) INFO 03-19 14:53:43 [api_server.py:1839] vLLM API server version 0.11.0 (APIServer pid=3429210) INFO 03-19 14:53:43 [utils.py:233] non-default args: {'model_tag': './grpo_model', 'model': './grpo_model', 'dtype': 'bfloat16', 'max_model_len': 4096, 'gpu_memory_utilization': 0.85} ## 파일 - : Stage 1 SFT - : RS 샘플링 - : Stage 2 RS-SFT - : Stage 3 GRPO - : HRM8K 평가