import torch import torch.nn as nn from torchvision import datasets, transforms from torch.utils.data import DataLoader import matplotlib matplotlib.use('Agg') # Use non-interactive backend import matplotlib.pyplot as plt import os # ---------------------------- # 1️⃣ Device setup # ---------------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) # ---------------------------- # 2️⃣ Hyperparameters # ---------------------------- batch_size = 16 # smaller batch for evaluation save_dir = "./reconstructions" os.makedirs(save_dir, exist_ok=True) # ---------------------------- # 3️⃣ Data transforms and test dataset # ---------------------------- transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor() ]) test_dataset = datasets.CelebA(root='./data', split='test', download=True, transform=transform) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) # ---------------------------- # 4️⃣ Autoencoder definition (must match training) # ---------------------------- class Autoencoder(nn.Module): def __init__(self): super(Autoencoder, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(3, 16, 3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 32, 3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(32, 64, 7) ) self.decoder = nn.Sequential( nn.ConvTranspose2d(64, 32, 7), nn.ReLU(), nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1), nn.ReLU(), nn.ConvTranspose2d(16, 3, 3, stride=2, padding=1, output_padding=1), nn.Sigmoid() ) def forward(self, x): x = self.encoder(x) x = self.decoder(x) return x # ---------------------------- # 5️⃣ Load trained model # ---------------------------- model = Autoencoder().to(device) model.load_state_dict(torch.load("celeba_autoencoder.pth", map_location=device)) model.eval() print("Model loaded successfully!") # ---------------------------- # 6️⃣ Evaluate on test dataset # ---------------------------- criterion = nn.MSELoss() test_loss = 0.0 with torch.no_grad(): for data in test_loader: imgs, _ = data imgs = imgs.to(device) outputs = model(imgs) loss = criterion(outputs, imgs) test_loss += loss.item() * imgs.size(0) test_loss /= len(test_loader.dataset) print(f"Test Loss: {test_loss:.4f}") # ---------------------------- # 7️⃣ Save original + reconstructed images # ---------------------------- def save_image(img_tensor, filename): img = img_tensor.cpu().numpy().transpose((1,2,0)) # C,H,W -> H,W,C plt.imsave(filename, img) # Take one batch for visualization data_iter = iter(test_loader) images, _ = next(data_iter) images = images.to(device) with torch.no_grad(): outputs = model(images) num_images = min(batch_size, 8) # save first 8 images for i in range(num_images): save_image(images[i], os.path.join(save_dir, f'original_{i}.png')) save_image(outputs[i], os.path.join(save_dir, f'reconstructed_{i}.png')) print(f"Saved {num_images} original and reconstructed images to {save_dir}")