adsnap-studio / streamlit_app.py
Dewmike's picture
Initial commit with Hugging Face Whisper integration
0362b52
import os
import sys
import streamlit as st
from dotenv import load_dotenv
# Add the project root to the Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Now import from src
from src.services import (
lifestyle_shot_by_image,
lifestyle_shot_by_text,
add_shadow,
create_packshot,
enhance_prompt,
generative_fill,
generate_hd_image,
erase_foreground
)
from src.components.voice_to_image import render_voice_to_image_section
from PIL import Image
import io
import requests
import json
import time
import base64
from streamlit_drawable_canvas import st_canvas
import numpy as np
from src.services.erase_foreground import erase_foreground
import pkg_resources
print([p.project_name for p in pkg_resources.working_set])
# Configure Streamlit page
st.set_page_config(
page_title="AdSnap Studio",
page_icon="🎨",
layout="wide",
initial_sidebar_state="expanded"
)
# Load environment variables
print("Loading environment variables...")
load_dotenv(verbose=True) # Add verbose=True to see loading details
# Debug: Print environment variable status
api_key = os.getenv("BRIA_API_KEY")
print(f"API Key present: {bool(api_key)}")
print(f"API Key value: {api_key if api_key else 'Not found'}")
print(f"Current working directory: {os.getcwd()}")
print(f".env file exists: {os.path.exists('.env')}")
def initialize_session_state():
"""Initialize session state variables."""
if 'api_key' not in st.session_state:
st.session_state.api_key = os.getenv('BRIA_API_KEY')
if 'generated_images' not in st.session_state:
st.session_state.generated_images = []
if 'current_image' not in st.session_state:
st.session_state.current_image = None
if 'pending_urls' not in st.session_state:
st.session_state.pending_urls = []
if 'edited_image' not in st.session_state:
st.session_state.edited_image = None
if 'original_prompt' not in st.session_state:
st.session_state.original_prompt = ""
if 'enhanced_prompt' not in st.session_state:
st.session_state.enhanced_prompt = None
def download_image(url):
"""Download image from URL and return as bytes."""
try:
response = requests.get(url)
response.raise_for_status()
return response.content
except Exception as e:
st.error(f"Error downloading image: {str(e)}")
return None
def apply_image_filter(image, filter_type):
"""Apply various filters to the image."""
try:
img = Image.open(io.BytesIO(image)) if isinstance(image, bytes) else Image.open(image)
if filter_type == "Grayscale":
return img.convert('L')
elif filter_type == "Sepia":
width, height = img.size
pixels = img.load()
for x in range(width):
for y in range(height):
r, g, b = img.getpixel((x, y))[:3]
tr = int(0.393 * r + 0.769 * g + 0.189 * b)
tg = int(0.349 * r + 0.686 * g + 0.168 * b)
tb = int(0.272 * r + 0.534 * g + 0.131 * b)
img.putpixel((x, y), (min(tr, 255), min(tg, 255), min(tb, 255)))
return img
elif filter_type == "High Contrast":
return img.point(lambda x: x * 1.5)
elif filter_type == "Blur":
return img.filter(Image.BLUR)
else:
return img
except Exception as e:
st.error(f"Error applying filter: {str(e)}")
return None
def check_generated_images():
"""Check if pending images are ready and update the display."""
if st.session_state.pending_urls:
ready_images = []
still_pending = []
for url in st.session_state.pending_urls:
try:
response = requests.head(url)
# Consider an image ready if we get a 200 response with any content length
if response.status_code == 200:
ready_images.append(url)
else:
still_pending.append(url)
except Exception as e:
still_pending.append(url)
# Update the pending URLs list
st.session_state.pending_urls = still_pending
# If we found any ready images, update the display
if ready_images:
st.session_state.edited_image = ready_images[0] # Display the first ready image
if len(ready_images) > 1:
st.session_state.generated_images = ready_images # Store all ready images
return True
return False
def auto_check_images(status_container):
"""Automatically check for image completion a few times."""
max_attempts = 3
attempt = 0
while attempt < max_attempts and st.session_state.pending_urls:
time.sleep(2) # Wait 2 seconds between checks
if check_generated_images():
status_container.success("✨ Image ready!")
return True
attempt += 1
return False
def main():
st.title("AdSnap Studio")
initialize_session_state()
# Sidebar for API key
with st.sidebar:
st.header("Settings")
api_key = st.text_input("Enter your API key:", value=st.session_state.api_key if st.session_state.api_key else "", type="password")
if api_key:
st.session_state.api_key = api_key
# Main tabs
tabs = st.tabs([
"🎨 Generate Image",
"🎤 Voice to Image",
"🖼️ Lifestyle Shot",
"🎨 Generative Fill",
"🎨 Erase Elements"
])
# Generate Images Tab
with tabs[0]:
st.header("Generate Images")
col1, col2 = st.columns([2, 1])
with col1:
# Prompt input
prompt = st.text_area("Enter your prompt",
value="",
height=100,
key="prompt_input")
# Store original prompt in session state when it changes
if "original_prompt" not in st.session_state:
st.session_state.original_prompt = prompt
elif prompt != st.session_state.original_prompt:
st.session_state.original_prompt = prompt
st.session_state.enhanced_prompt = None # Reset enhanced prompt when original changes
# Enhanced prompt display
if st.session_state.get('enhanced_prompt'):
st.markdown("**Enhanced Prompt:**")
st.markdown(f"*{st.session_state.enhanced_prompt}*")
# Enhance Prompt button
if st.button("✨ Enhance Prompt", key="enhance_button"):
if not prompt:
st.warning("Please enter a prompt to enhance.")
else:
with st.spinner("Enhancing prompt..."):
try:
result = enhance_prompt(st.session_state.api_key, prompt)
if result:
st.session_state.enhanced_prompt = result
st.success("Prompt enhanced!")
st.experimental_rerun() # Rerun to update the display
except Exception as e:
st.error(f"Error enhancing prompt: {str(e)}")
# Debug information
st.write("Debug - Session State:", {
"original_prompt": st.session_state.get("original_prompt"),
"enhanced_prompt": st.session_state.get("enhanced_prompt")
})
with col2:
num_images = st.slider("Number of images", 1, 4, 1)
aspect_ratio = st.selectbox("Aspect ratio", ["1:1", "16:9", "9:16", "4:3", "3:4"])
enhance_img = st.checkbox("Enhance image quality", value=True)
# Style options
st.subheader("Style Options")
style = st.selectbox("Image Style", [
"Realistic", "Artistic", "Cartoon", "Sketch",
"Watercolor", "Oil Painting", "Digital Art"
])
# Add style to prompt
if style and style != "Realistic":
prompt = f"{prompt}, in {style.lower()} style"
# Generate button
if st.button("🎨 Generate Images", type="primary"):
if not st.session_state.api_key:
st.error("Please enter your API key in the sidebar.")
return
with st.spinner("🎨 Generating your masterpiece..."):
try:
# Convert aspect ratio to proper format
result = generate_hd_image(
prompt=st.session_state.enhanced_prompt or prompt,
api_key=st.session_state.api_key,
num_results=num_images,
aspect_ratio=aspect_ratio, # Already in correct format (e.g. "1:1")
sync=True, # Wait for results
enhance_image=enhance_img,
medium="art" if style != "Realistic" else "photography",
prompt_enhancement=False, # We're already using our own prompt enhancement
content_moderation=True # Enable content moderation by default
)
if result:
# Extract image URL from the result - handle multiple possible response formats
image_url = None
if isinstance(result, dict):
# Format 1: {"result": [{"urls": ["..."]}]} or {"result": [{"url": "..."}]}
if 'result' in result and result['result']:
if isinstance(result['result'], list) and len(result['result']) > 0:
first_result = result['result'][0]
if isinstance(first_result, dict):
# Check for "urls" array first (actual API format)
if 'urls' in first_result and first_result['urls']:
image_url = first_result['urls'][0]
# Fallback to "url" single value
elif 'url' in first_result:
image_url = first_result['url']
elif isinstance(result['result'], dict):
if 'urls' in result['result'] and result['result']['urls']:
image_url = result['result']['urls'][0]
elif 'url' in result['result']:
image_url = result['result']['url']
# Format 2: {"result_url": "..."}
elif 'result_url' in result:
image_url = result['result_url']
# Format 3: {"url": "..."}
elif 'url' in result:
image_url = result['url']
# Format 4: {"result_urls": ["..."]}
elif 'result_urls' in result and result['result_urls']:
image_url = result['result_urls'][0]
if image_url:
# Add to generated images
if 'generated_images' not in st.session_state:
st.session_state.generated_images = []
st.session_state.generated_images.append({
'url': image_url,
'prompt': st.session_state.enhanced_prompt or prompt,
'source': 'text',
'timestamp': time.time()
})
st.success("✨ Image generated successfully!")
else:
st.error("❌ Failed to generate image. Please try again.")
except Exception as e:
st.error(f"Error generating images: {str(e)}")
st.write("Full error:", str(e))
# Display generated images from text prompts
if 'generated_images' in st.session_state and st.session_state.generated_images:
text_generated = [img for img in st.session_state.generated_images if img.get('source') == 'text']
if text_generated:
st.subheader("🎨 Generated Images")
for i, image_data in enumerate(text_generated):
with st.expander(f"Generated Image {i+1} - {image_data['prompt'][:50]}..."):
col1, col2 = st.columns([2, 1])
with col1:
st.image(image_data['url'], caption="Generated Image", use_column_width=True)
with col2:
st.markdown("**Original Prompt:**")
st.text(image_data['prompt'])
# Download button
try:
import requests
response = requests.get(image_data['url'])
if response.status_code == 200:
st.download_button(
label="📥 Download Image",
data=response.content,
file_name=f"generated_image_{i+1}.png",
mime="image/png"
)
else:
st.warning("Image not yet ready for download")
except Exception as e:
st.error(f"Download error: {str(e)}")
# Voice to Image Tab
with tabs[1]:
render_voice_to_image_section()
# Product Photography Tab
with tabs[2]:
st.header("Product Photography")
uploaded_file = st.file_uploader("Upload Product Image", type=["png", "jpg", "jpeg"], key="product_upload")
if uploaded_file:
col1, col2 = st.columns(2)
with col1:
st.image(uploaded_file, caption="Original Image", use_column_width=True)
# Product editing options
edit_option = st.selectbox("Select Edit Option", [
"Create Packshot",
"Add Shadow",
"Lifestyle Shot"
])
if edit_option == "Create Packshot":
col_a, col_b = st.columns(2)
with col_a:
bg_color = st.color_picker("Background Color", "#FFFFFF")
sku = st.text_input("SKU (optional)", "")
with col_b:
force_rmbg = st.checkbox("Force Background Removal", False)
content_moderation = st.checkbox("Enable Content Moderation", False)
if st.button("Create Packshot"):
with st.spinner("Creating professional packshot..."):
try:
# First remove background if needed
if force_rmbg:
from services.background_service import remove_background
bg_result = remove_background(
st.session_state.api_key,
uploaded_file.getvalue(),
content_moderation=content_moderation
)
if bg_result and "result_url" in bg_result:
# Download the background-removed image
response = requests.get(bg_result["result_url"])
if response.status_code == 200:
image_data = response.content
else:
st.error("Failed to download background-removed image")
return
else:
st.error("Background removal failed")
return
else:
image_data = uploaded_file.getvalue()
# Now create packshot
result = create_packshot(
st.session_state.api_key,
image_data,
background_color=bg_color,
sku=sku if sku else None,
force_rmbg=force_rmbg,
content_moderation=content_moderation
)
if result and "result_url" in result:
# Add to generated images
if 'generated_images' not in st.session_state:
st.session_state.generated_images = []
st.session_state.generated_images.append({
'url': result["result_url"],
'prompt': f"Packshot with {bg_color} background" + (f" - SKU: {sku}" if sku else ""),
'source': 'packshot',
'timestamp': time.time()
})
st.success("✨ Packshot created successfully!")
else:
st.error("No result URL in the API response. Please try again.")
except Exception as e:
st.error(f"Error creating packshot: {str(e)}")
if "422" in str(e):
st.warning("Content moderation failed. Please ensure the image is appropriate.")
elif edit_option == "Add Shadow":
col_a, col_b = st.columns(2)
with col_a:
shadow_type = st.selectbox("Shadow Type", ["Natural", "Drop"])
bg_color = st.color_picker("Background Color (optional)", "#FFFFFF")
use_transparent_bg = st.checkbox("Use Transparent Background", True)
shadow_color = st.color_picker("Shadow Color", "#000000")
sku = st.text_input("SKU (optional)", "")
# Shadow offset
st.subheader("Shadow Offset")
offset_x = st.slider("X Offset", -50, 50, 0)
offset_y = st.slider("Y Offset", -50, 50, 15)
with col_b:
shadow_intensity = st.slider("Shadow Intensity", 0, 100, 60)
shadow_blur = st.slider("Shadow Blur", 0, 50, 15 if shadow_type.lower() == "regular" else 20)
# Float shadow specific controls
if shadow_type == "Float":
st.subheader("Float Shadow Settings")
shadow_width = st.slider("Shadow Width", -100, 100, 0)
shadow_height = st.slider("Shadow Height", -100, 100, 70)
force_rmbg = st.checkbox("Force Background Removal", False)
content_moderation = st.checkbox("Enable Content Moderation", False)
if st.button("Add Shadow"):
with st.spinner("Adding shadow effect..."):
try:
result = add_shadow(
api_key=st.session_state.api_key,
image_data=uploaded_file.getvalue(),
shadow_type=shadow_type.lower(),
background_color=None if use_transparent_bg else bg_color,
shadow_color=shadow_color,
shadow_offset=[offset_x, offset_y],
shadow_intensity=shadow_intensity,
shadow_blur=shadow_blur,
shadow_width=shadow_width if shadow_type == "Float" else None,
shadow_height=shadow_height if shadow_type == "Float" else 70,
sku=sku if sku else None,
force_rmbg=force_rmbg,
content_moderation=content_moderation
)
if result and "result_url" in result:
# Add to generated images
if 'generated_images' not in st.session_state:
st.session_state.generated_images = []
st.session_state.generated_images.append({
'url': result["result_url"],
'prompt': f"{shadow_type} shadow effect" + (f" - SKU: {sku}" if sku else ""),
'source': 'shadow',
'timestamp': time.time()
})
st.success("✨ Shadow added successfully!")
else:
st.error("No result URL in the API response. Please try again.")
except Exception as e:
st.error(f"Error adding shadow: {str(e)}")
if "422" in str(e):
st.warning("Content moderation failed. Please ensure the image is appropriate.")
# ... (rest of the code remains the same)
elif edit_option == "Lifestyle Shot":
shot_type = st.radio("Shot Type", ["Text Prompt", "Reference Image"])
# Common settings for both types
col1, col2 = st.columns(2)
with col1:
placement_type = st.selectbox("Placement Type", [
"Original", "Automatic", "Manual Placement",
"Manual Padding", "Custom Coordinates"
])
num_results = st.slider("Number of Results", 1, 8, 4)
sync_mode = st.checkbox("Synchronous Mode", False,
help="Wait for results instead of getting URLs immediately")
original_quality = st.checkbox("Original Quality", False,
help="Maintain original image quality")
if placement_type == "Manual Placement":
positions = st.multiselect("Select Positions", [
"Upper Left", "Upper Right", "Bottom Left", "Bottom Right",
"Right Center", "Left Center", "Upper Center",
"Bottom Center", "Center Vertical", "Center Horizontal"
], ["Upper Left"])
elif placement_type == "Manual Padding":
st.subheader("Padding Values (pixels)")
pad_left = st.number_input("Left Padding", 0, 1000, 0)
pad_right = st.number_input("Right Padding", 0, 1000, 0)
pad_top = st.number_input("Top Padding", 0, 1000, 0)
pad_bottom = st.number_input("Bottom Padding", 0, 1000, 0)
elif placement_type in ["Automatic", "Manual Placement", "Custom Coordinates"]:
st.subheader("Shot Size")
shot_width = st.number_input("Width", 100, 2000, 1000)
shot_height = st.number_input("Height", 100, 2000, 1000)
with col2:
if placement_type == "Custom Coordinates":
st.subheader("Product Position")
fg_width = st.number_input("Product Width", 50, 1000, 500)
fg_height = st.number_input("Product Height", 50, 1000, 500)
fg_x = st.number_input("X Position", -500, 1500, 0)
fg_y = st.number_input("Y Position", -500, 1500, 0)
sku = st.text_input("SKU (optional)")
force_rmbg = st.checkbox("Force Background Removal", False)
content_moderation = st.checkbox("Enable Content Moderation", False)
if shot_type == "Text Prompt":
fast_mode = st.checkbox("Fast Mode", True,
help="Balance between speed and quality")
optimize_desc = st.checkbox("Optimize Description", True,
help="Enhance scene description using AI")
if not fast_mode:
exclude_elements = st.text_area("Exclude Elements (optional)",
help="Elements to exclude from the generated scene")
else: # Reference Image
enhance_ref = st.checkbox("Enhance Reference Image", True,
help="Improve lighting, shadows, and texture")
ref_influence = st.slider("Reference Influence", 0.0, 1.0, 1.0,
help="Control similarity to reference image")
if shot_type == "Text Prompt":
prompt = st.text_area("Describe the environment")
if st.button("Generate Lifestyle Shot") and prompt:
with st.spinner("Generating lifestyle shot..."):
try:
# Convert placement selections to API format
if placement_type == "Manual Placement":
manual_placements = [p.lower().replace(" ", "_") for p in positions]
else:
manual_placements = ["upper_left"]
result = lifestyle_shot_by_text(
api_key=st.session_state.api_key,
image_data=uploaded_file.getvalue(),
scene_description=prompt,
placement_type=placement_type.lower().replace(" ", "_"),
num_results=num_results,
sync=sync_mode,
fast=fast_mode,
optimize_description=optimize_desc,
shot_size=[shot_width, shot_height] if placement_type != "Original" else [1000, 1000],
original_quality=original_quality,
exclude_elements=exclude_elements if not fast_mode else None,
manual_placement_selection=manual_placements,
padding_values=[pad_left, pad_right, pad_top, pad_bottom] if placement_type == "Manual Padding" else [0, 0, 0, 0],
foreground_image_size=[fg_width, fg_height] if placement_type == "Custom Coordinates" else None,
foreground_image_location=[fg_x, fg_y] if placement_type == "Custom Coordinates" else None,
force_rmbg=force_rmbg,
content_moderation=content_moderation,
sku=sku if sku else None
)
if result:
# Debug logging
st.write("Debug - Raw API Response:", result)
if sync_mode:
if isinstance(result, dict):
if "result_url" in result:
st.session_state.edited_image = result["result_url"]
st.success("✨ Image generated successfully!")
elif "result_urls" in result:
st.session_state.edited_image = result["result_urls"][0]
st.success("✨ Image generated successfully!")
elif "result" in result and isinstance(result["result"], list):
for item in result["result"]:
if isinstance(item, dict) and "urls" in item:
st.session_state.edited_image = item["urls"][0]
st.success("✨ Image generated successfully!")
break
elif isinstance(item, list) and len(item) > 0:
st.session_state.edited_image = item[0]
st.success("✨ Image generated successfully!")
break
elif "urls" in result:
st.session_state.edited_image = result["urls"][0]
st.success("✨ Image generated successfully!")
else:
urls = []
if isinstance(result, dict):
if "urls" in result:
urls.extend(result["urls"][:num_results]) # Limit to requested number
elif "result" in result and isinstance(result["result"], list):
# Process each result item
for item in result["result"]:
if isinstance(item, dict) and "urls" in item:
urls.extend(item["urls"])
elif isinstance(item, list):
urls.extend(item)
# Break if we have enough URLs
if len(urls) >= num_results:
break
# Trim to requested number
urls = urls[:num_results]
if urls:
st.session_state.pending_urls = urls
# Create a container for status messages
status_container = st.empty()
refresh_container = st.empty()
# Show initial status
status_container.info(f"🎨 Generation started! Waiting for {len(urls)} image{'s' if len(urls) > 1 else ''}...")
# Try automatic checking first
if auto_check_images(status_container):
st.experimental_rerun()
# Add refresh button for manual checking
if refresh_container.button("🔄 Check for Generated Images"):
with st.spinner("Checking for completed images..."):
if check_generated_images():
status_container.success("✨ Image ready!")
st.experimental_rerun()
else:
status_container.warning(f"⏳ Still generating your image{'s' if len(urls) > 1 else ''}... Please check again in a moment.")
except Exception as e:
st.error(f"Error: {str(e)}")
if "422" in str(e):
st.warning("Content moderation failed. Please ensure the content is appropriate.")
else:
ref_image = st.file_uploader("Upload Reference Image", type=["png", "jpg", "jpeg"], key="ref_upload")
if st.button("Generate Lifestyle Shot") and ref_image:
with st.spinner("Generating lifestyle shot..."):
try:
# Convert placement selections to API format
if placement_type == "Manual Placement":
manual_placements = [p.lower().replace(" ", "_") for p in positions]
else:
manual_placements = ["upper_left"]
result = lifestyle_shot_by_image(
api_key=st.session_state.api_key,
image_data=uploaded_file.getvalue(),
reference_image=ref_image.getvalue(),
placement_type=placement_type.lower().replace(" ", "_"),
num_results=num_results,
sync=sync_mode,
shot_size=[shot_width, shot_height] if placement_type != "Original" else [1000, 1000],
original_quality=original_quality,
manual_placement_selection=manual_placements,
padding_values=[pad_left, pad_right, pad_top, pad_bottom] if placement_type == "Manual Padding" else [0, 0, 0, 0],
foreground_image_size=[fg_width, fg_height] if placement_type == "Custom Coordinates" else None,
foreground_image_location=[fg_x, fg_y] if placement_type == "Custom Coordinates" else None,
force_rmbg=force_rmbg,
content_moderation=content_moderation,
sku=sku if sku else None,
enhance_ref_image=enhance_ref,
ref_image_influence=ref_influence
)
if result:
# Debug logging
st.write("Debug - Raw API Response:", result)
if sync_mode:
if isinstance(result, dict):
if "result_url" in result:
st.session_state.edited_image = result["result_url"]
st.success("✨ Image generated successfully!")
elif "result_urls" in result:
st.session_state.edited_image = result["result_urls"][0]
st.success("✨ Image generated successfully!")
elif "result" in result and isinstance(result["result"], list):
for item in result["result"]:
if isinstance(item, dict) and "urls" in item:
st.session_state.edited_image = item["urls"][0]
st.success("✨ Image generated successfully!")
break
elif isinstance(item, list) and len(item) > 0:
st.session_state.edited_image = item[0]
st.success("✨ Image generated successfully!")
break
elif "urls" in result:
st.session_state.edited_image = result["urls"][0]
st.success("✨ Image generated successfully!")
else:
urls = []
if isinstance(result, dict):
if "urls" in result:
urls.extend(result["urls"][:num_results]) # Limit to requested number
elif "result" in result and isinstance(result["result"], list):
# Process each result item
for item in result["result"]:
if isinstance(item, dict) and "urls" in item:
urls.extend(item["urls"])
elif isinstance(item, list):
urls.extend(item)
# Break if we have enough URLs
if len(urls) >= num_results:
break
# Trim to requested number
urls = urls[:num_results]
if urls:
st.session_state.pending_urls = urls
# Create a container for status messages
status_container = st.empty()
refresh_container = st.empty()
# Show initial status
status_container.info(f"🎨 Generation started! Waiting for {len(urls)} image{'s' if len(urls) > 1 else ''}...")
# Try automatic checking first
if auto_check_images(status_container):
st.experimental_rerun()
# Add refresh button for manual checking
if refresh_container.button("🔄 Check for Generated Images"):
with st.spinner("Checking for completed images..."):
if check_generated_images():
status_container.success("✨ Image ready!")
st.experimental_rerun()
else:
status_container.warning(f"⏳ Still generating your image{'s' if len(urls) > 1 else ''}... Please check again in a moment.")
except Exception as e:
st.error(f"Error: {str(e)}")
if "422" in str(e):
st.warning("Content moderation failed. Please ensure the content is appropriate.")
with col2:
if st.session_state.edited_image:
st.image(st.session_state.edited_image, caption="Edited Image", use_column_width=True)
image_data = download_image(st.session_state.edited_image)
if image_data:
st.download_button(
"⬇️ Download Result",
image_data,
"edited_product.png",
"image/png"
)
elif st.session_state.pending_urls:
st.info("Images are being generated. Click the refresh button above to check if they're ready.")
# Display generated images from product photography
if 'generated_images' in st.session_state and st.session_state.generated_images:
product_generated = [img for img in st.session_state.generated_images if img.get('source') in ['packshot', 'shadow', 'lifestyle']]
if product_generated:
st.subheader("🎨 Generated Product Images")
for i, image_data in enumerate(product_generated):
with st.expander(f"Product Image {i+1} - {image_data['prompt'][:50]}..."):
col1, col2 = st.columns([2, 1])
with col1:
st.image(image_data['url'], caption="Generated Image", use_column_width=True)
with col2:
st.markdown("**Description:**")
st.text(image_data['prompt'])
st.markdown(f"**Type:** {image_data['source'].title()}")
# Download button
try:
import requests
response = requests.get(image_data['url'])
if response.status_code == 200:
st.download_button(
label="📥 Download Image",
data=response.content,
file_name=f"{image_data['source']}_image_{i+1}.png",
mime="image/png"
)
else:
st.warning("Image not yet ready for download")
except Exception as e:
st.error(f"Download error: {str(e)}")
# Generative Fill Tab
with tabs[3]:
st.header("🎨 Generative Fill")
st.markdown("Draw a mask on the image and describe what you want to generate in that area.")
uploaded_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"], key="fill_upload")
if uploaded_file:
# Create columns for original image and canvas
col1, col2 = st.columns(2)
with col1:
# Display original image
st.image(uploaded_file, caption="Original Image", use_column_width=True)
# Get image dimensions for canvas
img = Image.open(uploaded_file)
img_width, img_height = img.size
# Calculate aspect ratio and set canvas height
aspect_ratio = img_height / img_width
canvas_width = min(img_width, 800) # Max width of 800px
canvas_height = int(canvas_width * aspect_ratio)
# Resize image to match canvas dimensions
img = img.resize((canvas_width, canvas_height))
# Convert to RGB if necessary
if img.mode != 'RGB':
img = img.convert('RGB')
# Convert to numpy array with proper shape and type
img_array = np.array(img).astype(np.uint8)
# Add drawing canvas using Streamlit's drawing canvas component
stroke_width = st.slider("Brush width", 1, 50, 20)
stroke_color = st.color_picker("Brush color", "#fff")
drawing_mode = "freedraw"
# Create canvas with background image
canvas_result = st_canvas(
fill_color="rgba(255, 255, 255, 0.0)", # Transparent fill
stroke_width=stroke_width,
stroke_color=stroke_color,
drawing_mode=drawing_mode,
background_color="", # Transparent background
background_image=img if img_array.shape[-1] == 3 else None, # Only pass RGB images
height=canvas_height,
width=canvas_width,
key="canvas",
)
# Options for generation
st.subheader("Generation Options")
prompt = st.text_area("Describe what to generate in the masked area")
negative_prompt = st.text_area("Describe what to avoid (optional)")
col_a, col_b = st.columns(2)
with col_a:
num_results = st.slider("Number of variations", 1, 4, 1)
sync_mode = st.checkbox("Synchronous Mode", False,
help="Wait for results instead of getting URLs immediately",
key="gen_fill_sync_mode")
with col_b:
seed = st.number_input("Seed (optional)", min_value=0, value=0,
help="Use same seed to reproduce results")
content_moderation = st.checkbox("Enable Content Moderation", False,
key="gen_fill_content_mod")
if st.button("🎨 Generate", type="primary"):
if not prompt:
st.error("Please enter a prompt describing what to generate.")
return
if canvas_result.image_data is None:
st.error("Please draw a mask on the image first.")
return
# Convert canvas result to mask
mask_img = Image.fromarray(canvas_result.image_data.astype('uint8'), mode='RGBA')
mask_img = mask_img.convert('L')
# Convert mask to bytes
mask_bytes = io.BytesIO()
mask_img.save(mask_bytes, format='PNG')
mask_bytes = mask_bytes.getvalue()
# Convert uploaded image to bytes
image_bytes = uploaded_file.getvalue()
with st.spinner("🎨 Generating..."):
try:
result = generative_fill(
st.session_state.api_key,
image_bytes,
mask_bytes,
prompt,
negative_prompt=negative_prompt if negative_prompt else None,
num_results=num_results,
sync=sync_mode,
seed=seed if seed != 0 else None,
content_moderation=content_moderation
)
if result:
st.write("Debug - API Response:", result)
if sync_mode:
if "urls" in result and result["urls"]:
st.session_state.edited_image = result["urls"][0]
if len(result["urls"]) > 1:
st.session_state.generated_images = result["urls"]
st.success("✨ Generation complete!")
elif "result_url" in result:
st.session_state.edited_image = result["result_url"]
st.success("✨ Generation complete!")
else:
if "urls" in result:
st.session_state.pending_urls = result["urls"][:num_results]
# Create containers for status
status_container = st.empty()
refresh_container = st.empty()
# Show initial status
status_container.info(f"🎨 Generation started! Waiting for {len(st.session_state.pending_urls)} image{'s' if len(st.session_state.pending_urls) > 1 else ''}...")
# Try automatic checking
if auto_check_images(status_container):
st.rerun()
# Add refresh button
if refresh_container.button("🔄 Check for Generated Images"):
if check_generated_images():
status_container.success("✨ Images ready!")
st.rerun()
else:
status_container.warning("⏳ Still generating... Please check again in a moment.")
except Exception as e:
st.error(f"Error: {str(e)}")
st.write("Full error details:", str(e))
with col2:
if st.session_state.edited_image:
st.image(st.session_state.edited_image, caption="Generated Result", use_column_width=True)
image_data = download_image(st.session_state.edited_image)
if image_data:
st.download_button(
"⬇️ Download Result",
image_data,
"generated_fill.png",
"image/png"
)
elif st.session_state.pending_urls:
st.info("Generation in progress. Click the refresh button above to check status.")
# Erase Elements Tab
with tabs[4]:
st.header("🎨 Erase Elements")
st.markdown("Upload an image and select the area you want to erase.")
uploaded_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"], key="erase_upload")
if uploaded_file:
col1, col2 = st.columns(2)
with col1:
# Display original image
st.image(uploaded_file, caption="Original Image", use_column_width=True)
# Get image dimensions for canvas
img = Image.open(uploaded_file)
img_width, img_height = img.size
# Calculate aspect ratio and set canvas height
aspect_ratio = img_height / img_width
canvas_width = min(img_width, 800) # Max width of 800px
canvas_height = int(canvas_width * aspect_ratio)
# Resize image to match canvas dimensions
img = img.resize((canvas_width, canvas_height))
# Convert to RGB if necessary
if img.mode != 'RGB':
img = img.convert('RGB')
# Add drawing canvas using Streamlit's drawing canvas component
stroke_width = st.slider("Brush width", 1, 50, 20, key="erase_brush_width")
stroke_color = st.color_picker("Brush color", "#fff", key="erase_brush_color")
# Create canvas with background image
canvas_result = st_canvas(
fill_color="rgba(255, 255, 255, 0.0)", # Transparent fill
stroke_width=stroke_width,
stroke_color=stroke_color,
background_color="", # Transparent background
background_image=img, # Pass PIL Image directly
drawing_mode="freedraw",
height=canvas_height,
width=canvas_width,
key="erase_canvas",
)
# Options for erasing
st.subheader("Erase Options")
content_moderation = st.checkbox("Enable Content Moderation", False, key="erase_content_mod")
if st.button("🎨 Erase Selected Area", key="erase_btn"):
if not canvas_result.image_data is None:
with st.spinner("Erasing selected area..."):
try:
# Convert canvas result to mask
mask_img = Image.fromarray(canvas_result.image_data.astype('uint8'), mode='RGBA')
mask_img = mask_img.convert('L')
# Convert uploaded image to bytes
image_bytes = uploaded_file.getvalue()
result = erase_foreground(
st.session_state.api_key,
image_data=image_bytes,
content_moderation=content_moderation
)
if result:
if "result_url" in result:
st.session_state.edited_image = result["result_url"]
st.success("✨ Area erased successfully!")
else:
st.error("No result URL in the API response. Please try again.")
except Exception as e:
st.error(f"Error: {str(e)}")
if "422" in str(e):
st.warning("Content moderation failed. Please ensure the image is appropriate.")
else:
st.warning("Please draw on the image to select the area to erase.")
with col2:
if st.session_state.edited_image:
st.image(st.session_state.edited_image, caption="Result", use_column_width=True)
image_data = download_image(st.session_state.edited_image)
if image_data:
st.download_button(
"⬇️ Download Result",
image_data,
"erased_image.png",
"image/png",
key="erase_download"
)
if __name__ == "__main__":
main()