rrn-qa / code /test_model.py
will4381's picture
Upload folder using huggingface_hub
3451ca0 verified
# test_model.py - RRN QA Model evaluation script with multi-step reasoning support
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel, default_data_collator
from datasets import load_dataset
from tqdm.auto import tqdm
import os
import evaluate as hf_evaluate # Import with alias to avoid naming conflict
import collections
import numpy as np
import logging
import multiprocessing # For Windows multiprocessing support
import json
import argparse
import matplotlib.pyplot as plt
from collections import defaultdict
# Import custom modules and config
import config
from model import EnhancedRRN_QA_Model # Import the enhanced model
# Make sure memory.py and modules.py are accessible
# --- Configuration ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description="Test RRN QA Model")
parser.add_argument("--checkpoint", type=str, default="./rrn_qa_model_epoch_3",
help="Path to checkpoint directory (default: ./rrn_qa_model_epoch_3)")
parser.add_argument("--batch_size", type=int, default=8,
help="Evaluation batch size (default: 8)")
parser.add_argument("--fixed_steps", type=int, default=None,
help="Override to use fixed number of reasoning steps (default: None, use model's dynamic steps)")
parser.add_argument("--use_memory", action="store_true",
help="Enable active memory during evaluation")
parser.add_argument("--output_dir", type=str, default="./eval_results",
help="Directory to save evaluation results (default: ./eval_results)")
parser.add_argument("--visualize", action="store_true",
help="Generate visualizations of reasoning steps")
args = parser.parse_args()
CHECKPOINT_DIR = args.checkpoint
EVAL_BATCH_SIZE = args.batch_size
DEVICE = config.DEVICE
USE_MEMORY = args.use_memory
OUTPUT_DIR = args.output_dir
# Create output directory if it doesn't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)
logger.info(f"Evaluation configuration:")
logger.info(f" Checkpoint: {CHECKPOINT_DIR}")
logger.info(f" Batch size: {EVAL_BATCH_SIZE}")
logger.info(f" Device: {DEVICE}")
logger.info(f" Use memory: {USE_MEMORY}")
logger.info(f" Output directory: {OUTPUT_DIR}")
if args.fixed_steps is not None:
logger.info(f" Using fixed {args.fixed_steps} reasoning steps (overriding model config)")
# --- 1. Load Tokenizer and Model from Checkpoint ---
logger.info(f"Loading tokenizer from {CHECKPOINT_DIR}...")
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
logger.info(f"Loading Enhanced RRN QA Model architecture...")
# Instantiate the enhanced model architecture
model = EnhancedRRN_QA_Model(config.BASE_MODEL_NAME)
# Check if we're loading from a checkpoint with the enhanced architecture
base_model_path = os.path.join(CHECKPOINT_DIR, "base_model")
qa_head_path = os.path.join(CHECKPOINT_DIR, "qa_head.pth")
retroactive_layer_path = os.path.join(CHECKPOINT_DIR, "retroactive_layer.pth")
gating_mechanism_path = os.path.join(CHECKPOINT_DIR, "gating_mechanism.pth")
step_controller_path = os.path.join(CHECKPOINT_DIR, "step_controller.pth")
# Check for required components
if not os.path.exists(base_model_path):
logger.error(f"Base model directory not found at: {base_model_path}")
exit()
if not os.path.exists(qa_head_path):
logger.error(f"QA head weights not found at: {qa_head_path}")
exit()
if not os.path.exists(retroactive_layer_path):
logger.error(f"Retroactive layer weights not found at: {retroactive_layer_path}")
exit()
# Load base model weights
logger.info(f"Loading base model weights from {base_model_path}...")
model.base_model = AutoModel.from_pretrained(base_model_path)
# Check if we're loading from an enhanced checkpoint or a legacy checkpoint
is_enhanced_checkpoint = os.path.exists(gating_mechanism_path)
if is_enhanced_checkpoint:
# Load all enhanced components
logger.info("Loading enhanced model components...")
model.qa_head.load_state_dict(torch.load(qa_head_path, map_location='cpu'))
model.retroactive_update_layer.load_state_dict(torch.load(retroactive_layer_path, map_location='cpu'))
model.gating_mechanism.load_state_dict(torch.load(gating_mechanism_path, map_location='cpu'))
# Load step controller if available (for learned dynamic steps)
if os.path.exists(step_controller_path) and hasattr(model, "step_controller"):
logger.info("Loading step controller for learned dynamic steps...")
model.step_controller.load_state_dict(torch.load(step_controller_path, map_location='cpu'))
logger.info("Enhanced model loaded successfully.")
else:
# We're loading from a legacy checkpoint - need to adapt the weights
logger.info("Loading from legacy checkpoint - adapting weights to enhanced architecture...")
# For the QA head, we need to initialize the enhanced QA head from scratch
# since the architectures are different
logger.info("Initializing enhanced QA head with random weights...")
# For the retroactive layer, we can try to load the weights but might need adjustments
logger.warning("Note: The enhanced model uses a different architecture than the checkpoint.")
logger.warning("Some components will use random initialization.")
# Load enhanced config if available
enhanced_config_path = os.path.join(CHECKPOINT_DIR, "enhanced_config.json")
if os.path.exists(enhanced_config_path):
logger.info(f"Loading enhanced configuration from {enhanced_config_path}")
with open(enhanced_config_path, 'r') as f:
enhanced_config = json.load(f)
# Override model configuration with saved values
if "num_reasoning_steps" in enhanced_config:
model.num_reasoning_steps = enhanced_config["num_reasoning_steps"]
logger.info(f"Using {model.num_reasoning_steps} reasoning steps from config")
if "use_dynamic_steps" in enhanced_config:
model.use_dynamic_steps = enhanced_config["use_dynamic_steps"]
if model.use_dynamic_steps:
model.max_reasoning_steps = enhanced_config.get("max_reasoning_steps", config.MAX_REASONING_STEPS)
model.min_reasoning_steps = enhanced_config.get("min_reasoning_steps", config.MIN_REASONING_STEPS)
model.reasoning_step_type = enhanced_config.get("reasoning_step_type", config.REASONING_STEP_TYPE)
model.early_stop_threshold = enhanced_config.get("early_stop_threshold", config.EARLY_STOP_THRESHOLD)
logger.info(f"Using dynamic reasoning steps (type: {model.reasoning_step_type})")
logger.info(f"Min steps: {model.min_reasoning_steps}, Max steps: {model.max_reasoning_steps}")
# Override with fixed steps if specified
if args.fixed_steps is not None:
logger.info(f"Overriding with fixed {args.fixed_steps} reasoning steps")
model.use_dynamic_steps = False
model.num_reasoning_steps = args.fixed_steps
model.to(DEVICE)
model.eval() # Set model to evaluation mode
logger.info("Model loaded successfully and set to evaluation mode.")
# --- 2. Load and Preprocess Validation Dataset ---
logger.info("Loading SQuAD validation dataset...")
raw_datasets = load_dataset("squad", split="validation")
question_column_name = "question"
context_column_name = "context"
answer_column_name = "answers"
pad_on_right = tokenizer.padding_side == "right"
# Validation preprocessing: Keep example_id and offset_mapping
def prepare_validation_features(examples):
examples[question_column_name] = [q.strip() for q in examples[question_column_name]]
tokenized_examples = tokenizer(
examples[question_column_name if pad_on_right else context_column_name],
examples[context_column_name if pad_on_right else question_column_name],
truncation="only_second" if pad_on_right else "only_first",
max_length=config.MAX_SEQ_LENGTH,
stride=config.DOC_STRIDE,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)
# Keep track of which feature belongs to which example
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
# Add the example_id to link features to original examples
tokenized_examples["example_id"] = []
for i in range(len(tokenized_examples["input_ids"])):
sequence_ids = tokenized_examples.sequence_ids(i)
context_index = 1 if pad_on_right else 0
sample_index = sample_mapping[i]
tokenized_examples["example_id"].append(examples["id"][sample_index])
# Set offset mapping to None for question tokens to avoid predicting answers there
tokenized_examples["offset_mapping"][i] = [
(o if sequence_ids[k] == context_index else None)
for k, o in enumerate(tokenized_examples["offset_mapping"][i])
]
return tokenized_examples
logger.info("Preprocessing validation dataset...")
# Disable multiprocessing which can hang on some systems
logger.info("Using single process for preprocessing to prevent hanging")
eval_dataset = raw_datasets.map(
prepare_validation_features,
batched=True,
remove_columns=raw_datasets.column_names,
num_proc=1, # Disable multiprocessing to avoid hanging
)
# Custom collator to handle None values in offset_mapping
def custom_data_collator(features):
# First, remove offset_mapping which contains None values that can't be batched
offset_mappings = [f.pop("offset_mapping") for f in features]
# Use default collator for everything else
batch = default_data_collator(features)
# Add offset_mapping back as a list since it can't be converted to a tensor
batch["offset_mapping"] = offset_mappings
return batch
# Use custom data collator
data_collator = custom_data_collator
eval_dataloader = DataLoader(
eval_dataset,
collate_fn=data_collator,
batch_size=EVAL_BATCH_SIZE
)
# --- 3. Run Inference ---
logger.info("***** Running Evaluation *****")
logger.info(f" Num examples = {len(eval_dataset)}")
logger.info(f" Batch size = {EVAL_BATCH_SIZE}")
all_start_logits = []
all_end_logits = []
feature_indices = [] # Keep track of the order
# Track multi-step reasoning metrics
reasoning_steps_taken = []
delta_magnitudes = []
gate_values = []
initial_vs_final_changes = []
with torch.no_grad():
for step, batch in enumerate(tqdm(eval_dataloader, desc="Evaluating")):
# Move batch to device
batch_on_device = {k: v.to(DEVICE) for k, v in batch.items() if isinstance(v, torch.Tensor)}
# Store feature indices corresponding to this batch
# Assuming 'input_ids' or similar key represents features in order
current_indices = list(range(step * EVAL_BATCH_SIZE, step * EVAL_BATCH_SIZE + len(batch_on_device['input_ids'])))
feature_indices.extend(current_indices)
# Forward pass - pass only inputs needed by model.forward
outputs = model(
input_ids=batch_on_device.get("input_ids"),
attention_mask=batch_on_device.get("attention_mask"),
token_type_ids=batch_on_device.get("token_type_ids"),
use_memory=USE_MEMORY, # Use memory if enabled
return_dict=True
)
# Get the final logits (y1)
start_logits = outputs.start_logits
end_logits = outputs.end_logits
all_start_logits.append(start_logits.cpu().numpy())
all_end_logits.append(end_logits.cpu().numpy())
# Collect multi-step reasoning metrics from custom_outputs
if hasattr(model, 'custom_outputs'):
# Number of reasoning steps taken
if 'steps_taken' in model.custom_outputs:
reasoning_steps_taken.append(model.custom_outputs['steps_taken'])
# Delta magnitudes (how much the model updates at each step)
if 'all_deltas' in model.custom_outputs and len(model.custom_outputs['all_deltas']) > 0:
batch_deltas = []
for delta in model.custom_outputs['all_deltas']:
# Calculate mean delta magnitude across sequence dimension
delta_norm = delta.norm(dim=-1).mean().cpu().item()
batch_deltas.append(delta_norm)
delta_magnitudes.append(batch_deltas)
# Gate values (how selective the updates are)
if 'all_gates' in model.custom_outputs and len(model.custom_outputs['all_gates']) > 0:
batch_gates = []
for gate in model.custom_outputs['all_gates']:
# Calculate mean gate value across sequence dimension
gate_mean = gate.mean().cpu().item()
batch_gates.append(gate_mean)
gate_values.append(batch_gates)
# Compare initial vs final predictions
if 'y0_start_logits' in model.custom_outputs and 'y0_end_logits' in model.custom_outputs:
y0_start = model.custom_outputs['y0_start_logits']
y0_end = model.custom_outputs['y0_end_logits']
# Calculate how much the predictions changed
start_change = (start_logits - y0_start).abs().mean().cpu().item()
end_change = (end_logits - y0_end).abs().mean().cpu().item()
initial_vs_final_changes.append((start_change + end_change) / 2)
# Concatenate all results
all_start_logits = np.concatenate(all_start_logits, axis=0)
all_end_logits = np.concatenate(all_end_logits, axis=0)
# Ensure the number of predictions matches the number of features
if len(all_start_logits) != len(eval_dataset):
logger.warning(f"Mismatch in prediction count ({len(all_start_logits)}) and feature count ({len(eval_dataset)}). Check dataloader/inference loop.")
# Attempt to slice if predictions exceed features (might happen if last batch wasn't full)
all_start_logits = all_start_logits[:len(eval_dataset)]
all_end_logits = all_end_logits[:len(eval_dataset)]
# Create dictionary mapping feature index to its logits
predictions_dict = {
feature_index: (start_logit, end_logit)
for feature_index, (start_logit, end_logit) in zip(feature_indices, zip(all_start_logits, all_end_logits))
}
# --- 4. Post-Processing ---
# (Adapted from Hugging Face run_qa.py example script)
def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size=20, max_answer_length=30, tokenizer=tokenizer):
all_start_logits, all_end_logits = zip(*raw_predictions.values())
# Build a map from example ID to list of related feature indices
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
features_per_example = collections.defaultdict(list)
for i, feature in enumerate(features):
features_per_example[example_id_to_index[feature["example_id"]]].append(i)
# Dictionary to store predictions
predictions = collections.OrderedDict()
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
# Loop over all examples
for example_index, example in enumerate(tqdm(examples, desc="Post-processing")):
feature_indices = features_per_example[example_index] # Indices of features related to this example
min_null_score = None # Used to identify impossible answers
valid_answers = []
context = example["context"]
# Loop through features associated with the current example
for feature_index in feature_indices:
start_logits = all_start_logits[feature_index]
end_logits = all_end_logits[feature_index]
offset_mapping = features[feature_index]["offset_mapping"]
# Update minimum null prediction score
cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id)
feature_null_score = start_logits[cls_index] + end_logits[cls_index]
if min_null_score is None or min_null_score < feature_null_score:
min_null_score = feature_null_score
# Go through all possibilities for start/end positions
start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
for start_index in start_indexes:
for end_index in end_indexes:
# Skip invalid pairs (start > end, index out of bounds, answer in question part)
if start_index >= len(offset_mapping) or end_index >= len(offset_mapping) or \
offset_mapping[start_index] is None or offset_mapping[end_index] is None or \
end_index < start_index:
continue
# Check answer length
if end_index - start_index + 1 > max_answer_length:
continue
# Extract text and score
start_char = offset_mapping[start_index][0]
end_char = offset_mapping[end_index][1]
score = start_logits[start_index] + end_logits[end_index]
valid_answers.append({
"score": score,
"text": context[start_char: end_char]
})
# Select the best answer across all features for this example
if len(valid_answers) > 0:
best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
else:
# Fallback for no valid answers found
best_answer = {"text": "", "score": min_null_score} # Assign CLS score if needed
# Assign final prediction (use empty string if null score is best)
# Simple version: always take the best scoring valid answer
# More sophisticated versions might compare best_answer["score"] vs min_null_score
predictions[example["id"]] = best_answer["text"]
return predictions
logger.info("Starting post-processing...")
final_predictions = postprocess_qa_predictions(raw_datasets, eval_dataset, predictions_dict)
# --- 5. Compute Metrics ---
logger.info("Calculating SQuAD metrics...")
metric = hf_evaluate.load("squad")
# Format predictions and references for the metric
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in final_predictions.items()]
formatted_references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in raw_datasets]
results = metric.compute(predictions=formatted_predictions, references=formatted_references)
logger.info("***** Evaluation Results *****")
print(results)
# --- 6. Analyze Multi-step Reasoning Metrics ---
logger.info("\n***** Multi-step Reasoning Analysis *****")
# Calculate average number of reasoning steps
if reasoning_steps_taken:
avg_steps = sum(reasoning_steps_taken) / len(reasoning_steps_taken)
logger.info(f"Average reasoning steps: {avg_steps:.2f}")
# Count frequency of each step count
step_counts = collections.Counter(reasoning_steps_taken)
logger.info(f"Step count distribution: {dict(sorted(step_counts.items()))}")
# Calculate average delta magnitudes per step
if delta_magnitudes:
# Transpose to get step-wise averages
steps_delta_magnitudes = defaultdict(list)
for batch_deltas in delta_magnitudes:
for step_idx, delta in enumerate(batch_deltas):
steps_delta_magnitudes[step_idx].append(delta)
avg_delta_by_step = {step: sum(deltas)/len(deltas) for step, deltas in steps_delta_magnitudes.items()}
logger.info(f"Average delta magnitude by step: {avg_delta_by_step}")
# Calculate average gate values per step
if gate_values:
# Transpose to get step-wise averages
steps_gate_values = defaultdict(list)
for batch_gates in gate_values:
for step_idx, gate in enumerate(batch_gates):
steps_gate_values[step_idx].append(gate)
avg_gate_by_step = {step: sum(gates)/len(gates) for step, gates in steps_gate_values.items()}
logger.info(f"Average gate value by step: {avg_gate_by_step}")
# Calculate average change from initial to final predictions
if initial_vs_final_changes:
avg_change = sum(initial_vs_final_changes) / len(initial_vs_final_changes)
logger.info(f"Average change from initial to final predictions: {avg_change:.4f}")
# --- 7. Save Results ---
results_file = os.path.join(OUTPUT_DIR, "eval_results.json")
with open(results_file, 'w') as f:
# Combine SQuAD metrics with multi-step reasoning metrics
full_results = {
"squad_metrics": results,
"multi_step_metrics": {
"avg_reasoning_steps": avg_steps if reasoning_steps_taken else None,
"step_count_distribution": dict(sorted(step_counts.items())) if reasoning_steps_taken else None,
"avg_delta_by_step": avg_delta_by_step if delta_magnitudes else None,
"avg_gate_by_step": avg_gate_by_step if gate_values else None,
"avg_prediction_change": avg_change if initial_vs_final_changes else None
}
}
json.dump(full_results, f, indent=2)
logger.info(f"Results saved to {results_file}")
# --- 8. Generate Visualizations (if requested) ---
if args.visualize and (delta_magnitudes or gate_values or reasoning_steps_taken):
logger.info("Generating visualizations...")
# Create visualization directory
viz_dir = os.path.join(OUTPUT_DIR, "visualizations")
os.makedirs(viz_dir, exist_ok=True)
# Plot step distribution
if reasoning_steps_taken:
plt.figure(figsize=(10, 6))
plt.bar(step_counts.keys(), step_counts.values())
plt.xlabel('Number of Reasoning Steps')
plt.ylabel('Frequency')
plt.title('Distribution of Reasoning Steps')
plt.savefig(os.path.join(viz_dir, 'step_distribution.png'))
plt.close()
# Plot delta magnitudes by step
if delta_magnitudes and steps_delta_magnitudes:
plt.figure(figsize=(10, 6))
steps = sorted(steps_delta_magnitudes.keys())
values = [avg_delta_by_step[step] for step in steps]
plt.plot(steps, values, marker='o')
plt.xlabel('Reasoning Step')
plt.ylabel('Average Delta Magnitude')
plt.title('Delta Magnitude by Reasoning Step')
plt.grid(True)
plt.savefig(os.path.join(viz_dir, 'delta_magnitudes.png'))
plt.close()
# Plot gate values by step
if gate_values and steps_gate_values:
plt.figure(figsize=(10, 6))
steps = sorted(steps_gate_values.keys())
values = [avg_gate_by_step[step] for step in steps]
plt.plot(steps, values, marker='o')
plt.xlabel('Reasoning Step')
plt.ylabel('Average Gate Value')
plt.title('Gate Value by Reasoning Step')
plt.grid(True)
plt.savefig(os.path.join(viz_dir, 'gate_values.png'))
plt.close()
logger.info(f"Visualizations saved to {viz_dir}")
if __name__ == "__main__":
# This is required for Windows to properly handle multiprocessing
multiprocessing.freeze_support()
main()
# Example usage:
# Test with default settings (epoch 3 checkpoint):
# python test_model.py
# Test with specific checkpoint:
# python test_model.py --checkpoint ./rrn_qa_model_epoch_2
# Test with fixed number of reasoning steps:
# python test_model.py --fixed_steps 3
# Test with active memory:
# python test_model.py --use_memory
# Test with visualizations:
# python test_model.py --visualize