from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM import torch import logging from tqdm import tqdm import functools # Standard library imports import os import sys import time import traceback from dataclasses import dataclass, field from datetime import datetime, timedelta from typing import Optional import torch.amp # Third-party imports import datasets import psutil import torch.multiprocessing as mp import transformers from peft import LoraConfig from tqdm import tqdm from transformers import HfArgumentParser, TrainingArguments, set_seed from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM # Local imports from lpm_kernel.L2.utils import ( create_and_prepare_model, formatting_prompts_func, create_chat_data, release_ollama_models_early, ) from lpm_kernel.configs.logging import LOGGING_CONFIG import logging.config from lpm_kernel.configs.logging import get_train_process_logger from lpm_kernel.L2.memory_manager import get_memory_manager logger = get_train_process_logger() # Configure how tqdm displays in logs class LogTqdm(tqdm): def __init__(self, *args, **kwargs): kwargs.setdefault("mininterval", 1.0) kwargs.setdefault("ascii", True) super().__init__(*args, **kwargs) # Replace the default tqdm sys.modules["tqdm"].tqdm = LogTqdm # Debug callback for logging training progress class DebugCallback(transformers.TrainerCallback): def __init__(self): self.total_time = 0 self.last_time = time.time() def on_step_end(self, args, state, control, **kwargs): if state.global_step % 10 == 0: current_time = time.time() step_time = current_time - self.last_time self.total_time += step_time self.last_time = current_time # Log step time and training progress logger.info(f"Step {state.global_step}: {step_time:.2f}s - Total training time: {self.total_time:.2f}s") def on_epoch_end(self, args, state, control, **kwargs): logger.info(f"Epoch {state.epoch} completed") @dataclass class ModelArguments: """ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. """ model_name_or_path: str = field( metadata={ "help": "Path to pretrained model or model identifier from huggingface.co/models" } ) chat_template_format: Optional[str] = field( default="none", metadata={ "help": "chatml|zephyr|none. Pass `none` if the dataset is already formatted with the chat template." }, ) lora_alpha: Optional[int] = field(default=16) lora_dropout: Optional[float] = field(default=0.1) lora_r: Optional[int] = field(default=64) lora_target_modules: Optional[str] = field( default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj", metadata={ "help": "comma separated list of target modules to apply LoRA layers to" }, ) use_nested_quant: Optional[bool] = field( default=False, metadata={"help": "Activate nested quantization for 4bit base models"}, ) bnb_4bit_compute_dtype: Optional[str] = field( default="float16", metadata={"help": "Compute dtype for 4bit base models"}, ) bnb_4bit_quant_storage_dtype: Optional[str] = field( default="float32", metadata={"help": "Quantization storage dtype for 4bit base models"}, ) bnb_4bit_quant_type: Optional[str] = field( default="nf4", metadata={"help": "Quantization type fp4 or nf4"}, ) use_flash_attn: Optional[bool] = field( default=False, metadata={"help": "Enables Flash attention for training."}, ) use_peft_lora: Optional[bool] = field( default=False, metadata={"help": "Enables PEFT LoRA for training."}, ) use_8bit_quantization: Optional[bool] = field( default=False, metadata={"help": "Enables loading model in 8bit."}, ) use_4bit_quantization: Optional[bool] = field( default=False, metadata={"help": "Enables loading model in 4bit."}, ) use_reentrant: Optional[bool] = field( default=False, metadata={"help": "Gradient Checkpointing param. Refer the related docs"}, ) use_unsloth: Optional[bool] = field( default=False, metadata={"help": "Enables UnSloth for training."}, ) use_cuda: Optional[bool] = field( default=False, metadata={"help": "Enables CUDA GPU acceleration for training and inference when available."}, ) @dataclass class DataTrainingArguments: dataset_name: Optional[str] = field( default="timdettmers/openassistant-guanaco", metadata={"help": "The preference dataset to use."}, ) append_concat_token: Optional[bool] = field( default=False, metadata={ "help": "If True, appends `eos_token_id` at the end of each sample being packed." }, ) add_special_tokens: Optional[bool] = field( default=False, metadata={ "help": "If True, tokenizers adds special tokens to each sample being packed." }, ) splits: Optional[str] = field( default="train,test", metadata={"help": "Comma separate list of the splits to use from the dataset."}, ) is_sequential: Optional[bool] = field( default=False, metadata={"help": "If True, the dataset is sequential."}, ) is_cot: Optional[bool] = field( default=False, metadata={"help": "If True, the dataset is COT dataset."}, ) user_name: Optional[str] = field( default="User", metadata={"help": "The name of the user."}, ) def main(model_args, data_args, training_args): logger.info(f"Python version--------------------: {sys.version}") # Configure logging logging.config.dictConfig(LOGGING_CONFIG) logger.info("Begin training...") # Ensure logs are flushed immediately for handler in logging.getLogger().handlers: handler.flush() # Get memory manager for optimization memory_manager = get_memory_manager() memory_manager.cleanup_memory(force=True) # Release Ollama models if they exist to free up VRAM if torch.cuda.is_available() and model_args.use_cuda: release_ollama_models_early() logger.info("Initializing training with memory optimizations") set_seed(training_args.seed) # Apply PyTorch memory optimizations to training arguments logger.info("Applying memory optimizations to training configuration") training_args = memory_manager.optimize_training_args(training_args) # --- Accelerate optimizer state offloading logic --- # Enable optimizer state offload to CPU if VRAM is low and not using DeepSpeed vram_total = memory_manager.get_memory_info().get("vram_total_gb", 0) use_accelerate_offload = False if torch.cuda.is_available() and model_args.use_cuda and vram_total > 0 and vram_total < 16: # Only set if not already using DeepSpeed if not hasattr(training_args, "deepspeed") or training_args.deepspeed is None: logger.info("Enabling Hugging Face Accelerate optimizer state offload to CPU for low VRAM GPUs") accelerate_config = { "compute_environment": "LOCAL_MACHINE", "deepspeed_config": None, "distributed_type": "NO", "downcast_bf16": False, "fsdp_config": {}, "main_training_function": "main", "mixed_precision": "no", "num_machines": 1, "num_processes": 1, "use_cpu": False, "zero3_init_flag": False, "offload_optimizer_device": "cpu", "offload_param_device": "none" } training_args.accelerate_config = accelerate_config use_accelerate_offload = True # Model loading with device_map="auto" for automatic offloading logger.info(f"Loading model with automatic memory management from {model_args.model_name_or_path}") # Create model arguments dict with automatic offloading model_kwargs = { # Don't use "auto" device_map initially to avoid meta tensor issues "device_map": None, "trust_remote_code": True } # Configure quantization if requested if model_args.use_4bit_quantization: from transformers import BitsAndBytesConfig compute_dtype = getattr(torch, model_args.bnb_4bit_compute_dtype) quant_storage_dtype = getattr(torch, model_args.bnb_4bit_quant_storage_dtype) model_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=model_args.use_4bit_quantization, bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=model_args.use_nested_quant, bnb_4bit_quant_storage=quant_storage_dtype, ) # For 4-bit models, we can use device_map="auto" model_kwargs["device_map"] = "auto" logger.info("Using 4-bit quantization for memory efficiency") elif model_args.use_8bit_quantization: from transformers import BitsAndBytesConfig model_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_8bit=model_args.use_8bit_quantization ) # For 8-bit models, we can use device_map="auto" model_kwargs["device_map"] = "auto" logger.info("Using 8-bit quantization for memory efficiency") # Flash attention for memory efficiency when supported if model_args.use_flash_attn and torch.cuda.is_available() and model_args.use_cuda: model_kwargs["attn_implementation"] = "flash_attention_2" logger.info("Using Flash Attention 2 for memory efficiency") # Load model with built-in memory management features model, peft_config, tokenizer = create_and_prepare_model( model_args, data_args, training_args, model_kwargs=model_kwargs ) # If model has meta tensors, handle them properly if hasattr(model, "is_meta") and model.is_meta: logger.info("Model has meta tensors, using to_empty() to properly initialize") device = "cuda" if torch.cuda.is_available() and model_args.use_cuda else "cpu" model = model.to_empty(device=device) # Apply gradient checkpointing for memory efficiency if training_args.gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"): logger.info("Enabling gradient checkpointing for memory efficiency") model.gradient_checkpointing_enable() model.config.use_cache = False # Allow only one full forward/backward pass at a time (if needed for memory) if torch.cuda.is_available() and memory_manager.get_memory_info().get("vram_total_gb", 0) < 8: torch.cuda.set_per_process_memory_fraction(0.9) logger.info("Setting memory fraction limit to avoid OOM errors") # datasets train_dataset = create_chat_data( data_args, tokenizer, ) response_template = "\n<|im_start|>assistant\n" collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer) training_args.dataset_kwargs = { "append_concat_token": data_args.append_concat_token, "add_special_tokens": data_args.add_special_tokens, } # Use DeepSpeed to handle meta tensors if available try: # Only configure DeepSpeed if meta tensors are present and DeepSpeed is available if hasattr(model, "is_meta") and model.is_meta: logger.info("Model has meta tensors, checking DeepSpeed availability") # First verify DeepSpeed is properly installed and importable try: import deepspeed logger.info("DeepSpeed is available, configuring for meta tensor handling") # Configure with appropriate settings for meta tensors training_args.deepspeed = { "zero_stage": 3, "offload_optimizer": { "device": "cpu" }, "offload_param": { "device": "cpu" }, "zero3_init_flag": True, "zero_force_ds_cpu_optimizer": False } logger.info("DeepSpeed configured for meta tensor handling") except ImportError: logger.warning("DeepSpeed is not available, meta tensors will be handled differently") # If DeepSpeed isn't available, use alternative approach to handle meta tensors if torch.cuda.is_available() and model_args.use_cuda: logger.info("Initializing meta tensors on GPU") # Use device_map instead of DeepSpeed for meta tensor initialization from accelerate import init_empty_weights with init_empty_weights(): model.to_empty(device="cuda") else: logger.info("Initializing meta tensors on CPU") model.to_empty(device="cpu") except Exception as e: logger.warning(f"Could not configure meta tensor handling: {e}") logger.warning(traceback.format_exc()) trainer = SFTTrainer( model=model, tokenizer=tokenizer, args=training_args, train_dataset=train_dataset, peft_config=peft_config, formatting_func=formatting_prompts_func, data_collator=collator, ) # Print model details trainer.accelerator.print(f"{trainer.model}") if hasattr(trainer.model, "print_trainable_parameters"): trainer.model.print_trainable_parameters() # Memory usage tracking callback class MemoryMonitorCallback(transformers.TrainerCallback): def __init__(self): self.memory_manager = get_memory_manager() def on_step_end(self, args, state, control, **kwargs): # Check memory every 5 steps if state.global_step % 5 == 0 and torch.cuda.is_available(): info = self.memory_manager.get_memory_info() vram_usage_pct = info.get("vram_used_gb", 0) / info.get("vram_total_gb", 1) * 100 if vram_usage_pct > 90: logger.info(f"VRAM usage high ({vram_usage_pct:.1f}%), cleaning cache") self.memory_manager.cleanup_memory() def on_save(self, args, state, control, **kwargs): # Free up memory before saving self.memory_manager.cleanup_memory(force=True) # Add memory monitoring trainer.add_callback(MemoryMonitorCallback()) # Add existing debug callback trainer.add_callback(DebugCallback()) # Resume from checkpoint if specified checkpoint = None if training_args.resume_from_checkpoint is not None: checkpoint = training_args.resume_from_checkpoint # Training with automatic memory management try: logger.info("Starting training with memory-optimized configuration") trainer.train(resume_from_checkpoint=checkpoint) except Exception as e: logger.error(f"Error during training: {str(e)}") logger.error(f"Error type: {type(e)}") logger.error(f"Traceback: {traceback.format_exc()}") raise # Save the model if trainer.is_fsdp_enabled: trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") # Clean up before saving memory_manager.cleanup_memory(force=True) trainer.save_model() logger.info("Training completed successfully") # Create a patch to handle autocast compatibility def get_autocast(): if hasattr(torch.cpu, "amp") and hasattr(torch.cpu.amp, "autocast"): # Old version return torch.cpu.amp.autocast else: # New version return lambda **kwargs: torch.amp.autocast("cpu", **kwargs) # Replace the original torch.cpu.amp.autocast with our compatible function torch.cpu.amp.autocast = get_autocast() if __name__ == "__main__": parser = HfArgumentParser((ModelArguments, DataTrainingArguments, SFTConfig)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1]) ) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() main(model_args, data_args, training_args)