autoencoderkiranbackend1 / Model_evaluation.py
kiran6969's picture
Initial Commit
cd698c9
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend
import matplotlib.pyplot as plt
import os
# ----------------------------
# 1️⃣ Device setup
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# ----------------------------
# 2️⃣ Hyperparameters
# ----------------------------
batch_size = 16 # smaller batch for evaluation
save_dir = "./reconstructions"
os.makedirs(save_dir, exist_ok=True)
# ----------------------------
# 3️⃣ Data transforms and test dataset
# ----------------------------
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor()
])
test_dataset = datasets.CelebA(root='./data', split='test', download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
# ----------------------------
# 4️⃣ Autoencoder definition (must match training)
# ----------------------------
class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 16, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 7)
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(64, 32, 7),
nn.ReLU(),
nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(16, 3, 3, stride=2, padding=1, output_padding=1),
nn.Sigmoid()
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
# ----------------------------
# 5️⃣ Load trained model
# ----------------------------
model = Autoencoder().to(device)
model.load_state_dict(torch.load("celeba_autoencoder.pth", map_location=device))
model.eval()
print("Model loaded successfully!")
# ----------------------------
# 6️⃣ Evaluate on test dataset
# ----------------------------
criterion = nn.MSELoss()
test_loss = 0.0
with torch.no_grad():
for data in test_loader:
imgs, _ = data
imgs = imgs.to(device)
outputs = model(imgs)
loss = criterion(outputs, imgs)
test_loss += loss.item() * imgs.size(0)
test_loss /= len(test_loader.dataset)
print(f"Test Loss: {test_loss:.4f}")
# ----------------------------
# 7️⃣ Save original + reconstructed images
# ----------------------------
def save_image(img_tensor, filename):
img = img_tensor.cpu().numpy().transpose((1,2,0)) # C,H,W -> H,W,C
plt.imsave(filename, img)
# Take one batch for visualization
data_iter = iter(test_loader)
images, _ = next(data_iter)
images = images.to(device)
with torch.no_grad():
outputs = model(images)
num_images = min(batch_size, 8) # save first 8 images
for i in range(num_images):
save_image(images[i], os.path.join(save_dir, f'original_{i}.png'))
save_image(outputs[i], os.path.join(save_dir, f'reconstructed_{i}.png'))
print(f"Saved {num_images} original and reconstructed images to {save_dir}")