File size: 1,295 Bytes
70be616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
from torch import nn

from diffqrcoder.losses import PerceptualLoss, ScanningRobustLoss



GRADIENT_SCALE = 100


class ScanningRobustPerceptualGuidance(nn.Module):
    def __init__(
        self,
        module_size: int = 20,
        scanning_robust_guidance_scale: int = 500,
        perceptual_guidance_scale: int = 2,
    ):
        super().__init__()
        self.module_size = module_size
        self.scanning_robust_guidance_scale = scanning_robust_guidance_scale
        self.perceptual_guidance_scale = perceptual_guidance_scale
        self.scanning_robust_loss_fn = ScanningRobustLoss(module_size=module_size)
        self.perceptual_loss_fn = PerceptualLoss()

    def compute_loss(self, image: torch.Tensor, qrcode: torch.Tensor, ref_image: torch.Tensor) -> torch.Tensor:
        return (
            self.scanning_robust_guidance_scale * self.scanning_robust_loss_fn(image, qrcode) + \
            self.perceptual_guidance_scale * self.perceptual_loss_fn(image, ref_image)
        ) * GRADIENT_SCALE

    def compute_score(self, latents: torch.Tensor, image: torch.Tensor, qrcode: torch.Tensor, ref_image: torch.Tensor) -> torch.Tensor:
        loss = self.compute_loss(image, qrcode, ref_image)
        return torch.autograd.grad(loss, latents)[0] / GRADIENT_SCALE