StudioWhite / app.py
sfmilan10's picture
Update app.py
e49d0b1 verified
import gradio as gr
import replicate
from PIL import Image
import io
import os
import tempfile
import requests
def generate_flat_lay(image, api_key):
"""
HYBRID APPROACH:
1. Remove background first (keeps EXACT clothing - no hallucination)
2. Use Nano Banana to create flat lay positioning on white background
"""
if image is None:
return None, "Please upload a clothing image"
token = api_key.strip() if api_key and api_key.strip() else os.environ.get("REPLICATE_API_TOKEN", "")
if not token:
return None, "Error: Please enter your Replicate API key"
try:
os.environ["REPLICATE_API_TOKEN"] = token
# Save original image to temp file
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
temp_path = temp_file.name
if isinstance(image, str):
img = Image.open(image)
else:
img = Image.fromarray(image)
if img.mode == 'RGBA':
img = img.convert('RGB')
img.save(temp_path, format='PNG')
temp_file.close()
# ========== STEP 1: REMOVE BACKGROUND (preserves exact clothing) ==========
with open(temp_path, "rb") as f:
bg_removed = replicate.run(
"lucataco/remove-bg:95fcc2a26d3899cd6c2691c900465aaeff466285a65c14638cc5f36f34befaf1",
input={"image": f}
)
if not bg_removed:
os.unlink(temp_path)
return None, "Error: Background removal failed"
# Download the transparent image
response = requests.get(bg_removed)
if response.status_code != 200:
os.unlink(temp_path)
return None, f"Error downloading: {response.status_code}"
# Load and process the transparent clothing
clothing_transparent = Image.open(io.BytesIO(response.content))
# Create a clean version with white background for Nano Banana
if clothing_transparent.mode == 'RGBA':
white_bg = Image.new('RGB', clothing_transparent.size, (255, 255, 255))
white_bg.paste(clothing_transparent, mask=clothing_transparent.split()[3])
clothing_on_white = white_bg
else:
clothing_on_white = clothing_transparent.convert('RGB')
# Save for Nano Banana
temp_path2 = temp_path.replace('.png', '_clean.png')
clothing_on_white.save(temp_path2, format='PNG')
# ========== STEP 2: NANO BANANA FOR FLAT LAY STYLING ==========
prompt = "Lay this garment flat and centered as a professional e-commerce flat lay photo. Keep this EXACT clothing item completely unchanged - same exact colors, same design, same fabric. Pure white studio background. Top-down flat lay perspective."
with open(temp_path2, "rb") as f:
output = replicate.run(
"google/nano-banana",
input={
"image": f,
"prompt": prompt,
"output_format": "png"
}
)
# Cleanup temp files
try:
os.unlink(temp_path)
os.unlink(temp_path2)
except:
pass
if not output:
return None, "Error: No output from Nano Banana"
# Handle output
result_url = None
if isinstance(output, str):
result_url = output
elif hasattr(output, 'url'):
result_url = output.url
elif isinstance(output, list) and len(output) > 0:
item = output[0]
result_url = item if isinstance(item, str) else getattr(item, 'url', None)
elif hasattr(output, '__iter__'):
for item in output:
if isinstance(item, str):
result_url = item
break
elif hasattr(item, 'url'):
result_url = item.url
break
if result_url:
response = requests.get(result_url)
if response.status_code == 200:
result_img = Image.open(io.BytesIO(response.content))
return result_img, "✅ Flat lay created!"
else:
return None, f"Error: {response.status_code}"
return None, "Error: Could not get result"
except Exception as e:
return None, f"Error: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="Studio White - Flat Lay Generator", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 👕 Studio White - Flat Lay E-Commerce Photos
Transform clothing photos into professional flat lay images.
**HYBRID APPROACH (Better accuracy):**
1. First removes background (preserves your EXACT clothing)
2. Then Nano Banana AI creates flat lay positioning
**Speed:** ~15-20 seconds | **Cost:** ~$0.02 per image
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="📷 Upload Clothing Photo",
type="numpy",
height=400
)
api_key_input = gr.Textbox(
label="🔑 Replicate API Key",
placeholder="r8_xxxxxxxxxxxxxxxxx",
type="password"
)
generate_btn = gr.Button("✨ Generate Flat Lay", variant="primary", size="lg")
with gr.Column():
output_image = gr.Image(
label="🖼️ Flat Lay Result",
height=400
)
status_text = gr.Textbox(label="Status", interactive=False)
generate_btn.click(
fn=generate_flat_lay,
inputs=[input_image, api_key_input],
outputs=[output_image, status_text]
)
gr.Markdown("""
---
**Get API key:** [replicate.com/account/api-tokens](https://replicate.com/account/api-tokens)
Cost: ~$0.03 per image
""")
demo.launch()