Spaces:
Runtime error
Runtime error
| import torch | |
| from PIL import Image | |
| import torchvision.transforms as T | |
| import matplotlib.pyplot as plt | |
| from model import CycleGAN | |
| # Load and preprocess the input image | |
| def load_image(image_path, device, image_size=(256, 256)): | |
| transform = T.Compose([ | |
| T.Resize(image_size), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize to [-1, 1] | |
| ]) | |
| image = Image.open(image_path).convert("RGB") | |
| image = transform(image).unsqueeze(0).to(device) | |
| return image | |
| # Display the output image | |
| def display_image(tensor_image): | |
| tensor_image = tensor_image.squeeze(0).cpu() # Remove batch dimension | |
| tensor_image = (tensor_image * 0.5 + 0.5).clamp(0, 1) # Denormalize | |
| plt.imshow(tensor_image.permute(1, 2, 0)) # CHW to HWC | |
| plt.axis("off") | |
| plt.show() | |
| # Load the input image | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = CycleGAN.load_from_checkpoint("/content/cyclegan_monet_unet_250_epochs.ckpt", **MODEL_CONFIG) | |