Spaces:
Running on Zero
Running on Zero
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import torch | |
| import spaces | |
| import os | |
| import time | |
| import tempfile | |
| import traceback | |
| from PIL import Image | |
| from diffusers import QwenImageEditPlusPipeline | |
| from gradio_client import Client, handle_file | |
| # ==================== 环境配置 ==================== | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if not HF_TOKEN: | |
| raise ValueError("请设置 HF_TOKEN 环境变量") | |
| # ==================== 多角度编辑模型 ==================== | |
| dtype = torch.bfloat16 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipe_angle = QwenImageEditPlusPipeline.from_pretrained( | |
| "Qwen/Qwen-Image-Edit-2511", | |
| torch_dtype=dtype | |
| ).to(device) | |
| pipe_angle.load_lora_weights( | |
| "lightx2v/Qwen-Image-Edit-2511-Lightning", | |
| weight_name="Qwen-Image-Edit-2511-Lightning-4steps-V1.0-bf16.safetensors", | |
| adapter_name="lightning" | |
| ) | |
| pipe_angle.load_lora_weights( | |
| "fal/Qwen-Image-Edit-2511-Multiple-Angles-LoRA", | |
| weight_name="qwen-image-edit-2511-multiple-angles-lora.safetensors", | |
| adapter_name="angles" | |
| ) | |
| pipe_angle.set_adapters(["lightning", "angles"], adapter_weights=[1.0, 1.0]) | |
| MAX_SEED = np.iinfo(np.int32).max | |
| AZIMUTH_MAP = {0: "front view", 45: "front-right quarter view", 90: "right side view", | |
| 135: "back-right quarter view", 180: "back view", 225: "back-left quarter view", | |
| 270: "left side view", 315: "front-left quarter view"} | |
| ELEVATION_MAP = {-30: "low-angle shot", 0: "eye-level shot", 30: "elevated shot", 60: "high-angle shot"} | |
| DISTANCE_MAP = {0.6: "close-up", 1.0: "medium shot", 1.8: "wide shot"} | |
| def snap_to_nearest(value, options): | |
| return min(options, key=lambda x: abs(x - value)) | |
| def build_camera_prompt(azimuth, elevation, distance): | |
| azimuth_snapped = snap_to_nearest(azimuth, list(AZIMUTH_MAP.keys())) | |
| elevation_snapped = snap_to_nearest(elevation, list(ELEVATION_MAP.keys())) | |
| distance_snapped = snap_to_nearest(distance, list(DISTANCE_MAP.keys())) | |
| return f"<sks> {AZIMUTH_MAP[azimuth_snapped]} {ELEVATION_MAP[elevation_snapped]} {DISTANCE_MAP[distance_snapped]}" | |
| def generate_image(image, azimuth=0.0, elevation=0.0, distance=1.0, seed=0, randomize_seed=True, | |
| guidance_scale=1.0, num_inference_steps=4, height=1024, width=1024): | |
| prompt = build_camera_prompt(azimuth, elevation, distance) | |
| print(f"Generated Prompt: {prompt}") | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| if image is None: | |
| raise gr.Error("请上传图片") | |
| pil_image = image.convert("RGB") if isinstance(image, Image.Image) else Image.open(image).convert("RGB") | |
| result = pipe_angle( | |
| image=[pil_image], | |
| prompt=prompt, | |
| height=height if height != 0 else None, | |
| width=width if width != 0 else None, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| guidance_scale=guidance_scale, | |
| num_images_per_prompt=1, | |
| ).images[0] | |
| return result | |
| def update_dimensions_on_upload(image): | |
| if image is None: | |
| return 1024, 1024 | |
| w, h = image.size | |
| if w > h: | |
| new_w = 1024 | |
| new_h = int(1024 * h / w) | |
| else: | |
| new_h = 1024 | |
| new_w = int(1024 * w / h) | |
| new_w = (new_w // 8) * 8 | |
| new_h = (new_h // 8) * 8 | |
| return new_w, new_h | |
| # ==================== 全景生成辅助函数 ==================== | |
| outpaint_client = Client("fffiloni/diffusers-image-outpaint", verbose=False) | |
| flux_client = Client("black-forest-labs/FLUX.2-dev", verbose=False) | |
| inpaint_client = Client("diffusers/stable-diffusion-xl-inpainting", verbose=False) | |
| def safe_outpaint(image_path, prompt, steps=8, overlap=5): | |
| try: | |
| result = outpaint_client.predict( | |
| image=handle_file(image_path), | |
| width=1280, height=720, | |
| overlap_percentage=overlap, | |
| num_inference_steps=steps, | |
| resize_option="Full", | |
| custom_resize_percentage=50, | |
| prompt_input=prompt, | |
| alignment="Middle", | |
| overlap_left=True, overlap_right=True, overlap_top=True, overlap_bottom=True, | |
| api_name="/infer" | |
| ) | |
| if isinstance(result, (tuple, list)) and len(result) >= 2: | |
| img = result[1] | |
| else: | |
| img = result | |
| if isinstance(img, str): | |
| img = Image.open(img) | |
| return img | |
| except Exception as e: | |
| print(f"Outpaint 失败: {e}") | |
| traceback.print_exc() | |
| return None | |
| def safe_flux_call(image_path, prompt): | |
| for attempt in range(2): | |
| try: | |
| img = Image.open(image_path).convert("RGB") | |
| img.thumbnail((1024, 1024), Image.LANCZOS) | |
| with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: | |
| img.save(tmp.name, quality=95) | |
| pre_path = tmp.name | |
| input_gallery = [{"image": handle_file(pre_path), "caption": None}] | |
| result, _ = flux_client.predict( | |
| prompt=prompt[:200], | |
| input_images=input_gallery, | |
| seed=42, | |
| randomize_seed=False, | |
| width=1024, height=512, | |
| num_inference_steps=30, | |
| guidance_scale=4, | |
| prompt_upsampling=False, | |
| api_name="/infer" | |
| ) | |
| if isinstance(result, dict): | |
| if 'path' in result: | |
| img = Image.open(result['path']) | |
| elif 'url' in result: | |
| img = Image.open(result['url']) | |
| else: | |
| raise ValueError | |
| elif isinstance(result, str): | |
| img = Image.open(result) | |
| else: | |
| img = result | |
| return img | |
| except Exception as e: | |
| print(f"FLUX 调用失败 (尝试 {attempt+1}): {e}") | |
| time.sleep(2) | |
| return None | |
| def seam_fix(img, prompt="", seam_width=48, strength=0.7, debug_mode=True): | |
| w, h = img.size | |
| left = img.crop((0, 0, w//2, h)) | |
| right = img.crop((w//2, 0, w, h)) | |
| swapped = Image.new("RGB", (w, h)) | |
| swapped.paste(right, (0, 0)) | |
| swapped.paste(left, (w//2, 0)) | |
| seam_x = w//2 - seam_width//2 | |
| import numpy as np | |
| mask_arr = np.zeros((h, w), dtype=np.uint8) | |
| for i in range(seam_width): | |
| ratio = 1 - abs(i - seam_width//2) / (seam_width//2) | |
| alpha = int(255 * ratio) | |
| mask_arr[:, seam_x + i] = alpha | |
| mask_full = Image.fromarray(mask_arr, mode='L') | |
| mask_rgba = Image.new("RGBA", (w, h), (0,0,0,0)) | |
| mask_rgba.putalpha(mask_full) | |
| if debug_mode: | |
| import os | |
| debug_dir = "/tmp/debug_inpaint" | |
| os.makedirs(debug_dir, exist_ok=True) | |
| swapped.save(os.path.join(debug_dir, "swapped.png")) | |
| mask_full.save(os.path.join(debug_dir, "mask_full.png")) | |
| print(f"Strength: {strength}, Seam width: {seam_width}, Seam_x: {seam_x}") | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f_bg, \ | |
| tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f_layer, \ | |
| tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f_composite: | |
| swapped.save(f_bg.name) | |
| mask_rgba.save(f_layer.name) | |
| swapped.save(f_composite.name) | |
| result = inpaint_client.predict( | |
| input_image={ | |
| "background": handle_file(f_bg.name), | |
| "layers": [handle_file(f_layer.name)], | |
| "composite": handle_file(f_composite.name) | |
| }, | |
| prompt="seamless transition, natural blending, no visible seam", | |
| negative_prompt="seam, line, border, cut, artifact, blur, low quality", | |
| guidance_scale=7.5, | |
| steps=30, | |
| strength=strength, # 使用传入的强度值 | |
| scheduler="EulerDiscreteScheduler", | |
| api_name="/predict" | |
| ) | |
| if isinstance(result, (tuple, list)) and len(result) >= 2: | |
| final = result[1] | |
| else: | |
| final = result | |
| if isinstance(final, str): | |
| final = Image.open(final) | |
| return final.resize((2048, 1024), Image.LANCZOS) | |
| def get_last_image(state): | |
| """ | |
| 获取最近一次生成的图像(供 API 或 UI 按钮调用) | |
| 参数 state: gr.State 中存储的 PIL 图像对象 | |
| 返回: 临时 PNG 文件路径(用于下载) | |
| """ | |
| if state is None: | |
| raise gr.Error("没有已生成的图像,请先运行任意步骤") | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: | |
| state.save(tmp.name, format='PNG') | |
| return tmp.name | |
| # ==================== 独立步骤函数 ==================== | |
| def run_angle(image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, steps): | |
| if image is None: | |
| raise gr.Error("当前没有输入图像,请上传或复制一个图像到输入框") | |
| result = generate_image(image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, steps, 1024, 1024) | |
| # 返回两个值:图像输出和状态更新 | |
| return result, result | |
| def run_outpaint(image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, angle_steps, outpaint_steps, outpaint_overlap): | |
| if image is None: | |
| raise gr.Error("当前没有输入图像,请上传或复制一个图像到输入框") | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: | |
| image.save(tmp.name, format='PNG') | |
| img_path = tmp.name | |
| try: | |
| prompt_text = f"extend scene naturally, {build_camera_prompt(azimuth, elevation, distance)}" | |
| result = safe_outpaint(img_path, prompt_text, steps=outpaint_steps, overlap=outpaint_overlap) | |
| if result is None: | |
| raise gr.Error("Outpaint 失败") | |
| return result, result | |
| finally: | |
| if os.path.exists(img_path): | |
| os.unlink(img_path) | |
| def run_flux(image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, | |
| angle_steps, outpaint_steps, outpaint_overlap, keep_area, user_prompt): | |
| if image is None: | |
| raise gr.Error("当前没有输入图像,请上传或复制一个图像到输入框") | |
| w, h = image.size | |
| # 根据保留区域预处理图像(高质量复制) | |
| if keep_area == 'left': | |
| left_part = image.crop((0, 0, w//2, h)) | |
| new_img = Image.new('RGB', (w, h), (128,128,128)) | |
| new_img.paste(left_part, (0, 0)) | |
| elif keep_area == 'right': | |
| right_part = image.crop((w//2, 0, w, h)) | |
| new_img = Image.new('RGB', (w, h), (128,128,128)) | |
| new_img.paste(right_part, (w//2, 0)) | |
| elif keep_area == 'center': | |
| center_width = w // 3 | |
| start_x = w//2 - center_width//2 | |
| center_part = image.crop((start_x, 0, start_x + center_width, h)) | |
| new_img = Image.new('RGB', (w, h), (128,128,128)) | |
| new_img.paste(center_part, (start_x, 0)) | |
| else: # 'full' | |
| new_img = image | |
| # 使用 PNG 保存,保证输入质量 | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: | |
| new_img.save(tmp.name, format='PNG') | |
| img_path = tmp.name | |
| try: | |
| # 构建基础提示词 | |
| base_prompt = f"360 equirectangular panorama, 2:1 aspect ratio, high quality, {build_camera_prompt(azimuth, elevation, distance)}" | |
| if user_prompt and user_prompt.strip(): | |
| flux_prompt = f"{base_prompt}. {user_prompt.strip()}" | |
| else: | |
| flux_prompt = base_prompt | |
| # 添加保留区域后缀 | |
| if keep_area == 'left': | |
| flux_prompt += " Keep the left part exactly as is, extend the scene to the right naturally without repeating content." | |
| elif keep_area == 'right': | |
| flux_prompt += " Keep the right part exactly as is, extend the scene to the left naturally without repeating content." | |
| elif keep_area == 'center': | |
| flux_prompt += " Keep the central part unchanged, extend both sides without repeating content." | |
| result = safe_flux_call(img_path, flux_prompt) | |
| if result is None: | |
| raise gr.Error("FLUX 生成失败") | |
| # 返回两个值:图像和状态(状态也是图像) | |
| return result, result | |
| finally: | |
| if os.path.exists(img_path): | |
| os.unlink(img_path) | |
| def run_inpaint(image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, angle_steps, outpaint_steps, outpaint_overlap, inpaint_width, inpaint_strength): | |
| if image is None: | |
| raise gr.Error("当前没有输入图像,请上传或复制一个图像到输入框") | |
| result = seam_fix(image, build_camera_prompt(azimuth, elevation, distance), inpaint_width, inpaint_strength) | |
| return result, result | |
| # ==================== Gradio 界面 ==================== | |
| with gr.Blocks(title="Flexible Panorama Pipeline", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 灵活全景生成工具") | |
| last_image_state = gr.State(value=None) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="当前输入图像", height=200, interactive=True) | |
| # 角度参数(通用) | |
| azimuth = gr.Slider(0, 315, step=45, value=0, label="方位角 (°)") | |
| elevation = gr.Slider(-30, 60, step=30, value=0, label="仰角 (°)") | |
| distance = gr.Slider(0.6, 1.8, step=0.4, value=1.0, label="距离系数") | |
| # ========== 多角度编辑高级设置 ========== | |
| with gr.Accordion("⚙️ 多角度编辑高级设置", open=False): | |
| seed = gr.Slider(0, MAX_SEED, step=1, value=0, label="种子") | |
| randomize_seed = gr.Checkbox(value=True, label="随机种子") | |
| guidance_scale = gr.Slider(1.0, 10.0, step=0.1, value=1.0, label="引导系数") | |
| angle_steps = gr.Slider(1, 20, step=1, value=4, label="推理步数") | |
| # ========== Outpaint 高级设置 ========== | |
| with gr.Accordion("🎨 Outpaint 高级设置", open=False): | |
| outpaint_steps = gr.Slider(4, 20, step=1, value=8, label="步数") | |
| outpaint_overlap = gr.Slider(5, 30, step=1, value=5, label="重叠百分比 (%)") | |
| # ========== FLUX 高级设置 ========== | |
| with gr.Accordion("✨ FLUX 高级设置", open=False): | |
| keep_area = gr.Radio( | |
| choices=['full', 'left', 'center', 'right'], | |
| value='full', | |
| label="保留区域(防止重复)", | |
| info="选择要保留的图像区域,其他区域由 AI 生成" | |
| ) | |
| flux_prompt_text = gr.Textbox( | |
| label="自定义提示词(可选)", | |
| placeholder="例如:fantasy landscape, cyberpunk style...", | |
| info="附加到默认提示词后面" | |
| ) | |
| # ========== 接缝修补高级设置 ========== | |
| with gr.Accordion("🔧 接缝修补高级设置", open=False): | |
| inpaint_width = gr.Slider( | |
| 16, 256, step=8, value=48, | |
| label="修补宽度 (px)", | |
| info="中央修补区域的宽度,值越大覆盖越宽" | |
| ) | |
| inpaint_strength = gr.Slider( | |
| 0.2, 0.9, step=0.05, value=0.7, | |
| label="修补强度", | |
| info="值越高,AI 对中央区域的修改程度越大,适用于明显接缝;值越低越贴近原图。" | |
| ) | |
| # 步骤按钮 | |
| with gr.Row(): | |
| btn_angle = gr.Button("1. 多角度编辑", variant="primary") | |
| btn_outpaint = gr.Button("2. Outpaint 扩展", variant="secondary") | |
| btn_flux = gr.Button("3. FLUX 全景生成", variant="secondary") | |
| btn_inpaint = gr.Button("4. 接缝修补", variant="secondary") | |
| with gr.Column(): | |
| angle_output = gr.Image(type="pil", label="1. 多角度结果", height=150) | |
| with gr.Row(): | |
| angle_copy = gr.Button("📋 设为输入", size="sm") | |
| outpaint_output = gr.Image(type="pil", label="2. Outpaint 结果", height=150) | |
| with gr.Row(): | |
| outpaint_copy = gr.Button("📋 设为输入", size="sm") | |
| flux_output = gr.Image(type="pil", label="3. FLUX 结果", height=150) | |
| with gr.Row(): | |
| flux_copy = gr.Button("📋 设为输入", size="sm") | |
| final_output = gr.Image(type="pil", label="4. 最终全景图", height=150) | |
| with gr.Row(): | |
| final_copy = gr.Button("📋 设为输入", size="sm") | |
| # 多角度编辑(参数:image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, angle_steps) | |
| btn_angle.click( | |
| fn=run_angle, | |
| inputs=[input_image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, angle_steps], | |
| outputs=[angle_output, last_image_state], | |
| api_name="run_angle" | |
| ) | |
| # Outpaint(参数:image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, angle_steps, outpaint_steps, outpaint_overlap) | |
| btn_outpaint.click( | |
| fn=run_outpaint, | |
| inputs=[input_image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, | |
| angle_steps, outpaint_steps, outpaint_overlap], | |
| outputs=[outpaint_output, last_image_state], | |
| api_name="run_outpaint" | |
| ) | |
| # FLUX(参数:image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, angle_steps, outpaint_steps, outpaint_overlap, keep_area, flux_prompt_text) | |
| btn_flux.click( | |
| fn=run_flux, | |
| inputs=[input_image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, | |
| angle_steps, outpaint_steps, outpaint_overlap, keep_area, flux_prompt_text], | |
| outputs=[flux_output, last_image_state], | |
| api_name="run_flux" | |
| ) | |
| # 接缝修补(参数:image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, angle_steps, outpaint_steps, outpaint_overlap, inpaint_width) | |
| btn_inpaint.click( | |
| fn=run_inpaint, | |
| inputs=[input_image, azimuth, elevation, distance, seed, randomize_seed, guidance_scale, | |
| angle_steps, outpaint_steps, outpaint_overlap, inpaint_width, inpaint_strength], | |
| outputs=[final_output, last_image_state], | |
| api_name="run_inpaint" | |
| ) | |
| # 复制按钮 | |
| angle_copy.click( | |
| fn=lambda img: gr.update(value=img), | |
| inputs=[angle_output], | |
| outputs=[input_image] | |
| ) | |
| outpaint_copy.click( | |
| fn=lambda img: gr.update(value=img), | |
| inputs=[outpaint_output], | |
| outputs=[input_image] | |
| ) | |
| flux_copy.click( | |
| fn=lambda img: gr.update(value=img), | |
| inputs=[flux_output], | |
| outputs=[input_image] | |
| ) | |
| final_copy.click( | |
| fn=lambda img: gr.update(value=img), | |
| inputs=[final_output], | |
| outputs=[input_image] | |
| ) | |
| get_btn = gr.Button("📸 获取最终图像", variant="secondary") | |
| download_file = gr.File(label="点击下载最终图像") | |
| get_btn.click( | |
| fn=get_last_image, | |
| inputs=[last_image_state], | |
| outputs=download_file | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True) | |