Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,115 Bytes
a846205 87d37c2 e473d08 a846205 771c988 a846205 87d37c2 771c988 a846205 771c988 a846205 66a9c17 a846205 66a9c17 a846205 66a9c17 a846205 e473d08 b62c832 a846205 b62c832 a846205 b62c832 a846205 b62c832 a846205 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import gradio as gr
import os
import numpy as np
from PIL import Image
import torch
import copy
from omegaconf import OmegaConf
from torchvision.transforms import v2
from torchvision.transforms.functional import to_pil_image
from huggingface_hub import hf_hub_download, login
import spaces
from chord import ChordModel
from chord.module import make
from chord.util import get_positions, rgb_to_srgb
from chord.io import load_torch_file
EXAMPLES_USECASE_1 = [
[f"examples/generated/{f}"]
for f in sorted(os.listdir("examples/generated"))
]
EXAMPLES_USECASE_2 = [
[f"examples/in_the_wild/{f}"]
for f in sorted(os.listdir("examples/in_the_wild"))
]
EXAMPLES_USECASE_3 = [
[f"examples/specular/{f}"]
for f in sorted(os.listdir("examples/specular"))
]
MODEL_OBJ = None
login(token=os.environ["HF_TOKEN"])
MODEL_CKPT_PATH = hf_hub_download(repo_id="Ubisoft/ubisoft-laforge-chord", filename="chord_v1.safetensors")
def load_model(ckpt_path):
print("Loading model from:", ckpt_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = OmegaConf.load("config/chord.yaml")
model = ChordModel(config)
state_dict = load_torch_file(ckpt_path)
model.load_state_dict(state_dict)
model.eval()
model.to(device)
return model
def run_model(model, img: Image.Image):
device = next(model.parameters()).device
to_tensor = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
image = to_tensor(img).to(device)
x = v2.Resize(size=(1024, 1024), antialias=True)(image).unsqueeze(0)
with torch.no_grad() as no_grad, torch.autocast(device_type=device.type) as amp:
output = model(x)
return output
def relit(model, maps):
maps['metallic'] = maps.get('metalness', torch.zeros_like(maps['basecolor']))
device = next(model.parameters()).device
h, w = maps["basecolor"].shape[-2:]
light = make("point-light", {"position": [0, 0, 10]}).to(device)
pos = get_positions(h, w, 10).to(device)
camera = torch.tensor([0, 0, 10.0]).to(device)
for key in maps:
if maps[key].dim() == 3:
maps[key] = maps[key].unsqueeze(0)
maps[key] = maps[key].permute(0,2,3,1) # BxCxHxW -> BxHxWxC
rgb = model.model.compute_render(maps, camera, pos, light).squeeze(0).permute(0,3,1,2) # GxBxHxWxC -> BxCxHxW
return torch.clamp(rgb_to_srgb(rgb), 0, 1)
@spaces.GPU
def inference(img):
global MODEL_OBJ
if MODEL_OBJ is None or getattr(MODEL_OBJ, "_ckpt", None) != MODEL_CKPT_PATH:
MODEL_OBJ = load_model(MODEL_CKPT_PATH)
MODEL_OBJ._ckpt = MODEL_CKPT_PATH # store path inside object
if img is None:
return None, None, None, None, None
ori_h, ori_w = img.size[1], img.size[0]
out = run_model(MODEL_OBJ, img)
maps = copy.deepcopy(out)
rendered = relit(MODEL_OBJ, maps)
resize_back = v2.Resize(size=(ori_h, ori_w), antialias=True)
return (
to_pil_image(resize_back(out["basecolor"]).squeeze(0)),
to_pil_image(resize_back(out["normal"]).squeeze(0)),
to_pil_image(resize_back(out["roughness"]).squeeze(0)),
to_pil_image(resize_back(out["metalness"]).squeeze(0)),
to_pil_image(resize_back(rendered).squeeze(0)),
)
with gr.Blocks(title="Chord") as demo:
gr.Markdown("# **Chord: Chain of Rendering Decomposition for PBR Material Estimation from Generated Texture Images**")
gr.Markdown("Upload an image or select an example to estimate PBR channels.")
with gr.Row():
with gr.Column():
input_img = gr.Image(type="pil", label="Input Image", height=512)
gr.Markdown("### Example Inputs — Generated Textures")
gr.Examples(
examples=EXAMPLES_USECASE_1,
inputs=[input_img],
label="Examples (Generated Textures)"
)
gr.Markdown("### Example Inputs — In The Wild Photographs")
gr.Examples(
examples=EXAMPLES_USECASE_2,
inputs=[input_img],
label="Examples (In The Wild Photographs)"
)
gr.Markdown("### Example Inputs — Specular Textures")
gr.Examples(
examples=EXAMPLES_USECASE_3,
inputs=[input_img],
label="Examples (Specular Textures)"
)
run_button = gr.Button("Run Estimation")
with gr.Column():
gr.Markdown("### Predicted Channels")
basecolor_out = gr.Image(label="Basecolor", height=512)
normal_out = gr.Image(label="Normal", height=512)
roughness_out = gr.Image(label="Roughness", height=512)
metallic_out = gr.Image(label="Metalness", height=512)
gr.Markdown("### Relit Output")
render_out = gr.Image(label="Relit Image (Centered Point Light)", height=512)
run_button.click(
inference,
inputs=[input_img],
outputs=[basecolor_out, normal_out, roughness_out, metallic_out, render_out]
)
if __name__ == "__main__":
demo.launch()
|