| # /// script | |
| # dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "datasets", "accelerate", "bitsandbytes"] | |
| # /// | |
| from datasets import load_dataset | |
| from peft import LoraConfig | |
| from trl import SFTTrainer, SFTConfig | |
| dataset = load_dataset("NTA-Dev/bf-forge-training", split="train") | |
| print(f"Loaded {len(dataset)} examples for Forge") | |
| trainer = SFTTrainer( | |
| model="Qwen/Qwen2.5-3B-Instruct", | |
| train_dataset=dataset, | |
| peft_config=LoraConfig( | |
| r=32, lora_alpha=64, | |
| target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"], | |
| lora_dropout=0.05, | |
| ), | |
| args=SFTConfig( | |
| output_dir="forge-agent", | |
| push_to_hub=True, | |
| hub_model_id="NTA-Dev/bf-forge-agent", | |
| num_train_epochs=5, | |
| per_device_train_batch_size=2, | |
| gradient_accumulation_steps=4, | |
| learning_rate=2e-4, | |
| warmup_ratio=0.1, | |
| logging_steps=5, | |
| save_strategy="epoch", | |
| bf16=True, | |
| gradient_checkpointing=True, | |
| max_length=2048, | |
| ) | |
| ) | |
| trainer.train() | |
| trainer.push_to_hub() | |
| print("Forge training complete!") | |