Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,108 Bytes
a846205 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import gradio as gr
import os
import numpy as np
from PIL import Image
import torch
import copy
from omegaconf import OmegaConf
from torchvision.transforms import v2
from torchvision.transforms.functional import to_pil_image
from chord import ChordModel
from chord.module import make
from chord.util import get_positions, rgb_to_srgb
EXAMPLES_USECASE_1 = [
[f"examples/generated/{f}"]
for f in sorted(os.listdir("examples/generated"))
]
EXAMPLES_USECASE_2 = [
[f"examples/in_the_wild/{f}"]
for f in sorted(os.listdir("examples/in_the_wild"))
]
EXAMPLES_USECASE_3 = [
[f"examples/specular/{f}"]
for f in sorted(os.listdir("examples/specular"))
]
MODEL_OBJ = None
def load_model(ckpt_path):
print("Loading model from:", ckpt_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = OmegaConf.load("config/chord.yaml")
model = ChordModel(config)
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
model.load_state_dict(ckpt["state_dict"])
model.eval()
model.to(device)
return model
def run_model(model, img: Image.Image):
to_tensor = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
image = to_tensor(img).to(next(model.parameters()).device)
x = v2.Resize(size=(1024, 1024), antialias=True)(image).unsqueeze(0)
with torch.no_grad() as no_grad, torch.autocast(device_type="cuda") as amp:
output = model(x)
output.update({"input": image})
return output
def relit(model, maps):
maps['metallic'] = maps.get('metalness', torch.zeros_like(maps['basecolor']))
device = next(model.parameters()).device
h, w = maps["basecolor"].shape[-2:]
light = make("point-light", {"position": [0, 0, 10]}).to(device)
pos = get_positions(h, w, 10).to(device)
camera = torch.tensor([0, 0, 10.0]).to(device)
for key in maps:
if maps[key].dim() == 3:
maps[key] = maps[key].unsqueeze(0)
maps[key] = maps[key].permute(0,2,3,1) # BxCxHxW -> BxHxWxC
rgb = model.model.compute_render(maps, camera, pos, light).squeeze(0).permute(0,3,1,2) # GxBxHxWxC -> BxCxHxW
return torch.clamp(rgb_to_srgb(rgb), 0, 1)
def inference(img, ckpt_path):
global MODEL_OBJ
if MODEL_OBJ is None or getattr(MODEL_OBJ, "_ckpt", None) != ckpt_path:
MODEL_OBJ = load_model(ckpt_path)
MODEL_OBJ._ckpt = ckpt_path # store path inside object
if img is None:
return None, None, None, None, None
ori_h, ori_w = img.size[1], img.size[0]
out = run_model(MODEL_OBJ, img)
maps = copy.deepcopy(out)
rendered = relit(MODEL_OBJ, maps)
resize_back = v2.Resize(size=(ori_h, ori_w), antialias=True)
return (
to_pil_image(resize_back(out["basecolor"]).squeeze(0)),
to_pil_image(resize_back(out["normal"]).squeeze(0)),
to_pil_image(resize_back(out["roughness"]).squeeze(0)),
to_pil_image(resize_back(out["metalness"]).squeeze(0)),
to_pil_image(resize_back(rendered).squeeze(0)),
)
with gr.Blocks(title="Chord") as demo:
gr.Markdown("# **Chord: Chain of Rendering Decomposition for PBR Material Estimation from Generated Texture images**")
ckpt_path = gr.Textbox(
label="Model Checkpoint Path",
value="chord_v1.ckpt",
placeholder="Path to your model checkpoint",
)
gr.Markdown("Upload an image or select an example to estimate PBR channels and render the result under custom lighting.")
with gr.Row():
with gr.Column():
input_img = gr.Image(type="pil", label="Input Image", height=512)
gr.Markdown("### Example Inputs — Generated Textures")
gr.Examples(
examples=EXAMPLES_USECASE_1,
inputs=[input_img],
label="Examples (Generated Textures)"
)
gr.Markdown("### Example Inputs — In The Wild Photographs")
gr.Examples(
examples=EXAMPLES_USECASE_2,
inputs=[input_img],
label="Examples (In The Wild Photographs)"
)
gr.Markdown("### Example Inputs — Specular Textures")
gr.Examples(
examples=EXAMPLES_USECASE_3,
inputs=[input_img],
label="Examples (Specular Textures)"
)
run_button = gr.Button("Run Estimation")
with gr.Column():
gr.Markdown("### Predicted Channels")
basecolor_out = gr.Image(label="Basecolor", height=512)
normal_out = gr.Image(label="Normal", height=512)
roughness_out = gr.Image(label="Roughness", height=512)
metallic_out = gr.Image(label="Metalness", height=512)
gr.Markdown("### Relit Output")
render_out = gr.Image(label="Relit Image (Centered Point Light)", height=512)
run_button.click(
inference,
inputs=[input_img, ckpt_path],
outputs=[basecolor_out, normal_out, roughness_out, metallic_out, render_out]
)
if __name__ == "__main__":
demo.launch()
|