autoencoderkiranbackend1 / singleimagereconstruction.py
kiran6969's picture
Initial Commit
cd698c9
import torch
import torch.nn as nn
from torchvision import transforms
import matplotlib
matplotlib.use('Agg') # non-interactive backend
import matplotlib.pyplot as plt
from PIL import Image
import os
# ----------------------------
# 1️⃣ Device setup
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# ----------------------------
# 2️⃣ Transform (same as training)
# ----------------------------
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor()
])
# ----------------------------
# 3️⃣ Autoencoder definition (must match training)
# ----------------------------
class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3,32,3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32,64,3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(128, 256, 3, stride=2, padding=1)
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 3, 3, stride=2, padding=1, output_padding=1),
nn.Sigmoid()
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
# ----------------------------
# 4️⃣ Load trained model
# ----------------------------
model = Autoencoder().to(device)
model.load_state_dict(torch.load("celeba_autoencoder.pth", map_location=device))
model.eval()
print("Model loaded successfully!")
# ----------------------------
# 5️⃣ Load and reconstruct a single image
# ----------------------------
custom_img_path = "photoc.jpg" # <--- change to your image path
if os.path.exists(custom_img_path):
# Load original image
img = Image.open(custom_img_path).convert("RGB")
orig_size = img.size # (width, height)
# Apply transform for model input
img_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
reconstructed = model(img_tensor)
# Convert reconstruction to PIL
recon_img = reconstructed.squeeze(0).cpu().numpy().transpose(1, 2, 0)
recon_pil = Image.fromarray((recon_img * 255).astype("uint8"))
recon_pil = recon_pil.resize(orig_size, Image.BILINEAR)
# Create side-by-side comparison
combined = Image.new("RGB", (orig_size[0] * 2, orig_size[1]))
combined.paste(img, (0, 0)) # left: original
combined.paste(recon_pil, (orig_size[0], 0)) # right: reconstructed
combined.save("comparison.png")
print(f"✅ Saved comparison as comparison.png (original | reconstructed)")
else:
print(f"⚠️ Image not found: {custom_img_path}")