| import torch
|
| from models import Generator
|
| from PIL import Image
|
| import torchvision.transforms as transforms
|
| import numpy as np
|
| import os
|
| import matplotlib.pyplot as plt
|
|
|
| def predict(model, image_path, device="cpu"):
|
| transform = transforms.Compose([
|
| transforms.Resize((256, 256)),
|
| transforms.ToTensor(),
|
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| ])
|
|
|
| image = Image.open(image_path).convert("RGB")
|
| image_tensor = transform(image).unsqueeze(0).to(device)
|
|
|
| model.eval()
|
| with torch.no_grad():
|
| prediction = model(image_tensor)
|
| prediction = prediction.squeeze(0).cpu().detach().numpy()
|
| prediction = (prediction * 0.5 + 0.5).transpose(1, 2, 0)
|
| prediction = (prediction * 255).astype(np.uint8)
|
|
|
| return prediction
|
|
|
| def main():
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
| gen_Z = Generator(img_channels=3).to(device)
|
|
|
|
|
| checkpoint_path = "genz.pth.tar"
|
| if os.path.exists(checkpoint_path):
|
| gen_Z.load_state_dict(torch.load(checkpoint_path, map_location=device))
|
| print(f"Loaded checkpoint from {checkpoint_path}")
|
| else:
|
| print("Using untrained model (no checkpoint found).")
|
|
|
| test_image = "data/horse2zebra/testA/n02381460_1010.jpg"
|
| if os.path.exists(test_image):
|
| result = predict(gen_Z, test_image, device)
|
| plt.imshow(result)
|
| plt.title("Style Transferred Image (Zebra)")
|
| plt.axis("off")
|
| plt.savefig("prediction_result.png")
|
| print("Prediction saved to prediction_result.png")
|
| else:
|
| print(f"Test image {test_image} not found.")
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|