# test_model.py - RRN QA Model evaluation script with multi-step reasoning support import torch from torch.utils.data import DataLoader from transformers import AutoTokenizer, AutoModel, default_data_collator from datasets import load_dataset from tqdm.auto import tqdm import os import evaluate as hf_evaluate # Import with alias to avoid naming conflict import collections import numpy as np import logging import multiprocessing # For Windows multiprocessing support import json import argparse import matplotlib.pyplot as plt from collections import defaultdict # Import custom modules and config import config from model import EnhancedRRN_QA_Model # Import the enhanced model # Make sure memory.py and modules.py are accessible # --- Configuration --- logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def main(): # Parse command line arguments parser = argparse.ArgumentParser(description="Test RRN QA Model") parser.add_argument("--checkpoint", type=str, default="./rrn_qa_model_epoch_3", help="Path to checkpoint directory (default: ./rrn_qa_model_epoch_3)") parser.add_argument("--batch_size", type=int, default=8, help="Evaluation batch size (default: 8)") parser.add_argument("--fixed_steps", type=int, default=None, help="Override to use fixed number of reasoning steps (default: None, use model's dynamic steps)") parser.add_argument("--use_memory", action="store_true", help="Enable active memory during evaluation") parser.add_argument("--output_dir", type=str, default="./eval_results", help="Directory to save evaluation results (default: ./eval_results)") parser.add_argument("--visualize", action="store_true", help="Generate visualizations of reasoning steps") args = parser.parse_args() CHECKPOINT_DIR = args.checkpoint EVAL_BATCH_SIZE = args.batch_size DEVICE = config.DEVICE USE_MEMORY = args.use_memory OUTPUT_DIR = args.output_dir # Create output directory if it doesn't exist os.makedirs(OUTPUT_DIR, exist_ok=True) logger.info(f"Evaluation configuration:") logger.info(f" Checkpoint: {CHECKPOINT_DIR}") logger.info(f" Batch size: {EVAL_BATCH_SIZE}") logger.info(f" Device: {DEVICE}") logger.info(f" Use memory: {USE_MEMORY}") logger.info(f" Output directory: {OUTPUT_DIR}") if args.fixed_steps is not None: logger.info(f" Using fixed {args.fixed_steps} reasoning steps (overriding model config)") # --- 1. Load Tokenizer and Model from Checkpoint --- logger.info(f"Loading tokenizer from {CHECKPOINT_DIR}...") tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR) logger.info(f"Loading Enhanced RRN QA Model architecture...") # Instantiate the enhanced model architecture model = EnhancedRRN_QA_Model(config.BASE_MODEL_NAME) # Check if we're loading from a checkpoint with the enhanced architecture base_model_path = os.path.join(CHECKPOINT_DIR, "base_model") qa_head_path = os.path.join(CHECKPOINT_DIR, "qa_head.pth") retroactive_layer_path = os.path.join(CHECKPOINT_DIR, "retroactive_layer.pth") gating_mechanism_path = os.path.join(CHECKPOINT_DIR, "gating_mechanism.pth") step_controller_path = os.path.join(CHECKPOINT_DIR, "step_controller.pth") # Check for required components if not os.path.exists(base_model_path): logger.error(f"Base model directory not found at: {base_model_path}") exit() if not os.path.exists(qa_head_path): logger.error(f"QA head weights not found at: {qa_head_path}") exit() if not os.path.exists(retroactive_layer_path): logger.error(f"Retroactive layer weights not found at: {retroactive_layer_path}") exit() # Load base model weights logger.info(f"Loading base model weights from {base_model_path}...") model.base_model = AutoModel.from_pretrained(base_model_path) # Check if we're loading from an enhanced checkpoint or a legacy checkpoint is_enhanced_checkpoint = os.path.exists(gating_mechanism_path) if is_enhanced_checkpoint: # Load all enhanced components logger.info("Loading enhanced model components...") model.qa_head.load_state_dict(torch.load(qa_head_path, map_location='cpu')) model.retroactive_update_layer.load_state_dict(torch.load(retroactive_layer_path, map_location='cpu')) model.gating_mechanism.load_state_dict(torch.load(gating_mechanism_path, map_location='cpu')) # Load step controller if available (for learned dynamic steps) if os.path.exists(step_controller_path) and hasattr(model, "step_controller"): logger.info("Loading step controller for learned dynamic steps...") model.step_controller.load_state_dict(torch.load(step_controller_path, map_location='cpu')) logger.info("Enhanced model loaded successfully.") else: # We're loading from a legacy checkpoint - need to adapt the weights logger.info("Loading from legacy checkpoint - adapting weights to enhanced architecture...") # For the QA head, we need to initialize the enhanced QA head from scratch # since the architectures are different logger.info("Initializing enhanced QA head with random weights...") # For the retroactive layer, we can try to load the weights but might need adjustments logger.warning("Note: The enhanced model uses a different architecture than the checkpoint.") logger.warning("Some components will use random initialization.") # Load enhanced config if available enhanced_config_path = os.path.join(CHECKPOINT_DIR, "enhanced_config.json") if os.path.exists(enhanced_config_path): logger.info(f"Loading enhanced configuration from {enhanced_config_path}") with open(enhanced_config_path, 'r') as f: enhanced_config = json.load(f) # Override model configuration with saved values if "num_reasoning_steps" in enhanced_config: model.num_reasoning_steps = enhanced_config["num_reasoning_steps"] logger.info(f"Using {model.num_reasoning_steps} reasoning steps from config") if "use_dynamic_steps" in enhanced_config: model.use_dynamic_steps = enhanced_config["use_dynamic_steps"] if model.use_dynamic_steps: model.max_reasoning_steps = enhanced_config.get("max_reasoning_steps", config.MAX_REASONING_STEPS) model.min_reasoning_steps = enhanced_config.get("min_reasoning_steps", config.MIN_REASONING_STEPS) model.reasoning_step_type = enhanced_config.get("reasoning_step_type", config.REASONING_STEP_TYPE) model.early_stop_threshold = enhanced_config.get("early_stop_threshold", config.EARLY_STOP_THRESHOLD) logger.info(f"Using dynamic reasoning steps (type: {model.reasoning_step_type})") logger.info(f"Min steps: {model.min_reasoning_steps}, Max steps: {model.max_reasoning_steps}") # Override with fixed steps if specified if args.fixed_steps is not None: logger.info(f"Overriding with fixed {args.fixed_steps} reasoning steps") model.use_dynamic_steps = False model.num_reasoning_steps = args.fixed_steps model.to(DEVICE) model.eval() # Set model to evaluation mode logger.info("Model loaded successfully and set to evaluation mode.") # --- 2. Load and Preprocess Validation Dataset --- logger.info("Loading SQuAD validation dataset...") raw_datasets = load_dataset("squad", split="validation") question_column_name = "question" context_column_name = "context" answer_column_name = "answers" pad_on_right = tokenizer.padding_side == "right" # Validation preprocessing: Keep example_id and offset_mapping def prepare_validation_features(examples): examples[question_column_name] = [q.strip() for q in examples[question_column_name]] tokenized_examples = tokenizer( examples[question_column_name if pad_on_right else context_column_name], examples[context_column_name if pad_on_right else question_column_name], truncation="only_second" if pad_on_right else "only_first", max_length=config.MAX_SEQ_LENGTH, stride=config.DOC_STRIDE, return_overflowing_tokens=True, return_offsets_mapping=True, padding="max_length", ) # Keep track of which feature belongs to which example sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") # Add the example_id to link features to original examples tokenized_examples["example_id"] = [] for i in range(len(tokenized_examples["input_ids"])): sequence_ids = tokenized_examples.sequence_ids(i) context_index = 1 if pad_on_right else 0 sample_index = sample_mapping[i] tokenized_examples["example_id"].append(examples["id"][sample_index]) # Set offset mapping to None for question tokens to avoid predicting answers there tokenized_examples["offset_mapping"][i] = [ (o if sequence_ids[k] == context_index else None) for k, o in enumerate(tokenized_examples["offset_mapping"][i]) ] return tokenized_examples logger.info("Preprocessing validation dataset...") # Disable multiprocessing which can hang on some systems logger.info("Using single process for preprocessing to prevent hanging") eval_dataset = raw_datasets.map( prepare_validation_features, batched=True, remove_columns=raw_datasets.column_names, num_proc=1, # Disable multiprocessing to avoid hanging ) # Custom collator to handle None values in offset_mapping def custom_data_collator(features): # First, remove offset_mapping which contains None values that can't be batched offset_mappings = [f.pop("offset_mapping") for f in features] # Use default collator for everything else batch = default_data_collator(features) # Add offset_mapping back as a list since it can't be converted to a tensor batch["offset_mapping"] = offset_mappings return batch # Use custom data collator data_collator = custom_data_collator eval_dataloader = DataLoader( eval_dataset, collate_fn=data_collator, batch_size=EVAL_BATCH_SIZE ) # --- 3. Run Inference --- logger.info("***** Running Evaluation *****") logger.info(f" Num examples = {len(eval_dataset)}") logger.info(f" Batch size = {EVAL_BATCH_SIZE}") all_start_logits = [] all_end_logits = [] feature_indices = [] # Keep track of the order # Track multi-step reasoning metrics reasoning_steps_taken = [] delta_magnitudes = [] gate_values = [] initial_vs_final_changes = [] with torch.no_grad(): for step, batch in enumerate(tqdm(eval_dataloader, desc="Evaluating")): # Move batch to device batch_on_device = {k: v.to(DEVICE) for k, v in batch.items() if isinstance(v, torch.Tensor)} # Store feature indices corresponding to this batch # Assuming 'input_ids' or similar key represents features in order current_indices = list(range(step * EVAL_BATCH_SIZE, step * EVAL_BATCH_SIZE + len(batch_on_device['input_ids']))) feature_indices.extend(current_indices) # Forward pass - pass only inputs needed by model.forward outputs = model( input_ids=batch_on_device.get("input_ids"), attention_mask=batch_on_device.get("attention_mask"), token_type_ids=batch_on_device.get("token_type_ids"), use_memory=USE_MEMORY, # Use memory if enabled return_dict=True ) # Get the final logits (y1) start_logits = outputs.start_logits end_logits = outputs.end_logits all_start_logits.append(start_logits.cpu().numpy()) all_end_logits.append(end_logits.cpu().numpy()) # Collect multi-step reasoning metrics from custom_outputs if hasattr(model, 'custom_outputs'): # Number of reasoning steps taken if 'steps_taken' in model.custom_outputs: reasoning_steps_taken.append(model.custom_outputs['steps_taken']) # Delta magnitudes (how much the model updates at each step) if 'all_deltas' in model.custom_outputs and len(model.custom_outputs['all_deltas']) > 0: batch_deltas = [] for delta in model.custom_outputs['all_deltas']: # Calculate mean delta magnitude across sequence dimension delta_norm = delta.norm(dim=-1).mean().cpu().item() batch_deltas.append(delta_norm) delta_magnitudes.append(batch_deltas) # Gate values (how selective the updates are) if 'all_gates' in model.custom_outputs and len(model.custom_outputs['all_gates']) > 0: batch_gates = [] for gate in model.custom_outputs['all_gates']: # Calculate mean gate value across sequence dimension gate_mean = gate.mean().cpu().item() batch_gates.append(gate_mean) gate_values.append(batch_gates) # Compare initial vs final predictions if 'y0_start_logits' in model.custom_outputs and 'y0_end_logits' in model.custom_outputs: y0_start = model.custom_outputs['y0_start_logits'] y0_end = model.custom_outputs['y0_end_logits'] # Calculate how much the predictions changed start_change = (start_logits - y0_start).abs().mean().cpu().item() end_change = (end_logits - y0_end).abs().mean().cpu().item() initial_vs_final_changes.append((start_change + end_change) / 2) # Concatenate all results all_start_logits = np.concatenate(all_start_logits, axis=0) all_end_logits = np.concatenate(all_end_logits, axis=0) # Ensure the number of predictions matches the number of features if len(all_start_logits) != len(eval_dataset): logger.warning(f"Mismatch in prediction count ({len(all_start_logits)}) and feature count ({len(eval_dataset)}). Check dataloader/inference loop.") # Attempt to slice if predictions exceed features (might happen if last batch wasn't full) all_start_logits = all_start_logits[:len(eval_dataset)] all_end_logits = all_end_logits[:len(eval_dataset)] # Create dictionary mapping feature index to its logits predictions_dict = { feature_index: (start_logit, end_logit) for feature_index, (start_logit, end_logit) in zip(feature_indices, zip(all_start_logits, all_end_logits)) } # --- 4. Post-Processing --- # (Adapted from Hugging Face run_qa.py example script) def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size=20, max_answer_length=30, tokenizer=tokenizer): all_start_logits, all_end_logits = zip(*raw_predictions.values()) # Build a map from example ID to list of related feature indices example_id_to_index = {k: i for i, k in enumerate(examples["id"])} features_per_example = collections.defaultdict(list) for i, feature in enumerate(features): features_per_example[example_id_to_index[feature["example_id"]]].append(i) # Dictionary to store predictions predictions = collections.OrderedDict() logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") # Loop over all examples for example_index, example in enumerate(tqdm(examples, desc="Post-processing")): feature_indices = features_per_example[example_index] # Indices of features related to this example min_null_score = None # Used to identify impossible answers valid_answers = [] context = example["context"] # Loop through features associated with the current example for feature_index in feature_indices: start_logits = all_start_logits[feature_index] end_logits = all_end_logits[feature_index] offset_mapping = features[feature_index]["offset_mapping"] # Update minimum null prediction score cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id) feature_null_score = start_logits[cls_index] + end_logits[cls_index] if min_null_score is None or min_null_score < feature_null_score: min_null_score = feature_null_score # Go through all possibilities for start/end positions start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist() end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist() for start_index in start_indexes: for end_index in end_indexes: # Skip invalid pairs (start > end, index out of bounds, answer in question part) if start_index >= len(offset_mapping) or end_index >= len(offset_mapping) or \ offset_mapping[start_index] is None or offset_mapping[end_index] is None or \ end_index < start_index: continue # Check answer length if end_index - start_index + 1 > max_answer_length: continue # Extract text and score start_char = offset_mapping[start_index][0] end_char = offset_mapping[end_index][1] score = start_logits[start_index] + end_logits[end_index] valid_answers.append({ "score": score, "text": context[start_char: end_char] }) # Select the best answer across all features for this example if len(valid_answers) > 0: best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0] else: # Fallback for no valid answers found best_answer = {"text": "", "score": min_null_score} # Assign CLS score if needed # Assign final prediction (use empty string if null score is best) # Simple version: always take the best scoring valid answer # More sophisticated versions might compare best_answer["score"] vs min_null_score predictions[example["id"]] = best_answer["text"] return predictions logger.info("Starting post-processing...") final_predictions = postprocess_qa_predictions(raw_datasets, eval_dataset, predictions_dict) # --- 5. Compute Metrics --- logger.info("Calculating SQuAD metrics...") metric = hf_evaluate.load("squad") # Format predictions and references for the metric formatted_predictions = [{"id": k, "prediction_text": v} for k, v in final_predictions.items()] formatted_references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in raw_datasets] results = metric.compute(predictions=formatted_predictions, references=formatted_references) logger.info("***** Evaluation Results *****") print(results) # --- 6. Analyze Multi-step Reasoning Metrics --- logger.info("\n***** Multi-step Reasoning Analysis *****") # Calculate average number of reasoning steps if reasoning_steps_taken: avg_steps = sum(reasoning_steps_taken) / len(reasoning_steps_taken) logger.info(f"Average reasoning steps: {avg_steps:.2f}") # Count frequency of each step count step_counts = collections.Counter(reasoning_steps_taken) logger.info(f"Step count distribution: {dict(sorted(step_counts.items()))}") # Calculate average delta magnitudes per step if delta_magnitudes: # Transpose to get step-wise averages steps_delta_magnitudes = defaultdict(list) for batch_deltas in delta_magnitudes: for step_idx, delta in enumerate(batch_deltas): steps_delta_magnitudes[step_idx].append(delta) avg_delta_by_step = {step: sum(deltas)/len(deltas) for step, deltas in steps_delta_magnitudes.items()} logger.info(f"Average delta magnitude by step: {avg_delta_by_step}") # Calculate average gate values per step if gate_values: # Transpose to get step-wise averages steps_gate_values = defaultdict(list) for batch_gates in gate_values: for step_idx, gate in enumerate(batch_gates): steps_gate_values[step_idx].append(gate) avg_gate_by_step = {step: sum(gates)/len(gates) for step, gates in steps_gate_values.items()} logger.info(f"Average gate value by step: {avg_gate_by_step}") # Calculate average change from initial to final predictions if initial_vs_final_changes: avg_change = sum(initial_vs_final_changes) / len(initial_vs_final_changes) logger.info(f"Average change from initial to final predictions: {avg_change:.4f}") # --- 7. Save Results --- results_file = os.path.join(OUTPUT_DIR, "eval_results.json") with open(results_file, 'w') as f: # Combine SQuAD metrics with multi-step reasoning metrics full_results = { "squad_metrics": results, "multi_step_metrics": { "avg_reasoning_steps": avg_steps if reasoning_steps_taken else None, "step_count_distribution": dict(sorted(step_counts.items())) if reasoning_steps_taken else None, "avg_delta_by_step": avg_delta_by_step if delta_magnitudes else None, "avg_gate_by_step": avg_gate_by_step if gate_values else None, "avg_prediction_change": avg_change if initial_vs_final_changes else None } } json.dump(full_results, f, indent=2) logger.info(f"Results saved to {results_file}") # --- 8. Generate Visualizations (if requested) --- if args.visualize and (delta_magnitudes or gate_values or reasoning_steps_taken): logger.info("Generating visualizations...") # Create visualization directory viz_dir = os.path.join(OUTPUT_DIR, "visualizations") os.makedirs(viz_dir, exist_ok=True) # Plot step distribution if reasoning_steps_taken: plt.figure(figsize=(10, 6)) plt.bar(step_counts.keys(), step_counts.values()) plt.xlabel('Number of Reasoning Steps') plt.ylabel('Frequency') plt.title('Distribution of Reasoning Steps') plt.savefig(os.path.join(viz_dir, 'step_distribution.png')) plt.close() # Plot delta magnitudes by step if delta_magnitudes and steps_delta_magnitudes: plt.figure(figsize=(10, 6)) steps = sorted(steps_delta_magnitudes.keys()) values = [avg_delta_by_step[step] for step in steps] plt.plot(steps, values, marker='o') plt.xlabel('Reasoning Step') plt.ylabel('Average Delta Magnitude') plt.title('Delta Magnitude by Reasoning Step') plt.grid(True) plt.savefig(os.path.join(viz_dir, 'delta_magnitudes.png')) plt.close() # Plot gate values by step if gate_values and steps_gate_values: plt.figure(figsize=(10, 6)) steps = sorted(steps_gate_values.keys()) values = [avg_gate_by_step[step] for step in steps] plt.plot(steps, values, marker='o') plt.xlabel('Reasoning Step') plt.ylabel('Average Gate Value') plt.title('Gate Value by Reasoning Step') plt.grid(True) plt.savefig(os.path.join(viz_dir, 'gate_values.png')) plt.close() logger.info(f"Visualizations saved to {viz_dir}") if __name__ == "__main__": # This is required for Windows to properly handle multiprocessing multiprocessing.freeze_support() main() # Example usage: # Test with default settings (epoch 3 checkpoint): # python test_model.py # Test with specific checkpoint: # python test_model.py --checkpoint ./rrn_qa_model_epoch_2 # Test with fixed number of reasoning steps: # python test_model.py --fixed_steps 3 # Test with active memory: # python test_model.py --use_memory # Test with visualizations: # python test_model.py --visualize