|
|
|
|
|
import torch
|
|
|
import torch.optim as optim
|
|
|
from torch.utils.data import DataLoader
|
|
|
from torch.amp import autocast, GradScaler
|
|
|
from transformers import AutoTokenizer, default_data_collator
|
|
|
from datasets import load_dataset
|
|
|
from tqdm.auto import tqdm
|
|
|
import os
|
|
|
import evaluate
|
|
|
import logging
|
|
|
import multiprocessing
|
|
|
import argparse
|
|
|
|
|
|
|
|
|
import config
|
|
|
from model import EnhancedRRN_QA_Model
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def main():
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Train RRN QA Model")
|
|
|
parser.add_argument("--checkpoint", type=str, help="Path to checkpoint directory to resume from")
|
|
|
parser.add_argument("--start_epoch", type=int, default=0, help="Epoch to start training from")
|
|
|
parser.add_argument(
|
|
|
"--subset_percentage",
|
|
|
type=float,
|
|
|
default=100.0,
|
|
|
help="Percentage of training data to use (1.0-100.0). Default: 100.0 (full dataset)"
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--bypass_delta",
|
|
|
action="store_true",
|
|
|
help="Bypass RRN delta calculation (sets delta = torch.zeros_like(h0))"
|
|
|
)
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
if args.bypass_delta:
|
|
|
logger.info("BYPASS_DELTA_CALCULATION enabled: Setting delta = torch.zeros_like(h0)")
|
|
|
config.BYPASS_DELTA_CALCULATION = True
|
|
|
else:
|
|
|
config.BYPASS_DELTA_CALCULATION = False
|
|
|
|
|
|
|
|
|
if args.checkpoint:
|
|
|
logger.info(f"Loading tokenizer from checkpoint: {args.checkpoint}")
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint)
|
|
|
|
|
|
logger.info(f"Loading model from checkpoint: {args.checkpoint}")
|
|
|
|
|
|
model = EnhancedRRN_QA_Model(os.path.join(args.checkpoint, "base_model"))
|
|
|
|
|
|
|
|
|
gating_mechanism_path = os.path.join(args.checkpoint, "gating_mechanism.pth")
|
|
|
is_enhanced_checkpoint = os.path.exists(gating_mechanism_path)
|
|
|
|
|
|
|
|
|
logger.info("Loading model components...")
|
|
|
model.qa_head.load_state_dict(torch.load(os.path.join(args.checkpoint, "qa_head.pth")))
|
|
|
model.retroactive_update_layer.load_state_dict(torch.load(os.path.join(args.checkpoint, "retroactive_layer.pth")))
|
|
|
|
|
|
|
|
|
if is_enhanced_checkpoint:
|
|
|
logger.info("Loading gating mechanism...")
|
|
|
model.gating_mechanism.load_state_dict(torch.load(gating_mechanism_path))
|
|
|
|
|
|
|
|
|
step_controller_path = os.path.join(args.checkpoint, "step_controller.pth")
|
|
|
if os.path.exists(step_controller_path) and hasattr(model, "step_controller"):
|
|
|
logger.info("Loading step controller for learned dynamic steps...")
|
|
|
model.step_controller.load_state_dict(torch.load(step_controller_path))
|
|
|
else:
|
|
|
logger.info("Loading tokenizer...")
|
|
|
tokenizer = AutoTokenizer.from_pretrained(config.BASE_MODEL_NAME)
|
|
|
|
|
|
logger.info("Instantiating Enhanced RRN QA Model for Full Fine-tuning...")
|
|
|
model = EnhancedRRN_QA_Model(config.BASE_MODEL_NAME)
|
|
|
|
|
|
model.to(config.DEVICE)
|
|
|
|
|
|
|
|
|
logger.info("Loading SQuAD dataset...")
|
|
|
raw_datasets = load_dataset("squad")
|
|
|
|
|
|
|
|
|
subset_percentage = args.subset_percentage
|
|
|
if subset_percentage < 100.0:
|
|
|
original_train_size = len(raw_datasets["train"])
|
|
|
|
|
|
|
|
|
subset_percentage = max(0.1, min(100.0, subset_percentage))
|
|
|
train_subset_size = int(original_train_size * subset_percentage / 100)
|
|
|
train_subset_size = max(100, min(original_train_size, train_subset_size))
|
|
|
|
|
|
|
|
|
subset_indices = torch.randperm(original_train_size, generator=torch.Generator().manual_seed(42))[:train_subset_size].tolist()
|
|
|
raw_datasets["train"] = raw_datasets["train"].select(subset_indices)
|
|
|
|
|
|
logger.info(f"Using {subset_percentage:.1f}% of training data ({train_subset_size}/{original_train_size} examples)")
|
|
|
else:
|
|
|
logger.info(f"Using full training dataset ({len(raw_datasets['train'])} examples)")
|
|
|
|
|
|
question_column_name = "question"
|
|
|
context_column_name = "context"
|
|
|
answer_column_name = "answers"
|
|
|
pad_on_right = tokenizer.padding_side == "right"
|
|
|
|
|
|
def prepare_train_features(examples):
|
|
|
examples[question_column_name] = [q.strip() for q in examples[question_column_name]]
|
|
|
tokenized_examples = tokenizer(
|
|
|
examples[question_column_name if pad_on_right else context_column_name],
|
|
|
examples[context_column_name if pad_on_right else question_column_name],
|
|
|
truncation="only_second" if pad_on_right else "only_first",
|
|
|
max_length=config.MAX_SEQ_LENGTH,
|
|
|
stride=config.DOC_STRIDE,
|
|
|
return_overflowing_tokens=True,
|
|
|
return_offsets_mapping=True,
|
|
|
padding="max_length",
|
|
|
)
|
|
|
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
|
|
|
offset_mapping = tokenized_examples.pop("offset_mapping")
|
|
|
tokenized_examples["start_positions"] = []
|
|
|
tokenized_examples["end_positions"] = []
|
|
|
|
|
|
for i, offsets in enumerate(offset_mapping):
|
|
|
input_ids = tokenized_examples["input_ids"][i]
|
|
|
cls_index = input_ids.index(tokenizer.cls_token_id)
|
|
|
sequence_ids = tokenized_examples.sequence_ids(i)
|
|
|
sample_index = sample_mapping[i]
|
|
|
answers = examples[answer_column_name][sample_index]
|
|
|
|
|
|
if len(answers["answer_start"]) == 0:
|
|
|
tokenized_examples["start_positions"].append(cls_index)
|
|
|
tokenized_examples["end_positions"].append(cls_index)
|
|
|
else:
|
|
|
start_char = answers["answer_start"][0]
|
|
|
end_char = start_char + len(answers["text"][0])
|
|
|
token_start_index = 0
|
|
|
while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
|
|
|
token_start_index += 1
|
|
|
token_end_index = len(input_ids) - 1
|
|
|
while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
|
|
|
token_end_index -= 1
|
|
|
if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
|
|
|
tokenized_examples["start_positions"].append(cls_index)
|
|
|
tokenized_examples["end_positions"].append(cls_index)
|
|
|
else:
|
|
|
while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
|
|
|
token_start_index += 1
|
|
|
tokenized_examples["start_positions"].append(token_start_index - 1)
|
|
|
while offsets[token_end_index][1] >= end_char:
|
|
|
token_end_index -= 1
|
|
|
tokenized_examples["end_positions"].append(token_end_index + 1)
|
|
|
return tokenized_examples
|
|
|
|
|
|
logger.info("Preprocessing datasets...")
|
|
|
|
|
|
tokenized_datasets = raw_datasets.map(
|
|
|
prepare_train_features,
|
|
|
batched=True,
|
|
|
remove_columns=raw_datasets["train"].column_names,
|
|
|
num_proc=1
|
|
|
)
|
|
|
|
|
|
data_collator = default_data_collator
|
|
|
train_dataloader = DataLoader(
|
|
|
tokenized_datasets["train"],
|
|
|
shuffle=True,
|
|
|
collate_fn=data_collator,
|
|
|
batch_size=config.BATCH_SIZE
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Setting up optimizer for FULL model fine-tuning...")
|
|
|
|
|
|
optimizer = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE)
|
|
|
|
|
|
logger.info(f"Optimizer: AdamW with LR={config.LEARNING_RATE}")
|
|
|
|
|
|
num_update_steps_per_epoch = len(train_dataloader) // config.GRADIENT_ACCUMULATION_STEPS
|
|
|
num_training_steps = config.EPOCHS * num_update_steps_per_epoch
|
|
|
logger.info(f"Total optimization steps: {num_training_steps}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scaler = GradScaler('cuda', enabled=config.USE_MIXED_PRECISION)
|
|
|
|
|
|
|
|
|
if config.USE_MIXED_PRECISION:
|
|
|
logger.info("Mixed precision training (FP16) enabled")
|
|
|
if config.USE_DYNAMIC_STEPS:
|
|
|
logger.info(f"Dynamic reasoning steps enabled (type: {config.REASONING_STEP_TYPE})")
|
|
|
logger.info(f"Min steps: {config.MIN_REASONING_STEPS}, Max steps: {config.MAX_REASONING_STEPS}")
|
|
|
|
|
|
|
|
|
if config.BYPASS_DELTA_CALCULATION:
|
|
|
logger.info("BYPASS_DELTA_CALCULATION enabled: Delta calculation is bypassed (delta = torch.zeros_like(h0))")
|
|
|
|
|
|
|
|
|
logger.info("***** Starting Training *****")
|
|
|
logger.info(f" Num examples = {len(tokenized_datasets['train'])}")
|
|
|
logger.info(f" Num Epochs = {config.EPOCHS}")
|
|
|
logger.info(f" Instantaneous batch size per device = {config.BATCH_SIZE}")
|
|
|
logger.info(f" Gradient Accumulation steps = {config.GRADIENT_ACCUMULATION_STEPS}")
|
|
|
logger.info(f" Total optimization steps = {num_training_steps}")
|
|
|
|
|
|
|
|
|
if subset_percentage < 100.0:
|
|
|
logger.info(f" NOTE: Training on {subset_percentage:.1f}% of data - metrics may not represent full dataset performance")
|
|
|
|
|
|
|
|
|
model.train()
|
|
|
global_step = 0
|
|
|
total_loss = 0.0
|
|
|
|
|
|
|
|
|
start_epoch = args.start_epoch
|
|
|
|
|
|
for epoch in range(start_epoch, config.EPOCHS):
|
|
|
logger.info(f"\n--- Starting Epoch {epoch+1}/{config.EPOCHS} ---")
|
|
|
progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}", unit="batch")
|
|
|
|
|
|
for step, batch in enumerate(progress_bar):
|
|
|
|
|
|
|
|
|
batch_on_device = {}
|
|
|
for k, v in batch.items():
|
|
|
if isinstance(v, torch.Tensor):
|
|
|
batch_on_device[k] = v.to(config.DEVICE)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
with autocast('cuda', enabled=config.USE_MIXED_PRECISION):
|
|
|
outputs = model(
|
|
|
input_ids=batch_on_device.get("input_ids"),
|
|
|
attention_mask=batch_on_device.get("attention_mask"),
|
|
|
token_type_ids=batch_on_device.get("token_type_ids"),
|
|
|
start_positions=batch_on_device.get("start_positions"),
|
|
|
end_positions=batch_on_device.get("end_positions"),
|
|
|
use_memory=False
|
|
|
)
|
|
|
loss = outputs.loss
|
|
|
|
|
|
if loss is None:
|
|
|
logger.warning(f"Step {step}: Loss is None. Skipping batch.")
|
|
|
continue
|
|
|
|
|
|
|
|
|
loss = loss / config.GRADIENT_ACCUMULATION_STEPS
|
|
|
|
|
|
|
|
|
total_loss += loss.item()
|
|
|
|
|
|
|
|
|
scaler.scale(loss).backward()
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error during forward/backward pass at step {step}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
if (step + 1) % config.GRADIENT_ACCUMULATION_STEPS == 0 or step == len(train_dataloader) - 1:
|
|
|
|
|
|
scaler.unscale_(optimizer)
|
|
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
|
|
|
|
|
|
|
scaler.step(optimizer)
|
|
|
scaler.update()
|
|
|
optimizer.zero_grad()
|
|
|
global_step += 1
|
|
|
|
|
|
|
|
|
if global_step % 50 == 0:
|
|
|
avg_loss = total_loss / 50
|
|
|
logger.info(f"Step: {global_step}, Avg Loss: {avg_loss:.4f}")
|
|
|
total_loss = 0.0
|
|
|
|
|
|
|
|
|
postfix = {
|
|
|
"Loss": f"{loss.item()*config.GRADIENT_ACCUMULATION_STEPS:.4f}",
|
|
|
"Step": global_step
|
|
|
}
|
|
|
|
|
|
|
|
|
if config.USE_DYNAMIC_STEPS and hasattr(model, 'custom_outputs'):
|
|
|
if 'steps_taken' in model.custom_outputs:
|
|
|
postfix["Steps"] = model.custom_outputs['steps_taken']
|
|
|
|
|
|
progress_bar.set_postfix(postfix)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_dir = f"./rrn_qa_model_epoch_{epoch+1}"
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
logger.info(f"--- Saving model checkpoint to {output_dir} ---")
|
|
|
|
|
|
|
|
|
try:
|
|
|
logger.info(f"Saving enhanced model components to {output_dir}")
|
|
|
|
|
|
model.base_model.save_pretrained(os.path.join(output_dir, "base_model"))
|
|
|
|
|
|
|
|
|
torch.save(model.qa_head.state_dict(), os.path.join(output_dir, "qa_head.pth"))
|
|
|
torch.save(model.retroactive_update_layer.state_dict(), os.path.join(output_dir, "retroactive_layer.pth"))
|
|
|
torch.save(model.gating_mechanism.state_dict(), os.path.join(output_dir, "gating_mechanism.pth"))
|
|
|
|
|
|
|
|
|
if config.USE_DYNAMIC_STEPS and config.REASONING_STEP_TYPE == "learned" and hasattr(model, "step_controller"):
|
|
|
torch.save(model.step_controller.state_dict(), os.path.join(output_dir, "step_controller.pth"))
|
|
|
logger.info("Saved step controller for learned dynamic steps")
|
|
|
|
|
|
|
|
|
tokenizer.save_pretrained(output_dir)
|
|
|
|
|
|
|
|
|
with open(os.path.join(output_dir, "enhanced_config.json"), "w") as f:
|
|
|
import json
|
|
|
config_dict = {
|
|
|
"num_reasoning_steps": config.NUM_REASONING_STEPS,
|
|
|
"delta_target_ratio": config.DELTA_TARGET_RATIO,
|
|
|
"lambda_coherence": config.LAMBDA_COHERENCE,
|
|
|
"lambda_delta_reg": config.LAMBDA_DELTA_REG,
|
|
|
"memory_max_size": config.MEMORY_MAX_SIZE,
|
|
|
"memory_retrieval_k": config.MEMORY_RETRIEVAL_K,
|
|
|
"use_mixed_precision": config.USE_MIXED_PRECISION,
|
|
|
"bypass_delta_calculation": config.BYPASS_DELTA_CALCULATION
|
|
|
}
|
|
|
|
|
|
|
|
|
if config.USE_DYNAMIC_STEPS:
|
|
|
config_dict.update({
|
|
|
"use_dynamic_steps": config.USE_DYNAMIC_STEPS,
|
|
|
"max_reasoning_steps": config.MAX_REASONING_STEPS,
|
|
|
"min_reasoning_steps": config.MIN_REASONING_STEPS,
|
|
|
"reasoning_step_type": config.REASONING_STEP_TYPE,
|
|
|
"early_stop_threshold": config.EARLY_STOP_THRESHOLD
|
|
|
})
|
|
|
|
|
|
json.dump(config_dict, f, indent=2)
|
|
|
|
|
|
logger.info("Enhanced model checkpoint saved successfully.")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error saving checkpoint at epoch {epoch+1}: {e}")
|
|
|
|
|
|
|
|
|
logger.info("\n***** Training finished *****")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
multiprocessing.freeze_support()
|
|
|
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|