Spaces:
Paused
Paused
| import base64 | |
| import io | |
| import json | |
| import os | |
| import sys | |
| from typing import Union, Any, Optional | |
| import gradio as gr | |
| import numpy as np | |
| import requests | |
| import torch | |
| from PIL import Image | |
| import spaces | |
| # 添加项目根目录到Python路径 | |
| project_root = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.append(project_root) | |
| hf_token = os.environ.get("CASCADE_PRIVATE_MODEL_HF_TOKEN") | |
| secret_model = os.environ.get("MODEL_PATH") | |
| # 从环境变量获取基础模型路径 | |
| BASE_MODEL = os.environ.get("BASE_MODEL_ID") | |
| from cascade.condition import Condition | |
| from cascade.generate import generate | |
| from cascade.lora_controller import set_lora_scale | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| # Global pipeline variable | |
| _global_pipe = None | |
| # 認証トークンを使ってファイルをダウンロード | |
| model_path = hf_hub_download( | |
| repo_id="Cascade-Inc/private_model", | |
| filename=secret_model, | |
| token=hf_token, | |
| repo_type="space" | |
| ) | |
| # Get temp directory | |
| temp_dir = os.path.join(os.path.expanduser("~"), "gradio_temp") | |
| os.makedirs(temp_dir, exist_ok=True) | |
| os.environ["GRADIO_TEMP_DIR"] = temp_dir | |
| ADAPTER_NAME = "subject" | |
| MODEL_PATH = model_path | |
| ZEN_BG_ENDPOINT = "https://zen-inpaint-1066271267292.europe-west1.run.app/" | |
| def get_gpu_memory_gb() -> float: | |
| return torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
| def init_pipeline_if_needed(): | |
| global _global_pipe | |
| if _global_pipe is not None: | |
| return _global_pipe | |
| print("🚀 Initializing pipeline...") | |
| # 如果设置了 BASE_MODEL_ID,从私有库加载预配置的 pipeline | |
| if BASE_MODEL: | |
| print(f"Loading pipeline from: {BASE_MODEL}") | |
| try: | |
| # 下载并导入私有库中的 pipeline 加载脚本 | |
| pipeline_loader_path = hf_hub_download( | |
| repo_id="Cascade-Inc/private_model", | |
| filename=BASE_MODEL, # 应该是 .py 文件,例如: "pipeline_loader.py" | |
| token=hf_token, | |
| repo_type="space" | |
| ) | |
| # 动态导入 | |
| import importlib.util | |
| spec = importlib.util.spec_from_file_location("pipeline_loader", pipeline_loader_path) | |
| pipeline_module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(pipeline_module) | |
| # 调用私有库中的函数获取 pipeline | |
| _pipe = pipeline_module.get_pipeline(hf_token) | |
| except Exception as e: | |
| print(f"❌ Error loading pipeline from {BASE_MODEL}: {e}") | |
| raise ValueError( | |
| f"Failed to load pipeline loader from BASE_MODEL_ID='{BASE_MODEL}'. " | |
| f"Make sure:\n" | |
| f"1. The file exists in Cascade-Inc/private_model space\n" | |
| f"2. BASE_MODEL_ID should be a .py file name (e.g., 'pipeline_loader.py')\n" | |
| f"3. Not the LoRA path (that's MODEL_PATH)" | |
| ) | |
| else: | |
| raise ValueError( | |
| "BASE_MODEL_ID environment variable is not set.\n" | |
| "Please set it to the pipeline loader filename (e.g., 'pipeline_loader.py')" | |
| ) | |
| print("📦 Loading model to CUDA...") | |
| _pipe = _pipe.to("cuda") | |
| print("🎨 Loading Cascade weights...") | |
| _pipe.load_lora_weights(MODEL_PATH, adapter_name=ADAPTER_NAME) | |
| _pipe.set_adapters([ADAPTER_NAME]) | |
| _global_pipe = _pipe | |
| print("✅ Pipeline initialized successfully!") | |
| return _global_pipe | |
| def _to_pil_rgba(img: Any) -> Image.Image: | |
| """Convert various inputs to PIL RGBA image""" | |
| pil: Optional[Image.Image] = None | |
| if isinstance(img, Image.Image): | |
| pil = img | |
| elif isinstance(img, np.ndarray): | |
| pil = Image.fromarray(img) | |
| elif isinstance(img, str) and os.path.exists(img): | |
| pil = Image.open(img) | |
| else: | |
| raise ValueError("Unsupported image type") | |
| if pil.mode != "RGBA": | |
| pil = pil.convert("RGBA") | |
| return pil | |
| def _center_subject_on_canvas(subject_rgba: Image.Image, canvas_width: int, canvas_height: int) -> Image.Image: | |
| """ | |
| Center the subject image on a transparent canvas that matches the requested size. | |
| """ | |
| if subject_rgba is None: | |
| return Image.new("RGBA", (canvas_width, canvas_height), (0, 0, 0, 0)) | |
| canvas = Image.new("RGBA", (canvas_width, canvas_height), (0, 0, 0, 0)) | |
| paste_x = (canvas_width - subject_rgba.width) // 2 | |
| paste_y = (canvas_height - subject_rgba.height) // 2 | |
| # If the subject is larger than canvas, crop to fit | |
| if subject_rgba.width > canvas_width or subject_rgba.height > canvas_height: | |
| subject_rgba = subject_rgba.crop( | |
| ( | |
| max(0, -paste_x), | |
| max(0, -paste_y), | |
| max(0, -paste_x) + min(canvas_width, subject_rgba.width), | |
| max(0, -paste_y) + min(canvas_height, subject_rgba.height), | |
| ) | |
| ) | |
| paste_x = max(0, paste_x) | |
| paste_y = max(0, paste_y) | |
| canvas.alpha_composite(subject_rgba, dest=(paste_x, paste_y)) | |
| return canvas | |
| def _place_subject_on_canvas( | |
| subject_rgba: Image.Image, | |
| canvas_size: int, | |
| style: str, | |
| base_coverage: float = 0.7, | |
| ) -> Image.Image: | |
| """ | |
| Place subject on transparent canvas with position and angle adjustments based on style | |
| """ | |
| canvas = Image.new("RGBA", (canvas_size, canvas_size), (0, 0, 0, 0)) | |
| # Define three styles | |
| styles = { | |
| "center": {"scale": 1.0, "rotation": 0, "pos": (0.0, 0.0)}, | |
| "tilt_left": {"scale": 0.95, "rotation": -15, "pos": (-0.1, 0.0)}, | |
| "right": {"scale": 0.95, "rotation": 0, "pos": (0.25, 0.0)}, | |
| } | |
| if style not in styles: | |
| style = "center" | |
| style_config = styles[style] | |
| # Calculate scaling | |
| subject_w, subject_h = subject_rgba.size | |
| max_dim = max(subject_w, subject_h) | |
| desired_max_dim = max(1, int(canvas_size * base_coverage * style_config["scale"])) | |
| scale = desired_max_dim / max(1, max_dim) | |
| new_w = max(1, int(subject_w * scale)) | |
| new_h = max(1, int(subject_h * scale)) | |
| resized = subject_rgba.resize((new_w, new_h), Image.LANCZOS) | |
| # Rotation | |
| rotated = resized.rotate(style_config["rotation"], expand=True, resample=Image.BICUBIC) | |
| rw, rh = rotated.size | |
| # Positioning | |
| cx = canvas_size // 2 | |
| cy = canvas_size // 2 | |
| dx = int(style_config["pos"][0] * canvas_size) | |
| dy = int(style_config["pos"][1] * canvas_size) | |
| paste_x = int(cx + dx - rw // 2) | |
| paste_y = int(cy + dy - rh // 2) | |
| canvas.alpha_composite(rotated, dest=(paste_x, paste_y)) | |
| return canvas | |
| def _place_subject_on_canvas_rect( | |
| subject_rgba: Image.Image, | |
| canvas_width: int, | |
| canvas_height: int, | |
| style: str, | |
| base_coverage: float = 0.7, | |
| ) -> Image.Image: | |
| """ | |
| Place subject on rectangular transparent canvas with position and angle adjustments based on style | |
| """ | |
| canvas = Image.new("RGBA", (canvas_width, canvas_height), (0, 0, 0, 0)) | |
| # Define three styles | |
| styles = { | |
| "center": {"scale": 1.0, "rotation": 0, "pos": (0.0, 0.0)}, | |
| "tilt_left": {"scale": 0.95, "rotation": -15, "pos": (-0.1, 0.0)}, | |
| "right": {"scale": 0.95, "rotation": 0, "pos": (0.25, 0.0)}, | |
| } | |
| if style not in styles: | |
| style = "center" | |
| style_config = styles[style] | |
| # Calculate scaling based on smaller dimension | |
| subject_w, subject_h = subject_rgba.size | |
| max_dim = max(subject_w, subject_h) | |
| canvas_min_dim = min(canvas_width, canvas_height) | |
| desired_max_dim = max(1, int(canvas_min_dim * base_coverage * style_config["scale"])) | |
| scale = desired_max_dim / max(1, max_dim) | |
| new_w = max(1, int(subject_w * scale)) | |
| new_h = max(1, int(subject_h * scale)) | |
| resized = subject_rgba.resize((new_w, new_h), Image.LANCZOS) | |
| # Rotation | |
| rotated = resized.rotate(style_config["rotation"], expand=True, resample=Image.BICUBIC) | |
| rw, rh = rotated.size | |
| # Positioning | |
| cx = canvas_width // 2 | |
| cy = canvas_height // 2 | |
| dx = int(style_config["pos"][0] * canvas_width) | |
| dy = int(style_config["pos"][1] * canvas_height) | |
| paste_x = int(cx + dx - rw // 2) | |
| paste_y = int(cy + dy - rh // 2) | |
| canvas.alpha_composite(rotated, dest=(paste_x, paste_y)) | |
| return canvas | |
| def apply_style(image: Image.Image, style: str, width: int = 1024, height: int = 1024) -> Image.Image: | |
| """Apply specified style to image with custom dimensions""" | |
| if image is None: | |
| # Create default transparent image | |
| image = Image.new("RGBA", (512, 512), (255, 255, 255, 0)) | |
| # Ensure image is in RGBA format | |
| if image.mode != "RGBA": | |
| image = image.convert("RGBA") | |
| # Apply style with custom dimensions | |
| styled_image = _place_subject_on_canvas_rect(image, width, height, style) | |
| return styled_image | |
| def generate_background_local(styled_image: Image.Image, prompt: str, steps: int = 10, width: int = 1024, height: int = 1024) -> Image.Image: | |
| """Generate background using local model""" | |
| width = int(width) | |
| height = int(height) | |
| pipe = init_pipeline_if_needed() | |
| if styled_image is None: | |
| return Image.new("RGB", (width, height), (255, 255, 255)) | |
| # Ensure the subject image matches requested canvas size | |
| styled_image = _center_subject_on_canvas(styled_image, width, height) | |
| # Convert to RGB for background generation | |
| img_rgb = styled_image.convert("RGB") | |
| condition = Condition(ADAPTER_NAME, img_rgb, position_delta=(0, 0)) | |
| # Enable padding token orthogonalization for enhanced text-image alignment | |
| model_config = { | |
| 'padding_orthogonalization_enabled': True, | |
| 'preserve_norm': True, | |
| 'orthogonalize_all_tokens': False, | |
| } | |
| with set_lora_scale([ADAPTER_NAME], scale=3.0): | |
| result_img = generate( | |
| pipe, | |
| model_config=model_config, | |
| prompt=prompt.strip() if prompt else "", | |
| conditions=[condition], | |
| num_inference_steps=steps, | |
| height=height, | |
| width=width, | |
| default_lora=True, | |
| ).images[0] | |
| return result_img | |
| def image_to_base64(image: Image.Image) -> str: | |
| """Convert PIL Image to base64 string (PNG to preserve transparency)""" | |
| if image.mode != "RGBA": | |
| image = image.convert("RGBA") | |
| buffer = io.BytesIO() | |
| image.save(buffer, format="PNG") | |
| img_bytes = buffer.getvalue() | |
| return base64.b64encode(img_bytes).decode("utf-8") | |
| def generate_background_api( | |
| styled_image: Image.Image, | |
| prompt: str, | |
| steps: int = 4, | |
| api_key: str = "", | |
| email: str = "", | |
| zen_mode: str = "bg_generation", | |
| ) -> Image.Image: | |
| """Generate background using API""" | |
| if styled_image is None: | |
| return Image.new("RGB", (1024, 1024), (255, 255, 255)) | |
| if not api_key or not email: | |
| return Image.new("RGB", styled_image.size, (255, 200, 200)) # Red tint to indicate error | |
| try: | |
| width, height = styled_image.size | |
| base64_image = image_to_base64(styled_image) | |
| # Ensure padding so the API always receives valid Base64 chunks | |
| subject_base64 = base64_image + "=" * (-len(base64_image) % 4) | |
| # Map legacy UI modes to documented gen_mode values | |
| gen_mode = { | |
| "subject": "bg_generation", | |
| "canny": "bg_generation", | |
| "bg_generation": "bg_generation", | |
| }.get(zen_mode, "bg_generation") | |
| max_dim = max(width, height) | |
| if max_dim <= 1024: | |
| upscale = "1k" | |
| elif max_dim <= 1536: | |
| upscale = "1.5k" | |
| else: | |
| upscale = "2k" | |
| payload = { | |
| "gen_mode": gen_mode, | |
| "prompt": prompt.strip() if prompt else "professional product photography background", | |
| "subject": subject_base64, | |
| "subject_format": "base64", | |
| "background": "", | |
| "negative_prompt": "", | |
| "steps": int(steps), | |
| "seed": 42, | |
| "randomize_seed": True, | |
| "bg_upscale_choice": upscale, | |
| "max_bg_side_px": int(max_dim), | |
| "output_image_format": "base64", | |
| "use_bg_size_for_output": True, | |
| } | |
| headers = { | |
| "x-api-key": api_key, | |
| "x-email": email, | |
| "Content-Type": "application/json", | |
| } | |
| response = requests.post( | |
| ZEN_BG_ENDPOINT, | |
| headers=headers, | |
| json=payload, | |
| timeout=60, | |
| ) | |
| if response.status_code == 200: | |
| try: | |
| result_data = response.json() | |
| except Exception: | |
| print(f"[API] Unable to parse response JSON: {response.text[:200]}") | |
| result_data = {} | |
| image_field = result_data.get("image") | |
| if image_field: | |
| if image_field.startswith("http"): | |
| try: | |
| img_resp = requests.get(image_field, timeout=60) | |
| img_resp.raise_for_status() | |
| return Image.open(io.BytesIO(img_resp.content)) | |
| except Exception as download_err: | |
| print(f"[API] Failed to download image URL: {download_err}") | |
| else: | |
| try: | |
| img_data = base64.b64decode(image_field) | |
| return Image.open(io.BytesIO(img_data)) | |
| except Exception as decode_err: | |
| print(f"[API] Failed to decode base64 response: {decode_err}") | |
| print(f"[API] 200 response without image: {result_data}") | |
| else: | |
| print(f"[API] Non-200 response ({response.status_code}): {response.text[:500]}") | |
| return Image.new("RGB", styled_image.size, (255, 200, 200)) | |
| except Exception as e: | |
| print(f"API Error: {e}") | |
| return Image.new("RGB", styled_image.size, (255, 200, 200)) | |
| def generate_background( | |
| styled_image: Image.Image, | |
| prompt: str, | |
| steps: int = 10, | |
| use_api: bool = False, | |
| api_key: str = "", | |
| email: str = "", | |
| width: int = 1024, | |
| height: int = 1024, | |
| mode: str = "subject", | |
| ) -> Image.Image: | |
| """Generate background using either API or local model""" | |
| if use_api: | |
| return generate_background_api( | |
| styled_image, prompt, steps, api_key, email, zen_mode=mode | |
| ) | |
| return generate_background_local(styled_image, prompt, steps, width, height) | |
| # Gradio Interface | |
| def create_simple_app(): | |
| # Example prompts for reference | |
| example_prompts = [ | |
| { | |
| "title": "Handcrafted Leather Wallet", | |
| "prompt": "A premium lifestyle advertisement for a hand-stitched dark brown leather wallet. The wallet is half-open on a timeworn walnut desk, revealing the suede interior and a few vintage travel tickets. Surround it with a rolled map, brass fountain pen, and antique compass to emphasize heritage craftsmanship. Soft amber light from a desk lamp on the right grazes the grainy leather and creates gentle shadow falloff, while a blurred wall of old books fills the background. Overall tone is classic, rustic, and aspirational." | |
| }, | |
| { | |
| "title": "Sparkling Water with Fresh Lemons", | |
| "prompt": "A product hero shot for a premium sparkling water infused with fresh lemons. Place a dewy glass bottle at the center of a white marble countertop, with a tall tumbler filled with effervescent water, thin lemon wheels, and crystal-clear ice cubes beside it. Scatter a few lemon zest curls and condensation droplets for sensory detail. Use a soft-focus pale blue and white gradient background to communicate freshness, and bathe the scene in bright, cool, top-down lighting that creates sharp reflections. Keep the styling ultra-clean, crisp, and minimalist." | |
| }, | |
| { | |
| "title": "High-tech Smartwatch", | |
| "prompt": "A cinematic tech advertisement for a titanium smartwatch with an always-on illuminated screen displaying futuristic UI graphics. Position the watch on a jagged slab of matte black slate to contrast its polished chamfered edges. Behind it, place a blurred nighttime cityscape with teal and magenta neon bokeh to suggest urban energy. Hit the product with a sharp, directional spotlight from the top left to carve out highlights along the bezel and bracelet, while subtle rim lighting separates it from the background. Mood is sleek, futuristic, and performance-driven." | |
| }, | |
| { | |
| "title": "Japanese Ramen Bowl", | |
| "prompt": "A mouthwatering food advertisement for a ceramic bowl of tonkotsu ramen. Present silky broth with two slices of torched chashu pork, a jammy soft-boiled egg, nori sheets, scallions, and sesame seeds arranged artfully. Place the bowl on a rustic wooden table with lacquered chopsticks resting on a ceramic holder, plus a tiny dish of pickled ginger for color. Capture wisps of steam drifting upward in soft overhead light, while the background falls into a blurred, amber-toned izakaya interior with paper lanterns. Atmosphere is warm, authentic, and comforting." | |
| }, | |
| { | |
| "title": "Japanese Peach Iced Tea", | |
| "prompt": "A commercial advertisement for a Japanese peach-flavored iced tea. The composition features the product bottle placed next to a tall, elegant glass filled with the tea and sparkling ice cubes. The background is a soft, warm gradient of peach and beige, creating a gentle and sophisticated atmosphere. The overall style is clean, minimalist, and refined, with bright, soft lighting that highlights the crisp, refreshing quality of the beverage." | |
| } | |
| ] | |
| with gr.Blocks(title="Ads Background Generation") as app: | |
| gr.Markdown("# Ads Background Generation App") | |
| gr.Markdown("Upload an image with transparent background → Enter prompt → Generate") | |
| # Example Prompts Section | |
| with gr.Accordion("📝 Example Prompts (Click to expand)", open=False): | |
| gr.Markdown("### Background Prompt Examples") | |
| gr.Markdown("Click any example below to copy it to the background description field:") | |
| # Create example buttons | |
| example_buttons = [] | |
| with gr.Row(): | |
| for i, example in enumerate(example_prompts): | |
| if i < 3: # First row | |
| example_btn = gr.Button( | |
| f"📋 {example['title']}", | |
| variant="secondary", | |
| size="sm" | |
| ) | |
| example_buttons.append(example_btn) | |
| with gr.Row(): | |
| for i, example in enumerate(example_prompts): | |
| if i >= 3: # Second row | |
| example_btn = gr.Button( | |
| f"📋 {example['title']}", | |
| variant="secondary", | |
| size="sm" | |
| ) | |
| example_buttons.append(example_btn) | |
| # Display area for selected prompt preview | |
| selected_prompt_display = gr.Textbox( | |
| label="Selected Prompt Preview", | |
| lines=4, | |
| max_lines=8, | |
| interactive=False, | |
| visible=False | |
| ) | |
| with gr.Row(): | |
| # Left column | |
| with gr.Column(scale=1): | |
| # Image upload (top left) | |
| input_image = gr.Image( | |
| label="Upload Image (Transparent Background)", | |
| type="pil", | |
| format="png", | |
| image_mode="RGBA", | |
| height=350 | |
| ) | |
| # Image dimensions | |
| with gr.Row(): | |
| img_width = gr.Number( | |
| value=1024, | |
| label="Width", | |
| precision=0, | |
| minimum=256, | |
| maximum=2048 | |
| ) | |
| img_height = gr.Number( | |
| value=1024, | |
| label="Height", | |
| precision=0, | |
| minimum=256, | |
| maximum=2048 | |
| ) | |
| # Background prompt (bottom left) | |
| bg_prompt = gr.Textbox( | |
| label="Background Description", | |
| placeholder="e.g.: Forest scene, soft lighting", | |
| lines=3 | |
| ) | |
| use_api = gr.Checkbox( | |
| label="Use API", | |
| value=False | |
| ) | |
| with gr.Group(visible=False) as api_group: | |
| api_key = gr.Textbox( | |
| label="API Key", | |
| type="password", | |
| placeholder="Enter your API key" | |
| ) | |
| email = gr.Textbox( | |
| label="Email", | |
| placeholder="Enter your registered email" | |
| ) | |
| mode = gr.Radio( | |
| choices=["bg_generation"], | |
| value="bg_generation", | |
| label="API gen_mode", | |
| interactive=False | |
| ) | |
| # Generation steps | |
| steps_slider = gr.Slider( | |
| minimum=5, | |
| maximum=20, | |
| value=10, | |
| step=1, | |
| label="Generation Steps" | |
| ) | |
| # Generate background button | |
| generate_bg_btn = gr.Button("Generate Background", variant="primary", size="lg") | |
| # Right column - Result display | |
| with gr.Column(scale=1): | |
| final_result = gr.Image( | |
| label="Generated Result", | |
| type="pil", | |
| format="png", | |
| height=700 | |
| ) | |
| def toggle_api_group(use_api_flag): | |
| return gr.update(visible=use_api_flag) | |
| use_api.change( | |
| fn=toggle_api_group, | |
| inputs=[use_api], | |
| outputs=[api_group] | |
| ) | |
| # Generate background directly from input image | |
| def generate_from_input(image, prompt, steps, width, height, use_api_flag, api_key_value, email_value, mode_value): | |
| if image is None: | |
| return None | |
| # Ensure image is RGBA | |
| if image.mode != "RGBA": | |
| image = image.convert("RGBA") | |
| width = int(width) | |
| height = int(height) | |
| # Center original uploaded image on a transparent canvas without resizing | |
| image = _center_subject_on_canvas(image, width, height) | |
| # Generate background using selected method | |
| return generate_background( | |
| image, | |
| prompt, | |
| steps, | |
| use_api_flag, | |
| api_key_value, | |
| email_value, | |
| width, | |
| height, | |
| mode_value, | |
| ) | |
| # Event binding | |
| generate_bg_btn.click( | |
| fn=generate_from_input, | |
| inputs=[input_image, bg_prompt, steps_slider, img_width, img_height, use_api, api_key, email, mode], | |
| outputs=[final_result] | |
| ) | |
| # Example prompt button handlers | |
| def create_example_handler(prompt_text): | |
| def handler(): | |
| return prompt_text, gr.update(value=prompt_text, visible=True) | |
| return handler | |
| # Connect example buttons to background prompt field and preview | |
| for i, example_btn in enumerate(example_buttons): | |
| if i < len(example_prompts): | |
| example_btn.click( | |
| fn=create_example_handler(example_prompts[i]['prompt']), | |
| outputs=[bg_prompt, selected_prompt_display] | |
| ) | |
| return app | |
| # 在应用启动前预加载模型 | |
| print("=" * 60) | |
| print("🔧 Pre-loading models on startup...") | |
| print("=" * 60) | |
| init_pipeline_if_needed() | |
| print("=" * 60) | |
| print("✨ All models loaded and ready!") | |
| print("=" * 60) | |
| if __name__ == "__main__": | |
| app = create_simple_app() | |
| app.launch( | |
| debug=True, | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) | |