#!/usr/bin/env python3 """ Training script for Zenith-7B model. Fine-tunes on OpenThoughts-1.2M with custom data for code generation and EQ. """ import argparse import logging import os import sys from pathlib import Path import torch from transformers import AutoTokenizer # Add current directory to path for imports sys.path.append(str(Path(__file__).parent)) from configs.zenith_config import get_7b_config, DataConfig, TrainingConfig, TrainerConfig from data.openthoughts_processor import OpenThoughtsConfig, OpenThoughtsProcessor, QualityFilter, CurriculumSampler from models.zenith_model import ZenithForCausalLM, LoRAAdapter, QLoRAAdapter from training.trainer import train_zenith_model, Trainer from utils.checkpoint import setup_logging logger = logging.getLogger(__name__) def parse_args(): parser = argparse.ArgumentParser(description="Train Zenith-7B model") parser.add_argument("--output_dir", type=str, default="./outputs/zenith-7b", help="Output directory") parser.add_argument("--data_dir", type=str, default="./data", help="Data directory") parser.add_argument("--cache_dir", type=str, default="./cache", help="Cache directory") parser.add_argument("--log_dir", type=str, default="./logs", help="Log directory") # Model parser.add_argument("--base_model", type=str, default="meta-llama/Llama-2-7b-hf", help="Base model to fine-tune") parser.add_argument("--use_lora", action="store_true", help="Use LoRA for efficient fine-tuning") parser.add_argument("--lora_rank", type=int, default=16, help="LoRA rank") parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha") parser.add_argument("--use_qlora", action="store_true", help="Use QLoRA (4-bit quantization)") # Data parser.add_argument("--openthoughts_dataset", type=str, default="open-thoughts/OpenThoughts3-1.2M", help="OpenThoughts dataset") parser.add_argument("--custom_datasets", type=str, nargs="+", default=[], help="Custom dataset paths") parser.add_argument("--max_seq_length", type=int, default=8192, help="Maximum sequence length") parser.add_argument("--train_batch_size", type=int, default=4, help="Training batch size") parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Gradient accumulation steps") parser.add_argument("--effective_batch_size", type=int, default=32, help="Effective batch size (overrides gradient_accumulation if set)") # Training parser.add_argument("--learning_rate", type=float, default=2e-4, help="Learning rate") parser.add_argument("--num_train_epochs", type=int, default=3, help="Number of training epochs") parser.add_argument("--max_steps", type=int, default=-1, help="Maximum training steps (-1 for epochs)") parser.add_argument("--warmup_steps", type=int, default=1000, help="Warmup steps") parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") parser.add_argument("--clip_grad_norm", type=float, default=1.0, help="Gradient clipping norm") # Advanced parser.add_argument("--use_curriculum", action="store_true", help="Enable curriculum learning") parser.add_argument("--use_quality_filter", action="store_true", help="Enable quality filtering") parser.add_argument("--use_augmentation", action="store_true", help="Enable data augmentation") parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["no", "fp16", "bf16"], help="Mixed precision") parser.add_argument("--seed", type=int, default=42, help="Random seed") # Logging parser.add_argument("--logging_steps", type=int, default=10, help="Logging steps") parser.add_argument("--eval_steps", type=int, default=500, help="Evaluation steps") parser.add_argument("--save_steps", type=int, default=1000, help="Save checkpoint steps") parser.add_argument("--report_to", type=str, nargs="+", default=["tensorboard", "wandb"], help="Reporting platforms") # Resume parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Resume from checkpoint") return parser.parse_args() def main(): args = parse_args() # Setup logging setup_logging(log_dir=args.log_dir) logger.info("Starting Zenith-7B training") logger.info(f"Arguments: {args}") # Set seed torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) # Create output directories os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.cache_dir, exist_ok=True) # Load tokenizer logger.info(f"Loading tokenizer: {args.base_model}") tokenizer = AutoTokenizer.from_pretrained( args.base_model, cache_dir=args.cache_dir, use_fast=True, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load base model logger.info(f"Loading base model: {args.base_model}") model_kwargs = { "cache_dir": args.cache_dir, "torch_dtype": torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16 if args.mixed_precision == "fp16" else torch.float32, "device_map": "auto" if torch.cuda.is_available() else None, } if args.use_qlora: model_kwargs["load_in_4bit"] = True model_kwargs["bnb_4bit_compute_dtype"] = torch.bfloat16 model_kwargs["bnb_4bit_quant_type"] = "nf4" model_kwargs["bnb_4bit_use_double_quant"] = True base_model = AutoModelForCausalLM.from_pretrained(args.base_model, **model_kwargs) # Apply LoRA if requested if args.use_lora or args.use_qlora: logger.info("Applying LoRA adapters...") lora_config = LoRAAdapter( r=args.lora_rank, lora_alpha=args.lora_alpha, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_dropout=0.05, bias="none", ) base_model = apply_lora(base_model, lora_config) # Create Zenith model config = get_7b_config() model = ZenithForCausalLM(config, base_model=base_model) logger.info(f"Model initialized: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9:.2f}B trainable parameters") # Data configuration data_config = DataConfig( openthoughts_dataset=args.openthoughts_dataset, custom_datasets=args.custom_datasets, tokenizer_name=args.base_model, max_seq_length=args.max_seq_length, use_curriculum=args.use_curriculum, use_augmentation=args.use_augmentation, cache_dir=args.cache_dir, ) # Quality filter quality_filter = QualityFilter() if args.use_quality_filter else None data_config.quality_filter = quality_filter # Training configuration if args.effective_batch_size: gradient_accumulation_steps = args.effective_batch_size // args.train_batch_size else: gradient_accumulation_steps = args.gradient_accumulation_steps training_config = TrainingConfig( train_batch_size=args.train_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, learning_rate=args.learning_rate, num_train_epochs=args.num_train_epochs, max_steps=args.max_steps, save_steps=args.save_steps, eval_steps=args.eval_steps, logging_steps=args.logging_steps, optimizer=type('obj', (object,), { 'type': 'adamw', 'learning_rate': args.learning_rate, 'weight_decay': args.weight_decay, 'clip_grad_norm': args.clip_grad_norm, })(), scheduler=type('obj', (object,), { 'type': 'cosine', 'warmup_steps': args.warmup_steps, })(), mixed_precision=args.mixed_precision, gradient_ckpt=True, report_to=args.report_to, seed=args.seed, resume_from_checkpoint=args.resume_from_checkpoint, ) # Trainer configuration trainer_config = TrainerConfig( model_config=config, data_config=data_config, training_config=training_config, output_dir=args.output_dir, logging_dir=args.log_dir, checkpoint_dir=f"{args.output_dir}/checkpoints", gradient_accumulation_steps=gradient_accumulation_steps, use_amp=args.mixed_precision != "no", log_interval=args.logging_steps, eval_interval=args.eval_steps, save_interval=args.save_steps, resume_from_checkpoint=args.resume_from_checkpoint, ) # Load dataset logger.info("Loading OpenThoughts dataset...") openthoughts_config = OpenThoughtsConfig( dataset_name=args.openthoughts_dataset, cache_dir=args.cache_dir, quality_filter=quality_filter, use_curriculum=args.use_curriculum, use_augmentation=args.use_augmentation, max_seq_length=args.max_seq_length, tokenizer=tokenizer, ) processor = OpenThoughtsProcessor(openthoughts_config) dataset = processor.load_dataset() # Split dataset logger.info("Splitting dataset...") split_dataset = dataset.train_test_split(test_size=0.05, seed=args.seed) train_dataset = split_dataset["train"] val_dataset = split_dataset["test"] logger.info(f"Train samples: {len(train_dataset)}") logger.info(f"Val samples: {len(val_dataset)}") # Create curriculum sampler if needed if args.use_curriculum: from ..data import create_curriculum_sampler curriculum_sampler = create_curriculum_sampler( train_dataset, data_config.curriculum, current_epoch=0, seed=args.seed, ) if curriculum_sampler: # Will be used in dataloader creation pass # Train logger.info("Starting training...") trainer = train_zenith_model( model=model, tokenizer=tokenizer, config=trainer_config, train_dataset=train_dataset, val_dataset=val_dataset, ) logger.info("Training complete!") logger.info(f"Model saved to {args.output_dir}") # Save final model model.save_pretrained(f"{args.output_dir}/final") tokenizer.save_pretrained(f"{args.output_dir}/final") if __name__ == "__main__": main()