DiT360_edit / app.py
asd755's picture
Update app.py
a8dbf58 verified
import gradio as gr
import torch
import numpy as np
import random
from PIL import Image
import spaces
import os
import gc
from pa_src.pipeline import RFPanoInversionParallelFluxPipeline
from pa_src.attn_processor import PersonalizeAnythingAttnProcessor, set_flux_transformer_attn_processor
from pa_src.utils import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
pipe = RFPanoInversionParallelFluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=dtype,
low_cpu_mem_usage=True
).to(device)
pipe.load_lora_weights("Insta360-Research/DiT360-Panorama-Image-Generation")
MAX_SEED = np.iinfo(np.int32).max
def generate_seed():
return random.randint(0, MAX_SEED)
def create_outpainting_mask(image, target_size=(2048, 1024)): # Use full target res (model trained on 1024Γ—2048)
w, h = image.size
target_w, target_h = target_size
# Create canvas with gray (fallback color, but won't matter much)
canvas = Image.new("RGB", (target_w, target_h), (128, 128, 128))
# Paste input centered (for symmetric outpainting)
paste_x = (target_w - w) // 2
paste_y = (target_h - h) // 2
canvas.paste(image, (paste_x, paste_y))
# MASK: 1 = preserve (white), 0 = generate (black)
mask_img = Image.new("L", (target_w, target_h), 0) # Start with all generate (black)
mask_img.paste(255, (paste_x, paste_y, paste_x + w, paste_y + h)) # Center = preserve (white!)
return canvas, mask_img
def prepare_mask_for_pipeline(mask_img, latent_w, latent_h):
mask = np.array(mask_img.resize((latent_w, latent_h))) / 255.0
mask = torch.from_numpy(mask).float().to(device)
mask = torch.cat([mask[:, 0:1], mask, mask[:, -1:]], dim=-1).view(-1, 1)
return mask
@spaces.GPU
def infer(
prompt,
input_image,
seed,
num_inference_steps,
guidance_scale=2.8,
tau=50,
progress=gr.Progress(track_tqdm=True),
):
if input_image is None:
raise gr.Error("Please upload an input image for outpainting.")
with torch.inference_mode():
torch.cuda.empty_cache()
generator = torch.Generator(device=device).manual_seed(int(seed))
target_height = 1024
target_width = 2048
# ── Downscale input ──────────────────────────────────────────────
max_input_side = 640
input_w, input_h = input_image.size
if max(input_w, input_h) > max_input_side:
scale = max_input_side / max(input_w, input_h)
input_image = input_image.resize(
(int(input_w * scale), int(input_h * scale)),
Image.LANCZOS
)
# ── Canvas + correct mask ────────────────────────────────────────
canvas = Image.new("RGB", (target_width, target_height), (127, 127, 127))
paste_x = (target_width - input_image.width) // 2
paste_y = (target_height - input_image.height) // 2
canvas.paste(input_image, (paste_x, paste_y))
mask_img = Image.new("L", (target_width, target_height), 0)
mask_img.paste(255, (paste_x, paste_y, paste_x + input_image.width, paste_y + input_image.height))
# ── Calculate latent sizes EARLY (always needed) ─────────────────
scale_factor = pipe.vae_scale_factor
latent_h = target_height // (scale_factor * 2)
latent_w = target_width // (scale_factor * 2)
img_dims = latent_h * (latent_w + 2)
# ── Source & full prompt ─────────────────────────────────────────
source_prompt = (
"a high-quality historical or modern photograph, "
"realistic scene, natural lighting, detailed architecture and landscape"
)
full_prompt = f"A seamless 360Β° equirectangular panorama, photorealistic, high detail, {prompt.strip()}"
# ── Inversion (real or dummy) ────────────────────────────────────
if True: # change to False for dummy testing
inverted_latents, image_latents, latent_image_ids = pipe.invert(
source_prompt=source_prompt,
image=canvas,
height=target_height,
width=target_width,
num_inference_steps=num_inference_steps,
gamma=1.2,
)
else:
print("Using dummy packed latents for testing (Flux expects 3D packed shape + 2D IDs)")
# Packed latents: 3D (bsz, num_patches, hidden_dim)
hidden_dim = 64 # common Flux hidden size after packing (adjust if crashes later)
num_patches = latent_h * (latent_w + 2) # your pano-specific +2
packed_shape = (1, num_patches, hidden_dim)
inverted_latents = torch.randn(packed_shape, device=device, dtype=dtype)
image_latents = torch.randn(packed_shape, device=device, dtype=dtype)
# latent_image_ids: make 2D to match txt_ids after potential stripping
# Shape: (num_patches, 3) for (x, y, t) positional coords
ids_shape = (num_patches, 3) # NO batch dim here β€” pipeline often expects/strips batch
latent_image_ids = torch.randn(ids_shape, device=device, dtype=dtype)
# Optional: add small random values mimicking real IDs (0-1 normalized coords)
# latent_image_ids[..., 0] = torch.linspace(0, 1, num_patches, device=device) # x
# latent_image_ids[..., 1] = torch.linspace(0, 1, num_patches, device=device) # y
# latent_image_ids[..., 2] = torch.zeros(num_patches, device=device) # t/time
# ── Mask prep & attn processor (still needed even in dummy) ──────
mask = prepare_mask_for_pipeline(mask_img, latent_w, latent_h)
set_flux_transformer_attn_processor(
pipe.transformer,
set_attn_proc_func=lambda name, dh, nh, ap: PersonalizeAnythingAttnProcessor(
name=name, tau=tau / 100.0, mask=mask, device=device, img_dims=img_dims
),
)
# ── Generation ───────────────────────────────────────────────────
result_images = pipe(
[source_prompt, full_prompt],
inverted_latents=inverted_latents,
image_latents=image_latents,
latent_image_ids=latent_image_ids,
height=target_height,
width=target_width,
start_timestep=0.0,
stop_timestep=0.99,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
eta=1.0,
generator=generator,
mask=mask,
use_timestep=True,
).images
final_image = result_images[1]
# -------------------- Gradio η•Œι’ --------------------
css = """
#main-container {
display: flex;
flex-direction: column;
gap: 2rem;
margin-top: 1rem;
}
#top-row {
display: flex;
flex-direction: row;
justify-content: center;
align-items: flex-start;
gap: 2rem;
}
#bottom-row {
display: flex;
flex-direction: row;
gap: 2rem;
}
#image-panel {
flex: 2;
max-width: 1200px;
margin: 0 auto;
}
#input-panel {
flex: 1;
}
#example-panel {
flex: 2;
}
#settings-panel {
flex: 1;
max-width: 280px;
}
#prompt-box textarea {
resize: none !important;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(
"""
# πŸŒ€ DiT360: High-Fidelity Panoramic Image Generation with Outpainting
Here are our resources:
- πŸ’» **Code**: [https://github.com/Insta360-Research-Team/DiT360](https://github.com/Insta360-Research-Team/DiT360)
- 🌐 **Web Page**: [https://fenghora.github.io/DiT360-Page/](https://fenghora.github.io/DiT360-Page/)
- 🧠 **Pretrained Model**: [https://huggingface.co/Insta360-Research/DiT360-Panorama-Image-Generation](https://huggingface.co/Insta360-Research/DiT360-Panorama-Image-Generation)
"""
)
gr.Markdown("Official Gradio demo for **[DiT360](https://fenghora.github.io/DiT360-Page/)**, now with outpainting from a single image.")
with gr.Row(elem_id="top-row"):
with gr.Column(elem_id="top-panel"):
result = gr.Image(label="Generated Panorama", show_label=False, type="pil", height=800)
input_image = gr.Image(type="pil", label="Input Image (for outpainting)", height=300)
prompt = gr.Textbox(
elem_id="prompt-box",
placeholder="Describe your panoramic scene here...",
show_label=False,
lines=2,
container=False,
)
run_button = gr.Button("Generate Panorama", variant="primary")
with gr.Row(elem_id="bottom-row"):
with gr.Column(elem_id="example-panel"):
gr.Markdown("### πŸ“š Examples")
gr.Examples(examples=[
"A medieval castle stands proudly on a hilltop surrounded by autumn forests, with golden light spilling across the landscape.",
"A futuristic cityscape under a starry night sky.",
"A futuristic city skyline reflects on the calm river at sunset, neon lights glowing against the twilight sky.",
"A snowy mountain village under northern lights, with cozy cabins and smoke rising from chimneys.",
], inputs=[prompt])
with gr.Column(elem_id="settings-panel"):
gr.Markdown("### βš™οΈ Settings")
gr.Markdown(
"For better results, the output image is fixed at **2048Γ—1024** (2:1 aspect ratio). "
)
seed_display = gr.Number(value=0, label="Seed", interactive=True)
random_seed_button = gr.Button("🎲 Random Seed")
random_seed_button.click(fn=generate_seed, inputs=[], outputs=seed_display)
num_inference_steps = gr.Slider(10, 100, value=15, step=1, label="Inference Steps")
tau_slider = gr.Slider(0, 100, value=30, step=1, label="Tau (0=strictly follow input, 100=free generation)")
gr.Markdown(
"πŸ’‘ *Tip: Upload an image and describe the scene. The model will extend it to a full 360Β° panorama using outpainting.*"
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[prompt, input_image, seed_display, num_inference_steps, tau_slider],
outputs=[result],
)
if __name__ == "__main__":
demo.launch()