Zenith-28b-p300-V1 / train.py
Zandy-Wandy's picture
Upload Zenith-28b-V1-Tenstorrent-Blackhole-p300 model
8944ef7 verified
#!/usr/bin/env python3
"""
Training script for Zenith-28B-p300 model.
Based on Qwen3.5-27B-Claude-4.6-Opus-Reasoning-Distilled with advanced reasoning capabilities.
Optimized for Tenstorrent p300a with 32k context and ring attention.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from transformers import AutoTokenizer
from ..configs import get_28b_p300_config, DataConfig, TrainingConfig, TrainerConfig
from ..data import OpenThoughtsConfig, OpenThoughtsProcessor, QualityFilter, CurriculumSampler
from ..models import ZenithForCausalLM, LoRAAdapter, QLoRAAdapter
from ..training import train_zenith_model
from ..utils import setup_logging
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Train Zenith-28B-p300 model")
parser.add_argument("--output_dir", type=str, default="./outputs/zenith-28b-p300", help="Output directory")
parser.add_argument("--data_dir", type=str, default="./data", help="Data directory")
parser.add_argument("--cache_dir", type=str, default="./cache", help="Cache directory")
parser.add_argument("--log_dir", type=str, default="./logs", help="Log directory")
# Model
parser.add_argument("--base_model", type=str, default="Jackrong/Qwen3.5-27B-Claude-4.6-Opus-Reasoning-Distilled", help="Base model to fine-tune")
parser.add_argument("--use_lora", action="store_true", help="Use LoRA for efficient fine-tuning")
parser.add_argument("--lora_rank", type=int, default=32, help="LoRA rank (higher for 28B)")
parser.add_argument("--lora_alpha", type=int, default=64, help="LoRA alpha")
parser.add_argument("--use_qlora", action="store_true", help="Use QLoRA (4-bit quantization)")
# Tenstorrent p300 specific
parser.add_argument("--use_tenstorrent_optimizations", action="store_true", default=True, help="Enable p300 optimizations")
parser.add_argument("--tensor_parallel_size", type=int, default=8, help="Tensor parallelism (8 cores/chip)")
parser.add_argument("--pipeline_parallel_size", type=int, default=4, help="Pipeline parallelism (4 cores/chip)")
parser.add_argument("--use_ring_attention", action="store_true", default=True, help="Use ring attention for 32k context")
parser.add_argument("--ring_chunk_size", type=int, default=8192, help="Ring attention chunk size")
parser.add_argument("--ring_overlap", type=int, default=2048, help="Ring attention overlap")
# Data
parser.add_argument("--openthoughts_dataset", type=str, default="open-thoughts/OpenThoughts3-1.2M", help="OpenThoughts dataset")
parser.add_argument("--custom_datasets", type=str, nargs="+", default=[], help="Custom dataset paths")
parser.add_argument("--max_seq_length", type=int, default=32768, help="Maximum sequence length (32k)")
parser.add_argument("--train_batch_size", type=int, default=2, help="Training batch size (smaller for 28B)")
parser.add_argument("--gradient_accumulation_steps", type=int, default=16, help="Gradient accumulation steps")
parser.add_argument("--effective_batch_size", type=int, default=32, help="Effective batch size")
# Training
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
parser.add_argument("--num_train_epochs", type=int, default=3, help="Number of training epochs")
parser.add_argument("--max_steps", type=int, default=-1, help="Maximum training steps (-1 for epochs)")
parser.add_argument("--warmup_steps", type=int, default=2000, help="Warmup steps")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--clip_grad_norm", type=float, default=1.0, help="Gradient clipping norm")
# Advanced
parser.add_argument("--use_curriculum", action="store_true", help="Enable curriculum learning")
parser.add_argument("--use_quality_filter", action="store_true", help="Enable quality filtering")
parser.add_argument("--use_augmentation", action="store_true", help="Enable data augmentation")
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["no", "fp16", "bf16"], help="Mixed precision")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
# Logging
parser.add_argument("--logging_steps", type=int, default=10, help="Logging steps")
parser.add_argument("--eval_steps", type=int, default=500, help="Evaluation steps")
parser.add_argument("--save_steps", type=int, default=1000, help="Save checkpoint steps")
parser.add_argument("--report_to", type=str, nargs="+", default=["tensorboard", "wandb"], help="Reporting platforms")
# Resume
parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Resume from checkpoint")
return parser.parse_args()
def main():
args = parse_args()
# Setup logging
setup_logging(log_dir=args.log_dir)
logger.info("Starting Zenith-28B-p300 training")
logger.info(f"Base model: {args.base_model}")
logger.info(f"Arguments: {args}")
# Set seed
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
# Create output directories
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(args.cache_dir, exist_ok=True)
# Load tokenizer
logger.info(f"Loading tokenizer: {args.base_model}")
tokenizer = AutoTokenizer.from_pretrained(
args.base_model,
cache_dir=args.cache_dir,
use_fast=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load base model
logger.info(f"Loading base model: {args.base_model}")
model_kwargs = {
"cache_dir": args.cache_dir,
"torch_dtype": torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16 if args.mixed_precision == "fp16" else torch.float32,
"device_map": "auto" if torch.cuda.is_available() else None,
}
if args.use_qlora:
model_kwargs["load_in_4bit"] = True
model_kwargs["bnb_4bit_compute_dtype"] = torch.bfloat16
model_kwargs["bnb_4bit_quant_type"] = "nf4"
model_kwargs["bnb_4bit_use_double_quant"] = True
try:
base_model = AutoModelForCausalLM.from_pretrained(args.base_model, **model_kwargs)
except Exception as e:
logger.error(f"Failed to load base model: {e}")
logger.info("Attempting to load with trust_remote_code=True...")
model_kwargs["trust_remote_code"] = True
base_model = AutoModelForCausalLM.from_pretrained(args.base_model, **model_kwargs)
# Apply LoRA if requested
if args.use_lora or args.use_qlora:
logger.info("Applying LoRA adapters...")
lora_config = LoRAAdapter(
r=args.lora_rank,
lora_alpha=args.lora_alpha,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_dropout=0.05,
bias="none",
)
base_model = apply_lora(base_model, lora_config)
# Create Zenith config with p300 optimizations
config = get_28b_p300_config()
config.max_seq_len = args.max_seq_length
config.use_ring_attention = args.use_ring_attention
config.ring_attention_chunk_size = args.ring_chunk_size
config.ring_attention_overlap = args.ring_overlap
config.tensor_parallel_size = args.tensor_parallel_size
config.pipeline_parallel_size = args.pipeline_parallel_size
# Create Zenith model
model = ZenithForCausalLM(config, base_model=base_model)
logger.info(f"Model initialized: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9:.2f}B trainable parameters")
# Data configuration
data_config = DataConfig(
openthoughts_dataset=args.openthoughts_dataset,
custom_datasets=args.custom_datasets,
tokenizer_name=args.base_model,
max_seq_length=args.max_seq_length,
use_curriculum=args.use_curriculum,
use_augmentation=args.use_augmentation,
cache_dir=args.cache_dir,
)
# Quality filter
quality_filter = QualityFilter() if args.use_quality_filter else None
data_config.quality_filter = quality_filter
# Training configuration
if args.effective_batch_size:
gradient_accumulation_steps = args.effective_batch_size // args.train_batch_size
else:
gradient_accumulation_steps = args.gradient_accumulation_steps
training_config = TrainingConfig(
train_batch_size=args.train_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=args.learning_rate,
num_train_epochs=args.num_train_epochs,
max_steps=args.max_steps,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
logging_steps=args.logging_steps,
optimizer=type('obj', (object,), {
'type': 'adamw',
'learning_rate': args.learning_rate,
'weight_decay': args.weight_decay,
'clip_grad_norm': args.clip_grad_norm,
})(),
scheduler=type('obj', (object,), {
'type': 'cosine',
'warmup_steps': args.warmup_steps,
})(),
mixed_precision=args.mixed_precision,
gradient_ckpt=True,
report_to=args.report_to,
seed=args.seed,
resume_from_checkpoint=args.resume_from_checkpoint,
)
# Trainer configuration
trainer_config = TrainerConfig(
model_config=config,
data_config=data_config,
training_config=training_config,
output_dir=args.output_dir,
logging_dir=args.log_dir,
checkpoint_dir=f"{args.output_dir}/checkpoints",
gradient_accumulation_steps=gradient_accumulation_steps,
use_amp=args.mixed_precision != "no",
log_interval=args.logging_steps,
eval_interval=args.eval_steps,
save_interval=args.save_steps,
resume_from_checkpoint=args.resume_from_checkpoint,
)
# Load dataset
logger.info("Loading OpenThoughts dataset...")
openthoughts_config = OpenThoughtsConfig(
dataset_name=args.openthoughts_dataset,
cache_dir=args.cache_dir,
quality_filter=quality_filter,
use_curriculum=args.use_curriculum,
use_augmentation=args.use_augmentation,
max_seq_length=args.max_seq_length,
tokenizer=tokenizer,
)
processor = OpenThoughtsProcessor(openthoughts_config)
dataset = processor.load_dataset()
# Split dataset
logger.info("Splitting dataset...")
split_dataset = dataset.train_test_split(test_size=0.05, seed=args.seed)
train_dataset = split_dataset["train"]
val_dataset = split_dataset["test"]
logger.info(f"Train samples: {len(train_dataset)}")
logger.info(f"Val samples: {len(val_dataset)}")
# Train
logger.info("Starting training...")
trainer = train_zenith_model(
model=model,
tokenizer=tokenizer,
config=trainer_config,
train_dataset=train_dataset,
val_dataset=val_dataset,
)
logger.info("Training complete!")
logger.info(f"Model saved to {args.output_dir}")
# Save final model
model.save_pretrained(f"{args.output_dir}/final")
tokenizer.save_pretrained(f"{args.output_dir}/final")
if __name__ == "__main__":
main()