File size: 3,971 Bytes
436b829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from PIL import Image
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import random
from omegaconf import DictConfig
from ppd.utils.diffusion.timesteps import Timesteps
from ppd.utils.diffusion.schedule import LinearSchedule
from ppd.utils.diffusion.sampler import EulerSampler
from ppd.utils.transform import video2tensor
from ppd.utils.align_vda import align_video_depth

from ppd.models.dit_video import DiT_Video
from safetensors.torch import load_file

# infer settings, do not change
INFER_LEN = 16
KEYFRAMES = [0, 8, 15]
OVERLAP = 3
STRIDE = 13

class PixelPerfectVideoDepth(nn.Module):
    def __init__(
        self,
        semantics_model='Pi3',
        semantics_pth='checkpoints/pi3.safetensors',
        sampling_steps=4,
        ):
        super().__init__()
        self.sampling_steps = sampling_steps
        DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
        self.device = DEVICE

        if semantics_model == 'Pi3':
            from ppd.models.pi3.models.pi3 import Pi3
            self.sem_encoder = Pi3()
            self.sem_encoder.load_state_dict(load_file(semantics_pth))

        self.sem_encoder = self.sem_encoder.to(self.device).eval()
        self.sem_encoder.requires_grad_(False)

        self.configure_diffusion()
        self.dit_video = DiT_Video()

    def configure_diffusion(self):
        self.schedule = LinearSchedule(T=1000)
        self.sampling_timesteps = Timesteps(
            T=self.schedule.T,
            steps=self.sampling_steps,
            device=self.device,
            )
        self.sampler = EulerSampler(
            schedule=self.schedule,
            timesteps=self.sampling_timesteps,
            prediction_type='velocity'
            )
    
    @torch.no_grad()
    def infer_video(self, images, use_fp16: bool = True):
        images = video2tensor(images)
        images = [img.to(self.device) for img in images]
        p_imgs = [F.interpolate(img, size=(512, 512), mode='bilinear', align_corners=False) for img in images]
        LEN = len(p_imgs)
        R = (LEN - INFER_LEN) % STRIDE
        if R != 0:
            pad_len = STRIDE - R
            last_img = p_imgs[-1]
            p_imgs.extend([last_img.clone() for _ in range(pad_len)])
        autocast_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        with torch.autocast(device_type=self.device.type, dtype=autocast_dtype):
            preds = self.forward_test(p_imgs)

        preds = [F.interpolate(pred, size=images[0].shape[-2:], mode='bilinear', align_corners=False) for pred in preds]
        preds = align_video_depth(preds, INFER_LEN, KEYFRAMES, OVERLAP)
        return preds[:LEN]
    
    @torch.no_grad()
    def forward_test(self, imgs):
        preds = []
        pre_img = None
        init_latent = torch.randn(size=[INFER_LEN, 1, imgs[0].shape[2], imgs[0].shape[3]]).to(self.device)
        for i in range(0, len(imgs)-INFER_LEN+1, STRIDE):
            cur_img = imgs[i:i+INFER_LEN]
            if pre_img is not None:
                cur_img[:OVERLAP] = [pre_img[k] for k in KEYFRAMES]
            pre_img = cur_img
            concat_img = torch.cat(cur_img, dim=0)
            semantics = self.semantics_prompt(concat_img)
            cond = concat_img - 0.5
            latent = init_latent
        
            for timestep in self.sampling_timesteps:
                input = torch.cat([latent, cond], dim=1)
                pred = self.dit_video(x=input, semantics=semantics, timestep=timestep)
                latent = self.sampler.step(pred=pred, x_t=latent, t=timestep)
            cur_pred = latent + 0.5
            preds.append(cur_pred)
        return preds

    @torch.no_grad()
    def semantics_prompt(self, images):
        with torch.no_grad():
            semantics = self.sem_encoder.forward_semantics(images)
        return semantics