# train.py (Updated for Full Fine-tuning) import torch import torch.optim as optim from torch.utils.data import DataLoader from torch.amp import autocast, GradScaler # For mixed precision training (updated import) from transformers import AutoTokenizer, default_data_collator from datasets import load_dataset from tqdm.auto import tqdm # Progress bar import os import evaluate # For metrics import logging # Optional: Better logging import multiprocessing # For Windows multiprocessing support import argparse # For command line arguments # Import our custom modules and config import config from model import EnhancedRRN_QA_Model # Setup basic logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def main(): # Parse command line arguments parser = argparse.ArgumentParser(description="Train RRN QA Model") parser.add_argument("--checkpoint", type=str, help="Path to checkpoint directory to resume from") parser.add_argument("--start_epoch", type=int, default=0, help="Epoch to start training from") parser.add_argument( "--subset_percentage", type=float, default=100.0, help="Percentage of training data to use (1.0-100.0). Default: 100.0 (full dataset)" ) parser.add_argument( "--bypass_delta", action="store_true", help="Bypass RRN delta calculation (sets delta = torch.zeros_like(h0))" ) args = parser.parse_args() # Set bypass delta calculation flag if specified if args.bypass_delta: logger.info("BYPASS_DELTA_CALCULATION enabled: Setting delta = torch.zeros_like(h0)") config.BYPASS_DELTA_CALCULATION = True else: config.BYPASS_DELTA_CALCULATION = False # --- 1. Load Tokenizer and Model --- if args.checkpoint: logger.info(f"Loading tokenizer from checkpoint: {args.checkpoint}") tokenizer = AutoTokenizer.from_pretrained(args.checkpoint) logger.info(f"Loading model from checkpoint: {args.checkpoint}") # Initialize the model with base architecture model = EnhancedRRN_QA_Model(os.path.join(args.checkpoint, "base_model")) # Check for enhanced model components gating_mechanism_path = os.path.join(args.checkpoint, "gating_mechanism.pth") is_enhanced_checkpoint = os.path.exists(gating_mechanism_path) # Load custom module weights logger.info("Loading model components...") model.qa_head.load_state_dict(torch.load(os.path.join(args.checkpoint, "qa_head.pth"))) model.retroactive_update_layer.load_state_dict(torch.load(os.path.join(args.checkpoint, "retroactive_layer.pth"))) # Load gating mechanism if available if is_enhanced_checkpoint: logger.info("Loading gating mechanism...") model.gating_mechanism.load_state_dict(torch.load(gating_mechanism_path)) # Load step controller if available (for learned dynamic steps) step_controller_path = os.path.join(args.checkpoint, "step_controller.pth") if os.path.exists(step_controller_path) and hasattr(model, "step_controller"): logger.info("Loading step controller for learned dynamic steps...") model.step_controller.load_state_dict(torch.load(step_controller_path)) else: logger.info("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(config.BASE_MODEL_NAME) logger.info("Instantiating Enhanced RRN QA Model for Full Fine-tuning...") model = EnhancedRRN_QA_Model(config.BASE_MODEL_NAME) model.to(config.DEVICE) # --- 2. Load and Preprocess Dataset --- logger.info("Loading SQuAD dataset...") raw_datasets = load_dataset("squad") # Handle dataset subsetting subset_percentage = args.subset_percentage if subset_percentage < 100.0: original_train_size = len(raw_datasets["train"]) # Calculate subset size and validate subset_percentage = max(0.1, min(100.0, subset_percentage)) # Clamp between 0.1% and 100% train_subset_size = int(original_train_size * subset_percentage / 100) train_subset_size = max(100, min(original_train_size, train_subset_size)) # Ensure reasonable bounds # Create reproducible subset with fixed seed for consistency subset_indices = torch.randperm(original_train_size, generator=torch.Generator().manual_seed(42))[:train_subset_size].tolist() raw_datasets["train"] = raw_datasets["train"].select(subset_indices) logger.info(f"Using {subset_percentage:.1f}% of training data ({train_subset_size}/{original_train_size} examples)") else: logger.info(f"Using full training dataset ({len(raw_datasets['train'])} examples)") question_column_name = "question" context_column_name = "context" answer_column_name = "answers" pad_on_right = tokenizer.padding_side == "right" def prepare_train_features(examples): examples[question_column_name] = [q.strip() for q in examples[question_column_name]] tokenized_examples = tokenizer( examples[question_column_name if pad_on_right else context_column_name], examples[context_column_name if pad_on_right else question_column_name], truncation="only_second" if pad_on_right else "only_first", max_length=config.MAX_SEQ_LENGTH, stride=config.DOC_STRIDE, return_overflowing_tokens=True, return_offsets_mapping=True, padding="max_length", ) sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") offset_mapping = tokenized_examples.pop("offset_mapping") tokenized_examples["start_positions"] = [] tokenized_examples["end_positions"] = [] for i, offsets in enumerate(offset_mapping): input_ids = tokenized_examples["input_ids"][i] cls_index = input_ids.index(tokenizer.cls_token_id) sequence_ids = tokenized_examples.sequence_ids(i) sample_index = sample_mapping[i] answers = examples[answer_column_name][sample_index] if len(answers["answer_start"]) == 0: tokenized_examples["start_positions"].append(cls_index) tokenized_examples["end_positions"].append(cls_index) else: start_char = answers["answer_start"][0] end_char = start_char + len(answers["text"][0]) token_start_index = 0 while sequence_ids[token_start_index] != (1 if pad_on_right else 0): token_start_index += 1 token_end_index = len(input_ids) - 1 while sequence_ids[token_end_index] != (1 if pad_on_right else 0): token_end_index -= 1 if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): tokenized_examples["start_positions"].append(cls_index) tokenized_examples["end_positions"].append(cls_index) else: while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char: token_start_index += 1 tokenized_examples["start_positions"].append(token_start_index - 1) while offsets[token_end_index][1] >= end_char: token_end_index -= 1 tokenized_examples["end_positions"].append(token_end_index + 1) return tokenized_examples logger.info("Preprocessing datasets...") # Use single process on Windows to avoid multiprocessing issues tokenized_datasets = raw_datasets.map( prepare_train_features, batched=True, remove_columns=raw_datasets["train"].column_names, num_proc=1 # Use single process to avoid Windows multiprocessing issues ) data_collator = default_data_collator train_dataloader = DataLoader( tokenized_datasets["train"], shuffle=True, collate_fn=data_collator, batch_size=config.BATCH_SIZE ) # Consider adding validation dataloader setup here as well # eval_dataloader = DataLoader(...) # --- 3. Setup Optimizer --- logger.info("Setting up optimizer for FULL model fine-tuning...") # Optimize all parameters since PEFT is disabled optimizer = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE) logger.info(f"Optimizer: AdamW with LR={config.LEARNING_RATE}") # Calculate total steps considering gradient accumulation num_update_steps_per_epoch = len(train_dataloader) // config.GRADIENT_ACCUMULATION_STEPS num_training_steps = config.EPOCHS * num_update_steps_per_epoch logger.info(f"Total optimization steps: {num_training_steps}") # --- 4. Initialize Mixed Precision Training --- # Initialize gradient scaler for mixed precision training scaler = GradScaler('cuda', enabled=config.USE_MIXED_PRECISION) # Updated to fix deprecation warning # Log mixed precision and dynamic steps status if config.USE_MIXED_PRECISION: logger.info("Mixed precision training (FP16) enabled") if config.USE_DYNAMIC_STEPS: logger.info(f"Dynamic reasoning steps enabled (type: {config.REASONING_STEP_TYPE})") logger.info(f"Min steps: {config.MIN_REASONING_STEPS}, Max steps: {config.MAX_REASONING_STEPS}") # Log bypass delta calculation status if config.BYPASS_DELTA_CALCULATION: logger.info("BYPASS_DELTA_CALCULATION enabled: Delta calculation is bypassed (delta = torch.zeros_like(h0))") # --- 5. Training Loop --- logger.info("***** Starting Training *****") logger.info(f" Num examples = {len(tokenized_datasets['train'])}") logger.info(f" Num Epochs = {config.EPOCHS}") logger.info(f" Instantaneous batch size per device = {config.BATCH_SIZE}") logger.info(f" Gradient Accumulation steps = {config.GRADIENT_ACCUMULATION_STEPS}") logger.info(f" Total optimization steps = {num_training_steps}") # Add note about subset training if applicable if subset_percentage < 100.0: logger.info(f" NOTE: Training on {subset_percentage:.1f}% of data - metrics may not represent full dataset performance") model.train() # Set model to training mode global_step = 0 total_loss = 0.0 # Use float for accumulated loss # Start from specified epoch (default is 0 if not provided) start_epoch = args.start_epoch for epoch in range(start_epoch, config.EPOCHS): logger.info(f"\n--- Starting Epoch {epoch+1}/{config.EPOCHS} ---") progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}", unit="batch") for step, batch in enumerate(progress_bar): # Move batch to device # Ensure only tensors are moved, handle potential non-tensor data if any batch_on_device = {} for k, v in batch.items(): if isinstance(v, torch.Tensor): batch_on_device[k] = v.to(config.DEVICE) # else: # Handle or skip non-tensor items if necessary # batch_on_device[k] = v try: # Forward pass with autocast for mixed precision with autocast('cuda', enabled=config.USE_MIXED_PRECISION): # Updated to fix deprecation warning outputs = model( input_ids=batch_on_device.get("input_ids"), attention_mask=batch_on_device.get("attention_mask"), token_type_ids=batch_on_device.get("token_type_ids"), start_positions=batch_on_device.get("start_positions"), end_positions=batch_on_device.get("end_positions"), use_memory=False # Disable memory during training steps ) loss = outputs.loss if loss is None: logger.warning(f"Step {step}: Loss is None. Skipping batch.") continue # Scale loss for gradient accumulation loss = loss / config.GRADIENT_ACCUMULATION_STEPS # Accumulate loss value for logging (before backward) total_loss += loss.item() # Scale loss and perform backward pass with AMP scaler.scale(loss).backward() except Exception as e: logger.error(f"Error during forward/backward pass at step {step}: {e}") # Optional: Add more detailed error handling or debugging info # logger.error(f"Batch keys: {batch.keys()}") # logger.error(f"Input IDs shape: {batch_on_device.get('input_ids').shape if batch_on_device.get('input_ids') is not None else 'None'}") raise e # Re-raise the exception to stop training # Optimizer step (perform step only after accumulating gradients) if (step + 1) % config.GRADIENT_ACCUMULATION_STEPS == 0 or step == len(train_dataloader) - 1: # Unscale before optimizer step (to check for infs/NaNs) scaler.unscale_(optimizer) # Clip gradients to avoid explosion (optional but recommended with mixed precision) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Step with scaler scaler.step(optimizer) scaler.update() optimizer.zero_grad() # Reset gradients for the next accumulation cycle global_step += 1 # Log progress periodically if global_step % 50 == 0: # Log every 50 optimization steps avg_loss = total_loss / 50 # Average loss over the last 50 steps logger.info(f"Step: {global_step}, Avg Loss: {avg_loss:.4f}") total_loss = 0.0 # Reset loss accumulator # Update progress bar description with current step loss and steps info postfix = { "Loss": f"{loss.item()*config.GRADIENT_ACCUMULATION_STEPS:.4f}", "Step": global_step } # Add steps info if using dynamic steps if config.USE_DYNAMIC_STEPS and hasattr(model, 'custom_outputs'): if 'steps_taken' in model.custom_outputs: postfix["Steps"] = model.custom_outputs['steps_taken'] progress_bar.set_postfix(postfix) # --- (Optional) Evaluation at the end of each epoch --- # logger.info(f"\n--- Evaluating after Epoch {epoch+1} ---") # model.eval() # # Add evaluation loop here (requires validation dataloader, postprocessing, metrics) # model.train() # Set back to train mode # --- Save Model Checkpoint --- output_dir = f"./rrn_qa_model_epoch_{epoch+1}" os.makedirs(output_dir, exist_ok=True) logger.info(f"--- Saving model checkpoint to {output_dir} ---") # --- Saving Logic for Enhanced Model --- try: logger.info(f"Saving enhanced model components to {output_dir}") # Save base model using its save_pretrained model.base_model.save_pretrained(os.path.join(output_dir, "base_model")) # Save all custom modules' state dicts torch.save(model.qa_head.state_dict(), os.path.join(output_dir, "qa_head.pth")) torch.save(model.retroactive_update_layer.state_dict(), os.path.join(output_dir, "retroactive_layer.pth")) torch.save(model.gating_mechanism.state_dict(), os.path.join(output_dir, "gating_mechanism.pth")) # Save step controller if using learned dynamic steps if config.USE_DYNAMIC_STEPS and config.REASONING_STEP_TYPE == "learned" and hasattr(model, "step_controller"): torch.save(model.step_controller.state_dict(), os.path.join(output_dir, "step_controller.pth")) logger.info("Saved step controller for learned dynamic steps") # Save tokenizer tokenizer.save_pretrained(output_dir) # Save configuration with open(os.path.join(output_dir, "enhanced_config.json"), "w") as f: import json config_dict = { "num_reasoning_steps": config.NUM_REASONING_STEPS, "delta_target_ratio": config.DELTA_TARGET_RATIO, "lambda_coherence": config.LAMBDA_COHERENCE, "lambda_delta_reg": config.LAMBDA_DELTA_REG, "memory_max_size": config.MEMORY_MAX_SIZE, "memory_retrieval_k": config.MEMORY_RETRIEVAL_K, "use_mixed_precision": config.USE_MIXED_PRECISION, "bypass_delta_calculation": config.BYPASS_DELTA_CALCULATION } # Add dynamic steps configuration if enabled if config.USE_DYNAMIC_STEPS: config_dict.update({ "use_dynamic_steps": config.USE_DYNAMIC_STEPS, "max_reasoning_steps": config.MAX_REASONING_STEPS, "min_reasoning_steps": config.MIN_REASONING_STEPS, "reasoning_step_type": config.REASONING_STEP_TYPE, "early_stop_threshold": config.EARLY_STOP_THRESHOLD }) json.dump(config_dict, f, indent=2) logger.info("Enhanced model checkpoint saved successfully.") except Exception as e: logger.error(f"Error saving checkpoint at epoch {epoch+1}: {e}") logger.info("\n***** Training finished *****") if __name__ == "__main__": # This is required for Windows to properly handle multiprocessing multiprocessing.freeze_support() main() # Example usage: # Train on full dataset (default): # python train.py # Train on 10% of data for faster iterations: # python train.py --subset_percentage 10.0 # Train on 1% for very quick testing: # python train.py --subset_percentage 1.0 # Resume training from checkpoint with subset: # python train.py --checkpoint ./rrn_qa_model_epoch_1 --start_epoch 1 --subset_percentage 25.0 # Test with bypassed delta calculation (sets delta = torch.zeros_like(h0)): # python train.py --bypass_delta --subset_percentage 1.0