File size: 1,119 Bytes
e80846a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | # /// script
# dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "datasets", "accelerate", "bitsandbytes"]
# ///
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
dataset = load_dataset("NTA-Dev/bf-forge-training", split="train")
print(f"Loaded {len(dataset)} examples for Forge")
trainer = SFTTrainer(
model="Qwen/Qwen2.5-3B-Instruct",
train_dataset=dataset,
peft_config=LoraConfig(
r=32, lora_alpha=64,
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
lora_dropout=0.05,
),
args=SFTConfig(
output_dir="forge-agent",
push_to_hub=True,
hub_model_id="NTA-Dev/bf-forge-agent",
num_train_epochs=5,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-4,
warmup_ratio=0.1,
logging_steps=5,
save_strategy="epoch",
bf16=True,
gradient_checkpointing=True,
max_length=2048,
)
)
trainer.train()
trainer.push_to_hub()
print("Forge training complete!")
|