File size: 3,416 Bytes
dfd1909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from typing import Optional
from torch import Tensor,device

from tqdm import tqdm
import torch
from collections import deque

from TorchJaekwon.Util.UtilTorch import UtilTorch
from TorchJaekwon.Model.Diffusion.DDPM.DDPM import DDPM

from TorchJaekwon.Model.Diffusion.DDPM.DiffusionUtil import DiffusionUtil

class PNDM:
    #Pseudo Numerical methods for Diffusion Models on manifolds (PNDM) is by Luping Liu, Yi Ren, Zhijie Lin and Zhou Zhao

    def __init__(self, ddpm_module:DDPM) -> None:
        self.ddpm_module = ddpm_module

    @torch.no_grad()
    def infer(self,
              x_shape:Optional[tuple],
              cond:Optional[dict] = None,
              is_cond_unpack:bool = False,
              pndm_speedup:int = 10) -> Tensor:
        _, cond, additional_data_dict = self.ddpm_module.preprocess(x_start = None, cond=cond)
        if x_shape is None: x_shape = self.ddpm_module.get_x_shape(cond=cond)
        total_timesteps:int = self.ddpm_module.timesteps
        model_device:device = UtilTorch.get_model_device(self.ddpm_module)
        x:Tensor = torch.randn(x_shape, device = model_device)
        self.noise_list = deque(maxlen=4)

        for i in tqdm(reversed(range(0, total_timesteps, pndm_speedup)), desc='sample time step', total=total_timesteps // pndm_speedup):
            x = self.p_sample_plms(x, torch.full((x_shape[0],), i, device=model_device, dtype=torch.long), pndm_speedup, cond, is_cond_unpack)

        return self.ddpm_module.postprocess(x, additional_data_dict)
    
    @torch.no_grad()
    def p_sample_plms(self, x, t, interval, cond, is_cond_unpack):
        """
        Use the PLMS method from [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778).
        """

        noise_list = self.noise_list
        noise_pred = self.ddpm_module.apply_model(x, t, cond, is_cond_unpack, self.ddpm_module.cfg_scale)
        if self.ddpm_module.model_output_type == 'v_prediction':
            noise_pred = self.ddpm_module.predict_noise_from_v(x, t, noise_pred)

        if len(noise_list) == 0:
            x_pred = self.get_x_pred(x, noise_pred, t, interval)
            noise_pred_prev = self.ddpm_module.apply_model(x_pred, torch.max(t-interval, torch.zeros_like(t)), cond, is_cond_unpack) #max(t-interval, 0)
            noise_pred_prime = (noise_pred + noise_pred_prev) / 2
        elif len(noise_list) == 1:
            noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2
        elif len(noise_list) == 2:
            noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12
        elif len(noise_list) >= 3:
            noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24

        x_prev = self.get_x_pred(x, noise_pred_prime, t, interval)
        noise_list.append(noise_pred)

        return x_prev
    
    def get_x_pred(self, x, noise_t, t, interval):
        a_t = DiffusionUtil.extract(self.ddpm_module.alphas_cumprod, t, x.shape)
        a_prev = DiffusionUtil.extract(self.ddpm_module.alphas_cumprod, torch.max(t-interval, torch.zeros_like(t)), x.shape)
        a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()

        x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
        x_pred = x + x_delta

        return x_pred