import gradio as gr import os import numpy as np from PIL import Image import torch import copy from omegaconf import OmegaConf from torchvision.transforms import v2 from torchvision.transforms.functional import to_pil_image from huggingface_hub import hf_hub_download, login import spaces from chord import ChordModel from chord.module import make from chord.util import get_positions, rgb_to_srgb from chord.io import load_torch_file EXAMPLES_USECASE_1 = [ [f"examples/generated/{f}"] for f in sorted(os.listdir("examples/generated")) ] EXAMPLES_USECASE_2 = [ [f"examples/in_the_wild/{f}"] for f in sorted(os.listdir("examples/in_the_wild")) ] EXAMPLES_USECASE_3 = [ [f"examples/specular/{f}"] for f in sorted(os.listdir("examples/specular")) ] MODEL_OBJ = None login(token=os.environ["HF_TOKEN"]) MODEL_CKPT_PATH = hf_hub_download(repo_id="Ubisoft/ubisoft-laforge-chord", filename="chord_v1.safetensors") def load_model(ckpt_path): print("Loading model from:", ckpt_path) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") config = OmegaConf.load("config/chord.yaml") model = ChordModel(config) state_dict = load_torch_file(ckpt_path) model.load_state_dict(state_dict) model.eval() model.to(device) return model def run_model(model, img: Image.Image): device = next(model.parameters()).device to_tensor = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]) image = to_tensor(img).to(device) x = v2.Resize(size=(1024, 1024), antialias=True)(image).unsqueeze(0) with torch.no_grad() as no_grad, torch.autocast(device_type=device.type) as amp: output = model(x) return output def relit(model, maps): maps['metallic'] = maps.get('metalness', torch.zeros_like(maps['basecolor'])) device = next(model.parameters()).device h, w = maps["basecolor"].shape[-2:] light = make("point-light", {"position": [0, 0, 10]}).to(device) pos = get_positions(h, w, 10).to(device) camera = torch.tensor([0, 0, 10.0]).to(device) for key in maps: if maps[key].dim() == 3: maps[key] = maps[key].unsqueeze(0) maps[key] = maps[key].permute(0,2,3,1) # BxCxHxW -> BxHxWxC rgb = model.model.compute_render(maps, camera, pos, light).squeeze(0).permute(0,3,1,2) # GxBxHxWxC -> BxCxHxW return torch.clamp(rgb_to_srgb(rgb), 0, 1) @spaces.GPU def inference(img): global MODEL_OBJ if MODEL_OBJ is None or getattr(MODEL_OBJ, "_ckpt", None) != MODEL_CKPT_PATH: MODEL_OBJ = load_model(MODEL_CKPT_PATH) MODEL_OBJ._ckpt = MODEL_CKPT_PATH # store path inside object if img is None: return None, None, None, None, None ori_h, ori_w = img.size[1], img.size[0] out = run_model(MODEL_OBJ, img) maps = copy.deepcopy(out) rendered = relit(MODEL_OBJ, maps) resize_back = v2.Resize(size=(ori_h, ori_w), antialias=True) return ( to_pil_image(resize_back(out["basecolor"]).squeeze(0)), to_pil_image(resize_back(out["normal"]).squeeze(0)), to_pil_image(resize_back(out["roughness"]).squeeze(0)), to_pil_image(resize_back(out["metalness"]).squeeze(0)), to_pil_image(resize_back(rendered).squeeze(0)), ) with gr.Blocks(title="Chord") as demo: gr.Markdown("# **Chord: Chain of Rendering Decomposition for PBR Material Estimation from Generated Texture Images**") gr.Markdown("Upload an image or select an example to estimate PBR channels.") with gr.Row(): with gr.Column(): input_img = gr.Image(type="pil", label="Input Image", height=512) gr.Markdown("### Example Inputs — Generated Textures") gr.Examples( examples=EXAMPLES_USECASE_1, inputs=[input_img], label="Examples (Generated Textures)" ) gr.Markdown("### Example Inputs — In The Wild Photographs") gr.Examples( examples=EXAMPLES_USECASE_2, inputs=[input_img], label="Examples (In The Wild Photographs)" ) gr.Markdown("### Example Inputs — Specular Textures") gr.Examples( examples=EXAMPLES_USECASE_3, inputs=[input_img], label="Examples (Specular Textures)" ) run_button = gr.Button("Run Estimation") with gr.Column(): gr.Markdown("### Predicted Channels") basecolor_out = gr.Image(label="Basecolor", height=512) normal_out = gr.Image(label="Normal", height=512) roughness_out = gr.Image(label="Roughness", height=512) metallic_out = gr.Image(label="Metalness", height=512) gr.Markdown("### Relit Output") render_out = gr.Image(label="Relit Image (Centered Point Light)", height=512) run_button.click( inference, inputs=[input_img], outputs=[basecolor_out, normal_out, roughness_out, metallic_out, render_out] ) if __name__ == "__main__": demo.launch()