BoooomNing's picture
Update app.py
e6260ad verified
import os
import io
import json
import base64
import time
import numpy as np
import logging
import gradio as gr
from PIL import Image
from scipy import ndimage
from gradio_client import Client
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ───────── Backend connection with health monitoring ─────────
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("HF_TOKEN environment variable is required")
# Backend connection state
backend_status = {
"client": None,
"connected": False,
"last_check": None,
"error_message": ""
}
def check_backend_connection():
"""Check backend connection and update status"""
global backend_status
try:
test_client = Client("SnapwearAI/Snapwear_BGAI", hf_token=HF_TOKEN)
backend_status["client"] = test_client
backend_status["connected"] = True
backend_status["error_message"] = ""
backend_status["last_check"] = time.time()
logger.info("βœ… Backend connection established")
return True, "🟒 Backend is ready for Create Background"
except Exception as e:
backend_status["client"] = None
backend_status["connected"] = False
backend_status["last_check"] = time.time()
error_str = str(e).lower()
if "timeout" in error_str or "read operation timed out" in error_str:
backend_status["error_message"] = "Backend is starting up (5-6 minutes on first load)"
return False, "🟑 Backend is starting up. Please wait 5-6 minutes and try again."
else:
backend_status["error_message"] = f"Connection error: {str(e)}"
return False, f"πŸ”΄ Backend error: {str(e)}"
# Initial connection attempt
try:
success, status_msg = check_backend_connection()
if success:
logger.info("Backend client established")
else:
logger.warning(f"Initial backend connection failed: {status_msg}")
except Exception as e:
logger.error(f"Failed to connect to backend: {e}")
backend_status["connected"] = False
backend_status["error_message"] = str(e)
def update_backend_status():
"""Check and update backend status"""
success, status_msg = check_backend_connection()
if success:
css_class = "status-ready"
elif "starting up" in status_msg:
css_class = "status-starting"
else:
css_class = "status-error"
status_html = f'<div class="status-banner {css_class}">{status_msg}</div>'
return status_html
# ───────── Styling ─────────
css = """
body, .gradio-container {
font-family: 'Inter', 'SF Pro Display', -apple-system, BlinkMacSystemFont, sans-serif;
}
#col-left, #col-mid, #col-right {
margin: 0 auto;
max-width: 430px;
}
#col-showcase {
margin: 0 auto;
max-width: 1100px;
}
#button {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: #ffffff;
font-weight: 600;
font-size: 18px;
border: none;
border-radius: 12px;
padding: 12px 24px;
transition: all 0.3s ease;
}
#button:hover {
transform: translateY(-2px);
box-shadow: 0 8px 25px rgba(102,126,234,0.3);
}
#button:disabled {
background: #ccc !important;
cursor: not-allowed;
transform: none;
box-shadow: none;
}
.hero-section {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 40px 20px;
border-radius: 20px;
margin: 20px 0;
text-align: center;
}
.feature-box {
background: #f8fafc;
border: 1px solid #e2e8f0;
padding: 20px;
border-radius: 12px;
margin: 10px 0;
border-left: 4px solid #667eea;
}
.showcase-section {
background: #ffffff;
border: 1px solid #e2e8f0;
padding: 30px;
border-radius: 16px;
box-shadow: 0 4px 20px rgba(0,0,0,0.1);
margin: 20px 0;
}
.step-header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 15px;
border-radius: 12px;
text-align: center;
font-weight: 600;
margin: 10px 0;
}
.social-links {
text-align: center;
margin: 20px 0;
}
.social-links a {
margin: 0 10px;
padding: 8px 16px;
background: #667eea;
color: white;
text-decoration: none;
border-radius: 8px;
transition: all 0.3s ease;
}
.social-links a:hover {
background: #764ba2;
transform: translateY(-2px);
}
.error-message {
color: #dc3545;
font-weight: 500;
}
.success-message {
color: #28a745;
font-weight: 500;
}
.status-banner {
padding: 15px;
border-radius: 12px;
margin: 10px 0;
text-align: center;
font-weight: 600;
}
.status-ready {
background: #d4edda;
border: 1px solid #c3e6cb;
color: #155724;
}
.status-starting {
background: #fff3cd;
border: 1px solid #ffeaa7;
color: #856404;
}
.status-error {
background: #f8d7da;
border: 1px solid #f5c6cb;
color: #721c24;
}
.queue-info {
background: #e8f4fd;
border: 1px solid #bee5eb;
padding: 12px;
border-radius: 8px;
margin: 10px 0;
text-align: center;
font-size: 14px;
color: #0c5460;
}
"""
def image_to_base64(image: Image.Image) -> str:
"""
Convert a PIL Image to a base64‐encoded PNG string.
"""
if image is None:
return ""
if image.mode not in ("RGB", "RGBA"):
image = image.convert("RGB")
buffer = io.BytesIO()
image.save(buffer, format="PNG", optimize=True)
buffer.seek(0)
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def base64_to_image(b64_str: str) -> Image.Image:
"""
Decode a base64 string (with or without data URL prefix) into a PIL Image.
"""
if not b64_str:
return None
try:
if b64_str.startswith("data:"):
b64_str = b64_str.split(",", 1)[1]
data = base64.b64decode(b64_str)
return Image.open(io.BytesIO(data)).convert("RGBA")
except Exception as e:
logger.error(f"Failed to decode base64 image: {e}")
return None
def prepare_editor_data(editor_data: dict) -> dict:
"""
Convert Gradio ImageEditor output (a dict with 'background' and 'layers')
into a JSON‐serializable dict where each image is base64‐encoded.
"""
if not editor_data:
return {}
result = {}
# Convert background PIL image to a base64 string
bg = editor_data.get("background", None)
if isinstance(bg, Image.Image):
result["background"] = image_to_base64(bg)
else:
result["background"] = ""
# Convert each layer (mask) to a base64 string
layers = editor_data.get("layers", [])
encoded_layers = []
for layer in layers:
if isinstance(layer, Image.Image):
# Convert mask to binary: any non‐black pixel β†’ white
gray = layer.convert("L")
arr = np.array(gray)
arr[arr > 0] = 255
bin_mask = Image.fromarray(arr.astype(np.uint8))
encoded_layers.append(image_to_base64(bin_mask))
else:
encoded_layers.append("")
result["layers"] = encoded_layers
return result
def dots_to_points(editor_value):
"""
Convert white‐dot brush layer to a list of (x, y) float coordinates.
Expect at least one layer with opaque white dots on transparent bg.
"""
bg = editor_value["background"] # PIL.Image
layers = editor_value["layers"]
if not layers:
raise gr.Error("Draw at least one dot with the brush first!")
# ── find the first non‐empty dot layer ─────────────────────────────
for lyr in layers:
layer_img = lyr if isinstance(lyr, Image.Image) else lyr["data"]
alpha = np.array(layer_img.split()[-1]) # alpha channel
if alpha.max() > 0:
dot_layer = layer_img
break
else:
raise gr.Error("No non-empty brush layer found.")
# ── binarize (opaque => 1) ─────────────────────────────────────────
bin_mask = (np.array(dot_layer.split()[-1]) > 0).astype(np.uint8)
# ── label each connected blob and take centroids ───────────────────
labelled, n = ndimage.label(bin_mask)
if n == 0:
raise gr.Error("No dots detected on the brush layer.")
centroids = ndimage.center_of_mass(bin_mask, labelled, range(1, n + 1)) # (y, x)
# flip to (x, y) order for SAM
point_coords = [(float(x), float(y)) for y, x in centroids]
return bg.convert("RGB"), point_coords
# ───────── Section 1: SAM Mask Generation ────────
def run_sam_frontend(editor_data):
"""
1) Extract (bg_image, point_coords) from ImageEditor via dots_to_points()
2) Build two JSON payloads:
β€’ image_payload_str = JSON of {"background":…, "layers":[…]}
β€’ labels_payload_str = JSON of {"point_coords":…, "point_labels":[…]}
3) Call backend run_sam with both JSONs in one predict() call.
4) Decode returned mask and return as (PIL.Image, base64_str).
"""
# Check backend connection first
if not backend_status["connected"] or not backend_status["client"]:
success, status_msg = check_backend_connection()
if not success:
return None, 0, status_msg
if not editor_data or not editor_data.get("background"):
return None, ""
# 1) Extract point_coords from the brush layers
try:
_, point_coords = dots_to_points(editor_data)
except Exception as e:
logger.error(f"Error extracting points: {e}")
return None, ""
# Build a list of 1’s for every point (all dots = β€œforeground”)
point_labels = [1] * len(point_coords)
# 2a) Build the β€œimage” JSON
image_payload = prepare_editor_data(editor_data)
image_payload_str = json.dumps(image_payload)
# 2b) Build the β€œlabels” JSON
labels_payload = {
"point_coords": point_coords,
"point_labels": point_labels
}
labels_payload_str = json.dumps(labels_payload)
# 3) Call backend /run_sam(endpoint) with TWO JSONs
HF_TOKEN = os.getenv("HF_TOKEN")
client = Client("SnapwearAI/Snapwear_BGAI", hf_token=HF_TOKEN)
try:
# Feed both JSON strings as positional args:
mask_b64 = client.predict(
image_payload_str,
labels_payload_str,
api_name="/run_sam"
)
except Exception as e:
logger.error(f"SAM call failed: {e}")
return None, ""
# 4) Decode the returned base64 mask into a PIL.Image
mask_image = base64_to_image(mask_b64) if mask_b64 else None
return mask_image, mask_b64
# ───────── Section 2: Flux Image Generation ─────────
def generate_images_frontend(editor_data, mask_b64, prompt):
"""
1. Convert ImageEditor data to JSON payload.
2. Use `mask_b64` directly.
3. Call backend `/generate_images` endpoint.
4. Decode returned base64 and return as PIL Image.
"""
# Check backend connection first
if not backend_status["connected"] or not backend_status["client"]:
success, status_msg = check_backend_connection()
if not success:
return None, 0, status_msg
# Validate inputs
if not editor_data or not editor_data.get("background"):
return None
if not mask_b64:
return None
if not prompt:
return None
# 1) Prepare JSON payload
payload = prepare_editor_data(editor_data)
payload_str = json.dumps(payload)
# 2) Invoke backend
from gradio_client import Client
HF_TOKEN = os.getenv("HF_TOKEN")
client = Client("SnapwearAI/Snapwear_BGAI", hf_token=HF_TOKEN)
try:
result_b64 = client.predict(
payload_str,
mask_b64,
prompt,
api_name="/generate_images"
)
except Exception as e:
logger.error(f"Image generation call failed: {e}")
return None
# 3) Decode and return
result_img = base64_to_image(result_b64) if result_b64 else None
return result_img
# ───────── Gradio App (Single Canvas) ─────────
# ───────── Main UI ─────────
with gr.Blocks(css=css, title="Snapwear Create Background") as demo:
# ──────── Hero Section ────────
gr.HTML("""
<div class="hero-section">
<h1 style="font-size:48px;margin:0;background:linear-gradient(45deg,#fff,#f0f8ff);-webkit-background-clip:text;-webkit-text-fill-color:transparent;">
πŸŒ„ Snapwear Create Background
</h1>
<h2 style="font-size:24px;margin:10px 0;opacity:0.9;">
Create a unique pose and setting for your photograph.
</h2>
<div class="social-links">
<a href="https://snapwear.io" target="_blank">🌐 Official Website</a>
<a href="https://www.instagram.com/snapwearai/" target="_blank">πŸ“Έ Instagram</a>
<a href="https://huggingface.co/spaces/SnapwearAI/Snapwear-Texture-Transfer" target="_blank">🎨 Pattern Transfer</a>
<a href="https://huggingface.co/spaces/SnapwearAI/Snapwear-Virtual-Try-On" target="_blank">πŸ‘— Snapwear Virtual TryOn</a>
</div>
<p style="font-size:13px; margin-top:15px; opacity:0.7;">
<b>Disclaimer:</b> This demo is free for trials only. Any solicitation
for payment based on the free features we provide on this HuggingFace Space
is a fraudulent act.
</p>
</div>
""")
# ──────── Backend Status Section ────────
with gr.Row():
with gr.Column():
# Initial status display
if backend_status["connected"]:
initial_status = '<div class="status-banner status-ready">🟒 Create Background is ready!</div>'
else:
initial_status = '<div class="status-banner status-starting">🟑 Model may be starting up. Click "Check Status" to verify.</div>'
status_display = gr.HTML(value=initial_status)
# Status check button
check_status_btn = gr.Button("πŸ”„ Check Status", size="sm")
# ──────── Key Features ────────
gr.HTML("""
<div style="display:grid;grid-template-columns:repeat(auto-fit,minmax(250px,1fr));gap:20px;margin:30px 0;">
<div class="feature-box">
<h3>πŸš€ Instant Background Swap</h3>
<p>Change backgrounds in 10–20 seconds with a single click</p>
</div>
<div class="feature-box">
<h3>🎯 Seamless Blending</h3>
<p>Preserves subject edges, lighting, and shadows for natural integration</p>
</div>
<div class="feature-box">
<h3>πŸ’Ž High-Resolution Output</h3>
<p>Produce professional-grade images perfect for photography, e-commerce, and virtual presentations</p>
</div>
</div>
""")
# ──────── Step Headers ────────
with gr.Row():
with gr.Column(elem_id="col-left"):
gr.HTML('<div class="step-header">Step 1: Upload Image & Draw dots on the area you want to Preserve πŸ–ΌοΈπŸ–ŒοΈ</div>')
with gr.Column(elem_id="col-mid"):
gr.HTML('<div class="step-header">Step 2. Press Mask Button and Mask The Model image ⬇️</div>')
with gr.Column(elem_id="col-right"):
gr.HTML('<div class="step-header">Step 3. Press "Generate" to get your Background result βœ¨πŸŒ„</div>')
# ──────── Main Interface ────────
with gr.Row():
# β‘  Person + Dot Mask
with gr.Column(elem_id="col-left"):
model_editor = gr.ImageEditor(
label="Model Image",
type="pil",
brush=gr.Brush(color_mode="select", default_size=20),
image_mode="RGBA",
height=450
)
gr.HTML('<div style="font-size:14px; color:#666; margin-top:8px; text-align:center;">'
'⚠️ <b>Important:</b> First Draw a mask on the area you want to Preserve<br/>')
gr.Examples(
label="Example Model Images",
inputs=model_editor,
examples_per_page=12,
examples=[f"examples/model{i}.jpg" for i in range(1, 5)] if os.path.exists("examples") else [],
)
# β‘‘ Mask Preview
with gr.Column(elem_id="col-mid"):
mask_preview = gr.Image(
label="Mask Preview",
height=450,
)
mask_b64_hidden = gr.Textbox(label="Mask (base64)", visible=False)
sam_button = gr.Button("πŸ–ŒοΈ Generate Mask", elem_id="button", size="md")
gr.HTML('<p style="text-align:center;color:#888;font-size:13px;">A mask will be generated to segment the area you want to preserve.</p>')
# β‘’ Generated Image
with gr.Column():
result_preview = gr.Image(label="Generated Image",show_share_button=True, height=450)
with gr.Column():
prompt_box = gr.Textbox(label="Prompt", placeholder="Describe the Background...")
# βœ… Adding prompt examples here
gr.Examples(
label="Prompt Examples",
examples=[
"asian model standing in a busy street in new york",
"side pose of a female model wearing mini-malist earrings",
"wooden chair in home balcony with plants",
"A female model posing on a beach"
],
inputs=prompt_box
)
gen_button = gr.Button("Generate", elem_id="button")
# ──────── Event Handlers ────────
# Status check button
check_status_btn.click(
fn=update_backend_status,
outputs=[status_display]
)
sam_button.click(
fn=run_sam_frontend,
inputs=[model_editor],
outputs=[mask_preview, mask_b64_hidden],
show_progress=True
)
gen_button.click(
fn=generate_images_frontend,
inputs=[model_editor, mask_b64_hidden, prompt_box],
outputs=[result_preview],
concurrency_limit=1, # Match backend queue system
show_progress=True
)
# ──────── Look-Book Grid ────────
# Virtual try-on examples
lookbook_rows = [
[f"lookbook/model{i}.jpg",
f"lookbook/mask{i}.jpg",
f"lookbook/result{i}.jpg"]
for i in range(1, 5) if os.path.exists("lookbook") # adjust range to your file count
]
if lookbook_rows:
gr.HTML("""
<div class="showcase-section">
<h2 style="text-align:center;color:#333;margin-bottom:30px;">
🌟 Create Background Showcase
</h2>
</div>
""")
gr.Examples(
examples=lookbook_rows,
inputs=[model_editor, mask_preview, result_preview],
label=None,
examples_per_page=4,
)
# ──────── Model Comparison Grid ────────
if os.path.exists("examples/Grid.jpg"):
gr.HTML("""
<div class="showcase-section">
<h2 style="text-align:center;color:#333;margin-bottom:20px;">
πŸ”¬ Model Comparison Analysis
</h2>
<p style="text-align:center;color:#666;margin-bottom:30px;font-size:16px;">
See how Snapwear BGAI compares against leading Create Background models
</p>
</div>
""")
# Display the comparison grid image
with gr.Row():
with gr.Column():
comparison_image = gr.Image(
value="examples/Grid.jpg",
label="Create Background Model Comparison",
show_label=True,
interactive=False,
height=600,
show_download_button=True,
show_share_button=False
)
# ──────── Use Cases ────────
gr.HTML("""
<div style="background:#f8fafc;border:1px solid #e2e8f0;padding:30px;border-radius:16px;margin:30px 0;">
<h2 style="text-align:center;color:#333;margin-bottom:25px;">🎯 Perfect For</h2>
<div style="display:grid;grid-template-columns:repeat(auto-fit,minmax(200px,1fr));gap:20px;">
<div style="text-align:center;padding:15px;">
<h3 style="color:#667eea;">πŸ“Έ Photographers</h3>
<p style="color:#666;">Replace or enhance backgrounds for professional-quality shots</p>
</div>
<div style="text-align:center;padding:15px;">
<h3 style="color:#667eea;">πŸŽ₯ Content Creators</h3>
<p style="color:#666;">Craft stunning visuals by swapping backgrounds instantly</p>
</div>
<div style="text-align:center;padding:15px;">
<h3 style="color:#667eea;">🏠 Real Estate Agents</h3>
<p style="color:#666;">Stage property photos with appealing environments</p>
</div>
<div style="text-align:center;padding:15px;">
<h3 style="color:#667eea;">πŸ’Ό Virtual Professionals</h3>
<p style="color:#666;">Set a polished backdrop for virtual meetings and presentations</p>
</div>
</div>
</div>
""")
# ──────── Footer ────────
gr.HTML("""
<div style="text-align:center;padding:40px 20px;background:#f8fafc;border:1px solid #e2e8f0;border-radius:16px;margin:30px 0;">
<h3 style="color:#333;">πŸš€ Powered by Snapwear AI</h3>
<p style="color:#666;">
Experience the future of virtual Photoshoot.
</p>
<div class="social-links">
<a href="https://snapwear.io" target="_blank">🌐 Website</a>
<a href="https://www.instagram.com/snapwearai/" target="_blank">πŸ“Έ Instagram</a>
<a href="https://huggingface.co/spaces/SnapwearAI/Snapwear-Texture-Transfer" target="_blank">🎨 Pattern Transfer</a>
<a href="https://huggingface.co/spaces/SnapwearAI/Snapwear-Virtual-Try-On" target="_blank">πŸ‘— Snapwear Virtual TryOn</a>
</div>
<p style="font-size:12px;color:#999;margin-top:20px;">
Β© 2024 Snapwear AI. Professional AI tools for fashion and design.
</p>
</div>
""")
# ───────── Launch App ─────────
if __name__ == "__main__":
demo.queue(
max_size=20,
default_concurrency_limit=1, # Single concurrent request to match backend
api_open=False
).launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_api=False
)