ksangk commited on
Commit
a846205
·
1 Parent(s): e5992ee
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ output
LICENSE.txt ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Ubisoft Machine Learning License (Research-Only - Copyleft)
2
+
3
+ This license governs the use, reproduction, and distribution of the Licensed
4
+ Materials, including AI Models and associated source code for the sole purpose
5
+ of scientific research. By accessing, downloading or using the Licensed
6
+ Materials, you hereby accept to be bound by this [Ubisoft Machine Learning
7
+ License (Research-Only - Copyleft)] agreement (hereinafter the “License”).
8
+
9
+ 1. Licensed Materials
10
+
11
+ - AI Models
12
+ - Source Code
13
+
14
+ 2. Definitions
15
+
16
+ “Licensed Materials”: Refers to the AI Models and/or Source Code licensed under
17
+ this agreement.
18
+ "Source Code" means the preferred form of the work for making modifications to
19
+ it corresponding to text written using human-readable programming language.
20
+ "Object Code" means any non-source form of a work.
21
+ “AI Model” means any machine learning based assembly or assemblies (including
22
+ checkpoints), consisting of learnt weights, parameters (including optimizer
23
+ states), corresponding to the model architecture as embodied in the Source Code.
24
+ “Output” means the results of operating an AI Model as embodied in
25
+ informational content resulting therefrom.
26
+ “Derivative”: Any work derived from or based upon the Licensed Materials,
27
+ including modifications.
28
+ “Permitted Purpose”: Use for academic or research purposes only. Commercial
29
+ use is strictly prohibited.
30
+ “Distribution”: Any sharing of the Licensed Materials or Derivatives with third
31
+ parties, including hosting as a service.
32
+ “Licensor”: The rights holder or authorized entity granting this License.
33
+ “You”: The individual or entity receiving and exercising rights under this
34
+ License.
35
+
36
+ 3. Grant of Rights
37
+
38
+ Subject to compliance with the terms of this License, You are granted a
39
+ worldwide, royalty-free, non-exclusive License to use, study, reproduce,
40
+ modify, and distribute the Licensed Materials and Derivatives solely for the
41
+ Permitted Purpose. As between You and Licensor, Licensor claims no rights in
42
+ the Outputs You generate using the AI Models used in accordance with the
43
+ Permitted Purpose.
44
+
45
+ 4. Distribution of Licensed Materials and Derivatives
46
+
47
+ Any Distribution of the Derivatives of the Licensed Materials, or the Licensed
48
+ Materials shall be licensed under the same exact terms as this License.
49
+ Redistribution shall include this License and retain all notices of author
50
+ attribution and all modifications shall be clearly marked.
51
+
52
+ 5. Use Restrictions
53
+
54
+ You shall not use the Licensed Materials or its Derivatives for:
55
+ - any other purposes than the Permitted Purpose, including for commercial
56
+ purposes such as using the Licensed Materials in any activity intended for
57
+ commercial advantage or monetary compensation directly or indirectly;
58
+ - weaponry, warfare, military applications, surveillance, or any activity that
59
+ may cause harm or violate human rights;
60
+ - engaging or enabling fully automated decision-making that may adversely
61
+ impacts a natural person's legal rights;
62
+ - providing medical advice or making clinical decisions;
63
+ - generating content that promotes or incites hatred, violence, discrimination,
64
+ or harm based on race, ethnicity, religion, gender, sexual orientation, or
65
+ any other protected characteristic;
66
+ - generating content that includes depictions of sexual abuse, sexual
67
+ violence, explicit pornography, or any form of non-consensual acts and/or
68
+ generating content that includes depictions of child nudity, child
69
+ pornography, or any form of child exploitation;
70
+
71
+ 6. Disclaimer of Warranty
72
+
73
+ THE LICENSED MATERIALS IS PROVIDED "AS IS" AND “AS AVAILABLE” WITHOUT
74
+ WARRANTIES OF ANY KIND WHETHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
75
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE,
76
+ NON-INFRINGEMENT, CORRECTNESS, ACCURACY, OR RELIABILITY. THE LICENSOR DISCLAIMS
77
+ ALL LIABILITY FOR DAMAGES RESULTING FROM THE USE OR INABILITY TO USE THE
78
+ LICENSED MATERIALS. THE USE OF THE LICENSED MATERIALS AND ANY OUTPUTS YOU MAY
79
+ GENERATE SHALL BE AT YOUR OWN RISK.
80
+
81
+ 7. Termination
82
+
83
+ This License terminates automatically if You violate any of its terms. Upon
84
+ termination, You shall cease all use and distribution of the Licensed
85
+ Materials and its Derivatives.
86
+
87
+ 8. Governing Law
88
+
89
+ The validity of this Agreement and any of its terms and provisions, as well as
90
+ the rights and duties of the parties hereunder, shall be governed, interpreted
91
+ and enforced in accordance with the laws of France.
92
+
93
+ 9. Miscellaneous
94
+
95
+ If any provision of this License is held to be invalid, illegal or
96
+ unenforceable, the remaining provisions shall be unaffected thereby and remain
97
+ valid as if such provision had not been set forth herein.
98
+
99
+ Copyright (C) 2025 UBISOFT ENTERTAINMENT. All Rights Reserved.
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ import copy
7
+ from omegaconf import OmegaConf
8
+ from torchvision.transforms import v2
9
+ from torchvision.transforms.functional import to_pil_image
10
+
11
+ from chord import ChordModel
12
+ from chord.module import make
13
+ from chord.util import get_positions, rgb_to_srgb
14
+
15
+ EXAMPLES_USECASE_1 = [
16
+ [f"examples/generated/{f}"]
17
+ for f in sorted(os.listdir("examples/generated"))
18
+ ]
19
+ EXAMPLES_USECASE_2 = [
20
+ [f"examples/in_the_wild/{f}"]
21
+ for f in sorted(os.listdir("examples/in_the_wild"))
22
+ ]
23
+ EXAMPLES_USECASE_3 = [
24
+ [f"examples/specular/{f}"]
25
+ for f in sorted(os.listdir("examples/specular"))
26
+ ]
27
+
28
+ MODEL_OBJ = None
29
+ def load_model(ckpt_path):
30
+ print("Loading model from:", ckpt_path)
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ config = OmegaConf.load("config/chord.yaml")
33
+ model = ChordModel(config)
34
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
35
+ model.load_state_dict(ckpt["state_dict"])
36
+ model.eval()
37
+ model.to(device)
38
+ return model
39
+
40
+ def run_model(model, img: Image.Image):
41
+ to_tensor = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
42
+ image = to_tensor(img).to(next(model.parameters()).device)
43
+ x = v2.Resize(size=(1024, 1024), antialias=True)(image).unsqueeze(0)
44
+ with torch.no_grad() as no_grad, torch.autocast(device_type="cuda") as amp:
45
+ output = model(x)
46
+ output.update({"input": image})
47
+ return output
48
+
49
+ def relit(model, maps):
50
+ maps['metallic'] = maps.get('metalness', torch.zeros_like(maps['basecolor']))
51
+ device = next(model.parameters()).device
52
+ h, w = maps["basecolor"].shape[-2:]
53
+ light = make("point-light", {"position": [0, 0, 10]}).to(device)
54
+ pos = get_positions(h, w, 10).to(device)
55
+ camera = torch.tensor([0, 0, 10.0]).to(device)
56
+ for key in maps:
57
+ if maps[key].dim() == 3:
58
+ maps[key] = maps[key].unsqueeze(0)
59
+ maps[key] = maps[key].permute(0,2,3,1) # BxCxHxW -> BxHxWxC
60
+ rgb = model.model.compute_render(maps, camera, pos, light).squeeze(0).permute(0,3,1,2) # GxBxHxWxC -> BxCxHxW
61
+ return torch.clamp(rgb_to_srgb(rgb), 0, 1)
62
+
63
+ def inference(img, ckpt_path):
64
+ global MODEL_OBJ
65
+
66
+ if MODEL_OBJ is None or getattr(MODEL_OBJ, "_ckpt", None) != ckpt_path:
67
+ MODEL_OBJ = load_model(ckpt_path)
68
+ MODEL_OBJ._ckpt = ckpt_path # store path inside object
69
+
70
+ if img is None:
71
+ return None, None, None, None, None
72
+
73
+ ori_h, ori_w = img.size[1], img.size[0]
74
+ out = run_model(MODEL_OBJ, img)
75
+ maps = copy.deepcopy(out)
76
+ rendered = relit(MODEL_OBJ, maps)
77
+ resize_back = v2.Resize(size=(ori_h, ori_w), antialias=True)
78
+ return (
79
+ to_pil_image(resize_back(out["basecolor"]).squeeze(0)),
80
+ to_pil_image(resize_back(out["normal"]).squeeze(0)),
81
+ to_pil_image(resize_back(out["roughness"]).squeeze(0)),
82
+ to_pil_image(resize_back(out["metalness"]).squeeze(0)),
83
+ to_pil_image(resize_back(rendered).squeeze(0)),
84
+ )
85
+
86
+ with gr.Blocks(title="Chord") as demo:
87
+
88
+ gr.Markdown("# **Chord: Chain of Rendering Decomposition for PBR Material Estimation from Generated Texture images**")
89
+ ckpt_path = gr.Textbox(
90
+ label="Model Checkpoint Path",
91
+ value="chord_v1.ckpt",
92
+ placeholder="Path to your model checkpoint",
93
+ )
94
+ gr.Markdown("Upload an image or select an example to estimate PBR channels and render the result under custom lighting.")
95
+
96
+ with gr.Row():
97
+ with gr.Column():
98
+ input_img = gr.Image(type="pil", label="Input Image", height=512)
99
+
100
+ gr.Markdown("### Example Inputs — Generated Textures")
101
+ gr.Examples(
102
+ examples=EXAMPLES_USECASE_1,
103
+ inputs=[input_img],
104
+ label="Examples (Generated Textures)"
105
+ )
106
+
107
+ gr.Markdown("### Example Inputs — In The Wild Photographs")
108
+ gr.Examples(
109
+ examples=EXAMPLES_USECASE_2,
110
+ inputs=[input_img],
111
+ label="Examples (In The Wild Photographs)"
112
+ )
113
+
114
+ gr.Markdown("### Example Inputs — Specular Textures")
115
+ gr.Examples(
116
+ examples=EXAMPLES_USECASE_3,
117
+ inputs=[input_img],
118
+ label="Examples (Specular Textures)"
119
+ )
120
+
121
+ run_button = gr.Button("Run Estimation")
122
+
123
+ with gr.Column():
124
+ gr.Markdown("### Predicted Channels")
125
+ basecolor_out = gr.Image(label="Basecolor", height=512)
126
+ normal_out = gr.Image(label="Normal", height=512)
127
+ roughness_out = gr.Image(label="Roughness", height=512)
128
+ metallic_out = gr.Image(label="Metalness", height=512)
129
+
130
+ gr.Markdown("### Relit Output")
131
+ render_out = gr.Image(label="Relit Image (Centered Point Light)", height=512)
132
+
133
+ run_button.click(
134
+ inference,
135
+ inputs=[input_img, ckpt_path],
136
+ outputs=[basecolor_out, normal_out, roughness_out, metallic_out, render_out]
137
+ )
138
+
139
+
140
+ if __name__ == "__main__":
141
+ demo.launch()
chord/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from chord.module import make
4
+ from chord.module.chord import post_decoder
5
+
6
+ class ChordModel(nn.Module):
7
+ def __init__(self, config):
8
+ super().__init__()
9
+ self.model = make(config.model.name, config.model)
10
+
11
+ def forward(self, x: torch.Tensor):
12
+ x = {"render": x}
13
+ pred = self.model(x)
14
+ return post_decoder(pred)
chord/io.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import imageio.v3 as imageio
3
+ import numpy as np
4
+ import warnings
5
+ import os
6
+
7
+ import torchvision.transforms.functional as F
8
+
9
+ def read_image(filename: str, out: torch.Tensor=None) -> torch.Tensor:
10
+ '''
11
+ Read a local image file into a float tensor (pixel values are normalized to [0, 1], CxHxW)
12
+
13
+ Args:
14
+ filename: Image file path.
15
+ out: Fill in this tensor rather than return a new tensor if provided.
16
+
17
+ Returns:
18
+ Loaded image tensor.
19
+ '''
20
+ with warnings.catch_warnings():
21
+ warnings.simplefilter("ignore") # ignore PIL's user warning that reads fp16 img as fp32
22
+ img: np.ndarray = imageio.imread(filename)
23
+
24
+ # Convert the image array to float tensor according to its data type
25
+ res = None
26
+ if img.dtype == np.uint8:
27
+ img = img.astype(np.float32) / 255.0
28
+ elif img.dtype == np.uint16 or img.dtype == np.int32:
29
+ img = img.astype(np.float32) / 65535.0
30
+ else:
31
+ raise ValueError(f'Unrecognized image pixel value type: {img.dtype}')
32
+ if img.ndim == 2:
33
+ res = torch.from_numpy(img).unsqueeze(0) # 1xHxW for grayscale images
34
+ elif img.ndim == 3:
35
+ res = torch.from_numpy(img).movedim(2, 0)[:3] # HxWxC to CxHxW
36
+ else:
37
+ raise ValueError(f'Unrecognized image dimension: {img.shape}')
38
+
39
+ if out is None:
40
+ return res
41
+ out.copy_(res)
42
+
43
+ def create_img(img: torch.Tensor):
44
+ '''
45
+ Convert tensor to PIL image
46
+
47
+ Args:
48
+ path: Image tensor CxHxW. Squeeze if BxCxHxW and B==1
49
+
50
+ Returns:
51
+ PIL image
52
+ '''
53
+ if img.dim() == 4:
54
+ assert img.shape[0] == 1
55
+ img = img.squeeze(0)
56
+
57
+ if img.shape[0] == 4:
58
+ out_img = F.to_pil_image(img, mode="CMYK")
59
+ out_img = out_img.convert('RGB')
60
+ elif img.shape[0] == 3:
61
+ out_img = F.to_pil_image(img, mode="RGB")
62
+ elif img.shape[0] == 1:
63
+ out_img = F.to_pil_image(img, mode="L")
64
+ else:
65
+ raise ValueError("Unsupported image dimension.")
66
+ return out_img
67
+
68
+ def save_maps(path: str, maps: dict):
69
+ '''
70
+ Save SVBRDF maps to a given path.
71
+
72
+ Args:
73
+ path: Output path.
74
+ maps: Named maps of tensor images.
75
+ '''
76
+ if not os.path.exists(path):
77
+ os.makedirs(path)
78
+ for name, image in maps.items():
79
+ out_img = create_img(image)
80
+ out_img.save(os.path.join(path, name+".png"))
chord/module/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ modules = {}
2
+
3
+ def register(name):
4
+ def decorator(cls):
5
+ modules[name] = cls
6
+ return cls
7
+ return decorator
8
+
9
+
10
+ def make(name, config):
11
+ model = modules[name](config)
12
+ return model
13
+
14
+
15
+ from . import (
16
+ light,
17
+ stable_diffusion,
18
+ chord,
19
+ )
chord/module/base.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class Base(nn.Module):
5
+ def __init__(self, config):
6
+ super().__init__()
7
+ self.config = config
8
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ self.setup()
10
+
11
+ def setup(self):
12
+ raise NotImplementedError
13
+
chord/module/chord.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as Fn
5
+ from torchvision.transforms import v2
6
+
7
+ from . import register, make
8
+ from .base import Base
9
+
10
+ from chord.util import fresnelSchlick, GeometrySchlickGGX, DistributionGGX
11
+ from chord.util import srgb_to_rgb, tone_gamma, get_positions, safe_01_div
12
+
13
+ class dummy_module(nn.Module):
14
+ def forward(self, x): return x
15
+
16
+ def post_decoder(out_dict):
17
+ out = {}
18
+ for key in out_dict.keys():
19
+ if key.startswith("approx"): continue
20
+ elif key == "normal":
21
+ out[key] = Fn.normalize(2. * out_dict[key] - 1., dim=1) / 2. + 0.5
22
+ elif key == "rou_met":
23
+ out['roughness'], out['metalness'] = out_dict['rou_met'][:,0], out_dict['rou_met'][:,1]
24
+ else: out[key] = out_dict[key]
25
+ return out
26
+
27
+ def process_irradiance(radiance, kernel_size=25, res=64):
28
+ """
29
+ Process the irradiance using PyTorch, equivalent to the original OpenCV-based function.
30
+
31
+ Args:
32
+ radiance (torch.Tensor): Input radiance tensor (H, W).
33
+ kernel_size (int): Size of the kernel for the median blur.
34
+ res (int): Target resolution for resizing the image.
35
+
36
+ Returns:
37
+ torch.Tensor: Processed radiance tensor (res, res).
38
+ """
39
+ # Ensure the input radiance is a 4D tensor (B, 1, H, W)
40
+ assert radiance.shape[1] == 1 and radiance.dim() == 4, f"Invalid radiance shape, got {radiance.shape}"
41
+ # resize to low resolution
42
+ resizer = v2.Resize(size=res, antialias=True)
43
+ radiance = resizer(radiance)
44
+
45
+ # Define a 11x11 averaging kernel
46
+ kernel = torch.ones((1, 1, 11, 11), dtype=torch.float32).to(radiance) / 121.0
47
+ # Apply convolution (averaging filter)
48
+ radiance = Fn.pad(radiance, (5,)*4, mode="reflect") # Pad for edge handling
49
+ radiance = Fn.conv2d(radiance, kernel, padding=0) # 'padding=2' to maintain input dimensions
50
+
51
+ # Clamp values and scale to [0, 255] for median filtering
52
+ radiance = torch.clamp(radiance * 255, 0, 255) # Remove batch/channel dims
53
+
54
+ # Apply median filtering
55
+ paded_radiance = Fn.pad(radiance, (kernel_size // 2,) * 4, mode="reflect") # Pad for edge handling
56
+ unfolded = Fn.unfold(paded_radiance, kernel_size) # Extract patches
57
+ radiance = torch.median(unfolded, dim=1).values.view(radiance.shape) # Median of patches
58
+
59
+ # Normalize to [0, 1]
60
+ rad_min, rad_max = radiance.amin([2,3], keepdim=True), radiance.amax([2,3], keepdim=True)
61
+ radiance = (radiance - rad_min) / (rad_max - rad_min)
62
+ return radiance
63
+
64
+ def opt_light_dir(_radiance, _num_samples=6):
65
+ '''
66
+ _radiance: (bs, 1, h, w)
67
+ '''
68
+ assert _radiance.shape[1] == 1 and _radiance.dim()==4
69
+ bs, _, h, w = _radiance.shape
70
+
71
+ def evenly_sample(_num_samples, min=0, max=2*torch.pi):
72
+ # returns torch.tensor([1, _num_samples])
73
+ return torch.tensor(range(_num_samples+1)) * (max - min) / _num_samples + min
74
+
75
+ def compute_radiance_diff(angles):
76
+ num = angles.shape[-1]
77
+ dirs = torch.cat([torch.cos(angles), torch.sin(angles)]).T
78
+ pos_dir = grid_pos.repeat(num, 1, 1, 1)
79
+ pos_mask = torch.einsum("abcd,ad->abc", pos_dir, dirs) > 0
80
+ neg_mask = torch.einsum("abcd,ad->abc", pos_dir, dirs) < 0
81
+ samples_radiance = _radiance.repeat(1,num,1,1)
82
+ radiance_diff = (samples_radiance*pos_mask[None] - samples_radiance*neg_mask[None]).sum([2,3])
83
+ return radiance_diff
84
+
85
+ angle_min, angle_max = 0, 2*torch.pi
86
+ grid_pos = Fn.normalize(get_positions(h,w,10)[...,:2], dim=-1, eps=1e-6).to(_radiance)
87
+ while(((angle_max - angle_min) > (torch.pi/90))):
88
+ angles = evenly_sample(_num_samples, angle_min, angle_max)[None].to(_radiance)
89
+ diffs = compute_radiance_diff(angles).mean(0)
90
+ angle_min = angles[:,diffs.argmax()].item() - (angle_max - angle_min)/_num_samples
91
+ angle_max = angles[:,diffs.argmax()].item() + (angle_max - angle_min)/_num_samples
92
+
93
+ light_angle = angles[:, diffs.argmax()]
94
+ return torch.tensor([torch.cos(light_angle), torch.sin(light_angle)]).to(_radiance)
95
+
96
+
97
+ def find_light_dir(raw_irradiance, light):
98
+ raw_irradiance = v2.functional.rgb_to_grayscale(raw_irradiance)
99
+ irradiance = process_irradiance(raw_irradiance)
100
+ dir = opt_light_dir(irradiance)
101
+ dir = torch.cat([dir, torch.tensor([0.5**0.5]).to(dir)])
102
+ _light = copy.deepcopy(light)
103
+ _light.direction = dir
104
+ return _light
105
+
106
+ @register("chord")
107
+ class Chord(Base):
108
+ def setup(self):
109
+ # Define forward chain
110
+ self.chain_type = self.config.get("chain_type", "chord")
111
+ self.chain = self.config.get("chain_library", {})[self.chain_type]
112
+ self.prompts = self.config.get("rgbx_prompts", {})
113
+ self.roughness_step = self.config.get("roughness_step", 10)
114
+ self.metallic_step = self.config.get("metallic_step", 0.2)
115
+
116
+ self.sd = make(self.config.stable_diffusion.name, self.config.stable_diffusion)
117
+ self.dtype = self.sd.dtype
118
+ self.device = self.sd.device
119
+
120
+ # LEGO-conditioning
121
+ self.sd.unet.ConvIns = nn.ModuleDict()
122
+ self.sd.unet.ConvOuts = nn.ModuleDict()
123
+ self.sd.unet.FirstDownBlocks = nn.ModuleDict()
124
+ self.sd.unet.LastUpBlocks = nn.ModuleDict()
125
+ for key in list(set("_".join(self.chain.values()).split("_"))) + ["noise"]:
126
+ if "0" in key or "1" in key: continue
127
+ self.sd.unet.ConvIns[key] = nn.Conv2d(4, 320, 3, 1 , 1, device=self.device, dtype=self.dtype)
128
+ self.sd.unet.ConvIns[key].load_state_dict(self.sd.unet.conv_in.state_dict())
129
+ for kout in list(set(self.chain.keys())):
130
+ self.sd.unet.ConvOuts[kout] = nn.Conv2d(320, 4, 3, 1 , 1, device=self.device, dtype=self.dtype)
131
+ self.sd.unet.ConvOuts[kout].load_state_dict(self.sd.unet.conv_out.state_dict())
132
+ self.sd.unet.LastUpBlocks[kout] = copy.deepcopy(self.sd.unet.up_blocks[-1]).to(self.device)
133
+ self.sd.unet.FirstDownBlocks[kout] = copy.deepcopy(self.sd.unet.down_blocks[0]).to(self.device)
134
+ self.sd.unet.ConvIns.train()
135
+ self.sd.unet.ConvOuts.train()
136
+ self.sd.unet.FirstDownBlocks.train()
137
+ self.sd.unet.LastUpBlocks.train()
138
+ self.sd.unet.conv_in = dummy_module()
139
+ self.sd.unet.conv_out = dummy_module()
140
+
141
+ # Load Lights
142
+ if self.config.get("prior_light", None) is None:
143
+ self.prior_light = make("point-light", {"position": [0, 0, 10]})
144
+ else:
145
+ self.prior_light = make(self.config.prior_light.name, self.config.prior_light)
146
+
147
+ # Init Embeddings
148
+ self.text_emb = {}
149
+ # Eq.3
150
+ def compute_approxIrr(self, render, basecolor):
151
+ approxIrr = safe_01_div.apply(srgb_to_rgb(render), srgb_to_rgb(basecolor))
152
+ return tone_gamma(approxIrr)
153
+ # Eq.6
154
+ @torch.no_grad()
155
+ def compute_approxRouMet(self, render, maps, seperate=False, light=None):
156
+ render = srgb_to_rgb(render)
157
+ bs, _, h, w = render.shape
158
+ light = find_light_dir(maps['approxIrr'], self.prior_light) if light is None else light
159
+ # light.direction = estimate_light_dir(render, maps)
160
+ pos = get_positions(h, w, 10).to(self.device)
161
+ cameras = torch.tensor([0, 0, 10.0]).to(self.device)
162
+
163
+ # sample grid
164
+ r_samples = torch.arange(25, 225+self.roughness_step, self.roughness_step) / 255
165
+ m_samples = torch.arange(0., 1.+self.metallic_step, self.metallic_step)
166
+
167
+ grid_maps = {} # change map size into: gs, bs, h, w, c
168
+ grid_maps['basecolor'] = maps['basecolor'][None].permute(0,1,3,4,2)
169
+ grid_maps['normal'] = maps['normal'][None].permute(0,1,3,4,2)
170
+ r_values = r_samples[:,None].repeat(1,len(m_samples)).reshape(-1,1,1,1,1).to(maps['basecolor'])
171
+ m_values = m_samples[None].repeat(len(r_samples),1).reshape(-1,1,1,1,1).to(maps['basecolor'])
172
+ # split into chunks to avoid OOM
173
+ chunk_size = 25
174
+ rgb_list, r_list, m_list = [], [], []
175
+ for _r, _m in zip(torch.split(r_values, chunk_size), torch.split(m_values, chunk_size)):
176
+ grid_maps['roughness'], grid_maps['metallic'] = _r, _m
177
+ _rgb = self.compute_render(grid_maps, cameras, pos, light)
178
+ loss = (render[None].permute(0,1,3,4,2) - _rgb).abs().sum(-1,keepdim=True)
179
+ min_idx = loss.argmin(dim=0,keepdim=True)
180
+ r_list.append(torch.gather(grid_maps['roughness'].flatten(), 0, min_idx.flatten()).reshape(min_idx.shape))
181
+ m_list.append(torch.gather(grid_maps['metallic'].flatten(), 0, min_idx.flatten()).reshape(min_idx.shape))
182
+ rgb_list.append(torch.gather(_rgb, 0, min_idx.repeat(1,1,1,1,3)))
183
+ rgb = torch.cat(rgb_list).permute(0,1,4,2,3)
184
+ roughness = torch.cat(r_list).permute(0,1,4,2,3)
185
+ metallic = torch.cat(m_list).permute(0,1,4,2,3)
186
+ loss = (render[None] - rgb).abs().sum(2,keepdim=True)
187
+ roughness = torch.gather(roughness, 0, loss.argmin(dim=0,keepdim=True))[0]
188
+ metallic = torch.gather(metallic, 0, loss.argmin(dim=0,keepdim=True))[0]
189
+ torch.cuda.empty_cache()
190
+ if seperate:
191
+ return roughness, metallic
192
+ else:
193
+ out = torch.cat([roughness, metallic, torch.zeros_like(roughness)], dim=1)
194
+ return out
195
+
196
+
197
+ @torch.no_grad()
198
+ def compute_render(self, maps, camera_position, pos, light):
199
+ '''
200
+ maps: gs, bs, h, w, c (gs: the number of grids)
201
+ '''
202
+ def cos(x, y):
203
+ return torch.clamp((x*y).sum(-1, keepdim=True), min=0, max=1)
204
+
205
+ # pre-process
206
+ albedo = srgb_to_rgb(maps['basecolor'])
207
+ normal = maps['normal'].clone()
208
+ normal[..., :2] = normal[..., [1,0]]
209
+ N = Fn.normalize((normal - 0.5) * 2.0, dim=-1, eps=1e-6)
210
+ roughness = maps['roughness']
211
+ metallic = maps['metallic']
212
+ V = Fn.normalize(camera_position - pos, dim=-1, eps=1e-6).repeat(1,1,1,1,1).to(self.device)
213
+ irradiance, L = light(pos)
214
+ irradiance, L = irradiance.repeat(1,1,1,1,1).to(self.device), L.repeat(1,1,1,1,1).to(self.device)
215
+ # rendering
216
+ H = Fn.normalize(L+V, dim=-1, eps=1e-6)
217
+ f0 = torch.ones_like(albedo).to(self.device) * 0.04
218
+ F0 = torch.lerp(f0, albedo, metallic)
219
+ F = fresnelSchlick(cos(H,V), F0)
220
+ ks = F
221
+
222
+ diffuse = (1-ks) * albedo / torch.pi
223
+ diffuse *= 1-metallic
224
+
225
+ NDF = DistributionGGX(cos(N,H), roughness)
226
+ G = GeometrySchlickGGX(cos(N,L), roughness) * GeometrySchlickGGX(cos(N,V), roughness)
227
+
228
+ numerator = NDF * G * F
229
+ denominator = 4.0 * cos(N,V) * cos(N,L) + 1e-3
230
+ specular = numerator / denominator
231
+ ambient = 0.3 * albedo
232
+
233
+ rgb = (diffuse + specular) * irradiance * cos(N,L) + ambient
234
+
235
+ return rgb
236
+
237
+ def forward(self, maps:dict):
238
+ # prepare
239
+ bs = maps['render'].shape[0]
240
+ self.sd.scheduler.set_timesteps(1)
241
+ t = self.sd.scheduler.timesteps[0]
242
+ # chain processing
243
+ pred, pred_latent, arxiv_latent = {}, {}, {}
244
+ for kout, info in self.chain.items():
245
+ info = info.split("_")
246
+ keys, ids = info[:-1], info[-1]
247
+ # Swap active LEGO blocks
248
+ self.sd.unet.down_blocks[0] = self.sd.unet.FirstDownBlocks[kout]
249
+ self.sd.unet.up_blocks[-1] = self.sd.unet.LastUpBlocks[kout]
250
+ # Eq.2, summing input latents
251
+ in_latent = 0
252
+ for k, i in zip(keys, ids):
253
+ if i=="0":
254
+ if not k in arxiv_latent.keys(): arxiv_latent[k] = self.sd.encode_imgs_deterministic(maps[k])
255
+ zx = arxiv_latent[k]
256
+ else:
257
+ zx = pred_latent[k]
258
+ in_latent += self.sd.unet.ConvIns[k](zx)
259
+ in_latent = in_latent / len(keys)
260
+ # single-step denoising
261
+ embs = self.produce_embeddings(kout, bs)
262
+ out_latent = self.sd.unet(in_latent, t, **embs)[0]
263
+ out_latent = self.sd.unet.ConvOuts[kout](out_latent)
264
+ pred_latent[kout] = self.sd.scheduler.step(out_latent, t, torch.zeros_like(zx)).pred_original_sample
265
+ pred[kout] = self.sd.decode_latents(pred_latent[kout]).float()
266
+ # compute intermediate representations
267
+ if self.chain_type in ["chord"] and kout == "basecolor":
268
+ pred['approxIrr'] = self.compute_approxIrr(maps['render'], pred['basecolor'])
269
+ pred_latent['approxIrr'] = self.sd.encode_imgs_deterministic(pred['approxIrr'])
270
+ if self.chain_type in ["chord"] and kout == "normal":
271
+ pred['approxRM'] = self.compute_approxRouMet(maps['render'], pred, seperate=False)
272
+ pred_latent['approxRM'] = self.sd.encode_imgs_deterministic(pred['approxRM'])
273
+
274
+ return pred
275
+
276
+ @torch.no_grad()
277
+ def produce_embeddings(self, key, batch_size):
278
+ if key not in self.text_emb.keys():
279
+ self.text_emb[key] = self.sd.encode_text(self.prompts[key], "max_length")
280
+ prompt_emb = self.text_emb[key].expand(batch_size, -1, -1)
281
+ return { "encoder_hidden_states": prompt_emb }
chord/module/light.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional
3
+ import torch.nn.functional as Fn
4
+ import math
5
+ import copy
6
+
7
+ from . import register
8
+ from .base import Base
9
+
10
+ class BaseLight(Base):
11
+ """
12
+ Base class for light models.
13
+ """
14
+
15
+ def setup(self):
16
+ pass
17
+
18
+ def forward(self, x: Optional[torch.Tensor] = None):
19
+ """
20
+ Get the light intensity.
21
+
22
+ Args:
23
+ x: positions of shape (..., 3).
24
+
25
+ Returns:
26
+ color: radiance intensity of shape (..., 3)
27
+ d: directions of shape (..., 3).
28
+ """
29
+ raise NotImplementedError
30
+
31
+
32
+ @register("point-light")
33
+ class PointLight(BaseLight):
34
+ """Point light definitions
35
+ """
36
+ def setup(self):
37
+ """Initialize point light.
38
+
39
+ Args:
40
+ position (float, float, float): World coordinate of the light.
41
+ color (float, float, float): Light color in (R, G, B).
42
+ power (float): Light power, it will be directly multiplied to each color channel.
43
+ """
44
+ position = self.config.get("position", [0., 0., 10.])
45
+ color = self.config.get("color", [23.47, 21.31, 20.79])
46
+ power = self.config.get("power", 10.)
47
+
48
+ self.register_buffer("position", torch.tensor(position))
49
+ self.register_buffer("color", torch.tensor(color) * power)
50
+
51
+ def forward(self, x: Optional[torch.Tensor] = None):
52
+ """Compute light radiance and direction.
53
+
54
+ Args:
55
+ x : World coordinate of the interacting surface. [B, H, W, 3]
56
+ Returns:
57
+ color: radiance intensity of shape [B, H, W, 3]
58
+ d: directions of shape [B, H, W, 3], V = (light_pos - world_pos)
59
+ """
60
+ distance = torch.norm(self.position - x, dim=-1, keepdim=True)
61
+ attenuation = 1.0 / (distance ** 2)
62
+ radiance = self.color * attenuation
63
+ direction = Fn.normalize(self.position - x, dim=-1)
64
+ return radiance, direction
65
+
66
+ @register("distant-light")
67
+ class DistantLight(BaseLight):
68
+ """Distant light definitions
69
+ """
70
+ def setup(self):
71
+ """Initialize distant light.
72
+
73
+ Args:
74
+ direction (float, float, float):The direction of light vector.
75
+ color (float, float, float): Light color in (R, G, B).
76
+ power (float): Light power, it will be directly multiplied to each color channel.
77
+ """
78
+ direction = self.config.get("direction", [0., 0., 1.])
79
+ color = self.config.get("color", [23.47, 21.31, 20.79])
80
+ power = self.config.get("power", 0.1)
81
+
82
+ self.register_buffer("color", torch.tensor(color) * power)
83
+ self.register_buffer("direction", Fn.normalize(torch.tensor(direction), dim=0))
84
+
85
+ def forward(self, x: Optional[torch.Tensor] = None):
86
+ """Compute light radiance and direction.
87
+
88
+ Args:
89
+ x : World coordinate of the interacting surface. [B, H, W, 3]
90
+ Returns:
91
+ color: radiance intensity of shape [B, H, W, 3]
92
+ d: directions of shape [B, H, W, 3]
93
+ """
94
+ radiance = self.color.repeat(*x.shape[:-1], 1)
95
+ direction = self.direction.repeat(*x.shape[:-1], 1)
96
+ return radiance, direction
chord/module/stable_diffusion.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.transforms import v2
3
+
4
+ from diffusers import UNet2DConditionModel, AutoencoderKL, DDIMScheduler
5
+ from transformers import CLIPTextModel, CLIPTextConfig, CLIPTokenizer
6
+
7
+ from . import register
8
+ from .base import Base
9
+
10
+
11
+ def apply_padding(model, mode):
12
+ for layer in [layer for _, layer in model.named_modules() if isinstance(layer, torch.nn.Conv2d)]:
13
+ if mode == 'circular':
14
+ layer.padding_mode = 'circular'
15
+ else:
16
+ layer.padding_mode = 'zeros'
17
+ return model
18
+
19
+ def freeze(model):
20
+ model = model.eval()
21
+ for param in model.parameters():
22
+ param.requires_grad = False
23
+ return model
24
+
25
+ @register("stable_diffusion")
26
+ class StableDiffusion(Base):
27
+ def setup(self):
28
+ hf_key = self.config.get("hf_key", None)
29
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ fp16 = self.config.get("fp16", True)
31
+ self.dtype = torch.bfloat16 if fp16 else torch.float32
32
+ vae_padding = self.config.get("vae_padding", "zeros")
33
+
34
+ self.sd_version = self.config.get("version", 2.1)
35
+ local_files_only = False
36
+ if hf_key is not None:
37
+ print(f"[INFO] using hugging face custom model key: {hf_key}")
38
+ model_key = hf_key
39
+ local_files_only = True
40
+ elif str(self.sd_version) == "2.1":
41
+ # model_key = "stabilityai/stable-diffusion-2-1"
42
+ # StabilityAI deleted the original 2.1 model from HF, use a community version
43
+ model_key = "RedbeardNZ/stable-diffusion-2-1-base"
44
+ else:
45
+ raise ValueError(
46
+ f"Stable-diffusion version {self.sd_version} not supported."
47
+ )
48
+
49
+ # Load components separately to avoid download unnecessary weights
50
+ # 1. UNet (diffusion backbone)
51
+ unet_config = UNet2DConditionModel.load_config(model_key, subfolder="unet")
52
+ self.unet = UNet2DConditionModel.from_config(unet_config, local_files_only=local_files_only)
53
+ self.unet.to(self.device, dtype=self.dtype).eval()
54
+ # 2. VAE (image autoencoder)
55
+ vae_config = AutoencoderKL.load_config(model_key, subfolder="vae")
56
+ self.vae = AutoencoderKL.from_config(vae_config, local_files_only=local_files_only)
57
+ self.vae.to(self.device, dtype=self.dtype).eval()
58
+ self.vae = apply_padding(freeze(self.vae), vae_padding)
59
+ # 3. Text encoder (CLIP)
60
+ text_encoder_config = CLIPTextConfig.from_pretrained(model_key, subfolder="text_encoder", local_files_only=local_files_only)
61
+ self.text_encoder = CLIPTextModel(text_encoder_config)
62
+ self.text_encoder.to(self.device, dtype=self.dtype).eval()
63
+ # 4. Tokenizer (CLIP tokenizer, this one has vocab so from_pretrained is needed)
64
+ self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer", local_files_only=local_files_only)
65
+ # 5. Scheduler
66
+ scheduler_config = DDIMScheduler.load_config(model_key, subfolder="scheduler")
67
+ scheduler_config["prediction_type"] = "v_prediction"
68
+ scheduler_config["timestep_spacing"] = "trailing"
69
+ scheduler_config["rescale_betas_zero_snr"] = True
70
+ self.scheduler = DDIMScheduler.from_config(scheduler_config)
71
+
72
+ def encode_text(self, prompt, padding_mode="do_not_pad"):
73
+ # prompt: [str]
74
+ inputs = self.tokenizer(
75
+ prompt,
76
+ padding=padding_mode,
77
+ max_length=self.tokenizer.model_max_length,
78
+ return_tensors="pt",
79
+ )
80
+ embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
81
+ return embeddings
82
+
83
+ def decode_latents(self, latents):
84
+ latents = 1 / self.vae.config.scaling_factor * latents
85
+ imgs = self.vae.decode(latents).sample
86
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
87
+ return imgs
88
+
89
+ def encode_imgs(self, imgs):
90
+ if imgs.shape[1] == 1: # for grayscale maps
91
+ imgs = v2.functional.grayscale_to_rgb(imgs)
92
+ imgs = 2 * imgs - 1
93
+ posterior = self.vae.encode(imgs).latent_dist
94
+ latents = posterior.sample() * self.vae.config.scaling_factor
95
+ return latents
96
+
97
+ def encode_imgs_deterministic(self, imgs):
98
+ if imgs.shape[1] == 1: # for grayscale maps
99
+ imgs = v2.functional.grayscale_to_rgb(imgs)
100
+ imgs = 2 * imgs - 1
101
+ h = self.vae.encoder(imgs)
102
+ moments = self.vae.quant_conv(h)
103
+ mean, logvar = torch.chunk(moments, 2, dim=1)
104
+ latents = mean * self.vae.config.scaling_factor
105
+ return latents
chord/util.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def vector_dot(A: torch.Tensor, B: torch.Tensor, min=0.0) -> torch.Tensor:
4
+ return torch.clamp((A * B).sum(1, keepdim=True), min=min, max=1.0)
5
+
6
+ def srgb_to_rgb(f: torch.Tensor) -> torch.Tensor:
7
+ return torch.where(f <= 0.04045, f / 12.92, torch.pow((torch.clamp(f, 0.04045) + 0.055) / 1.055, 2.4)).to(f.dtype)
8
+
9
+ def rgb_to_srgb(f: torch.Tensor) -> torch.Tensor:
10
+ return torch.where(f <= 0.0031308, f * 12.92, torch.pow(torch.clamp(f, 0.0031308), 1.0/2.4)*1.055 - 0.055).to(f.dtype)
11
+
12
+ def tone_gamma(x: torch.Tensor) -> torch.Tensor:
13
+ x = 1 - torch.exp(-x)
14
+ return torch.pow(x, 1.0/2.2)
15
+
16
+ # safe division for value range 0-1
17
+ class safe_01_div(torch.autograd.Function):
18
+ @staticmethod
19
+ def forward(ctx, a, b):
20
+ ctx.save_for_backward(a, b)
21
+ return torch.div(a, torch.clamp(b, min=1e-4, max=1.0))
22
+
23
+ @staticmethod
24
+ def backward(ctx, grad_output):
25
+ a, b = ctx.saved_tensors
26
+ grad_input = grad_output.clone()
27
+
28
+ return torch.div(1, torch.clamp(b, min=1e-4, max=1.0)) * grad_input, -1 * torch.div(a, torch.clamp(b, min=1e-2, max=1.0)**2) * grad_input
29
+
30
+
31
+ def get_positions(h, w, real_size, use_pixel_centers=True) -> torch.Tensor:
32
+ pixel_center = 0.5 if use_pixel_centers else 0
33
+ i, j = torch.meshgrid(
34
+ torch.arange(h) + pixel_center,
35
+ torch.arange(w) + pixel_center,
36
+ indexing='ij'
37
+ )
38
+ if not isinstance(real_size, list):
39
+ real_size = [real_size] * 2
40
+ pos = torch.stack([(i / h - 0.5) * real_size[0], (j / w - 0.5) * real_size[1], torch.zeros_like(i)], dim=-1)
41
+ return pos
42
+
43
+ # N, H: (Bx3xHxW), roughness: (Bx1xHxW)
44
+ # The "D", facet distribution function in Cook-Torrence model
45
+ def DistributionGGX(cosNH, roughness):
46
+ a = roughness * roughness
47
+ a2 = a * a
48
+ cosNH2 = cosNH * cosNH
49
+ num = a2
50
+ denom = cosNH2 * (a2 - 1.0) + 1.0
51
+ denom = torch.pi * denom * denom
52
+ return num / denom
53
+
54
+ # NdotV, roughness: (Bx1xHxW)
55
+ def GeometrySchlickGGX(NdotV: torch.Tensor, roughness: torch.Tensor) -> torch.Tensor:
56
+ r = (roughness + 1.0)
57
+ k = (r*r) / 8.0
58
+
59
+ num = NdotV
60
+ denom = NdotV * (1.0 - k) + k
61
+
62
+ return num / denom
63
+
64
+ # cosTheta, F0 (Bx1xHxW)
65
+ # The "F"
66
+ def fresnelSchlick(cosTheta: torch.Tensor, F0: torch.Tensor) -> torch.Tensor:
67
+ return F0 + (1.0 - F0) * torch.pow(1.0 - cosTheta, 5.0)
config/chord.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ name: chord
3
+ roughness_step: 5.
4
+ metallic_step: 1.
5
+ # format: "OutputMapName": ConvInInput1_ConvInInput2_{0/1}
6
+ # 0/1 stands for using gt/pred image;
7
+ chain_type: chord
8
+ chain_library:
9
+ chord:
10
+ basecolor: render_0
11
+ normal: render_approxIrr_01
12
+ rou_met: render_approxRM_01
13
+ rgbx_prompts:
14
+ basecolor: Basecolor
15
+ normal: Normal
16
+ roughness: Roughness
17
+ metallic: Metallic
18
+ irradiance: Irradiance
19
+ rou_met: Roughness and Metallic
20
+ prior_light:
21
+ name: distant-light
22
+ direction: [-1.0, -1.0, 1.0] # Top-left corner towards bottom right
23
+ color: [23.47, 21.31, 20.79]
24
+ power: 0.1
25
+ stable_diffusion:
26
+ name: stable_diffusion
27
+ fp16: true
28
+ vae_padding: circular
29
+ version: 2.1
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub
2
+ diffusers
3
+ transformers
4
+ typer
5
+ omegaconf
6
+ imageio
7
+ tqdm
8
+ gradio