Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import compat_patch | |
| # import spaces #[uncomment to use ZeroGPU] | |
| from scripts.cubemap_vae import CubemapVAE | |
| from scripts.cubemap_unet import CubemapUNet | |
| from diffusers import DiffusionPipeline | |
| from scripts.cubemap_diffusion_pipeline import CubemapDiffusionInpaintPipeline | |
| from scripts.utils import resize_and_crop,convert_to_equirectangular,to_cubemap_dict,cubemap_unfold | |
| from diffusers import AutoencoderKL,UNet2DConditionModel | |
| from contextlib import nullcontext | |
| import torch | |
| from PIL import Image | |
| import base64 | |
| from io import BytesIO | |
| import json | |
| import os | |
| from datetime import datetime | |
| import time | |
| from realesrgan import RealESRGANer | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_repo_id = "zimhe/SpatialDiffusion" # Replace to the model you would like to use | |
| upscale_model_id = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth" | |
| if torch.cuda.is_available(): | |
| print("CUDA is available") | |
| torch_dtype = torch.float16 | |
| else: | |
| torch_dtype = torch.float32 | |
| pretrained_vae = AutoencoderKL.from_pretrained( | |
| model_repo_id, subfolder="vae",torch_dtype=torch_dtype | |
| ) | |
| pretrained_unet=UNet2DConditionModel.from_pretrained(model_repo_id,subfolder="unet",torch_dtype=torch_dtype) | |
| cubemap_unet=CubemapUNet(pretrained_unet=pretrained_unet) | |
| cubemap_vae = CubemapVAE(num_views=6, pretrained_vae=pretrained_vae,in_channels=3) # 你的 VAE 结构 | |
| pipe = CubemapDiffusionInpaintPipeline.from_pretrained(model_repo_id,vae=cubemap_vae,unet=cubemap_unet,torch_dtype=torch_dtype,safety_checker=None) | |
| pipe = pipe.to(device) | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) | |
| upsampler = RealESRGANer( | |
| scale=4, | |
| model_path=upscale_model_id, | |
| model=model, | |
| tile=512, | |
| tile_pad=32, | |
| pre_pad=0, | |
| device=device, | |
| half=True | |
| ) | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 512 | |
| # 获取当前脚本所在目录 | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| viewer_html_path = os.path.join(current_dir, "viewer.html") | |
| default_image_url = "../examples/004.png" | |
| # 读取 viewer.html 内容 | |
| with open(viewer_html_path, 'r', encoding='utf-8') as f: | |
| viewer_html_content = f.read() | |
| with open("examples/examples.json", "r") as f: | |
| examples_data = json.load(f) | |
| examples=[] | |
| example_labels=[] | |
| for key in examples_data: | |
| example=examples_data[key] | |
| example_list=[ | |
| example["img"], | |
| example["global"], | |
| example["front"], | |
| example["back"], | |
| example["left"], | |
| example["right"], | |
| example["top"], | |
| example["bottom"] | |
| ] | |
| examples.append(example_list) | |
| example_labels.append(key) | |
| def process_panorama(image): | |
| """处理上传的全景图片并创建查看器""" | |
| if image is None: | |
| return None | |
| try: | |
| # 将图片转换为 JPEG 格式的二进制数据 | |
| buffered = BytesIO() | |
| if isinstance(image, Image.Image): | |
| image.save(buffered, format="JPEG", quality=95, optimize=True) | |
| else: | |
| Image.fromarray(image).save(buffered, format="JPEG", quality=95, optimize=True) | |
| # 将图片转换为 base64 字符串 | |
| img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
| return img_str | |
| except Exception as e: | |
| print(f"处理图片时出错: {str(e)}") | |
| return None | |
| def infer( | |
| prompt, | |
| front_prompt, | |
| back_prompt, | |
| left_prompt, | |
| right_prompt, | |
| top_prompt, | |
| bottom_prompt, | |
| cond_img: Image.Image, # Declare cond_img as a PIL Image | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| upscale=False, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| generator = torch.Generator().manual_seed(seed) | |
| # Preprocess the input image to make it square (1:1 aspect ratio) | |
| # Ensure the image is square by cropping to the smallest dimension | |
| W, H = cond_img.size | |
| min_dim = min(W, H) | |
| left = (W - min_dim) // 2 | |
| top = (H - min_dim) // 2 | |
| right = left + min_dim | |
| bottom = top + min_dim | |
| cond_img = cond_img.crop((left, top, right, bottom)) | |
| if torch.backends.mps.is_available(): | |
| autocast_ctx = nullcontext() | |
| elif torch.cuda.is_available(): | |
| autocast_ctx = torch.amp.autocast(device_type="cuda") | |
| else: | |
| autocast_ctx = torch.cpu.amp.autocast() | |
| face_prompt_dict = { | |
| "front": front_prompt, | |
| "back": back_prompt, | |
| "left": left_prompt, | |
| "right": right_prompt, | |
| "top": top_prompt, | |
| "bottom": bottom_prompt, | |
| } | |
| with autocast_ctx: | |
| images = pipe( | |
| global_prompt=prompt, | |
| per_face_prompts=face_prompt_dict, | |
| image=cond_img, | |
| negative_prompt=negative_prompt, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| width=width, | |
| height=height, | |
| output_type="np", | |
| generator=generator, | |
| ).images | |
| cubemaps=[resize_and_crop(image=image,padding=16) for image in images] | |
| cubemap_dict=to_cubemap_dict(cubemaps) | |
| pano_img=convert_to_equirectangular(cubemap_dict,width=2048,height=1024) | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| if upscale: | |
| try: | |
| # Use the existing autocast_ctx instead of creating a new one | |
| img_np = np.array(pano_img).astype(np.uint8) | |
| output, _ = upsampler.enhance(img=img_np, outscale=2) | |
| pano_img = Image.fromarray(output) | |
| except Exception as e: | |
| print(f"Upscaling error: {str(e)}") | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| return cubemap_dict["F"], cubemap_dict["B"], cubemap_dict["L"], cubemap_dict["R"], cubemap_dict["U"], cubemap_dict["D"], pano_img,seed, | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 980px; | |
| } | |
| #input_container { | |
| margin: 0 auto; | |
| max-width: 640px; | |
| } | |
| #squre_image { | |
| width: 100%; | |
| height: auto; | |
| aspect-ratio: 1 / 1; | |
| } | |
| #pano_image { | |
| width: 100%; | |
| height: auto; | |
| aspect-ratio: 2 / 1; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown(" # Spatial Diffusion") | |
| pano_html = gr.HTML(label="panorama viewer", elem_classes=["panorama-output"],container=True) | |
| gr.Markdown("## Input Parameters") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Image upload with 1:1 aspect ratio | |
| cond_img = gr.Image( | |
| label="Condition Image", | |
| type="pil", | |
| sources=["upload","webcam","clipboard"], | |
| elem_id="squre_image", | |
| container=True, | |
| ) | |
| with gr.Column(scale=1): | |
| global_prompt = gr.Text( | |
| label="Global Prompt", | |
| show_label=True, | |
| max_lines=2, | |
| placeholder="Enter global prompt", | |
| container=True, | |
| ) | |
| face_prompts = {} | |
| for face in ["front", "back", "left", "right", "top", "bottom"]: | |
| face_prompts[face] = gr.Text( | |
| label=f"{face.capitalize()} Prompt", | |
| show_label=True, | |
| max_lines=1, | |
| placeholder=f"Enter {face.lower()} prompt", | |
| container=False, | |
| ) | |
| run_button = gr.Button("Run", variant="primary") | |
| gr.Examples( | |
| examples=examples, | |
| example_labels=example_labels, | |
| inputs=[ | |
| cond_img, | |
| global_prompt, | |
| face_prompts["front"], | |
| face_prompts["back"], | |
| face_prompts["left"], | |
| face_prompts["right"], | |
| face_prompts["top"], | |
| face_prompts["bottom"] | |
| ], | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| negative_prompt = gr.Text( | |
| value='''grids, lines, texts, labels, blury, bad quality, bad image, wrong scale, clear seams, distorted objects, disconnected edges, replicated items, | |
| blurry, overexposed, chaotic, low resolution, 3D render, overly dramatic, unrealistic''', | |
| label="Negative prompt", | |
| max_lines=1, | |
| placeholder="Enter a negative prompt", | |
| ) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=0, | |
| ) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| upscale=gr.Checkbox(label="Upscale", value=False) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", | |
| minimum=256, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=32, | |
| value=512, # Replace with defaults that work for your model | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| minimum=256, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=32, | |
| value=512, # Replace with defaults that work for your model | |
| ) | |
| with gr.Row(): | |
| guidance_scale = gr.Slider( | |
| label="Guidance scale", | |
| minimum=0.0, | |
| maximum=15.0, | |
| step=0.1, | |
| value=9.0, # Replace with defaults that work for your model | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Number of inference steps", | |
| minimum=1, | |
| maximum=50, | |
| step=1, | |
| value=30, # Replace with defaults that work for your model | |
| ) | |
| gr.Markdown("## Result") | |
| with gr.Row(): | |
| left_face = gr.Image(label="Left", show_label=True,elem_id="squre_image",format="png") | |
| front_face = gr.Image(label="Front", show_label=True,elem_id="squre_image",format="png") | |
| right_face = gr.Image(label="Right", show_label=True,elem_id="squre_image",format="png") | |
| with gr.Row(): | |
| back_face = gr.Image(label="Back", show_label=True,elem_id="squre_image",format="png") | |
| top_face = gr.Image(label="Top", show_label=True,elem_id="squre_image",format="png") | |
| bottom_face = gr.Image(label="Bottom", show_label=True,elem_id="squre_image",format="png") | |
| pano = gr.Image(label="Equirectangular Image", show_label=True, interactive=False,type="pil",elem_id="pano_image",format="png") | |
| save_button = gr.Button("Save All", variant="primary") | |
| # 监听 result 图像的变化 | |
| pano.change( | |
| fn=process_panorama, # 不需要 Python 函数 | |
| inputs=[pano], # 将图像转换为 base64 字符串 | |
| outputs=[pano_html], | |
| js=f""" | |
| async (img_obj) => {{ | |
| if (!img_obj || !img_obj.url) return; | |
| // 创建 iframe 容器 | |
| const container = document.querySelector('.panorama-output'); | |
| if (container) {{ | |
| // 将 viewer.html 内容转换为 data URL | |
| const viewerHtml = `{viewer_html_content}`; | |
| const viewerBlob = new Blob([viewerHtml], {{ type: 'text/html' }}); | |
| const viewerUrl = URL.createObjectURL(viewerBlob); | |
| container.innerHTML = `<iframe id="panorama-viewer" style="width: 100%; height: 480px; border: none;" src="${{viewerUrl}}"></iframe>`; | |
| // 等待 iframe 加载完成 | |
| const iframe = document.getElementById('panorama-viewer'); | |
| iframe.onload = async () => {{ | |
| try {{ | |
| // 从 URL 获取图片数据 | |
| const response = await fetch(img_obj.url); | |
| const blob = await response.blob(); | |
| const reader = new FileReader(); | |
| reader.onloadend = () => {{ | |
| // 向 iframe 发送图片数据 | |
| iframe.contentWindow.postMessage({{ | |
| type: 'loadPanorama', | |
| image: reader.result | |
| }}, '*'); | |
| }}; | |
| reader.readAsDataURL(blob); | |
| }} catch (error) {{ | |
| console.error('Error processing image:', error); | |
| console.log('Image object:', img_obj); | |
| }} | |
| }}; | |
| }} | |
| }} | |
| """ | |
| ) | |
| run_button.click( | |
| fn=infer, | |
| inputs=[ | |
| global_prompt, | |
| face_prompts["front"], # 显式传递每个面对应的组件 | |
| face_prompts["back"], | |
| face_prompts["left"], | |
| face_prompts["right"], | |
| face_prompts["top"], | |
| face_prompts["bottom"], | |
| cond_img, | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| upscale | |
| ], | |
| outputs=[ | |
| front_face, # Update with "front" | |
| back_face, # Update with "back" | |
| left_face, # Update with "left" | |
| right_face, # Update with "right" | |
| top_face, # Update with "top" | |
| bottom_face, # Update with "bottom" | |
| pano, # Update with "pano" | |
| seed, # Update with "seed" | |
| ], | |
| ) | |
| # 初始化时显示默认全景图 | |
| demo.load( | |
| fn=None, | |
| inputs=None, | |
| outputs=None, | |
| js=f""" | |
| () => {{ | |
| // 创建 iframe 容器 | |
| const container = document.querySelector('.panorama-output'); | |
| if (container) {{ | |
| // 将 viewer.html 内容转换为 data URL | |
| const viewerHtml = `{viewer_html_content}`; | |
| const viewerBlob = new Blob([viewerHtml], {{ type: 'text/html' }}); | |
| const viewerUrl = URL.createObjectURL(viewerBlob); | |
| container.innerHTML = `<iframe id="panorama-viewer" style="width: 100%; height: 480px; border: none;" src="${{viewerUrl}}"></iframe>`; | |
| // 等待 iframe 加载完成 | |
| const iframe = document.getElementById('panorama-viewer'); | |
| iframe.onload = () => {{ | |
| // 使用本地默认全景图 | |
| const defaultImage = '{default_image_url}'; | |
| // 向 iframe 发送图片数据 | |
| iframe.contentWindow.postMessage({{ | |
| type: 'loadPanorama', | |
| image: defaultImage | |
| }}, '*'); | |
| }}; | |
| }} | |
| }} | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |