"""C20: Variants of C18-2 (the 48.5% recipe) with different replay ratios""" import json, re, random, torch, numpy as np, os from transformers import AutoModelForCausalLM, AutoTokenizer from trl import SFTTrainer, SFTConfig from datasets import Dataset SEED = 42 random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED) if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED) if torch.cuda.get_device_capability()[0] >= 8: torch.set_float32_matmul_precision('high') SP = "주어진 수학 문제를 단계별로 풀고 답변을 작성하세요.\n반드시 최종 답변을 \\boxed{정수} 형식으로 마지막 줄에 출력하세요.\n예시: \\boxed{42}" BASE = "outputs/models/c17d-gemma-3-1b-it-Math" # Load RS1+RS2 (the winning combo) with open("outputs/c17d_rs/sft_dataset.json") as f: rs1 = json.load(f) with open("outputs/c17d_rs2/sft_dataset.json") as f: rs2 = json.load(f) seen = set() rs_combined = [] for d in rs1 + rs2: if d["n_correct"] < 4: continue key = (d["question"], d["answer"]) if key not in seen: seen.add(key) rs_combined.append({"question": d["question"], "answer": d["answer"], "source": "gsm8k"}) print(f"RS1+RS2 combined: {len(rs_combined)}") with open("data/GSM8K_full_qwen3_30b.json") as f: orig_data = json.load(f) orig_filtered = [d for d in orig_data if len(d["answer"]) <= 1500] def to_sft(ex): return {"prompt": [{"role":"user","content":SP+"\n\n"+ex["question"]}], "completion": [{"role":"assistant","content":ex["answer"]}]} # === Condition 1: RS1+RS2 + 2x replay (more aggressive RS) === print("\n=== C20-1: RS1+RS2 + 2x replay ===") random.seed(SEED) rs_qs = set(d["question"] for d in rs_combined) replay = [d for d in orig_filtered if d["question"] not in rs_qs] random.shuffle(replay) replay1 = replay[:int(len(rs_combined) * 2)] mixed1 = rs_combined + replay1 random.shuffle(mixed1) print(f" RS: {len(rs_combined)} + replay: {len(replay1)} = {len(mixed1)}") ds1 = Dataset.from_list(mixed1) cols = [c for c in ds1.column_names if c not in ["prompt","completion"]] ds1 = ds1.map(to_sft, remove_columns=cols) tokenizer = AutoTokenizer.from_pretrained(BASE) model = AutoModelForCausalLM.from_pretrained(BASE, dtype=torch.bfloat16, device_map="auto", attn_implementation='flash_attention_2') tokenizer.pad_token = tokenizer.eos_token model.gradient_checkpointing_enable(); model.config.use_cache = False cfg1 = SFTConfig(report_to='none', seed=SEED, num_train_epochs=1, warmup_ratio=0.05, weight_decay=0.01, max_grad_norm=1.0, per_device_train_batch_size=8, gradient_accumulation_steps=4, max_length=2048, lr_scheduler_type='cosine', learning_rate=2e-6, bf16=True, optim="paged_adamw_8bit", output_dir="outputs/c20_1_ckpt", logging_steps=25, save_strategy="no") trainer = SFTTrainer(model=model, processing_class=tokenizer, train_dataset=ds1, args=cfg1) r = trainer.train() print(f" Loss: {r.training_loss:.4f}") SAVE1 = "outputs/models/c20-1-2x-replay" os.makedirs(SAVE1, exist_ok=True) model.eval(); model.save_pretrained(SAVE1, safe_serialization=False) tokenizer.save_pretrained(SAVE1) del model, trainer; torch.cuda.empty_cache() # === Condition 2: RS1+RS2 + 5x replay (more teacher data) === print("\n=== C20-2: RS1+RS2 + 5x replay ===") random.seed(SEED) replay = [d for d in orig_filtered if d["question"] not in rs_qs] random.shuffle(replay) replay2 = replay[:int(len(rs_combined) * 5)] mixed2 = rs_combined + replay2 random.shuffle(mixed2) print(f" RS: {len(rs_combined)} + replay: {len(replay2)} = {len(mixed2)}") ds2 = Dataset.from_list(mixed2) cols = [c for c in ds2.column_names if c not in ["prompt","completion"]] ds2 = ds2.map(to_sft, remove_columns=cols) tokenizer = AutoTokenizer.from_pretrained(BASE) model = AutoModelForCausalLM.from_pretrained(BASE, dtype=torch.bfloat16, device_map="auto", attn_implementation='flash_attention_2') tokenizer.pad_token = tokenizer.eos_token model.gradient_checkpointing_enable(); model.config.use_cache = False cfg2 = SFTConfig(report_to='none', seed=SEED, num_train_epochs=1, warmup_ratio=0.05, weight_decay=0.01, max_grad_norm=1.0, per_device_train_batch_size=8, gradient_accumulation_steps=4, max_length=2048, lr_scheduler_type='cosine', learning_rate=2e-6, bf16=True, optim="paged_adamw_8bit", output_dir="outputs/c20_2_ckpt", logging_steps=25, save_strategy="no") trainer = SFTTrainer(model=model, processing_class=tokenizer, train_dataset=ds2, args=cfg2) r = trainer.train() print(f" Loss: {r.training_loss:.4f}") SAVE2 = "outputs/models/c20-2-5x-replay" os.makedirs(SAVE2, exist_ok=True) model.eval(); model.save_pretrained(SAVE2, safe_serialization=False) tokenizer.save_pretrained(SAVE2) del model, trainer; torch.cuda.empty_cache() # === Condition 3: RS1+RS2 + 3x replay + lr=3e-6 (higher lr) === print("\n=== C20-3: RS1+RS2 + 3x replay + lr=3e-6 ===") random.seed(SEED) replay = [d for d in orig_filtered if d["question"] not in rs_qs] random.shuffle(replay) replay3 = replay[:int(len(rs_combined) * 3)] mixed3 = rs_combined + replay3 random.shuffle(mixed3) print(f" RS: {len(rs_combined)} + replay: {len(replay3)} = {len(mixed3)}") ds3 = Dataset.from_list(mixed3) cols = [c for c in ds3.column_names if c not in ["prompt","completion"]] ds3 = ds3.map(to_sft, remove_columns=cols) tokenizer = AutoTokenizer.from_pretrained(BASE) model = AutoModelForCausalLM.from_pretrained(BASE, dtype=torch.bfloat16, device_map="auto", attn_implementation='flash_attention_2') tokenizer.pad_token = tokenizer.eos_token model.gradient_checkpointing_enable(); model.config.use_cache = False cfg3 = SFTConfig(report_to='none', seed=SEED, num_train_epochs=1, warmup_ratio=0.05, weight_decay=0.01, max_grad_norm=1.0, per_device_train_batch_size=8, gradient_accumulation_steps=4, max_length=2048, lr_scheduler_type='cosine', learning_rate=3e-6, bf16=True, optim="paged_adamw_8bit", output_dir="outputs/c20_3_ckpt", logging_steps=25, save_strategy="no") trainer = SFTTrainer(model=model, processing_class=tokenizer, train_dataset=ds3, args=cfg3) r = trainer.train() print(f" Loss: {r.training_loss:.4f}") SAVE3 = "outputs/models/c20-3-lr3e-6" os.makedirs(SAVE3, exist_ok=True) model.eval(); model.save_pretrained(SAVE3, safe_serialization=False) tokenizer.save_pretrained(SAVE3) del model, trainer; torch.cuda.empty_cache() print("\n=== All conditions complete ===")