| import os |
| import sys |
| os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) |
|
|
|
|
| import gradio as gr |
| import numpy as np |
| import torch |
| from marigold import MarigoldIIDOutput, MarigoldIIDPipeline |
| from src.util.image_util import read_img_from_file, img_hwc2chw, img_linear2srgb, is_hdr |
| from marigold.util.image_util import float2int |
| from src.util.seeding import seed_all |
| import logging |
|
|
| HF_REPO_ID = "GDAOSU/olbedo" |
| seed = 1234 |
| seed_all(seed) |
| if torch.cuda.is_available(): |
| device = torch.device("cuda") |
| else: |
| device = torch.device("cpu") |
| logging.warning("CUDA is not available. Running on CPU will be slow.") |
|
|
| available_models = [ |
| "marigold_appearance/finetuned", |
| "marigold_appearance/pretrained", |
| "marigold_lighting/finetuned", |
| "marigold_lighting/pretrained", |
| "rgbx/finetuned", |
| "rgbx/pretrained" |
| ] |
|
|
| loaded_models = {} |
|
|
| prompts = ["Albedo (diffuse basecolor)", "Camera-space Normal","Roughness", "Metallicness","Irradiance (diffuse lighting)"] |
|
|
| def get_demo(): |
|
|
| def callback( |
| photo, |
| inference_step, |
| selected_model, |
| selected_prompt, |
| processing_res |
| ): |
| if "rgbx" in selected_model: |
| mode = "rgbx" |
| prompt = selected_prompt |
| else: |
| mode = "other" |
| prompt = None |
| if selected_model not in loaded_models: |
| pipe = MarigoldIIDPipeline.from_pretrained( |
| HF_REPO_ID, |
| subfolder=selected_model, |
| torch_dtype=torch.float32 |
| ).to(device) |
| pipe.mode = mode |
| loaded_models[selected_model] = pipe |
| else: |
| pipe = loaded_models[selected_model] |
|
|
| generator = torch.Generator(device=device) |
| generator.manual_seed(seed) |
| img = read_img_from_file(photo) |
| if len(img.shape) == 3: |
| img = img_hwc2chw(img) |
| if is_hdr(photo): |
| img = img_linear2srgb(img) |
| if img.shape[0] == 4: |
| img = img[:3, :, :] |
| rgb_float = torch.from_numpy(img).float() |
| input_image = float2int(rgb_float).unsqueeze(0) |
|
|
| if "rgbx" in selected_model: |
| pipe.prompt = prompt |
|
|
| pipe_out: MarigoldIIDOutput = pipe( |
| input_image, |
| denoising_steps=inference_step, |
| ensemble_size=1, |
| processing_res=processing_res, |
| match_input_res=1, |
| batch_size=0, |
| show_progress_bar=False, |
| resample_method="bilinear", |
| generator=generator, |
| ) |
| target_pred = pipe_out["albedo"].array |
| if prompt is not None and ("Metallicness" in prompt or "Roughness" in prompt): |
| target_pred = np.repeat(target_pred[0:1,:], 3, axis=0) |
| generated_image = target_pred.transpose(1, 2, 0) |
| if generated_image.dtype != np.uint8: |
| generated_image = np.clip(generated_image, 0, 1) |
| generated_image = (generated_image * 255).astype(np.uint8) |
|
|
| TMP_DIR = "/tmp" |
| os.makedirs(TMP_DIR, exist_ok=True) |
|
|
| npy_path = os.path.join(TMP_DIR, "target_pred.npy") |
| np.save(npy_path, target_pred) |
|
|
| from PIL import Image |
| png_path = os.path.join(TMP_DIR, "target_pred.png") |
| Image.fromarray(generated_image).save(png_path) |
|
|
| return png_path, npy_path, generated_image |
|
|
| block = gr.Blocks() |
| with block: |
| with gr.Row(): |
| gr.Markdown("## OSU albedo demo") |
| with gr.Row(): |
| |
| with gr.Column(): |
| gr.Markdown("### Given Image") |
| photo = gr.Image(label="Photo",type="filepath") |
|
|
| gr.Markdown("### Parameters") |
| run_button = gr.Button(value="Run") |
| with gr.Accordion("Advanced options", open=False): |
| inference_step = gr.Slider( |
| label="Inference Step", |
| minimum=1, |
| maximum=100, |
| step=1, |
| value=4, |
| ) |
| processing_res = gr.Number(value=0, label="Processing Resolution (processing_res)", precision=0) |
|
|
| gr.Markdown("### Select Model") |
| model_selector = gr.Dropdown( |
| label="Checkpoint", |
| choices=available_models, |
| value=available_models[0] |
| ) |
|
|
| gr.Markdown("### Select Prompt (only for rgbx models)") |
| prompt_selector = gr.Dropdown( |
| label="Prompts", |
| choices=prompts, |
| value=prompts[0] |
| ) |
|
|
| |
| with gr.Column(): |
| gr.Markdown("### Output Gallery") |
| result_image = gr.Image(label="Output Image", interactive=False) |
| result_png = gr.File(label="Download Generated Image (.png)") |
| result_npy = gr.File(label="Download Target Albedo (.npy)") |
|
|
| inputs = [ |
| photo, |
| inference_step, |
| model_selector, |
| prompt_selector, |
| processing_res |
| ] |
| outputs = [result_png, result_npy, result_image] |
| run_button.click(fn=callback, inputs=inputs, outputs=outputs, queue=True) |
|
|
| return block |
|
|
|
|
| if __name__ == "__main__": |
| demo = get_demo() |
| demo.queue(max_size=1) |
| demo.launch() |