Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| # Basic Python imports | |
| import os | |
| import sys | |
| import json | |
| import argparse | |
| import logging | |
| from datetime import datetime | |
| import time | |
| import warnings | |
| import traceback | |
| from importlib.util import find_spec | |
| import multiprocessing | |
| import torch | |
| import random | |
| import numpy as np | |
| from tqdm import tqdm | |
| # Check hardware capabilities first | |
| CUDA_AVAILABLE = "CUDA_VISIBLE_DEVICES" in os.environ or os.environ.get("NVIDIA_VISIBLE_DEVICES") != "" | |
| NUM_GPUS = torch.cuda.device_count() if CUDA_AVAILABLE else 0 | |
| DEVICE_TYPE = "cuda" if CUDA_AVAILABLE else "cpu" | |
| # Set the multiprocessing start method to 'spawn' for CUDA compatibility | |
| if CUDA_AVAILABLE: | |
| try: | |
| multiprocessing.set_start_method('spawn', force=True) | |
| print("Set multiprocessing start method to 'spawn' for CUDA compatibility") | |
| except RuntimeError: | |
| # Method already set, which is fine | |
| print("Multiprocessing start method already set") | |
| # Import order is important: unsloth should be imported before transformers | |
| # Check for libraries without importing them | |
| unsloth_available = find_spec("unsloth") is not None | |
| if unsloth_available: | |
| import unsloth | |
| # Import torch first, then transformers if available | |
| import torch | |
| transformers_available = find_spec("transformers") is not None | |
| if transformers_available: | |
| import transformers | |
| from transformers import AutoTokenizer, TrainingArguments, Trainer, TrainerCallback, set_seed | |
| from torch.utils.data import DataLoader | |
| peft_available = find_spec("peft") is not None | |
| if peft_available: | |
| import peft | |
| # Only import HF datasets if available | |
| datasets_available = find_spec("datasets") is not None | |
| if datasets_available: | |
| from datasets import load_dataset | |
| # Set up the logger | |
| logger = logging.getLogger(__name__) | |
| log_handler = logging.StreamHandler() | |
| log_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | |
| log_handler.setFormatter(log_format) | |
| logger.addHandler(log_handler) | |
| logger.setLevel(logging.INFO) | |
| # Define a clean logging function for HF Space compatibility | |
| def log_info(message): | |
| """Log information in a format compatible with Hugging Face Spaces""" | |
| # Just use the logger, but ensure consistent formatting | |
| logger.info(message) | |
| # Also ensure output is flushed immediately for streaming | |
| sys.stdout.flush() | |
| # Check for BitsAndBytes | |
| try: | |
| from transformers import BitsAndBytesConfig | |
| bitsandbytes_available = True | |
| except ImportError: | |
| bitsandbytes_available = False | |
| logger.warning("BitsAndBytes not available. 4-bit quantization will not be used.") | |
| # Check for PEFT | |
| try: | |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
| peft_available = True | |
| except ImportError: | |
| peft_available = False | |
| logger.warning("PEFT not available. Parameter-efficient fine-tuning will not be used.") | |
| def load_env_variables(): | |
| """Load environment variables from system, .env file, or Hugging Face Space variables.""" | |
| # Check if we're running in a Hugging Face Space | |
| if os.environ.get("SPACE_ID"): | |
| logging.info("Running in Hugging Face Space") | |
| # Log the presence of variables (without revealing values) | |
| logging.info(f"HF_TOKEN available: {bool(os.environ.get('HF_TOKEN'))}") | |
| logging.info(f"HF_USERNAME available: {bool(os.environ.get('HF_USERNAME'))}") | |
| # If username is not set, try to extract from SPACE_ID | |
| if not os.environ.get("HF_USERNAME") and "/" in os.environ.get("SPACE_ID", ""): | |
| username = os.environ.get("SPACE_ID").split("/")[0] | |
| os.environ["HF_USERNAME"] = username | |
| logging.info(f"Set HF_USERNAME from SPACE_ID: {username}") | |
| else: | |
| # Try to load from .env file if not in a Space | |
| try: | |
| from dotenv import load_dotenv | |
| # First check the current directory | |
| env_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".env") | |
| if os.path.exists(env_path): | |
| load_dotenv(env_path) | |
| logging.info(f"Loaded environment variables from {env_path}") | |
| logging.info(f"HF_TOKEN loaded from .env file: {bool(os.environ.get('HF_TOKEN'))}") | |
| logging.info(f"HF_USERNAME loaded from .env file: {bool(os.environ.get('HF_USERNAME'))}") | |
| logging.info(f"HF_SPACE_NAME loaded from .env file: {bool(os.environ.get('HF_SPACE_NAME'))}") | |
| else: | |
| # Try the shared directory as fallback | |
| shared_env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "shared", ".env") | |
| if os.path.exists(shared_env_path): | |
| load_dotenv(shared_env_path) | |
| logging.info(f"Loaded environment variables from {shared_env_path}") | |
| logging.info(f"HF_TOKEN loaded from shared .env file: {bool(os.environ.get('HF_TOKEN'))}") | |
| logging.info(f"HF_USERNAME loaded from shared .env file: {bool(os.environ.get('HF_USERNAME'))}") | |
| logging.info(f"HF_SPACE_NAME loaded from shared .env file: {bool(os.environ.get('HF_SPACE_NAME'))}") | |
| else: | |
| logging.warning(f"No .env file found in current or shared directory") | |
| except ImportError: | |
| logging.warning("python-dotenv not installed, not loading from .env file") | |
| if not os.environ.get("HF_TOKEN"): | |
| logger.warning("HF_TOKEN is not set. Pushing to Hugging Face Hub will not work.") | |
| if not os.environ.get("HF_USERNAME"): | |
| logger.warning("HF_USERNAME is not set. Using default username.") | |
| if not os.environ.get("HF_SPACE_NAME"): | |
| logger.warning("HF_SPACE_NAME is not set. Using default space name.") | |
| # Set HF_TOKEN for huggingface_hub | |
| if os.environ.get("HF_TOKEN"): | |
| os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN") | |
| def load_configs(base_path): | |
| """Load configuration from transformers_config.json file.""" | |
| # Using a single consolidated config file | |
| config_file = base_path | |
| try: | |
| with open(config_file, "r") as f: | |
| config = json.load(f) | |
| logger.info(f"Loaded configuration from {config_file}") | |
| return config | |
| except Exception as e: | |
| logger.error(f"Error loading {config_file}: {e}") | |
| raise | |
| def parse_args(): | |
| """ | |
| Parse command line arguments for the training script. | |
| Returns: | |
| argparse.Namespace: The parsed command line arguments | |
| """ | |
| parser = argparse.ArgumentParser(description="Run training for language models") | |
| parser.add_argument( | |
| "--config_file", | |
| type=str, | |
| default=None, | |
| help="Path to the configuration file (default: transformers_config.json in script directory)" | |
| ) | |
| parser.add_argument( | |
| "--seed", | |
| type=int, | |
| default=None, | |
| help="Random seed for reproducibility (default: based on current time)" | |
| ) | |
| parser.add_argument( | |
| "--log_level", | |
| type=str, | |
| choices=["debug", "info", "warning", "error", "critical"], | |
| default="info", | |
| help="Logging level (default: info)" | |
| ) | |
| return parser.parse_args() | |
| def load_model_and_tokenizer(config): | |
| """ | |
| Load the model and tokenizer according to the configuration. | |
| Args: | |
| config (dict): Complete configuration dictionary | |
| Returns: | |
| tuple: (model, tokenizer) - The loaded model and tokenizer | |
| """ | |
| # Extract model configuration | |
| model_config = get_config_value(config, "model", {}) | |
| model_name = get_config_value(model_config, "name", "unsloth/phi-4-unsloth-bnb-4bit") | |
| use_fast_tokenizer = get_config_value(model_config, "use_fast_tokenizer", True) | |
| trust_remote_code = get_config_value(model_config, "trust_remote_code", True) | |
| model_revision = get_config_value(config, "model_revision", "main") | |
| # Unsloth configuration | |
| unsloth_config = get_config_value(config, "unsloth", {}) | |
| unsloth_enabled = get_config_value(unsloth_config, "enabled", True) | |
| # Tokenizer configuration | |
| tokenizer_config = get_config_value(config, "tokenizer", {}) | |
| max_seq_length = min( | |
| get_config_value(tokenizer_config, "max_seq_length", 2048), | |
| 4096 # Maximum supported by most models | |
| ) | |
| add_eos_token = get_config_value(tokenizer_config, "add_eos_token", True) | |
| chat_template = get_config_value(tokenizer_config, "chat_template", None) | |
| padding_side = get_config_value(tokenizer_config, "padding_side", "right") | |
| # Check for flash attention | |
| use_flash_attention = get_config_value(config, "use_flash_attention", False) | |
| flash_attention_available = False | |
| try: | |
| import flash_attn | |
| flash_attention_available = True | |
| log_info(f"Flash Attention detected (version: {flash_attn.__version__})") | |
| except ImportError: | |
| if use_flash_attention: | |
| log_info("Flash Attention requested but not available") | |
| log_info(f"Loading model: {model_name} (revision: {model_revision})") | |
| log_info(f"Max sequence length: {max_seq_length}") | |
| try: | |
| if unsloth_enabled and unsloth_available: | |
| log_info("Using Unsloth for 4-bit quantized model and LoRA") | |
| # Load using Unsloth | |
| from unsloth import FastLanguageModel | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=model_name, | |
| max_seq_length=max_seq_length, | |
| dtype=get_config_value(config, "torch_dtype", "bfloat16"), | |
| revision=model_revision, | |
| trust_remote_code=trust_remote_code, | |
| use_flash_attention_2=use_flash_attention and flash_attention_available | |
| ) | |
| # Configure tokenizer settings | |
| tokenizer.padding_side = padding_side | |
| if add_eos_token and tokenizer.eos_token is None: | |
| log_info("Setting EOS token") | |
| tokenizer.add_special_tokens({"eos_token": "</s>"}) | |
| # Set chat template if specified | |
| if chat_template: | |
| log_info(f"Setting chat template: {chat_template}") | |
| if hasattr(tokenizer, "chat_template"): | |
| tokenizer.chat_template = chat_template | |
| else: | |
| log_info("Tokenizer does not support chat templates, using default formatting") | |
| # Apply LoRA | |
| lora_r = get_config_value(unsloth_config, "r", 16) | |
| lora_alpha = get_config_value(unsloth_config, "alpha", 32) | |
| lora_dropout = get_config_value(unsloth_config, "dropout", 0) | |
| target_modules = get_config_value(unsloth_config, "target_modules", | |
| ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]) | |
| log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}") | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=lora_r, | |
| target_modules=target_modules, | |
| lora_alpha=lora_alpha, | |
| lora_dropout=lora_dropout, | |
| bias="none", | |
| use_gradient_checkpointing=get_config_value(config, "training.gradient_checkpointing", True), | |
| random_state=0, | |
| max_seq_length=max_seq_length, | |
| modules_to_save=None | |
| ) | |
| if use_flash_attention and flash_attention_available: | |
| log_info("🚀 Using Flash Attention for faster training") | |
| elif use_flash_attention and not flash_attention_available: | |
| log_info("⚠️ Flash Attention requested but not available - using standard attention") | |
| else: | |
| # Standard HuggingFace loading | |
| log_info("Using standard HuggingFace model loading (Unsloth not available or disabled)") | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # Check if flash attention should be enabled in config | |
| use_attn_implementation = None | |
| if use_flash_attention and flash_attention_available: | |
| use_attn_implementation = "flash_attention_2" | |
| log_info("🚀 Using Flash Attention for faster training") | |
| # Load tokenizer first | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| trust_remote_code=trust_remote_code, | |
| use_fast=use_fast_tokenizer, | |
| revision=model_revision, | |
| padding_side=padding_side | |
| ) | |
| # Configure tokenizer settings | |
| if add_eos_token and tokenizer.eos_token is None: | |
| log_info("Setting EOS token") | |
| tokenizer.add_special_tokens({"eos_token": "</s>"}) | |
| # Set chat template if specified | |
| if chat_template: | |
| log_info(f"Setting chat template: {chat_template}") | |
| if hasattr(tokenizer, "chat_template"): | |
| tokenizer.chat_template = chat_template | |
| else: | |
| log_info("Tokenizer does not support chat templates, using default formatting") | |
| # Now load model with updated tokenizer | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| trust_remote_code=trust_remote_code, | |
| revision=model_revision, | |
| torch_dtype=torch.bfloat16 if get_config_value(config, "torch_dtype", "bfloat16") == "bfloat16" else torch.float16, | |
| device_map="auto" if CUDA_AVAILABLE else None, | |
| attn_implementation=use_attn_implementation | |
| ) | |
| # Apply PEFT/LoRA if enabled but using standard loading | |
| if peft_available and get_config_value(unsloth_config, "enabled", True): | |
| log_info("Applying standard PEFT/LoRA configuration") | |
| from peft import LoraConfig, get_peft_model | |
| lora_r = get_config_value(unsloth_config, "r", 16) | |
| lora_alpha = get_config_value(unsloth_config, "alpha", 32) | |
| lora_dropout = get_config_value(unsloth_config, "dropout", 0) | |
| target_modules = get_config_value(unsloth_config, "target_modules", | |
| ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]) | |
| log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}") | |
| lora_config = LoraConfig( | |
| r=lora_r, | |
| lora_alpha=lora_alpha, | |
| target_modules=target_modules, | |
| lora_dropout=lora_dropout, | |
| bias="none", | |
| task_type="CAUSAL_LM" | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| # Print model summary | |
| log_info(f"Model loaded successfully: {model.__class__.__name__}") | |
| if hasattr(model, "print_trainable_parameters"): | |
| model.print_trainable_parameters() | |
| else: | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| log_info(f"Model has {total_params:,} parameters, {trainable_params:,} trainable ({trainable_params/total_params:.2%})") | |
| return model, tokenizer | |
| except Exception as e: | |
| log_info(f"Error loading model: {str(e)}") | |
| traceback.print_exc() | |
| return None, None | |
| def load_dataset_with_mapping(config): | |
| """ | |
| Load dataset from Hugging Face or local files and apply necessary transformations. | |
| Args: | |
| config (dict): Dataset configuration dictionary | |
| Returns: | |
| Dataset: The loaded and processed dataset | |
| """ | |
| # Extract dataset configuration | |
| dataset_info = get_config_value(config, "dataset", {}) | |
| dataset_name = get_config_value(dataset_info, "name", None) | |
| dataset_split = get_config_value(dataset_info, "split", "train") | |
| # Data formatting configuration | |
| formatting_config = get_config_value(config, "data_formatting", {}) | |
| if not dataset_name: | |
| raise ValueError("Dataset name not specified in config") | |
| log_info(f"Loading dataset: {dataset_name} (split: {dataset_split})") | |
| try: | |
| # Load dataset from Hugging Face or local path | |
| from datasets import load_dataset | |
| # Check if it's a local path or Hugging Face dataset | |
| if os.path.exists(dataset_name) or os.path.exists(os.path.join(os.getcwd(), dataset_name)): | |
| log_info(f"Loading dataset from local path: {dataset_name}") | |
| # Local dataset - check if it's a directory or file | |
| if os.path.isdir(dataset_name): | |
| # Directory - look for data files | |
| dataset = load_dataset( | |
| "json", | |
| data_files={"train": os.path.join(dataset_name, "*.json")}, | |
| split=dataset_split | |
| ) | |
| else: | |
| # Single file | |
| dataset = load_dataset( | |
| "json", | |
| data_files={"train": dataset_name}, | |
| split=dataset_split | |
| ) | |
| else: | |
| # Hugging Face dataset | |
| log_info(f"Loading dataset from Hugging Face: {dataset_name}") | |
| dataset = load_dataset(dataset_name, split=dataset_split) | |
| log_info(f"Dataset loaded with {len(dataset)} examples") | |
| # Check if dataset contains required fields | |
| required_fields = ["conversations"] | |
| missing_fields = [field for field in required_fields if field not in dataset.column_names] | |
| if missing_fields: | |
| log_info(f"WARNING: Dataset missing required fields: {missing_fields}") | |
| log_info("Attempting to map dataset structure to required format") | |
| # Implement conversion logic based on dataset structure | |
| if "messages" in dataset.column_names: | |
| log_info("Converting 'messages' field to 'conversations' format") | |
| dataset = dataset.map( | |
| lambda x: {"conversations": x["messages"]}, | |
| remove_columns=["messages"] | |
| ) | |
| elif "text" in dataset.column_names: | |
| log_info("Converting plain text to conversations format") | |
| dataset = dataset.map( | |
| lambda x: {"conversations": [{"role": "user", "content": x["text"]}]}, | |
| remove_columns=["text"] | |
| ) | |
| else: | |
| raise ValueError(f"Cannot convert dataset format - missing required fields and no conversion path available") | |
| # Log dataset info | |
| log_info(f"Dataset has {len(dataset)} examples and columns: {dataset.column_names}") | |
| # Show a few examples for verification | |
| for i in range(min(3, len(dataset))): | |
| example = dataset[i] | |
| log_info(f"Example {i}:") | |
| for key, value in example.items(): | |
| if key == "conversations": | |
| log_info(f" conversations: {len(value)} messages") | |
| # Show first message only to avoid cluttering logs | |
| if value and len(value) > 0: | |
| first_msg = value[0] | |
| if isinstance(first_msg, dict) and "content" in first_msg: | |
| content = first_msg["content"] | |
| log_info(f" First message: {content[:50]}..." if len(content) > 50 else f" First message: {content}") | |
| else: | |
| log_info(f" {key}: {value}") | |
| return dataset | |
| except Exception as e: | |
| log_info(f"Error loading dataset: {str(e)}") | |
| traceback.print_exc() | |
| return None | |
| def format_phi_chat(messages, dataset_config): | |
| """Format messages according to phi-4's chat template and dataset config. | |
| Only formats the conversation structure, preserves the actual content.""" | |
| formatted_chat = "" | |
| # Get role templates from config | |
| roles = dataset_config.get("data_formatting", {}).get("roles", { | |
| "system": "System: {content}\n\n", | |
| "human": "Human: {content}\n\n", | |
| "assistant": "Assistant: {content}\n\n" | |
| }) | |
| # Handle each message in the conversation | |
| for message in messages: | |
| if not isinstance(message, dict) or "content" not in message: | |
| logger.warning(f"Skipping invalid message format: {message}") | |
| continue | |
| content = message.get("content", "") # Don't strip() - preserve exact content | |
| # Skip empty content | |
| if not content: | |
| continue | |
| # Only add role prefixes based on position/content | |
| if "[RESEARCH INTRODUCTION]" in content: | |
| # System message | |
| template = roles.get("system", "System: {content}\n\n") | |
| formatted_chat = template.format(content=content) + formatted_chat | |
| else: | |
| # Alternate between human and assistant for regular conversation turns | |
| # In phi-4 format, human messages come first, followed by assistant responses | |
| if len(formatted_chat.split("Human:")) == len(formatted_chat.split("Assistant:")): | |
| # If equal numbers of Human and Assistant messages, next is Human | |
| template = roles.get("human", "Human: {content}\n\n") | |
| else: | |
| # Otherwise, next is Assistant | |
| template = roles.get("assistant", "Assistant: {content}\n\n") | |
| formatted_chat += template.format(content=content) | |
| return formatted_chat | |
| class SimpleDataCollator: | |
| def __init__(self, tokenizer, dataset_config): | |
| self.tokenizer = tokenizer | |
| self.max_seq_length = min(dataset_config.get("max_seq_length", 2048), tokenizer.model_max_length) | |
| self.stats = { | |
| "processed": 0, | |
| "skipped": 0, | |
| "total_tokens": 0 | |
| } | |
| logger.info(f"Initialized SimpleDataCollator with max_seq_length={self.max_seq_length}") | |
| def __call__(self, features): | |
| # Initialize tensors on CPU to save GPU memory | |
| batch = { | |
| "input_ids": [], | |
| "attention_mask": [], | |
| "labels": [] | |
| } | |
| for feature in features: | |
| paper_id = feature.get("article_id", "unknown") | |
| prompt_num = feature.get("prompt_number", 0) | |
| conversations = feature.get("conversations", []) | |
| if not conversations: | |
| logger.warning(f"No conversations for paper_id {paper_id}, prompt {prompt_num}") | |
| self.stats["skipped"] += 1 | |
| continue | |
| # Get the content directly | |
| content = conversations[0].get("content", "") | |
| if not content: | |
| logger.warning(f"Empty content for paper_id {paper_id}, prompt {prompt_num}") | |
| self.stats["skipped"] += 1 | |
| continue | |
| # Process the content string by tokenizing it | |
| if isinstance(content, str): | |
| # Tokenize the content string | |
| input_ids = self.tokenizer.encode(content, add_special_tokens=True) | |
| else: | |
| # If somehow the content is already tokenized (not a string), use it directly | |
| input_ids = content | |
| # Truncate if needed | |
| if len(input_ids) > self.max_seq_length: | |
| input_ids = input_ids[:self.max_seq_length] | |
| logger.warning(f"Truncated sequence for paper_id {paper_id}, prompt {prompt_num}") | |
| # Create attention mask (1s for all tokens) | |
| attention_mask = [1] * len(input_ids) | |
| # Add to batch | |
| batch["input_ids"].append(input_ids) | |
| batch["attention_mask"].append(attention_mask) | |
| batch["labels"].append(input_ids.copy()) # For causal LM, labels = input_ids | |
| self.stats["processed"] += 1 | |
| self.stats["total_tokens"] += len(input_ids) | |
| # Log statistics periodically | |
| if self.stats["processed"] % 100 == 0: | |
| avg_tokens = self.stats["total_tokens"] / max(1, self.stats["processed"]) | |
| logger.info(f"Data collation stats: processed={self.stats['processed']}, " | |
| f"skipped={self.stats['skipped']}, avg_tokens={avg_tokens:.1f}") | |
| # Convert to tensors or pad sequences (PyTorch will handle) | |
| if batch["input_ids"]: | |
| # Pad sequences to max length in batch using the tokenizer | |
| batch = self.tokenizer.pad( | |
| batch, | |
| padding="max_length", | |
| max_length=self.max_seq_length, | |
| return_tensors="pt" | |
| ) | |
| return batch | |
| else: | |
| # Return empty batch if no valid examples | |
| return {k: [] for k in batch} | |
| def log_gpu_memory_usage(step=None, frequency=50, clear_cache_threshold=0.9, label=None): | |
| """ | |
| Log GPU memory usage statistics with optional cache clearing | |
| Args: | |
| step: Current training step (if None, logs regardless of frequency) | |
| frequency: How often to log when step is provided | |
| clear_cache_threshold: Fraction of memory used that triggers cache clearing (0-1) | |
| label: Optional label for the log message (e.g., "Initial", "Error", "Step") | |
| """ | |
| if not CUDA_AVAILABLE: | |
| return | |
| # Only log every 'frequency' steps if step is provided | |
| if step is not None and frequency > 0 and step % frequency != 0: | |
| return | |
| # Get memory usage for each GPU | |
| memory_info = [] | |
| for i in range(NUM_GPUS): | |
| allocated = torch.cuda.memory_allocated(i) / (1024 ** 2) # MB | |
| reserved = torch.cuda.memory_reserved(i) / (1024 ** 2) # MB | |
| max_mem = torch.cuda.max_memory_allocated(i) / (1024 ** 2) # MB | |
| # Calculate percentage of reserved memory that's allocated | |
| usage_percent = (allocated / reserved) * 100 if reserved > 0 else 0 | |
| memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB ({usage_percent:.1f}%, max: {max_mem:.1f}MB)") | |
| # Automatically clear cache if over threshold | |
| if clear_cache_threshold > 0 and reserved > 0 and (allocated / reserved) > clear_cache_threshold: | |
| log_info(f"Clearing CUDA cache for GPU {i} - high utilization ({allocated:.1f}/{reserved:.1f}MB)") | |
| with torch.cuda.device(i): | |
| torch.cuda.empty_cache() | |
| prefix = f"{label} " if label else "" | |
| log_info(f"{prefix}GPU Memory: {', '.join(memory_info)}") | |
| class LoggingCallback(TrainerCallback): | |
| """ | |
| Custom callback for logging training progress and metrics. | |
| Provides detailed information about training status, GPU memory usage, and model performance. | |
| """ | |
| def __init__(self, model=None, dataset=None): | |
| # Ensure we have TrainerCallback | |
| try: | |
| super().__init__() | |
| except Exception as e: | |
| # Try to import directly if initial import failed | |
| try: | |
| from transformers.trainer_callback import TrainerCallback | |
| self.__class__.__bases__ = (TrainerCallback,) | |
| super().__init__() | |
| log_info("Successfully imported TrainerCallback directly") | |
| except ImportError as ie: | |
| log_info(f"❌ Error: Could not import TrainerCallback: {str(ie)}") | |
| log_info("Please ensure transformers is properly installed") | |
| raise | |
| self.training_started = time.time() | |
| self.last_log_time = time.time() | |
| self.last_step_time = None | |
| self.step_durations = [] | |
| self.best_loss = float('inf') | |
| self.model = model | |
| self.dataset = dataset | |
| def on_train_begin(self, args, state, control, **kwargs): | |
| """Called at the beginning of training""" | |
| try: | |
| log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===") | |
| # Log model info if available | |
| if self.model is not None: | |
| total_params = sum(p.numel() for p in self.model.parameters()) | |
| trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) | |
| log_info(f"Model parameters: {total_params/1e6:.2f}M total, {trainable_params/1e6:.2f}M trainable") | |
| # Log dataset info if available | |
| if self.dataset is not None: | |
| log_info(f"Dataset size: {len(self.dataset)} examples") | |
| # Log important training parameters for visibility | |
| total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS | |
| total_steps = int(len(self.dataset or []) / (args.per_device_train_batch_size * NUM_GPUS * args.gradient_accumulation_steps) * args.num_train_epochs) | |
| log_info(f"Training plan: {len(self.dataset or [])} examples over {args.num_train_epochs} epochs ≈ {total_steps} steps") | |
| log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total") | |
| # Log initial GPU memory usage with label | |
| log_gpu_memory_usage(label="Initial") | |
| except Exception as e: | |
| logger.warning(f"Error logging training begin statistics: {str(e)}") | |
| def on_step_end(self, args, state, control, **kwargs): | |
| """Called at the end of each step""" | |
| try: | |
| if state.global_step == 1 or state.global_step % args.logging_steps == 0: | |
| # Track step timing | |
| current_time = time.time() | |
| if self.last_step_time: | |
| step_duration = current_time - self.last_step_time | |
| self.step_durations.append(step_duration) | |
| # Keep only last 100 steps for averaging | |
| if len(self.step_durations) > 100: | |
| self.step_durations.pop(0) | |
| avg_step_time = sum(self.step_durations) / len(self.step_durations) | |
| log_info(f"Step {state.global_step}: {step_duration:.2f}s (avg: {avg_step_time:.2f}s)") | |
| self.last_step_time = current_time | |
| # Log GPU memory usage with step number | |
| log_gpu_memory_usage(state.global_step, args.logging_steps) | |
| # Log loss | |
| if state.log_history: | |
| latest_logs = state.log_history[-1] if state.log_history else {} | |
| if "loss" in latest_logs: | |
| loss = latest_logs["loss"] | |
| log_info(f"Step {state.global_step} loss: {loss:.4f}") | |
| # Track best loss | |
| if loss < self.best_loss: | |
| self.best_loss = loss | |
| log_info(f"New best loss: {loss:.4f}") | |
| except Exception as e: | |
| logger.warning(f"Error logging step end statistics: {str(e)}") | |
| def on_train_end(self, args, state, control, **kwargs): | |
| """Called at the end of training""" | |
| try: | |
| # Calculate training duration | |
| training_time = time.time() - self.training_started | |
| hours, remainder = divmod(training_time, 3600) | |
| minutes, seconds = divmod(remainder, 60) | |
| log_info(f"=== Training completed at {time.strftime('%Y-%m-%d %H:%M:%S')} ===") | |
| log_info(f"Training duration: {int(hours)}h {int(minutes)}m {int(seconds)}s") | |
| log_info(f"Final step: {state.global_step}") | |
| log_info(f"Best loss: {self.best_loss:.4f}") | |
| # Log final GPU memory usage | |
| log_gpu_memory_usage(label="Final") | |
| except Exception as e: | |
| logger.warning(f"Error logging training end statistics: {str(e)}") | |
| # Other callback methods with proper error handling | |
| def on_save(self, args, state, control, **kwargs): | |
| """Called when a checkpoint is saved""" | |
| try: | |
| log_info(f"Saving checkpoint at step {state.global_step}") | |
| except Exception as e: | |
| logger.warning(f"Error in on_save: {str(e)}") | |
| def on_log(self, args, state, control, **kwargs): | |
| """Called when a log is created""" | |
| pass | |
| def on_evaluate(self, args, state, control, **kwargs): | |
| """Called when evaluation is performed""" | |
| pass | |
| # Only implement the methods we actually need, remove the others | |
| def on_prediction_step(self, args, state, control, **kwargs): | |
| """Called when prediction is performed""" | |
| pass | |
| def on_save_model(self, args, state, control, **kwargs): | |
| """Called when model is saved""" | |
| try: | |
| # Log memory usage after saving | |
| log_gpu_memory_usage(label=f"Save at step {state.global_step}") | |
| except Exception as e: | |
| logger.warning(f"Error in on_save_model: {str(e)}") | |
| def on_epoch_end(self, args, state, control, **kwargs): | |
| """Called at the end of an epoch""" | |
| try: | |
| epoch = state.epoch | |
| log_info(f"Completed epoch {epoch:.2f}") | |
| log_gpu_memory_usage(label=f"Epoch {epoch:.2f}") | |
| except Exception as e: | |
| logger.warning(f"Error in on_epoch_end: {str(e)}") | |
| def on_step_begin(self, args, state, control, **kwargs): | |
| """Called at the beginning of a step""" | |
| pass | |
| def install_flash_attention(): | |
| """ | |
| Attempt to install Flash Attention for improved performance. | |
| Returns True if installation was successful, False otherwise. | |
| """ | |
| log_info("Attempting to install Flash Attention...") | |
| # Check for CUDA before attempting installation | |
| if not CUDA_AVAILABLE: | |
| log_info("❌ Cannot install Flash Attention: CUDA not available") | |
| return False | |
| try: | |
| # Check CUDA version to determine correct installation command | |
| cuda_version = torch.version.cuda | |
| if cuda_version is None: | |
| log_info("❌ Cannot determine CUDA version for Flash Attention installation") | |
| return False | |
| import subprocess | |
| # Use --no-build-isolation for better compatibility | |
| install_cmd = [ | |
| sys.executable, | |
| "-m", | |
| "pip", | |
| "install", | |
| "flash-attn", | |
| "--no-build-isolation" | |
| ] | |
| log_info(f"Running: {' '.join(install_cmd)}") | |
| result = subprocess.run( | |
| install_cmd, | |
| capture_output=True, | |
| text=True, | |
| check=False | |
| ) | |
| if result.returncode == 0: | |
| log_info("✅ Flash Attention installed successfully!") | |
| # Attempt to import to verify installation | |
| try: | |
| import flash_attn | |
| log_info(f"✅ Flash Attention version {flash_attn.__version__} is now available") | |
| return True | |
| except ImportError: | |
| log_info("⚠️ Flash Attention installed but import failed") | |
| return False | |
| else: | |
| log_info(f"❌ Flash Attention installation failed with error: {result.stderr}") | |
| return False | |
| except Exception as e: | |
| log_info(f"❌ Error installing Flash Attention: {str(e)}") | |
| return False | |
| def check_dependencies(): | |
| """ | |
| Check for required and optional dependencies, ensuring proper versions and import order. | |
| Returns True if all required dependencies are present, False otherwise. | |
| """ | |
| # Define required packages with versions and descriptions | |
| required_packages = { | |
| "unsloth": {"version": ">=2024.3", "feature": "fast 4-bit quantization and LoRA"}, | |
| "transformers": {"version": ">=4.38.0", "feature": "core model functionality"}, | |
| "peft": {"version": ">=0.9.0", "feature": "parameter-efficient fine-tuning"}, | |
| "accelerate": {"version": ">=0.27.0", "feature": "multi-GPU training"} | |
| } | |
| # Optional packages that enhance functionality | |
| optional_packages = { | |
| "flash_attn": {"feature": "faster attention computation"}, | |
| "bitsandbytes": {"feature": "quantization support"}, | |
| "optimum": {"feature": "model optimization"}, | |
| "wandb": {"feature": "experiment tracking"} | |
| } | |
| # Store results | |
| missing_packages = [] | |
| package_versions = {} | |
| order_issues = [] | |
| missing_optional = [] | |
| # Check required packages | |
| log_info("Checking required dependencies...") | |
| for package, info in required_packages.items(): | |
| version_req = info["version"] | |
| feature = info["feature"] | |
| try: | |
| # Special handling for packages we've already checked | |
| if package == "unsloth" and not unsloth_available: | |
| missing_packages.append(f"{package}{version_req}") | |
| log_info(f"❌ {package} - {feature} MISSING") | |
| continue | |
| elif package == "peft" and not peft_available: | |
| missing_packages.append(f"{package}{version_req}") | |
| log_info(f"❌ {package} - {feature} MISSING") | |
| continue | |
| # Try to import and get version | |
| module = __import__(package) | |
| version = getattr(module, "__version__", "unknown") | |
| package_versions[package] = version | |
| log_info(f"✅ {package} v{version} - {feature}") | |
| except ImportError: | |
| missing_packages.append(f"{package}{version_req}") | |
| log_info(f"❌ {package} - {feature} MISSING") | |
| # Check optional packages | |
| log_info("\nChecking optional dependencies...") | |
| for package, info in optional_packages.items(): | |
| feature = info["feature"] | |
| try: | |
| __import__(package) | |
| log_info(f"✅ {package} - {feature} available") | |
| except ImportError: | |
| log_info(f"⚠️ {package} - {feature} not available") | |
| missing_optional.append(package) | |
| # Check import order for optimal performance | |
| if "transformers" in package_versions and "unsloth" in package_versions: | |
| try: | |
| import sys | |
| modules = list(sys.modules.keys()) | |
| transformers_idx = modules.index("transformers") | |
| unsloth_idx = modules.index("unsloth") | |
| if transformers_idx < unsloth_idx: | |
| order_issue = "⚠️ For optimal performance, import unsloth before transformers" | |
| order_issues.append(order_issue) | |
| log_info(order_issue) | |
| log_info("This might cause performance issues but won't prevent training") | |
| else: | |
| log_info("✅ Import order: unsloth before transformers (optimal)") | |
| except (ValueError, IndexError) as e: | |
| log_info(f"⚠️ Could not verify import order: {str(e)}") | |
| # Try to install missing optional packages | |
| if "flash_attn" in missing_optional and CUDA_AVAILABLE: | |
| log_info("\nFlash Attention is missing but would improve performance.") | |
| install_result = install_flash_attention() | |
| if install_result: | |
| missing_optional.remove("flash_attn") | |
| # Report missing required packages | |
| if missing_packages: | |
| log_info("\n❌ Critical dependencies missing:") | |
| for pkg in missing_packages: | |
| log_info(f" - {pkg}") | |
| log_info("Please install missing dependencies with:") | |
| log_info(f" pip install {' '.join(missing_packages)}") | |
| return False | |
| log_info("\n✅ All required dependencies satisfied!") | |
| return True | |
| def get_config_value(config, path, default=None): | |
| """ | |
| Safely get a nested value from a config dictionary using a dot-separated path. | |
| Args: | |
| config: The configuration dictionary | |
| path: Dot-separated path to the value (e.g., "training.optimizer.lr") | |
| default: Default value to return if path doesn't exist | |
| Returns: | |
| The value at the specified path or the default value | |
| """ | |
| if not config: | |
| return default | |
| parts = path.split('.') | |
| current = config | |
| for part in parts: | |
| if isinstance(current, dict) and part in current: | |
| current = current[part] | |
| else: | |
| return default | |
| return current | |
| def update_huggingface_space(): | |
| """Update the Hugging Face Space with the current code.""" | |
| log_info("Updating Hugging Face Space...") | |
| update_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "update_space.py") | |
| if not os.path.exists(update_script): | |
| logger.warning(f"Update space script not found at {update_script}") | |
| return False | |
| try: | |
| import subprocess | |
| # Explicitly set space_name to ensure we're targeting the right Space | |
| result = subprocess.run( | |
| [sys.executable, update_script, "--force", "--space_name", "phi4training"], | |
| capture_output=True, text=True, check=False | |
| ) | |
| if result.returncode == 0: | |
| log_info("Hugging Face Space updated successfully!") | |
| log_info(f"Space URL: https://huggingface.co/spaces/George-API/phi4training") | |
| return True | |
| else: | |
| logger.error(f"Failed to update Hugging Face Space: {result.stderr}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error updating Hugging Face Space: {str(e)}") | |
| return False | |
| def validate_huggingface_credentials(): | |
| """Validate Hugging Face credentials to ensure they work correctly.""" | |
| if not os.environ.get("HF_TOKEN"): | |
| logger.warning("HF_TOKEN not found. Skipping Hugging Face credentials validation.") | |
| return False | |
| try: | |
| # Import here to avoid requiring huggingface_hub if not needed | |
| from huggingface_hub import HfApi, login | |
| # Try to login with the token | |
| login(token=os.environ.get("HF_TOKEN")) | |
| # Check if we can access the API | |
| api = HfApi() | |
| username = os.environ.get("HF_USERNAME", "George-API") | |
| space_name = os.environ.get("HF_SPACE_NAME", "phi4training") | |
| # Try to get whoami info | |
| user_info = api.whoami() | |
| logger.info(f"Successfully authenticated with Hugging Face as {user_info['name']}") | |
| # Check if we're using the expected Space | |
| expected_space_id = "George-API/phi4training" | |
| actual_space_id = f"{username}/{space_name}" | |
| if actual_space_id != expected_space_id: | |
| logger.warning(f"Using Space '{actual_space_id}' instead of the expected '{expected_space_id}'") | |
| logger.warning(f"Make sure this is intentional. To use the correct Space, update your .env file.") | |
| else: | |
| logger.info(f"Confirmed using Space: {expected_space_id}") | |
| # Check if the space exists | |
| try: | |
| space_id = f"{username}/{space_name}" | |
| space_info = api.space_info(repo_id=space_id) | |
| logger.info(f"Space {space_id} is accessible at: https://huggingface.co/spaces/{space_id}") | |
| return True | |
| except Exception as e: | |
| logger.warning(f"Could not access Space {username}/{space_name}: {str(e)}") | |
| logger.warning("Space updating may not work correctly") | |
| return False | |
| except ImportError: | |
| logger.warning("huggingface_hub not installed. Cannot validate Hugging Face credentials.") | |
| return False | |
| except Exception as e: | |
| logger.warning(f"Error validating Hugging Face credentials: {str(e)}") | |
| return False | |
| def setup_environment(args): | |
| """ | |
| Set up the training environment including logging, seed, and configurations. | |
| Args: | |
| args: Command line arguments | |
| Returns: | |
| tuple: (transformers_config, seed) - The loaded configuration and random seed | |
| """ | |
| # Load environment variables first | |
| load_env_variables() | |
| # Set random seed for reproducibility | |
| seed = args.seed if args.seed is not None else int(time.time()) % 10000 | |
| set_seed(seed) | |
| log_info(f"Using random seed: {seed}") | |
| # Load configuration | |
| base_path = os.path.dirname(os.path.abspath(__file__)) | |
| config_file = args.config_file or os.path.join(base_path, "transformers_config.json") | |
| if not os.path.exists(config_file): | |
| raise FileNotFoundError(f"Config file not found: {config_file}") | |
| log_info(f"Loading configuration from {config_file}") | |
| transformers_config = load_configs(config_file) | |
| # Set up hardware environment variables if CUDA is available | |
| if CUDA_AVAILABLE: | |
| memory_fraction = get_config_value(transformers_config, "hardware.system_settings.cuda_memory_fraction", 0.75) | |
| if memory_fraction < 1.0: | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:128,expandable_segments:True" | |
| log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128") | |
| # Check dependencies and install optional ones if needed | |
| if not check_dependencies(): | |
| raise RuntimeError("Critical dependencies missing") | |
| # Check if flash attention was successfully installed | |
| flash_attention_available = False | |
| try: | |
| import flash_attn | |
| flash_attention_available = True | |
| log_info(f"Flash Attention will be used (version: {flash_attn.__version__})") | |
| # Update config to use flash attention | |
| if "use_flash_attention" not in transformers_config: | |
| transformers_config["use_flash_attention"] = True | |
| except ImportError: | |
| log_info("Flash Attention not available, will use standard attention mechanism") | |
| return transformers_config, seed | |
| def setup_model_and_tokenizer(config): | |
| """ | |
| Load and configure the model and tokenizer. | |
| Args: | |
| config (dict): Complete configuration dictionary | |
| Returns: | |
| tuple: (model, tokenizer) - The loaded model and tokenizer | |
| """ | |
| # Extract model configuration | |
| model_config = get_config_value(config, "model", {}) | |
| model_name = get_config_value(model_config, "name", "unsloth/phi-4-unsloth-bnb-4bit") | |
| use_fast_tokenizer = get_config_value(model_config, "use_fast_tokenizer", True) | |
| trust_remote_code = get_config_value(model_config, "trust_remote_code", True) | |
| model_revision = get_config_value(config, "model_revision", "main") | |
| # Detect if model is already pre-quantized (includes '4bit', 'bnb', or 'int4' in name) | |
| is_prequantized = any(q in model_name.lower() for q in ['4bit', 'bnb', 'int4', 'quant']) | |
| if is_prequantized: | |
| log_info("⚠️ Detected pre-quantized model. No additional quantization will be applied.") | |
| # Unsloth configuration | |
| unsloth_config = get_config_value(config, "unsloth", {}) | |
| unsloth_enabled = get_config_value(unsloth_config, "enabled", True) | |
| # Tokenizer configuration | |
| tokenizer_config = get_config_value(config, "tokenizer", {}) | |
| max_seq_length = min( | |
| get_config_value(tokenizer_config, "max_seq_length", 2048), | |
| 4096 # Maximum supported by most models | |
| ) | |
| add_eos_token = get_config_value(tokenizer_config, "add_eos_token", True) | |
| chat_template = get_config_value(tokenizer_config, "chat_template", None) | |
| padding_side = get_config_value(tokenizer_config, "padding_side", "right") | |
| # Check for flash attention | |
| use_flash_attention = get_config_value(config, "use_flash_attention", False) | |
| flash_attention_available = False | |
| try: | |
| import flash_attn | |
| flash_attention_available = True | |
| log_info(f"Flash Attention detected (version: {flash_attn.__version__})") | |
| except ImportError: | |
| if use_flash_attention: | |
| log_info("Flash Attention requested but not available") | |
| log_info(f"Loading model: {model_name} (revision: {model_revision})") | |
| log_info(f"Max sequence length: {max_seq_length}") | |
| try: | |
| if unsloth_enabled and unsloth_available: | |
| log_info("Using Unsloth for LoRA fine-tuning") | |
| if is_prequantized: | |
| log_info("Using pre-quantized model - no additional quantization will be applied") | |
| else: | |
| log_info("Using 4-bit quantization for efficient training") | |
| # Load using Unsloth | |
| from unsloth import FastLanguageModel | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=model_name, | |
| max_seq_length=max_seq_length, | |
| dtype=get_config_value(config, "torch_dtype", "bfloat16"), | |
| revision=model_revision, | |
| trust_remote_code=trust_remote_code, | |
| use_flash_attention_2=use_flash_attention and flash_attention_available | |
| ) | |
| # Configure tokenizer settings | |
| tokenizer.padding_side = padding_side | |
| if add_eos_token and tokenizer.eos_token is None: | |
| log_info("Setting EOS token") | |
| tokenizer.add_special_tokens({"eos_token": "</s>"}) | |
| # Set chat template if specified | |
| if chat_template: | |
| log_info(f"Setting chat template: {chat_template}") | |
| if hasattr(tokenizer, "chat_template"): | |
| tokenizer.chat_template = chat_template | |
| else: | |
| log_info("Tokenizer does not support chat templates, using default formatting") | |
| # Apply LoRA | |
| lora_r = get_config_value(unsloth_config, "r", 16) | |
| lora_alpha = get_config_value(unsloth_config, "alpha", 32) | |
| lora_dropout = get_config_value(unsloth_config, "dropout", 0) | |
| target_modules = get_config_value(unsloth_config, "target_modules", | |
| ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]) | |
| log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}") | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=lora_r, | |
| target_modules=target_modules, | |
| lora_alpha=lora_alpha, | |
| lora_dropout=lora_dropout, | |
| bias="none", | |
| use_gradient_checkpointing=get_config_value(config, "training.gradient_checkpointing", True), | |
| random_state=0, | |
| max_seq_length=max_seq_length, | |
| modules_to_save=None | |
| ) | |
| if use_flash_attention and flash_attention_available: | |
| log_info("🚀 Using Flash Attention for faster training") | |
| elif use_flash_attention and not flash_attention_available: | |
| log_info("⚠️ Flash Attention requested but not available - using standard attention") | |
| else: | |
| # Standard HuggingFace loading | |
| log_info("Using standard HuggingFace model loading (Unsloth not available or disabled)") | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # Check if flash attention should be enabled in config | |
| use_attn_implementation = None | |
| if use_flash_attention and flash_attention_available: | |
| use_attn_implementation = "flash_attention_2" | |
| log_info("🚀 Using Flash Attention for faster training") | |
| # Load tokenizer first | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| trust_remote_code=trust_remote_code, | |
| use_fast=use_fast_tokenizer, | |
| revision=model_revision, | |
| padding_side=padding_side | |
| ) | |
| # Configure tokenizer settings | |
| if add_eos_token and tokenizer.eos_token is None: | |
| log_info("Setting EOS token") | |
| tokenizer.add_special_tokens({"eos_token": "</s>"}) | |
| # Set chat template if specified | |
| if chat_template: | |
| log_info(f"Setting chat template: {chat_template}") | |
| if hasattr(tokenizer, "chat_template"): | |
| tokenizer.chat_template = chat_template | |
| else: | |
| log_info("Tokenizer does not support chat templates, using default formatting") | |
| # Only apply quantization config if model is not already pre-quantized | |
| quantization_config = None | |
| if not is_prequantized and CUDA_AVAILABLE: | |
| try: | |
| from transformers import BitsAndBytesConfig | |
| log_info("Using 4-bit quantization (BitsAndBytes) for efficient training") | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True | |
| ) | |
| except ImportError: | |
| log_info("BitsAndBytes not available - quantization disabled") | |
| # Now load model with updated tokenizer | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| trust_remote_code=trust_remote_code, | |
| revision=model_revision, | |
| torch_dtype=torch.bfloat16 if get_config_value(config, "torch_dtype", "bfloat16") == "bfloat16" else torch.float16, | |
| device_map="auto" if CUDA_AVAILABLE else None, | |
| attn_implementation=use_attn_implementation, | |
| quantization_config=quantization_config | |
| ) | |
| # Apply PEFT/LoRA if enabled but using standard loading | |
| if peft_available and get_config_value(unsloth_config, "enabled", True): | |
| log_info("Applying standard PEFT/LoRA configuration") | |
| from peft import LoraConfig, get_peft_model | |
| lora_r = get_config_value(unsloth_config, "r", 16) | |
| lora_alpha = get_config_value(unsloth_config, "alpha", 32) | |
| lora_dropout = get_config_value(unsloth_config, "dropout", 0) | |
| target_modules = get_config_value(unsloth_config, "target_modules", | |
| ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]) | |
| log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}") | |
| lora_config = LoraConfig( | |
| r=lora_r, | |
| lora_alpha=lora_alpha, | |
| target_modules=target_modules, | |
| lora_dropout=lora_dropout, | |
| bias="none", | |
| task_type="CAUSAL_LM" | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| # Print model summary | |
| log_info(f"Model loaded successfully: {model.__class__.__name__}") | |
| if hasattr(model, "print_trainable_parameters"): | |
| model.print_trainable_parameters() | |
| else: | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| log_info(f"Model has {total_params:,} parameters, {trainable_params:,} trainable ({trainable_params/total_params:.2%})") | |
| return model, tokenizer | |
| except Exception as e: | |
| log_info(f"Error loading model: {str(e)}") | |
| traceback.print_exc() | |
| return None, None | |
| def setup_dataset_and_collator(config, tokenizer): | |
| """ | |
| Load and configure the dataset and data collator. | |
| Args: | |
| config: Complete configuration dictionary | |
| tokenizer: The tokenizer for the data collator | |
| Returns: | |
| tuple: (dataset, data_collator) - The loaded dataset and configured data collator | |
| """ | |
| dataset_config = get_config_value(config, "dataset", {}) | |
| log_info("Loading dataset...") | |
| dataset = load_dataset_with_mapping(dataset_config) | |
| # Validate dataset | |
| if dataset is None: | |
| raise ValueError("Dataset is None! Cannot proceed with training.") | |
| if not hasattr(dataset, '__len__') or len(dataset) == 0: | |
| raise ValueError("Dataset is empty! Cannot proceed with training.") | |
| log_info(f"Dataset loaded with {len(dataset)} examples") | |
| # Create data collator | |
| data_collator = SimpleDataCollator(tokenizer, dataset_config) | |
| return dataset, data_collator | |
| def create_training_arguments(config, dataset): | |
| """ | |
| Create and configure training arguments for the Trainer. | |
| Args: | |
| config: Complete configuration dictionary | |
| dataset: The dataset to determine total steps | |
| Returns: | |
| TrainingArguments: Configured training arguments | |
| """ | |
| # Extract configuration sections | |
| training_config = get_config_value(config, "training", {}) | |
| hardware_config = get_config_value(config, "hardware", {}) | |
| huggingface_config = get_config_value(config, "huggingface_hub", {}) | |
| distributed_config = get_config_value(config, "distributed_training", {}) | |
| # Extract key training parameters | |
| per_device_batch_size = get_config_value(training_config, "per_device_train_batch_size", 4) | |
| gradient_accumulation_steps = get_config_value(training_config, "gradient_accumulation_steps", 8) | |
| learning_rate = get_config_value(training_config, "learning_rate", 2e-5) | |
| num_train_epochs = get_config_value(training_config, "num_train_epochs", 3) | |
| # Extract hardware settings | |
| dataloader_workers = get_config_value(hardware_config, "system_settings.dataloader_num_workers", | |
| get_config_value(distributed_config, "dataloader_num_workers", 2)) | |
| pin_memory = get_config_value(hardware_config, "system_settings.dataloader_pin_memory", True) | |
| # BF16/FP16 settings - ensure only one is enabled | |
| use_bf16 = get_config_value(training_config, "bf16", False) | |
| use_fp16 = get_config_value(training_config, "fp16", False) if not use_bf16 else False | |
| # Configure distributed training | |
| fsdp_config = get_config_value(distributed_config, "fsdp_config", {}) | |
| fsdp_enabled = get_config_value(fsdp_config, "enabled", False) | |
| ddp_config = get_config_value(distributed_config, "ddp_config", {}) | |
| ddp_find_unused_parameters = get_config_value(ddp_config, "find_unused_parameters", False) | |
| # Set up FSDP args if enabled | |
| fsdp_args = None | |
| if fsdp_enabled and NUM_GPUS > 1: | |
| from accelerate import FullyShardedDataParallelPlugin | |
| from torch.distributed.fsdp.fully_sharded_data_parallel import ( | |
| FullOptimStateDictConfig, FullStateDictConfig | |
| ) | |
| fsdp_plugin = FullyShardedDataParallelPlugin( | |
| sharding_strategy=get_config_value(fsdp_config, "sharding_strategy", "FULL_SHARD"), | |
| mixed_precision_policy=get_config_value(fsdp_config, "mixed_precision", "BF16"), | |
| state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), | |
| optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), | |
| ) | |
| fsdp_args = { | |
| "fsdp": fsdp_plugin, | |
| "fsdp_transformer_layer_cls_to_wrap": ["LlamaDecoderLayer", "PhiDecoderLayer"] | |
| } | |
| # Create and return training arguments | |
| training_args = TrainingArguments( | |
| output_dir=get_config_value(config, "checkpointing.output_dir", "./results"), | |
| overwrite_output_dir=True, | |
| num_train_epochs=num_train_epochs, | |
| per_device_train_batch_size=per_device_batch_size, | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| learning_rate=learning_rate, | |
| weight_decay=get_config_value(training_config, "weight_decay", 0.01), | |
| max_grad_norm=get_config_value(training_config, "max_grad_norm", 1.0), | |
| warmup_ratio=get_config_value(training_config, "warmup_ratio", 0.03), | |
| lr_scheduler_type=get_config_value(training_config, "lr_scheduler_type", "cosine"), | |
| logging_steps=get_config_value(training_config, "logging_steps", 10), | |
| save_strategy=get_config_value(config, "checkpointing.save_strategy", "steps"), | |
| save_steps=get_config_value(config, "checkpointing.save_steps", 500), | |
| save_total_limit=get_config_value(config, "checkpointing.save_total_limit", 3), | |
| bf16=use_bf16, | |
| fp16=use_fp16, | |
| push_to_hub=get_config_value(huggingface_config, "push_to_hub", False), | |
| hub_model_id=get_config_value(huggingface_config, "hub_model_id", None), | |
| hub_strategy=get_config_value(huggingface_config, "hub_strategy", "every_save"), | |
| hub_private_repo=get_config_value(huggingface_config, "hub_private_repo", True), | |
| gradient_checkpointing=get_config_value(training_config, "gradient_checkpointing", True), | |
| dataloader_pin_memory=pin_memory, | |
| optim=get_config_value(training_config, "optim", "adamw_torch"), | |
| ddp_find_unused_parameters=ddp_find_unused_parameters, | |
| dataloader_drop_last=False, | |
| dataloader_num_workers=dataloader_workers, | |
| no_cuda=False if CUDA_AVAILABLE else True, | |
| **({} if fsdp_args is None else fsdp_args) | |
| ) | |
| log_info("Training arguments created successfully") | |
| return training_args | |
| def configure_custom_dataloader(trainer, dataset, config, training_args): | |
| """ | |
| Configure a custom dataloader for the trainer if needed. | |
| Args: | |
| trainer: The Trainer instance to configure | |
| dataset: The dataset to use | |
| config: Complete configuration dictionary | |
| training_args: The training arguments | |
| Returns: | |
| None (modifies trainer in-place) | |
| """ | |
| dataset_config = get_config_value(config, "dataset", {}) | |
| # Check if we need a custom dataloader | |
| if get_config_value(dataset_config, "data_loading.sequential_processing", True): | |
| log_info("Using custom sequential dataloader") | |
| # Create sequential sampler to maintain dataset order | |
| sequential_sampler = torch.utils.data.SequentialSampler(dataset) | |
| log_info("Sequential sampler created") | |
| # Define custom dataloader getter | |
| def custom_get_train_dataloader(): | |
| """Create a custom dataloader that maintains dataset order""" | |
| # Get configuration values | |
| batch_size = training_args.per_device_train_batch_size | |
| drop_last = get_config_value(dataset_config, "data_loading.drop_last", False) | |
| num_workers = training_args.dataloader_num_workers | |
| pin_memory = training_args.dataloader_pin_memory | |
| prefetch_factor = get_config_value(dataset_config, "data_loading.prefetch_factor", 2) | |
| persistent_workers = get_config_value(dataset_config, "data_loading.persistent_workers", False) | |
| # Create DataLoader with sequential sampler | |
| return DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| sampler=sequential_sampler, | |
| collate_fn=trainer.data_collator, | |
| drop_last=drop_last, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory, | |
| prefetch_factor=prefetch_factor if num_workers > 0 else None, | |
| persistent_workers=persistent_workers if num_workers > 0 else False, | |
| ) | |
| # Override the default dataloader | |
| trainer.get_train_dataloader = custom_get_train_dataloader | |
| def run_training(trainer, tokenizer, training_args): | |
| """ | |
| Run the training process and handle model saving. | |
| Args: | |
| trainer: Configured Trainer instance | |
| tokenizer: The tokenizer to save with the model | |
| training_args: Training arguments | |
| Returns: | |
| int: 0 for success, 1 for failure | |
| """ | |
| log_info("Starting training...") | |
| trainer.train() | |
| log_info("Training complete! Saving final model...") | |
| trainer.save_model() | |
| tokenizer.save_pretrained(training_args.output_dir) | |
| # Push to Hub if configured | |
| if training_args.push_to_hub: | |
| log_info(f"Pushing model to Hugging Face Hub: {training_args.hub_model_id}") | |
| trainer.push_to_hub() | |
| log_info("Training completed successfully!") | |
| return 0 | |
| def main(): | |
| """ | |
| Main entry point for the training script. | |
| Returns: | |
| int: 0 for success, non-zero for failure | |
| """ | |
| # Set up logging | |
| logger.info("Starting training process") | |
| try: | |
| # Verify critical imports are available | |
| if not transformers_available: | |
| log_info("❌ Error: transformers library not available. Please install it with: pip install transformers") | |
| return 1 | |
| # Check for required classes | |
| for required_class in ["Trainer", "TrainingArguments", "TrainerCallback"]: | |
| if not hasattr(transformers, required_class): | |
| log_info(f"❌ Error: {required_class} not found in transformers. Please update transformers.") | |
| return 1 | |
| # Check for potential import order issue and warn early | |
| if "transformers" in sys.modules and "unsloth" in sys.modules: | |
| if list(sys.modules.keys()).index("transformers") < list(sys.modules.keys()).index("unsloth"): | |
| log_info("⚠️ Warning: transformers was imported before unsloth. This may affect performance.") | |
| log_info(" For optimal performance in future runs, import unsloth first.") | |
| # Parse command line arguments | |
| args = parse_args() | |
| # Set up environment and load configuration | |
| transformers_config, seed = setup_environment(args) | |
| # Load model and tokenizer | |
| try: | |
| model, tokenizer = setup_model_and_tokenizer(transformers_config) | |
| except Exception as e: | |
| logger.error(f"Error setting up model: {str(e)}") | |
| return 1 | |
| # Load dataset and create data collator | |
| try: | |
| dataset, data_collator = setup_dataset_and_collator(transformers_config, tokenizer) | |
| except Exception as e: | |
| logger.error(f"Error setting up dataset: {str(e)}") | |
| return 1 | |
| # Configure training arguments | |
| try: | |
| training_args = create_training_arguments(transformers_config, dataset) | |
| except Exception as e: | |
| logger.error(f"Error configuring training arguments: {str(e)}") | |
| return 1 | |
| # Initialize trainer with callbacks | |
| log_info("Initializing Trainer") | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset, | |
| data_collator=data_collator, | |
| callbacks=[LoggingCallback(model=model, dataset=dataset)], | |
| ) | |
| # Configure custom dataloader if needed | |
| try: | |
| configure_custom_dataloader(trainer, dataset, transformers_config, training_args) | |
| except Exception as e: | |
| logger.error(f"Error configuring custom dataloader: {str(e)}") | |
| return 1 | |
| # Run training process | |
| try: | |
| return run_training(trainer, tokenizer, training_args) | |
| except Exception as e: | |
| logger.error(f"Training failed with error: {str(e)}") | |
| # Log GPU memory for debugging | |
| log_gpu_memory_usage(label="Error") | |
| # Print full stack trace | |
| traceback.print_exc() | |
| return 1 | |
| except Exception as e: | |
| logger.error(f"Error in main function: {str(e)}") | |
| traceback.print_exc() | |
| return 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |