""" Evaluation on real-world damaged characters from Jiucheng Palace inscription. Implements real-world scenario testing from the paper. """ import torch from torch.utils.data import Dataset, DataLoader from transformers import BertTokenizer from PIL import Image import numpy as np import os from config import Config from models.mmrm import MMRM from evaluation.metrics import RestorationMetrics class RealWorldDataset(Dataset): """ Dataset for real-world damaged characters. Loads images from data/real/pic/ and contexts from data/real/restore.txt """ def __init__(self, config: Config, tokenizer: BertTokenizer): """ Initialize real-world dataset. Args: config: Configuration object tokenizer: Tokenizer for text encoding """ self.config = config self.tokenizer = tokenizer # Load ground truth labels true_path = os.path.join(config.real_data_dir, 'true.txt') with open(true_path, 'r', encoding='utf-8') as f: self.labels = [line.strip() for line in f.readlines()] # Load context sentences restore_path = os.path.join(config.real_data_dir, 'restore.txt') with open(restore_path, 'r', encoding='utf-8') as f: self.contexts = [line.strip() for line in f.readlines()] # Image directory self.image_dir = os.path.join(config.real_data_dir, 'pic') # Map contexts to labels (each context may have multiple [MASK] or [UNK]) self.samples = [] label_idx = 0 for context in self.contexts: # Count [MASK] tokens in this context num_masks = context.count('[MASK]') if num_masks > 0: # Get labels for this context context_labels = [] for _ in range(num_masks): if label_idx < len(self.labels): context_labels.append(self.labels[label_idx]) label_idx += 1 self.samples.append({ 'context': context, 'labels': context_labels, 'image_indices': list(range(label_idx - num_masks + 1, label_idx + 1)) }) print(f"Loaded {len(self.samples)} real-world samples") def __len__(self): return len(self.samples) def __getitem__(self, idx): """ Get a real-world sample. Returns: Dictionary with tokenized context, damaged images, and labels """ sample = self.samples[idx] # Tokenize context encoding = self.tokenizer( sample['context'], max_length=self.config.max_seq_length, padding='max_length', truncation=True, return_tensors='pt' ) # Find [MASK] positions mask_token_id = self.tokenizer.mask_token_id input_ids = encoding['input_ids'].squeeze(0) mask_positions = (input_ids == mask_token_id).nonzero(as_tuple=True)[0] # Load damaged images damaged_images = [] for img_idx in sample['image_indices']: img_path = os.path.join(self.image_dir, f'o{img_idx}.png') img = Image.open(img_path).convert('L') # Resize to 64x64 img = img.resize((self.config.image_size, self.config.image_size)) # Convert to tensor and normalize img_array = np.array(img).astype(np.float32) / 255.0 img_tensor = torch.from_numpy(img_array).unsqueeze(0) damaged_images.append(img_tensor) damaged_images = torch.stack(damaged_images) if len(damaged_images) > 0 else torch.zeros(1, 1, 64, 64) # Convert labels to IDs label_ids = [] for label in sample['labels']: label_id = self.tokenizer.convert_tokens_to_ids(label) label_ids.append(label_id) labels = torch.tensor(label_ids, dtype=torch.long) return { 'input_ids': input_ids, 'attention_mask': encoding['attention_mask'].squeeze(0), 'mask_positions': mask_positions, 'damaged_images': damaged_images, 'labels': labels } def evaluate_real_world(config: Config, checkpoint_path: str) -> str: """ Evaluate on real-world damaged characters. Args: config: Configuration object checkpoint_path: Path to model checkpoint Returns: Formatted results string """ device = torch.device(config.device if torch.cuda.is_available() or config.device == "cuda" else "cpu") # Load model model = MMRM(config).to(device) checkpoint = torch.load(checkpoint_path, map_location=device, weights_only = False) model.load_state_dict(checkpoint['model_state_dict']) model.eval() print(f"Loaded model from {checkpoint_path}") # Initialize tokenizer tokenizer = BertTokenizer.from_pretrained(config.roberta_model) # Create dataset real_dataset = RealWorldDataset(config, tokenizer) real_loader = DataLoader( real_dataset, batch_size=1, # Process one context at a time shuffle=False ) # Evaluate metrics = RestorationMetrics(config.top_k_values) print("\nEvaluating on real-world data...") with torch.no_grad(): for batch in real_loader: input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) mask_positions = batch['mask_positions'].to(device) damaged_images = batch['damaged_images'].to(device) labels = batch['labels'].to(device) # Forward pass text_logits, _ = model(input_ids, attention_mask, mask_positions, damaged_images) # Update metrics metrics.update(text_logits, labels) results = metrics.compute() output = f"\nReal-world Evaluation Results (38 characters):\n" output += f"{'='*50}\n" output += f"Accuracy: {results['accuracy']:.2f}%\n" output += f"Hit@5: {results['hit_5']:.2f}%\n" output += f"Hit@10: {results['hit_10']:.2f}%\n" output += f"Hit@20: {results['hit_20']:.2f}%\n" output += f"MRR: {results['mrr']:.2f}\n" output += f"{'='*50}\n" output += f"\nCompare with paper results:\n" output += f" Paper - Accuracy: 55.26%, MRR: 62.28\n" return output if __name__ == "__main__": import sys if len(sys.argv) < 2: print("Usage: python evaluate_real.py ") sys.exit(1) checkpoint_path = sys.argv[1] config = Config() results = evaluate_real_world(config, checkpoint_path) print(results)