Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| # Basic Python imports | |
| import os | |
| import sys | |
| import json | |
| import argparse | |
| import logging | |
| from datetime import datetime | |
| import time | |
| import warnings | |
| from importlib.util import find_spec | |
| import multiprocessing | |
| import torch | |
| import random | |
| import numpy as np | |
| from tqdm import tqdm | |
| # Check hardware capabilities first | |
| CUDA_AVAILABLE = "CUDA_VISIBLE_DEVICES" in os.environ or os.environ.get("NVIDIA_VISIBLE_DEVICES") != "" | |
| NUM_GPUS = torch.cuda.device_count() if CUDA_AVAILABLE else 0 | |
| DEVICE_TYPE = "cuda" if CUDA_AVAILABLE else "cpu" | |
| # Set the multiprocessing start method to 'spawn' for CUDA compatibility | |
| if CUDA_AVAILABLE: | |
| try: | |
| multiprocessing.set_start_method('spawn', force=True) | |
| print("Set multiprocessing start method to 'spawn' for CUDA compatibility") | |
| except RuntimeError: | |
| # Method already set, which is fine | |
| print("Multiprocessing start method already set") | |
| # Now import the rest of the modules | |
| import torch | |
| # Configure logging early | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s", | |
| handlers=[logging.StreamHandler(sys.stdout)] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Set other loggers to WARNING to reduce noise and ensure our logs are visible | |
| logging.getLogger("transformers").setLevel(logging.WARNING) | |
| logging.getLogger("datasets").setLevel(logging.WARNING) | |
| logging.getLogger("accelerate").setLevel(logging.WARNING) | |
| logging.getLogger("torch").setLevel(logging.WARNING) | |
| logging.getLogger("bitsandbytes").setLevel(logging.WARNING) | |
| # Import Unsloth first, before other ML imports | |
| try: | |
| from unsloth import FastLanguageModel | |
| from unsloth.chat_templates import get_chat_template | |
| unsloth_available = True | |
| logger.info("Unsloth successfully imported") | |
| except ImportError: | |
| unsloth_available = False | |
| logger.warning("Unsloth not available. Please install with: pip install unsloth") | |
| # Now import other ML libraries | |
| try: | |
| import transformers | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| TrainingArguments, | |
| Trainer, | |
| TrainerCallback, | |
| set_seed, | |
| BitsAndBytesConfig | |
| ) | |
| logger.info(f"Transformers version: {transformers.__version__}") | |
| except ImportError: | |
| logger.error("Transformers not available. This is a critical dependency.") | |
| # Check availability of libraries | |
| peft_available = find_spec("peft") is not None | |
| if peft_available: | |
| import peft | |
| logger.info(f"PEFT version: {peft.__version__}") | |
| else: | |
| logger.warning("PEFT not available. Parameter-efficient fine-tuning will not be used.") | |
| # Import datasets library after the main ML libraries | |
| try: | |
| from datasets import load_dataset | |
| logger.info("Datasets library successfully imported") | |
| except ImportError: | |
| logger.error("Datasets library not available. This is required for loading training data.") | |
| # Define a clean logging function for HF Space compatibility | |
| def log_info(message): | |
| """Log information in a format compatible with Hugging Face Spaces""" | |
| # Just use the logger, but ensure consistent formatting | |
| logger.info(message) | |
| # Also ensure output is flushed immediately for streaming | |
| sys.stdout.flush() | |
| # Check for BitsAndBytes | |
| try: | |
| from transformers import BitsAndBytesConfig | |
| bitsandbytes_available = True | |
| except ImportError: | |
| bitsandbytes_available = False | |
| logger.warning("BitsAndBytes not available. 4-bit quantization will not be used.") | |
| # Check for PEFT | |
| try: | |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
| peft_available = True | |
| except ImportError: | |
| peft_available = False | |
| logger.warning("PEFT not available. Parameter-efficient fine-tuning will not be used.") | |
| def load_env_variables(): | |
| """Load environment variables from system, .env file, or Hugging Face Space variables.""" | |
| # Check if we're running in a Hugging Face Space | |
| if os.environ.get("SPACE_ID"): | |
| logging.info("Running in Hugging Face Space") | |
| # Log the presence of variables (without revealing values) | |
| logging.info(f"HF_TOKEN available: {bool(os.environ.get('HF_TOKEN'))}") | |
| logging.info(f"HF_USERNAME available: {bool(os.environ.get('HF_USERNAME'))}") | |
| # If username is not set, try to extract from SPACE_ID | |
| if not os.environ.get("HF_USERNAME") and "/" in os.environ.get("SPACE_ID", ""): | |
| username = os.environ.get("SPACE_ID").split("/")[0] | |
| os.environ["HF_USERNAME"] = username | |
| logging.info(f"Set HF_USERNAME from SPACE_ID: {username}") | |
| else: | |
| # Try to load from .env file if not in a Space | |
| try: | |
| from dotenv import load_dotenv | |
| # First check the current directory | |
| env_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".env") | |
| if os.path.exists(env_path): | |
| load_dotenv(env_path) | |
| logging.info(f"Loaded environment variables from {env_path}") | |
| logging.info(f"HF_TOKEN loaded from .env file: {bool(os.environ.get('HF_TOKEN'))}") | |
| logging.info(f"HF_USERNAME loaded from .env file: {bool(os.environ.get('HF_USERNAME'))}") | |
| logging.info(f"HF_SPACE_NAME loaded from .env file: {bool(os.environ.get('HF_SPACE_NAME'))}") | |
| else: | |
| # Try the shared directory as fallback | |
| shared_env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "shared", ".env") | |
| if os.path.exists(shared_env_path): | |
| load_dotenv(shared_env_path) | |
| logging.info(f"Loaded environment variables from {shared_env_path}") | |
| logging.info(f"HF_TOKEN loaded from shared .env file: {bool(os.environ.get('HF_TOKEN'))}") | |
| logging.info(f"HF_USERNAME loaded from shared .env file: {bool(os.environ.get('HF_USERNAME'))}") | |
| logging.info(f"HF_SPACE_NAME loaded from shared .env file: {bool(os.environ.get('HF_SPACE_NAME'))}") | |
| else: | |
| logging.warning(f"No .env file found in current or shared directory") | |
| except ImportError: | |
| logging.warning("python-dotenv not installed, not loading from .env file") | |
| if not os.environ.get("HF_TOKEN"): | |
| logger.warning("HF_TOKEN is not set. Pushing to Hugging Face Hub will not work.") | |
| if not os.environ.get("HF_USERNAME"): | |
| logger.warning("HF_USERNAME is not set. Using default username.") | |
| if not os.environ.get("HF_SPACE_NAME"): | |
| logger.warning("HF_SPACE_NAME is not set. Using default space name.") | |
| # Set HF_TOKEN for huggingface_hub | |
| if os.environ.get("HF_TOKEN"): | |
| os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN") | |
| def load_configs(base_path): | |
| """Load configuration from transformers_config.json file.""" | |
| # Using a single consolidated config file | |
| config_file = base_path | |
| try: | |
| with open(config_file, "r") as f: | |
| config = json.load(f) | |
| logger.info(f"Loaded configuration from {config_file}") | |
| return config | |
| except Exception as e: | |
| logger.error(f"Error loading {config_file}: {e}") | |
| raise | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Fine-tune a language model on a text dataset") | |
| parser.add_argument("--config", type=str, default="transformers_config.json", help="Path to configuration file") | |
| return parser.parse_args() | |
| def load_model_and_tokenizer(config): | |
| """Load model and tokenizer with proper error handling and optimizations.""" | |
| try: | |
| if not unsloth_available: | |
| logger.error("Unsloth is required for training with pre-quantized model") | |
| logger.error("Please ensure unsloth is in requirements.txt") | |
| raise ImportError("Unsloth is required for this training setup") | |
| # Get model name correctly from config | |
| model_name = config.get("model_name") or config.get("model", {}).get("name") | |
| logger.info(f"Loading model: {model_name}") | |
| if not model_name: | |
| raise ValueError("Model name not found in configuration. Please check your transformers_config.json file.") | |
| logger.info("Using Unsloth optimizations with pre-quantized model") | |
| # First detect if we have a GPU | |
| if torch.cuda.is_available(): | |
| gpu_count = torch.cuda.device_count() | |
| logger.info(f"Found {gpu_count} CUDA devices") | |
| else: | |
| logger.warning("No CUDA devices detected. Training will be slow on CPU!") | |
| gpu_count = 0 | |
| # Set default dtype for better numerics | |
| if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: | |
| # Use bfloat16 for Ampere or newer | |
| dtype = torch.bfloat16 | |
| logger.info("Using bfloat16 precision (Ampere+ GPU)") | |
| elif torch.cuda.is_available(): | |
| # Use float16 for older GPUs | |
| dtype = torch.float16 | |
| logger.info("Using float16 precision (pre-Ampere GPU)") | |
| else: | |
| # CPU, use default dtype | |
| dtype = None | |
| logger.info("Using default precision (CPU)") | |
| # Check for flash attention as the last dependency check | |
| use_flash_attention = config.get("use_flash_attention", True) | |
| if use_flash_attention and not find_spec("flash_attn"): | |
| logger.warning("flash-attn not found. Will continue without flash attention.") | |
| logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation") | |
| use_flash_attention = False | |
| # Set device map based on config or default to "auto" | |
| device_map = config.get("hardware", {}).get("hardware_setup", {}).get("device_map", "auto") | |
| # Calculate max memory settings if multiple GPUs are available | |
| max_memory = None | |
| if gpu_count > 1: | |
| memory_per_gpu = config.get("hardware", {}).get("specs", {}).get("vram_per_gpu", 24) | |
| max_memory = {i: f"{int(memory_per_gpu * 0.85)}GiB" for i in range(gpu_count)} | |
| max_memory["cpu"] = "64GiB" # Allow CPU offloading if needed | |
| # Load model with proper error handling for out-of-memory | |
| try: | |
| # Improved memory settings for multi-GPU setup | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=model_name, | |
| max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048), | |
| dtype=dtype, | |
| device_map=device_map, | |
| max_memory=max_memory, | |
| # Don't explicitly use flash attention config here, let Unsloth handle it | |
| ) | |
| except RuntimeError as e: | |
| if "CUDA out of memory" in str(e): | |
| logger.error("Out of GPU memory. Consider using a smaller batch size or gradient accumulation steps.") | |
| raise | |
| else: | |
| # Try again with CPU placement to see if it's a memory issue | |
| logger.warning(f"Error loading model on default device: {str(e)}") | |
| logger.warning("Attempting to load with device_map='cpu' and no specific dtype") | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=model_name, | |
| max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048), | |
| dtype=None, | |
| device_map={"": "cpu"}, | |
| ) | |
| logger.warning("Model loaded on CPU. Training will be very slow.") | |
| # Ensure model and optimizer init is on the same device | |
| logger.info(f"Model device map: {model.hf_device_map if hasattr(model, 'hf_device_map') else 'Not available'}") | |
| # Apply Unsloth's training optimizations with config parameters | |
| unsloth_config = config.get("unsloth", {}) | |
| # Get dropout value; if not explicitly zero, warn about performance implications | |
| lora_dropout = unsloth_config.get("dropout", 0.05) | |
| if lora_dropout > 0: | |
| logger.warning(f"Unsloth works best with dropout=0, but config has dropout={lora_dropout}") | |
| logger.warning("This will impact performance but training will still work") | |
| logger.warning("Consider setting dropout=0 in your config for better performance") | |
| # Apply optimizations | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=unsloth_config.get("r", 32), | |
| target_modules=unsloth_config.get("target_modules", | |
| ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]), | |
| lora_alpha=unsloth_config.get("alpha", 16), | |
| lora_dropout=lora_dropout, # Using the value from config or default | |
| bias="none", | |
| use_gradient_checkpointing=config.get("gradient_checkpointing", True) or config.get("training", {}).get("gradient_checkpointing", True), | |
| random_state=config.get("seed", 42), | |
| ) | |
| logger.info("Unsloth optimizations applied successfully") | |
| # Set up tokenizer settings | |
| chat_template = config.get("chat_template") or config.get("tokenizer", {}).get("chat_template") | |
| if chat_template: | |
| try: | |
| # Get the correct chat template for phi models | |
| template = get_chat_template("phi") | |
| # Correctly apply the template to the tokenizer (it's a string) | |
| if isinstance(template, str): | |
| tokenizer.chat_template = template | |
| logger.info("Set phi chat template (string)") | |
| else: | |
| # If it's not a string, it's likely already a template object | |
| tokenizer.chat_template = template | |
| logger.info("Set phi chat template (object)") | |
| except Exception as e: | |
| logger.warning(f"Failed to set chat template: {str(e)}") | |
| logger.warning("Chat formatting may not work correctly, but training can continue") | |
| # Ensure proper token settings | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| logger.info(f"Set pad_token_id to eos_token_id: {tokenizer.pad_token_id}") | |
| return model, tokenizer | |
| except Exception as e: | |
| logger.error(f"Error in model/tokenizer loading: {str(e)}") | |
| logger.error("If missing dependencies, check the requirements.txt file") | |
| raise | |
| def load_dataset_with_mapping(dataset_config): | |
| """Load dataset and apply appropriate column mappings.""" | |
| try: | |
| # Load dataset | |
| dataset_name = dataset_config.get("dataset", {}).get("name", "") | |
| dataset_split = dataset_config.get("dataset", {}).get("split", "train") | |
| if not dataset_name: | |
| raise ValueError("Dataset name not provided in configuration") | |
| logger.info(f"Loading pre-processed dataset {dataset_name}, split {dataset_split}") | |
| try: | |
| dataset = load_dataset(dataset_name, split=dataset_split) | |
| # Verify the dataset was actually loaded and is not None | |
| if dataset is None: | |
| raise ValueError(f"Dataset {dataset_name} (split {dataset_split}) loaded as None - check dataset exists and is accessible") | |
| # Check if the dataset is empty | |
| if len(dataset) == 0: | |
| raise ValueError(f"Dataset {dataset_name} (split {dataset_split}) is empty (contains 0 examples)") | |
| # Verify conversations field specifically | |
| if "conversations" not in dataset.column_names: | |
| raise ValueError(f"Dataset {dataset_name} missing required 'conversations' column") | |
| # Validate conversation structure | |
| if len(dataset) > 0: | |
| sample = dataset[0] | |
| conversations = sample.get("conversations", []) | |
| if conversations: | |
| first_conv = conversations[0] | |
| if isinstance(first_conv, dict): | |
| # Check actual fields | |
| fields = list(first_conv.keys()) | |
| logger.info(f"Conversation fields: {fields}") | |
| # Verify only 'content' field exists | |
| if fields == ["content"]: | |
| logger.info("Confirmed conversations have correct format with only 'content' field") | |
| else: | |
| logger.warning(f"Unexpected conversation fields: {fields}") | |
| logger.warning("Expected only 'content' field") | |
| # Check a sample of conversation entries to validate structure | |
| logger.info("Validating conversation structure...") | |
| for i in range(min(5, len(dataset))): | |
| conv = dataset[i].get("conversations") | |
| if conv is None: | |
| logger.warning(f"Example {i} has None as 'conversations' value") | |
| elif not isinstance(conv, list): | |
| logger.warning(f"Example {i} has non-list 'conversations': {type(conv)}") | |
| elif len(conv) == 0: | |
| logger.warning(f"Example {i} has empty conversations list") | |
| else: | |
| # Look at the first conversation entry | |
| first_entry = conv[0] | |
| if isinstance(first_entry, dict) and "content" in first_entry: | |
| logger.info(f"Content field example: {str(first_entry['content'])[:50]}...") | |
| else: | |
| logger.warning(f"Example {i} missing 'content' key in conversation") | |
| except Exception as dataset_error: | |
| logger.error(f"Failed to load dataset {dataset_name}: {str(dataset_error)}") | |
| logger.error("Make sure the dataset exists and you have proper access permissions") | |
| logger.error("This could be due to authentication issues with your HF_TOKEN") | |
| raise | |
| return dataset | |
| except Exception as e: | |
| logger.error(f"Error loading dataset: {str(e)}") | |
| return 1 | |
| def format_phi_chat(messages, dataset_config): | |
| """Format messages according to phi-4's chat template and dataset config. | |
| Only formats the conversation structure, preserves the actual content.""" | |
| formatted_chat = "" | |
| # Get role templates from config | |
| roles = dataset_config.get("data_formatting", {}).get("roles", { | |
| "system": "System: {content}\n\n", | |
| "human": "Human: {content}\n\n", | |
| "assistant": "Assistant: {content}\n\n" | |
| }) | |
| # Handle each message in the conversation | |
| for message in messages: | |
| if not isinstance(message, dict) or "content" not in message: | |
| logger.warning(f"Skipping invalid message format: {message}") | |
| continue | |
| content = message.get("content", "") # Don't strip() - preserve exact content | |
| # Skip empty content | |
| if not content: | |
| continue | |
| # Only add role prefixes based on position/content | |
| if "[RESEARCH INTRODUCTION]" in content: | |
| # System message | |
| template = roles.get("system", "System: {content}\n\n") | |
| formatted_chat = template.format(content=content) + formatted_chat | |
| else: | |
| # Alternate between human and assistant for regular conversation turns | |
| # In phi-4 format, human messages come first, followed by assistant responses | |
| if len(formatted_chat.split("Human:")) == len(formatted_chat.split("Assistant:")): | |
| # If equal numbers of Human and Assistant messages, next is Human | |
| template = roles.get("human", "Human: {content}\n\n") | |
| else: | |
| # Otherwise, next is Assistant | |
| template = roles.get("assistant", "Assistant: {content}\n\n") | |
| formatted_chat += template.format(content=content) | |
| return formatted_chat | |
| class SimpleDataCollator: | |
| def __init__(self, tokenizer, dataset_config): | |
| self.tokenizer = tokenizer | |
| self.max_seq_length = min(dataset_config.get("max_seq_length", 2048), tokenizer.model_max_length) | |
| self.stats = { | |
| "processed": 0, | |
| "skipped": 0, | |
| "total_tokens": 0 | |
| } | |
| logger.info(f"Initialized SimpleDataCollator with max_seq_length={self.max_seq_length}") | |
| def __call__(self, features): | |
| # Initialize tensors on CPU to save GPU memory | |
| batch = { | |
| "input_ids": [], | |
| "attention_mask": [], | |
| "labels": [] | |
| } | |
| for feature in features: | |
| paper_id = feature.get("article_id", "unknown") | |
| prompt_num = feature.get("prompt_number", 0) | |
| conversations = feature.get("conversations", []) | |
| if not conversations: | |
| logger.warning(f"No conversations for paper_id {paper_id}, prompt {prompt_num}") | |
| self.stats["skipped"] += 1 | |
| continue | |
| # Get the content directly | |
| content = conversations[0].get("content", "") | |
| if not content: | |
| logger.warning(f"Empty content for paper_id {paper_id}, prompt {prompt_num}") | |
| self.stats["skipped"] += 1 | |
| continue | |
| # Process the content string by tokenizing it | |
| if isinstance(content, str): | |
| # Tokenize the content string | |
| input_ids = self.tokenizer.encode(content, add_special_tokens=True) | |
| else: | |
| # If somehow the content is already tokenized (not a string), use it directly | |
| input_ids = content | |
| # Truncate if needed | |
| if len(input_ids) > self.max_seq_length: | |
| input_ids = input_ids[:self.max_seq_length] | |
| logger.warning(f"Truncated sequence for paper_id {paper_id}, prompt {prompt_num}") | |
| # Create attention mask (1s for all tokens) | |
| attention_mask = [1] * len(input_ids) | |
| # Add to batch | |
| batch["input_ids"].append(input_ids) | |
| batch["attention_mask"].append(attention_mask) | |
| batch["labels"].append(input_ids.copy()) # For causal LM, labels = input_ids | |
| self.stats["processed"] += 1 | |
| self.stats["total_tokens"] += len(input_ids) | |
| # Log statistics periodically | |
| if self.stats["processed"] % 100 == 0: | |
| avg_tokens = self.stats["total_tokens"] / max(1, self.stats["processed"]) | |
| logger.info(f"Data collation stats: processed={self.stats['processed']}, " | |
| f"skipped={self.stats['skipped']}, avg_tokens={avg_tokens:.1f}") | |
| # Convert to tensors or pad sequences (PyTorch will handle) | |
| if batch["input_ids"]: | |
| # Pad sequences to max length in batch using the tokenizer | |
| batch = self.tokenizer.pad( | |
| batch, | |
| padding="max_length", | |
| max_length=self.max_seq_length, | |
| return_tensors="pt" | |
| ) | |
| return batch | |
| else: | |
| # Return empty batch if no valid examples | |
| return {k: [] for k in batch} | |
| class LoggingCallback(TrainerCallback): | |
| def __init__(self, model=None, dataset=None): | |
| super().__init__() | |
| self.training_started = time.time() | |
| self.last_log_time = time.time() | |
| self.last_step = 0 | |
| self.model = model | |
| self.dataset = dataset | |
| def on_train_begin(self, args, state, control, **kwargs): | |
| log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===") | |
| # Log model info if available | |
| if self.model is not None: | |
| log_info(f"Model parameters: {sum(p.numel() for p in self.model.parameters())/1e6:.2f}M") | |
| # Log dataset info if available | |
| if self.dataset is not None: | |
| log_info(f"Dataset size: {len(self.dataset)} examples") | |
| # Log important training parameters for visibility | |
| total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS | |
| total_steps = int(len(self.dataset or []) / (args.per_device_train_batch_size * NUM_GPUS * args.gradient_accumulation_steps) * args.num_train_epochs) | |
| log_info(f"Training plan: {len(self.dataset or [])} examples over {args.num_train_epochs} epochs ≈ {total_steps} steps") | |
| log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total") | |
| # Log memory information in compact format | |
| if CUDA_AVAILABLE: | |
| memory_info = [] | |
| for i in range(NUM_GPUS): | |
| allocated = torch.cuda.memory_allocated(i) / 1024**2 | |
| max_mem = torch.cuda.max_memory_allocated(i) / 1024**2 | |
| memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)") | |
| log_info(f"Initial memory usage - {', '.join(memory_info)}") | |
| def check_dependencies(): | |
| """Check if all required dependencies are installed and in the correct order.""" | |
| missing_packages = [] | |
| order_issues = [] | |
| # Define required packages with versions | |
| required_packages = { | |
| "unsloth": ">=2024.3", | |
| "transformers": ">=4.38.0", | |
| "peft": ">=0.9.0", | |
| "accelerate": ">=0.27.0" | |
| } | |
| # Check for required packages | |
| for package, version in required_packages.items(): | |
| try: | |
| if package == "unsloth" and not unsloth_available: | |
| missing_packages.append(f"{package}{version}") | |
| elif package == "peft" and not peft_available: | |
| missing_packages.append(f"{package}{version}") | |
| else: | |
| module = __import__(package) | |
| logger.info(f"Using {package} version {getattr(module, '__version__', 'unknown')}") | |
| except ImportError: | |
| missing_packages.append(f"{package}{version}") | |
| # Check import order | |
| try: | |
| import sys | |
| modules = list(sys.modules.keys()) | |
| if 'transformers' in modules and 'unsloth' in modules: | |
| try: | |
| transformers_idx = modules.index('transformers') | |
| unsloth_idx = modules.index('unsloth') | |
| if transformers_idx < unsloth_idx: | |
| order_issues.append("For optimal performance, unsloth should be imported before transformers") | |
| except ValueError: | |
| pass | |
| except Exception as e: | |
| logger.warning(f"Could not check module import order: {str(e)}") | |
| # Check optional dependencies | |
| optional_packages = { | |
| "flash_attn": "Flash attention support", | |
| "bitsandbytes": "4-bit quantization support" | |
| } | |
| for package, feature in optional_packages.items(): | |
| if find_spec(package): | |
| logger.info(f"Found {package} - {feature} enabled") | |
| else: | |
| logger.warning(f"{package} not found - {feature} will not be available") | |
| # Report missing required packages | |
| if missing_packages: | |
| logger.error("Critical dependencies missing:") | |
| for pkg in missing_packages: | |
| logger.error(f" - {pkg}") | |
| logger.error("Please install the missing dependencies with:") | |
| logger.error(f" pip install {' '.join(missing_packages)}") | |
| return False | |
| # Report order issues as warnings | |
| for issue in order_issues: | |
| logger.warning(issue) | |
| return True | |
| def update_huggingface_space(): | |
| """Update the Hugging Face Space with the current code.""" | |
| log_info("Updating Hugging Face Space...") | |
| update_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "update_space.py") | |
| if not os.path.exists(update_script): | |
| logger.warning(f"Update space script not found at {update_script}") | |
| return False | |
| try: | |
| import subprocess | |
| # Explicitly set space_name to ensure we're targeting the right Space | |
| result = subprocess.run( | |
| [sys.executable, update_script, "--force", "--space_name", "phi4training"], | |
| capture_output=True, text=True, check=False | |
| ) | |
| if result.returncode == 0: | |
| log_info("Hugging Face Space updated successfully!") | |
| log_info(f"Space URL: https://huggingface.co/spaces/George-API/phi4training") | |
| return True | |
| else: | |
| logger.error(f"Failed to update Hugging Face Space: {result.stderr}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error updating Hugging Face Space: {str(e)}") | |
| return False | |
| def validate_huggingface_credentials(): | |
| """Validate Hugging Face credentials to ensure they work correctly.""" | |
| if not os.environ.get("HF_TOKEN"): | |
| logger.warning("HF_TOKEN not found. Skipping Hugging Face credentials validation.") | |
| return False | |
| try: | |
| # Import here to avoid requiring huggingface_hub if not needed | |
| from huggingface_hub import HfApi, login | |
| # Try to login with the token | |
| login(token=os.environ.get("HF_TOKEN")) | |
| # Check if we can access the API | |
| api = HfApi() | |
| username = os.environ.get("HF_USERNAME", "George-API") | |
| space_name = os.environ.get("HF_SPACE_NAME", "phi4training") | |
| # Try to get whoami info | |
| user_info = api.whoami() | |
| logger.info(f"Successfully authenticated with Hugging Face as {user_info['name']}") | |
| # Check if we're using the expected Space | |
| expected_space_id = "George-API/phi4training" | |
| actual_space_id = f"{username}/{space_name}" | |
| if actual_space_id != expected_space_id: | |
| logger.warning(f"Using Space '{actual_space_id}' instead of the expected '{expected_space_id}'") | |
| logger.warning(f"Make sure this is intentional. To use the correct Space, update your .env file.") | |
| else: | |
| logger.info(f"Confirmed using Space: {expected_space_id}") | |
| # Check if the space exists | |
| try: | |
| space_id = f"{username}/{space_name}" | |
| space_info = api.space_info(repo_id=space_id) | |
| logger.info(f"Space {space_id} is accessible at: https://huggingface.co/spaces/{space_id}") | |
| return True | |
| except Exception as e: | |
| logger.warning(f"Could not access Space {username}/{space_name}: {str(e)}") | |
| logger.warning("Space updating may not work correctly") | |
| return False | |
| except ImportError: | |
| logger.warning("huggingface_hub not installed. Cannot validate Hugging Face credentials.") | |
| return False | |
| except Exception as e: | |
| logger.warning(f"Error validating Hugging Face credentials: {str(e)}") | |
| return False | |
| def main(): | |
| # Set up logging | |
| logger.info("Starting training process") | |
| try: | |
| # Check dependencies first, before any other operations | |
| if not check_dependencies(): | |
| logger.error("Aborting due to missing critical dependencies") | |
| return 1 | |
| # Parse arguments | |
| args = parse_args() | |
| # Load environment variables | |
| load_env_variables() | |
| # Validate Hugging Face credentials if we're going to use them | |
| validate_huggingface_credentials() | |
| # Load configuration | |
| try: | |
| transformers_config = load_configs(args.config) | |
| hardware_config = transformers_config.get("hardware", {}) | |
| dataset_config = transformers_config.get("dataset", {}) | |
| logger.info("Configuration loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Error loading configuration: {e}") | |
| return 1 | |
| # Check if we're in distributed mode | |
| is_distributed = "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1 | |
| if is_distributed: | |
| local_rank = int(os.environ.get("LOCAL_RANK", "0")) | |
| log_info(f"Running in distributed mode with {os.environ.get('WORLD_SIZE')} processes, local_rank: {local_rank}") | |
| else: | |
| log_info("Running in non-distributed mode (single process)") | |
| # Set random seed for reproducibility | |
| seed = transformers_config.get("seed", 42) | |
| set_seed(seed) | |
| logger.info(f"Set random seed to {seed}") | |
| # Load model and tokenizer using the consolidated config | |
| model, tokenizer = load_model_and_tokenizer(transformers_config) | |
| # Empty CUDA cache to ensure clean state | |
| if CUDA_AVAILABLE: | |
| torch.cuda.empty_cache() | |
| log_info("Cleared CUDA cache") | |
| # Setup environment variable for CUDA memory allocation | |
| if CUDA_AVAILABLE: | |
| system_settings = hardware_config.get("system_settings", {}) | |
| cuda_memory_fraction = system_settings.get("cuda_memory_fraction", 0.85) | |
| if cuda_memory_fraction < 1.0: | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:128,expandable_segments:True" | |
| log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128") | |
| try: | |
| log_info("Loading dataset...") | |
| dataset = load_dataset_with_mapping(dataset_config) | |
| # Extra validation to catch None/empty dataset issues | |
| if dataset is None: | |
| logger.error("Dataset is None! Cannot proceed with training.") | |
| return 1 | |
| if not hasattr(dataset, '__len__') or len(dataset) == 0: | |
| logger.error("Dataset is empty! Cannot proceed with training.") | |
| return 1 | |
| log_info(f"Dataset loaded with {len(dataset)} examples") | |
| # Minimal validation before proceeding | |
| if dataset is None or len(dataset) == 0: | |
| logger.error("Dataset is empty or None! Cannot proceed with training.") | |
| return 1 | |
| # Create data collator | |
| data_collator = SimpleDataCollator(tokenizer, dataset_config) | |
| # Verify precision settings - ensure only one of bf16/fp16 is set, with bf16 taking precedence | |
| # First check hardware config, then transformers config | |
| use_bf16 = False | |
| use_fp16 = False | |
| # Check hardware config first | |
| hardware_precision = hardware_config.get("training_optimizations", {}).get("mixed_precision", "") | |
| if hardware_precision.lower() == "bf16": | |
| use_bf16 = True | |
| log_info("Using BF16 precision from hardware config") | |
| elif hardware_precision.lower() == "fp16": | |
| use_fp16 = True | |
| log_info("Using FP16 precision from hardware config") | |
| else: | |
| # Fall back to transformers config | |
| use_bf16 = transformers_config.get("bf16", False) or transformers_config.get("torch_dtype", "") == "bfloat16" | |
| use_fp16 = transformers_config.get("fp16", False) and not use_bf16 # Only use fp16 if bf16 is not set | |
| log_info(f"Using precision: {'bf16' if use_bf16 else 'fp16' if use_fp16 else 'full precision'}") | |
| # Get per device batch size - from transformers config, but possibly overridden by hardware config | |
| per_device_batch_size = transformers_config.get("training", {}).get("per_device_train_batch_size", 16) | |
| gradient_accumulation_steps = transformers_config.get("training", {}).get("gradient_accumulation_steps", 3) | |
| # Get multi-GPU strategy from hardware config (default to data_parallel) | |
| multi_gpu_strategy = hardware_config.get("training_optimizations", {}).get("multi_gpu_strategy", "data_parallel") | |
| logger.info(f"Multi-GPU strategy: {multi_gpu_strategy}") | |
| # For multi-GPU setup, adjust for better balance | |
| if CUDA_AVAILABLE and NUM_GPUS > 1: | |
| log_info(f"Multi-GPU setup: Adjusting for {NUM_GPUS} GPUs") | |
| # Set up FSDP for multi-GPU training if specified and in distributed mode | |
| fsdp_config = None | |
| if multi_gpu_strategy == "fsdp" and is_distributed and NUM_GPUS > 1: | |
| try: | |
| from torch.distributed.fsdp import ( | |
| FullyShardedDataParallel as FSDP, | |
| MixedPrecision, | |
| BackwardPrefetch, | |
| ShardingStrategy, | |
| CPUOffload, | |
| ) | |
| from torch.distributed.fsdp.wrap import ( | |
| transformer_auto_wrap_policy, | |
| enable_wrap, | |
| wrap, | |
| ) | |
| log_info("Using FSDP for distributed training") | |
| # Configure FSDP | |
| fsdp_config = { | |
| "fsdp_transformer_layer_cls_to_wrap": ["LlamaDecoderLayer"], | |
| "fsdp_offload_params": False, | |
| "fsdp_backward_prefetch": "BACKWARD_PRE", | |
| "fsdp_min_num_params": 1e6, | |
| "fsdp_sharding_strategy": 1, # FULL_SHARD | |
| } | |
| if use_bf16 or use_fp16: | |
| precision_type = "bf16" if use_bf16 else "fp16" | |
| fsdp_config["fsdp_state_dict_type"] = "FULL_STATE_DICT" | |
| log_info(f"FSDP using mixed precision: {precision_type}") | |
| except ImportError: | |
| log_info("FSDP imports failed, falling back to standard DDP") | |
| fsdp_config = None | |
| elif multi_gpu_strategy == "fsdp" and not is_distributed: | |
| log_info("FSDP disabled: requires distributed environment (use torchrun or accelerate)") | |
| log_info("Using DataParallel for multi-GPU training instead") | |
| else: | |
| log_info(f"Using {multi_gpu_strategy} for multi-GPU training") | |
| # Get system settings from hardware config | |
| dataloader_workers = hardware_config.get("system_settings", {}).get("dataloader_num_workers", 2) | |
| pin_memory = hardware_config.get("system_settings", {}).get("dataloader_pin_memory", True) | |
| # Set up training arguments | |
| log_info("Setting up training arguments") | |
| # Handle FSDP configuration | |
| fsdp_config = transformers_config.get("distributed_training", {}).get("fsdp_config", {}) | |
| fsdp_enabled = fsdp_config.get("enabled", False) | |
| # Only set FSDP args if explicitly enabled | |
| fsdp_args = None | |
| if fsdp_enabled and is_distributed and NUM_GPUS > 1: | |
| fsdp_args = { | |
| "fsdp": ["full_shard", "auto_wrap"], | |
| "fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer", | |
| "fsdp_offload_params": fsdp_config.get("offload_params", False), | |
| "fsdp_state_dict_type": "FULL_STATE_DICT", | |
| "fsdp_sharding_strategy": 1, # FULL_SHARD | |
| } | |
| log_info("FSDP configuration enabled") | |
| else: | |
| log_info("FSDP disabled, using standard data parallel") | |
| # Check if we're running in a Space | |
| is_space = bool(os.environ.get("SPACE_ID")) | |
| # Create training arguments | |
| training_args = TrainingArguments( | |
| output_dir=transformers_config.get("output_dir", "./results") or transformers_config.get("checkpointing", {}).get("output_dir", "./results"), | |
| num_train_epochs=transformers_config.get("training", {}).get("num_train_epochs", 3), | |
| per_device_train_batch_size=per_device_batch_size, | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| learning_rate=transformers_config.get("training", {}).get("learning_rate", 2e-5), | |
| weight_decay=transformers_config.get("training", {}).get("weight_decay", 0.01), | |
| warmup_ratio=transformers_config.get("training", {}).get("warmup_ratio", 0.05), | |
| lr_scheduler_type=transformers_config.get("training", {}).get("lr_scheduler_type", "cosine"), | |
| logging_steps=transformers_config.get("training", {}).get("logging_steps", 10), | |
| save_strategy=transformers_config.get("checkpointing", {}).get("save_strategy", "steps"), | |
| save_steps=transformers_config.get("checkpointing", {}).get("save_steps", 100), | |
| save_total_limit=transformers_config.get("checkpointing", {}).get("save_total_limit", 3), | |
| fp16=use_fp16, | |
| bf16=use_bf16, | |
| max_grad_norm=transformers_config.get("training", {}).get("max_grad_norm", 1.0), | |
| push_to_hub=transformers_config.get("huggingface_hub", {}).get("push_to_hub", False), | |
| hub_model_id=transformers_config.get("huggingface_hub", {}).get("hub_model_id", None), | |
| hub_token=None if is_space else os.environ.get("HF_TOKEN", None), | |
| report_to="tensorboard", | |
| remove_unused_columns=False, # Keep all columns | |
| gradient_checkpointing=transformers_config.get("training", {}).get("gradient_checkpointing", True), | |
| dataloader_pin_memory=pin_memory, | |
| optim=transformers_config.get("training", {}).get("optim", "adamw_torch"), | |
| ddp_find_unused_parameters=False, # Improve distributed training efficiency | |
| dataloader_drop_last=False, # Process all examples | |
| dataloader_num_workers=dataloader_workers, | |
| no_cuda=False if CUDA_AVAILABLE else True, # Use CUDA if available | |
| **({} if fsdp_args is None else fsdp_args) # Only include FSDP args if configured | |
| ) | |
| log_info("Training arguments created successfully") | |
| # Validate dataset before creating sampler | |
| if dataset is None: | |
| raise ValueError("Dataset is None - cannot create sampler") | |
| # Create sequential sampler to maintain original dataset order | |
| sequential_sampler = torch.utils.data.SequentialSampler(dataset) | |
| log_info("Sequential sampler created") | |
| # Initialize trainer first | |
| log_info("Initializing Trainer") | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset, | |
| data_collator=data_collator, | |
| callbacks=[LoggingCallback(model=model, dataset=dataset)], | |
| ) | |
| # Then override the get_train_dataloader method | |
| def custom_get_train_dataloader(): | |
| """Custom dataloader that preserves original dataset order""" | |
| log_info("Creating sequential dataloader to maintain original dataset order") | |
| # Safety check - make sure dataset exists and is not None | |
| if dataset is None: | |
| raise ValueError("Dataset is None - cannot create dataloader") | |
| # Make sure dataset is not empty | |
| if len(dataset) == 0: | |
| raise ValueError("Dataset is empty - cannot create dataloader") | |
| # Create a simple sequential sampler | |
| sequential_sampler = torch.utils.data.SequentialSampler(dataset) | |
| # Verification of sequence preservation flags - simplified | |
| data_loading_config = dataset_config.get("data_loading", {}) | |
| shuffle_enabled = data_loading_config.get("shuffle", False) | |
| if shuffle_enabled: | |
| log_info("WARNING: Shuffle is enabled in configuration! This will be overridden to preserve order.") | |
| # We enforce sequential processing regardless of config | |
| # Log our approach clearly | |
| log_info("Using SequentialSampler to guarantee dataset order is preserved based on prompt_number") | |
| # Verify column order and check for 'conversations' field | |
| expected_order = ["prompt_number", "article_id", "conversations"] | |
| if hasattr(dataset, 'column_names'): | |
| actual_order = dataset.column_names | |
| # Verify all required fields exist | |
| missing_fields = [field for field in ["conversations"] if field not in actual_order] | |
| if missing_fields: | |
| raise ValueError(f"Dataset missing critical fields: {missing_fields}") | |
| if actual_order == expected_order: | |
| log_info(f"Confirmed dataset columns are in expected order: {', '.join(expected_order)}") | |
| else: | |
| log_info(f"Note: Dataset columns ({', '.join(actual_order)}) are not in expected order ({', '.join(expected_order)})") | |
| log_info("This is handled correctly by field-based access, but noting for clarity") | |
| log_info("Dataset is pre-processed with prompt_number field indicating the correct sequence") | |
| # Validate a few samples before proceeding | |
| for i in range(min(3, len(dataset))): | |
| sample = dataset[i] | |
| if "conversations" not in sample: | |
| log_info(f"WARNING: Sample {i} missing 'conversations' field") | |
| elif sample["conversations"] is None: | |
| log_info(f"WARNING: Sample {i} has None 'conversations' field") | |
| elif not isinstance(sample["conversations"], list): | |
| log_info(f"WARNING: Sample {i} has non-list 'conversations' field: {type(sample['conversations'])}") | |
| # Calculate batch size based on device availability | |
| if getattr(training_args, "no_cuda", False): | |
| batch_size = training_args.per_device_train_batch_size | |
| else: | |
| batch_size = max(training_args.per_device_train_batch_size * max(1, NUM_GPUS), 1) | |
| log_info(f"Using sequential sampler with batch size {batch_size}") | |
| # Return DataLoader with sequential sampler and extra error handling | |
| try: | |
| return torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| sampler=sequential_sampler, # Always use sequential sampler | |
| collate_fn=data_collator, | |
| drop_last=training_args.dataloader_drop_last, | |
| num_workers=training_args.dataloader_num_workers, | |
| pin_memory=training_args.dataloader_pin_memory, | |
| ) | |
| except Exception as e: | |
| log_info(f"Error creating DataLoader: {str(e)}") | |
| # Try again with minimal settings | |
| log_info("Attempting to create DataLoader with minimal settings") | |
| return torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=1, # Minimal batch size | |
| sampler=sequential_sampler, | |
| collate_fn=data_collator, | |
| num_workers=0, # No parallel workers | |
| pin_memory=False, | |
| ) | |
| # Override the get_train_dataloader method | |
| trainer.get_train_dataloader = custom_get_train_dataloader | |
| # Start training | |
| log_info("=== Starting Training ===") | |
| try: | |
| # Empty cache again right before training | |
| if CUDA_AVAILABLE: | |
| torch.cuda.empty_cache() | |
| log_info("Cleared CUDA cache before training") | |
| # Display compact training info | |
| total_steps = int((len(dataset) / (per_device_batch_size * NUM_GPUS * gradient_accumulation_steps)) * training_args.num_train_epochs) | |
| log_info(f"Training plan: {len(dataset)} examples over {training_args.num_train_epochs} epochs ≈ {total_steps} steps") | |
| trainer.train() | |
| log_info("Training completed successfully!") | |
| # Save the final model | |
| log_info("Saving final model...") | |
| trainer.save_model() | |
| log_info(f"Model saved to {training_args.output_dir}") | |
| # Push to hub if enabled | |
| if transformers_config.get("huggingface_hub", {}).get("push_to_hub", False): | |
| hub_id = transformers_config.get("huggingface_hub", {}).get("hub_model_id", "model") | |
| log_info(f"Pushing model to Hugging Face Hub as {hub_id}...") | |
| trainer.push_to_hub() | |
| log_info("Model successfully pushed to Hub") | |
| # Update the Hugging Face Space with current code | |
| if os.environ.get("HF_TOKEN") and os.environ.get("HF_USERNAME") and os.environ.get("HF_SPACE_NAME"): | |
| update_huggingface_space() | |
| return 0 | |
| except Exception as e: | |
| logger.error(f"Training failed with error: {str(e)}") | |
| # Log CUDA memory info if available in compact format | |
| if CUDA_AVAILABLE: | |
| memory_info = [] | |
| for i in range(NUM_GPUS): | |
| allocated = torch.cuda.memory_allocated(i) / 1024**2 | |
| reserved = torch.cuda.memory_reserved(i) / 1024**2 | |
| max_mem = torch.cuda.max_memory_allocated(i) / 1024**2 | |
| memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB (max: {max_mem:.1f}MB)") | |
| logger.error(f"GPU memory at failure: {', '.join(memory_info)}") | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in main training loop: {str(e)}") | |
| return 1 | |
| except Exception as e: | |
| logger.error(f"Error in main function: {str(e)}") | |
| return 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |