bf-agent-training-scripts / train_manny.py
NTA-Dev's picture
Upload train_manny.py with huggingface_hub
c0b8ce5 verified
# /// script
# dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "datasets", "accelerate", "bitsandbytes"]
# ///
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
dataset = load_dataset("NTA-Dev/bf-manny-training", split="train")
print(f"Loaded {len(dataset)} examples for Manny")
trainer = SFTTrainer(
model="Qwen/Qwen2.5-3B-Instruct",
train_dataset=dataset,
peft_config=LoraConfig(
r=32, lora_alpha=64,
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
lora_dropout=0.05,
),
args=SFTConfig(
output_dir="manny-agent",
push_to_hub=True,
hub_model_id="NTA-Dev/bf-manny-agent",
num_train_epochs=5,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-4,
warmup_ratio=0.1,
logging_steps=5,
save_strategy="epoch",
bf16=True,
gradient_checkpointing=True,
max_length=2048,
)
)
trainer.train()
trainer.push_to_hub()
print("Manny training complete!")