File size: 5,115 Bytes
a846205
 
 
 
 
 
 
 
 
87d37c2
e473d08
a846205
 
 
 
771c988
a846205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87d37c2
771c988
a846205
 
 
 
 
771c988
 
a846205
 
 
 
 
66a9c17
a846205
66a9c17
a846205
66a9c17
a846205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e473d08
b62c832
a846205
 
b62c832
 
 
a846205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b62c832
 
 
a846205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b62c832
a846205
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
import os
import numpy as np
from PIL import Image
import torch
import copy
from omegaconf import OmegaConf
from torchvision.transforms import v2
from torchvision.transforms.functional import to_pil_image
from huggingface_hub import hf_hub_download, login
import spaces

from chord import ChordModel
from chord.module import make
from chord.util import get_positions, rgb_to_srgb
from chord.io import load_torch_file

EXAMPLES_USECASE_1 = [
    [f"examples/generated/{f}"] 
    for f in sorted(os.listdir("examples/generated"))
]
EXAMPLES_USECASE_2 = [
    [f"examples/in_the_wild/{f}"] 
    for f in sorted(os.listdir("examples/in_the_wild"))
]
EXAMPLES_USECASE_3 = [
    [f"examples/specular/{f}"] 
    for f in sorted(os.listdir("examples/specular"))
]

MODEL_OBJ = None
login(token=os.environ["HF_TOKEN"])
MODEL_CKPT_PATH = hf_hub_download(repo_id="Ubisoft/ubisoft-laforge-chord", filename="chord_v1.safetensors")
def load_model(ckpt_path):
    print("Loading model from:", ckpt_path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    config = OmegaConf.load("config/chord.yaml")
    model = ChordModel(config)
    state_dict = load_torch_file(ckpt_path)
    model.load_state_dict(state_dict)
    model.eval()
    model.to(device)
    return model

def run_model(model, img: Image.Image):
    device = next(model.parameters()).device
    to_tensor = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
    image = to_tensor(img).to(device)
    x = v2.Resize(size=(1024, 1024), antialias=True)(image).unsqueeze(0)
    with torch.no_grad() as no_grad, torch.autocast(device_type=device.type) as amp:
        output = model(x)
    return output

def relit(model, maps):
    maps['metallic'] = maps.get('metalness', torch.zeros_like(maps['basecolor']))
    device = next(model.parameters()).device
    h, w = maps["basecolor"].shape[-2:]
    light = make("point-light", {"position": [0, 0, 10]}).to(device)
    pos = get_positions(h, w, 10).to(device)
    camera = torch.tensor([0, 0, 10.0]).to(device)
    for key in maps:
        if maps[key].dim() == 3:
            maps[key] = maps[key].unsqueeze(0)
        maps[key] = maps[key].permute(0,2,3,1)  # BxCxHxW -> BxHxWxC
    rgb = model.model.compute_render(maps, camera, pos, light).squeeze(0).permute(0,3,1,2)  # GxBxHxWxC -> BxCxHxW
    return torch.clamp(rgb_to_srgb(rgb), 0, 1)

@spaces.GPU
def inference(img):
    global MODEL_OBJ

    if MODEL_OBJ is None or getattr(MODEL_OBJ, "_ckpt", None) != MODEL_CKPT_PATH:
        MODEL_OBJ = load_model(MODEL_CKPT_PATH)
        MODEL_OBJ._ckpt = MODEL_CKPT_PATH  # store path inside object

    if img is None:
        return None, None, None, None, None
    
    ori_h, ori_w = img.size[1], img.size[0]
    out = run_model(MODEL_OBJ, img)
    maps = copy.deepcopy(out)
    rendered = relit(MODEL_OBJ, maps)
    resize_back = v2.Resize(size=(ori_h, ori_w), antialias=True)
    return (
        to_pil_image(resize_back(out["basecolor"]).squeeze(0)),
        to_pil_image(resize_back(out["normal"]).squeeze(0)),
        to_pil_image(resize_back(out["roughness"]).squeeze(0)),
        to_pil_image(resize_back(out["metalness"]).squeeze(0)),
        to_pil_image(resize_back(rendered).squeeze(0)),
    )

with gr.Blocks(title="Chord") as demo:

    gr.Markdown("# **Chord: Chain of Rendering Decomposition for PBR Material Estimation from Generated Texture Images**")
    gr.Markdown("Upload an image or select an example to estimate PBR channels.")
    
    with gr.Row():
        with gr.Column():
            input_img = gr.Image(type="pil", label="Input Image", height=512)

            gr.Markdown("### Example Inputs — Generated Textures")
            gr.Examples(
                examples=EXAMPLES_USECASE_1,
                inputs=[input_img],
                label="Examples (Generated Textures)"
            )

            gr.Markdown("### Example Inputs — In The Wild Photographs")
            gr.Examples(
                examples=EXAMPLES_USECASE_2,
                inputs=[input_img],
                label="Examples (In The Wild Photographs)"
            )

            gr.Markdown("### Example Inputs — Specular Textures")
            gr.Examples(
                examples=EXAMPLES_USECASE_3,
                inputs=[input_img],
                label="Examples (Specular Textures)"
            )

            run_button = gr.Button("Run Estimation")

        with gr.Column():
            gr.Markdown("### Predicted Channels")
            basecolor_out = gr.Image(label="Basecolor", height=512)
            normal_out = gr.Image(label="Normal", height=512)
            roughness_out = gr.Image(label="Roughness", height=512)
            metallic_out = gr.Image(label="Metalness", height=512)

            gr.Markdown("### Relit Output")
            render_out = gr.Image(label="Relit Image (Centered Point Light)", height=512)

    run_button.click(
        inference,
        inputs=[input_img],
        outputs=[basecolor_out, normal_out, roughness_out, metallic_out, render_out]
    )

if __name__ == "__main__":
    demo.launch()