ksangk's picture
demo
a846205
import copy
import torch
from torch import nn
import torch.nn.functional as Fn
from torchvision.transforms import v2
from . import register, make
from .base import Base
from chord.util import fresnelSchlick, GeometrySchlickGGX, DistributionGGX
from chord.util import srgb_to_rgb, tone_gamma, get_positions, safe_01_div
class dummy_module(nn.Module):
def forward(self, x): return x
def post_decoder(out_dict):
out = {}
for key in out_dict.keys():
if key.startswith("approx"): continue
elif key == "normal":
out[key] = Fn.normalize(2. * out_dict[key] - 1., dim=1) / 2. + 0.5
elif key == "rou_met":
out['roughness'], out['metalness'] = out_dict['rou_met'][:,0], out_dict['rou_met'][:,1]
else: out[key] = out_dict[key]
return out
def process_irradiance(radiance, kernel_size=25, res=64):
"""
Process the irradiance using PyTorch, equivalent to the original OpenCV-based function.
Args:
radiance (torch.Tensor): Input radiance tensor (H, W).
kernel_size (int): Size of the kernel for the median blur.
res (int): Target resolution for resizing the image.
Returns:
torch.Tensor: Processed radiance tensor (res, res).
"""
# Ensure the input radiance is a 4D tensor (B, 1, H, W)
assert radiance.shape[1] == 1 and radiance.dim() == 4, f"Invalid radiance shape, got {radiance.shape}"
# resize to low resolution
resizer = v2.Resize(size=res, antialias=True)
radiance = resizer(radiance)
# Define a 11x11 averaging kernel
kernel = torch.ones((1, 1, 11, 11), dtype=torch.float32).to(radiance) / 121.0
# Apply convolution (averaging filter)
radiance = Fn.pad(radiance, (5,)*4, mode="reflect") # Pad for edge handling
radiance = Fn.conv2d(radiance, kernel, padding=0) # 'padding=2' to maintain input dimensions
# Clamp values and scale to [0, 255] for median filtering
radiance = torch.clamp(radiance * 255, 0, 255) # Remove batch/channel dims
# Apply median filtering
paded_radiance = Fn.pad(radiance, (kernel_size // 2,) * 4, mode="reflect") # Pad for edge handling
unfolded = Fn.unfold(paded_radiance, kernel_size) # Extract patches
radiance = torch.median(unfolded, dim=1).values.view(radiance.shape) # Median of patches
# Normalize to [0, 1]
rad_min, rad_max = radiance.amin([2,3], keepdim=True), radiance.amax([2,3], keepdim=True)
radiance = (radiance - rad_min) / (rad_max - rad_min)
return radiance
def opt_light_dir(_radiance, _num_samples=6):
'''
_radiance: (bs, 1, h, w)
'''
assert _radiance.shape[1] == 1 and _radiance.dim()==4
bs, _, h, w = _radiance.shape
def evenly_sample(_num_samples, min=0, max=2*torch.pi):
# returns torch.tensor([1, _num_samples])
return torch.tensor(range(_num_samples+1)) * (max - min) / _num_samples + min
def compute_radiance_diff(angles):
num = angles.shape[-1]
dirs = torch.cat([torch.cos(angles), torch.sin(angles)]).T
pos_dir = grid_pos.repeat(num, 1, 1, 1)
pos_mask = torch.einsum("abcd,ad->abc", pos_dir, dirs) > 0
neg_mask = torch.einsum("abcd,ad->abc", pos_dir, dirs) < 0
samples_radiance = _radiance.repeat(1,num,1,1)
radiance_diff = (samples_radiance*pos_mask[None] - samples_radiance*neg_mask[None]).sum([2,3])
return radiance_diff
angle_min, angle_max = 0, 2*torch.pi
grid_pos = Fn.normalize(get_positions(h,w,10)[...,:2], dim=-1, eps=1e-6).to(_radiance)
while(((angle_max - angle_min) > (torch.pi/90))):
angles = evenly_sample(_num_samples, angle_min, angle_max)[None].to(_radiance)
diffs = compute_radiance_diff(angles).mean(0)
angle_min = angles[:,diffs.argmax()].item() - (angle_max - angle_min)/_num_samples
angle_max = angles[:,diffs.argmax()].item() + (angle_max - angle_min)/_num_samples
light_angle = angles[:, diffs.argmax()]
return torch.tensor([torch.cos(light_angle), torch.sin(light_angle)]).to(_radiance)
def find_light_dir(raw_irradiance, light):
raw_irradiance = v2.functional.rgb_to_grayscale(raw_irradiance)
irradiance = process_irradiance(raw_irradiance)
dir = opt_light_dir(irradiance)
dir = torch.cat([dir, torch.tensor([0.5**0.5]).to(dir)])
_light = copy.deepcopy(light)
_light.direction = dir
return _light
@register("chord")
class Chord(Base):
def setup(self):
# Define forward chain
self.chain_type = self.config.get("chain_type", "chord")
self.chain = self.config.get("chain_library", {})[self.chain_type]
self.prompts = self.config.get("rgbx_prompts", {})
self.roughness_step = self.config.get("roughness_step", 10)
self.metallic_step = self.config.get("metallic_step", 0.2)
self.sd = make(self.config.stable_diffusion.name, self.config.stable_diffusion)
self.dtype = self.sd.dtype
self.device = self.sd.device
# LEGO-conditioning
self.sd.unet.ConvIns = nn.ModuleDict()
self.sd.unet.ConvOuts = nn.ModuleDict()
self.sd.unet.FirstDownBlocks = nn.ModuleDict()
self.sd.unet.LastUpBlocks = nn.ModuleDict()
for key in list(set("_".join(self.chain.values()).split("_"))) + ["noise"]:
if "0" in key or "1" in key: continue
self.sd.unet.ConvIns[key] = nn.Conv2d(4, 320, 3, 1 , 1, device=self.device, dtype=self.dtype)
self.sd.unet.ConvIns[key].load_state_dict(self.sd.unet.conv_in.state_dict())
for kout in list(set(self.chain.keys())):
self.sd.unet.ConvOuts[kout] = nn.Conv2d(320, 4, 3, 1 , 1, device=self.device, dtype=self.dtype)
self.sd.unet.ConvOuts[kout].load_state_dict(self.sd.unet.conv_out.state_dict())
self.sd.unet.LastUpBlocks[kout] = copy.deepcopy(self.sd.unet.up_blocks[-1]).to(self.device)
self.sd.unet.FirstDownBlocks[kout] = copy.deepcopy(self.sd.unet.down_blocks[0]).to(self.device)
self.sd.unet.ConvIns.train()
self.sd.unet.ConvOuts.train()
self.sd.unet.FirstDownBlocks.train()
self.sd.unet.LastUpBlocks.train()
self.sd.unet.conv_in = dummy_module()
self.sd.unet.conv_out = dummy_module()
# Load Lights
if self.config.get("prior_light", None) is None:
self.prior_light = make("point-light", {"position": [0, 0, 10]})
else:
self.prior_light = make(self.config.prior_light.name, self.config.prior_light)
# Init Embeddings
self.text_emb = {}
# Eq.3
def compute_approxIrr(self, render, basecolor):
approxIrr = safe_01_div.apply(srgb_to_rgb(render), srgb_to_rgb(basecolor))
return tone_gamma(approxIrr)
# Eq.6
@torch.no_grad()
def compute_approxRouMet(self, render, maps, seperate=False, light=None):
render = srgb_to_rgb(render)
bs, _, h, w = render.shape
light = find_light_dir(maps['approxIrr'], self.prior_light) if light is None else light
# light.direction = estimate_light_dir(render, maps)
pos = get_positions(h, w, 10).to(self.device)
cameras = torch.tensor([0, 0, 10.0]).to(self.device)
# sample grid
r_samples = torch.arange(25, 225+self.roughness_step, self.roughness_step) / 255
m_samples = torch.arange(0., 1.+self.metallic_step, self.metallic_step)
grid_maps = {} # change map size into: gs, bs, h, w, c
grid_maps['basecolor'] = maps['basecolor'][None].permute(0,1,3,4,2)
grid_maps['normal'] = maps['normal'][None].permute(0,1,3,4,2)
r_values = r_samples[:,None].repeat(1,len(m_samples)).reshape(-1,1,1,1,1).to(maps['basecolor'])
m_values = m_samples[None].repeat(len(r_samples),1).reshape(-1,1,1,1,1).to(maps['basecolor'])
# split into chunks to avoid OOM
chunk_size = 25
rgb_list, r_list, m_list = [], [], []
for _r, _m in zip(torch.split(r_values, chunk_size), torch.split(m_values, chunk_size)):
grid_maps['roughness'], grid_maps['metallic'] = _r, _m
_rgb = self.compute_render(grid_maps, cameras, pos, light)
loss = (render[None].permute(0,1,3,4,2) - _rgb).abs().sum(-1,keepdim=True)
min_idx = loss.argmin(dim=0,keepdim=True)
r_list.append(torch.gather(grid_maps['roughness'].flatten(), 0, min_idx.flatten()).reshape(min_idx.shape))
m_list.append(torch.gather(grid_maps['metallic'].flatten(), 0, min_idx.flatten()).reshape(min_idx.shape))
rgb_list.append(torch.gather(_rgb, 0, min_idx.repeat(1,1,1,1,3)))
rgb = torch.cat(rgb_list).permute(0,1,4,2,3)
roughness = torch.cat(r_list).permute(0,1,4,2,3)
metallic = torch.cat(m_list).permute(0,1,4,2,3)
loss = (render[None] - rgb).abs().sum(2,keepdim=True)
roughness = torch.gather(roughness, 0, loss.argmin(dim=0,keepdim=True))[0]
metallic = torch.gather(metallic, 0, loss.argmin(dim=0,keepdim=True))[0]
torch.cuda.empty_cache()
if seperate:
return roughness, metallic
else:
out = torch.cat([roughness, metallic, torch.zeros_like(roughness)], dim=1)
return out
@torch.no_grad()
def compute_render(self, maps, camera_position, pos, light):
'''
maps: gs, bs, h, w, c (gs: the number of grids)
'''
def cos(x, y):
return torch.clamp((x*y).sum(-1, keepdim=True), min=0, max=1)
# pre-process
albedo = srgb_to_rgb(maps['basecolor'])
normal = maps['normal'].clone()
normal[..., :2] = normal[..., [1,0]]
N = Fn.normalize((normal - 0.5) * 2.0, dim=-1, eps=1e-6)
roughness = maps['roughness']
metallic = maps['metallic']
V = Fn.normalize(camera_position - pos, dim=-1, eps=1e-6).repeat(1,1,1,1,1).to(self.device)
irradiance, L = light(pos)
irradiance, L = irradiance.repeat(1,1,1,1,1).to(self.device), L.repeat(1,1,1,1,1).to(self.device)
# rendering
H = Fn.normalize(L+V, dim=-1, eps=1e-6)
f0 = torch.ones_like(albedo).to(self.device) * 0.04
F0 = torch.lerp(f0, albedo, metallic)
F = fresnelSchlick(cos(H,V), F0)
ks = F
diffuse = (1-ks) * albedo / torch.pi
diffuse *= 1-metallic
NDF = DistributionGGX(cos(N,H), roughness)
G = GeometrySchlickGGX(cos(N,L), roughness) * GeometrySchlickGGX(cos(N,V), roughness)
numerator = NDF * G * F
denominator = 4.0 * cos(N,V) * cos(N,L) + 1e-3
specular = numerator / denominator
ambient = 0.3 * albedo
rgb = (diffuse + specular) * irradiance * cos(N,L) + ambient
return rgb
def forward(self, maps:dict):
# prepare
bs = maps['render'].shape[0]
self.sd.scheduler.set_timesteps(1)
t = self.sd.scheduler.timesteps[0]
# chain processing
pred, pred_latent, arxiv_latent = {}, {}, {}
for kout, info in self.chain.items():
info = info.split("_")
keys, ids = info[:-1], info[-1]
# Swap active LEGO blocks
self.sd.unet.down_blocks[0] = self.sd.unet.FirstDownBlocks[kout]
self.sd.unet.up_blocks[-1] = self.sd.unet.LastUpBlocks[kout]
# Eq.2, summing input latents
in_latent = 0
for k, i in zip(keys, ids):
if i=="0":
if not k in arxiv_latent.keys(): arxiv_latent[k] = self.sd.encode_imgs_deterministic(maps[k])
zx = arxiv_latent[k]
else:
zx = pred_latent[k]
in_latent += self.sd.unet.ConvIns[k](zx)
in_latent = in_latent / len(keys)
# single-step denoising
embs = self.produce_embeddings(kout, bs)
out_latent = self.sd.unet(in_latent, t, **embs)[0]
out_latent = self.sd.unet.ConvOuts[kout](out_latent)
pred_latent[kout] = self.sd.scheduler.step(out_latent, t, torch.zeros_like(zx)).pred_original_sample
pred[kout] = self.sd.decode_latents(pred_latent[kout]).float()
# compute intermediate representations
if self.chain_type in ["chord"] and kout == "basecolor":
pred['approxIrr'] = self.compute_approxIrr(maps['render'], pred['basecolor'])
pred_latent['approxIrr'] = self.sd.encode_imgs_deterministic(pred['approxIrr'])
if self.chain_type in ["chord"] and kout == "normal":
pred['approxRM'] = self.compute_approxRouMet(maps['render'], pred, seperate=False)
pred_latent['approxRM'] = self.sd.encode_imgs_deterministic(pred['approxRM'])
return pred
@torch.no_grad()
def produce_embeddings(self, key, batch_size):
if key not in self.text_emb.keys():
self.text_emb[key] = self.sd.encode_text(self.prompts[key], "max_length")
prompt_emb = self.text_emb[key].expand(batch_size, -1, -1)
return { "encoder_hidden_states": prompt_emb }