Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import os | |
| from typing import cast | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from diffusers import DDIMScheduler | |
| from load_image import load_exr_image, load_ldr_image | |
| from pipeline_x2rgb import StableDiffusionAOVDropoutPipeline | |
| os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" | |
| current_directory = os.path.dirname(os.path.abspath(__file__)) | |
| _pipe = StableDiffusionAOVDropoutPipeline.from_pretrained( | |
| "zheng95z/x-to-rgb", | |
| torch_dtype=torch.float16, | |
| cache_dir=os.path.join(current_directory, "model_cache"), | |
| ).to("cuda") | |
| pipe = cast(StableDiffusionAOVDropoutPipeline, _pipe) | |
| pipe.scheduler = DDIMScheduler.from_config( | |
| pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing" | |
| ) | |
| pipe.set_progress_bar_config(disable=True) | |
| pipe.to("cuda") | |
| pipe = cast(StableDiffusionAOVDropoutPipeline, pipe) | |
| def generate( | |
| albedo, | |
| normal, | |
| roughness, | |
| metallic, | |
| irradiance, | |
| prompt: str, | |
| seed: int, | |
| inference_step: int, | |
| num_samples: int, | |
| guidance_scale: float, | |
| image_guidance_scale: float, | |
| ) -> list[Image.Image]: | |
| generator = torch.Generator(device="cuda").manual_seed(seed) | |
| # Load and process each intrinsic channel image | |
| def process_image(file, **kwargs): | |
| if file is None: | |
| return None | |
| if file.name.endswith(".exr"): | |
| return load_exr_image(file.name, **kwargs).to("cuda") | |
| elif file.name.endswith((".png", ".jpg", ".jpeg")): | |
| return load_ldr_image(file.name, **kwargs).to("cuda") | |
| return None | |
| albedo_image = process_image(albedo, clamp=True) | |
| normal_image = process_image(normal, normalize=True) | |
| roughness_image = process_image(roughness, clamp=True) | |
| metallic_image = process_image(metallic, clamp=True) | |
| irradiance_image = process_image(irradiance, tonemaping=True, clamp=True) | |
| # Set default height and width based on the first available image | |
| height, width = 768, 768 | |
| for img in [ | |
| albedo_image, | |
| normal_image, | |
| roughness_image, | |
| metallic_image, | |
| irradiance_image, | |
| ]: | |
| if img is not None: | |
| height, width = img.shape[1], img.shape[2] | |
| break | |
| required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"] | |
| return_list = [] | |
| for i in range(num_samples): | |
| generated_image = pipe( | |
| prompt=prompt, | |
| albedo=albedo_image, | |
| normal=normal_image, | |
| roughness=roughness_image, | |
| metallic=metallic_image, | |
| irradiance=irradiance_image, | |
| num_inference_steps=inference_step, | |
| height=height, | |
| width=width, | |
| generator=generator, | |
| required_aovs=required_aovs, | |
| guidance_scale=guidance_scale, | |
| image_guidance_scale=image_guidance_scale, | |
| guidance_rescale=0.7, | |
| output_type="np", | |
| ).images[0] # type: ignore | |
| return_list.append((generated_image, f"Generated Image {i}")) | |
| # Append additional images to the output gallery | |
| def post_process_image(img, **kwargs): | |
| if img is not None: | |
| return (img.cpu().numpy().transpose(1, 2, 0), kwargs.get("label", "Image")) | |
| return np.zeros((height, width, 3)) | |
| return_list.extend( | |
| [ | |
| post_process_image(albedo_image, label="Albedo"), | |
| post_process_image(normal_image, label="Normal"), | |
| post_process_image(roughness_image, label="Roughness"), | |
| post_process_image(metallic_image, label="Metallic"), | |
| post_process_image(irradiance_image, label="Irradiance"), | |
| ] | |
| ) | |
| return return_list | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| gr.Markdown("## Model X -> RGB (Intrinsic channels -> realistic image)") | |
| with gr.Row(): | |
| # Input side | |
| with gr.Column(): | |
| gr.Markdown("### Given intrinsic channels") | |
| albedo = gr.File(label="Albedo", file_types=[".exr", ".png", ".jpg"]) | |
| normal = gr.File(label="Normal", file_types=[".exr", ".png", ".jpg"]) | |
| roughness = gr.File(label="Roughness", file_types=[".exr", ".png", ".jpg"]) | |
| metallic = gr.File(label="Metallic", file_types=[".exr", ".png", ".jpg"]) | |
| irradiance = gr.File( | |
| label="Irradiance", file_types=[".exr", ".png", ".jpg"] | |
| ) | |
| gr.Markdown("### Parameters") | |
| prompt = gr.Textbox(label="Prompt") | |
| run_button = gr.Button(value="Run") | |
| with gr.Accordion("Advanced options", open=False): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=-1, | |
| maximum=2147483647, | |
| step=1, | |
| randomize=True, | |
| ) | |
| inference_step = gr.Slider( | |
| label="Inference Step", | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=50, | |
| ) | |
| num_samples = gr.Slider( | |
| label="Samples", | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=1, | |
| ) | |
| guidance_scale = gr.Slider( | |
| label="Guidance Scale", | |
| minimum=0.0, | |
| maximum=10.0, | |
| step=0.1, | |
| value=7.5, | |
| ) | |
| image_guidance_scale = gr.Slider( | |
| label="Image Guidance Scale", | |
| minimum=0.0, | |
| maximum=10.0, | |
| step=0.1, | |
| value=1.5, | |
| ) | |
| # Output side | |
| with gr.Column(): | |
| gr.Markdown("### Output Gallery") | |
| result_gallery = gr.Gallery( | |
| label="Output", | |
| show_label=False, | |
| elem_id="gallery", | |
| columns=2, | |
| ) | |
| run_button.click( | |
| fn=generate, | |
| inputs=[ | |
| albedo, | |
| normal, | |
| roughness, | |
| metallic, | |
| irradiance, | |
| prompt, | |
| seed, | |
| inference_step, | |
| num_samples, | |
| guidance_scale, | |
| image_guidance_scale, | |
| ], | |
| outputs=result_gallery, | |
| queue=True, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=False, share=False, show_api=False) | |