import os import sys os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) import gradio as gr import numpy as np import torch from olbedo import OlbedoIIDOutput, OlbedoIIDPipeline from src.util.image_util import read_img_from_file, img_hwc2chw, img_linear2srgb, is_hdr from olbedo.util.image_util import float2int from src.util.seeding import seed_all import logging from huggingface_hub import snapshot_download seed = 1234 seed_all(seed) if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") logging.warning("CUDA is not available. Running on CPU will be slow.") available_models = [ "marigold_appearance/finetuned", "marigold_appearance/pretrained", "marigold_lighting/finetuned", "marigold_lighting/pretrained", "rgbx/finetuned", "rgbx/pretrained" ] loaded_models = {} prompts = ["Albedo (diffuse basecolor)", "Camera-space Normal","Roughness", "Metallicness","Irradiance (diffuse lighting)"] def get_demo(): def load_model(selected_model): if selected_model in loaded_models: return loaded_models[selected_model] local_dir = snapshot_download( repo_id="GDAOSU/olbedo", allow_patterns=f"{selected_model}/*", ) model_path = os.path.join(local_dir, selected_model) pipe = OlbedoIIDPipeline.from_pretrained( model_path, torch_dtype=torch.float32, ).to(device) if "rgbx" in selected_model: pipe.mode = "rgbx" loaded_models[selected_model] = pipe return pipe def callback( photo, inference_step, selected_model, selected_prompt, processing_res ): if "rgbx" in selected_model: mode = "rgbx" prompt = selected_prompt else: mode = "other" prompt = None # if selected_model not in loaded_models: # pipe = MarigoldIIDPipeline.from_pretrained( # f"GDAOSU/olbedo/{selected_model}", # torch_dtype=torch.float32 # ).to(device) # pipe.mode = mode # loaded_models[selected_model] = pipe # else: # pipe = loaded_models[selected_model] pipe = load_model(selected_model) generator = torch.Generator(device=device) generator.manual_seed(seed) img = read_img_from_file(photo) if len(img.shape) == 3: img = img_hwc2chw(img) if is_hdr(photo): img = img_linear2srgb(img) if img.shape[0] == 4: img = img[:3, :, :] rgb_float = torch.from_numpy(img).float() input_image = float2int(rgb_float).unsqueeze(0) if "rgbx" in selected_model: pipe.prompt = prompt pipe_out: OlbedoIIDOutput = pipe( input_image, denoising_steps=inference_step, ensemble_size=1, processing_res=processing_res, match_input_res=1, batch_size=0, show_progress_bar=False, resample_method="bilinear", generator=generator, ) target_pred = pipe_out["albedo"].array if prompt is not None and ("Metallicness" in prompt or "Roughness" in prompt): target_pred = np.repeat(target_pred[0:1,:], 3, axis=0) generated_image = target_pred.transpose(1, 2, 0) if generated_image.dtype != np.uint8: generated_image = np.clip(generated_image, 0, 1) generated_image = (generated_image * 255).astype(np.uint8) TMP_DIR = "/tmp" os.makedirs(TMP_DIR, exist_ok=True) npy_path = os.path.join(TMP_DIR, "target_pred.npy") np.save(npy_path, target_pred) from PIL import Image png_path = os.path.join(TMP_DIR, "target_pred.png") Image.fromarray(generated_image).save(png_path) return png_path, npy_path, generated_image block = gr.Blocks() with block: with gr.Row(): gr.Markdown("## Olbedo: An Albedo and Shading Aerial Dataset for Large-Scale Outdoor Environments") with gr.Row(): # Input side with gr.Column(): gr.Markdown("### Given Image") photo = gr.Image(label="Photo",type="filepath") gr.Markdown("### Parameters") run_button = gr.Button(value="Run") with gr.Accordion("Advanced options", open=False): inference_step = gr.Slider( label="Inference Step", minimum=1, maximum=100, step=1, value=4, ) processing_res = gr.Number(value=1000, label="Processing Resolution (processing_res)", precision=0) gr.Markdown("### Select Model") model_selector = gr.Dropdown( label="Checkpoint", choices=available_models, value="rgbx/finetuned" ) gr.Markdown("### Select Prompt (only for rgbx models)") prompt_selector = gr.Dropdown( label="Prompts", choices=prompts, value=prompts[0] ) # Output side with gr.Column(): gr.Markdown("### Output Gallery") result_image = gr.Image(label="Output Image", interactive=False) result_png = gr.File(label="Download Generated Image (.png)") result_npy = gr.File(label="Download Target Albedo (.npy)") inputs = [ photo, inference_step, model_selector, prompt_selector, processing_res ] outputs = [result_png, result_npy, result_image] run_button.click(fn=callback, inputs=inputs, outputs=outputs, queue=True) return block if __name__ == "__main__": demo = get_demo() demo.queue(max_size=1) demo.launch()