sayshara's picture
added diffqrcoder_wrapper
70be616
import torch
from torch import nn
from diffqrcoder.losses import PerceptualLoss, ScanningRobustLoss
GRADIENT_SCALE = 100
class ScanningRobustPerceptualGuidance(nn.Module):
def __init__(
self,
module_size: int = 20,
scanning_robust_guidance_scale: int = 500,
perceptual_guidance_scale: int = 2,
):
super().__init__()
self.module_size = module_size
self.scanning_robust_guidance_scale = scanning_robust_guidance_scale
self.perceptual_guidance_scale = perceptual_guidance_scale
self.scanning_robust_loss_fn = ScanningRobustLoss(module_size=module_size)
self.perceptual_loss_fn = PerceptualLoss()
def compute_loss(self, image: torch.Tensor, qrcode: torch.Tensor, ref_image: torch.Tensor) -> torch.Tensor:
return (
self.scanning_robust_guidance_scale * self.scanning_robust_loss_fn(image, qrcode) + \
self.perceptual_guidance_scale * self.perceptual_loss_fn(image, ref_image)
) * GRADIENT_SCALE
def compute_score(self, latents: torch.Tensor, image: torch.Tensor, qrcode: torch.Tensor, ref_image: torch.Tensor) -> torch.Tensor:
loss = self.compute_loss(image, qrcode, ref_image)
return torch.autograd.grad(loss, latents)[0] / GRADIENT_SCALE