project_agora / scripts /train_planner.py
ilessio-aiflowlab's picture
[AGORA] Full export: pth + safetensors + ONNX + TRT fp16 + TRT fp32
12d70dc verified
#!/usr/bin/env python3
"""Fine-tune Qwen2.5-1.5B-Instruct as an AGORA multi-robot task planner using LoRA.
Reads training data from /mnt/artifacts-datai/logs/project_agora/planning_train.jsonl
Saves checkpoints to /mnt/artifacts-datai/checkpoints/project_agora/
Saves final model to /mnt/artifacts-datai/models/project_agora/agora-planner-v1/
Usage:
CUDA_VISIBLE_DEVICES=2,3 python scripts/train_planner.py
CUDA_VISIBLE_DEVICES=2,3 python scripts/train_planner.py --model Qwen/Qwen2.5-0.5B
"""
from __future__ import annotations
import json
import os
import sys
from pathlib import Path
import torch
# ---------------------------------------------------------------------------
# Project and artifact paths
# ---------------------------------------------------------------------------
PROJECT = "project_agora"
ARTIFACTS = "/mnt/artifacts-datai"
CHECKPOINT_DIR = f"{ARTIFACTS}/checkpoints/{PROJECT}"
MODEL_DIR = f"{ARTIFACTS}/models/{PROJECT}/agora-planner-v1"
LOG_DIR = f"{ARTIFACTS}/logs/{PROJECT}"
TB_DIR = f"{ARTIFACTS}/tensorboard/{PROJECT}"
for d in [CHECKPOINT_DIR, MODEL_DIR, LOG_DIR, TB_DIR]:
os.makedirs(d, exist_ok=True)
# ---------------------------------------------------------------------------
# Defaults
# ---------------------------------------------------------------------------
DEFAULT_MODEL = "/mnt/forge-data/models/Qwen--Qwen2.5-1.5B-Instruct"
DEFAULT_TRAIN_DATA = f"{LOG_DIR}/planning_train.jsonl"
DEFAULT_EVAL_DATA = f"{LOG_DIR}/planning_eval.jsonl"
def main():
import argparse
parser = argparse.ArgumentParser(description="Train AGORA planner with LoRA")
parser.add_argument(
"--model", default=DEFAULT_MODEL,
help="Base model path or HF ID",
)
parser.add_argument(
"--train-data", default=DEFAULT_TRAIN_DATA,
help="Training JSONL path",
)
parser.add_argument(
"--eval-data", default=DEFAULT_EVAL_DATA,
help="Evaluation JSONL path",
)
parser.add_argument("--epochs", type=int, default=3, help="Training epochs")
parser.add_argument("--batch-size", type=int, default=4, help="Per-device batch size")
parser.add_argument("--grad-accum", type=int, default=4, help="Gradient accumulation steps")
parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate")
parser.add_argument("--max-seq-len", type=int, default=2048, help="Max sequence length")
parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank")
parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha")
parser.add_argument("--lora-dropout", type=float, default=0.05, help="LoRA dropout")
parser.add_argument("--warmup-ratio", type=float, default=0.05, help="Warmup ratio")
parser.add_argument("--save-steps", type=int, default=100, help="Save every N steps")
parser.add_argument("--logging-steps", type=int, default=10, help="Log every N steps")
parser.add_argument("--bf16", action="store_true", default=True, help="Use bf16")
parser.add_argument("--num-workers", type=int, default=2, help="Dataloader num_workers")
parser.add_argument("--pin-memory", action="store_true", default=False, help="Pin memory")
parser.add_argument("--max-steps", type=int, default=-1, help="Max steps (-1=full run)")
parser.add_argument("--merge-and-save", action="store_true", default=True,
help="Merge LoRA weights into base model after training")
args = parser.parse_args()
# Validate model path
model_path = Path(args.model)
if not model_path.exists():
# Try HF models directory
alt = Path("/mnt/forge-data/models") / args.model.replace("/", "--")
if alt.exists():
args.model = str(alt)
else:
print(f"WARNING: Model not found at {args.model} or {alt}")
print("Available models:")
for p in sorted(Path("/mnt/forge-data/models").iterdir()):
if p.is_dir() and "qwen" in p.name.lower():
print(f" {p}")
sys.exit(1)
# Validate training data
if not Path(args.train_data).exists():
print(f"ERROR: Training data not found at {args.train_data}")
print("Run: python scripts/generate_planning_data.py")
sys.exit(1)
print("=" * 60)
print("AGORA Planner Training")
print("=" * 60)
print(f"Model: {args.model}")
print(f"Train data: {args.train_data}")
print(f"Eval data: {args.eval_data}")
print(f"Checkpoints: {CHECKPOINT_DIR}")
print(f"Final model: {MODEL_DIR}")
print(f"TensorBoard: {TB_DIR}")
print(f"Epochs: {args.epochs}")
print(f"Batch size: {args.batch_size} x {args.grad_accum} accum")
print(f"LR: {args.lr}")
print(f"LoRA: r={args.lora_r}, alpha={args.lora_alpha}")
print(f"Max seq len: {args.max_seq_len}")
print(f"bf16: {args.bf16}")
print(f"GPUs: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
name = torch.cuda.get_device_name(i)
mem = torch.cuda.get_device_properties(i).total_memory / 1e9
print(f" GPU {i}: {name} ({mem:.1f}GB)")
print("=" * 60)
# ---------------------------------------------------------------------------
# Load tokenizer and model with LoRA
# ---------------------------------------------------------------------------
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTConfig, SFTTrainer
print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
args.model,
trust_remote_code=True,
padding_side="right",
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Loading base model...")
model = AutoModelForCausalLM.from_pretrained(
args.model,
torch_dtype=torch.bfloat16 if args.bf16 else torch.float16,
device_map="auto",
trust_remote_code=True,
)
model.config.use_cache = False # Required for gradient checkpointing
print("Applying LoRA adapter...")
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
bias="none",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# ---------------------------------------------------------------------------
# Load dataset
# ---------------------------------------------------------------------------
print("\nLoading training data...")
dataset = load_dataset("json", data_files={
"train": args.train_data,
"eval": args.eval_data if Path(args.eval_data).exists() else args.train_data,
})
print(f"Train examples: {len(dataset['train'])}")
print(f"Eval examples: {len(dataset['eval'])}")
# ---------------------------------------------------------------------------
# Training configuration
# ---------------------------------------------------------------------------
training_args = SFTConfig(
output_dir=CHECKPOINT_DIR,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=args.lr,
lr_scheduler_type="cosine",
warmup_ratio=args.warmup_ratio,
bf16=args.bf16,
fp16=not args.bf16,
logging_dir=TB_DIR,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
save_total_limit=3,
eval_strategy="steps",
eval_steps=args.save_steps,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
max_length=args.max_seq_len,
max_steps=args.max_steps,
report_to=["tensorboard"],
seed=42,
dataloader_num_workers=args.num_workers,
dataloader_pin_memory=args.pin_memory,
remove_unused_columns=True,
packing=False,
)
# ---------------------------------------------------------------------------
# Train
# ---------------------------------------------------------------------------
print("\nStarting training...")
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["eval"],
processing_class=tokenizer,
)
train_result = trainer.train()
# Log final metrics
metrics = train_result.metrics
print("\n=== Training Complete ===")
print(f"Train loss: {metrics.get('train_loss', 'N/A')}")
print(f"Train runtime: {metrics.get('train_runtime', 'N/A'):.1f}s")
print(f"Train samples/s: {metrics.get('train_samples_per_second', 'N/A'):.1f}")
# Save metrics
metrics_path = f"{LOG_DIR}/training_metrics.json"
with open(metrics_path, "w") as f:
json.dump(metrics, f, indent=2, default=str)
print(f"Metrics saved to: {metrics_path}")
# ---------------------------------------------------------------------------
# Save
# ---------------------------------------------------------------------------
# Save LoRA adapter
lora_path = f"{MODEL_DIR}/lora_adapter"
print(f"\nSaving LoRA adapter to: {lora_path}")
model.save_pretrained(lora_path)
tokenizer.save_pretrained(lora_path)
# Merge and save full model
if args.merge_and_save:
print("Merging LoRA weights into base model...")
merged_model = model.merge_and_unload()
merged_path = f"{MODEL_DIR}/merged"
print(f"Saving merged model to: {merged_path}")
merged_model.save_pretrained(merged_path)
tokenizer.save_pretrained(merged_path)
print("Merged model saved successfully.")
# Save model card
card_path = f"{MODEL_DIR}/README.md"
with open(card_path, "w") as f:
f.write(f"""# AGORA Planner v1
Fine-tuned multi-robot task planner for the AGORA coordination framework.
## Base Model
- Qwen2.5-1.5B-Instruct
## Training
- Method: LoRA (r={args.lora_r}, alpha={args.lora_alpha})
- Epochs: {args.epochs}
- Learning rate: {args.lr}
- Effective batch size: {args.batch_size * args.grad_accum}
- Max sequence length: {args.max_seq_len}
- Training loss: {metrics.get('train_loss', 'N/A')}
## Purpose
Task allocation for heterogeneous robot teams. Given a team state (robot
capabilities, battery levels, locations, recent history) and a set of task
requests, the model produces optimal task-to-robot assignments with reasoning.
## Usage
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("{MODEL_DIR}/merged")
tokenizer = AutoTokenizer.from_pretrained("{MODEL_DIR}/merged")
```
""")
print(f"\n{'=' * 60}")
print("TRAINING COMPLETE")
print(f"{'=' * 60}")
print(f"LoRA adapter: {lora_path}")
if args.merge_and_save:
print(f"Merged model: {merged_path}")
print(f"Metrics: {metrics_path}")
print(f"TensorBoard: {TB_DIR}")
print(f"Model card: {card_path}")
if __name__ == "__main__":
main()