Olbedo / app.py
degbo's picture
update with new code
f2dd2b8
import os
import sys
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
import gradio as gr
import numpy as np
import torch
from olbedo import OlbedoIIDOutput, OlbedoIIDPipeline
from src.util.image_util import read_img_from_file, img_hwc2chw, img_linear2srgb, is_hdr
from olbedo.util.image_util import float2int
from src.util.seeding import seed_all
import logging
from huggingface_hub import snapshot_download
seed = 1234
seed_all(seed)
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
logging.warning("CUDA is not available. Running on CPU will be slow.")
available_models = [
"marigold_appearance/finetuned",
"marigold_appearance/pretrained",
"marigold_lighting/finetuned",
"marigold_lighting/pretrained",
"rgbx/finetuned",
"rgbx/pretrained"
]
loaded_models = {}
prompts = ["Albedo (diffuse basecolor)", "Camera-space Normal","Roughness", "Metallicness","Irradiance (diffuse lighting)"]
def get_demo():
def load_model(selected_model):
if selected_model in loaded_models:
return loaded_models[selected_model]
local_dir = snapshot_download(
repo_id="GDAOSU/olbedo",
allow_patterns=f"{selected_model}/*",
)
model_path = os.path.join(local_dir, selected_model)
pipe = OlbedoIIDPipeline.from_pretrained(
model_path,
torch_dtype=torch.float32,
).to(device)
if "rgbx" in selected_model:
pipe.mode = "rgbx"
loaded_models[selected_model] = pipe
return pipe
def callback(
photo,
inference_step,
selected_model,
selected_prompt,
processing_res
):
if "rgbx" in selected_model:
mode = "rgbx"
prompt = selected_prompt
else:
mode = "other"
prompt = None
# if selected_model not in loaded_models:
# pipe = MarigoldIIDPipeline.from_pretrained(
# f"GDAOSU/olbedo/{selected_model}",
# torch_dtype=torch.float32
# ).to(device)
# pipe.mode = mode
# loaded_models[selected_model] = pipe
# else:
# pipe = loaded_models[selected_model]
pipe = load_model(selected_model)
generator = torch.Generator(device=device)
generator.manual_seed(seed)
img = read_img_from_file(photo)
if len(img.shape) == 3:
img = img_hwc2chw(img)
if is_hdr(photo):
img = img_linear2srgb(img)
if img.shape[0] == 4:
img = img[:3, :, :]
rgb_float = torch.from_numpy(img).float()
input_image = float2int(rgb_float).unsqueeze(0)
if "rgbx" in selected_model:
pipe.prompt = prompt
pipe_out: OlbedoIIDOutput = pipe(
input_image,
denoising_steps=inference_step,
ensemble_size=1,
processing_res=processing_res,
match_input_res=1,
batch_size=0,
show_progress_bar=False,
resample_method="bilinear",
generator=generator,
)
target_pred = pipe_out["albedo"].array
if prompt is not None and ("Metallicness" in prompt or "Roughness" in prompt):
target_pred = np.repeat(target_pred[0:1,:], 3, axis=0)
generated_image = target_pred.transpose(1, 2, 0)
if generated_image.dtype != np.uint8:
generated_image = np.clip(generated_image, 0, 1)
generated_image = (generated_image * 255).astype(np.uint8)
TMP_DIR = "/tmp"
os.makedirs(TMP_DIR, exist_ok=True)
npy_path = os.path.join(TMP_DIR, "target_pred.npy")
np.save(npy_path, target_pred)
from PIL import Image
png_path = os.path.join(TMP_DIR, "target_pred.png")
Image.fromarray(generated_image).save(png_path)
return png_path, npy_path, generated_image
block = gr.Blocks()
with block:
with gr.Row():
gr.Markdown("## Olbedo: An Albedo and Shading Aerial Dataset for Large-Scale Outdoor Environments")
with gr.Row():
# Input side
with gr.Column():
gr.Markdown("### Given Image")
photo = gr.Image(label="Photo",type="filepath")
gr.Markdown("### Parameters")
run_button = gr.Button(value="Run")
with gr.Accordion("Advanced options", open=False):
inference_step = gr.Slider(
label="Inference Step",
minimum=1,
maximum=100,
step=1,
value=4,
)
processing_res = gr.Number(value=1000, label="Processing Resolution (processing_res)", precision=0)
gr.Markdown("### Select Model")
model_selector = gr.Dropdown(
label="Checkpoint",
choices=available_models,
value="rgbx/finetuned"
)
gr.Markdown("### Select Prompt (only for rgbx models)")
prompt_selector = gr.Dropdown(
label="Prompts",
choices=prompts,
value=prompts[0]
)
# Output side
with gr.Column():
gr.Markdown("### Output Gallery")
result_image = gr.Image(label="Output Image", interactive=False)
result_png = gr.File(label="Download Generated Image (.png)")
result_npy = gr.File(label="Download Target Albedo (.npy)")
inputs = [
photo,
inference_step,
model_selector,
prompt_selector,
processing_res
]
outputs = [result_png, result_npy, result_image]
run_button.click(fn=callback, inputs=inputs, outputs=outputs, queue=True)
return block
if __name__ == "__main__":
demo = get_demo()
demo.queue(max_size=1)
demo.launch()