Spaces:
Running
on
Zero
Running
on
Zero
demo
Browse files- .gitignore +2 -0
- LICENSE.txt +99 -0
- app.py +141 -0
- chord/__init__.py +14 -0
- chord/io.py +80 -0
- chord/module/__init__.py +19 -0
- chord/module/base.py +13 -0
- chord/module/chord.py +281 -0
- chord/module/light.py +96 -0
- chord/module/stable_diffusion.py +105 -0
- chord/util.py +67 -0
- config/chord.yaml +29 -0
- requirements.txt +8 -0
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
output
|
LICENSE.txt
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Ubisoft Machine Learning License (Research-Only - Copyleft)
|
| 2 |
+
|
| 3 |
+
This license governs the use, reproduction, and distribution of the Licensed
|
| 4 |
+
Materials, including AI Models and associated source code for the sole purpose
|
| 5 |
+
of scientific research. By accessing, downloading or using the Licensed
|
| 6 |
+
Materials, you hereby accept to be bound by this [Ubisoft Machine Learning
|
| 7 |
+
License (Research-Only - Copyleft)] agreement (hereinafter the “License”).
|
| 8 |
+
|
| 9 |
+
1. Licensed Materials
|
| 10 |
+
|
| 11 |
+
- AI Models
|
| 12 |
+
- Source Code
|
| 13 |
+
|
| 14 |
+
2. Definitions
|
| 15 |
+
|
| 16 |
+
“Licensed Materials”: Refers to the AI Models and/or Source Code licensed under
|
| 17 |
+
this agreement.
|
| 18 |
+
"Source Code" means the preferred form of the work for making modifications to
|
| 19 |
+
it corresponding to text written using human-readable programming language.
|
| 20 |
+
"Object Code" means any non-source form of a work.
|
| 21 |
+
“AI Model” means any machine learning based assembly or assemblies (including
|
| 22 |
+
checkpoints), consisting of learnt weights, parameters (including optimizer
|
| 23 |
+
states), corresponding to the model architecture as embodied in the Source Code.
|
| 24 |
+
“Output” means the results of operating an AI Model as embodied in
|
| 25 |
+
informational content resulting therefrom.
|
| 26 |
+
“Derivative”: Any work derived from or based upon the Licensed Materials,
|
| 27 |
+
including modifications.
|
| 28 |
+
“Permitted Purpose”: Use for academic or research purposes only. Commercial
|
| 29 |
+
use is strictly prohibited.
|
| 30 |
+
“Distribution”: Any sharing of the Licensed Materials or Derivatives with third
|
| 31 |
+
parties, including hosting as a service.
|
| 32 |
+
“Licensor”: The rights holder or authorized entity granting this License.
|
| 33 |
+
“You”: The individual or entity receiving and exercising rights under this
|
| 34 |
+
License.
|
| 35 |
+
|
| 36 |
+
3. Grant of Rights
|
| 37 |
+
|
| 38 |
+
Subject to compliance with the terms of this License, You are granted a
|
| 39 |
+
worldwide, royalty-free, non-exclusive License to use, study, reproduce,
|
| 40 |
+
modify, and distribute the Licensed Materials and Derivatives solely for the
|
| 41 |
+
Permitted Purpose. As between You and Licensor, Licensor claims no rights in
|
| 42 |
+
the Outputs You generate using the AI Models used in accordance with the
|
| 43 |
+
Permitted Purpose.
|
| 44 |
+
|
| 45 |
+
4. Distribution of Licensed Materials and Derivatives
|
| 46 |
+
|
| 47 |
+
Any Distribution of the Derivatives of the Licensed Materials, or the Licensed
|
| 48 |
+
Materials shall be licensed under the same exact terms as this License.
|
| 49 |
+
Redistribution shall include this License and retain all notices of author
|
| 50 |
+
attribution and all modifications shall be clearly marked.
|
| 51 |
+
|
| 52 |
+
5. Use Restrictions
|
| 53 |
+
|
| 54 |
+
You shall not use the Licensed Materials or its Derivatives for:
|
| 55 |
+
- any other purposes than the Permitted Purpose, including for commercial
|
| 56 |
+
purposes such as using the Licensed Materials in any activity intended for
|
| 57 |
+
commercial advantage or monetary compensation directly or indirectly;
|
| 58 |
+
- weaponry, warfare, military applications, surveillance, or any activity that
|
| 59 |
+
may cause harm or violate human rights;
|
| 60 |
+
- engaging or enabling fully automated decision-making that may adversely
|
| 61 |
+
impacts a natural person's legal rights;
|
| 62 |
+
- providing medical advice or making clinical decisions;
|
| 63 |
+
- generating content that promotes or incites hatred, violence, discrimination,
|
| 64 |
+
or harm based on race, ethnicity, religion, gender, sexual orientation, or
|
| 65 |
+
any other protected characteristic;
|
| 66 |
+
- generating content that includes depictions of sexual abuse, sexual
|
| 67 |
+
violence, explicit pornography, or any form of non-consensual acts and/or
|
| 68 |
+
generating content that includes depictions of child nudity, child
|
| 69 |
+
pornography, or any form of child exploitation;
|
| 70 |
+
|
| 71 |
+
6. Disclaimer of Warranty
|
| 72 |
+
|
| 73 |
+
THE LICENSED MATERIALS IS PROVIDED "AS IS" AND “AS AVAILABLE” WITHOUT
|
| 74 |
+
WARRANTIES OF ANY KIND WHETHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
|
| 75 |
+
THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE,
|
| 76 |
+
NON-INFRINGEMENT, CORRECTNESS, ACCURACY, OR RELIABILITY. THE LICENSOR DISCLAIMS
|
| 77 |
+
ALL LIABILITY FOR DAMAGES RESULTING FROM THE USE OR INABILITY TO USE THE
|
| 78 |
+
LICENSED MATERIALS. THE USE OF THE LICENSED MATERIALS AND ANY OUTPUTS YOU MAY
|
| 79 |
+
GENERATE SHALL BE AT YOUR OWN RISK.
|
| 80 |
+
|
| 81 |
+
7. Termination
|
| 82 |
+
|
| 83 |
+
This License terminates automatically if You violate any of its terms. Upon
|
| 84 |
+
termination, You shall cease all use and distribution of the Licensed
|
| 85 |
+
Materials and its Derivatives.
|
| 86 |
+
|
| 87 |
+
8. Governing Law
|
| 88 |
+
|
| 89 |
+
The validity of this Agreement and any of its terms and provisions, as well as
|
| 90 |
+
the rights and duties of the parties hereunder, shall be governed, interpreted
|
| 91 |
+
and enforced in accordance with the laws of France.
|
| 92 |
+
|
| 93 |
+
9. Miscellaneous
|
| 94 |
+
|
| 95 |
+
If any provision of this License is held to be invalid, illegal or
|
| 96 |
+
unenforceable, the remaining provisions shall be unaffected thereby and remain
|
| 97 |
+
valid as if such provision had not been set forth herein.
|
| 98 |
+
|
| 99 |
+
Copyright (C) 2025 UBISOFT ENTERTAINMENT. All Rights Reserved.
|
app.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import torch
|
| 6 |
+
import copy
|
| 7 |
+
from omegaconf import OmegaConf
|
| 8 |
+
from torchvision.transforms import v2
|
| 9 |
+
from torchvision.transforms.functional import to_pil_image
|
| 10 |
+
|
| 11 |
+
from chord import ChordModel
|
| 12 |
+
from chord.module import make
|
| 13 |
+
from chord.util import get_positions, rgb_to_srgb
|
| 14 |
+
|
| 15 |
+
EXAMPLES_USECASE_1 = [
|
| 16 |
+
[f"examples/generated/{f}"]
|
| 17 |
+
for f in sorted(os.listdir("examples/generated"))
|
| 18 |
+
]
|
| 19 |
+
EXAMPLES_USECASE_2 = [
|
| 20 |
+
[f"examples/in_the_wild/{f}"]
|
| 21 |
+
for f in sorted(os.listdir("examples/in_the_wild"))
|
| 22 |
+
]
|
| 23 |
+
EXAMPLES_USECASE_3 = [
|
| 24 |
+
[f"examples/specular/{f}"]
|
| 25 |
+
for f in sorted(os.listdir("examples/specular"))
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
MODEL_OBJ = None
|
| 29 |
+
def load_model(ckpt_path):
|
| 30 |
+
print("Loading model from:", ckpt_path)
|
| 31 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 32 |
+
config = OmegaConf.load("config/chord.yaml")
|
| 33 |
+
model = ChordModel(config)
|
| 34 |
+
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
| 35 |
+
model.load_state_dict(ckpt["state_dict"])
|
| 36 |
+
model.eval()
|
| 37 |
+
model.to(device)
|
| 38 |
+
return model
|
| 39 |
+
|
| 40 |
+
def run_model(model, img: Image.Image):
|
| 41 |
+
to_tensor = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
|
| 42 |
+
image = to_tensor(img).to(next(model.parameters()).device)
|
| 43 |
+
x = v2.Resize(size=(1024, 1024), antialias=True)(image).unsqueeze(0)
|
| 44 |
+
with torch.no_grad() as no_grad, torch.autocast(device_type="cuda") as amp:
|
| 45 |
+
output = model(x)
|
| 46 |
+
output.update({"input": image})
|
| 47 |
+
return output
|
| 48 |
+
|
| 49 |
+
def relit(model, maps):
|
| 50 |
+
maps['metallic'] = maps.get('metalness', torch.zeros_like(maps['basecolor']))
|
| 51 |
+
device = next(model.parameters()).device
|
| 52 |
+
h, w = maps["basecolor"].shape[-2:]
|
| 53 |
+
light = make("point-light", {"position": [0, 0, 10]}).to(device)
|
| 54 |
+
pos = get_positions(h, w, 10).to(device)
|
| 55 |
+
camera = torch.tensor([0, 0, 10.0]).to(device)
|
| 56 |
+
for key in maps:
|
| 57 |
+
if maps[key].dim() == 3:
|
| 58 |
+
maps[key] = maps[key].unsqueeze(0)
|
| 59 |
+
maps[key] = maps[key].permute(0,2,3,1) # BxCxHxW -> BxHxWxC
|
| 60 |
+
rgb = model.model.compute_render(maps, camera, pos, light).squeeze(0).permute(0,3,1,2) # GxBxHxWxC -> BxCxHxW
|
| 61 |
+
return torch.clamp(rgb_to_srgb(rgb), 0, 1)
|
| 62 |
+
|
| 63 |
+
def inference(img, ckpt_path):
|
| 64 |
+
global MODEL_OBJ
|
| 65 |
+
|
| 66 |
+
if MODEL_OBJ is None or getattr(MODEL_OBJ, "_ckpt", None) != ckpt_path:
|
| 67 |
+
MODEL_OBJ = load_model(ckpt_path)
|
| 68 |
+
MODEL_OBJ._ckpt = ckpt_path # store path inside object
|
| 69 |
+
|
| 70 |
+
if img is None:
|
| 71 |
+
return None, None, None, None, None
|
| 72 |
+
|
| 73 |
+
ori_h, ori_w = img.size[1], img.size[0]
|
| 74 |
+
out = run_model(MODEL_OBJ, img)
|
| 75 |
+
maps = copy.deepcopy(out)
|
| 76 |
+
rendered = relit(MODEL_OBJ, maps)
|
| 77 |
+
resize_back = v2.Resize(size=(ori_h, ori_w), antialias=True)
|
| 78 |
+
return (
|
| 79 |
+
to_pil_image(resize_back(out["basecolor"]).squeeze(0)),
|
| 80 |
+
to_pil_image(resize_back(out["normal"]).squeeze(0)),
|
| 81 |
+
to_pil_image(resize_back(out["roughness"]).squeeze(0)),
|
| 82 |
+
to_pil_image(resize_back(out["metalness"]).squeeze(0)),
|
| 83 |
+
to_pil_image(resize_back(rendered).squeeze(0)),
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
with gr.Blocks(title="Chord") as demo:
|
| 87 |
+
|
| 88 |
+
gr.Markdown("# **Chord: Chain of Rendering Decomposition for PBR Material Estimation from Generated Texture images**")
|
| 89 |
+
ckpt_path = gr.Textbox(
|
| 90 |
+
label="Model Checkpoint Path",
|
| 91 |
+
value="chord_v1.ckpt",
|
| 92 |
+
placeholder="Path to your model checkpoint",
|
| 93 |
+
)
|
| 94 |
+
gr.Markdown("Upload an image or select an example to estimate PBR channels and render the result under custom lighting.")
|
| 95 |
+
|
| 96 |
+
with gr.Row():
|
| 97 |
+
with gr.Column():
|
| 98 |
+
input_img = gr.Image(type="pil", label="Input Image", height=512)
|
| 99 |
+
|
| 100 |
+
gr.Markdown("### Example Inputs — Generated Textures")
|
| 101 |
+
gr.Examples(
|
| 102 |
+
examples=EXAMPLES_USECASE_1,
|
| 103 |
+
inputs=[input_img],
|
| 104 |
+
label="Examples (Generated Textures)"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
gr.Markdown("### Example Inputs — In The Wild Photographs")
|
| 108 |
+
gr.Examples(
|
| 109 |
+
examples=EXAMPLES_USECASE_2,
|
| 110 |
+
inputs=[input_img],
|
| 111 |
+
label="Examples (In The Wild Photographs)"
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
gr.Markdown("### Example Inputs — Specular Textures")
|
| 115 |
+
gr.Examples(
|
| 116 |
+
examples=EXAMPLES_USECASE_3,
|
| 117 |
+
inputs=[input_img],
|
| 118 |
+
label="Examples (Specular Textures)"
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
run_button = gr.Button("Run Estimation")
|
| 122 |
+
|
| 123 |
+
with gr.Column():
|
| 124 |
+
gr.Markdown("### Predicted Channels")
|
| 125 |
+
basecolor_out = gr.Image(label="Basecolor", height=512)
|
| 126 |
+
normal_out = gr.Image(label="Normal", height=512)
|
| 127 |
+
roughness_out = gr.Image(label="Roughness", height=512)
|
| 128 |
+
metallic_out = gr.Image(label="Metalness", height=512)
|
| 129 |
+
|
| 130 |
+
gr.Markdown("### Relit Output")
|
| 131 |
+
render_out = gr.Image(label="Relit Image (Centered Point Light)", height=512)
|
| 132 |
+
|
| 133 |
+
run_button.click(
|
| 134 |
+
inference,
|
| 135 |
+
inputs=[input_img, ckpt_path],
|
| 136 |
+
outputs=[basecolor_out, normal_out, roughness_out, metallic_out, render_out]
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
if __name__ == "__main__":
|
| 141 |
+
demo.launch()
|
chord/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from chord.module import make
|
| 4 |
+
from chord.module.chord import post_decoder
|
| 5 |
+
|
| 6 |
+
class ChordModel(nn.Module):
|
| 7 |
+
def __init__(self, config):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.model = make(config.model.name, config.model)
|
| 10 |
+
|
| 11 |
+
def forward(self, x: torch.Tensor):
|
| 12 |
+
x = {"render": x}
|
| 13 |
+
pred = self.model(x)
|
| 14 |
+
return post_decoder(pred)
|
chord/io.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import imageio.v3 as imageio
|
| 3 |
+
import numpy as np
|
| 4 |
+
import warnings
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
import torchvision.transforms.functional as F
|
| 8 |
+
|
| 9 |
+
def read_image(filename: str, out: torch.Tensor=None) -> torch.Tensor:
|
| 10 |
+
'''
|
| 11 |
+
Read a local image file into a float tensor (pixel values are normalized to [0, 1], CxHxW)
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
filename: Image file path.
|
| 15 |
+
out: Fill in this tensor rather than return a new tensor if provided.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
Loaded image tensor.
|
| 19 |
+
'''
|
| 20 |
+
with warnings.catch_warnings():
|
| 21 |
+
warnings.simplefilter("ignore") # ignore PIL's user warning that reads fp16 img as fp32
|
| 22 |
+
img: np.ndarray = imageio.imread(filename)
|
| 23 |
+
|
| 24 |
+
# Convert the image array to float tensor according to its data type
|
| 25 |
+
res = None
|
| 26 |
+
if img.dtype == np.uint8:
|
| 27 |
+
img = img.astype(np.float32) / 255.0
|
| 28 |
+
elif img.dtype == np.uint16 or img.dtype == np.int32:
|
| 29 |
+
img = img.astype(np.float32) / 65535.0
|
| 30 |
+
else:
|
| 31 |
+
raise ValueError(f'Unrecognized image pixel value type: {img.dtype}')
|
| 32 |
+
if img.ndim == 2:
|
| 33 |
+
res = torch.from_numpy(img).unsqueeze(0) # 1xHxW for grayscale images
|
| 34 |
+
elif img.ndim == 3:
|
| 35 |
+
res = torch.from_numpy(img).movedim(2, 0)[:3] # HxWxC to CxHxW
|
| 36 |
+
else:
|
| 37 |
+
raise ValueError(f'Unrecognized image dimension: {img.shape}')
|
| 38 |
+
|
| 39 |
+
if out is None:
|
| 40 |
+
return res
|
| 41 |
+
out.copy_(res)
|
| 42 |
+
|
| 43 |
+
def create_img(img: torch.Tensor):
|
| 44 |
+
'''
|
| 45 |
+
Convert tensor to PIL image
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
path: Image tensor CxHxW. Squeeze if BxCxHxW and B==1
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
PIL image
|
| 52 |
+
'''
|
| 53 |
+
if img.dim() == 4:
|
| 54 |
+
assert img.shape[0] == 1
|
| 55 |
+
img = img.squeeze(0)
|
| 56 |
+
|
| 57 |
+
if img.shape[0] == 4:
|
| 58 |
+
out_img = F.to_pil_image(img, mode="CMYK")
|
| 59 |
+
out_img = out_img.convert('RGB')
|
| 60 |
+
elif img.shape[0] == 3:
|
| 61 |
+
out_img = F.to_pil_image(img, mode="RGB")
|
| 62 |
+
elif img.shape[0] == 1:
|
| 63 |
+
out_img = F.to_pil_image(img, mode="L")
|
| 64 |
+
else:
|
| 65 |
+
raise ValueError("Unsupported image dimension.")
|
| 66 |
+
return out_img
|
| 67 |
+
|
| 68 |
+
def save_maps(path: str, maps: dict):
|
| 69 |
+
'''
|
| 70 |
+
Save SVBRDF maps to a given path.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
path: Output path.
|
| 74 |
+
maps: Named maps of tensor images.
|
| 75 |
+
'''
|
| 76 |
+
if not os.path.exists(path):
|
| 77 |
+
os.makedirs(path)
|
| 78 |
+
for name, image in maps.items():
|
| 79 |
+
out_img = create_img(image)
|
| 80 |
+
out_img.save(os.path.join(path, name+".png"))
|
chord/module/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
modules = {}
|
| 2 |
+
|
| 3 |
+
def register(name):
|
| 4 |
+
def decorator(cls):
|
| 5 |
+
modules[name] = cls
|
| 6 |
+
return cls
|
| 7 |
+
return decorator
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def make(name, config):
|
| 11 |
+
model = modules[name](config)
|
| 12 |
+
return model
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
from . import (
|
| 16 |
+
light,
|
| 17 |
+
stable_diffusion,
|
| 18 |
+
chord,
|
| 19 |
+
)
|
chord/module/base.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class Base(nn.Module):
|
| 5 |
+
def __init__(self, config):
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.config = config
|
| 8 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 9 |
+
self.setup()
|
| 10 |
+
|
| 11 |
+
def setup(self):
|
| 12 |
+
raise NotImplementedError
|
| 13 |
+
|
chord/module/chord.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
import torch.nn.functional as Fn
|
| 5 |
+
from torchvision.transforms import v2
|
| 6 |
+
|
| 7 |
+
from . import register, make
|
| 8 |
+
from .base import Base
|
| 9 |
+
|
| 10 |
+
from chord.util import fresnelSchlick, GeometrySchlickGGX, DistributionGGX
|
| 11 |
+
from chord.util import srgb_to_rgb, tone_gamma, get_positions, safe_01_div
|
| 12 |
+
|
| 13 |
+
class dummy_module(nn.Module):
|
| 14 |
+
def forward(self, x): return x
|
| 15 |
+
|
| 16 |
+
def post_decoder(out_dict):
|
| 17 |
+
out = {}
|
| 18 |
+
for key in out_dict.keys():
|
| 19 |
+
if key.startswith("approx"): continue
|
| 20 |
+
elif key == "normal":
|
| 21 |
+
out[key] = Fn.normalize(2. * out_dict[key] - 1., dim=1) / 2. + 0.5
|
| 22 |
+
elif key == "rou_met":
|
| 23 |
+
out['roughness'], out['metalness'] = out_dict['rou_met'][:,0], out_dict['rou_met'][:,1]
|
| 24 |
+
else: out[key] = out_dict[key]
|
| 25 |
+
return out
|
| 26 |
+
|
| 27 |
+
def process_irradiance(radiance, kernel_size=25, res=64):
|
| 28 |
+
"""
|
| 29 |
+
Process the irradiance using PyTorch, equivalent to the original OpenCV-based function.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
radiance (torch.Tensor): Input radiance tensor (H, W).
|
| 33 |
+
kernel_size (int): Size of the kernel for the median blur.
|
| 34 |
+
res (int): Target resolution for resizing the image.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
torch.Tensor: Processed radiance tensor (res, res).
|
| 38 |
+
"""
|
| 39 |
+
# Ensure the input radiance is a 4D tensor (B, 1, H, W)
|
| 40 |
+
assert radiance.shape[1] == 1 and radiance.dim() == 4, f"Invalid radiance shape, got {radiance.shape}"
|
| 41 |
+
# resize to low resolution
|
| 42 |
+
resizer = v2.Resize(size=res, antialias=True)
|
| 43 |
+
radiance = resizer(radiance)
|
| 44 |
+
|
| 45 |
+
# Define a 11x11 averaging kernel
|
| 46 |
+
kernel = torch.ones((1, 1, 11, 11), dtype=torch.float32).to(radiance) / 121.0
|
| 47 |
+
# Apply convolution (averaging filter)
|
| 48 |
+
radiance = Fn.pad(radiance, (5,)*4, mode="reflect") # Pad for edge handling
|
| 49 |
+
radiance = Fn.conv2d(radiance, kernel, padding=0) # 'padding=2' to maintain input dimensions
|
| 50 |
+
|
| 51 |
+
# Clamp values and scale to [0, 255] for median filtering
|
| 52 |
+
radiance = torch.clamp(radiance * 255, 0, 255) # Remove batch/channel dims
|
| 53 |
+
|
| 54 |
+
# Apply median filtering
|
| 55 |
+
paded_radiance = Fn.pad(radiance, (kernel_size // 2,) * 4, mode="reflect") # Pad for edge handling
|
| 56 |
+
unfolded = Fn.unfold(paded_radiance, kernel_size) # Extract patches
|
| 57 |
+
radiance = torch.median(unfolded, dim=1).values.view(radiance.shape) # Median of patches
|
| 58 |
+
|
| 59 |
+
# Normalize to [0, 1]
|
| 60 |
+
rad_min, rad_max = radiance.amin([2,3], keepdim=True), radiance.amax([2,3], keepdim=True)
|
| 61 |
+
radiance = (radiance - rad_min) / (rad_max - rad_min)
|
| 62 |
+
return radiance
|
| 63 |
+
|
| 64 |
+
def opt_light_dir(_radiance, _num_samples=6):
|
| 65 |
+
'''
|
| 66 |
+
_radiance: (bs, 1, h, w)
|
| 67 |
+
'''
|
| 68 |
+
assert _radiance.shape[1] == 1 and _radiance.dim()==4
|
| 69 |
+
bs, _, h, w = _radiance.shape
|
| 70 |
+
|
| 71 |
+
def evenly_sample(_num_samples, min=0, max=2*torch.pi):
|
| 72 |
+
# returns torch.tensor([1, _num_samples])
|
| 73 |
+
return torch.tensor(range(_num_samples+1)) * (max - min) / _num_samples + min
|
| 74 |
+
|
| 75 |
+
def compute_radiance_diff(angles):
|
| 76 |
+
num = angles.shape[-1]
|
| 77 |
+
dirs = torch.cat([torch.cos(angles), torch.sin(angles)]).T
|
| 78 |
+
pos_dir = grid_pos.repeat(num, 1, 1, 1)
|
| 79 |
+
pos_mask = torch.einsum("abcd,ad->abc", pos_dir, dirs) > 0
|
| 80 |
+
neg_mask = torch.einsum("abcd,ad->abc", pos_dir, dirs) < 0
|
| 81 |
+
samples_radiance = _radiance.repeat(1,num,1,1)
|
| 82 |
+
radiance_diff = (samples_radiance*pos_mask[None] - samples_radiance*neg_mask[None]).sum([2,3])
|
| 83 |
+
return radiance_diff
|
| 84 |
+
|
| 85 |
+
angle_min, angle_max = 0, 2*torch.pi
|
| 86 |
+
grid_pos = Fn.normalize(get_positions(h,w,10)[...,:2], dim=-1, eps=1e-6).to(_radiance)
|
| 87 |
+
while(((angle_max - angle_min) > (torch.pi/90))):
|
| 88 |
+
angles = evenly_sample(_num_samples, angle_min, angle_max)[None].to(_radiance)
|
| 89 |
+
diffs = compute_radiance_diff(angles).mean(0)
|
| 90 |
+
angle_min = angles[:,diffs.argmax()].item() - (angle_max - angle_min)/_num_samples
|
| 91 |
+
angle_max = angles[:,diffs.argmax()].item() + (angle_max - angle_min)/_num_samples
|
| 92 |
+
|
| 93 |
+
light_angle = angles[:, diffs.argmax()]
|
| 94 |
+
return torch.tensor([torch.cos(light_angle), torch.sin(light_angle)]).to(_radiance)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def find_light_dir(raw_irradiance, light):
|
| 98 |
+
raw_irradiance = v2.functional.rgb_to_grayscale(raw_irradiance)
|
| 99 |
+
irradiance = process_irradiance(raw_irradiance)
|
| 100 |
+
dir = opt_light_dir(irradiance)
|
| 101 |
+
dir = torch.cat([dir, torch.tensor([0.5**0.5]).to(dir)])
|
| 102 |
+
_light = copy.deepcopy(light)
|
| 103 |
+
_light.direction = dir
|
| 104 |
+
return _light
|
| 105 |
+
|
| 106 |
+
@register("chord")
|
| 107 |
+
class Chord(Base):
|
| 108 |
+
def setup(self):
|
| 109 |
+
# Define forward chain
|
| 110 |
+
self.chain_type = self.config.get("chain_type", "chord")
|
| 111 |
+
self.chain = self.config.get("chain_library", {})[self.chain_type]
|
| 112 |
+
self.prompts = self.config.get("rgbx_prompts", {})
|
| 113 |
+
self.roughness_step = self.config.get("roughness_step", 10)
|
| 114 |
+
self.metallic_step = self.config.get("metallic_step", 0.2)
|
| 115 |
+
|
| 116 |
+
self.sd = make(self.config.stable_diffusion.name, self.config.stable_diffusion)
|
| 117 |
+
self.dtype = self.sd.dtype
|
| 118 |
+
self.device = self.sd.device
|
| 119 |
+
|
| 120 |
+
# LEGO-conditioning
|
| 121 |
+
self.sd.unet.ConvIns = nn.ModuleDict()
|
| 122 |
+
self.sd.unet.ConvOuts = nn.ModuleDict()
|
| 123 |
+
self.sd.unet.FirstDownBlocks = nn.ModuleDict()
|
| 124 |
+
self.sd.unet.LastUpBlocks = nn.ModuleDict()
|
| 125 |
+
for key in list(set("_".join(self.chain.values()).split("_"))) + ["noise"]:
|
| 126 |
+
if "0" in key or "1" in key: continue
|
| 127 |
+
self.sd.unet.ConvIns[key] = nn.Conv2d(4, 320, 3, 1 , 1, device=self.device, dtype=self.dtype)
|
| 128 |
+
self.sd.unet.ConvIns[key].load_state_dict(self.sd.unet.conv_in.state_dict())
|
| 129 |
+
for kout in list(set(self.chain.keys())):
|
| 130 |
+
self.sd.unet.ConvOuts[kout] = nn.Conv2d(320, 4, 3, 1 , 1, device=self.device, dtype=self.dtype)
|
| 131 |
+
self.sd.unet.ConvOuts[kout].load_state_dict(self.sd.unet.conv_out.state_dict())
|
| 132 |
+
self.sd.unet.LastUpBlocks[kout] = copy.deepcopy(self.sd.unet.up_blocks[-1]).to(self.device)
|
| 133 |
+
self.sd.unet.FirstDownBlocks[kout] = copy.deepcopy(self.sd.unet.down_blocks[0]).to(self.device)
|
| 134 |
+
self.sd.unet.ConvIns.train()
|
| 135 |
+
self.sd.unet.ConvOuts.train()
|
| 136 |
+
self.sd.unet.FirstDownBlocks.train()
|
| 137 |
+
self.sd.unet.LastUpBlocks.train()
|
| 138 |
+
self.sd.unet.conv_in = dummy_module()
|
| 139 |
+
self.sd.unet.conv_out = dummy_module()
|
| 140 |
+
|
| 141 |
+
# Load Lights
|
| 142 |
+
if self.config.get("prior_light", None) is None:
|
| 143 |
+
self.prior_light = make("point-light", {"position": [0, 0, 10]})
|
| 144 |
+
else:
|
| 145 |
+
self.prior_light = make(self.config.prior_light.name, self.config.prior_light)
|
| 146 |
+
|
| 147 |
+
# Init Embeddings
|
| 148 |
+
self.text_emb = {}
|
| 149 |
+
# Eq.3
|
| 150 |
+
def compute_approxIrr(self, render, basecolor):
|
| 151 |
+
approxIrr = safe_01_div.apply(srgb_to_rgb(render), srgb_to_rgb(basecolor))
|
| 152 |
+
return tone_gamma(approxIrr)
|
| 153 |
+
# Eq.6
|
| 154 |
+
@torch.no_grad()
|
| 155 |
+
def compute_approxRouMet(self, render, maps, seperate=False, light=None):
|
| 156 |
+
render = srgb_to_rgb(render)
|
| 157 |
+
bs, _, h, w = render.shape
|
| 158 |
+
light = find_light_dir(maps['approxIrr'], self.prior_light) if light is None else light
|
| 159 |
+
# light.direction = estimate_light_dir(render, maps)
|
| 160 |
+
pos = get_positions(h, w, 10).to(self.device)
|
| 161 |
+
cameras = torch.tensor([0, 0, 10.0]).to(self.device)
|
| 162 |
+
|
| 163 |
+
# sample grid
|
| 164 |
+
r_samples = torch.arange(25, 225+self.roughness_step, self.roughness_step) / 255
|
| 165 |
+
m_samples = torch.arange(0., 1.+self.metallic_step, self.metallic_step)
|
| 166 |
+
|
| 167 |
+
grid_maps = {} # change map size into: gs, bs, h, w, c
|
| 168 |
+
grid_maps['basecolor'] = maps['basecolor'][None].permute(0,1,3,4,2)
|
| 169 |
+
grid_maps['normal'] = maps['normal'][None].permute(0,1,3,4,2)
|
| 170 |
+
r_values = r_samples[:,None].repeat(1,len(m_samples)).reshape(-1,1,1,1,1).to(maps['basecolor'])
|
| 171 |
+
m_values = m_samples[None].repeat(len(r_samples),1).reshape(-1,1,1,1,1).to(maps['basecolor'])
|
| 172 |
+
# split into chunks to avoid OOM
|
| 173 |
+
chunk_size = 25
|
| 174 |
+
rgb_list, r_list, m_list = [], [], []
|
| 175 |
+
for _r, _m in zip(torch.split(r_values, chunk_size), torch.split(m_values, chunk_size)):
|
| 176 |
+
grid_maps['roughness'], grid_maps['metallic'] = _r, _m
|
| 177 |
+
_rgb = self.compute_render(grid_maps, cameras, pos, light)
|
| 178 |
+
loss = (render[None].permute(0,1,3,4,2) - _rgb).abs().sum(-1,keepdim=True)
|
| 179 |
+
min_idx = loss.argmin(dim=0,keepdim=True)
|
| 180 |
+
r_list.append(torch.gather(grid_maps['roughness'].flatten(), 0, min_idx.flatten()).reshape(min_idx.shape))
|
| 181 |
+
m_list.append(torch.gather(grid_maps['metallic'].flatten(), 0, min_idx.flatten()).reshape(min_idx.shape))
|
| 182 |
+
rgb_list.append(torch.gather(_rgb, 0, min_idx.repeat(1,1,1,1,3)))
|
| 183 |
+
rgb = torch.cat(rgb_list).permute(0,1,4,2,3)
|
| 184 |
+
roughness = torch.cat(r_list).permute(0,1,4,2,3)
|
| 185 |
+
metallic = torch.cat(m_list).permute(0,1,4,2,3)
|
| 186 |
+
loss = (render[None] - rgb).abs().sum(2,keepdim=True)
|
| 187 |
+
roughness = torch.gather(roughness, 0, loss.argmin(dim=0,keepdim=True))[0]
|
| 188 |
+
metallic = torch.gather(metallic, 0, loss.argmin(dim=0,keepdim=True))[0]
|
| 189 |
+
torch.cuda.empty_cache()
|
| 190 |
+
if seperate:
|
| 191 |
+
return roughness, metallic
|
| 192 |
+
else:
|
| 193 |
+
out = torch.cat([roughness, metallic, torch.zeros_like(roughness)], dim=1)
|
| 194 |
+
return out
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
@torch.no_grad()
|
| 198 |
+
def compute_render(self, maps, camera_position, pos, light):
|
| 199 |
+
'''
|
| 200 |
+
maps: gs, bs, h, w, c (gs: the number of grids)
|
| 201 |
+
'''
|
| 202 |
+
def cos(x, y):
|
| 203 |
+
return torch.clamp((x*y).sum(-1, keepdim=True), min=0, max=1)
|
| 204 |
+
|
| 205 |
+
# pre-process
|
| 206 |
+
albedo = srgb_to_rgb(maps['basecolor'])
|
| 207 |
+
normal = maps['normal'].clone()
|
| 208 |
+
normal[..., :2] = normal[..., [1,0]]
|
| 209 |
+
N = Fn.normalize((normal - 0.5) * 2.0, dim=-1, eps=1e-6)
|
| 210 |
+
roughness = maps['roughness']
|
| 211 |
+
metallic = maps['metallic']
|
| 212 |
+
V = Fn.normalize(camera_position - pos, dim=-1, eps=1e-6).repeat(1,1,1,1,1).to(self.device)
|
| 213 |
+
irradiance, L = light(pos)
|
| 214 |
+
irradiance, L = irradiance.repeat(1,1,1,1,1).to(self.device), L.repeat(1,1,1,1,1).to(self.device)
|
| 215 |
+
# rendering
|
| 216 |
+
H = Fn.normalize(L+V, dim=-1, eps=1e-6)
|
| 217 |
+
f0 = torch.ones_like(albedo).to(self.device) * 0.04
|
| 218 |
+
F0 = torch.lerp(f0, albedo, metallic)
|
| 219 |
+
F = fresnelSchlick(cos(H,V), F0)
|
| 220 |
+
ks = F
|
| 221 |
+
|
| 222 |
+
diffuse = (1-ks) * albedo / torch.pi
|
| 223 |
+
diffuse *= 1-metallic
|
| 224 |
+
|
| 225 |
+
NDF = DistributionGGX(cos(N,H), roughness)
|
| 226 |
+
G = GeometrySchlickGGX(cos(N,L), roughness) * GeometrySchlickGGX(cos(N,V), roughness)
|
| 227 |
+
|
| 228 |
+
numerator = NDF * G * F
|
| 229 |
+
denominator = 4.0 * cos(N,V) * cos(N,L) + 1e-3
|
| 230 |
+
specular = numerator / denominator
|
| 231 |
+
ambient = 0.3 * albedo
|
| 232 |
+
|
| 233 |
+
rgb = (diffuse + specular) * irradiance * cos(N,L) + ambient
|
| 234 |
+
|
| 235 |
+
return rgb
|
| 236 |
+
|
| 237 |
+
def forward(self, maps:dict):
|
| 238 |
+
# prepare
|
| 239 |
+
bs = maps['render'].shape[0]
|
| 240 |
+
self.sd.scheduler.set_timesteps(1)
|
| 241 |
+
t = self.sd.scheduler.timesteps[0]
|
| 242 |
+
# chain processing
|
| 243 |
+
pred, pred_latent, arxiv_latent = {}, {}, {}
|
| 244 |
+
for kout, info in self.chain.items():
|
| 245 |
+
info = info.split("_")
|
| 246 |
+
keys, ids = info[:-1], info[-1]
|
| 247 |
+
# Swap active LEGO blocks
|
| 248 |
+
self.sd.unet.down_blocks[0] = self.sd.unet.FirstDownBlocks[kout]
|
| 249 |
+
self.sd.unet.up_blocks[-1] = self.sd.unet.LastUpBlocks[kout]
|
| 250 |
+
# Eq.2, summing input latents
|
| 251 |
+
in_latent = 0
|
| 252 |
+
for k, i in zip(keys, ids):
|
| 253 |
+
if i=="0":
|
| 254 |
+
if not k in arxiv_latent.keys(): arxiv_latent[k] = self.sd.encode_imgs_deterministic(maps[k])
|
| 255 |
+
zx = arxiv_latent[k]
|
| 256 |
+
else:
|
| 257 |
+
zx = pred_latent[k]
|
| 258 |
+
in_latent += self.sd.unet.ConvIns[k](zx)
|
| 259 |
+
in_latent = in_latent / len(keys)
|
| 260 |
+
# single-step denoising
|
| 261 |
+
embs = self.produce_embeddings(kout, bs)
|
| 262 |
+
out_latent = self.sd.unet(in_latent, t, **embs)[0]
|
| 263 |
+
out_latent = self.sd.unet.ConvOuts[kout](out_latent)
|
| 264 |
+
pred_latent[kout] = self.sd.scheduler.step(out_latent, t, torch.zeros_like(zx)).pred_original_sample
|
| 265 |
+
pred[kout] = self.sd.decode_latents(pred_latent[kout]).float()
|
| 266 |
+
# compute intermediate representations
|
| 267 |
+
if self.chain_type in ["chord"] and kout == "basecolor":
|
| 268 |
+
pred['approxIrr'] = self.compute_approxIrr(maps['render'], pred['basecolor'])
|
| 269 |
+
pred_latent['approxIrr'] = self.sd.encode_imgs_deterministic(pred['approxIrr'])
|
| 270 |
+
if self.chain_type in ["chord"] and kout == "normal":
|
| 271 |
+
pred['approxRM'] = self.compute_approxRouMet(maps['render'], pred, seperate=False)
|
| 272 |
+
pred_latent['approxRM'] = self.sd.encode_imgs_deterministic(pred['approxRM'])
|
| 273 |
+
|
| 274 |
+
return pred
|
| 275 |
+
|
| 276 |
+
@torch.no_grad()
|
| 277 |
+
def produce_embeddings(self, key, batch_size):
|
| 278 |
+
if key not in self.text_emb.keys():
|
| 279 |
+
self.text_emb[key] = self.sd.encode_text(self.prompts[key], "max_length")
|
| 280 |
+
prompt_emb = self.text_emb[key].expand(batch_size, -1, -1)
|
| 281 |
+
return { "encoder_hidden_states": prompt_emb }
|
chord/module/light.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Optional
|
| 3 |
+
import torch.nn.functional as Fn
|
| 4 |
+
import math
|
| 5 |
+
import copy
|
| 6 |
+
|
| 7 |
+
from . import register
|
| 8 |
+
from .base import Base
|
| 9 |
+
|
| 10 |
+
class BaseLight(Base):
|
| 11 |
+
"""
|
| 12 |
+
Base class for light models.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def setup(self):
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
def forward(self, x: Optional[torch.Tensor] = None):
|
| 19 |
+
"""
|
| 20 |
+
Get the light intensity.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
x: positions of shape (..., 3).
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
color: radiance intensity of shape (..., 3)
|
| 27 |
+
d: directions of shape (..., 3).
|
| 28 |
+
"""
|
| 29 |
+
raise NotImplementedError
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@register("point-light")
|
| 33 |
+
class PointLight(BaseLight):
|
| 34 |
+
"""Point light definitions
|
| 35 |
+
"""
|
| 36 |
+
def setup(self):
|
| 37 |
+
"""Initialize point light.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
position (float, float, float): World coordinate of the light.
|
| 41 |
+
color (float, float, float): Light color in (R, G, B).
|
| 42 |
+
power (float): Light power, it will be directly multiplied to each color channel.
|
| 43 |
+
"""
|
| 44 |
+
position = self.config.get("position", [0., 0., 10.])
|
| 45 |
+
color = self.config.get("color", [23.47, 21.31, 20.79])
|
| 46 |
+
power = self.config.get("power", 10.)
|
| 47 |
+
|
| 48 |
+
self.register_buffer("position", torch.tensor(position))
|
| 49 |
+
self.register_buffer("color", torch.tensor(color) * power)
|
| 50 |
+
|
| 51 |
+
def forward(self, x: Optional[torch.Tensor] = None):
|
| 52 |
+
"""Compute light radiance and direction.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
x : World coordinate of the interacting surface. [B, H, W, 3]
|
| 56 |
+
Returns:
|
| 57 |
+
color: radiance intensity of shape [B, H, W, 3]
|
| 58 |
+
d: directions of shape [B, H, W, 3], V = (light_pos - world_pos)
|
| 59 |
+
"""
|
| 60 |
+
distance = torch.norm(self.position - x, dim=-1, keepdim=True)
|
| 61 |
+
attenuation = 1.0 / (distance ** 2)
|
| 62 |
+
radiance = self.color * attenuation
|
| 63 |
+
direction = Fn.normalize(self.position - x, dim=-1)
|
| 64 |
+
return radiance, direction
|
| 65 |
+
|
| 66 |
+
@register("distant-light")
|
| 67 |
+
class DistantLight(BaseLight):
|
| 68 |
+
"""Distant light definitions
|
| 69 |
+
"""
|
| 70 |
+
def setup(self):
|
| 71 |
+
"""Initialize distant light.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
direction (float, float, float):The direction of light vector.
|
| 75 |
+
color (float, float, float): Light color in (R, G, B).
|
| 76 |
+
power (float): Light power, it will be directly multiplied to each color channel.
|
| 77 |
+
"""
|
| 78 |
+
direction = self.config.get("direction", [0., 0., 1.])
|
| 79 |
+
color = self.config.get("color", [23.47, 21.31, 20.79])
|
| 80 |
+
power = self.config.get("power", 0.1)
|
| 81 |
+
|
| 82 |
+
self.register_buffer("color", torch.tensor(color) * power)
|
| 83 |
+
self.register_buffer("direction", Fn.normalize(torch.tensor(direction), dim=0))
|
| 84 |
+
|
| 85 |
+
def forward(self, x: Optional[torch.Tensor] = None):
|
| 86 |
+
"""Compute light radiance and direction.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
x : World coordinate of the interacting surface. [B, H, W, 3]
|
| 90 |
+
Returns:
|
| 91 |
+
color: radiance intensity of shape [B, H, W, 3]
|
| 92 |
+
d: directions of shape [B, H, W, 3]
|
| 93 |
+
"""
|
| 94 |
+
radiance = self.color.repeat(*x.shape[:-1], 1)
|
| 95 |
+
direction = self.direction.repeat(*x.shape[:-1], 1)
|
| 96 |
+
return radiance, direction
|
chord/module/stable_diffusion.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision.transforms import v2
|
| 3 |
+
|
| 4 |
+
from diffusers import UNet2DConditionModel, AutoencoderKL, DDIMScheduler
|
| 5 |
+
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTokenizer
|
| 6 |
+
|
| 7 |
+
from . import register
|
| 8 |
+
from .base import Base
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def apply_padding(model, mode):
|
| 12 |
+
for layer in [layer for _, layer in model.named_modules() if isinstance(layer, torch.nn.Conv2d)]:
|
| 13 |
+
if mode == 'circular':
|
| 14 |
+
layer.padding_mode = 'circular'
|
| 15 |
+
else:
|
| 16 |
+
layer.padding_mode = 'zeros'
|
| 17 |
+
return model
|
| 18 |
+
|
| 19 |
+
def freeze(model):
|
| 20 |
+
model = model.eval()
|
| 21 |
+
for param in model.parameters():
|
| 22 |
+
param.requires_grad = False
|
| 23 |
+
return model
|
| 24 |
+
|
| 25 |
+
@register("stable_diffusion")
|
| 26 |
+
class StableDiffusion(Base):
|
| 27 |
+
def setup(self):
|
| 28 |
+
hf_key = self.config.get("hf_key", None)
|
| 29 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 30 |
+
fp16 = self.config.get("fp16", True)
|
| 31 |
+
self.dtype = torch.bfloat16 if fp16 else torch.float32
|
| 32 |
+
vae_padding = self.config.get("vae_padding", "zeros")
|
| 33 |
+
|
| 34 |
+
self.sd_version = self.config.get("version", 2.1)
|
| 35 |
+
local_files_only = False
|
| 36 |
+
if hf_key is not None:
|
| 37 |
+
print(f"[INFO] using hugging face custom model key: {hf_key}")
|
| 38 |
+
model_key = hf_key
|
| 39 |
+
local_files_only = True
|
| 40 |
+
elif str(self.sd_version) == "2.1":
|
| 41 |
+
# model_key = "stabilityai/stable-diffusion-2-1"
|
| 42 |
+
# StabilityAI deleted the original 2.1 model from HF, use a community version
|
| 43 |
+
model_key = "RedbeardNZ/stable-diffusion-2-1-base"
|
| 44 |
+
else:
|
| 45 |
+
raise ValueError(
|
| 46 |
+
f"Stable-diffusion version {self.sd_version} not supported."
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Load components separately to avoid download unnecessary weights
|
| 50 |
+
# 1. UNet (diffusion backbone)
|
| 51 |
+
unet_config = UNet2DConditionModel.load_config(model_key, subfolder="unet")
|
| 52 |
+
self.unet = UNet2DConditionModel.from_config(unet_config, local_files_only=local_files_only)
|
| 53 |
+
self.unet.to(self.device, dtype=self.dtype).eval()
|
| 54 |
+
# 2. VAE (image autoencoder)
|
| 55 |
+
vae_config = AutoencoderKL.load_config(model_key, subfolder="vae")
|
| 56 |
+
self.vae = AutoencoderKL.from_config(vae_config, local_files_only=local_files_only)
|
| 57 |
+
self.vae.to(self.device, dtype=self.dtype).eval()
|
| 58 |
+
self.vae = apply_padding(freeze(self.vae), vae_padding)
|
| 59 |
+
# 3. Text encoder (CLIP)
|
| 60 |
+
text_encoder_config = CLIPTextConfig.from_pretrained(model_key, subfolder="text_encoder", local_files_only=local_files_only)
|
| 61 |
+
self.text_encoder = CLIPTextModel(text_encoder_config)
|
| 62 |
+
self.text_encoder.to(self.device, dtype=self.dtype).eval()
|
| 63 |
+
# 4. Tokenizer (CLIP tokenizer, this one has vocab so from_pretrained is needed)
|
| 64 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer", local_files_only=local_files_only)
|
| 65 |
+
# 5. Scheduler
|
| 66 |
+
scheduler_config = DDIMScheduler.load_config(model_key, subfolder="scheduler")
|
| 67 |
+
scheduler_config["prediction_type"] = "v_prediction"
|
| 68 |
+
scheduler_config["timestep_spacing"] = "trailing"
|
| 69 |
+
scheduler_config["rescale_betas_zero_snr"] = True
|
| 70 |
+
self.scheduler = DDIMScheduler.from_config(scheduler_config)
|
| 71 |
+
|
| 72 |
+
def encode_text(self, prompt, padding_mode="do_not_pad"):
|
| 73 |
+
# prompt: [str]
|
| 74 |
+
inputs = self.tokenizer(
|
| 75 |
+
prompt,
|
| 76 |
+
padding=padding_mode,
|
| 77 |
+
max_length=self.tokenizer.model_max_length,
|
| 78 |
+
return_tensors="pt",
|
| 79 |
+
)
|
| 80 |
+
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
|
| 81 |
+
return embeddings
|
| 82 |
+
|
| 83 |
+
def decode_latents(self, latents):
|
| 84 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 85 |
+
imgs = self.vae.decode(latents).sample
|
| 86 |
+
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
| 87 |
+
return imgs
|
| 88 |
+
|
| 89 |
+
def encode_imgs(self, imgs):
|
| 90 |
+
if imgs.shape[1] == 1: # for grayscale maps
|
| 91 |
+
imgs = v2.functional.grayscale_to_rgb(imgs)
|
| 92 |
+
imgs = 2 * imgs - 1
|
| 93 |
+
posterior = self.vae.encode(imgs).latent_dist
|
| 94 |
+
latents = posterior.sample() * self.vae.config.scaling_factor
|
| 95 |
+
return latents
|
| 96 |
+
|
| 97 |
+
def encode_imgs_deterministic(self, imgs):
|
| 98 |
+
if imgs.shape[1] == 1: # for grayscale maps
|
| 99 |
+
imgs = v2.functional.grayscale_to_rgb(imgs)
|
| 100 |
+
imgs = 2 * imgs - 1
|
| 101 |
+
h = self.vae.encoder(imgs)
|
| 102 |
+
moments = self.vae.quant_conv(h)
|
| 103 |
+
mean, logvar = torch.chunk(moments, 2, dim=1)
|
| 104 |
+
latents = mean * self.vae.config.scaling_factor
|
| 105 |
+
return latents
|
chord/util.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def vector_dot(A: torch.Tensor, B: torch.Tensor, min=0.0) -> torch.Tensor:
|
| 4 |
+
return torch.clamp((A * B).sum(1, keepdim=True), min=min, max=1.0)
|
| 5 |
+
|
| 6 |
+
def srgb_to_rgb(f: torch.Tensor) -> torch.Tensor:
|
| 7 |
+
return torch.where(f <= 0.04045, f / 12.92, torch.pow((torch.clamp(f, 0.04045) + 0.055) / 1.055, 2.4)).to(f.dtype)
|
| 8 |
+
|
| 9 |
+
def rgb_to_srgb(f: torch.Tensor) -> torch.Tensor:
|
| 10 |
+
return torch.where(f <= 0.0031308, f * 12.92, torch.pow(torch.clamp(f, 0.0031308), 1.0/2.4)*1.055 - 0.055).to(f.dtype)
|
| 11 |
+
|
| 12 |
+
def tone_gamma(x: torch.Tensor) -> torch.Tensor:
|
| 13 |
+
x = 1 - torch.exp(-x)
|
| 14 |
+
return torch.pow(x, 1.0/2.2)
|
| 15 |
+
|
| 16 |
+
# safe division for value range 0-1
|
| 17 |
+
class safe_01_div(torch.autograd.Function):
|
| 18 |
+
@staticmethod
|
| 19 |
+
def forward(ctx, a, b):
|
| 20 |
+
ctx.save_for_backward(a, b)
|
| 21 |
+
return torch.div(a, torch.clamp(b, min=1e-4, max=1.0))
|
| 22 |
+
|
| 23 |
+
@staticmethod
|
| 24 |
+
def backward(ctx, grad_output):
|
| 25 |
+
a, b = ctx.saved_tensors
|
| 26 |
+
grad_input = grad_output.clone()
|
| 27 |
+
|
| 28 |
+
return torch.div(1, torch.clamp(b, min=1e-4, max=1.0)) * grad_input, -1 * torch.div(a, torch.clamp(b, min=1e-2, max=1.0)**2) * grad_input
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_positions(h, w, real_size, use_pixel_centers=True) -> torch.Tensor:
|
| 32 |
+
pixel_center = 0.5 if use_pixel_centers else 0
|
| 33 |
+
i, j = torch.meshgrid(
|
| 34 |
+
torch.arange(h) + pixel_center,
|
| 35 |
+
torch.arange(w) + pixel_center,
|
| 36 |
+
indexing='ij'
|
| 37 |
+
)
|
| 38 |
+
if not isinstance(real_size, list):
|
| 39 |
+
real_size = [real_size] * 2
|
| 40 |
+
pos = torch.stack([(i / h - 0.5) * real_size[0], (j / w - 0.5) * real_size[1], torch.zeros_like(i)], dim=-1)
|
| 41 |
+
return pos
|
| 42 |
+
|
| 43 |
+
# N, H: (Bx3xHxW), roughness: (Bx1xHxW)
|
| 44 |
+
# The "D", facet distribution function in Cook-Torrence model
|
| 45 |
+
def DistributionGGX(cosNH, roughness):
|
| 46 |
+
a = roughness * roughness
|
| 47 |
+
a2 = a * a
|
| 48 |
+
cosNH2 = cosNH * cosNH
|
| 49 |
+
num = a2
|
| 50 |
+
denom = cosNH2 * (a2 - 1.0) + 1.0
|
| 51 |
+
denom = torch.pi * denom * denom
|
| 52 |
+
return num / denom
|
| 53 |
+
|
| 54 |
+
# NdotV, roughness: (Bx1xHxW)
|
| 55 |
+
def GeometrySchlickGGX(NdotV: torch.Tensor, roughness: torch.Tensor) -> torch.Tensor:
|
| 56 |
+
r = (roughness + 1.0)
|
| 57 |
+
k = (r*r) / 8.0
|
| 58 |
+
|
| 59 |
+
num = NdotV
|
| 60 |
+
denom = NdotV * (1.0 - k) + k
|
| 61 |
+
|
| 62 |
+
return num / denom
|
| 63 |
+
|
| 64 |
+
# cosTheta, F0 (Bx1xHxW)
|
| 65 |
+
# The "F"
|
| 66 |
+
def fresnelSchlick(cosTheta: torch.Tensor, F0: torch.Tensor) -> torch.Tensor:
|
| 67 |
+
return F0 + (1.0 - F0) * torch.pow(1.0 - cosTheta, 5.0)
|
config/chord.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
name: chord
|
| 3 |
+
roughness_step: 5.
|
| 4 |
+
metallic_step: 1.
|
| 5 |
+
# format: "OutputMapName": ConvInInput1_ConvInInput2_{0/1}
|
| 6 |
+
# 0/1 stands for using gt/pred image;
|
| 7 |
+
chain_type: chord
|
| 8 |
+
chain_library:
|
| 9 |
+
chord:
|
| 10 |
+
basecolor: render_0
|
| 11 |
+
normal: render_approxIrr_01
|
| 12 |
+
rou_met: render_approxRM_01
|
| 13 |
+
rgbx_prompts:
|
| 14 |
+
basecolor: Basecolor
|
| 15 |
+
normal: Normal
|
| 16 |
+
roughness: Roughness
|
| 17 |
+
metallic: Metallic
|
| 18 |
+
irradiance: Irradiance
|
| 19 |
+
rou_met: Roughness and Metallic
|
| 20 |
+
prior_light:
|
| 21 |
+
name: distant-light
|
| 22 |
+
direction: [-1.0, -1.0, 1.0] # Top-left corner towards bottom right
|
| 23 |
+
color: [23.47, 21.31, 20.79]
|
| 24 |
+
power: 0.1
|
| 25 |
+
stable_diffusion:
|
| 26 |
+
name: stable_diffusion
|
| 27 |
+
fp16: true
|
| 28 |
+
vae_padding: circular
|
| 29 |
+
version: 2.1
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
huggingface_hub
|
| 2 |
+
diffusers
|
| 3 |
+
transformers
|
| 4 |
+
typer
|
| 5 |
+
omegaconf
|
| 6 |
+
imageio
|
| 7 |
+
tqdm
|
| 8 |
+
gradio
|