|
|
""" |
|
|
Evaluation on real-world damaged characters from Jiucheng Palace inscription. |
|
|
Implements real-world scenario testing from the paper. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from transformers import BertTokenizer |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import os |
|
|
|
|
|
from config import Config |
|
|
from models.mmrm import MMRM |
|
|
from evaluation.metrics import RestorationMetrics |
|
|
|
|
|
|
|
|
class RealWorldDataset(Dataset): |
|
|
""" |
|
|
Dataset for real-world damaged characters. |
|
|
Loads images from data/real/pic/ and contexts from data/real/restore.txt |
|
|
""" |
|
|
|
|
|
def __init__(self, config: Config, tokenizer: BertTokenizer): |
|
|
""" |
|
|
Initialize real-world dataset. |
|
|
|
|
|
Args: |
|
|
config: Configuration object |
|
|
tokenizer: Tokenizer for text encoding |
|
|
""" |
|
|
self.config = config |
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
|
|
|
true_path = os.path.join(config.real_data_dir, 'true.txt') |
|
|
with open(true_path, 'r', encoding='utf-8') as f: |
|
|
self.labels = [line.strip() for line in f.readlines()] |
|
|
|
|
|
|
|
|
restore_path = os.path.join(config.real_data_dir, 'restore.txt') |
|
|
with open(restore_path, 'r', encoding='utf-8') as f: |
|
|
self.contexts = [line.strip() for line in f.readlines()] |
|
|
|
|
|
|
|
|
self.image_dir = os.path.join(config.real_data_dir, 'pic') |
|
|
|
|
|
|
|
|
self.samples = [] |
|
|
label_idx = 0 |
|
|
|
|
|
for context in self.contexts: |
|
|
|
|
|
num_masks = context.count('[MASK]') |
|
|
|
|
|
if num_masks > 0: |
|
|
|
|
|
context_labels = [] |
|
|
for _ in range(num_masks): |
|
|
if label_idx < len(self.labels): |
|
|
context_labels.append(self.labels[label_idx]) |
|
|
label_idx += 1 |
|
|
|
|
|
self.samples.append({ |
|
|
'context': context, |
|
|
'labels': context_labels, |
|
|
'image_indices': list(range(label_idx - num_masks + 1, label_idx + 1)) |
|
|
}) |
|
|
|
|
|
print(f"Loaded {len(self.samples)} real-world samples") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.samples) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
""" |
|
|
Get a real-world sample. |
|
|
|
|
|
Returns: |
|
|
Dictionary with tokenized context, damaged images, and labels |
|
|
""" |
|
|
sample = self.samples[idx] |
|
|
|
|
|
|
|
|
encoding = self.tokenizer( |
|
|
sample['context'], |
|
|
max_length=self.config.max_seq_length, |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
|
|
|
mask_token_id = self.tokenizer.mask_token_id |
|
|
input_ids = encoding['input_ids'].squeeze(0) |
|
|
mask_positions = (input_ids == mask_token_id).nonzero(as_tuple=True)[0] |
|
|
|
|
|
|
|
|
damaged_images = [] |
|
|
for img_idx in sample['image_indices']: |
|
|
img_path = os.path.join(self.image_dir, f'o{img_idx}.png') |
|
|
img = Image.open(img_path).convert('L') |
|
|
|
|
|
|
|
|
img = img.resize((self.config.image_size, self.config.image_size)) |
|
|
|
|
|
|
|
|
img_array = np.array(img).astype(np.float32) / 255.0 |
|
|
img_tensor = torch.from_numpy(img_array).unsqueeze(0) |
|
|
|
|
|
damaged_images.append(img_tensor) |
|
|
|
|
|
damaged_images = torch.stack(damaged_images) if len(damaged_images) > 0 else torch.zeros(1, 1, 64, 64) |
|
|
|
|
|
|
|
|
label_ids = [] |
|
|
for label in sample['labels']: |
|
|
label_id = self.tokenizer.convert_tokens_to_ids(label) |
|
|
label_ids.append(label_id) |
|
|
|
|
|
labels = torch.tensor(label_ids, dtype=torch.long) |
|
|
|
|
|
return { |
|
|
'input_ids': input_ids, |
|
|
'attention_mask': encoding['attention_mask'].squeeze(0), |
|
|
'mask_positions': mask_positions, |
|
|
'damaged_images': damaged_images, |
|
|
'labels': labels |
|
|
} |
|
|
|
|
|
|
|
|
def evaluate_real_world(config: Config, checkpoint_path: str) -> str: |
|
|
""" |
|
|
Evaluate on real-world damaged characters. |
|
|
|
|
|
Args: |
|
|
config: Configuration object |
|
|
checkpoint_path: Path to model checkpoint |
|
|
|
|
|
Returns: |
|
|
Formatted results string |
|
|
""" |
|
|
device = torch.device(config.device if torch.cuda.is_available() or config.device == "cuda" else "cpu") |
|
|
|
|
|
|
|
|
model = MMRM(config).to(device) |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only = False) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
|
|
|
print(f"Loaded model from {checkpoint_path}") |
|
|
|
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained(config.roberta_model) |
|
|
|
|
|
|
|
|
real_dataset = RealWorldDataset(config, tokenizer) |
|
|
real_loader = DataLoader( |
|
|
real_dataset, |
|
|
batch_size=1, |
|
|
shuffle=False |
|
|
) |
|
|
|
|
|
|
|
|
metrics = RestorationMetrics(config.top_k_values) |
|
|
|
|
|
print("\nEvaluating on real-world data...") |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in real_loader: |
|
|
input_ids = batch['input_ids'].to(device) |
|
|
attention_mask = batch['attention_mask'].to(device) |
|
|
mask_positions = batch['mask_positions'].to(device) |
|
|
damaged_images = batch['damaged_images'].to(device) |
|
|
labels = batch['labels'].to(device) |
|
|
|
|
|
|
|
|
text_logits, _ = model(input_ids, attention_mask, mask_positions, damaged_images) |
|
|
|
|
|
|
|
|
metrics.update(text_logits, labels) |
|
|
|
|
|
results = metrics.compute() |
|
|
|
|
|
output = f"\nReal-world Evaluation Results (38 characters):\n" |
|
|
output += f"{'='*50}\n" |
|
|
output += f"Accuracy: {results['accuracy']:.2f}%\n" |
|
|
output += f"Hit@5: {results['hit_5']:.2f}%\n" |
|
|
output += f"Hit@10: {results['hit_10']:.2f}%\n" |
|
|
output += f"Hit@20: {results['hit_20']:.2f}%\n" |
|
|
output += f"MRR: {results['mrr']:.2f}\n" |
|
|
output += f"{'='*50}\n" |
|
|
output += f"\nCompare with paper results:\n" |
|
|
output += f" Paper - Accuracy: 55.26%, MRR: 62.28\n" |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import sys |
|
|
|
|
|
if len(sys.argv) < 2: |
|
|
print("Usage: python evaluate_real.py <checkpoint_path>") |
|
|
sys.exit(1) |
|
|
|
|
|
checkpoint_path = sys.argv[1] |
|
|
|
|
|
config = Config() |
|
|
results = evaluate_real_world(config, checkpoint_path) |
|
|
print(results) |
|
|
|