import os import sys import streamlit as st from dotenv import load_dotenv # Add the project root to the Python path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # Now import from src from src.services import ( lifestyle_shot_by_image, lifestyle_shot_by_text, add_shadow, create_packshot, enhance_prompt, generative_fill, generate_hd_image, erase_foreground ) from src.components.voice_to_image import render_voice_to_image_section from PIL import Image import io import requests import json import time import base64 from streamlit_drawable_canvas import st_canvas import numpy as np from src.services.erase_foreground import erase_foreground import pkg_resources print([p.project_name for p in pkg_resources.working_set]) # Configure Streamlit page st.set_page_config( page_title="AdSnap Studio", page_icon="🎨", layout="wide", initial_sidebar_state="expanded" ) # Load environment variables print("Loading environment variables...") load_dotenv(verbose=True) # Add verbose=True to see loading details # Debug: Print environment variable status api_key = os.getenv("BRIA_API_KEY") print(f"API Key present: {bool(api_key)}") print(f"API Key value: {api_key if api_key else 'Not found'}") print(f"Current working directory: {os.getcwd()}") print(f".env file exists: {os.path.exists('.env')}") def initialize_session_state(): """Initialize session state variables.""" if 'api_key' not in st.session_state: st.session_state.api_key = os.getenv('BRIA_API_KEY') if 'generated_images' not in st.session_state: st.session_state.generated_images = [] if 'current_image' not in st.session_state: st.session_state.current_image = None if 'pending_urls' not in st.session_state: st.session_state.pending_urls = [] if 'edited_image' not in st.session_state: st.session_state.edited_image = None if 'original_prompt' not in st.session_state: st.session_state.original_prompt = "" if 'enhanced_prompt' not in st.session_state: st.session_state.enhanced_prompt = None def download_image(url): """Download image from URL and return as bytes.""" try: response = requests.get(url) response.raise_for_status() return response.content except Exception as e: st.error(f"Error downloading image: {str(e)}") return None def apply_image_filter(image, filter_type): """Apply various filters to the image.""" try: img = Image.open(io.BytesIO(image)) if isinstance(image, bytes) else Image.open(image) if filter_type == "Grayscale": return img.convert('L') elif filter_type == "Sepia": width, height = img.size pixels = img.load() for x in range(width): for y in range(height): r, g, b = img.getpixel((x, y))[:3] tr = int(0.393 * r + 0.769 * g + 0.189 * b) tg = int(0.349 * r + 0.686 * g + 0.168 * b) tb = int(0.272 * r + 0.534 * g + 0.131 * b) img.putpixel((x, y), (min(tr, 255), min(tg, 255), min(tb, 255))) return img elif filter_type == "High Contrast": return img.point(lambda x: x * 1.5) elif filter_type == "Blur": return img.filter(Image.BLUR) else: return img except Exception as e: st.error(f"Error applying filter: {str(e)}") return None def check_generated_images(): """Check if pending images are ready and update the display.""" if st.session_state.pending_urls: ready_images = [] still_pending = [] for url in st.session_state.pending_urls: try: response = requests.head(url) # Consider an image ready if we get a 200 response with any content length if response.status_code == 200: ready_images.append(url) else: still_pending.append(url) except Exception as e: still_pending.append(url) # Update the pending URLs list st.session_state.pending_urls = still_pending # If we found any ready images, update the display if ready_images: st.session_state.edited_image = ready_images[0] # Display the first ready image if len(ready_images) > 1: st.session_state.generated_images = ready_images # Store all ready images return True return False def auto_check_images(status_container): """Automatically check for image completion a few times.""" max_attempts = 3 attempt = 0 while attempt < max_attempts and st.session_state.pending_urls: time.sleep(2) # Wait 2 seconds between checks if check_generated_images(): status_container.success("✨ Image ready!") return True attempt += 1 return False def main(): st.title("AdSnap Studio") initialize_session_state() # Sidebar for API key with st.sidebar: st.header("Settings") api_key = st.text_input("Enter your API key:", value=st.session_state.api_key if st.session_state.api_key else "", type="password") if api_key: st.session_state.api_key = api_key # Main tabs tabs = st.tabs([ "🎨 Generate Image", "🎤 Voice to Image", "🖼️ Lifestyle Shot", "🎨 Generative Fill", "🎨 Erase Elements" ]) # Generate Images Tab with tabs[0]: st.header("Generate Images") col1, col2 = st.columns([2, 1]) with col1: # Prompt input prompt = st.text_area("Enter your prompt", value="", height=100, key="prompt_input") # Store original prompt in session state when it changes if "original_prompt" not in st.session_state: st.session_state.original_prompt = prompt elif prompt != st.session_state.original_prompt: st.session_state.original_prompt = prompt st.session_state.enhanced_prompt = None # Reset enhanced prompt when original changes # Enhanced prompt display if st.session_state.get('enhanced_prompt'): st.markdown("**Enhanced Prompt:**") st.markdown(f"*{st.session_state.enhanced_prompt}*") # Enhance Prompt button if st.button("✨ Enhance Prompt", key="enhance_button"): if not prompt: st.warning("Please enter a prompt to enhance.") else: with st.spinner("Enhancing prompt..."): try: result = enhance_prompt(st.session_state.api_key, prompt) if result: st.session_state.enhanced_prompt = result st.success("Prompt enhanced!") st.experimental_rerun() # Rerun to update the display except Exception as e: st.error(f"Error enhancing prompt: {str(e)}") # Debug information st.write("Debug - Session State:", { "original_prompt": st.session_state.get("original_prompt"), "enhanced_prompt": st.session_state.get("enhanced_prompt") }) with col2: num_images = st.slider("Number of images", 1, 4, 1) aspect_ratio = st.selectbox("Aspect ratio", ["1:1", "16:9", "9:16", "4:3", "3:4"]) enhance_img = st.checkbox("Enhance image quality", value=True) # Style options st.subheader("Style Options") style = st.selectbox("Image Style", [ "Realistic", "Artistic", "Cartoon", "Sketch", "Watercolor", "Oil Painting", "Digital Art" ]) # Add style to prompt if style and style != "Realistic": prompt = f"{prompt}, in {style.lower()} style" # Generate button if st.button("🎨 Generate Images", type="primary"): if not st.session_state.api_key: st.error("Please enter your API key in the sidebar.") return with st.spinner("🎨 Generating your masterpiece..."): try: # Convert aspect ratio to proper format result = generate_hd_image( prompt=st.session_state.enhanced_prompt or prompt, api_key=st.session_state.api_key, num_results=num_images, aspect_ratio=aspect_ratio, # Already in correct format (e.g. "1:1") sync=True, # Wait for results enhance_image=enhance_img, medium="art" if style != "Realistic" else "photography", prompt_enhancement=False, # We're already using our own prompt enhancement content_moderation=True # Enable content moderation by default ) if result: # Extract image URL from the result - handle multiple possible response formats image_url = None if isinstance(result, dict): # Format 1: {"result": [{"urls": ["..."]}]} or {"result": [{"url": "..."}]} if 'result' in result and result['result']: if isinstance(result['result'], list) and len(result['result']) > 0: first_result = result['result'][0] if isinstance(first_result, dict): # Check for "urls" array first (actual API format) if 'urls' in first_result and first_result['urls']: image_url = first_result['urls'][0] # Fallback to "url" single value elif 'url' in first_result: image_url = first_result['url'] elif isinstance(result['result'], dict): if 'urls' in result['result'] and result['result']['urls']: image_url = result['result']['urls'][0] elif 'url' in result['result']: image_url = result['result']['url'] # Format 2: {"result_url": "..."} elif 'result_url' in result: image_url = result['result_url'] # Format 3: {"url": "..."} elif 'url' in result: image_url = result['url'] # Format 4: {"result_urls": ["..."]} elif 'result_urls' in result and result['result_urls']: image_url = result['result_urls'][0] if image_url: # Add to generated images if 'generated_images' not in st.session_state: st.session_state.generated_images = [] st.session_state.generated_images.append({ 'url': image_url, 'prompt': st.session_state.enhanced_prompt or prompt, 'source': 'text', 'timestamp': time.time() }) st.success("✨ Image generated successfully!") else: st.error("❌ Failed to generate image. Please try again.") except Exception as e: st.error(f"Error generating images: {str(e)}") st.write("Full error:", str(e)) # Display generated images from text prompts if 'generated_images' in st.session_state and st.session_state.generated_images: text_generated = [img for img in st.session_state.generated_images if img.get('source') == 'text'] if text_generated: st.subheader("🎨 Generated Images") for i, image_data in enumerate(text_generated): with st.expander(f"Generated Image {i+1} - {image_data['prompt'][:50]}..."): col1, col2 = st.columns([2, 1]) with col1: st.image(image_data['url'], caption="Generated Image", use_column_width=True) with col2: st.markdown("**Original Prompt:**") st.text(image_data['prompt']) # Download button try: import requests response = requests.get(image_data['url']) if response.status_code == 200: st.download_button( label="📥 Download Image", data=response.content, file_name=f"generated_image_{i+1}.png", mime="image/png" ) else: st.warning("Image not yet ready for download") except Exception as e: st.error(f"Download error: {str(e)}") # Voice to Image Tab with tabs[1]: render_voice_to_image_section() # Product Photography Tab with tabs[2]: st.header("Product Photography") uploaded_file = st.file_uploader("Upload Product Image", type=["png", "jpg", "jpeg"], key="product_upload") if uploaded_file: col1, col2 = st.columns(2) with col1: st.image(uploaded_file, caption="Original Image", use_column_width=True) # Product editing options edit_option = st.selectbox("Select Edit Option", [ "Create Packshot", "Add Shadow", "Lifestyle Shot" ]) if edit_option == "Create Packshot": col_a, col_b = st.columns(2) with col_a: bg_color = st.color_picker("Background Color", "#FFFFFF") sku = st.text_input("SKU (optional)", "") with col_b: force_rmbg = st.checkbox("Force Background Removal", False) content_moderation = st.checkbox("Enable Content Moderation", False) if st.button("Create Packshot"): with st.spinner("Creating professional packshot..."): try: # First remove background if needed if force_rmbg: from services.background_service import remove_background bg_result = remove_background( st.session_state.api_key, uploaded_file.getvalue(), content_moderation=content_moderation ) if bg_result and "result_url" in bg_result: # Download the background-removed image response = requests.get(bg_result["result_url"]) if response.status_code == 200: image_data = response.content else: st.error("Failed to download background-removed image") return else: st.error("Background removal failed") return else: image_data = uploaded_file.getvalue() # Now create packshot result = create_packshot( st.session_state.api_key, image_data, background_color=bg_color, sku=sku if sku else None, force_rmbg=force_rmbg, content_moderation=content_moderation ) if result and "result_url" in result: # Add to generated images if 'generated_images' not in st.session_state: st.session_state.generated_images = [] st.session_state.generated_images.append({ 'url': result["result_url"], 'prompt': f"Packshot with {bg_color} background" + (f" - SKU: {sku}" if sku else ""), 'source': 'packshot', 'timestamp': time.time() }) st.success("✨ Packshot created successfully!") else: st.error("No result URL in the API response. Please try again.") except Exception as e: st.error(f"Error creating packshot: {str(e)}") if "422" in str(e): st.warning("Content moderation failed. Please ensure the image is appropriate.") elif edit_option == "Add Shadow": col_a, col_b = st.columns(2) with col_a: shadow_type = st.selectbox("Shadow Type", ["Natural", "Drop"]) bg_color = st.color_picker("Background Color (optional)", "#FFFFFF") use_transparent_bg = st.checkbox("Use Transparent Background", True) shadow_color = st.color_picker("Shadow Color", "#000000") sku = st.text_input("SKU (optional)", "") # Shadow offset st.subheader("Shadow Offset") offset_x = st.slider("X Offset", -50, 50, 0) offset_y = st.slider("Y Offset", -50, 50, 15) with col_b: shadow_intensity = st.slider("Shadow Intensity", 0, 100, 60) shadow_blur = st.slider("Shadow Blur", 0, 50, 15 if shadow_type.lower() == "regular" else 20) # Float shadow specific controls if shadow_type == "Float": st.subheader("Float Shadow Settings") shadow_width = st.slider("Shadow Width", -100, 100, 0) shadow_height = st.slider("Shadow Height", -100, 100, 70) force_rmbg = st.checkbox("Force Background Removal", False) content_moderation = st.checkbox("Enable Content Moderation", False) if st.button("Add Shadow"): with st.spinner("Adding shadow effect..."): try: result = add_shadow( api_key=st.session_state.api_key, image_data=uploaded_file.getvalue(), shadow_type=shadow_type.lower(), background_color=None if use_transparent_bg else bg_color, shadow_color=shadow_color, shadow_offset=[offset_x, offset_y], shadow_intensity=shadow_intensity, shadow_blur=shadow_blur, shadow_width=shadow_width if shadow_type == "Float" else None, shadow_height=shadow_height if shadow_type == "Float" else 70, sku=sku if sku else None, force_rmbg=force_rmbg, content_moderation=content_moderation ) if result and "result_url" in result: # Add to generated images if 'generated_images' not in st.session_state: st.session_state.generated_images = [] st.session_state.generated_images.append({ 'url': result["result_url"], 'prompt': f"{shadow_type} shadow effect" + (f" - SKU: {sku}" if sku else ""), 'source': 'shadow', 'timestamp': time.time() }) st.success("✨ Shadow added successfully!") else: st.error("No result URL in the API response. Please try again.") except Exception as e: st.error(f"Error adding shadow: {str(e)}") if "422" in str(e): st.warning("Content moderation failed. Please ensure the image is appropriate.") # ... (rest of the code remains the same) elif edit_option == "Lifestyle Shot": shot_type = st.radio("Shot Type", ["Text Prompt", "Reference Image"]) # Common settings for both types col1, col2 = st.columns(2) with col1: placement_type = st.selectbox("Placement Type", [ "Original", "Automatic", "Manual Placement", "Manual Padding", "Custom Coordinates" ]) num_results = st.slider("Number of Results", 1, 8, 4) sync_mode = st.checkbox("Synchronous Mode", False, help="Wait for results instead of getting URLs immediately") original_quality = st.checkbox("Original Quality", False, help="Maintain original image quality") if placement_type == "Manual Placement": positions = st.multiselect("Select Positions", [ "Upper Left", "Upper Right", "Bottom Left", "Bottom Right", "Right Center", "Left Center", "Upper Center", "Bottom Center", "Center Vertical", "Center Horizontal" ], ["Upper Left"]) elif placement_type == "Manual Padding": st.subheader("Padding Values (pixels)") pad_left = st.number_input("Left Padding", 0, 1000, 0) pad_right = st.number_input("Right Padding", 0, 1000, 0) pad_top = st.number_input("Top Padding", 0, 1000, 0) pad_bottom = st.number_input("Bottom Padding", 0, 1000, 0) elif placement_type in ["Automatic", "Manual Placement", "Custom Coordinates"]: st.subheader("Shot Size") shot_width = st.number_input("Width", 100, 2000, 1000) shot_height = st.number_input("Height", 100, 2000, 1000) with col2: if placement_type == "Custom Coordinates": st.subheader("Product Position") fg_width = st.number_input("Product Width", 50, 1000, 500) fg_height = st.number_input("Product Height", 50, 1000, 500) fg_x = st.number_input("X Position", -500, 1500, 0) fg_y = st.number_input("Y Position", -500, 1500, 0) sku = st.text_input("SKU (optional)") force_rmbg = st.checkbox("Force Background Removal", False) content_moderation = st.checkbox("Enable Content Moderation", False) if shot_type == "Text Prompt": fast_mode = st.checkbox("Fast Mode", True, help="Balance between speed and quality") optimize_desc = st.checkbox("Optimize Description", True, help="Enhance scene description using AI") if not fast_mode: exclude_elements = st.text_area("Exclude Elements (optional)", help="Elements to exclude from the generated scene") else: # Reference Image enhance_ref = st.checkbox("Enhance Reference Image", True, help="Improve lighting, shadows, and texture") ref_influence = st.slider("Reference Influence", 0.0, 1.0, 1.0, help="Control similarity to reference image") if shot_type == "Text Prompt": prompt = st.text_area("Describe the environment") if st.button("Generate Lifestyle Shot") and prompt: with st.spinner("Generating lifestyle shot..."): try: # Convert placement selections to API format if placement_type == "Manual Placement": manual_placements = [p.lower().replace(" ", "_") for p in positions] else: manual_placements = ["upper_left"] result = lifestyle_shot_by_text( api_key=st.session_state.api_key, image_data=uploaded_file.getvalue(), scene_description=prompt, placement_type=placement_type.lower().replace(" ", "_"), num_results=num_results, sync=sync_mode, fast=fast_mode, optimize_description=optimize_desc, shot_size=[shot_width, shot_height] if placement_type != "Original" else [1000, 1000], original_quality=original_quality, exclude_elements=exclude_elements if not fast_mode else None, manual_placement_selection=manual_placements, padding_values=[pad_left, pad_right, pad_top, pad_bottom] if placement_type == "Manual Padding" else [0, 0, 0, 0], foreground_image_size=[fg_width, fg_height] if placement_type == "Custom Coordinates" else None, foreground_image_location=[fg_x, fg_y] if placement_type == "Custom Coordinates" else None, force_rmbg=force_rmbg, content_moderation=content_moderation, sku=sku if sku else None ) if result: # Debug logging st.write("Debug - Raw API Response:", result) if sync_mode: if isinstance(result, dict): if "result_url" in result: st.session_state.edited_image = result["result_url"] st.success("✨ Image generated successfully!") elif "result_urls" in result: st.session_state.edited_image = result["result_urls"][0] st.success("✨ Image generated successfully!") elif "result" in result and isinstance(result["result"], list): for item in result["result"]: if isinstance(item, dict) and "urls" in item: st.session_state.edited_image = item["urls"][0] st.success("✨ Image generated successfully!") break elif isinstance(item, list) and len(item) > 0: st.session_state.edited_image = item[0] st.success("✨ Image generated successfully!") break elif "urls" in result: st.session_state.edited_image = result["urls"][0] st.success("✨ Image generated successfully!") else: urls = [] if isinstance(result, dict): if "urls" in result: urls.extend(result["urls"][:num_results]) # Limit to requested number elif "result" in result and isinstance(result["result"], list): # Process each result item for item in result["result"]: if isinstance(item, dict) and "urls" in item: urls.extend(item["urls"]) elif isinstance(item, list): urls.extend(item) # Break if we have enough URLs if len(urls) >= num_results: break # Trim to requested number urls = urls[:num_results] if urls: st.session_state.pending_urls = urls # Create a container for status messages status_container = st.empty() refresh_container = st.empty() # Show initial status status_container.info(f"🎨 Generation started! Waiting for {len(urls)} image{'s' if len(urls) > 1 else ''}...") # Try automatic checking first if auto_check_images(status_container): st.experimental_rerun() # Add refresh button for manual checking if refresh_container.button("🔄 Check for Generated Images"): with st.spinner("Checking for completed images..."): if check_generated_images(): status_container.success("✨ Image ready!") st.experimental_rerun() else: status_container.warning(f"⏳ Still generating your image{'s' if len(urls) > 1 else ''}... Please check again in a moment.") except Exception as e: st.error(f"Error: {str(e)}") if "422" in str(e): st.warning("Content moderation failed. Please ensure the content is appropriate.") else: ref_image = st.file_uploader("Upload Reference Image", type=["png", "jpg", "jpeg"], key="ref_upload") if st.button("Generate Lifestyle Shot") and ref_image: with st.spinner("Generating lifestyle shot..."): try: # Convert placement selections to API format if placement_type == "Manual Placement": manual_placements = [p.lower().replace(" ", "_") for p in positions] else: manual_placements = ["upper_left"] result = lifestyle_shot_by_image( api_key=st.session_state.api_key, image_data=uploaded_file.getvalue(), reference_image=ref_image.getvalue(), placement_type=placement_type.lower().replace(" ", "_"), num_results=num_results, sync=sync_mode, shot_size=[shot_width, shot_height] if placement_type != "Original" else [1000, 1000], original_quality=original_quality, manual_placement_selection=manual_placements, padding_values=[pad_left, pad_right, pad_top, pad_bottom] if placement_type == "Manual Padding" else [0, 0, 0, 0], foreground_image_size=[fg_width, fg_height] if placement_type == "Custom Coordinates" else None, foreground_image_location=[fg_x, fg_y] if placement_type == "Custom Coordinates" else None, force_rmbg=force_rmbg, content_moderation=content_moderation, sku=sku if sku else None, enhance_ref_image=enhance_ref, ref_image_influence=ref_influence ) if result: # Debug logging st.write("Debug - Raw API Response:", result) if sync_mode: if isinstance(result, dict): if "result_url" in result: st.session_state.edited_image = result["result_url"] st.success("✨ Image generated successfully!") elif "result_urls" in result: st.session_state.edited_image = result["result_urls"][0] st.success("✨ Image generated successfully!") elif "result" in result and isinstance(result["result"], list): for item in result["result"]: if isinstance(item, dict) and "urls" in item: st.session_state.edited_image = item["urls"][0] st.success("✨ Image generated successfully!") break elif isinstance(item, list) and len(item) > 0: st.session_state.edited_image = item[0] st.success("✨ Image generated successfully!") break elif "urls" in result: st.session_state.edited_image = result["urls"][0] st.success("✨ Image generated successfully!") else: urls = [] if isinstance(result, dict): if "urls" in result: urls.extend(result["urls"][:num_results]) # Limit to requested number elif "result" in result and isinstance(result["result"], list): # Process each result item for item in result["result"]: if isinstance(item, dict) and "urls" in item: urls.extend(item["urls"]) elif isinstance(item, list): urls.extend(item) # Break if we have enough URLs if len(urls) >= num_results: break # Trim to requested number urls = urls[:num_results] if urls: st.session_state.pending_urls = urls # Create a container for status messages status_container = st.empty() refresh_container = st.empty() # Show initial status status_container.info(f"🎨 Generation started! Waiting for {len(urls)} image{'s' if len(urls) > 1 else ''}...") # Try automatic checking first if auto_check_images(status_container): st.experimental_rerun() # Add refresh button for manual checking if refresh_container.button("🔄 Check for Generated Images"): with st.spinner("Checking for completed images..."): if check_generated_images(): status_container.success("✨ Image ready!") st.experimental_rerun() else: status_container.warning(f"⏳ Still generating your image{'s' if len(urls) > 1 else ''}... Please check again in a moment.") except Exception as e: st.error(f"Error: {str(e)}") if "422" in str(e): st.warning("Content moderation failed. Please ensure the content is appropriate.") with col2: if st.session_state.edited_image: st.image(st.session_state.edited_image, caption="Edited Image", use_column_width=True) image_data = download_image(st.session_state.edited_image) if image_data: st.download_button( "⬇️ Download Result", image_data, "edited_product.png", "image/png" ) elif st.session_state.pending_urls: st.info("Images are being generated. Click the refresh button above to check if they're ready.") # Display generated images from product photography if 'generated_images' in st.session_state and st.session_state.generated_images: product_generated = [img for img in st.session_state.generated_images if img.get('source') in ['packshot', 'shadow', 'lifestyle']] if product_generated: st.subheader("🎨 Generated Product Images") for i, image_data in enumerate(product_generated): with st.expander(f"Product Image {i+1} - {image_data['prompt'][:50]}..."): col1, col2 = st.columns([2, 1]) with col1: st.image(image_data['url'], caption="Generated Image", use_column_width=True) with col2: st.markdown("**Description:**") st.text(image_data['prompt']) st.markdown(f"**Type:** {image_data['source'].title()}") # Download button try: import requests response = requests.get(image_data['url']) if response.status_code == 200: st.download_button( label="📥 Download Image", data=response.content, file_name=f"{image_data['source']}_image_{i+1}.png", mime="image/png" ) else: st.warning("Image not yet ready for download") except Exception as e: st.error(f"Download error: {str(e)}") # Generative Fill Tab with tabs[3]: st.header("🎨 Generative Fill") st.markdown("Draw a mask on the image and describe what you want to generate in that area.") uploaded_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"], key="fill_upload") if uploaded_file: # Create columns for original image and canvas col1, col2 = st.columns(2) with col1: # Display original image st.image(uploaded_file, caption="Original Image", use_column_width=True) # Get image dimensions for canvas img = Image.open(uploaded_file) img_width, img_height = img.size # Calculate aspect ratio and set canvas height aspect_ratio = img_height / img_width canvas_width = min(img_width, 800) # Max width of 800px canvas_height = int(canvas_width * aspect_ratio) # Resize image to match canvas dimensions img = img.resize((canvas_width, canvas_height)) # Convert to RGB if necessary if img.mode != 'RGB': img = img.convert('RGB') # Convert to numpy array with proper shape and type img_array = np.array(img).astype(np.uint8) # Add drawing canvas using Streamlit's drawing canvas component stroke_width = st.slider("Brush width", 1, 50, 20) stroke_color = st.color_picker("Brush color", "#fff") drawing_mode = "freedraw" # Create canvas with background image canvas_result = st_canvas( fill_color="rgba(255, 255, 255, 0.0)", # Transparent fill stroke_width=stroke_width, stroke_color=stroke_color, drawing_mode=drawing_mode, background_color="", # Transparent background background_image=img if img_array.shape[-1] == 3 else None, # Only pass RGB images height=canvas_height, width=canvas_width, key="canvas", ) # Options for generation st.subheader("Generation Options") prompt = st.text_area("Describe what to generate in the masked area") negative_prompt = st.text_area("Describe what to avoid (optional)") col_a, col_b = st.columns(2) with col_a: num_results = st.slider("Number of variations", 1, 4, 1) sync_mode = st.checkbox("Synchronous Mode", False, help="Wait for results instead of getting URLs immediately", key="gen_fill_sync_mode") with col_b: seed = st.number_input("Seed (optional)", min_value=0, value=0, help="Use same seed to reproduce results") content_moderation = st.checkbox("Enable Content Moderation", False, key="gen_fill_content_mod") if st.button("🎨 Generate", type="primary"): if not prompt: st.error("Please enter a prompt describing what to generate.") return if canvas_result.image_data is None: st.error("Please draw a mask on the image first.") return # Convert canvas result to mask mask_img = Image.fromarray(canvas_result.image_data.astype('uint8'), mode='RGBA') mask_img = mask_img.convert('L') # Convert mask to bytes mask_bytes = io.BytesIO() mask_img.save(mask_bytes, format='PNG') mask_bytes = mask_bytes.getvalue() # Convert uploaded image to bytes image_bytes = uploaded_file.getvalue() with st.spinner("🎨 Generating..."): try: result = generative_fill( st.session_state.api_key, image_bytes, mask_bytes, prompt, negative_prompt=negative_prompt if negative_prompt else None, num_results=num_results, sync=sync_mode, seed=seed if seed != 0 else None, content_moderation=content_moderation ) if result: st.write("Debug - API Response:", result) if sync_mode: if "urls" in result and result["urls"]: st.session_state.edited_image = result["urls"][0] if len(result["urls"]) > 1: st.session_state.generated_images = result["urls"] st.success("✨ Generation complete!") elif "result_url" in result: st.session_state.edited_image = result["result_url"] st.success("✨ Generation complete!") else: if "urls" in result: st.session_state.pending_urls = result["urls"][:num_results] # Create containers for status status_container = st.empty() refresh_container = st.empty() # Show initial status status_container.info(f"🎨 Generation started! Waiting for {len(st.session_state.pending_urls)} image{'s' if len(st.session_state.pending_urls) > 1 else ''}...") # Try automatic checking if auto_check_images(status_container): st.rerun() # Add refresh button if refresh_container.button("🔄 Check for Generated Images"): if check_generated_images(): status_container.success("✨ Images ready!") st.rerun() else: status_container.warning("⏳ Still generating... Please check again in a moment.") except Exception as e: st.error(f"Error: {str(e)}") st.write("Full error details:", str(e)) with col2: if st.session_state.edited_image: st.image(st.session_state.edited_image, caption="Generated Result", use_column_width=True) image_data = download_image(st.session_state.edited_image) if image_data: st.download_button( "⬇️ Download Result", image_data, "generated_fill.png", "image/png" ) elif st.session_state.pending_urls: st.info("Generation in progress. Click the refresh button above to check status.") # Erase Elements Tab with tabs[4]: st.header("🎨 Erase Elements") st.markdown("Upload an image and select the area you want to erase.") uploaded_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"], key="erase_upload") if uploaded_file: col1, col2 = st.columns(2) with col1: # Display original image st.image(uploaded_file, caption="Original Image", use_column_width=True) # Get image dimensions for canvas img = Image.open(uploaded_file) img_width, img_height = img.size # Calculate aspect ratio and set canvas height aspect_ratio = img_height / img_width canvas_width = min(img_width, 800) # Max width of 800px canvas_height = int(canvas_width * aspect_ratio) # Resize image to match canvas dimensions img = img.resize((canvas_width, canvas_height)) # Convert to RGB if necessary if img.mode != 'RGB': img = img.convert('RGB') # Add drawing canvas using Streamlit's drawing canvas component stroke_width = st.slider("Brush width", 1, 50, 20, key="erase_brush_width") stroke_color = st.color_picker("Brush color", "#fff", key="erase_brush_color") # Create canvas with background image canvas_result = st_canvas( fill_color="rgba(255, 255, 255, 0.0)", # Transparent fill stroke_width=stroke_width, stroke_color=stroke_color, background_color="", # Transparent background background_image=img, # Pass PIL Image directly drawing_mode="freedraw", height=canvas_height, width=canvas_width, key="erase_canvas", ) # Options for erasing st.subheader("Erase Options") content_moderation = st.checkbox("Enable Content Moderation", False, key="erase_content_mod") if st.button("🎨 Erase Selected Area", key="erase_btn"): if not canvas_result.image_data is None: with st.spinner("Erasing selected area..."): try: # Convert canvas result to mask mask_img = Image.fromarray(canvas_result.image_data.astype('uint8'), mode='RGBA') mask_img = mask_img.convert('L') # Convert uploaded image to bytes image_bytes = uploaded_file.getvalue() result = erase_foreground( st.session_state.api_key, image_data=image_bytes, content_moderation=content_moderation ) if result: if "result_url" in result: st.session_state.edited_image = result["result_url"] st.success("✨ Area erased successfully!") else: st.error("No result URL in the API response. Please try again.") except Exception as e: st.error(f"Error: {str(e)}") if "422" in str(e): st.warning("Content moderation failed. Please ensure the image is appropriate.") else: st.warning("Please draw on the image to select the area to erase.") with col2: if st.session_state.edited_image: st.image(st.session_state.edited_image, caption="Result", use_column_width=True) image_data = download_image(st.session_state.edited_image) if image_data: st.download_button( "⬇️ Download Result", image_data, "erased_image.png", "image/png", key="erase_download" ) if __name__ == "__main__": main()