Upload train_grpo.py with huggingface_hub
Browse files- train_grpo.py +14 -0
train_grpo.py
CHANGED
|
@@ -237,6 +237,20 @@ def main():
|
|
| 237 |
model = model.merge_and_unload() # Merge LoRA weights
|
| 238 |
print("Model loaded and LoRA merged.")
|
| 239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
# Initialize reward function
|
| 241 |
reward_fn = QMDRewardFunction()
|
| 242 |
|
|
|
|
| 237 |
model = model.merge_and_unload() # Merge LoRA weights
|
| 238 |
print("Model loaded and LoRA merged.")
|
| 239 |
|
| 240 |
+
# Add new LoRA adapter for GRPO training
|
| 241 |
+
from peft import get_peft_model
|
| 242 |
+
grpo_lora_config = LoraConfig(
|
| 243 |
+
r=8,
|
| 244 |
+
lora_alpha=16,
|
| 245 |
+
lora_dropout=0.05,
|
| 246 |
+
bias="none",
|
| 247 |
+
task_type="CAUSAL_LM",
|
| 248 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 249 |
+
)
|
| 250 |
+
model = get_peft_model(model, grpo_lora_config)
|
| 251 |
+
model.print_trainable_parameters()
|
| 252 |
+
print("Added new LoRA adapter for GRPO.")
|
| 253 |
+
|
| 254 |
# Initialize reward function
|
| 255 |
reward_fn = QMDRewardFunction()
|
| 256 |
|