Spaces:
Sleeping
Sleeping
| def format_phi_chat(messages, dataset_config): | |
| """Format messages according to phi-4's chat template and dataset config.""" | |
| formatted_chat = "" | |
| # Get role templates from config | |
| roles = dataset_config.get("data_formatting", {}).get("roles", { | |
| "system": "System: {content}\n\n", | |
| "human": "Human: {content}\n\n", | |
| "user": "Human: {content}\n\n", | |
| "assistant": "Assistant: {content}\n\n" | |
| }) | |
| # Handle research introduction metadata first | |
| metadata = next((msg for msg in messages if isinstance(msg, dict) and | |
| "[RESEARCH INTRODUCTION]" in msg.get("content", "")), None) | |
| if metadata: | |
| system_template = roles.get("system", "System: {content}\n\n") | |
| formatted_chat = system_template.format(content=metadata['content']) | |
| messages = [msg for msg in messages if msg != metadata] | |
| # Process remaining messages | |
| for message in messages: | |
| if not isinstance(message, dict) or "content" not in message: | |
| logger.warning(f"Skipping invalid message format: {message}") | |
| continue | |
| role = message.get("role", "").lower() | |
| content = message.get("content", "") | |
| # Format based on role | |
| if role == "human" or role == "user": | |
| template = roles.get("user", roles.get("human", "Human: {content}\n\n")) | |
| formatted_chat += template.format(content=content) | |
| elif role == "assistant" or role == "bot": | |
| template = roles.get("assistant", "Assistant: {content}\n\n") | |
| formatted_chat += template.format(content=content) | |
| elif role == "system": | |
| # For system messages, prepend them | |
| template = roles.get("system", "System: {content}\n\n") | |
| formatted_chat = template.format(content=content) + formatted_chat | |
| else: | |
| # Default to system for unknown roles | |
| logger.warning(f"Unknown role '{role}' - treating as system message") | |
| template = roles.get("system", "System: {content}\n\n") | |
| formatted_chat += template.format(content=content) | |
| return formatted_chat.strip() | |
| class SimpleDataCollator: | |
| def __init__(self, tokenizer, dataset_config): | |
| self.tokenizer = tokenizer | |
| self.dataset_config = dataset_config | |
| self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0} | |
| self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 | |
| self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048) | |
| logger.info(f"SimpleDataCollator initialized - using pre-audited dataset with max_seq_length={self.max_seq_length}") | |
| logger.info("Using exact dataset structure without reformatting") | |
| # Check if we're on GPU | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"SimpleDataCollator using device: {self.device}") | |
| def __call__(self, features): | |
| """Process examples preserving exact JSONL structure""" | |
| batch = {"input_ids": [], "attention_mask": [], "labels": []} | |
| for example in features: | |
| try: | |
| # Get ID | |
| paper_id = example.get("id", "") | |
| # Get conversations - these should already contain role and content | |
| conversations = example.get("conversations", []) | |
| if not conversations: | |
| self.stats["skipped"] += 1 | |
| continue | |
| # Directly use the conversations array as input to the model's chat template | |
| # This preserves the exact structure with roles and content as they are | |
| try: | |
| # Let tokenizer handle the content with the model's chat template | |
| inputs = self.tokenizer.apply_chat_template( | |
| conversations, | |
| return_tensors=None, | |
| add_generation_prompt=False | |
| ) | |
| except Exception as chat_error: | |
| # Fallback if apply_chat_template fails | |
| logger.warning(f"Chat template application failed for example {paper_id}: {str(chat_error)[:100]}") | |
| # Create a basic representation of the conversation | |
| conversation_text = "" | |
| for msg in conversations: | |
| if isinstance(msg, dict) and 'content' in msg: | |
| conversation_text += msg.get('content', '') + "\n\n" | |
| # Basic tokenization | |
| inputs = self.tokenizer( | |
| conversation_text, | |
| add_special_tokens=True, | |
| return_tensors=None | |
| ) | |
| # Apply length cap if needed (shouldn't be necessary for pre-audited data) | |
| if self.max_seq_length > 0 and len(inputs) > self.max_seq_length: | |
| logger.warning(f"Example {paper_id} exceeds max_seq_length ({len(inputs)} > {self.max_seq_length})") | |
| inputs = inputs[:self.max_seq_length] | |
| # Create attention mask (1 for all tokens) | |
| attention_mask = [1] * len(inputs) | |
| if len(inputs) > 0: | |
| # For causal language modeling, labels are the same as inputs | |
| labels = inputs.copy() | |
| batch["input_ids"].append(inputs) | |
| batch["attention_mask"].append(attention_mask) | |
| batch["labels"].append(labels) | |
| self.stats["processed"] += 1 | |
| self.stats["total_tokens"] += len(inputs) | |
| # Debug logging for first few examples | |
| log_samples = self.dataset_config.get("validation", {}).get("log_samples", 3) | |
| if self.stats["processed"] <= log_samples: | |
| logger.info(f"Example {self.stats['processed']}:") | |
| logger.info(f"Paper ID: {paper_id}") | |
| logger.info(f"Token count: {len(inputs)}") | |
| logger.info(f"Conversation entries: {len(conversations)}") | |
| else: | |
| self.stats["skipped"] += 1 | |
| except Exception as e: | |
| logger.warning(f"Error processing example: {str(e)[:100]}...") | |
| logger.warning(f"Problematic example ID: {example.get('id', 'unknown')}") | |
| self.stats["skipped"] += 1 | |
| continue | |
| if not batch["input_ids"]: | |
| logger.warning("Empty batch, returning dummy tensors") | |
| return { | |
| "input_ids": torch.zeros((1, 1), dtype=torch.long), | |
| "attention_mask": torch.zeros((1, 1), dtype=torch.long), | |
| "labels": torch.zeros((1, 1), dtype=torch.long) | |
| } | |
| # Pad the batch | |
| max_length = max(len(ids) for ids in batch["input_ids"]) | |
| for i in range(len(batch["input_ids"])): | |
| padding_length = max_length - len(batch["input_ids"][i]) | |
| if padding_length > 0: | |
| batch["input_ids"][i].extend([self.pad_token_id] * padding_length) | |
| batch["attention_mask"][i].extend([0] * padding_length) | |
| batch["labels"][i].extend([-100] * padding_length) | |
| # Convert to tensors | |
| batch = {k: torch.tensor(v, dtype=torch.long) for k, v in batch.items()} | |
| # Log stats periodically | |
| log_interval = self.dataset_config.get("validation", {}).get("log_interval", 100) | |
| if self.stats["processed"] % log_interval == 0 and self.stats["processed"] > 0: | |
| logger.info(f"Data collator stats: processed={self.stats['processed']}, " | |
| f"skipped={self.stats['skipped']}, " | |
| f"avg_tokens={self.stats['total_tokens']/self.stats['processed']:.1f}") | |
| return batch | |
| class LoggingCallback(TrainerCallback): | |
| def __init__(self): | |
| self.last_log_time = time.time() | |
| self.last_memory_log_time = time.time() | |
| def on_step_end(self, args, state, control, **kwargs): | |
| # Log every 50 steps or every 5 minutes, whichever comes first | |
| current_time = time.time() | |
| # Log loss every 50 steps or 5 minutes | |
| if (state.global_step % 50 == 0) or (current_time - self.last_log_time > 300): | |
| if state.log_history: | |
| loss = state.log_history[-1].get('loss', 'N/A') | |
| # Use simple formatting for better HF Space log compatibility | |
| log_info(f"Step {state.global_step}: Loss {loss}") | |
| else: | |
| log_info(f"Step {state.global_step}: No loss data available") | |
| self.last_log_time = current_time | |
| # Log memory usage every 15 minutes | |
| if current_time - self.last_memory_log_time > 900: # 15 minutes | |
| if torch.cuda.is_available(): | |
| memory_info = [] | |
| for i in range(torch.cuda.device_count()): | |
| allocated = torch.cuda.memory_allocated(i) / 1024**2 | |
| reserved = torch.cuda.memory_reserved(i) / 1024**2 | |
| memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB") | |
| # Log in compact format for better visibility | |
| log_info(f"Memory usage - {', '.join(memory_info)}") | |
| self.last_memory_log_time = current_time | |
| def on_train_begin(self, args, state, control, **kwargs): | |
| log_info("=== Training is starting ===") | |
| # Log important training parameters for visibility | |
| effective_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * max(1, torch.cuda.device_count()) | |
| log_info(f"Per device batch size: {args.per_device_train_batch_size}") | |
| log_info(f"Gradient accumulation steps: {args.gradient_accumulation_steps}") | |
| log_info(f"Number of GPUs: {max(1, torch.cuda.device_count())}") | |
| log_info(f"Total effective batch size: {effective_batch_size}") | |
| log_info(f"Learning rate: {args.learning_rate}") | |
| log_info(f"Epochs: {args.num_train_epochs}") | |
| # Log dataset information | |
| if hasattr(trainer, 'train_dataset') and trainer.train_dataset is not None: | |
| log_info(f"Dataset size: {len(trainer.train_dataset)} examples") | |
| if len(trainer.train_dataset) > 0: | |
| try: | |
| # Log first few prompt numbers to verify sequence | |
| prompt_numbers = [] | |
| for i in range(min(5, len(trainer.train_dataset))): | |
| if 'prompt_number' in trainer.train_dataset[i]: | |
| prompt_numbers.append(trainer.train_dataset[i]['prompt_number']) | |
| if prompt_numbers: | |
| log_info(f"First few prompt numbers: {prompt_numbers}") | |
| except Exception as e: | |
| log_info(f"Error accessing dataset samples: {e}") | |
| # Log memory information in compact format | |
| if torch.cuda.is_available(): | |
| memory_info = [] | |
| for i in range(torch.cuda.device_count()): | |
| allocated = torch.cuda.memory_allocated(i) / 1024**2 | |
| max_mem = torch.cuda.max_memory_allocated(i) / 1024**2 | |
| memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)") | |
| log_info(f"Initial memory usage - {', '.join(memory_info)}") | |
| def on_train_end(self, args, state, control, **kwargs): | |
| log_info("=== Training completed ===") | |
| if torch.cuda.is_available(): | |
| memory_info = [] | |
| for i in range(torch.cuda.device_count()): | |
| allocated = torch.cuda.memory_allocated(i) / 1024**2 | |
| max_mem = torch.cuda.max_memory_allocated(i) / 1024**2 | |
| memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)") | |
| log_info(f"Final memory usage - {', '.join(memory_info)}") | |
| log_info(f"Total steps: {state.global_step}") | |
| log_info(f"Final loss: {state.log_history[-1].get('loss', 'N/A') if state.log_history else 'N/A'}") | |
| def custom_get_train_dataloader(): | |
| """Custom dataloader that preserves original dataset order""" | |
| log_info("Creating sequential dataloader to maintain original dataset order") | |
| # Create a simple sequential sampler | |
| sequential_sampler = torch.utils.data.SequentialSampler(dataset) | |
| # Verify shuffle is disabled | |
| data_loading_config = dataset_config.get("data_loading", {}) | |
| shuffle_enabled = data_loading_config.get("shuffle", False) | |
| if shuffle_enabled: | |
| log_info("CRITICAL ERROR: Shuffle is enabled! This will randomize data entry order!") | |
| raise ValueError("Dataset shuffling is enabled but sequential processing is required. " + | |
| "Please disable shuffling in your configuration.") | |
| # Log our sequential processing approach | |
| log_info("Using SequentialSampler to guarantee original dataset order is preserved") | |
| log_info("Data order preservation is critical for proper training sequence") | |
| # Calculate batch size based on device availability | |
| if getattr(training_args, "no_cuda", False): | |
| batch_size = training_args.per_device_train_batch_size | |
| else: | |
| batch_size = max(training_args.per_device_train_batch_size * max(1, NUM_GPUS), 1) | |
| log_info(f"Using sequential sampler with batch size {batch_size}") | |
| # Return DataLoader with sequential sampler | |
| return torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| sampler=sequential_sampler, | |
| collate_fn=data_collator, | |
| drop_last=training_args.dataloader_drop_last, | |
| num_workers=training_args.dataloader_num_workers, | |
| pin_memory=training_args.dataloader_pin_memory, | |
| ) | |
| def check_dependencies(): | |
| """Check for critical dependencies and provide useful warnings.""" | |
| # Check for flash attention without attempting import | |
| flash_attention_available = False | |
| try: | |
| import importlib.util | |
| if importlib.util.find_spec("flash_attn") is not None: | |
| flash_attention_available = True | |
| log_info("flash-attn found! Using Flash Attention for faster training.") | |
| else: | |
| log_info("flash-attn not found. Training will continue but may be slower.") | |
| log_info("To use flash attention, install: pip install flash-attn==2.5.2 --no-build-isolation") | |
| # Still continue as this is optional | |
| except Exception as e: | |
| log_info(f"Error checking for flash-attn: {e}") | |
| # Check for torch CUDA | |
| if not torch.cuda.is_available(): | |
| log_info("WARNING: CUDA not available. Training will be extremely slow on CPU!") | |
| else: | |
| log_info(f"Found {torch.cuda.device_count()} CUDA devices") | |
| # Check for unsloth | |
| unsloth_available = False | |
| try: | |
| import importlib.util | |
| if importlib.util.find_spec("unsloth") is not None: | |
| unsloth_available = True | |
| log_info("Unsloth found! Using Unsloth for optimized training.") | |
| else: | |
| log_info("CRITICAL: Unsloth not found. This pipeline requires Unsloth.") | |
| log_info("Install with: pip install unsloth>=2024.3") | |
| return False | |
| except Exception as e: | |
| log_info(f"Error checking for unsloth: {e}") | |
| return False | |
| return True | |
| def main(): | |
| """Main training function with error handling.""" | |
| try: | |
| # Initialize logging | |
| log_info("Starting Phi-4 training process") | |
| # Parse arguments | |
| args = parse_args() | |
| # Load environment variables | |
| load_env_variables() | |
| # Load config from file | |
| config = load_configs(args.config) | |
| # Extract specific configurations | |
| hardware_config = config.get("hardware", {}) | |
| dataset_config = config.get("dataset", {}) | |
| # Define multi_gpu_strategy early to prevent undefined errors | |
| multi_gpu_strategy = hardware_config.get("training_optimizations", {}).get("multi_gpu_strategy", "data_parallel") | |
| log_info(f"Multi-GPU strategy: {multi_gpu_strategy}") | |
| # Check dependencies | |
| if not check_dependencies(): | |
| log_info("Aborting due to missing critical dependencies") | |
| return 1 | |
| # Log hardware info | |
| cuda_available = torch.cuda.is_available() | |
| num_gpus = torch.cuda.device_count() if cuda_available else 0 | |
| log_info(f"Hardware: {num_gpus} GPUs detected" if cuda_available else "Hardware: CPU only") | |
| # Rest of training code would go here | |
| # ... | |
| return 0 | |
| except Exception as e: | |
| log_info(f"Error in main training loop: {str(e)}") | |
| # Log CUDA memory if available | |
| if torch.cuda.is_available(): | |
| try: | |
| memory_info = [] | |
| for i in range(torch.cuda.device_count()): | |
| allocated = torch.cuda.memory_allocated(i) / 1024**2 | |
| reserved = torch.cuda.memory_reserved(i) / 1024**2 | |
| memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB") | |
| log_info(f"GPU memory at failure: {', '.join(memory_info)}") | |
| except: | |
| pass | |
| return 1 | |
| if __name__ == "__main__": | |
| import sys | |
| sys.exit(main()) |