import torch import numpy as np import os from PIL import Image from fused_ssim import fused_ssim gt_image = torch.tensor(np.array(Image.open(os.path.join("..", "images", "albert.jpg"))), dtype=torch.float32, device="cuda").unsqueeze(0).unsqueeze(0) / 255.0 pred_image = torch.nn.Parameter(torch.rand_like(gt_image)) with torch.no_grad(): ssim_value = fused_ssim(pred_image, gt_image, train=False) print("Starting with SSIM value:", ssim_value) optimizer = torch.optim.Adam([pred_image]) while ssim_value < 0.9999: optimizer.zero_grad() loss = 1.0 - fused_ssim(pred_image, gt_image) loss.backward() optimizer.step() with torch.no_grad(): ssim_value = fused_ssim(pred_image, gt_image, train=False) print("SSIM value:", ssim_value) pred_image = (pred_image * 255.0).squeeze(0).squeeze(0) to_save = pred_image.detach().cpu().numpy().astype(np.uint8) Image.fromarray(to_save).save(os.path.join("..", "images", "predicted.jpg"))