import gradio as gr import torch import numpy as np import random from PIL import Image import spaces import os import gc from pa_src.pipeline import RFPanoInversionParallelFluxPipeline from pa_src.attn_processor import PersonalizeAnythingAttnProcessor, set_flux_transformer_attn_processor from pa_src.utils import * device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float16 if torch.cuda.is_available() else torch.float32 pipe = RFPanoInversionParallelFluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=dtype, low_cpu_mem_usage=True ).to(device) pipe.load_lora_weights("Insta360-Research/DiT360-Panorama-Image-Generation") MAX_SEED = np.iinfo(np.int32).max def generate_seed(): return random.randint(0, MAX_SEED) def create_outpainting_mask(image, target_size=(2048, 1024)): # Use full target res (model trained on 1024×2048) w, h = image.size target_w, target_h = target_size # Create canvas with gray (fallback color, but won't matter much) canvas = Image.new("RGB", (target_w, target_h), (128, 128, 128)) # Paste input centered (for symmetric outpainting) paste_x = (target_w - w) // 2 paste_y = (target_h - h) // 2 canvas.paste(image, (paste_x, paste_y)) # MASK: 1 = preserve (white), 0 = generate (black) mask_img = Image.new("L", (target_w, target_h), 0) # Start with all generate (black) mask_img.paste(255, (paste_x, paste_y, paste_x + w, paste_y + h)) # Center = preserve (white!) return canvas, mask_img def prepare_mask_for_pipeline(mask_img, latent_w, latent_h): mask = np.array(mask_img.resize((latent_w, latent_h))) / 255.0 mask = torch.from_numpy(mask).float().to(device) mask = torch.cat([mask[:, 0:1], mask, mask[:, -1:]], dim=-1).view(-1, 1) return mask @spaces.GPU def infer( prompt, input_image, seed, num_inference_steps, guidance_scale=2.8, tau=50, progress=gr.Progress(track_tqdm=True), ): if input_image is None: raise gr.Error("Please upload an input image for outpainting.") with torch.inference_mode(): torch.cuda.empty_cache() generator = torch.Generator(device=device).manual_seed(int(seed)) target_height = 1024 target_width = 2048 # ── Downscale input ────────────────────────────────────────────── max_input_side = 640 input_w, input_h = input_image.size if max(input_w, input_h) > max_input_side: scale = max_input_side / max(input_w, input_h) input_image = input_image.resize( (int(input_w * scale), int(input_h * scale)), Image.LANCZOS ) # ── Canvas + correct mask ──────────────────────────────────────── canvas = Image.new("RGB", (target_width, target_height), (127, 127, 127)) paste_x = (target_width - input_image.width) // 2 paste_y = (target_height - input_image.height) // 2 canvas.paste(input_image, (paste_x, paste_y)) mask_img = Image.new("L", (target_width, target_height), 0) mask_img.paste(255, (paste_x, paste_y, paste_x + input_image.width, paste_y + input_image.height)) # ── Calculate latent sizes EARLY (always needed) ───────────────── scale_factor = pipe.vae_scale_factor latent_h = target_height // (scale_factor * 2) latent_w = target_width // (scale_factor * 2) img_dims = latent_h * (latent_w + 2) # ── Source & full prompt ───────────────────────────────────────── source_prompt = ( "a high-quality historical or modern photograph, " "realistic scene, natural lighting, detailed architecture and landscape" ) full_prompt = f"A seamless 360° equirectangular panorama, photorealistic, high detail, {prompt.strip()}" # ── Inversion (real or dummy) ──────────────────────────────────── if True: # change to False for dummy testing inverted_latents, image_latents, latent_image_ids = pipe.invert( source_prompt=source_prompt, image=canvas, height=target_height, width=target_width, num_inference_steps=num_inference_steps, gamma=1.2, ) else: print("Using dummy packed latents for testing (Flux expects 3D packed shape + 2D IDs)") # Packed latents: 3D (bsz, num_patches, hidden_dim) hidden_dim = 64 # common Flux hidden size after packing (adjust if crashes later) num_patches = latent_h * (latent_w + 2) # your pano-specific +2 packed_shape = (1, num_patches, hidden_dim) inverted_latents = torch.randn(packed_shape, device=device, dtype=dtype) image_latents = torch.randn(packed_shape, device=device, dtype=dtype) # latent_image_ids: make 2D to match txt_ids after potential stripping # Shape: (num_patches, 3) for (x, y, t) positional coords ids_shape = (num_patches, 3) # NO batch dim here — pipeline often expects/strips batch latent_image_ids = torch.randn(ids_shape, device=device, dtype=dtype) # Optional: add small random values mimicking real IDs (0-1 normalized coords) # latent_image_ids[..., 0] = torch.linspace(0, 1, num_patches, device=device) # x # latent_image_ids[..., 1] = torch.linspace(0, 1, num_patches, device=device) # y # latent_image_ids[..., 2] = torch.zeros(num_patches, device=device) # t/time # ── Mask prep & attn processor (still needed even in dummy) ────── mask = prepare_mask_for_pipeline(mask_img, latent_w, latent_h) set_flux_transformer_attn_processor( pipe.transformer, set_attn_proc_func=lambda name, dh, nh, ap: PersonalizeAnythingAttnProcessor( name=name, tau=tau / 100.0, mask=mask, device=device, img_dims=img_dims ), ) # ── Generation ─────────────────────────────────────────────────── result_images = pipe( [source_prompt, full_prompt], inverted_latents=inverted_latents, image_latents=image_latents, latent_image_ids=latent_image_ids, height=target_height, width=target_width, start_timestep=0.0, stop_timestep=0.99, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, eta=1.0, generator=generator, mask=mask, use_timestep=True, ).images final_image = result_images[1] # -------------------- Gradio 界面 -------------------- css = """ #main-container { display: flex; flex-direction: column; gap: 2rem; margin-top: 1rem; } #top-row { display: flex; flex-direction: row; justify-content: center; align-items: flex-start; gap: 2rem; } #bottom-row { display: flex; flex-direction: row; gap: 2rem; } #image-panel { flex: 2; max-width: 1200px; margin: 0 auto; } #input-panel { flex: 1; } #example-panel { flex: 2; } #settings-panel { flex: 1; max-width: 280px; } #prompt-box textarea { resize: none !important; } """ with gr.Blocks(css=css) as demo: gr.Markdown( """ # 🌀 DiT360: High-Fidelity Panoramic Image Generation with Outpainting Here are our resources: - 💻 **Code**: [https://github.com/Insta360-Research-Team/DiT360](https://github.com/Insta360-Research-Team/DiT360) - 🌐 **Web Page**: [https://fenghora.github.io/DiT360-Page/](https://fenghora.github.io/DiT360-Page/) - 🧠 **Pretrained Model**: [https://huggingface.co/Insta360-Research/DiT360-Panorama-Image-Generation](https://huggingface.co/Insta360-Research/DiT360-Panorama-Image-Generation) """ ) gr.Markdown("Official Gradio demo for **[DiT360](https://fenghora.github.io/DiT360-Page/)**, now with outpainting from a single image.") with gr.Row(elem_id="top-row"): with gr.Column(elem_id="top-panel"): result = gr.Image(label="Generated Panorama", show_label=False, type="pil", height=800) input_image = gr.Image(type="pil", label="Input Image (for outpainting)", height=300) prompt = gr.Textbox( elem_id="prompt-box", placeholder="Describe your panoramic scene here...", show_label=False, lines=2, container=False, ) run_button = gr.Button("Generate Panorama", variant="primary") with gr.Row(elem_id="bottom-row"): with gr.Column(elem_id="example-panel"): gr.Markdown("### 📚 Examples") gr.Examples(examples=[ "A medieval castle stands proudly on a hilltop surrounded by autumn forests, with golden light spilling across the landscape.", "A futuristic cityscape under a starry night sky.", "A futuristic city skyline reflects on the calm river at sunset, neon lights glowing against the twilight sky.", "A snowy mountain village under northern lights, with cozy cabins and smoke rising from chimneys.", ], inputs=[prompt]) with gr.Column(elem_id="settings-panel"): gr.Markdown("### ⚙️ Settings") gr.Markdown( "For better results, the output image is fixed at **2048×1024** (2:1 aspect ratio). " ) seed_display = gr.Number(value=0, label="Seed", interactive=True) random_seed_button = gr.Button("🎲 Random Seed") random_seed_button.click(fn=generate_seed, inputs=[], outputs=seed_display) num_inference_steps = gr.Slider(10, 100, value=15, step=1, label="Inference Steps") tau_slider = gr.Slider(0, 100, value=30, step=1, label="Tau (0=strictly follow input, 100=free generation)") gr.Markdown( "💡 *Tip: Upload an image and describe the scene. The model will extend it to a full 360° panorama using outpainting.*" ) gr.on( triggers=[run_button.click, prompt.submit], fn=infer, inputs=[prompt, input_image, seed_display, num_inference_steps, tau_slider], outputs=[result], ) if __name__ == "__main__": demo.launch()