DiT360 / app.py
asd755's picture
Update app.py
f4d642d verified
import gradio as gr
import torch
import numpy as np
import random
from PIL import Image
import spaces
import os
import gc
from pa_src.pipeline import RFPanoInversionParallelFluxPipeline
from pa_src.attn_processor import PersonalizeAnythingAttnProcessor, set_flux_transformer_attn_processor
from pa_src.utils import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
pipe = RFPanoInversionParallelFluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=dtype,
low_cpu_mem_usage=True
).to(device)
pipe.load_lora_weights("Insta360-Research/DiT360-Panorama-Image-Generation")
MAX_SEED = np.iinfo(np.int32).max
def generate_seed():
return random.randint(0, MAX_SEED)
def create_outpainting_mask(image, target_size=(2048, 1024)): # Use full target res (model trained on 1024Γ—2048)
w, h = image.size
target_w, target_h = target_size
# Create canvas with gray (fallback color, but won't matter much)
canvas = Image.new("RGB", (target_w, target_h), (128, 128, 128))
# Paste input centered (for symmetric outpainting)
paste_x = (target_w - w) // 2
paste_y = (target_h - h) // 2
canvas.paste(image, (paste_x, paste_y))
# MASK: 1 = preserve (white), 0 = generate (black)
mask_img = Image.new("L", (target_w, target_h), 0) # Start with all generate (black)
mask_img.paste(255, (paste_x, paste_y, paste_x + w, paste_y + h)) # Center = preserve (white!)
return canvas, mask_img
def prepare_mask_for_pipeline(mask_img, latent_w, latent_h):
mask = np.array(mask_img.resize((latent_w, latent_h))) / 255.0
mask = torch.from_numpy(mask).float().to(device)
mask = torch.cat([mask[:, 0:1], mask, mask[:, -1:]], dim=-1).view(-1, 1)
return mask
@spaces.GPU
def infer(
prompt,
input_image,
seed,
num_inference_steps,
guidance_scale=2.8,
tau=50,
progress=gr.Progress(track_tqdm=True),
):
try:
if input_image is None:
raise gr.Error("Please upload an input image for outpainting.")
with torch.inference_mode():
torch.cuda.empty_cache()
generator = torch.Generator(device=device).manual_seed(int(seed))
# ──────────────────────────────────────────────────────────────
# 1. Force model training resolution β€” this is mandatory for DiT360/Flux
# ──────────────────────────────────────────────────────────────
target_height = 1024
target_width = 2048
# Downscale input image aggressively to leave enough space for outpainting
# and avoid latent size explosion on ZeroGPU
max_input_side = 640 # safe value for ZeroGPU + model conditioning
input_w, input_h = input_image.size
if max(input_w, input_h) > max_input_side:
scale = max_input_side / max(input_w, input_h)
input_image = input_image.resize(
(int(input_w * scale), int(input_h * scale)),
Image.LANCZOS
)
# ──────────────────────────────────────────────────────────────
# 2. Create canvas + correct mask (1 = preserve, 0 = generate)
# ──────────────────────────────────────────────────────────────
canvas = Image.new("RGB", (target_width, target_height), (127, 127, 127))
paste_x = (target_width - input_image.width) // 2
paste_y = (target_height - input_image.height) // 2
canvas.paste(input_image, (paste_x, paste_y))
mask_img = Image.new("L", (target_width, target_height), 0) # all generate
mask_img.paste(255, (paste_x, paste_y, paste_x + input_image.width, paste_y + input_image.height))
# ──────────────────────────────────────────────────────────────
# 3. Descriptive source prompt β€” very important for inversion quality
# ──────────────────────────────────────────────────────────────
source_prompt = "A historical black and white photograph of an old industrial power station building with tall chimney and surrounding structures in Hong Kong"
# Optional: you can make it dynamic based on user prompt or add image captioning later
full_prompt = f"A seamless 360Β° equirectangular panorama, photorealistic, high detail, {prompt.strip()}"
# ──────────────────────────────────────────────────────────────
# 4. Inversion β€” condition strongly on the pasted canvas
# ──────────────────────────────────────────────────────────────
inverted_latents, image_latents, latent_image_ids = pipe.invert(
source_prompt=source_prompt,
image=canvas,
height=target_height,
width=target_width,
num_inversion_steps=num_inference_steps,
gamma=1.2, # 1.0–1.5 range usually good
)
# ──────────────────────────────────────────────────────────────
# 5. Prepare mask for PersonalizeAnythingAttnProcessor
# ──────────────────────────────────────────────────────────────
scale_factor = pipe.vae_scale_factor
latent_h = target_height // (scale_factor * 2)
latent_w = target_width // (scale_factor * 2)
img_dims = latent_h * (latent_w + 2)
mask = prepare_mask_for_pipeline(mask_img, latent_w, latent_h)
# ──────────────────────────────────────────────────────────────
# 6. Set attention processor with current tau
# ──────────────────────────────────────────────────────────────
set_flux_transformer_attn_processor(
pipe.transformer,
set_attn_proc_func=lambda name, dh, nh, ap: PersonalizeAnythingAttnProcessor(
name=name,
tau=tau / 100.0,
mask=mask,
device=device,
img_dims=img_dims
),
)
# ──────────────────────────────────────────────────────────────
# 7. Generation
# ──────────────────────────────────────────────────────────────
result_images = pipe(
prompt=[source_prompt, full_prompt],
inverted_latents=inverted_latents,
image_latents=image_latents,
latent_image_ids=latent_image_ids,
height=target_height,
width=target_width,
start_timestep=0.0,
stop_timestep=0.99,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
eta=1.0,
generator=generator,
mask=mask,
use_timestep=True,
).images
final_image = result_images[1] # the conditioned / edited one
torch.cuda.empty_cache()
gc.collect()
return final_image
except Exception as e:
import traceback
error_msg = str(e) + "\n" + traceback.format_exc()
print(error_msg) # logs for you
if "aborted" in str(e).lower() or "CUDA" in str(e) or "RuntimeError" in str(e):
return gr.update(value=None), gr.update(value="ZeroGPU aborted (common glitch). Try duplicating the Space or lower steps/tau.")
else:
raise gr.Error(f"Generation failed: {str(e)}")
# -------------------- Gradio η•Œι’ --------------------
css = """
#main-container {
display: flex;
flex-direction: column;
gap: 2rem;
margin-top: 1rem;
}
#top-row {
display: flex;
flex-direction: row;
justify-content: center;
align-items: flex-start;
gap: 2rem;
}
#bottom-row {
display: flex;
flex-direction: row;
gap: 2rem;
}
#image-panel {
flex: 2;
max-width: 1200px;
margin: 0 auto;
}
#input-panel {
flex: 1;
}
#example-panel {
flex: 2;
}
#settings-panel {
flex: 1;
max-width: 280px;
}
#prompt-box textarea {
resize: none !important;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(
"""
# πŸŒ€ DiT360: High-Fidelity Panoramic Image Generation with Outpainting
Here are our resources:
- πŸ’» **Code**: [https://github.com/Insta360-Research-Team/DiT360](https://github.com/Insta360-Research-Team/DiT360)
- 🌐 **Web Page**: [https://fenghora.github.io/DiT360-Page/](https://fenghora.github.io/DiT360-Page/)
- 🧠 **Pretrained Model**: [https://huggingface.co/Insta360-Research/DiT360-Panorama-Image-Generation](https://huggingface.co/Insta360-Research/DiT360-Panorama-Image-Generation)
"""
)
gr.Markdown("Official Gradio demo for **[DiT360](https://fenghora.github.io/DiT360-Page/)**, now with outpainting from a single image.")
with gr.Row(elem_id="top-row"):
with gr.Column(elem_id="top-panel"):
result = gr.Image(label="Generated Panorama", show_label=False, type="pil", height=800)
input_image = gr.Image(type="pil", label="Input Image (for outpainting)", height=300)
prompt = gr.Textbox(
elem_id="prompt-box",
placeholder="Describe your panoramic scene here...",
show_label=False,
lines=2,
container=False,
)
run_button = gr.Button("Generate Panorama", variant="primary")
with gr.Row(elem_id="bottom-row"):
with gr.Column(elem_id="example-panel"):
gr.Markdown("### πŸ“š Examples")
gr.Examples(examples=[
"A medieval castle stands proudly on a hilltop surrounded by autumn forests, with golden light spilling across the landscape.",
"A futuristic cityscape under a starry night sky.",
"A futuristic city skyline reflects on the calm river at sunset, neon lights glowing against the twilight sky.",
"A snowy mountain village under northern lights, with cozy cabins and smoke rising from chimneys.",
], inputs=[prompt])
with gr.Column(elem_id="settings-panel"):
gr.Markdown("### βš™οΈ Settings")
gr.Markdown(
"For better results, the output image is fixed at **2048Γ—1024** (2:1 aspect ratio). "
)
seed_display = gr.Number(value=0, label="Seed", interactive=True)
random_seed_button = gr.Button("🎲 Random Seed")
random_seed_button.click(fn=generate_seed, inputs=[], outputs=seed_display)
num_inference_steps = gr.Slider(10, 100, value=10, step=1, label="Inference Steps")
tau_slider = gr.Slider(0, 100, value=30, step=1, label="Tau (0=strictly follow input, 100=free generation)")
gr.Markdown(
"πŸ’‘ *Tip: Upload an image and describe the scene. The model will extend it to a full 360Β° panorama using outpainting.*"
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[prompt, input_image, seed_display, num_inference_steps, tau_slider],
outputs=[result],
)
if __name__ == "__main__":
demo.launch()