import os import io import json import base64 import time import numpy as np import logging import gradio as gr from PIL import Image from scipy import ndimage from gradio_client import Client # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ───────── Backend connection with health monitoring ───────── HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: raise ValueError("HF_TOKEN environment variable is required") # Backend connection state backend_status = { "client": None, "connected": False, "last_check": None, "error_message": "" } def check_backend_connection(): """Check backend connection and update status""" global backend_status try: test_client = Client("SnapwearAI/Snapwear_BGAI", hf_token=HF_TOKEN) backend_status["client"] = test_client backend_status["connected"] = True backend_status["error_message"] = "" backend_status["last_check"] = time.time() logger.info("✅ Backend connection established") return True, "🟢 Backend is ready for Create Background" except Exception as e: backend_status["client"] = None backend_status["connected"] = False backend_status["last_check"] = time.time() error_str = str(e).lower() if "timeout" in error_str or "read operation timed out" in error_str: backend_status["error_message"] = "Backend is starting up (5-6 minutes on first load)" return False, "🟡 Backend is starting up. Please wait 5-6 minutes and try again." else: backend_status["error_message"] = f"Connection error: {str(e)}" return False, f"🔴 Backend error: {str(e)}" # Initial connection attempt try: success, status_msg = check_backend_connection() if success: logger.info("Backend client established") else: logger.warning(f"Initial backend connection failed: {status_msg}") except Exception as e: logger.error(f"Failed to connect to backend: {e}") backend_status["connected"] = False backend_status["error_message"] = str(e) def update_backend_status(): """Check and update backend status""" success, status_msg = check_backend_connection() if success: css_class = "status-ready" elif "starting up" in status_msg: css_class = "status-starting" else: css_class = "status-error" status_html = f'
{status_msg}
' return status_html # ───────── Styling ───────── css = """ body, .gradio-container { font-family: 'Inter', 'SF Pro Display', -apple-system, BlinkMacSystemFont, sans-serif; } #col-left, #col-mid, #col-right { margin: 0 auto; max-width: 430px; } #col-showcase { margin: 0 auto; max-width: 1100px; } #button { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: #ffffff; font-weight: 600; font-size: 18px; border: none; border-radius: 12px; padding: 12px 24px; transition: all 0.3s ease; } #button:hover { transform: translateY(-2px); box-shadow: 0 8px 25px rgba(102,126,234,0.3); } #button:disabled { background: #ccc !important; cursor: not-allowed; transform: none; box-shadow: none; } .hero-section { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 40px 20px; border-radius: 20px; margin: 20px 0; text-align: center; } .feature-box { background: #f8fafc; border: 1px solid #e2e8f0; padding: 20px; border-radius: 12px; margin: 10px 0; border-left: 4px solid #667eea; } .showcase-section { background: #ffffff; border: 1px solid #e2e8f0; padding: 30px; border-radius: 16px; box-shadow: 0 4px 20px rgba(0,0,0,0.1); margin: 20px 0; } .step-header { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 15px; border-radius: 12px; text-align: center; font-weight: 600; margin: 10px 0; } .social-links { text-align: center; margin: 20px 0; } .social-links a { margin: 0 10px; padding: 8px 16px; background: #667eea; color: white; text-decoration: none; border-radius: 8px; transition: all 0.3s ease; } .social-links a:hover { background: #764ba2; transform: translateY(-2px); } .error-message { color: #dc3545; font-weight: 500; } .success-message { color: #28a745; font-weight: 500; } .status-banner { padding: 15px; border-radius: 12px; margin: 10px 0; text-align: center; font-weight: 600; } .status-ready { background: #d4edda; border: 1px solid #c3e6cb; color: #155724; } .status-starting { background: #fff3cd; border: 1px solid #ffeaa7; color: #856404; } .status-error { background: #f8d7da; border: 1px solid #f5c6cb; color: #721c24; } .queue-info { background: #e8f4fd; border: 1px solid #bee5eb; padding: 12px; border-radius: 8px; margin: 10px 0; text-align: center; font-size: 14px; color: #0c5460; } """ def image_to_base64(image: Image.Image) -> str: """ Convert a PIL Image to a base64‐encoded PNG string. """ if image is None: return "" if image.mode not in ("RGB", "RGBA"): image = image.convert("RGB") buffer = io.BytesIO() image.save(buffer, format="PNG", optimize=True) buffer.seek(0) return base64.b64encode(buffer.getvalue()).decode("utf-8") def base64_to_image(b64_str: str) -> Image.Image: """ Decode a base64 string (with or without data URL prefix) into a PIL Image. """ if not b64_str: return None try: if b64_str.startswith("data:"): b64_str = b64_str.split(",", 1)[1] data = base64.b64decode(b64_str) return Image.open(io.BytesIO(data)).convert("RGBA") except Exception as e: logger.error(f"Failed to decode base64 image: {e}") return None def prepare_editor_data(editor_data: dict) -> dict: """ Convert Gradio ImageEditor output (a dict with 'background' and 'layers') into a JSON‐serializable dict where each image is base64‐encoded. """ if not editor_data: return {} result = {} # Convert background PIL image to a base64 string bg = editor_data.get("background", None) if isinstance(bg, Image.Image): result["background"] = image_to_base64(bg) else: result["background"] = "" # Convert each layer (mask) to a base64 string layers = editor_data.get("layers", []) encoded_layers = [] for layer in layers: if isinstance(layer, Image.Image): # Convert mask to binary: any non‐black pixel → white gray = layer.convert("L") arr = np.array(gray) arr[arr > 0] = 255 bin_mask = Image.fromarray(arr.astype(np.uint8)) encoded_layers.append(image_to_base64(bin_mask)) else: encoded_layers.append("") result["layers"] = encoded_layers return result def dots_to_points(editor_value): """ Convert white‐dot brush layer to a list of (x, y) float coordinates. Expect at least one layer with opaque white dots on transparent bg. """ bg = editor_value["background"] # PIL.Image layers = editor_value["layers"] if not layers: raise gr.Error("Draw at least one dot with the brush first!") # ── find the first non‐empty dot layer ───────────────────────────── for lyr in layers: layer_img = lyr if isinstance(lyr, Image.Image) else lyr["data"] alpha = np.array(layer_img.split()[-1]) # alpha channel if alpha.max() > 0: dot_layer = layer_img break else: raise gr.Error("No non-empty brush layer found.") # ── binarize (opaque => 1) ───────────────────────────────────────── bin_mask = (np.array(dot_layer.split()[-1]) > 0).astype(np.uint8) # ── label each connected blob and take centroids ─────────────────── labelled, n = ndimage.label(bin_mask) if n == 0: raise gr.Error("No dots detected on the brush layer.") centroids = ndimage.center_of_mass(bin_mask, labelled, range(1, n + 1)) # (y, x) # flip to (x, y) order for SAM point_coords = [(float(x), float(y)) for y, x in centroids] return bg.convert("RGB"), point_coords # ───────── Section 1: SAM Mask Generation ──────── def run_sam_frontend(editor_data): """ 1) Extract (bg_image, point_coords) from ImageEditor via dots_to_points() 2) Build two JSON payloads: • image_payload_str = JSON of {"background":…, "layers":[…]} • labels_payload_str = JSON of {"point_coords":…, "point_labels":[…]} 3) Call backend run_sam with both JSONs in one predict() call. 4) Decode returned mask and return as (PIL.Image, base64_str). """ # Check backend connection first if not backend_status["connected"] or not backend_status["client"]: success, status_msg = check_backend_connection() if not success: return None, 0, status_msg if not editor_data or not editor_data.get("background"): return None, "" # 1) Extract point_coords from the brush layers try: _, point_coords = dots_to_points(editor_data) except Exception as e: logger.error(f"Error extracting points: {e}") return None, "" # Build a list of 1’s for every point (all dots = “foreground”) point_labels = [1] * len(point_coords) # 2a) Build the “image” JSON image_payload = prepare_editor_data(editor_data) image_payload_str = json.dumps(image_payload) # 2b) Build the “labels” JSON labels_payload = { "point_coords": point_coords, "point_labels": point_labels } labels_payload_str = json.dumps(labels_payload) # 3) Call backend /run_sam(endpoint) with TWO JSONs HF_TOKEN = os.getenv("HF_TOKEN") client = Client("SnapwearAI/Snapwear_BGAI", hf_token=HF_TOKEN) try: # Feed both JSON strings as positional args: mask_b64 = client.predict( image_payload_str, labels_payload_str, api_name="/run_sam" ) except Exception as e: logger.error(f"SAM call failed: {e}") return None, "" # 4) Decode the returned base64 mask into a PIL.Image mask_image = base64_to_image(mask_b64) if mask_b64 else None return mask_image, mask_b64 # ───────── Section 2: Flux Image Generation ───────── def generate_images_frontend(editor_data, mask_b64, prompt): """ 1. Convert ImageEditor data to JSON payload. 2. Use `mask_b64` directly. 3. Call backend `/generate_images` endpoint. 4. Decode returned base64 and return as PIL Image. """ # Check backend connection first if not backend_status["connected"] or not backend_status["client"]: success, status_msg = check_backend_connection() if not success: return None, 0, status_msg # Validate inputs if not editor_data or not editor_data.get("background"): return None if not mask_b64: return None if not prompt: return None # 1) Prepare JSON payload payload = prepare_editor_data(editor_data) payload_str = json.dumps(payload) # 2) Invoke backend from gradio_client import Client HF_TOKEN = os.getenv("HF_TOKEN") client = Client("SnapwearAI/Snapwear_BGAI", hf_token=HF_TOKEN) try: result_b64 = client.predict( payload_str, mask_b64, prompt, api_name="/generate_images" ) except Exception as e: logger.error(f"Image generation call failed: {e}") return None # 3) Decode and return result_img = base64_to_image(result_b64) if result_b64 else None return result_img # ───────── Gradio App (Single Canvas) ───────── # ───────── Main UI ───────── with gr.Blocks(css=css, title="Snapwear Create Background") as demo: # ──────── Hero Section ──────── gr.HTML("""

🌄 Snapwear Create Background

Create a unique pose and setting for your photograph.

Disclaimer: This demo is free for trials only. Any solicitation for payment based on the free features we provide on this HuggingFace Space is a fraudulent act.

""") # ──────── Backend Status Section ──────── with gr.Row(): with gr.Column(): # Initial status display if backend_status["connected"]: initial_status = '
🟢 Create Background is ready!
' else: initial_status = '
🟡 Model may be starting up. Click "Check Status" to verify.
' status_display = gr.HTML(value=initial_status) # Status check button check_status_btn = gr.Button("🔄 Check Status", size="sm") # ──────── Key Features ──────── gr.HTML("""

🚀 Instant Background Swap

Change backgrounds in 10–20 seconds with a single click

🎯 Seamless Blending

Preserves subject edges, lighting, and shadows for natural integration

💎 High-Resolution Output

Produce professional-grade images perfect for photography, e-commerce, and virtual presentations

""") # ──────── Step Headers ──────── with gr.Row(): with gr.Column(elem_id="col-left"): gr.HTML('
Step 1: Upload Image & Draw dots on the area you want to Preserve 🖼️🖌️
') with gr.Column(elem_id="col-mid"): gr.HTML('
Step 2. Press Mask Button and Mask The Model image ⬇️
') with gr.Column(elem_id="col-right"): gr.HTML('
Step 3. Press "Generate" to get your Background result ✨🌄
') # ──────── Main Interface ──────── with gr.Row(): # ① Person + Dot Mask with gr.Column(elem_id="col-left"): model_editor = gr.ImageEditor( label="Model Image", type="pil", brush=gr.Brush(color_mode="select", default_size=20), image_mode="RGBA", height=450 ) gr.HTML('
' '⚠️ Important: First Draw a mask on the area you want to Preserve
') gr.Examples( label="Example Model Images", inputs=model_editor, examples_per_page=12, examples=[f"examples/model{i}.jpg" for i in range(1, 5)] if os.path.exists("examples") else [], ) # ② Mask Preview with gr.Column(elem_id="col-mid"): mask_preview = gr.Image( label="Mask Preview", height=450, ) mask_b64_hidden = gr.Textbox(label="Mask (base64)", visible=False) sam_button = gr.Button("🖌️ Generate Mask", elem_id="button", size="md") gr.HTML('

A mask will be generated to segment the area you want to preserve.

') # ③ Generated Image with gr.Column(): result_preview = gr.Image(label="Generated Image",show_share_button=True, height=450) with gr.Column(): prompt_box = gr.Textbox(label="Prompt", placeholder="Describe the Background...") # ✅ Adding prompt examples here gr.Examples( label="Prompt Examples", examples=[ "asian model standing in a busy street in new york", "side pose of a female model wearing mini-malist earrings", "wooden chair in home balcony with plants", "A female model posing on a beach" ], inputs=prompt_box ) gen_button = gr.Button("Generate", elem_id="button") # ──────── Event Handlers ──────── # Status check button check_status_btn.click( fn=update_backend_status, outputs=[status_display] ) sam_button.click( fn=run_sam_frontend, inputs=[model_editor], outputs=[mask_preview, mask_b64_hidden], show_progress=True ) gen_button.click( fn=generate_images_frontend, inputs=[model_editor, mask_b64_hidden, prompt_box], outputs=[result_preview], concurrency_limit=1, # Match backend queue system show_progress=True ) # ──────── Look-Book Grid ──────── # Virtual try-on examples lookbook_rows = [ [f"lookbook/model{i}.jpg", f"lookbook/mask{i}.jpg", f"lookbook/result{i}.jpg"] for i in range(1, 5) if os.path.exists("lookbook") # adjust range to your file count ] if lookbook_rows: gr.HTML("""

🌟 Create Background Showcase

""") gr.Examples( examples=lookbook_rows, inputs=[model_editor, mask_preview, result_preview], label=None, examples_per_page=4, ) # ──────── Model Comparison Grid ──────── if os.path.exists("examples/Grid.jpg"): gr.HTML("""

🔬 Model Comparison Analysis

See how Snapwear BGAI compares against leading Create Background models

""") # Display the comparison grid image with gr.Row(): with gr.Column(): comparison_image = gr.Image( value="examples/Grid.jpg", label="Create Background Model Comparison", show_label=True, interactive=False, height=600, show_download_button=True, show_share_button=False ) # ──────── Use Cases ──────── gr.HTML("""

🎯 Perfect For

📸 Photographers

Replace or enhance backgrounds for professional-quality shots

🎥 Content Creators

Craft stunning visuals by swapping backgrounds instantly

🏠 Real Estate Agents

Stage property photos with appealing environments

💼 Virtual Professionals

Set a polished backdrop for virtual meetings and presentations

""") # ──────── Footer ──────── gr.HTML("""

🚀 Powered by Snapwear AI

Experience the future of virtual Photoshoot.

© 2024 Snapwear AI. Professional AI tools for fashion and design.

""") # ───────── Launch App ───────── if __name__ == "__main__": demo.queue( max_size=20, default_concurrency_limit=1, # Single concurrent request to match backend api_open=False ).launch( server_name="0.0.0.0", server_port=7860, share=False, show_api=False )