import streamlit as st import os from PIL import Image import numpy as np from image_generation import generate_image from virtual_tryon import ( apply_virtual_tryon, VALID_CLOTH_TYPES, VALID_IMAGE_SIZES, DEFAULT_IMAGE_SIZE, DEFAULT_NUM_STEPS, DEFAULT_GUIDANCE_SCALE, DEFAULT_SEED ) import io import urllib3 urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) # Set page config st.set_page_config( page_title="Virtual Try-On", layout="wide", initial_sidebar_state="expanded" ) # Constants MAX_SEED = 2147483647 DEFAULT_LORA_KEYWORD = "" EXAMPLE_PROMPTS = [ "fashion product showcase closeup of a luxury red evening dress, studio lighting", "wide shot fashion model wearing designer denim jacket, urban street background", "closeup product photography of handcrafted leather handbag, neutral background", "full body fashion showcase of summer collection dress, beach setting", "detailed closeup of haute couture embroidered blouse, professional studio setup", "wide angle fashion editorial of winter coat ensemble, city backdrop", "product closeup of designer sneakers, minimalist background", "editorial wide shot of evening wear collection, elegant interior setting" ] def main(): st.title("Virtual Try-On System") st.markdown("Upload images or generate new ones to try on different clothing items!") # Sidebar for API keys and LoRA settings with st.sidebar: st.header("API Configuration") replicate_api = st.text_input("Replicate API Key", type="password") # Kling API Configuration st.subheader("Kling API Configuration") kling_access_key = st.text_input("Kling Access Key", type="password") kling_secret_key = st.text_input("Kling Secret Key", type="password") st.header("Generation Settings") use_lora = st.checkbox("Use LoRA", value=True) if use_lora: custom_lora = st.text_input( "Custom LoRA URL", value="" ) use_keyword = st.checkbox( "Add style keyword to prompts", value=True, help="Automatically add a style keyword to your prompts" ) if use_keyword: custom_keyword = st.text_input( "Custom Style Keyword", value=DEFAULT_LORA_KEYWORD, help="This keyword will be added to your prompts" ) # Add a tabs section for different functionalities tab1, tab2 = st.tabs(["Generate Images", "Virtual Try-On"]) with tab1: st.header("Image Generation") col1, col2 = st.columns(2) with col1: st.subheader("Generation Settings") prompt = st.text_area( "Generation Prompt", value=EXAMPLE_PROMPTS[0], height=100, help="Describe the image you want to generate" ) # Example prompts with st.expander("Example Prompts"): for example in EXAMPLE_PROMPTS: if st.button(example[:50] + "...", key=example): prompt = example st.session_state.current_prompt = example # Advanced settings with st.expander("Advanced Settings", expanded=True): num_steps = st.slider( "Number of Steps", min_value=20, max_value=50, value=DEFAULT_NUM_STEPS, help="More steps = better quality but slower" ) guidance = st.slider( "Guidance Scale", min_value=1.0, max_value=20.0, value=DEFAULT_GUIDANCE_SCALE, help="How closely to follow the prompt" ) aspect = st.radio( "Aspect Ratio", ["1:1", "16:9", "3:2", "2:3", "4:5", "5:4"], help="Select the aspect ratio for the generated image" ) negative_prompt = st.text_area( "Negative Prompt", value="ugly, blurry, low quality, distorted, deformed", help="What to avoid in the generation" ) with col2: st.subheader("Generated Image") if st.button("Generate Image", type="primary"): if not replicate_api: st.error("Please provide your Replicate API key in the sidebar!") else: with st.spinner("Generating image..."): # Add LoRA and keyword if enabled final_prompt = prompt if use_lora and use_keyword and 'custom_keyword' in locals(): final_prompt = f"{custom_keyword} {prompt}" result_image, status = generate_image( final_prompt, num_steps=num_steps, guidance_scale=guidance, aspect_ratio=aspect, replicate_api_key=replicate_api, lora_url=custom_lora if use_lora else None, negative_prompt=negative_prompt ) if result_image: st.session_state.generated_image = result_image st.image(result_image, caption="Generated Image", use_column_width=True) # Add buttons to use the generated image col1, col2 = st.columns(2) with col1: if st.button("Use as Person Image"): st.session_state.person_img = result_image st.success("Set as person image!") with col2: if st.button("Use as Garment Image"): st.session_state.garment_img = result_image st.success("Set as garment image!") # Add download button buf = io.BytesIO() result_image.save(buf, format="PNG") st.download_button( label="Download Image", data=buf.getvalue(), file_name="generated_image.png", mime="image/png" ) else: st.error(status) # Display history of generated images if 'generated_images' not in st.session_state: st.session_state.generated_images = [] if 'generated_image' in st.session_state: with st.expander("Generation History"): st.session_state.generated_images.append(st.session_state.generated_image) for idx, img in enumerate(reversed(st.session_state.generated_images[-5:])): st.image(img, caption=f"Generated Image {len(st.session_state.generated_images)-idx}", use_column_width=True) with tab2: # Move the existing three columns layout here col_left, col_mid, col_right = st.columns(3) with col_left: st.header("Person Image") upload_method = st.radio( "Choose input method", ["Upload Image", "Generate Image"], key="person_method" ) if upload_method == "Upload Image": preview_container = st.empty() person_image = st.file_uploader( "Upload person image", type=['png', 'jpg', 'jpeg'], key="person_upload", help="Upload a clear full-body photo of the person" ) if person_image is not None: img = Image.open(person_image) with st.expander("Image Settings"): resize = st.checkbox("Resize Image", value=True) if resize: max_size = st.slider("Max Image Size", 256, 1024, 512) aspect = float(img.size[0]) / float(img.size[1]) if aspect > 1: new_size = (max_size, int(max_size / aspect)) else: new_size = (int(max_size * aspect), max_size) img = img.resize(new_size, Image.Resampling.LANCZOS) preview_container.image(img, caption="Uploaded Person Image", use_column_width=True) st.session_state.person_img = img else: prompt = st.text_area( "Generation Prompt", value=EXAMPLE_PROMPTS[0], height=100 ) # Example prompts section if st.button("Show Example Prompts"): selected_prompt = st.selectbox( "Select an example prompt", EXAMPLE_PROMPTS ) if selected_prompt: prompt = st.text_area( "Generation Prompt", value=selected_prompt, height=100, key="selected_prompt" ) with st.expander("Generation Settings"): num_steps = st.slider("Steps", 20, 50, DEFAULT_NUM_STEPS) guidance = st.slider("Guidance Scale", 1.0, 20.0, DEFAULT_GUIDANCE_SCALE) aspect = st.radio( "Aspect Ratio", ["1:1", "16:9", "3:2", "2:3", "4:5", "5:4"], help="Select the aspect ratio for the generated image" ) if st.button("Generate Person Image"): with st.spinner("Generating image..."): final_prompt = prompt if use_lora and use_keyword and 'custom_keyword' in locals(): final_prompt = f"{custom_keyword} {prompt}" result_image, status = generate_image( final_prompt, num_steps=num_steps, guidance_scale=guidance, aspect_ratio=aspect, replicate_api_key=replicate_api, lora_url=custom_lora if use_lora else None ) if result_image: st.session_state.person_img = result_image st.image(result_image, caption="Generated Person Image") else: st.error(status) with col_mid: st.header("Garment Image") upload_method = st.radio( "Choose input method", ["Upload Image", "Generate Image"], key="garment_method" ) if upload_method == "Upload Image": preview_container = st.empty() garment_image = st.file_uploader( "Upload garment image", type=['png', 'jpg', 'jpeg'], key="garment_upload", help="Upload a clear photo of the garment on plain background" ) if garment_image is not None: img = Image.open(garment_image) with st.expander("Image Settings"): resize = st.checkbox("Resize Image", value=True, key="garment_resize") if resize: max_size = st.slider("Max Image Size", 256, 1024, 512, key="garment_size") aspect = float(img.size[0]) / float(img.size[1]) if aspect > 1: new_size = (max_size, int(max_size / aspect)) else: new_size = (int(max_size * aspect), max_size) img = img.resize(new_size, Image.Resampling.LANCZOS) preview_container.image(img, caption="Uploaded Garment Image", use_column_width=True) st.session_state.garment_img = img else: prompt = st.text_area( "Generation Prompt", value=EXAMPLE_PROMPTS[0], height=100 ) with st.expander("Generation Settings"): num_steps = st.slider("Steps", 20, 50, DEFAULT_NUM_STEPS, key="garment_steps") guidance = st.slider("Guidance Scale", 1.0, 20.0, DEFAULT_GUIDANCE_SCALE, key="garment_guidance") aspect = st.radio( "Aspect Ratio", ["1:1", "16:9", "3:2"], key="garment_aspect" ) if st.button("Generate Garment Image"): with st.spinner("Generating image..."): result_image, status = generate_image( prompt, num_steps=num_steps, guidance_scale=guidance, aspect_ratio=aspect, replicate_api_key=replicate_api, lora_url=custom_lora if use_lora else None ) if result_image: st.session_state.garment_img = result_image st.image(result_image, caption="Generated Garment Image") else: st.error(status) with col_right: st.header("Virtual Try-On Result") # Try-on button and result display if st.button("Generate Try-On", type="primary"): if not hasattr(st.session_state, 'person_img'): st.error("Please provide a person image first!") elif not hasattr(st.session_state, 'garment_img'): st.error("Please provide a garment image first!") elif not kling_access_key or not kling_secret_key: st.error("Please provide both Kling Access Key and Secret Key!") else: with st.spinner("Generating try-on image..."): result_image, status = apply_virtual_tryon( st.session_state.person_img, st.session_state.garment_img, kling_access_key, kling_secret_key ) if result_image: st.session_state.result_image = result_image st.success("Try-on completed successfully!") # Add download button for the result buf = io.BytesIO() result_image.save(buf, format="PNG") st.download_button( label="Download Result", data=buf.getvalue(), file_name="tryon_result.png", mime="image/png" ) else: st.error(status) # Result display area if 'result_image' in st.session_state: st.image( st.session_state.result_image, caption="Try-on Result", use_container_width=True ) if __name__ == "__main__": main()