Spaces:
Running on Zero
Running on Zero
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import random | |
| from PIL import Image | |
| import spaces | |
| import os | |
| import gc | |
| from pa_src.pipeline import RFPanoInversionParallelFluxPipeline | |
| from pa_src.attn_processor import PersonalizeAnythingAttnProcessor, set_flux_transformer_attn_processor | |
| from pa_src.utils import * | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| pipe = RFPanoInversionParallelFluxPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-dev", | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True | |
| ).to(device) | |
| pipe.load_lora_weights("Insta360-Research/DiT360-Panorama-Image-Generation") | |
| MAX_SEED = np.iinfo(np.int32).max | |
| def generate_seed(): | |
| return random.randint(0, MAX_SEED) | |
| def create_outpainting_mask(image, target_size=(2048, 1024)): # Use full target res (model trained on 1024Γ2048) | |
| w, h = image.size | |
| target_w, target_h = target_size | |
| # Create canvas with gray (fallback color, but won't matter much) | |
| canvas = Image.new("RGB", (target_w, target_h), (128, 128, 128)) | |
| # Paste input centered (for symmetric outpainting) | |
| paste_x = (target_w - w) // 2 | |
| paste_y = (target_h - h) // 2 | |
| canvas.paste(image, (paste_x, paste_y)) | |
| # MASK: 1 = preserve (white), 0 = generate (black) | |
| mask_img = Image.new("L", (target_w, target_h), 0) # Start with all generate (black) | |
| mask_img.paste(255, (paste_x, paste_y, paste_x + w, paste_y + h)) # Center = preserve (white!) | |
| return canvas, mask_img | |
| def prepare_mask_for_pipeline(mask_img, latent_w, latent_h): | |
| mask = np.array(mask_img.resize((latent_w, latent_h))) / 255.0 | |
| mask = torch.from_numpy(mask).float().to(device) | |
| mask = torch.cat([mask[:, 0:1], mask, mask[:, -1:]], dim=-1).view(-1, 1) | |
| return mask | |
| def infer( | |
| prompt, | |
| input_image, | |
| seed, | |
| num_inference_steps, | |
| guidance_scale=2.8, | |
| tau=50, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| if input_image is None: | |
| raise gr.Error("Please upload an input image for outpainting.") | |
| with torch.inference_mode(): | |
| torch.cuda.empty_cache() | |
| generator = torch.Generator(device=device).manual_seed(int(seed)) | |
| target_height = 1024 | |
| target_width = 2048 | |
| # ββ Downscale input ββββββββββββββββββββββββββββββββββββββββββββββ | |
| max_input_side = 640 | |
| input_w, input_h = input_image.size | |
| if max(input_w, input_h) > max_input_side: | |
| scale = max_input_side / max(input_w, input_h) | |
| input_image = input_image.resize( | |
| (int(input_w * scale), int(input_h * scale)), | |
| Image.LANCZOS | |
| ) | |
| # ββ Canvas + correct mask ββββββββββββββββββββββββββββββββββββββββ | |
| canvas = Image.new("RGB", (target_width, target_height), (127, 127, 127)) | |
| paste_x = (target_width - input_image.width) // 2 | |
| paste_y = (target_height - input_image.height) // 2 | |
| canvas.paste(input_image, (paste_x, paste_y)) | |
| mask_img = Image.new("L", (target_width, target_height), 0) | |
| mask_img.paste(255, (paste_x, paste_y, paste_x + input_image.width, paste_y + input_image.height)) | |
| # ββ Calculate latent sizes EARLY (always needed) βββββββββββββββββ | |
| scale_factor = pipe.vae_scale_factor | |
| latent_h = target_height // (scale_factor * 2) | |
| latent_w = target_width // (scale_factor * 2) | |
| img_dims = latent_h * (latent_w + 2) | |
| # ββ Source & full prompt βββββββββββββββββββββββββββββββββββββββββ | |
| source_prompt = ( | |
| "a high-quality historical or modern photograph, " | |
| "realistic scene, natural lighting, detailed architecture and landscape" | |
| ) | |
| full_prompt = f"A seamless 360Β° equirectangular panorama, photorealistic, high detail, {prompt.strip()}" | |
| # ββ Inversion (real or dummy) ββββββββββββββββββββββββββββββββββββ | |
| if True: # change to False for dummy testing | |
| inverted_latents, image_latents, latent_image_ids = pipe.invert( | |
| source_prompt=source_prompt, | |
| image=canvas, | |
| height=target_height, | |
| width=target_width, | |
| num_inference_steps=num_inference_steps, | |
| gamma=1.2, | |
| ) | |
| else: | |
| print("Using dummy packed latents for testing (Flux expects 3D packed shape + 2D IDs)") | |
| # Packed latents: 3D (bsz, num_patches, hidden_dim) | |
| hidden_dim = 64 # common Flux hidden size after packing (adjust if crashes later) | |
| num_patches = latent_h * (latent_w + 2) # your pano-specific +2 | |
| packed_shape = (1, num_patches, hidden_dim) | |
| inverted_latents = torch.randn(packed_shape, device=device, dtype=dtype) | |
| image_latents = torch.randn(packed_shape, device=device, dtype=dtype) | |
| # latent_image_ids: make 2D to match txt_ids after potential stripping | |
| # Shape: (num_patches, 3) for (x, y, t) positional coords | |
| ids_shape = (num_patches, 3) # NO batch dim here β pipeline often expects/strips batch | |
| latent_image_ids = torch.randn(ids_shape, device=device, dtype=dtype) | |
| # Optional: add small random values mimicking real IDs (0-1 normalized coords) | |
| # latent_image_ids[..., 0] = torch.linspace(0, 1, num_patches, device=device) # x | |
| # latent_image_ids[..., 1] = torch.linspace(0, 1, num_patches, device=device) # y | |
| # latent_image_ids[..., 2] = torch.zeros(num_patches, device=device) # t/time | |
| # ββ Mask prep & attn processor (still needed even in dummy) ββββββ | |
| mask = prepare_mask_for_pipeline(mask_img, latent_w, latent_h) | |
| set_flux_transformer_attn_processor( | |
| pipe.transformer, | |
| set_attn_proc_func=lambda name, dh, nh, ap: PersonalizeAnythingAttnProcessor( | |
| name=name, tau=tau / 100.0, mask=mask, device=device, img_dims=img_dims | |
| ), | |
| ) | |
| # ββ Generation βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| result_images = pipe( | |
| [source_prompt, full_prompt], | |
| inverted_latents=inverted_latents, | |
| image_latents=image_latents, | |
| latent_image_ids=latent_image_ids, | |
| height=target_height, | |
| width=target_width, | |
| start_timestep=0.0, | |
| stop_timestep=0.99, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| eta=1.0, | |
| generator=generator, | |
| mask=mask, | |
| use_timestep=True, | |
| ).images | |
| final_image = result_images[1] | |
| # -------------------- Gradio ηι’ -------------------- | |
| css = """ | |
| #main-container { | |
| display: flex; | |
| flex-direction: column; | |
| gap: 2rem; | |
| margin-top: 1rem; | |
| } | |
| #top-row { | |
| display: flex; | |
| flex-direction: row; | |
| justify-content: center; | |
| align-items: flex-start; | |
| gap: 2rem; | |
| } | |
| #bottom-row { | |
| display: flex; | |
| flex-direction: row; | |
| gap: 2rem; | |
| } | |
| #image-panel { | |
| flex: 2; | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| } | |
| #input-panel { | |
| flex: 1; | |
| } | |
| #example-panel { | |
| flex: 2; | |
| } | |
| #settings-panel { | |
| flex: 1; | |
| max-width: 280px; | |
| } | |
| #prompt-box textarea { | |
| resize: none !important; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown( | |
| """ | |
| # π DiT360: High-Fidelity Panoramic Image Generation with Outpainting | |
| Here are our resources: | |
| - π» **Code**: [https://github.com/Insta360-Research-Team/DiT360](https://github.com/Insta360-Research-Team/DiT360) | |
| - π **Web Page**: [https://fenghora.github.io/DiT360-Page/](https://fenghora.github.io/DiT360-Page/) | |
| - π§ **Pretrained Model**: [https://huggingface.co/Insta360-Research/DiT360-Panorama-Image-Generation](https://huggingface.co/Insta360-Research/DiT360-Panorama-Image-Generation) | |
| """ | |
| ) | |
| gr.Markdown("Official Gradio demo for **[DiT360](https://fenghora.github.io/DiT360-Page/)**, now with outpainting from a single image.") | |
| with gr.Row(elem_id="top-row"): | |
| with gr.Column(elem_id="top-panel"): | |
| result = gr.Image(label="Generated Panorama", show_label=False, type="pil", height=800) | |
| input_image = gr.Image(type="pil", label="Input Image (for outpainting)", height=300) | |
| prompt = gr.Textbox( | |
| elem_id="prompt-box", | |
| placeholder="Describe your panoramic scene here...", | |
| show_label=False, | |
| lines=2, | |
| container=False, | |
| ) | |
| run_button = gr.Button("Generate Panorama", variant="primary") | |
| with gr.Row(elem_id="bottom-row"): | |
| with gr.Column(elem_id="example-panel"): | |
| gr.Markdown("### π Examples") | |
| gr.Examples(examples=[ | |
| "A medieval castle stands proudly on a hilltop surrounded by autumn forests, with golden light spilling across the landscape.", | |
| "A futuristic cityscape under a starry night sky.", | |
| "A futuristic city skyline reflects on the calm river at sunset, neon lights glowing against the twilight sky.", | |
| "A snowy mountain village under northern lights, with cozy cabins and smoke rising from chimneys.", | |
| ], inputs=[prompt]) | |
| with gr.Column(elem_id="settings-panel"): | |
| gr.Markdown("### βοΈ Settings") | |
| gr.Markdown( | |
| "For better results, the output image is fixed at **2048Γ1024** (2:1 aspect ratio). " | |
| ) | |
| seed_display = gr.Number(value=0, label="Seed", interactive=True) | |
| random_seed_button = gr.Button("π² Random Seed") | |
| random_seed_button.click(fn=generate_seed, inputs=[], outputs=seed_display) | |
| num_inference_steps = gr.Slider(10, 100, value=15, step=1, label="Inference Steps") | |
| tau_slider = gr.Slider(0, 100, value=30, step=1, label="Tau (0=strictly follow input, 100=free generation)") | |
| gr.Markdown( | |
| "π‘ *Tip: Upload an image and describe the scene. The model will extend it to a full 360Β° panorama using outpainting.*" | |
| ) | |
| gr.on( | |
| triggers=[run_button.click, prompt.submit], | |
| fn=infer, | |
| inputs=[prompt, input_image, seed_display, num_inference_steps, tau_slider], | |
| outputs=[result], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |