File size: 13,126 Bytes
a846205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
import copy
import torch
from torch import nn
import torch.nn.functional as Fn
from torchvision.transforms import v2

from . import register, make
from .base import Base

from chord.util import fresnelSchlick, GeometrySchlickGGX, DistributionGGX
from chord.util import srgb_to_rgb, tone_gamma, get_positions, safe_01_div

class dummy_module(nn.Module):
    def forward(self, x): return x

def post_decoder(out_dict):
    out = {}
    for key in out_dict.keys():
        if key.startswith("approx"): continue
        elif key == "normal": 
            out[key] = Fn.normalize(2. * out_dict[key] - 1., dim=1) / 2. + 0.5
        elif key == "rou_met":
            out['roughness'], out['metalness'] = out_dict['rou_met'][:,0], out_dict['rou_met'][:,1]
        else: out[key] = out_dict[key]
    return out

def process_irradiance(radiance, kernel_size=25, res=64):
    """
    Process the irradiance using PyTorch, equivalent to the original OpenCV-based function.
    
    Args:
        radiance (torch.Tensor): Input radiance tensor (H, W).
        kernel_size (int): Size of the kernel for the median blur.
        res (int): Target resolution for resizing the image.
    
    Returns:
        torch.Tensor: Processed radiance tensor (res, res).
    """
    # Ensure the input radiance is a 4D tensor (B, 1, H, W)
    assert radiance.shape[1] == 1 and radiance.dim() == 4, f"Invalid radiance shape, got {radiance.shape}"
    # resize to low resolution
    resizer = v2.Resize(size=res, antialias=True)
    radiance = resizer(radiance)

    # Define a 11x11 averaging kernel
    kernel = torch.ones((1, 1, 11, 11), dtype=torch.float32).to(radiance) / 121.0
    # Apply convolution (averaging filter)
    radiance = Fn.pad(radiance, (5,)*4, mode="reflect")  # Pad for edge handling
    radiance = Fn.conv2d(radiance, kernel, padding=0)  # 'padding=2' to maintain input dimensions

    # Clamp values and scale to [0, 255] for median filtering
    radiance = torch.clamp(radiance * 255, 0, 255)  # Remove batch/channel dims

    # Apply median filtering
    paded_radiance = Fn.pad(radiance, (kernel_size // 2,) * 4, mode="reflect")  # Pad for edge handling
    unfolded = Fn.unfold(paded_radiance, kernel_size)  # Extract patches
    radiance = torch.median(unfolded, dim=1).values.view(radiance.shape)  # Median of patches

    # Normalize to [0, 1]
    rad_min, rad_max = radiance.amin([2,3], keepdim=True), radiance.amax([2,3], keepdim=True)
    radiance = (radiance - rad_min) / (rad_max - rad_min)
    return radiance

def opt_light_dir(_radiance, _num_samples=6):
    '''
        _radiance: (bs, 1, h, w)
    '''
    assert _radiance.shape[1] == 1 and _radiance.dim()==4
    bs, _, h, w = _radiance.shape

    def evenly_sample(_num_samples, min=0, max=2*torch.pi):
        # returns torch.tensor([1, _num_samples])
        return torch.tensor(range(_num_samples+1)) * (max - min) / _num_samples + min

    def compute_radiance_diff(angles):
        num = angles.shape[-1]
        dirs = torch.cat([torch.cos(angles), torch.sin(angles)]).T
        pos_dir = grid_pos.repeat(num, 1, 1, 1)
        pos_mask = torch.einsum("abcd,ad->abc", pos_dir, dirs) > 0
        neg_mask = torch.einsum("abcd,ad->abc", pos_dir, dirs) < 0
        samples_radiance = _radiance.repeat(1,num,1,1)
        radiance_diff = (samples_radiance*pos_mask[None] - samples_radiance*neg_mask[None]).sum([2,3])
        return radiance_diff
    
    angle_min, angle_max = 0, 2*torch.pi
    grid_pos = Fn.normalize(get_positions(h,w,10)[...,:2], dim=-1, eps=1e-6).to(_radiance)
    while(((angle_max - angle_min) > (torch.pi/90))):
        angles = evenly_sample(_num_samples, angle_min, angle_max)[None].to(_radiance)
        diffs = compute_radiance_diff(angles).mean(0)
        angle_min = angles[:,diffs.argmax()].item() - (angle_max - angle_min)/_num_samples
        angle_max = angles[:,diffs.argmax()].item() + (angle_max - angle_min)/_num_samples

    light_angle = angles[:, diffs.argmax()]
    return torch.tensor([torch.cos(light_angle), torch.sin(light_angle)]).to(_radiance)


def find_light_dir(raw_irradiance, light):
    raw_irradiance = v2.functional.rgb_to_grayscale(raw_irradiance)
    irradiance = process_irradiance(raw_irradiance)
    dir = opt_light_dir(irradiance)
    dir = torch.cat([dir, torch.tensor([0.5**0.5]).to(dir)])
    _light = copy.deepcopy(light)
    _light.direction = dir
    return _light

@register("chord")
class Chord(Base):
    def setup(self):
        # Define forward chain
        self.chain_type = self.config.get("chain_type", "chord")
        self.chain = self.config.get("chain_library", {})[self.chain_type]
        self.prompts = self.config.get("rgbx_prompts", {})
        self.roughness_step = self.config.get("roughness_step", 10)
        self.metallic_step = self.config.get("metallic_step", 0.2)

        self.sd = make(self.config.stable_diffusion.name, self.config.stable_diffusion)
        self.dtype = self.sd.dtype
        self.device = self.sd.device

        # LEGO-conditioning
        self.sd.unet.ConvIns = nn.ModuleDict()
        self.sd.unet.ConvOuts = nn.ModuleDict()
        self.sd.unet.FirstDownBlocks = nn.ModuleDict()
        self.sd.unet.LastUpBlocks = nn.ModuleDict()
        for key in list(set("_".join(self.chain.values()).split("_"))) + ["noise"]:
            if "0" in key or "1" in key: continue
            self.sd.unet.ConvIns[key] = nn.Conv2d(4, 320, 3, 1 , 1, device=self.device, dtype=self.dtype)
            self.sd.unet.ConvIns[key].load_state_dict(self.sd.unet.conv_in.state_dict())
        for kout in list(set(self.chain.keys())):
            self.sd.unet.ConvOuts[kout] = nn.Conv2d(320, 4, 3, 1 , 1, device=self.device, dtype=self.dtype)
            self.sd.unet.ConvOuts[kout].load_state_dict(self.sd.unet.conv_out.state_dict())
            self.sd.unet.LastUpBlocks[kout] = copy.deepcopy(self.sd.unet.up_blocks[-1]).to(self.device)
            self.sd.unet.FirstDownBlocks[kout] = copy.deepcopy(self.sd.unet.down_blocks[0]).to(self.device)
        self.sd.unet.ConvIns.train()
        self.sd.unet.ConvOuts.train()
        self.sd.unet.FirstDownBlocks.train()
        self.sd.unet.LastUpBlocks.train()
        self.sd.unet.conv_in = dummy_module()
        self.sd.unet.conv_out = dummy_module()

        # Load Lights
        if self.config.get("prior_light", None) is None:
            self.prior_light = make("point-light", {"position": [0, 0, 10]})
        else:
            self.prior_light = make(self.config.prior_light.name, self.config.prior_light)

        # Init Embeddings
        self.text_emb = {}
    # Eq.3
    def compute_approxIrr(self, render, basecolor):
        approxIrr = safe_01_div.apply(srgb_to_rgb(render), srgb_to_rgb(basecolor))
        return tone_gamma(approxIrr)
    # Eq.6
    @torch.no_grad()
    def compute_approxRouMet(self, render, maps, seperate=False, light=None):
        render = srgb_to_rgb(render)
        bs, _, h, w = render.shape
        light = find_light_dir(maps['approxIrr'], self.prior_light) if light is None else light
        # light.direction = estimate_light_dir(render, maps)
        pos = get_positions(h, w, 10).to(self.device)
        cameras = torch.tensor([0, 0, 10.0]).to(self.device)

        # sample grid
        r_samples = torch.arange(25, 225+self.roughness_step, self.roughness_step) / 255
        m_samples = torch.arange(0., 1.+self.metallic_step, self.metallic_step)

        grid_maps = {} # change map size into: gs, bs, h, w, c
        grid_maps['basecolor'] = maps['basecolor'][None].permute(0,1,3,4,2)
        grid_maps['normal'] = maps['normal'][None].permute(0,1,3,4,2)
        r_values = r_samples[:,None].repeat(1,len(m_samples)).reshape(-1,1,1,1,1).to(maps['basecolor'])
        m_values = m_samples[None].repeat(len(r_samples),1).reshape(-1,1,1,1,1).to(maps['basecolor'])
        # split into chunks to avoid OOM
        chunk_size = 25
        rgb_list, r_list, m_list = [], [], []
        for _r, _m in zip(torch.split(r_values, chunk_size), torch.split(m_values, chunk_size)):
            grid_maps['roughness'], grid_maps['metallic'] = _r, _m
            _rgb = self.compute_render(grid_maps, cameras, pos, light)
            loss = (render[None].permute(0,1,3,4,2) - _rgb).abs().sum(-1,keepdim=True)
            min_idx = loss.argmin(dim=0,keepdim=True)
            r_list.append(torch.gather(grid_maps['roughness'].flatten(), 0, min_idx.flatten()).reshape(min_idx.shape))
            m_list.append(torch.gather(grid_maps['metallic'].flatten(), 0, min_idx.flatten()).reshape(min_idx.shape))
            rgb_list.append(torch.gather(_rgb, 0, min_idx.repeat(1,1,1,1,3)))
        rgb = torch.cat(rgb_list).permute(0,1,4,2,3)
        roughness = torch.cat(r_list).permute(0,1,4,2,3)
        metallic = torch.cat(m_list).permute(0,1,4,2,3)
        loss = (render[None] - rgb).abs().sum(2,keepdim=True)
        roughness = torch.gather(roughness, 0, loss.argmin(dim=0,keepdim=True))[0]
        metallic = torch.gather(metallic, 0, loss.argmin(dim=0,keepdim=True))[0]
        torch.cuda.empty_cache()
        if seperate:
            return roughness, metallic
        else:
            out = torch.cat([roughness, metallic, torch.zeros_like(roughness)], dim=1)
            return out


    @torch.no_grad()
    def compute_render(self, maps, camera_position, pos, light):
        '''
            maps: gs, bs, h, w, c (gs: the number of grids)
        '''
        def cos(x, y): 
            return torch.clamp((x*y).sum(-1, keepdim=True), min=0, max=1)
        
        # pre-process
        albedo = srgb_to_rgb(maps['basecolor'])
        normal = maps['normal'].clone()
        normal[..., :2] = normal[..., [1,0]]
        N = Fn.normalize((normal - 0.5) * 2.0, dim=-1, eps=1e-6)
        roughness = maps['roughness']
        metallic = maps['metallic']
        V = Fn.normalize(camera_position - pos, dim=-1, eps=1e-6).repeat(1,1,1,1,1).to(self.device)
        irradiance, L = light(pos)
        irradiance, L = irradiance.repeat(1,1,1,1,1).to(self.device), L.repeat(1,1,1,1,1).to(self.device)
        # rendering
        H = Fn.normalize(L+V, dim=-1, eps=1e-6)
        f0 = torch.ones_like(albedo).to(self.device) * 0.04
        F0 = torch.lerp(f0, albedo, metallic)
        F = fresnelSchlick(cos(H,V), F0)   
        ks = F

        diffuse = (1-ks) * albedo / torch.pi
        diffuse *= 1-metallic

        NDF = DistributionGGX(cos(N,H), roughness) 
        G = GeometrySchlickGGX(cos(N,L), roughness) * GeometrySchlickGGX(cos(N,V), roughness)

        numerator = NDF * G * F
        denominator = 4.0 * cos(N,V) * cos(N,L) + 1e-3
        specular = numerator / denominator
        ambient = 0.3 * albedo

        rgb = (diffuse + specular) * irradiance * cos(N,L) + ambient

        return rgb
    
    def forward(self, maps:dict):
        # prepare
        bs = maps['render'].shape[0]
        self.sd.scheduler.set_timesteps(1)
        t = self.sd.scheduler.timesteps[0]
        # chain processing
        pred, pred_latent, arxiv_latent = {}, {}, {}
        for kout, info in self.chain.items():
            info = info.split("_")
            keys, ids = info[:-1], info[-1]
            # Swap active LEGO blocks
            self.sd.unet.down_blocks[0] = self.sd.unet.FirstDownBlocks[kout]
            self.sd.unet.up_blocks[-1] = self.sd.unet.LastUpBlocks[kout]
            # Eq.2, summing input latents
            in_latent = 0
            for k, i in zip(keys, ids):
                if i=="0":
                    if not k in arxiv_latent.keys(): arxiv_latent[k] = self.sd.encode_imgs_deterministic(maps[k])
                    zx = arxiv_latent[k]
                else:
                    zx = pred_latent[k]
                in_latent += self.sd.unet.ConvIns[k](zx)
            in_latent = in_latent / len(keys)
            # single-step denoising
            embs = self.produce_embeddings(kout, bs)
            out_latent = self.sd.unet(in_latent, t, **embs)[0]
            out_latent = self.sd.unet.ConvOuts[kout](out_latent)            
            pred_latent[kout] = self.sd.scheduler.step(out_latent, t, torch.zeros_like(zx)).pred_original_sample
            pred[kout] = self.sd.decode_latents(pred_latent[kout]).float()
            # compute intermediate representations
            if self.chain_type in ["chord"] and kout == "basecolor":
                pred['approxIrr'] = self.compute_approxIrr(maps['render'], pred['basecolor'])
                pred_latent['approxIrr'] = self.sd.encode_imgs_deterministic(pred['approxIrr'])
            if self.chain_type in ["chord"] and kout == "normal":
                pred['approxRM'] = self.compute_approxRouMet(maps['render'], pred, seperate=False)
                pred_latent['approxRM'] = self.sd.encode_imgs_deterministic(pred['approxRM'])

        return pred     
    
    @torch.no_grad()
    def produce_embeddings(self, key, batch_size):
        if key not in self.text_emb.keys():
            self.text_emb[key] = self.sd.encode_text(self.prompts[key], "max_length")
        prompt_emb = self.text_emb[key].expand(batch_size, -1, -1)
        return { "encoder_hidden_states": prompt_emb }