rrn-qa / code /train.py
will4381's picture
Upload folder using huggingface_hub
3451ca0 verified
# train.py (Updated for Full Fine-tuning)
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.amp import autocast, GradScaler # For mixed precision training (updated import)
from transformers import AutoTokenizer, default_data_collator
from datasets import load_dataset
from tqdm.auto import tqdm # Progress bar
import os
import evaluate # For metrics
import logging # Optional: Better logging
import multiprocessing # For Windows multiprocessing support
import argparse # For command line arguments
# Import our custom modules and config
import config
from model import EnhancedRRN_QA_Model
# Setup basic logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description="Train RRN QA Model")
parser.add_argument("--checkpoint", type=str, help="Path to checkpoint directory to resume from")
parser.add_argument("--start_epoch", type=int, default=0, help="Epoch to start training from")
parser.add_argument(
"--subset_percentage",
type=float,
default=100.0,
help="Percentage of training data to use (1.0-100.0). Default: 100.0 (full dataset)"
)
parser.add_argument(
"--bypass_delta",
action="store_true",
help="Bypass RRN delta calculation (sets delta = torch.zeros_like(h0))"
)
args = parser.parse_args()
# Set bypass delta calculation flag if specified
if args.bypass_delta:
logger.info("BYPASS_DELTA_CALCULATION enabled: Setting delta = torch.zeros_like(h0)")
config.BYPASS_DELTA_CALCULATION = True
else:
config.BYPASS_DELTA_CALCULATION = False
# --- 1. Load Tokenizer and Model ---
if args.checkpoint:
logger.info(f"Loading tokenizer from checkpoint: {args.checkpoint}")
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint)
logger.info(f"Loading model from checkpoint: {args.checkpoint}")
# Initialize the model with base architecture
model = EnhancedRRN_QA_Model(os.path.join(args.checkpoint, "base_model"))
# Check for enhanced model components
gating_mechanism_path = os.path.join(args.checkpoint, "gating_mechanism.pth")
is_enhanced_checkpoint = os.path.exists(gating_mechanism_path)
# Load custom module weights
logger.info("Loading model components...")
model.qa_head.load_state_dict(torch.load(os.path.join(args.checkpoint, "qa_head.pth")))
model.retroactive_update_layer.load_state_dict(torch.load(os.path.join(args.checkpoint, "retroactive_layer.pth")))
# Load gating mechanism if available
if is_enhanced_checkpoint:
logger.info("Loading gating mechanism...")
model.gating_mechanism.load_state_dict(torch.load(gating_mechanism_path))
# Load step controller if available (for learned dynamic steps)
step_controller_path = os.path.join(args.checkpoint, "step_controller.pth")
if os.path.exists(step_controller_path) and hasattr(model, "step_controller"):
logger.info("Loading step controller for learned dynamic steps...")
model.step_controller.load_state_dict(torch.load(step_controller_path))
else:
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.BASE_MODEL_NAME)
logger.info("Instantiating Enhanced RRN QA Model for Full Fine-tuning...")
model = EnhancedRRN_QA_Model(config.BASE_MODEL_NAME)
model.to(config.DEVICE)
# --- 2. Load and Preprocess Dataset ---
logger.info("Loading SQuAD dataset...")
raw_datasets = load_dataset("squad")
# Handle dataset subsetting
subset_percentage = args.subset_percentage
if subset_percentage < 100.0:
original_train_size = len(raw_datasets["train"])
# Calculate subset size and validate
subset_percentage = max(0.1, min(100.0, subset_percentage)) # Clamp between 0.1% and 100%
train_subset_size = int(original_train_size * subset_percentage / 100)
train_subset_size = max(100, min(original_train_size, train_subset_size)) # Ensure reasonable bounds
# Create reproducible subset with fixed seed for consistency
subset_indices = torch.randperm(original_train_size, generator=torch.Generator().manual_seed(42))[:train_subset_size].tolist()
raw_datasets["train"] = raw_datasets["train"].select(subset_indices)
logger.info(f"Using {subset_percentage:.1f}% of training data ({train_subset_size}/{original_train_size} examples)")
else:
logger.info(f"Using full training dataset ({len(raw_datasets['train'])} examples)")
question_column_name = "question"
context_column_name = "context"
answer_column_name = "answers"
pad_on_right = tokenizer.padding_side == "right"
def prepare_train_features(examples):
examples[question_column_name] = [q.strip() for q in examples[question_column_name]]
tokenized_examples = tokenizer(
examples[question_column_name if pad_on_right else context_column_name],
examples[context_column_name if pad_on_right else question_column_name],
truncation="only_second" if pad_on_right else "only_first",
max_length=config.MAX_SEQ_LENGTH,
stride=config.DOC_STRIDE,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
offset_mapping = tokenized_examples.pop("offset_mapping")
tokenized_examples["start_positions"] = []
tokenized_examples["end_positions"] = []
for i, offsets in enumerate(offset_mapping):
input_ids = tokenized_examples["input_ids"][i]
cls_index = input_ids.index(tokenizer.cls_token_id)
sequence_ids = tokenized_examples.sequence_ids(i)
sample_index = sample_mapping[i]
answers = examples[answer_column_name][sample_index]
if len(answers["answer_start"]) == 0:
tokenized_examples["start_positions"].append(cls_index)
tokenized_examples["end_positions"].append(cls_index)
else:
start_char = answers["answer_start"][0]
end_char = start_char + len(answers["text"][0])
token_start_index = 0
while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
token_start_index += 1
token_end_index = len(input_ids) - 1
while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
token_end_index -= 1
if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
tokenized_examples["start_positions"].append(cls_index)
tokenized_examples["end_positions"].append(cls_index)
else:
while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
token_start_index += 1
tokenized_examples["start_positions"].append(token_start_index - 1)
while offsets[token_end_index][1] >= end_char:
token_end_index -= 1
tokenized_examples["end_positions"].append(token_end_index + 1)
return tokenized_examples
logger.info("Preprocessing datasets...")
# Use single process on Windows to avoid multiprocessing issues
tokenized_datasets = raw_datasets.map(
prepare_train_features,
batched=True,
remove_columns=raw_datasets["train"].column_names,
num_proc=1 # Use single process to avoid Windows multiprocessing issues
)
data_collator = default_data_collator
train_dataloader = DataLoader(
tokenized_datasets["train"],
shuffle=True,
collate_fn=data_collator,
batch_size=config.BATCH_SIZE
)
# Consider adding validation dataloader setup here as well
# eval_dataloader = DataLoader(...)
# --- 3. Setup Optimizer ---
logger.info("Setting up optimizer for FULL model fine-tuning...")
# Optimize all parameters since PEFT is disabled
optimizer = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE)
logger.info(f"Optimizer: AdamW with LR={config.LEARNING_RATE}")
# Calculate total steps considering gradient accumulation
num_update_steps_per_epoch = len(train_dataloader) // config.GRADIENT_ACCUMULATION_STEPS
num_training_steps = config.EPOCHS * num_update_steps_per_epoch
logger.info(f"Total optimization steps: {num_training_steps}")
# --- 4. Initialize Mixed Precision Training ---
# Initialize gradient scaler for mixed precision training
scaler = GradScaler('cuda', enabled=config.USE_MIXED_PRECISION) # Updated to fix deprecation warning
# Log mixed precision and dynamic steps status
if config.USE_MIXED_PRECISION:
logger.info("Mixed precision training (FP16) enabled")
if config.USE_DYNAMIC_STEPS:
logger.info(f"Dynamic reasoning steps enabled (type: {config.REASONING_STEP_TYPE})")
logger.info(f"Min steps: {config.MIN_REASONING_STEPS}, Max steps: {config.MAX_REASONING_STEPS}")
# Log bypass delta calculation status
if config.BYPASS_DELTA_CALCULATION:
logger.info("BYPASS_DELTA_CALCULATION enabled: Delta calculation is bypassed (delta = torch.zeros_like(h0))")
# --- 5. Training Loop ---
logger.info("***** Starting Training *****")
logger.info(f" Num examples = {len(tokenized_datasets['train'])}")
logger.info(f" Num Epochs = {config.EPOCHS}")
logger.info(f" Instantaneous batch size per device = {config.BATCH_SIZE}")
logger.info(f" Gradient Accumulation steps = {config.GRADIENT_ACCUMULATION_STEPS}")
logger.info(f" Total optimization steps = {num_training_steps}")
# Add note about subset training if applicable
if subset_percentage < 100.0:
logger.info(f" NOTE: Training on {subset_percentage:.1f}% of data - metrics may not represent full dataset performance")
model.train() # Set model to training mode
global_step = 0
total_loss = 0.0 # Use float for accumulated loss
# Start from specified epoch (default is 0 if not provided)
start_epoch = args.start_epoch
for epoch in range(start_epoch, config.EPOCHS):
logger.info(f"\n--- Starting Epoch {epoch+1}/{config.EPOCHS} ---")
progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}", unit="batch")
for step, batch in enumerate(progress_bar):
# Move batch to device
# Ensure only tensors are moved, handle potential non-tensor data if any
batch_on_device = {}
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch_on_device[k] = v.to(config.DEVICE)
# else: # Handle or skip non-tensor items if necessary
# batch_on_device[k] = v
try:
# Forward pass with autocast for mixed precision
with autocast('cuda', enabled=config.USE_MIXED_PRECISION): # Updated to fix deprecation warning
outputs = model(
input_ids=batch_on_device.get("input_ids"),
attention_mask=batch_on_device.get("attention_mask"),
token_type_ids=batch_on_device.get("token_type_ids"),
start_positions=batch_on_device.get("start_positions"),
end_positions=batch_on_device.get("end_positions"),
use_memory=False # Disable memory during training steps
)
loss = outputs.loss
if loss is None:
logger.warning(f"Step {step}: Loss is None. Skipping batch.")
continue
# Scale loss for gradient accumulation
loss = loss / config.GRADIENT_ACCUMULATION_STEPS
# Accumulate loss value for logging (before backward)
total_loss += loss.item()
# Scale loss and perform backward pass with AMP
scaler.scale(loss).backward()
except Exception as e:
logger.error(f"Error during forward/backward pass at step {step}: {e}")
# Optional: Add more detailed error handling or debugging info
# logger.error(f"Batch keys: {batch.keys()}")
# logger.error(f"Input IDs shape: {batch_on_device.get('input_ids').shape if batch_on_device.get('input_ids') is not None else 'None'}")
raise e # Re-raise the exception to stop training
# Optimizer step (perform step only after accumulating gradients)
if (step + 1) % config.GRADIENT_ACCUMULATION_STEPS == 0 or step == len(train_dataloader) - 1:
# Unscale before optimizer step (to check for infs/NaNs)
scaler.unscale_(optimizer)
# Clip gradients to avoid explosion (optional but recommended with mixed precision)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Step with scaler
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad() # Reset gradients for the next accumulation cycle
global_step += 1
# Log progress periodically
if global_step % 50 == 0: # Log every 50 optimization steps
avg_loss = total_loss / 50 # Average loss over the last 50 steps
logger.info(f"Step: {global_step}, Avg Loss: {avg_loss:.4f}")
total_loss = 0.0 # Reset loss accumulator
# Update progress bar description with current step loss and steps info
postfix = {
"Loss": f"{loss.item()*config.GRADIENT_ACCUMULATION_STEPS:.4f}",
"Step": global_step
}
# Add steps info if using dynamic steps
if config.USE_DYNAMIC_STEPS and hasattr(model, 'custom_outputs'):
if 'steps_taken' in model.custom_outputs:
postfix["Steps"] = model.custom_outputs['steps_taken']
progress_bar.set_postfix(postfix)
# --- (Optional) Evaluation at the end of each epoch ---
# logger.info(f"\n--- Evaluating after Epoch {epoch+1} ---")
# model.eval()
# # Add evaluation loop here (requires validation dataloader, postprocessing, metrics)
# model.train() # Set back to train mode
# --- Save Model Checkpoint ---
output_dir = f"./rrn_qa_model_epoch_{epoch+1}"
os.makedirs(output_dir, exist_ok=True)
logger.info(f"--- Saving model checkpoint to {output_dir} ---")
# --- Saving Logic for Enhanced Model ---
try:
logger.info(f"Saving enhanced model components to {output_dir}")
# Save base model using its save_pretrained
model.base_model.save_pretrained(os.path.join(output_dir, "base_model"))
# Save all custom modules' state dicts
torch.save(model.qa_head.state_dict(), os.path.join(output_dir, "qa_head.pth"))
torch.save(model.retroactive_update_layer.state_dict(), os.path.join(output_dir, "retroactive_layer.pth"))
torch.save(model.gating_mechanism.state_dict(), os.path.join(output_dir, "gating_mechanism.pth"))
# Save step controller if using learned dynamic steps
if config.USE_DYNAMIC_STEPS and config.REASONING_STEP_TYPE == "learned" and hasattr(model, "step_controller"):
torch.save(model.step_controller.state_dict(), os.path.join(output_dir, "step_controller.pth"))
logger.info("Saved step controller for learned dynamic steps")
# Save tokenizer
tokenizer.save_pretrained(output_dir)
# Save configuration
with open(os.path.join(output_dir, "enhanced_config.json"), "w") as f:
import json
config_dict = {
"num_reasoning_steps": config.NUM_REASONING_STEPS,
"delta_target_ratio": config.DELTA_TARGET_RATIO,
"lambda_coherence": config.LAMBDA_COHERENCE,
"lambda_delta_reg": config.LAMBDA_DELTA_REG,
"memory_max_size": config.MEMORY_MAX_SIZE,
"memory_retrieval_k": config.MEMORY_RETRIEVAL_K,
"use_mixed_precision": config.USE_MIXED_PRECISION,
"bypass_delta_calculation": config.BYPASS_DELTA_CALCULATION
}
# Add dynamic steps configuration if enabled
if config.USE_DYNAMIC_STEPS:
config_dict.update({
"use_dynamic_steps": config.USE_DYNAMIC_STEPS,
"max_reasoning_steps": config.MAX_REASONING_STEPS,
"min_reasoning_steps": config.MIN_REASONING_STEPS,
"reasoning_step_type": config.REASONING_STEP_TYPE,
"early_stop_threshold": config.EARLY_STOP_THRESHOLD
})
json.dump(config_dict, f, indent=2)
logger.info("Enhanced model checkpoint saved successfully.")
except Exception as e:
logger.error(f"Error saving checkpoint at epoch {epoch+1}: {e}")
logger.info("\n***** Training finished *****")
if __name__ == "__main__":
# This is required for Windows to properly handle multiprocessing
multiprocessing.freeze_support()
main()
# Example usage:
# Train on full dataset (default):
# python train.py
# Train on 10% of data for faster iterations:
# python train.py --subset_percentage 10.0
# Train on 1% for very quick testing:
# python train.py --subset_percentage 1.0
# Resume training from checkpoint with subset:
# python train.py --checkpoint ./rrn_qa_model_epoch_1 --start_epoch 1 --subset_percentage 25.0
# Test with bypassed delta calculation (sets delta = torch.zeros_like(h0)):
# python train.py --bypass_delta --subset_percentage 1.0