import torch from models import Generator from PIL import Image import torchvision.transforms as transforms import numpy as np import os import matplotlib.pyplot as plt def predict(model, image_path, device="cpu"): transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) image = Image.open(image_path).convert("RGB") image_tensor = transform(image).unsqueeze(0).to(device) model.eval() with torch.no_grad(): prediction = model(image_tensor) prediction = prediction.squeeze(0).cpu().detach().numpy() prediction = (prediction * 0.5 + 0.5).transpose(1, 2, 0) prediction = (prediction * 255).astype(np.uint8) return prediction def main(): device = "cuda" if torch.cuda.is_available() else "cpu" gen_Z = Generator(img_channels=3).to(device) # Check if a checkpoint exists checkpoint_path = "genz.pth.tar" if os.path.exists(checkpoint_path): gen_Z.load_state_dict(torch.load(checkpoint_path, map_location=device)) print(f"Loaded checkpoint from {checkpoint_path}") else: print("Using untrained model (no checkpoint found).") test_image = "data/horse2zebra/testA/n02381460_1010.jpg" # Example horse image if os.path.exists(test_image): result = predict(gen_Z, test_image, device) plt.imshow(result) plt.title("Style Transferred Image (Zebra)") plt.axis("off") plt.savefig("prediction_result.png") print("Prediction saved to prediction_result.png") else: print(f"Test image {test_image} not found.") if __name__ == "__main__": main()