Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # coding=utf-8 | |
| import os | |
| import sys | |
| import json | |
| import argparse | |
| import logging | |
| from datetime import datetime | |
| # Import Unsloth first, before other ML imports | |
| try: | |
| from unsloth import FastLanguageModel | |
| from unsloth.chat_templates import get_chat_template | |
| unsloth_available = True | |
| except ImportError: | |
| unsloth_available = False | |
| logger = logging.getLogger(__name__) | |
| logger.warning("Unsloth not available. Please install with: pip install unsloth") | |
| import torch | |
| from datasets import load_dataset | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| TrainingArguments, | |
| Trainer, | |
| TrainerCallback, | |
| set_seed, | |
| BitsAndBytesConfig | |
| ) | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s", | |
| handlers=[logging.StreamHandler(sys.stdout)] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Check for BitsAndBytes | |
| try: | |
| from transformers import BitsAndBytesConfig | |
| bitsandbytes_available = True | |
| except ImportError: | |
| bitsandbytes_available = False | |
| logger.warning("BitsAndBytes not available. 4-bit quantization will not be used.") | |
| # Check for PEFT | |
| try: | |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
| peft_available = True | |
| except ImportError: | |
| peft_available = False | |
| logger.warning("PEFT not available. Parameter-efficient fine-tuning will not be used.") | |
| def load_env_variables(): | |
| """Load environment variables from system, .env file, or Hugging Face Space variables.""" | |
| # Check if we're running in a Hugging Face Space | |
| if os.environ.get("SPACE_ID"): | |
| logging.info("Running in Hugging Face Space") | |
| # Log the presence of variables (without revealing values) | |
| logging.info(f"HF_TOKEN available: {bool(os.environ.get('HF_TOKEN'))}") | |
| logging.info(f"HF_USERNAME available: {bool(os.environ.get('HF_USERNAME'))}") | |
| # If username is not set, try to extract from SPACE_ID | |
| if not os.environ.get("HF_USERNAME") and "/" in os.environ.get("SPACE_ID", ""): | |
| username = os.environ.get("SPACE_ID").split("/")[0] | |
| os.environ["HF_USERNAME"] = username | |
| logging.info(f"Set HF_USERNAME from SPACE_ID: {username}") | |
| else: | |
| # Try to load from .env file if not in a Space | |
| try: | |
| from dotenv import load_dotenv | |
| # Updated path to .env file in the new directory structure | |
| env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "shared", ".env") | |
| if os.path.exists(env_path): | |
| load_dotenv(env_path) | |
| logging.info(f"Loaded environment variables from {env_path}") | |
| logging.info(f"HF_TOKEN loaded from .env file: {bool(os.environ.get('HF_TOKEN'))}") | |
| logging.info(f"HF_USERNAME loaded from .env file: {bool(os.environ.get('HF_USERNAME'))}") | |
| logging.info(f"HF_SPACE_NAME loaded from .env file: {bool(os.environ.get('HF_SPACE_NAME'))}") | |
| else: | |
| logging.warning(f"No .env file found at {env_path}") | |
| except ImportError: | |
| logging.warning("python-dotenv not installed, not loading from .env file") | |
| if not os.environ.get("HF_USERNAME"): | |
| logger.warning("HF_USERNAME is not set. Using default username.") | |
| if not os.environ.get("HF_SPACE_NAME"): | |
| logger.warning("HF_SPACE_NAME is not set. Using default space name.") | |
| # Set HF_TOKEN for huggingface_hub | |
| if os.environ.get("HF_TOKEN"): | |
| os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN") | |
| def load_configs(base_path): | |
| """Load all configuration files.""" | |
| configs = {} | |
| # List of config files to load | |
| config_files = [ | |
| "transformers_config.json", | |
| "hardware_config.json", | |
| "dataset_config.json" | |
| ] | |
| for config_file in config_files: | |
| file_path = os.path.join(base_path, config_file) | |
| try: | |
| with open(file_path, "r") as f: | |
| config_name = config_file.replace("_config.json", "") | |
| configs[config_name] = json.load(f) | |
| logger.info(f"Loaded {config_name} configuration from {file_path}") | |
| except Exception as e: | |
| logger.error(f"Error loading {config_file}: {e}") | |
| raise | |
| return configs | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Fine-tune a language model on a text dataset") | |
| parser.add_argument("--config_dir", type=str, default=".", help="Directory containing configuration files") | |
| return parser.parse_args() | |
| def load_model_and_tokenizer(config): | |
| """Load model and tokenizer with proper error handling and optimizations.""" | |
| try: | |
| if not unsloth_available: | |
| logger.error("Unsloth is required for training with pre-quantized model") | |
| logger.error("Please ensure unsloth is in requirements.txt") | |
| raise ImportError("Unsloth is required for this training setup") | |
| # Get model name correctly from nested config structure | |
| model_name = config.get("model", {}).get("name") or config.get("model_name_or_path") or config.get("model_name") | |
| logger.info(f"Loading model: {model_name}") | |
| if not model_name: | |
| raise ValueError("Model name not found in configuration. Please check your transformers_config.json file.") | |
| logger.info("Using Unsloth optimizations with pre-quantized model") | |
| # Check for flash attention without importing it directly | |
| use_flash_attention = config.get("use_flash_attention", True) | |
| try: | |
| import flash_attn | |
| logger.info("Flash attention detected and will be used") | |
| except ImportError: | |
| use_flash_attention = False | |
| logger.warning("Flash attention not available, falling back to standard attention") | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=model_name, | |
| max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048), | |
| dtype=None, # Let Unsloth choose optimal dtype | |
| device_map="auto", | |
| # Don't explicitly use flash attention config here, let Unsloth handle it | |
| ) | |
| # Apply Unsloth's training optimizations with config parameters | |
| unsloth_config = config.get("unsloth", {}) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=unsloth_config.get("r", 32), | |
| target_modules=unsloth_config.get("target_modules", | |
| ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]), | |
| lora_alpha=unsloth_config.get("alpha", 16), | |
| lora_dropout=unsloth_config.get("dropout", 0.05), | |
| bias="none", | |
| use_gradient_checkpointing=config.get("gradient_checkpointing", True) or config.get("training", {}).get("gradient_checkpointing", True), | |
| random_state=config.get("seed", 42), | |
| ) | |
| logger.info("Unsloth optimizations applied successfully") | |
| # Set up tokenizer settings | |
| chat_template = config.get("chat_template") or config.get("tokenizer", {}).get("chat_template") | |
| if chat_template: | |
| try: | |
| template = get_chat_template("phi") | |
| tokenizer.chat_template = template | |
| logger.info("Set phi chat template") | |
| except Exception as e: | |
| logger.warning(f"Failed to set chat template: {str(e)}") | |
| # Ensure proper token settings | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| logger.info(f"Set pad_token_id to eos_token_id: {tokenizer.pad_token_id}") | |
| return model, tokenizer | |
| except Exception as e: | |
| logger.error(f"Error in model/tokenizer loading: {str(e)}") | |
| logger.error("If missing dependencies, check the requirements.txt file") | |
| raise | |
| def load_dataset_with_mapping(dataset_config): | |
| """Load and prepare dataset with proper column mapping.""" | |
| try: | |
| # Load dataset | |
| dataset_name = dataset_config.get("dataset", {}).get("name", "") | |
| dataset_split = dataset_config.get("dataset", {}).get("split", "train") | |
| if not dataset_name: | |
| raise ValueError("Dataset name not provided in configuration") | |
| logger.info(f"Loading dataset {dataset_name}, split {dataset_split}") | |
| dataset = load_dataset(dataset_name, split=dataset_split) | |
| # Map columns if specified - with checks to avoid conflicts | |
| column_mapping = dataset_config.get("dataset", {}).get("column_mapping", {}) | |
| if column_mapping: | |
| logger.info(f"Checking column mapping: {column_mapping}") | |
| # Only apply mappings for columns that need renaming and don't already exist | |
| safe_mappings = {} | |
| for target, source in column_mapping.items(): | |
| if source in dataset.column_names: | |
| # Skip if target already exists and is not the same as source | |
| if target in dataset.column_names and target != source: | |
| logger.warning(f"Cannot rename '{source}' to '{target}' - target column already exists") | |
| else: | |
| safe_mappings[source] = target | |
| # Apply safe renames | |
| if safe_mappings: | |
| logger.info(f"Applying safe column mapping: {safe_mappings}") | |
| for source, target in safe_mappings.items(): | |
| if source != target: # Only rename if names are different | |
| dataset = dataset.rename_column(source, target) | |
| # Verify expected columns exist | |
| expected_columns = {"id", "conversations"} | |
| for col in expected_columns: | |
| if col not in dataset.column_names: | |
| # If "conversations" is missing but "text" exists, it might need conversion | |
| if col == "conversations" and "text" in dataset.column_names: | |
| logger.info("Converting 'text' field to 'conversations' format") | |
| def convert_text_to_conversations(example): | |
| # Check if text is already a list of conversation turns | |
| if isinstance(example.get("text"), list): | |
| return {"conversations": example["text"]} | |
| # Otherwise, create a simple conversation with the text as user message | |
| else: | |
| return { | |
| "conversations": [ | |
| {"role": "user", "content": str(example.get("text", ""))} | |
| ] | |
| } | |
| dataset = dataset.map(convert_text_to_conversations) | |
| else: | |
| logger.warning(f"Expected column '{col}' not found in dataset") | |
| # Sort dataset if required | |
| sort_by_id = dataset_config.get("dataset", {}).get("processing", {}).get("sort_by_id", False) | |
| if sort_by_id and "id" in dataset.column_names: | |
| logger.info("Sorting dataset by ID") | |
| dataset = dataset.sort("id") | |
| # Log the first few IDs to verify sorting | |
| sample_ids = [example['id'] for example in dataset.select(range(min(5, len(dataset))))] | |
| logger.info(f"First few IDs after sorting: {sample_ids}") | |
| # Log example of conversations structure to verify format | |
| if "conversations" in dataset.column_names: | |
| sample_conv = dataset["conversations"][0] if len(dataset) > 0 else [] | |
| logger.info(f"Example conversation structure: {sample_conv}") | |
| logger.info(f"Dataset loaded successfully with {len(dataset)} examples") | |
| logger.info(f"Dataset columns: {dataset.column_names}") | |
| return dataset | |
| except Exception as e: | |
| logger.error(f"Error loading dataset: {str(e)}") | |
| raise | |
| def format_phi_chat(messages, dataset_config): | |
| """Format messages according to phi-4's chat template and dataset config.""" | |
| formatted_chat = "" | |
| # Get role templates from config | |
| roles = dataset_config.get("data_formatting", {}).get("roles", { | |
| "system": "System: {content}\n\n", | |
| "human": "Human: {content}\n\n", | |
| "user": "Human: {content}\n\n", | |
| "assistant": "Assistant: {content}\n\n" | |
| }) | |
| # Handle research introduction metadata first | |
| metadata = next((msg for msg in messages if isinstance(msg, dict) and | |
| "[RESEARCH INTRODUCTION]" in msg.get("content", "")), None) | |
| if metadata: | |
| system_template = roles.get("system", "System: {content}\n\n") | |
| formatted_chat = system_template.format(content=metadata['content']) | |
| messages = [msg for msg in messages if msg != metadata] | |
| # Process remaining messages | |
| for message in messages: | |
| if not isinstance(message, dict) or "content" not in message: | |
| logger.warning(f"Skipping invalid message format: {message}") | |
| continue | |
| role = message.get("role", "").lower() | |
| content = message.get("content", "") | |
| # Format based on role | |
| if role == "human" or role == "user": | |
| template = roles.get("user", roles.get("human", "Human: {content}\n\n")) | |
| formatted_chat += template.format(content=content) | |
| elif role == "assistant" or role == "bot": | |
| template = roles.get("assistant", "Assistant: {content}\n\n") | |
| formatted_chat += template.format(content=content) | |
| elif role == "system": | |
| # For system messages, prepend them | |
| template = roles.get("system", "System: {content}\n\n") | |
| formatted_chat = template.format(content=content) + formatted_chat | |
| else: | |
| # Default to system for unknown roles | |
| logger.warning(f"Unknown role '{role}' - treating as system message") | |
| template = roles.get("system", "System: {content}\n\n") | |
| formatted_chat += template.format(content=content) | |
| return formatted_chat.strip() | |
| class SimpleDataCollator: | |
| def __init__(self, tokenizer, dataset_config): | |
| self.tokenizer = tokenizer | |
| self.dataset_config = dataset_config | |
| self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0} | |
| self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 | |
| self.prompt_counter = 0 | |
| self.paper_counters = {} | |
| self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048) | |
| self.include_metadata = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("include_paper_id", True) | |
| self.include_chunk = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("include_chunk_number", True) | |
| self.metadata_format = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("metadata_format", "Paper ID: {paper_id} | Chunk: {chunk_number}") | |
| self.roles = dataset_config.get("data_formatting", {}).get("roles", {}) | |
| logger.info(f"SimpleDataCollator initialized - using phi-4 chat format with max_seq_length={self.max_seq_length}") | |
| def normalize_conversation(self, conversation): | |
| """Normalize conversation format to ensure consistent structure.""" | |
| normalized = [] | |
| # Handle non-list or empty inputs | |
| if not isinstance(conversation, list): | |
| logger.warning(f"Conversation is not a list: {type(conversation)}") | |
| if hasattr(conversation, 'items'): # It's a dict-like object | |
| conversation = [conversation] | |
| else: | |
| return [] | |
| for turn in conversation: | |
| # Skip empty or None entries | |
| if not turn: | |
| continue | |
| # Handle string entries (convert to user message) | |
| if isinstance(turn, str): | |
| normalized.append({"role": "user", "content": turn}) | |
| continue | |
| # Handle dict-like entries | |
| if not isinstance(turn, dict) and hasattr(turn, 'get'): | |
| # Convert to dict | |
| turn = {k: turn.get(k) for k in ['role', 'content'] if hasattr(turn, 'get') and turn.get(k) is not None} | |
| # Ensure both role and content exist | |
| if not isinstance(turn, dict) or 'role' not in turn or 'content' not in turn: | |
| logger.warning(f"Skipping malformatted conversation turn: {turn}") | |
| continue | |
| # Normalize role field | |
| role = turn.get('role', '').lower() | |
| if role == 'user' or role == 'human': | |
| role = 'user' | |
| elif role == 'assistant' or role == 'bot': | |
| role = 'assistant' | |
| # Add normalized turn | |
| normalized.append({ | |
| "role": role, | |
| "content": str(turn.get('content', '')) | |
| }) | |
| return normalized | |
| def __call__(self, features): | |
| batch = {"input_ids": [], "attention_mask": [], "labels": []} | |
| for example in features: | |
| try: | |
| # Get ID and conversation fields | |
| paper_id = example.get("id", "") | |
| # Handle conversation field - could be under 'conversations' or 'text' | |
| conversation = example.get("conversations", example.get("text", [])) | |
| # Normalize conversation format | |
| conversation = self.normalize_conversation(conversation) | |
| if not conversation: | |
| self.stats["skipped"] += 1 | |
| continue | |
| # Track paper chunks | |
| if paper_id not in self.paper_counters: | |
| self.paper_counters[paper_id] = 0 | |
| self.paper_counters[paper_id] += 1 | |
| # Add metadata if configured | |
| if self.include_metadata: | |
| # Format metadata according to configured format | |
| metadata_content = self.metadata_format.format( | |
| paper_id=paper_id, | |
| chunk_number=self.paper_counters[paper_id] | |
| ) | |
| # Add as system message if not already in conversation | |
| if not any(msg.get("role") == "system" for msg in conversation): | |
| conversation = [{"role": "system", "content": metadata_content}] + conversation | |
| # Format conversation with research introduction and chunk info | |
| formatted_content = format_phi_chat(conversation, self.dataset_config) | |
| # Tokenize with the model's chat template | |
| inputs = self.tokenizer( | |
| formatted_content, | |
| add_special_tokens=True, | |
| truncation=True, | |
| max_length=self.max_seq_length, | |
| return_tensors=None, | |
| ) | |
| if len(inputs["input_ids"]) > 0: | |
| # For causal language modeling, labels are the same as inputs | |
| labels = inputs["input_ids"].copy() | |
| batch["input_ids"].append(inputs["input_ids"]) | |
| batch["attention_mask"].append(inputs["attention_mask"]) | |
| batch["labels"].append(labels) | |
| self.stats["processed"] += 1 | |
| self.stats["total_tokens"] += len(inputs["input_ids"]) | |
| # Debug logging for first few examples | |
| log_samples = self.dataset_config.get("validation", {}).get("log_samples", 3) | |
| if self.stats["processed"] <= log_samples: | |
| logger.info(f"Example {self.stats['processed']} format:") | |
| logger.info(f"Paper ID: {paper_id} | Chunk: {self.paper_counters[paper_id]}") | |
| logger.info(f"Token count: {len(inputs['input_ids'])}") | |
| logger.info(f"Content preview:\n{formatted_content[:500]}...") | |
| logger.info(f"Conversation structure: {conversation[:2]}...") | |
| else: | |
| self.stats["skipped"] += 1 | |
| except Exception as e: | |
| logger.warning(f"Error processing example: {str(e)[:100]}...") | |
| logger.warning(f"Problematic example: {str(example)[:200]}...") | |
| self.stats["skipped"] += 1 | |
| continue | |
| if not batch["input_ids"]: | |
| logger.warning("Empty batch, returning dummy tensors") | |
| return { | |
| "input_ids": torch.zeros((1, 1), dtype=torch.long), | |
| "attention_mask": torch.zeros((1, 1), dtype=torch.long), | |
| "labels": torch.zeros((1, 1), dtype=torch.long) | |
| } | |
| # Pad the batch | |
| max_length = max(len(ids) for ids in batch["input_ids"]) | |
| for i in range(len(batch["input_ids"])): | |
| padding_length = max_length - len(batch["input_ids"][i]) | |
| if padding_length > 0: | |
| batch["input_ids"][i].extend([self.pad_token_id] * padding_length) | |
| batch["attention_mask"][i].extend([0] * padding_length) | |
| batch["labels"][i].extend([-100] * padding_length) | |
| # Convert to tensors | |
| batch = {k: torch.tensor(v) for k, v in batch.items()} | |
| # Log stats periodically | |
| log_interval = self.dataset_config.get("validation", {}).get("log_interval", 100) | |
| if self.stats["processed"] % log_interval == 0 and self.stats["processed"] > 0: | |
| logger.info(f"Data collator stats: processed={self.stats['processed']}, " | |
| f"skipped={self.stats['skipped']}, " | |
| f"avg_tokens={self.stats['total_tokens']/self.stats['processed']:.1f}, " | |
| f"unique_papers={len(self.paper_counters)}") | |
| return batch | |
| def check_dependencies(): | |
| """Check if all required dependencies are installed.""" | |
| missing_packages = [] | |
| # Critical packages | |
| if not unsloth_available: | |
| missing_packages.append("unsloth>=2024.3") | |
| if not peft_available: | |
| missing_packages.append("peft>=0.9.0") | |
| # Optional packages - don't add to missing list, just log | |
| try: | |
| import flash_attn | |
| logger.info("flash-attn found. Flash attention will be used for faster training.") | |
| except ImportError: | |
| logger.warning("flash-attn not found. Training will work but may be slower.") | |
| # Don't add to missing packages since it's optional and can cause build issues | |
| # If critical packages are missing, exit with instructions | |
| if missing_packages: | |
| logger.error("Critical dependencies missing:") | |
| for pkg in missing_packages: | |
| logger.error(f" - {pkg}") | |
| logger.error("Please ensure the space has these packages in requirements.txt") | |
| return False | |
| return True | |
| def main(): | |
| # Set up logging | |
| logger.info("Starting training process") | |
| # Parse arguments | |
| args = parse_args() | |
| # Check dependencies | |
| if not check_dependencies(): | |
| logger.error("Aborting due to missing critical dependencies") | |
| return 1 | |
| # Load environment variables | |
| load_env_variables() | |
| # Load all configurations | |
| try: | |
| configs = load_configs(args.config_dir) | |
| # Extract specific configs | |
| if not configs: | |
| logger.error("Failed to load configurations") | |
| return 1 | |
| # Verify configurations exist | |
| if "transformers" not in configs: | |
| logger.error("transformers_config.json not found or invalid") | |
| return 1 | |
| if "hardware" not in configs: | |
| logger.warning("hardware_config.json not found. Using default hardware configuration.") | |
| if "dataset" not in configs: | |
| logger.error("dataset_config.json not found or invalid") | |
| return 1 | |
| # Validate model configuration | |
| model_config = configs["transformers"] | |
| if not model_config.get("model", {}).get("name") and not model_config.get("model_name_or_path") and not model_config.get("model_name"): | |
| logger.error("Model name not specified in configuration") | |
| logger.error("Please ensure 'name' is specified under 'model' in transformers_config.json") | |
| return 1 | |
| logger.info(f"Model name: {model_config.get('model', {}).get('name') or model_config.get('model_name_or_path') or model_config.get('model_name')}") | |
| logger.info("All configurations loaded successfully") | |
| # Extract specific configs | |
| model_config = configs["transformers"] | |
| hardware_config = configs.get("hardware", {}) | |
| dataset_config = configs["dataset"] | |
| # Apply hardware-specific settings if available | |
| if hardware_config: | |
| training_opts = hardware_config.get("training_optimizations", {}) | |
| per_device_batch_size = training_opts.get("per_device_batch_size") | |
| gradient_accumulation = training_opts.get("gradient_accumulation_steps") | |
| if per_device_batch_size and model_config.get("training"): | |
| model_config["training"]["per_device_train_batch_size"] = per_device_batch_size | |
| logger.info(f"Applied hardware-specific batch size: {per_device_batch_size}") | |
| if gradient_accumulation and model_config.get("training"): | |
| model_config["training"]["gradient_accumulation_steps"] = gradient_accumulation | |
| logger.info(f"Applied hardware-specific gradient accumulation: {gradient_accumulation}") | |
| # Apply memory optimizations | |
| memory_opts = training_opts.get("memory_optimizations", {}) | |
| if memory_opts.get("use_gradient_checkpointing") is not None and model_config.get("training"): | |
| model_config["training"]["gradient_checkpointing"] = memory_opts["use_gradient_checkpointing"] | |
| except Exception as e: | |
| logger.error(f"Error loading configurations: {e}") | |
| return 1 | |
| # Set random seed for reproducibility | |
| seed = model_config.get("seed", 42) | |
| set_seed(seed) | |
| logger.info(f"Set random seed to {seed}") | |
| try: | |
| model, tokenizer = load_model_and_tokenizer(model_config) | |
| logger.info("Model and tokenizer loaded successfully") | |
| # Load dataset with proper mapping | |
| try: | |
| dataset = load_dataset_with_mapping(dataset_config) | |
| logger.info("Dataset loaded and prepared successfully") | |
| except Exception as e: | |
| logger.error(f"Error loading dataset: {e}") | |
| return 1 | |
| # Create data collator | |
| data_collator = SimpleDataCollator(tokenizer, dataset_config) | |
| # Simple logging callback | |
| class LoggingCallback(TrainerCallback): | |
| def __init__(self): | |
| self.last_log_time = datetime.now() | |
| self.training_start_time = datetime.now() | |
| def on_step_end(self, args, state, control, **kwargs): | |
| # Log every 50 steps or every 5 minutes, whichever comes first | |
| current_time = datetime.now() | |
| time_diff = (current_time - self.last_log_time).total_seconds() | |
| elapsed_time = (current_time - self.training_start_time).total_seconds() / 60 # in minutes | |
| if state.global_step % 50 == 0 or time_diff > 300: # 300 seconds = 5 minutes | |
| loss = state.log_history[-1]['loss'] if state.log_history else 'N/A' | |
| lr = state.log_history[-1]['learning_rate'] if state.log_history else 'N/A' | |
| if isinstance(loss, float): | |
| loss_str = f"{loss:.4f}" | |
| else: | |
| loss_str = str(loss) | |
| if isinstance(lr, float): | |
| lr_str = f"{lr:.8f}" | |
| else: | |
| lr_str = str(lr) | |
| logger.info(f"Step: {state.global_step} | Loss: {loss_str} | LR: {lr_str} | Elapsed: {elapsed_time:.2f} min") | |
| self.last_log_time = current_time | |
| # Set up training arguments | |
| logger.info("Setting up training arguments") | |
| training_args = TrainingArguments( | |
| output_dir=model_config.get("output_dir", "./results"), | |
| num_train_epochs=model_config.get("num_train_epochs", 3), | |
| per_device_train_batch_size=model_config.get("per_device_train_batch_size", 4), # Use config value, can be > 1 | |
| gradient_accumulation_steps=model_config.get("gradient_accumulation_steps", 8), | |
| learning_rate=model_config.get("learning_rate", 5e-5), | |
| weight_decay=model_config.get("weight_decay", 0.01), | |
| warmup_ratio=model_config.get("warmup_ratio", 0.1), | |
| lr_scheduler_type=model_config.get("lr_scheduler_type", "cosine"), | |
| logging_steps=model_config.get("logging_steps", 10), | |
| save_strategy=model_config.get("save_strategy", "steps"), # Updated to use steps by default | |
| save_steps=model_config.get("save_steps", 100), # Save every 100 steps by default | |
| save_total_limit=model_config.get("save_total_limit", 3), # Keep last 3 checkpoints | |
| fp16=model_config.get("fp16", True), | |
| bf16=model_config.get("bf16", False), | |
| max_grad_norm=model_config.get("max_grad_norm", 1.0), | |
| push_to_hub=model_config.get("push_to_hub", False), | |
| hub_model_id=model_config.get("hub_model_id", None), | |
| hub_token=os.environ.get("HF_TOKEN", None), | |
| report_to="tensorboard", | |
| remove_unused_columns=False, # Keep the conversations column | |
| gradient_checkpointing=model_config.get("gradient_checkpointing", True), # Enable gradient checkpointing | |
| dataloader_pin_memory=False, # Reduce memory usage | |
| optim=model_config.get("optim", "adamw_torch"), | |
| ddp_find_unused_parameters=False, # Improve distributed training efficiency | |
| dataloader_drop_last=False, # Process all examples | |
| dataloader_num_workers=0, # Sequential data loading | |
| ) | |
| # Create a sequential sampler to ensure dataset is processed in order | |
| logger.info("Creating sequential sampler to maintain dataset order") | |
| # Create trainer with callback | |
| logger.info("Creating trainer") | |
| # Check if we should resume from checkpoint | |
| resume_from_checkpoint = False | |
| output_dir = model_config.get("output_dir", "./results") | |
| if os.path.exists(output_dir): | |
| checkpoints = [folder for folder in os.listdir(output_dir) if folder.startswith("checkpoint-")] | |
| if checkpoints: | |
| latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("-")[1])) | |
| resume_from_checkpoint = os.path.join(output_dir, latest_checkpoint) | |
| logger.info(f"Found checkpoint: {resume_from_checkpoint}. Training will resume from this point.") | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset, | |
| data_collator=data_collator, | |
| callbacks=[LoggingCallback()] | |
| ) | |
| # Override the default data loader to disable shuffling | |
| # This is necessary because TrainingArguments doesn't have a direct shuffle parameter | |
| def get_train_dataloader_no_shuffle(): | |
| """Create a train DataLoader with shuffling disabled.""" | |
| logger.info("Creating train dataloader with sequential sampler (no shuffling)") | |
| # Create a sequential sampler to ensure dataset is processed in order | |
| train_sampler = torch.utils.data.SequentialSampler(dataset) | |
| return torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=training_args.per_device_train_batch_size, | |
| sampler=train_sampler, # Use sequential sampler instead of shuffle parameter | |
| collate_fn=data_collator, | |
| drop_last=False, | |
| num_workers=0, | |
| pin_memory=False | |
| ) | |
| # Replace the default data loader with our non-shuffling version | |
| trainer.get_train_dataloader = get_train_dataloader_no_shuffle | |
| # Start training | |
| logger.info("Starting training") | |
| logger.info(f"Processing with batch size = {training_args.per_device_train_batch_size}, each entry processed independently") | |
| # Create a lock file to indicate training is in progress | |
| lock_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "TRAINING_IN_PROGRESS.lock") | |
| with open(lock_file, "w") as f: | |
| f.write(f"Training started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") | |
| f.write(f"Expected completion: After {training_args.num_train_epochs} epochs\n") | |
| f.write("DO NOT UPDATE OR RESTART THIS SPACE UNTIL TRAINING COMPLETES\n") | |
| logger.info(f"Created lock file: {lock_file}") | |
| try: | |
| trainer.train(resume_from_checkpoint=resume_from_checkpoint) | |
| logger.info("Training completed successfully") | |
| # Save model | |
| if model_config.get("push_to_hub", False): | |
| logger.info(f"Pushing model to hub: {model_config.get('hub_model_id')}") | |
| trainer.push_to_hub() | |
| logger.info("Model pushed to hub successfully") | |
| else: | |
| logger.info(f"Saving model to {model_config.get('output_dir', './results')}") | |
| trainer.save_model() | |
| logger.info("Model saved successfully") | |
| except Exception as e: | |
| logger.error(f"Training failed with error: {str(e)}") | |
| raise | |
| finally: | |
| # Remove the lock file when training completes or fails | |
| if os.path.exists(lock_file): | |
| os.remove(lock_file) | |
| logger.info(f"Removed lock file: {lock_file}") | |
| return 0 | |
| except Exception as e: | |
| logger.error(f"Error in main training loop: {str(e)}") | |
| return 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |