| import torch |
| import torch.nn as nn |
| from torchvision import transforms |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| from PIL import Image |
| import os |
|
|
| |
| |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print("Using device:", device) |
|
|
| |
| |
| |
| transform = transforms.Compose([ |
| transforms.Resize((128, 128)), |
| transforms.ToTensor() |
| ]) |
|
|
| |
| |
| |
| class Autoencoder(nn.Module): |
| def __init__(self): |
| super(Autoencoder, self).__init__() |
| self.encoder = nn.Sequential( |
| nn.Conv2d(3,32,3, stride=2, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(32,64,3, stride=2, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(64, 128, 3, stride=2, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(128, 256, 3, stride=2, padding=1) |
| ) |
|
|
| self.decoder = nn.Sequential( |
| nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1), |
| nn.ReLU(), |
| nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), |
| nn.ReLU(), |
| nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), |
| nn.ReLU(), |
| nn.ConvTranspose2d(32, 3, 3, stride=2, padding=1, output_padding=1), |
| nn.Sigmoid() |
| ) |
|
|
| def forward(self, x): |
| x = self.encoder(x) |
| x = self.decoder(x) |
| return x |
| |
| |
| |
| model = Autoencoder().to(device) |
| model.load_state_dict(torch.load("celeba_autoencoder.pth", map_location=device)) |
| model.eval() |
| print("Model loaded successfully!") |
|
|
| |
| |
| |
| custom_img_path = "photoc.jpg" |
|
|
| if os.path.exists(custom_img_path): |
| |
| img = Image.open(custom_img_path).convert("RGB") |
| orig_size = img.size |
|
|
| |
| img_tensor = transform(img).unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| reconstructed = model(img_tensor) |
|
|
| |
| recon_img = reconstructed.squeeze(0).cpu().numpy().transpose(1, 2, 0) |
| recon_pil = Image.fromarray((recon_img * 255).astype("uint8")) |
| recon_pil = recon_pil.resize(orig_size, Image.BILINEAR) |
|
|
| |
| combined = Image.new("RGB", (orig_size[0] * 2, orig_size[1])) |
| combined.paste(img, (0, 0)) |
| combined.paste(recon_pil, (orig_size[0], 0)) |
|
|
| combined.save("comparison.png") |
| print(f"✅ Saved comparison as comparison.png (original | reconstructed)") |
| else: |
| print(f"⚠️ Image not found: {custom_img_path}") |
|
|