File size: 25,956 Bytes
3451ca0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 |
# test_model.py - RRN QA Model evaluation script with multi-step reasoning support
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel, default_data_collator
from datasets import load_dataset
from tqdm.auto import tqdm
import os
import evaluate as hf_evaluate # Import with alias to avoid naming conflict
import collections
import numpy as np
import logging
import multiprocessing # For Windows multiprocessing support
import json
import argparse
import matplotlib.pyplot as plt
from collections import defaultdict
# Import custom modules and config
import config
from model import EnhancedRRN_QA_Model # Import the enhanced model
# Make sure memory.py and modules.py are accessible
# --- Configuration ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description="Test RRN QA Model")
parser.add_argument("--checkpoint", type=str, default="./rrn_qa_model_epoch_3",
help="Path to checkpoint directory (default: ./rrn_qa_model_epoch_3)")
parser.add_argument("--batch_size", type=int, default=8,
help="Evaluation batch size (default: 8)")
parser.add_argument("--fixed_steps", type=int, default=None,
help="Override to use fixed number of reasoning steps (default: None, use model's dynamic steps)")
parser.add_argument("--use_memory", action="store_true",
help="Enable active memory during evaluation")
parser.add_argument("--output_dir", type=str, default="./eval_results",
help="Directory to save evaluation results (default: ./eval_results)")
parser.add_argument("--visualize", action="store_true",
help="Generate visualizations of reasoning steps")
args = parser.parse_args()
CHECKPOINT_DIR = args.checkpoint
EVAL_BATCH_SIZE = args.batch_size
DEVICE = config.DEVICE
USE_MEMORY = args.use_memory
OUTPUT_DIR = args.output_dir
# Create output directory if it doesn't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)
logger.info(f"Evaluation configuration:")
logger.info(f" Checkpoint: {CHECKPOINT_DIR}")
logger.info(f" Batch size: {EVAL_BATCH_SIZE}")
logger.info(f" Device: {DEVICE}")
logger.info(f" Use memory: {USE_MEMORY}")
logger.info(f" Output directory: {OUTPUT_DIR}")
if args.fixed_steps is not None:
logger.info(f" Using fixed {args.fixed_steps} reasoning steps (overriding model config)")
# --- 1. Load Tokenizer and Model from Checkpoint ---
logger.info(f"Loading tokenizer from {CHECKPOINT_DIR}...")
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
logger.info(f"Loading Enhanced RRN QA Model architecture...")
# Instantiate the enhanced model architecture
model = EnhancedRRN_QA_Model(config.BASE_MODEL_NAME)
# Check if we're loading from a checkpoint with the enhanced architecture
base_model_path = os.path.join(CHECKPOINT_DIR, "base_model")
qa_head_path = os.path.join(CHECKPOINT_DIR, "qa_head.pth")
retroactive_layer_path = os.path.join(CHECKPOINT_DIR, "retroactive_layer.pth")
gating_mechanism_path = os.path.join(CHECKPOINT_DIR, "gating_mechanism.pth")
step_controller_path = os.path.join(CHECKPOINT_DIR, "step_controller.pth")
# Check for required components
if not os.path.exists(base_model_path):
logger.error(f"Base model directory not found at: {base_model_path}")
exit()
if not os.path.exists(qa_head_path):
logger.error(f"QA head weights not found at: {qa_head_path}")
exit()
if not os.path.exists(retroactive_layer_path):
logger.error(f"Retroactive layer weights not found at: {retroactive_layer_path}")
exit()
# Load base model weights
logger.info(f"Loading base model weights from {base_model_path}...")
model.base_model = AutoModel.from_pretrained(base_model_path)
# Check if we're loading from an enhanced checkpoint or a legacy checkpoint
is_enhanced_checkpoint = os.path.exists(gating_mechanism_path)
if is_enhanced_checkpoint:
# Load all enhanced components
logger.info("Loading enhanced model components...")
model.qa_head.load_state_dict(torch.load(qa_head_path, map_location='cpu'))
model.retroactive_update_layer.load_state_dict(torch.load(retroactive_layer_path, map_location='cpu'))
model.gating_mechanism.load_state_dict(torch.load(gating_mechanism_path, map_location='cpu'))
# Load step controller if available (for learned dynamic steps)
if os.path.exists(step_controller_path) and hasattr(model, "step_controller"):
logger.info("Loading step controller for learned dynamic steps...")
model.step_controller.load_state_dict(torch.load(step_controller_path, map_location='cpu'))
logger.info("Enhanced model loaded successfully.")
else:
# We're loading from a legacy checkpoint - need to adapt the weights
logger.info("Loading from legacy checkpoint - adapting weights to enhanced architecture...")
# For the QA head, we need to initialize the enhanced QA head from scratch
# since the architectures are different
logger.info("Initializing enhanced QA head with random weights...")
# For the retroactive layer, we can try to load the weights but might need adjustments
logger.warning("Note: The enhanced model uses a different architecture than the checkpoint.")
logger.warning("Some components will use random initialization.")
# Load enhanced config if available
enhanced_config_path = os.path.join(CHECKPOINT_DIR, "enhanced_config.json")
if os.path.exists(enhanced_config_path):
logger.info(f"Loading enhanced configuration from {enhanced_config_path}")
with open(enhanced_config_path, 'r') as f:
enhanced_config = json.load(f)
# Override model configuration with saved values
if "num_reasoning_steps" in enhanced_config:
model.num_reasoning_steps = enhanced_config["num_reasoning_steps"]
logger.info(f"Using {model.num_reasoning_steps} reasoning steps from config")
if "use_dynamic_steps" in enhanced_config:
model.use_dynamic_steps = enhanced_config["use_dynamic_steps"]
if model.use_dynamic_steps:
model.max_reasoning_steps = enhanced_config.get("max_reasoning_steps", config.MAX_REASONING_STEPS)
model.min_reasoning_steps = enhanced_config.get("min_reasoning_steps", config.MIN_REASONING_STEPS)
model.reasoning_step_type = enhanced_config.get("reasoning_step_type", config.REASONING_STEP_TYPE)
model.early_stop_threshold = enhanced_config.get("early_stop_threshold", config.EARLY_STOP_THRESHOLD)
logger.info(f"Using dynamic reasoning steps (type: {model.reasoning_step_type})")
logger.info(f"Min steps: {model.min_reasoning_steps}, Max steps: {model.max_reasoning_steps}")
# Override with fixed steps if specified
if args.fixed_steps is not None:
logger.info(f"Overriding with fixed {args.fixed_steps} reasoning steps")
model.use_dynamic_steps = False
model.num_reasoning_steps = args.fixed_steps
model.to(DEVICE)
model.eval() # Set model to evaluation mode
logger.info("Model loaded successfully and set to evaluation mode.")
# --- 2. Load and Preprocess Validation Dataset ---
logger.info("Loading SQuAD validation dataset...")
raw_datasets = load_dataset("squad", split="validation")
question_column_name = "question"
context_column_name = "context"
answer_column_name = "answers"
pad_on_right = tokenizer.padding_side == "right"
# Validation preprocessing: Keep example_id and offset_mapping
def prepare_validation_features(examples):
examples[question_column_name] = [q.strip() for q in examples[question_column_name]]
tokenized_examples = tokenizer(
examples[question_column_name if pad_on_right else context_column_name],
examples[context_column_name if pad_on_right else question_column_name],
truncation="only_second" if pad_on_right else "only_first",
max_length=config.MAX_SEQ_LENGTH,
stride=config.DOC_STRIDE,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)
# Keep track of which feature belongs to which example
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
# Add the example_id to link features to original examples
tokenized_examples["example_id"] = []
for i in range(len(tokenized_examples["input_ids"])):
sequence_ids = tokenized_examples.sequence_ids(i)
context_index = 1 if pad_on_right else 0
sample_index = sample_mapping[i]
tokenized_examples["example_id"].append(examples["id"][sample_index])
# Set offset mapping to None for question tokens to avoid predicting answers there
tokenized_examples["offset_mapping"][i] = [
(o if sequence_ids[k] == context_index else None)
for k, o in enumerate(tokenized_examples["offset_mapping"][i])
]
return tokenized_examples
logger.info("Preprocessing validation dataset...")
# Disable multiprocessing which can hang on some systems
logger.info("Using single process for preprocessing to prevent hanging")
eval_dataset = raw_datasets.map(
prepare_validation_features,
batched=True,
remove_columns=raw_datasets.column_names,
num_proc=1, # Disable multiprocessing to avoid hanging
)
# Custom collator to handle None values in offset_mapping
def custom_data_collator(features):
# First, remove offset_mapping which contains None values that can't be batched
offset_mappings = [f.pop("offset_mapping") for f in features]
# Use default collator for everything else
batch = default_data_collator(features)
# Add offset_mapping back as a list since it can't be converted to a tensor
batch["offset_mapping"] = offset_mappings
return batch
# Use custom data collator
data_collator = custom_data_collator
eval_dataloader = DataLoader(
eval_dataset,
collate_fn=data_collator,
batch_size=EVAL_BATCH_SIZE
)
# --- 3. Run Inference ---
logger.info("***** Running Evaluation *****")
logger.info(f" Num examples = {len(eval_dataset)}")
logger.info(f" Batch size = {EVAL_BATCH_SIZE}")
all_start_logits = []
all_end_logits = []
feature_indices = [] # Keep track of the order
# Track multi-step reasoning metrics
reasoning_steps_taken = []
delta_magnitudes = []
gate_values = []
initial_vs_final_changes = []
with torch.no_grad():
for step, batch in enumerate(tqdm(eval_dataloader, desc="Evaluating")):
# Move batch to device
batch_on_device = {k: v.to(DEVICE) for k, v in batch.items() if isinstance(v, torch.Tensor)}
# Store feature indices corresponding to this batch
# Assuming 'input_ids' or similar key represents features in order
current_indices = list(range(step * EVAL_BATCH_SIZE, step * EVAL_BATCH_SIZE + len(batch_on_device['input_ids'])))
feature_indices.extend(current_indices)
# Forward pass - pass only inputs needed by model.forward
outputs = model(
input_ids=batch_on_device.get("input_ids"),
attention_mask=batch_on_device.get("attention_mask"),
token_type_ids=batch_on_device.get("token_type_ids"),
use_memory=USE_MEMORY, # Use memory if enabled
return_dict=True
)
# Get the final logits (y1)
start_logits = outputs.start_logits
end_logits = outputs.end_logits
all_start_logits.append(start_logits.cpu().numpy())
all_end_logits.append(end_logits.cpu().numpy())
# Collect multi-step reasoning metrics from custom_outputs
if hasattr(model, 'custom_outputs'):
# Number of reasoning steps taken
if 'steps_taken' in model.custom_outputs:
reasoning_steps_taken.append(model.custom_outputs['steps_taken'])
# Delta magnitudes (how much the model updates at each step)
if 'all_deltas' in model.custom_outputs and len(model.custom_outputs['all_deltas']) > 0:
batch_deltas = []
for delta in model.custom_outputs['all_deltas']:
# Calculate mean delta magnitude across sequence dimension
delta_norm = delta.norm(dim=-1).mean().cpu().item()
batch_deltas.append(delta_norm)
delta_magnitudes.append(batch_deltas)
# Gate values (how selective the updates are)
if 'all_gates' in model.custom_outputs and len(model.custom_outputs['all_gates']) > 0:
batch_gates = []
for gate in model.custom_outputs['all_gates']:
# Calculate mean gate value across sequence dimension
gate_mean = gate.mean().cpu().item()
batch_gates.append(gate_mean)
gate_values.append(batch_gates)
# Compare initial vs final predictions
if 'y0_start_logits' in model.custom_outputs and 'y0_end_logits' in model.custom_outputs:
y0_start = model.custom_outputs['y0_start_logits']
y0_end = model.custom_outputs['y0_end_logits']
# Calculate how much the predictions changed
start_change = (start_logits - y0_start).abs().mean().cpu().item()
end_change = (end_logits - y0_end).abs().mean().cpu().item()
initial_vs_final_changes.append((start_change + end_change) / 2)
# Concatenate all results
all_start_logits = np.concatenate(all_start_logits, axis=0)
all_end_logits = np.concatenate(all_end_logits, axis=0)
# Ensure the number of predictions matches the number of features
if len(all_start_logits) != len(eval_dataset):
logger.warning(f"Mismatch in prediction count ({len(all_start_logits)}) and feature count ({len(eval_dataset)}). Check dataloader/inference loop.")
# Attempt to slice if predictions exceed features (might happen if last batch wasn't full)
all_start_logits = all_start_logits[:len(eval_dataset)]
all_end_logits = all_end_logits[:len(eval_dataset)]
# Create dictionary mapping feature index to its logits
predictions_dict = {
feature_index: (start_logit, end_logit)
for feature_index, (start_logit, end_logit) in zip(feature_indices, zip(all_start_logits, all_end_logits))
}
# --- 4. Post-Processing ---
# (Adapted from Hugging Face run_qa.py example script)
def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size=20, max_answer_length=30, tokenizer=tokenizer):
all_start_logits, all_end_logits = zip(*raw_predictions.values())
# Build a map from example ID to list of related feature indices
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
features_per_example = collections.defaultdict(list)
for i, feature in enumerate(features):
features_per_example[example_id_to_index[feature["example_id"]]].append(i)
# Dictionary to store predictions
predictions = collections.OrderedDict()
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
# Loop over all examples
for example_index, example in enumerate(tqdm(examples, desc="Post-processing")):
feature_indices = features_per_example[example_index] # Indices of features related to this example
min_null_score = None # Used to identify impossible answers
valid_answers = []
context = example["context"]
# Loop through features associated with the current example
for feature_index in feature_indices:
start_logits = all_start_logits[feature_index]
end_logits = all_end_logits[feature_index]
offset_mapping = features[feature_index]["offset_mapping"]
# Update minimum null prediction score
cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id)
feature_null_score = start_logits[cls_index] + end_logits[cls_index]
if min_null_score is None or min_null_score < feature_null_score:
min_null_score = feature_null_score
# Go through all possibilities for start/end positions
start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
for start_index in start_indexes:
for end_index in end_indexes:
# Skip invalid pairs (start > end, index out of bounds, answer in question part)
if start_index >= len(offset_mapping) or end_index >= len(offset_mapping) or \
offset_mapping[start_index] is None or offset_mapping[end_index] is None or \
end_index < start_index:
continue
# Check answer length
if end_index - start_index + 1 > max_answer_length:
continue
# Extract text and score
start_char = offset_mapping[start_index][0]
end_char = offset_mapping[end_index][1]
score = start_logits[start_index] + end_logits[end_index]
valid_answers.append({
"score": score,
"text": context[start_char: end_char]
})
# Select the best answer across all features for this example
if len(valid_answers) > 0:
best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
else:
# Fallback for no valid answers found
best_answer = {"text": "", "score": min_null_score} # Assign CLS score if needed
# Assign final prediction (use empty string if null score is best)
# Simple version: always take the best scoring valid answer
# More sophisticated versions might compare best_answer["score"] vs min_null_score
predictions[example["id"]] = best_answer["text"]
return predictions
logger.info("Starting post-processing...")
final_predictions = postprocess_qa_predictions(raw_datasets, eval_dataset, predictions_dict)
# --- 5. Compute Metrics ---
logger.info("Calculating SQuAD metrics...")
metric = hf_evaluate.load("squad")
# Format predictions and references for the metric
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in final_predictions.items()]
formatted_references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in raw_datasets]
results = metric.compute(predictions=formatted_predictions, references=formatted_references)
logger.info("***** Evaluation Results *****")
print(results)
# --- 6. Analyze Multi-step Reasoning Metrics ---
logger.info("\n***** Multi-step Reasoning Analysis *****")
# Calculate average number of reasoning steps
if reasoning_steps_taken:
avg_steps = sum(reasoning_steps_taken) / len(reasoning_steps_taken)
logger.info(f"Average reasoning steps: {avg_steps:.2f}")
# Count frequency of each step count
step_counts = collections.Counter(reasoning_steps_taken)
logger.info(f"Step count distribution: {dict(sorted(step_counts.items()))}")
# Calculate average delta magnitudes per step
if delta_magnitudes:
# Transpose to get step-wise averages
steps_delta_magnitudes = defaultdict(list)
for batch_deltas in delta_magnitudes:
for step_idx, delta in enumerate(batch_deltas):
steps_delta_magnitudes[step_idx].append(delta)
avg_delta_by_step = {step: sum(deltas)/len(deltas) for step, deltas in steps_delta_magnitudes.items()}
logger.info(f"Average delta magnitude by step: {avg_delta_by_step}")
# Calculate average gate values per step
if gate_values:
# Transpose to get step-wise averages
steps_gate_values = defaultdict(list)
for batch_gates in gate_values:
for step_idx, gate in enumerate(batch_gates):
steps_gate_values[step_idx].append(gate)
avg_gate_by_step = {step: sum(gates)/len(gates) for step, gates in steps_gate_values.items()}
logger.info(f"Average gate value by step: {avg_gate_by_step}")
# Calculate average change from initial to final predictions
if initial_vs_final_changes:
avg_change = sum(initial_vs_final_changes) / len(initial_vs_final_changes)
logger.info(f"Average change from initial to final predictions: {avg_change:.4f}")
# --- 7. Save Results ---
results_file = os.path.join(OUTPUT_DIR, "eval_results.json")
with open(results_file, 'w') as f:
# Combine SQuAD metrics with multi-step reasoning metrics
full_results = {
"squad_metrics": results,
"multi_step_metrics": {
"avg_reasoning_steps": avg_steps if reasoning_steps_taken else None,
"step_count_distribution": dict(sorted(step_counts.items())) if reasoning_steps_taken else None,
"avg_delta_by_step": avg_delta_by_step if delta_magnitudes else None,
"avg_gate_by_step": avg_gate_by_step if gate_values else None,
"avg_prediction_change": avg_change if initial_vs_final_changes else None
}
}
json.dump(full_results, f, indent=2)
logger.info(f"Results saved to {results_file}")
# --- 8. Generate Visualizations (if requested) ---
if args.visualize and (delta_magnitudes or gate_values or reasoning_steps_taken):
logger.info("Generating visualizations...")
# Create visualization directory
viz_dir = os.path.join(OUTPUT_DIR, "visualizations")
os.makedirs(viz_dir, exist_ok=True)
# Plot step distribution
if reasoning_steps_taken:
plt.figure(figsize=(10, 6))
plt.bar(step_counts.keys(), step_counts.values())
plt.xlabel('Number of Reasoning Steps')
plt.ylabel('Frequency')
plt.title('Distribution of Reasoning Steps')
plt.savefig(os.path.join(viz_dir, 'step_distribution.png'))
plt.close()
# Plot delta magnitudes by step
if delta_magnitudes and steps_delta_magnitudes:
plt.figure(figsize=(10, 6))
steps = sorted(steps_delta_magnitudes.keys())
values = [avg_delta_by_step[step] for step in steps]
plt.plot(steps, values, marker='o')
plt.xlabel('Reasoning Step')
plt.ylabel('Average Delta Magnitude')
plt.title('Delta Magnitude by Reasoning Step')
plt.grid(True)
plt.savefig(os.path.join(viz_dir, 'delta_magnitudes.png'))
plt.close()
# Plot gate values by step
if gate_values and steps_gate_values:
plt.figure(figsize=(10, 6))
steps = sorted(steps_gate_values.keys())
values = [avg_gate_by_step[step] for step in steps]
plt.plot(steps, values, marker='o')
plt.xlabel('Reasoning Step')
plt.ylabel('Average Gate Value')
plt.title('Gate Value by Reasoning Step')
plt.grid(True)
plt.savefig(os.path.join(viz_dir, 'gate_values.png'))
plt.close()
logger.info(f"Visualizations saved to {viz_dir}")
if __name__ == "__main__":
# This is required for Windows to properly handle multiprocessing
multiprocessing.freeze_support()
main()
# Example usage:
# Test with default settings (epoch 3 checkpoint):
# python test_model.py
# Test with specific checkpoint:
# python test_model.py --checkpoint ./rrn_qa_model_epoch_2
# Test with fixed number of reasoning steps:
# python test_model.py --fixed_steps 3
# Test with active memory:
# python test_model.py --use_memory
# Test with visualizations:
# python test_model.py --visualize
|