asd755's picture
Update app.py
b4675ea verified
import gradio as gr
import numpy as np
import random
import torch
import spaces
import os
import time
import tempfile
import traceback
from PIL import Image
from diffusers import QwenImageEditPlusPipeline
from gradio_client import Client, handle_file
# ==================== 环境配置 ====================
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("请设置 HF_TOKEN 环境变量")
# ==================== 多角度编辑模型 ====================
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe_angle = QwenImageEditPlusPipeline.from_pretrained(
"Qwen/Qwen-Image-Edit-2511",
torch_dtype=dtype
).to(device)
pipe_angle.load_lora_weights(
"lightx2v/Qwen-Image-Edit-2511-Lightning",
weight_name="Qwen-Image-Edit-2511-Lightning-4steps-V1.0-bf16.safetensors",
adapter_name="lightning"
)
pipe_angle.load_lora_weights(
"fal/Qwen-Image-Edit-2511-Multiple-Angles-LoRA",
weight_name="qwen-image-edit-2511-multiple-angles-lora.safetensors",
adapter_name="angles"
)
pipe_angle.set_adapters(["lightning", "angles"], adapter_weights=[1.0, 1.0])
MAX_SEED = np.iinfo(np.int32).max
AZIMUTH_MAP = {0: "front view", 45: "front-right quarter view", 90: "right side view",
135: "back-right quarter view", 180: "back view", 225: "back-left quarter view",
270: "left side view", 315: "front-left quarter view"}
ELEVATION_MAP = {-30: "low-angle shot", 0: "eye-level shot", 30: "elevated shot", 60: "high-angle shot"}
DISTANCE_MAP = {0.6: "close-up", 1.0: "medium shot", 1.8: "wide shot"}
def snap_to_nearest(value, options):
return min(options, key=lambda x: abs(x - value))
def build_camera_prompt(azimuth, elevation, distance):
azimuth_snapped = snap_to_nearest(azimuth, list(AZIMUTH_MAP.keys()))
elevation_snapped = snap_to_nearest(elevation, list(ELEVATION_MAP.keys()))
distance_snapped = snap_to_nearest(distance, list(DISTANCE_MAP.keys()))
return f"<sks> {AZIMUTH_MAP[azimuth_snapped]} {ELEVATION_MAP[elevation_snapped]} {DISTANCE_MAP[distance_snapped]}"
@spaces.GPU
def generate_image(image, azimuth=0.0, elevation=0.0, distance=1.0, seed=0, randomize_seed=True,
guidance_scale=1.0, num_inference_steps=4, height=1024, width=1024):
prompt = build_camera_prompt(azimuth, elevation, distance)
print(f"Generated Prompt: {prompt}")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
if image is None:
raise gr.Error("请上传图片")
pil_image = image.convert("RGB") if isinstance(image, Image.Image) else Image.open(image).convert("RGB")
result = pipe_angle(
image=[pil_image],
prompt=prompt,
height=height if height != 0 else None,
width=width if width != 0 else None,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
).images[0]
return result
def update_dimensions_on_upload(image):
if image is None:
return 1024, 1024
w, h = image.size
if w > h:
new_w = 1024
new_h = int(1024 * h / w)
else:
new_h = 1024
new_w = int(1024 * w / h)
new_w = (new_w // 8) * 8
new_h = (new_h // 8) * 8
return new_w, new_h
# ==================== 全景生成辅助函数 ====================
outpaint_client = Client("fffiloni/diffusers-image-outpaint", verbose=False)
flux_client = Client("black-forest-labs/FLUX.2-dev", verbose=False)
inpaint_client = Client("diffusers/stable-diffusion-xl-inpainting", verbose=False)
def safe_outpaint(image_path, prompt, steps=8, overlap=5):
try:
result = outpaint_client.predict(
image=handle_file(image_path),
width=1280, height=720,
overlap_percentage=overlap,
num_inference_steps=steps,
resize_option="Full",
custom_resize_percentage=50,
prompt_input=prompt,
alignment="Middle",
overlap_left=True, overlap_right=True, overlap_top=True, overlap_bottom=True,
api_name="/infer"
)
if isinstance(result, (tuple, list)) and len(result) >= 2:
img = result[1]
else:
img = result
if isinstance(img, str):
img = Image.open(img)
return img
except Exception as e:
print(f"Outpaint 失败: {e}")
traceback.print_exc()
return None
def safe_flux_call(image_path, prompt):
for attempt in range(2):
try:
img = Image.open(image_path).convert("RGB")
img.thumbnail((1024, 1024), Image.LANCZOS)
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
img.save(tmp.name, quality=95)
pre_path = tmp.name
input_gallery = [{"image": handle_file(pre_path), "caption": None}]
result, _ = flux_client.predict(
prompt=prompt[:200],
input_images=input_gallery,
seed=42,
randomize_seed=False,
width=1024, height=512,
num_inference_steps=30,
guidance_scale=4,
prompt_upsampling=False,
api_name="/infer"
)
if isinstance(result, dict):
if 'path' in result:
img = Image.open(result['path'])
elif 'url' in result:
img = Image.open(result['url'])
else:
raise ValueError
elif isinstance(result, str):
img = Image.open(result)
else:
img = result
return img
except Exception as e:
print(f"FLUX 调用失败 (尝试 {attempt+1}): {e}")
time.sleep(2)
return None
def seam_fix(img, prompt="", seam_width=48, strength=0.7, debug_mode=True):
w, h = img.size
left = img.crop((0, 0, w//2, h))
right = img.crop((w//2, 0, w, h))
swapped = Image.new("RGB", (w, h))
swapped.paste(right, (0, 0))
swapped.paste(left, (w//2, 0))
seam_x = w//2 - seam_width//2
import numpy as np
mask_arr = np.zeros((h, w), dtype=np.uint8)
for i in range(seam_width):
ratio = 1 - abs(i - seam_width//2) / (seam_width//2)
alpha = int(255 * ratio)
mask_arr[:, seam_x + i] = alpha
mask_full = Image.fromarray(mask_arr, mode='L')
mask_rgba = Image.new("RGBA", (w, h), (0,0,0,0))
mask_rgba.putalpha(mask_full)
if debug_mode:
import os
debug_dir = "/tmp/debug_inpaint"
os.makedirs(debug_dir, exist_ok=True)
swapped.save(os.path.join(debug_dir, "swapped.png"))
mask_full.save(os.path.join(debug_dir, "mask_full.png"))
print(f"Strength: {strength}, Seam width: {seam_width}, Seam_x: {seam_x}")
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f_bg, \
tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f_layer, \
tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f_composite:
swapped.save(f_bg.name)
mask_rgba.save(f_layer.name)
swapped.save(f_composite.name)
result = inpaint_client.predict(
input_image={
"background": handle_file(f_bg.name),
"layers": [handle_file(f_layer.name)],
"composite": handle_file(f_composite.name)
},
prompt="seamless transition, natural blending, no visible seam",
negative_prompt="seam, line, border, cut, artifact, blur, low quality",
guidance_scale=7.5,
steps=30,
strength=strength, # 使用传入的强度值
scheduler="EulerDiscreteScheduler",
api_name="/predict"
)
if isinstance(result, (tuple, list)) and len(result) >= 2:
final = result[1]
else:
final = result
if isinstance(final, str):
final = Image.open(final)
return final.resize((2048, 1024), Image.LANCZOS)
def get_last_image(state):
"""
获取最近一次生成的图像(供 API 或 UI 按钮调用)
参数 state: gr.State 中存储的 PIL 图像对象
返回: 临时 PNG 文件路径(用于下载)
"""
if state is None:
raise gr.Error("没有已生成的图像,请先运行任意步骤")
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
state.save(tmp.name, format='PNG')
return tmp.name
# ==================== 独立步骤函数 ====================
def run_angle(image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, steps):
if image is None:
raise gr.Error("当前没有输入图像,请上传或复制一个图像到输入框")
result = generate_image(image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, steps, 1024, 1024)
# 返回两个值:图像输出和状态更新
return result, result
def run_outpaint(image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, angle_steps, outpaint_steps, outpaint_overlap):
if image is None:
raise gr.Error("当前没有输入图像,请上传或复制一个图像到输入框")
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
image.save(tmp.name, format='PNG')
img_path = tmp.name
try:
prompt_text = f"extend scene naturally, {build_camera_prompt(azimuth, elevation, distance)}"
result = safe_outpaint(img_path, prompt_text, steps=outpaint_steps, overlap=outpaint_overlap)
if result is None:
raise gr.Error("Outpaint 失败")
return result, result
finally:
if os.path.exists(img_path):
os.unlink(img_path)
def run_flux(image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale,
angle_steps, outpaint_steps, outpaint_overlap, keep_area, user_prompt):
if image is None:
raise gr.Error("当前没有输入图像,请上传或复制一个图像到输入框")
w, h = image.size
# 根据保留区域预处理图像(高质量复制)
if keep_area == 'left':
left_part = image.crop((0, 0, w//2, h))
new_img = Image.new('RGB', (w, h), (128,128,128))
new_img.paste(left_part, (0, 0))
elif keep_area == 'right':
right_part = image.crop((w//2, 0, w, h))
new_img = Image.new('RGB', (w, h), (128,128,128))
new_img.paste(right_part, (w//2, 0))
elif keep_area == 'center':
center_width = w // 3
start_x = w//2 - center_width//2
center_part = image.crop((start_x, 0, start_x + center_width, h))
new_img = Image.new('RGB', (w, h), (128,128,128))
new_img.paste(center_part, (start_x, 0))
else: # 'full'
new_img = image
# 使用 PNG 保存,保证输入质量
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
new_img.save(tmp.name, format='PNG')
img_path = tmp.name
try:
# 构建基础提示词
base_prompt = f"360 equirectangular panorama, 2:1 aspect ratio, high quality, {build_camera_prompt(azimuth, elevation, distance)}"
if user_prompt and user_prompt.strip():
flux_prompt = f"{base_prompt}. {user_prompt.strip()}"
else:
flux_prompt = base_prompt
# 添加保留区域后缀
if keep_area == 'left':
flux_prompt += " Keep the left part exactly as is, extend the scene to the right naturally without repeating content."
elif keep_area == 'right':
flux_prompt += " Keep the right part exactly as is, extend the scene to the left naturally without repeating content."
elif keep_area == 'center':
flux_prompt += " Keep the central part unchanged, extend both sides without repeating content."
result = safe_flux_call(img_path, flux_prompt)
if result is None:
raise gr.Error("FLUX 生成失败")
# 返回两个值:图像和状态(状态也是图像)
return result, result
finally:
if os.path.exists(img_path):
os.unlink(img_path)
def run_inpaint(image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, angle_steps, outpaint_steps, outpaint_overlap, inpaint_width, inpaint_strength):
if image is None:
raise gr.Error("当前没有输入图像,请上传或复制一个图像到输入框")
result = seam_fix(image, build_camera_prompt(azimuth, elevation, distance), inpaint_width, inpaint_strength)
return result, result
# ==================== Gradio 界面 ====================
with gr.Blocks(title="Flexible Panorama Pipeline", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 灵活全景生成工具")
last_image_state = gr.State(value=None)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="当前输入图像", height=200, interactive=True)
# 角度参数(通用)
azimuth = gr.Slider(0, 315, step=45, value=0, label="方位角 (°)")
elevation = gr.Slider(-30, 60, step=30, value=0, label="仰角 (°)")
distance = gr.Slider(0.6, 1.8, step=0.4, value=1.0, label="距离系数")
# ========== 多角度编辑高级设置 ==========
with gr.Accordion("⚙️ 多角度编辑高级设置", open=False):
seed = gr.Slider(0, MAX_SEED, step=1, value=0, label="种子")
randomize_seed = gr.Checkbox(value=True, label="随机种子")
guidance_scale = gr.Slider(1.0, 10.0, step=0.1, value=1.0, label="引导系数")
angle_steps = gr.Slider(1, 20, step=1, value=4, label="推理步数")
# ========== Outpaint 高级设置 ==========
with gr.Accordion("🎨 Outpaint 高级设置", open=False):
outpaint_steps = gr.Slider(4, 20, step=1, value=8, label="步数")
outpaint_overlap = gr.Slider(5, 30, step=1, value=5, label="重叠百分比 (%)")
# ========== FLUX 高级设置 ==========
with gr.Accordion("✨ FLUX 高级设置", open=False):
keep_area = gr.Radio(
choices=['full', 'left', 'center', 'right'],
value='full',
label="保留区域(防止重复)",
info="选择要保留的图像区域,其他区域由 AI 生成"
)
flux_prompt_text = gr.Textbox(
label="自定义提示词(可选)",
placeholder="例如:fantasy landscape, cyberpunk style...",
info="附加到默认提示词后面"
)
# ========== 接缝修补高级设置 ==========
with gr.Accordion("🔧 接缝修补高级设置", open=False):
inpaint_width = gr.Slider(
16, 256, step=8, value=48,
label="修补宽度 (px)",
info="中央修补区域的宽度,值越大覆盖越宽"
)
inpaint_strength = gr.Slider(
0.2, 0.9, step=0.05, value=0.7,
label="修补强度",
info="值越高,AI 对中央区域的修改程度越大,适用于明显接缝;值越低越贴近原图。"
)
# 步骤按钮
with gr.Row():
btn_angle = gr.Button("1. 多角度编辑", variant="primary")
btn_outpaint = gr.Button("2. Outpaint 扩展", variant="secondary")
btn_flux = gr.Button("3. FLUX 全景生成", variant="secondary")
btn_inpaint = gr.Button("4. 接缝修补", variant="secondary")
with gr.Column():
angle_output = gr.Image(type="pil", label="1. 多角度结果", height=150)
with gr.Row():
angle_copy = gr.Button("📋 设为输入", size="sm")
outpaint_output = gr.Image(type="pil", label="2. Outpaint 结果", height=150)
with gr.Row():
outpaint_copy = gr.Button("📋 设为输入", size="sm")
flux_output = gr.Image(type="pil", label="3. FLUX 结果", height=150)
with gr.Row():
flux_copy = gr.Button("📋 设为输入", size="sm")
final_output = gr.Image(type="pil", label="4. 最终全景图", height=150)
with gr.Row():
final_copy = gr.Button("📋 设为输入", size="sm")
# 多角度编辑(参数:image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, angle_steps)
btn_angle.click(
fn=run_angle,
inputs=[input_image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, angle_steps],
outputs=[angle_output, last_image_state],
api_name="run_angle"
)
# Outpaint(参数:image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, angle_steps, outpaint_steps, outpaint_overlap)
btn_outpaint.click(
fn=run_outpaint,
inputs=[input_image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale,
angle_steps, outpaint_steps, outpaint_overlap],
outputs=[outpaint_output, last_image_state],
api_name="run_outpaint"
)
# FLUX(参数:image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, angle_steps, outpaint_steps, outpaint_overlap, keep_area, flux_prompt_text)
btn_flux.click(
fn=run_flux,
inputs=[input_image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale,
angle_steps, outpaint_steps, outpaint_overlap, keep_area, flux_prompt_text],
outputs=[flux_output, last_image_state],
api_name="run_flux"
)
# 接缝修补(参数:image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, angle_steps, outpaint_steps, outpaint_overlap, inpaint_width)
btn_inpaint.click(
fn=run_inpaint,
inputs=[input_image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale,
angle_steps, outpaint_steps, outpaint_overlap, inpaint_width, inpaint_strength],
outputs=[final_output, last_image_state],
api_name="run_inpaint"
)
# 复制按钮
angle_copy.click(
fn=lambda img: gr.update(value=img),
inputs=[angle_output],
outputs=[input_image]
)
outpaint_copy.click(
fn=lambda img: gr.update(value=img),
inputs=[outpaint_output],
outputs=[input_image]
)
flux_copy.click(
fn=lambda img: gr.update(value=img),
inputs=[flux_output],
outputs=[input_image]
)
final_copy.click(
fn=lambda img: gr.update(value=img),
inputs=[final_output],
outputs=[input_image]
)
get_btn = gr.Button("📸 获取最终图像", variant="secondary")
download_file = gr.File(label="点击下载最终图像")
get_btn.click(
fn=get_last_image,
inputs=[last_image_state],
outputs=download_file
)
if __name__ == "__main__":
demo.queue()
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)