File size: 19,151 Bytes
3451ca0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 |
# train.py (Updated for Full Fine-tuning)
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.amp import autocast, GradScaler # For mixed precision training (updated import)
from transformers import AutoTokenizer, default_data_collator
from datasets import load_dataset
from tqdm.auto import tqdm # Progress bar
import os
import evaluate # For metrics
import logging # Optional: Better logging
import multiprocessing # For Windows multiprocessing support
import argparse # For command line arguments
# Import our custom modules and config
import config
from model import EnhancedRRN_QA_Model
# Setup basic logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description="Train RRN QA Model")
parser.add_argument("--checkpoint", type=str, help="Path to checkpoint directory to resume from")
parser.add_argument("--start_epoch", type=int, default=0, help="Epoch to start training from")
parser.add_argument(
"--subset_percentage",
type=float,
default=100.0,
help="Percentage of training data to use (1.0-100.0). Default: 100.0 (full dataset)"
)
parser.add_argument(
"--bypass_delta",
action="store_true",
help="Bypass RRN delta calculation (sets delta = torch.zeros_like(h0))"
)
args = parser.parse_args()
# Set bypass delta calculation flag if specified
if args.bypass_delta:
logger.info("BYPASS_DELTA_CALCULATION enabled: Setting delta = torch.zeros_like(h0)")
config.BYPASS_DELTA_CALCULATION = True
else:
config.BYPASS_DELTA_CALCULATION = False
# --- 1. Load Tokenizer and Model ---
if args.checkpoint:
logger.info(f"Loading tokenizer from checkpoint: {args.checkpoint}")
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint)
logger.info(f"Loading model from checkpoint: {args.checkpoint}")
# Initialize the model with base architecture
model = EnhancedRRN_QA_Model(os.path.join(args.checkpoint, "base_model"))
# Check for enhanced model components
gating_mechanism_path = os.path.join(args.checkpoint, "gating_mechanism.pth")
is_enhanced_checkpoint = os.path.exists(gating_mechanism_path)
# Load custom module weights
logger.info("Loading model components...")
model.qa_head.load_state_dict(torch.load(os.path.join(args.checkpoint, "qa_head.pth")))
model.retroactive_update_layer.load_state_dict(torch.load(os.path.join(args.checkpoint, "retroactive_layer.pth")))
# Load gating mechanism if available
if is_enhanced_checkpoint:
logger.info("Loading gating mechanism...")
model.gating_mechanism.load_state_dict(torch.load(gating_mechanism_path))
# Load step controller if available (for learned dynamic steps)
step_controller_path = os.path.join(args.checkpoint, "step_controller.pth")
if os.path.exists(step_controller_path) and hasattr(model, "step_controller"):
logger.info("Loading step controller for learned dynamic steps...")
model.step_controller.load_state_dict(torch.load(step_controller_path))
else:
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.BASE_MODEL_NAME)
logger.info("Instantiating Enhanced RRN QA Model for Full Fine-tuning...")
model = EnhancedRRN_QA_Model(config.BASE_MODEL_NAME)
model.to(config.DEVICE)
# --- 2. Load and Preprocess Dataset ---
logger.info("Loading SQuAD dataset...")
raw_datasets = load_dataset("squad")
# Handle dataset subsetting
subset_percentage = args.subset_percentage
if subset_percentage < 100.0:
original_train_size = len(raw_datasets["train"])
# Calculate subset size and validate
subset_percentage = max(0.1, min(100.0, subset_percentage)) # Clamp between 0.1% and 100%
train_subset_size = int(original_train_size * subset_percentage / 100)
train_subset_size = max(100, min(original_train_size, train_subset_size)) # Ensure reasonable bounds
# Create reproducible subset with fixed seed for consistency
subset_indices = torch.randperm(original_train_size, generator=torch.Generator().manual_seed(42))[:train_subset_size].tolist()
raw_datasets["train"] = raw_datasets["train"].select(subset_indices)
logger.info(f"Using {subset_percentage:.1f}% of training data ({train_subset_size}/{original_train_size} examples)")
else:
logger.info(f"Using full training dataset ({len(raw_datasets['train'])} examples)")
question_column_name = "question"
context_column_name = "context"
answer_column_name = "answers"
pad_on_right = tokenizer.padding_side == "right"
def prepare_train_features(examples):
examples[question_column_name] = [q.strip() for q in examples[question_column_name]]
tokenized_examples = tokenizer(
examples[question_column_name if pad_on_right else context_column_name],
examples[context_column_name if pad_on_right else question_column_name],
truncation="only_second" if pad_on_right else "only_first",
max_length=config.MAX_SEQ_LENGTH,
stride=config.DOC_STRIDE,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
offset_mapping = tokenized_examples.pop("offset_mapping")
tokenized_examples["start_positions"] = []
tokenized_examples["end_positions"] = []
for i, offsets in enumerate(offset_mapping):
input_ids = tokenized_examples["input_ids"][i]
cls_index = input_ids.index(tokenizer.cls_token_id)
sequence_ids = tokenized_examples.sequence_ids(i)
sample_index = sample_mapping[i]
answers = examples[answer_column_name][sample_index]
if len(answers["answer_start"]) == 0:
tokenized_examples["start_positions"].append(cls_index)
tokenized_examples["end_positions"].append(cls_index)
else:
start_char = answers["answer_start"][0]
end_char = start_char + len(answers["text"][0])
token_start_index = 0
while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
token_start_index += 1
token_end_index = len(input_ids) - 1
while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
token_end_index -= 1
if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
tokenized_examples["start_positions"].append(cls_index)
tokenized_examples["end_positions"].append(cls_index)
else:
while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
token_start_index += 1
tokenized_examples["start_positions"].append(token_start_index - 1)
while offsets[token_end_index][1] >= end_char:
token_end_index -= 1
tokenized_examples["end_positions"].append(token_end_index + 1)
return tokenized_examples
logger.info("Preprocessing datasets...")
# Use single process on Windows to avoid multiprocessing issues
tokenized_datasets = raw_datasets.map(
prepare_train_features,
batched=True,
remove_columns=raw_datasets["train"].column_names,
num_proc=1 # Use single process to avoid Windows multiprocessing issues
)
data_collator = default_data_collator
train_dataloader = DataLoader(
tokenized_datasets["train"],
shuffle=True,
collate_fn=data_collator,
batch_size=config.BATCH_SIZE
)
# Consider adding validation dataloader setup here as well
# eval_dataloader = DataLoader(...)
# --- 3. Setup Optimizer ---
logger.info("Setting up optimizer for FULL model fine-tuning...")
# Optimize all parameters since PEFT is disabled
optimizer = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE)
logger.info(f"Optimizer: AdamW with LR={config.LEARNING_RATE}")
# Calculate total steps considering gradient accumulation
num_update_steps_per_epoch = len(train_dataloader) // config.GRADIENT_ACCUMULATION_STEPS
num_training_steps = config.EPOCHS * num_update_steps_per_epoch
logger.info(f"Total optimization steps: {num_training_steps}")
# --- 4. Initialize Mixed Precision Training ---
# Initialize gradient scaler for mixed precision training
scaler = GradScaler('cuda', enabled=config.USE_MIXED_PRECISION) # Updated to fix deprecation warning
# Log mixed precision and dynamic steps status
if config.USE_MIXED_PRECISION:
logger.info("Mixed precision training (FP16) enabled")
if config.USE_DYNAMIC_STEPS:
logger.info(f"Dynamic reasoning steps enabled (type: {config.REASONING_STEP_TYPE})")
logger.info(f"Min steps: {config.MIN_REASONING_STEPS}, Max steps: {config.MAX_REASONING_STEPS}")
# Log bypass delta calculation status
if config.BYPASS_DELTA_CALCULATION:
logger.info("BYPASS_DELTA_CALCULATION enabled: Delta calculation is bypassed (delta = torch.zeros_like(h0))")
# --- 5. Training Loop ---
logger.info("***** Starting Training *****")
logger.info(f" Num examples = {len(tokenized_datasets['train'])}")
logger.info(f" Num Epochs = {config.EPOCHS}")
logger.info(f" Instantaneous batch size per device = {config.BATCH_SIZE}")
logger.info(f" Gradient Accumulation steps = {config.GRADIENT_ACCUMULATION_STEPS}")
logger.info(f" Total optimization steps = {num_training_steps}")
# Add note about subset training if applicable
if subset_percentage < 100.0:
logger.info(f" NOTE: Training on {subset_percentage:.1f}% of data - metrics may not represent full dataset performance")
model.train() # Set model to training mode
global_step = 0
total_loss = 0.0 # Use float for accumulated loss
# Start from specified epoch (default is 0 if not provided)
start_epoch = args.start_epoch
for epoch in range(start_epoch, config.EPOCHS):
logger.info(f"\n--- Starting Epoch {epoch+1}/{config.EPOCHS} ---")
progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}", unit="batch")
for step, batch in enumerate(progress_bar):
# Move batch to device
# Ensure only tensors are moved, handle potential non-tensor data if any
batch_on_device = {}
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch_on_device[k] = v.to(config.DEVICE)
# else: # Handle or skip non-tensor items if necessary
# batch_on_device[k] = v
try:
# Forward pass with autocast for mixed precision
with autocast('cuda', enabled=config.USE_MIXED_PRECISION): # Updated to fix deprecation warning
outputs = model(
input_ids=batch_on_device.get("input_ids"),
attention_mask=batch_on_device.get("attention_mask"),
token_type_ids=batch_on_device.get("token_type_ids"),
start_positions=batch_on_device.get("start_positions"),
end_positions=batch_on_device.get("end_positions"),
use_memory=False # Disable memory during training steps
)
loss = outputs.loss
if loss is None:
logger.warning(f"Step {step}: Loss is None. Skipping batch.")
continue
# Scale loss for gradient accumulation
loss = loss / config.GRADIENT_ACCUMULATION_STEPS
# Accumulate loss value for logging (before backward)
total_loss += loss.item()
# Scale loss and perform backward pass with AMP
scaler.scale(loss).backward()
except Exception as e:
logger.error(f"Error during forward/backward pass at step {step}: {e}")
# Optional: Add more detailed error handling or debugging info
# logger.error(f"Batch keys: {batch.keys()}")
# logger.error(f"Input IDs shape: {batch_on_device.get('input_ids').shape if batch_on_device.get('input_ids') is not None else 'None'}")
raise e # Re-raise the exception to stop training
# Optimizer step (perform step only after accumulating gradients)
if (step + 1) % config.GRADIENT_ACCUMULATION_STEPS == 0 or step == len(train_dataloader) - 1:
# Unscale before optimizer step (to check for infs/NaNs)
scaler.unscale_(optimizer)
# Clip gradients to avoid explosion (optional but recommended with mixed precision)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Step with scaler
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad() # Reset gradients for the next accumulation cycle
global_step += 1
# Log progress periodically
if global_step % 50 == 0: # Log every 50 optimization steps
avg_loss = total_loss / 50 # Average loss over the last 50 steps
logger.info(f"Step: {global_step}, Avg Loss: {avg_loss:.4f}")
total_loss = 0.0 # Reset loss accumulator
# Update progress bar description with current step loss and steps info
postfix = {
"Loss": f"{loss.item()*config.GRADIENT_ACCUMULATION_STEPS:.4f}",
"Step": global_step
}
# Add steps info if using dynamic steps
if config.USE_DYNAMIC_STEPS and hasattr(model, 'custom_outputs'):
if 'steps_taken' in model.custom_outputs:
postfix["Steps"] = model.custom_outputs['steps_taken']
progress_bar.set_postfix(postfix)
# --- (Optional) Evaluation at the end of each epoch ---
# logger.info(f"\n--- Evaluating after Epoch {epoch+1} ---")
# model.eval()
# # Add evaluation loop here (requires validation dataloader, postprocessing, metrics)
# model.train() # Set back to train mode
# --- Save Model Checkpoint ---
output_dir = f"./rrn_qa_model_epoch_{epoch+1}"
os.makedirs(output_dir, exist_ok=True)
logger.info(f"--- Saving model checkpoint to {output_dir} ---")
# --- Saving Logic for Enhanced Model ---
try:
logger.info(f"Saving enhanced model components to {output_dir}")
# Save base model using its save_pretrained
model.base_model.save_pretrained(os.path.join(output_dir, "base_model"))
# Save all custom modules' state dicts
torch.save(model.qa_head.state_dict(), os.path.join(output_dir, "qa_head.pth"))
torch.save(model.retroactive_update_layer.state_dict(), os.path.join(output_dir, "retroactive_layer.pth"))
torch.save(model.gating_mechanism.state_dict(), os.path.join(output_dir, "gating_mechanism.pth"))
# Save step controller if using learned dynamic steps
if config.USE_DYNAMIC_STEPS and config.REASONING_STEP_TYPE == "learned" and hasattr(model, "step_controller"):
torch.save(model.step_controller.state_dict(), os.path.join(output_dir, "step_controller.pth"))
logger.info("Saved step controller for learned dynamic steps")
# Save tokenizer
tokenizer.save_pretrained(output_dir)
# Save configuration
with open(os.path.join(output_dir, "enhanced_config.json"), "w") as f:
import json
config_dict = {
"num_reasoning_steps": config.NUM_REASONING_STEPS,
"delta_target_ratio": config.DELTA_TARGET_RATIO,
"lambda_coherence": config.LAMBDA_COHERENCE,
"lambda_delta_reg": config.LAMBDA_DELTA_REG,
"memory_max_size": config.MEMORY_MAX_SIZE,
"memory_retrieval_k": config.MEMORY_RETRIEVAL_K,
"use_mixed_precision": config.USE_MIXED_PRECISION,
"bypass_delta_calculation": config.BYPASS_DELTA_CALCULATION
}
# Add dynamic steps configuration if enabled
if config.USE_DYNAMIC_STEPS:
config_dict.update({
"use_dynamic_steps": config.USE_DYNAMIC_STEPS,
"max_reasoning_steps": config.MAX_REASONING_STEPS,
"min_reasoning_steps": config.MIN_REASONING_STEPS,
"reasoning_step_type": config.REASONING_STEP_TYPE,
"early_stop_threshold": config.EARLY_STOP_THRESHOLD
})
json.dump(config_dict, f, indent=2)
logger.info("Enhanced model checkpoint saved successfully.")
except Exception as e:
logger.error(f"Error saving checkpoint at epoch {epoch+1}: {e}")
logger.info("\n***** Training finished *****")
if __name__ == "__main__":
# This is required for Windows to properly handle multiprocessing
multiprocessing.freeze_support()
main()
# Example usage:
# Train on full dataset (default):
# python train.py
# Train on 10% of data for faster iterations:
# python train.py --subset_percentage 10.0
# Train on 1% for very quick testing:
# python train.py --subset_percentage 1.0
# Resume training from checkpoint with subset:
# python train.py --checkpoint ./rrn_qa_model_epoch_1 --start_epoch 1 --subset_percentage 25.0
# Test with bypassed delta calculation (sets delta = torch.zeros_like(h0)):
# python train.py --bypass_delta --subset_percentage 1.0
|