| | import os |
| | import sys |
| | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" |
| | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) |
| |
|
| |
|
| | import gradio as gr |
| | import numpy as np |
| | import torch |
| | from olbedo import OlbedoIIDOutput, OlbedoIIDPipeline |
| | from src.util.image_util import read_img_from_file, img_hwc2chw, img_linear2srgb, is_hdr |
| | from olbedo.util.image_util import float2int |
| | from src.util.seeding import seed_all |
| | import logging |
| | from huggingface_hub import snapshot_download |
| |
|
| | seed = 1234 |
| | seed_all(seed) |
| | if torch.cuda.is_available(): |
| | device = torch.device("cuda") |
| | else: |
| | device = torch.device("cpu") |
| | logging.warning("CUDA is not available. Running on CPU will be slow.") |
| |
|
| | available_models = [ |
| | "marigold_appearance/finetuned", |
| | "marigold_appearance/pretrained", |
| | "marigold_lighting/finetuned", |
| | "marigold_lighting/pretrained", |
| | "rgbx/finetuned", |
| | "rgbx/pretrained" |
| | ] |
| |
|
| | loaded_models = {} |
| |
|
| | prompts = ["Albedo (diffuse basecolor)", "Camera-space Normal","Roughness", "Metallicness","Irradiance (diffuse lighting)"] |
| |
|
| | def get_demo(): |
| |
|
| | def load_model(selected_model): |
| | if selected_model in loaded_models: |
| | return loaded_models[selected_model] |
| |
|
| | local_dir = snapshot_download( |
| | repo_id="GDAOSU/olbedo", |
| | allow_patterns=f"{selected_model}/*", |
| | ) |
| |
|
| | model_path = os.path.join(local_dir, selected_model) |
| |
|
| | pipe = OlbedoIIDPipeline.from_pretrained( |
| | model_path, |
| | torch_dtype=torch.float32, |
| | ).to(device) |
| |
|
| | if "rgbx" in selected_model: |
| | pipe.mode = "rgbx" |
| |
|
| | loaded_models[selected_model] = pipe |
| | return pipe |
| |
|
| | def callback( |
| | photo, |
| | inference_step, |
| | selected_model, |
| | selected_prompt, |
| | processing_res |
| | ): |
| | if "rgbx" in selected_model: |
| | mode = "rgbx" |
| | prompt = selected_prompt |
| | else: |
| | mode = "other" |
| | prompt = None |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | pipe = load_model(selected_model) |
| |
|
| | generator = torch.Generator(device=device) |
| | generator.manual_seed(seed) |
| | img = read_img_from_file(photo) |
| | if len(img.shape) == 3: |
| | img = img_hwc2chw(img) |
| | if is_hdr(photo): |
| | img = img_linear2srgb(img) |
| | if img.shape[0] == 4: |
| | img = img[:3, :, :] |
| | rgb_float = torch.from_numpy(img).float() |
| | input_image = float2int(rgb_float).unsqueeze(0) |
| |
|
| | if "rgbx" in selected_model: |
| | pipe.prompt = prompt |
| |
|
| | pipe_out: OlbedoIIDOutput = pipe( |
| | input_image, |
| | denoising_steps=inference_step, |
| | ensemble_size=1, |
| | processing_res=processing_res, |
| | match_input_res=1, |
| | batch_size=0, |
| | show_progress_bar=False, |
| | resample_method="bilinear", |
| | generator=generator, |
| | ) |
| | target_pred = pipe_out["albedo"].array |
| | if prompt is not None and ("Metallicness" in prompt or "Roughness" in prompt): |
| | target_pred = np.repeat(target_pred[0:1,:], 3, axis=0) |
| | generated_image = target_pred.transpose(1, 2, 0) |
| | if generated_image.dtype != np.uint8: |
| | generated_image = np.clip(generated_image, 0, 1) |
| | generated_image = (generated_image * 255).astype(np.uint8) |
| |
|
| | TMP_DIR = "/tmp" |
| | os.makedirs(TMP_DIR, exist_ok=True) |
| |
|
| | npy_path = os.path.join(TMP_DIR, "target_pred.npy") |
| | np.save(npy_path, target_pred) |
| |
|
| | from PIL import Image |
| | png_path = os.path.join(TMP_DIR, "target_pred.png") |
| | Image.fromarray(generated_image).save(png_path) |
| |
|
| | return png_path, npy_path, generated_image |
| |
|
| | block = gr.Blocks() |
| | with block: |
| | with gr.Row(): |
| | gr.Markdown("## Olbedo: An Albedo and Shading Aerial Dataset for Large-Scale Outdoor Environments") |
| | with gr.Row(): |
| | |
| | with gr.Column(): |
| | gr.Markdown("### Given Image") |
| | photo = gr.Image(label="Photo",type="filepath") |
| |
|
| | gr.Markdown("### Parameters") |
| | run_button = gr.Button(value="Run") |
| | with gr.Accordion("Advanced options", open=False): |
| | inference_step = gr.Slider( |
| | label="Inference Step", |
| | minimum=1, |
| | maximum=100, |
| | step=1, |
| | value=4, |
| | ) |
| | processing_res = gr.Number(value=1000, label="Processing Resolution (processing_res)", precision=0) |
| |
|
| | gr.Markdown("### Select Model") |
| | model_selector = gr.Dropdown( |
| | label="Checkpoint", |
| | choices=available_models, |
| | value="rgbx/finetuned" |
| | ) |
| |
|
| | gr.Markdown("### Select Prompt (only for rgbx models)") |
| | prompt_selector = gr.Dropdown( |
| | label="Prompts", |
| | choices=prompts, |
| | value=prompts[0] |
| | ) |
| |
|
| | |
| | with gr.Column(): |
| | gr.Markdown("### Output Gallery") |
| | result_image = gr.Image(label="Output Image", interactive=False) |
| | result_png = gr.File(label="Download Generated Image (.png)") |
| | result_npy = gr.File(label="Download Target Albedo (.npy)") |
| |
|
| | inputs = [ |
| | photo, |
| | inference_step, |
| | model_selector, |
| | prompt_selector, |
| | processing_res |
| | ] |
| | outputs = [result_png, result_npy, result_image] |
| | run_button.click(fn=callback, inputs=inputs, outputs=outputs, queue=True) |
| |
|
| | return block |
| |
|
| |
|
| | if __name__ == "__main__": |
| | demo = get_demo() |
| | demo.queue(max_size=1) |
| | demo.launch() |