Spaces:
Running
Running
| import streamlit as st | |
| import os | |
| from PIL import Image | |
| import numpy as np | |
| from image_generation import generate_image | |
| from virtual_tryon import ( | |
| apply_virtual_tryon, | |
| VALID_CLOTH_TYPES, | |
| VALID_IMAGE_SIZES, | |
| DEFAULT_IMAGE_SIZE, | |
| DEFAULT_NUM_STEPS, | |
| DEFAULT_GUIDANCE_SCALE, | |
| DEFAULT_SEED | |
| ) | |
| import io | |
| import urllib3 | |
| urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) | |
| # Set page config | |
| st.set_page_config( | |
| page_title="Virtual Try-On", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Constants | |
| MAX_SEED = 2147483647 | |
| DEFAULT_LORA_KEYWORD = "" | |
| EXAMPLE_PROMPTS = [ | |
| "fashion product showcase closeup of a luxury red evening dress, studio lighting", | |
| "wide shot fashion model wearing designer denim jacket, urban street background", | |
| "closeup product photography of handcrafted leather handbag, neutral background", | |
| "full body fashion showcase of summer collection dress, beach setting", | |
| "detailed closeup of haute couture embroidered blouse, professional studio setup", | |
| "wide angle fashion editorial of winter coat ensemble, city backdrop", | |
| "product closeup of designer sneakers, minimalist background", | |
| "editorial wide shot of evening wear collection, elegant interior setting" | |
| ] | |
| def main(): | |
| st.title("Virtual Try-On System") | |
| st.markdown("Upload images or generate new ones to try on different clothing items!") | |
| # Sidebar for API keys and LoRA settings | |
| with st.sidebar: | |
| st.header("API Configuration") | |
| replicate_api = st.text_input("Replicate API Key", type="password") | |
| # Kling API Configuration | |
| st.subheader("Kling API Configuration") | |
| kling_access_key = st.text_input("Kling Access Key", type="password") | |
| kling_secret_key = st.text_input("Kling Secret Key", type="password") | |
| st.header("Generation Settings") | |
| use_lora = st.checkbox("Use LoRA", value=True) | |
| if use_lora: | |
| custom_lora = st.text_input( | |
| "Custom LoRA URL", | |
| value="" | |
| ) | |
| use_keyword = st.checkbox( | |
| "Add style keyword to prompts", | |
| value=True, | |
| help="Automatically add a style keyword to your prompts" | |
| ) | |
| if use_keyword: | |
| custom_keyword = st.text_input( | |
| "Custom Style Keyword", | |
| value=DEFAULT_LORA_KEYWORD, | |
| help="This keyword will be added to your prompts" | |
| ) | |
| # Add a tabs section for different functionalities | |
| tab1, tab2 = st.tabs(["Generate Images", "Virtual Try-On"]) | |
| with tab1: | |
| st.header("Image Generation") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("Generation Settings") | |
| prompt = st.text_area( | |
| "Generation Prompt", | |
| value=EXAMPLE_PROMPTS[0], | |
| height=100, | |
| help="Describe the image you want to generate" | |
| ) | |
| # Example prompts | |
| with st.expander("Example Prompts"): | |
| for example in EXAMPLE_PROMPTS: | |
| if st.button(example[:50] + "...", key=example): | |
| prompt = example | |
| st.session_state.current_prompt = example | |
| # Advanced settings | |
| with st.expander("Advanced Settings", expanded=True): | |
| num_steps = st.slider( | |
| "Number of Steps", | |
| min_value=20, | |
| max_value=50, | |
| value=DEFAULT_NUM_STEPS, | |
| help="More steps = better quality but slower" | |
| ) | |
| guidance = st.slider( | |
| "Guidance Scale", | |
| min_value=1.0, | |
| max_value=20.0, | |
| value=DEFAULT_GUIDANCE_SCALE, | |
| help="How closely to follow the prompt" | |
| ) | |
| aspect = st.radio( | |
| "Aspect Ratio", | |
| ["1:1", "16:9", "3:2", "2:3", "4:5", "5:4"], | |
| help="Select the aspect ratio for the generated image" | |
| ) | |
| negative_prompt = st.text_area( | |
| "Negative Prompt", | |
| value="ugly, blurry, low quality, distorted, deformed", | |
| help="What to avoid in the generation" | |
| ) | |
| with col2: | |
| st.subheader("Generated Image") | |
| if st.button("Generate Image", type="primary"): | |
| if not replicate_api: | |
| st.error("Please provide your Replicate API key in the sidebar!") | |
| else: | |
| with st.spinner("Generating image..."): | |
| # Add LoRA and keyword if enabled | |
| final_prompt = prompt | |
| if use_lora and use_keyword and 'custom_keyword' in locals(): | |
| final_prompt = f"{custom_keyword} {prompt}" | |
| result_image, status = generate_image( | |
| final_prompt, | |
| num_steps=num_steps, | |
| guidance_scale=guidance, | |
| aspect_ratio=aspect, | |
| replicate_api_key=replicate_api, | |
| lora_url=custom_lora if use_lora else None, | |
| negative_prompt=negative_prompt | |
| ) | |
| if result_image: | |
| st.session_state.generated_image = result_image | |
| st.image(result_image, caption="Generated Image", use_column_width=True) | |
| # Add buttons to use the generated image | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if st.button("Use as Person Image"): | |
| st.session_state.person_img = result_image | |
| st.success("Set as person image!") | |
| with col2: | |
| if st.button("Use as Garment Image"): | |
| st.session_state.garment_img = result_image | |
| st.success("Set as garment image!") | |
| # Add download button | |
| buf = io.BytesIO() | |
| result_image.save(buf, format="PNG") | |
| st.download_button( | |
| label="Download Image", | |
| data=buf.getvalue(), | |
| file_name="generated_image.png", | |
| mime="image/png" | |
| ) | |
| else: | |
| st.error(status) | |
| # Display history of generated images | |
| if 'generated_images' not in st.session_state: | |
| st.session_state.generated_images = [] | |
| if 'generated_image' in st.session_state: | |
| with st.expander("Generation History"): | |
| st.session_state.generated_images.append(st.session_state.generated_image) | |
| for idx, img in enumerate(reversed(st.session_state.generated_images[-5:])): | |
| st.image(img, caption=f"Generated Image {len(st.session_state.generated_images)-idx}", | |
| use_column_width=True) | |
| with tab2: | |
| # Move the existing three columns layout here | |
| col_left, col_mid, col_right = st.columns(3) | |
| with col_left: | |
| st.header("Person Image") | |
| upload_method = st.radio( | |
| "Choose input method", | |
| ["Upload Image", "Generate Image"], | |
| key="person_method" | |
| ) | |
| if upload_method == "Upload Image": | |
| preview_container = st.empty() | |
| person_image = st.file_uploader( | |
| "Upload person image", | |
| type=['png', 'jpg', 'jpeg'], | |
| key="person_upload", | |
| help="Upload a clear full-body photo of the person" | |
| ) | |
| if person_image is not None: | |
| img = Image.open(person_image) | |
| with st.expander("Image Settings"): | |
| resize = st.checkbox("Resize Image", value=True) | |
| if resize: | |
| max_size = st.slider("Max Image Size", 256, 1024, 512) | |
| aspect = float(img.size[0]) / float(img.size[1]) | |
| if aspect > 1: | |
| new_size = (max_size, int(max_size / aspect)) | |
| else: | |
| new_size = (int(max_size * aspect), max_size) | |
| img = img.resize(new_size, Image.Resampling.LANCZOS) | |
| preview_container.image(img, caption="Uploaded Person Image", use_column_width=True) | |
| st.session_state.person_img = img | |
| else: | |
| prompt = st.text_area( | |
| "Generation Prompt", | |
| value=EXAMPLE_PROMPTS[0], | |
| height=100 | |
| ) | |
| # Example prompts section | |
| if st.button("Show Example Prompts"): | |
| selected_prompt = st.selectbox( | |
| "Select an example prompt", | |
| EXAMPLE_PROMPTS | |
| ) | |
| if selected_prompt: | |
| prompt = st.text_area( | |
| "Generation Prompt", | |
| value=selected_prompt, | |
| height=100, | |
| key="selected_prompt" | |
| ) | |
| with st.expander("Generation Settings"): | |
| num_steps = st.slider("Steps", 20, 50, DEFAULT_NUM_STEPS) | |
| guidance = st.slider("Guidance Scale", 1.0, 20.0, DEFAULT_GUIDANCE_SCALE) | |
| aspect = st.radio( | |
| "Aspect Ratio", | |
| ["1:1", "16:9", "3:2", "2:3", "4:5", "5:4"], | |
| help="Select the aspect ratio for the generated image" | |
| ) | |
| if st.button("Generate Person Image"): | |
| with st.spinner("Generating image..."): | |
| final_prompt = prompt | |
| if use_lora and use_keyword and 'custom_keyword' in locals(): | |
| final_prompt = f"{custom_keyword} {prompt}" | |
| result_image, status = generate_image( | |
| final_prompt, | |
| num_steps=num_steps, | |
| guidance_scale=guidance, | |
| aspect_ratio=aspect, | |
| replicate_api_key=replicate_api, | |
| lora_url=custom_lora if use_lora else None | |
| ) | |
| if result_image: | |
| st.session_state.person_img = result_image | |
| st.image(result_image, caption="Generated Person Image") | |
| else: | |
| st.error(status) | |
| with col_mid: | |
| st.header("Garment Image") | |
| upload_method = st.radio( | |
| "Choose input method", | |
| ["Upload Image", "Generate Image"], | |
| key="garment_method" | |
| ) | |
| if upload_method == "Upload Image": | |
| preview_container = st.empty() | |
| garment_image = st.file_uploader( | |
| "Upload garment image", | |
| type=['png', 'jpg', 'jpeg'], | |
| key="garment_upload", | |
| help="Upload a clear photo of the garment on plain background" | |
| ) | |
| if garment_image is not None: | |
| img = Image.open(garment_image) | |
| with st.expander("Image Settings"): | |
| resize = st.checkbox("Resize Image", value=True, key="garment_resize") | |
| if resize: | |
| max_size = st.slider("Max Image Size", 256, 1024, 512, key="garment_size") | |
| aspect = float(img.size[0]) / float(img.size[1]) | |
| if aspect > 1: | |
| new_size = (max_size, int(max_size / aspect)) | |
| else: | |
| new_size = (int(max_size * aspect), max_size) | |
| img = img.resize(new_size, Image.Resampling.LANCZOS) | |
| preview_container.image(img, caption="Uploaded Garment Image", use_column_width=True) | |
| st.session_state.garment_img = img | |
| else: | |
| prompt = st.text_area( | |
| "Generation Prompt", | |
| value=EXAMPLE_PROMPTS[0], | |
| height=100 | |
| ) | |
| with st.expander("Generation Settings"): | |
| num_steps = st.slider("Steps", 20, 50, DEFAULT_NUM_STEPS, key="garment_steps") | |
| guidance = st.slider("Guidance Scale", 1.0, 20.0, DEFAULT_GUIDANCE_SCALE, key="garment_guidance") | |
| aspect = st.radio( | |
| "Aspect Ratio", | |
| ["1:1", "16:9", "3:2"], | |
| key="garment_aspect" | |
| ) | |
| if st.button("Generate Garment Image"): | |
| with st.spinner("Generating image..."): | |
| result_image, status = generate_image( | |
| prompt, | |
| num_steps=num_steps, | |
| guidance_scale=guidance, | |
| aspect_ratio=aspect, | |
| replicate_api_key=replicate_api, | |
| lora_url=custom_lora if use_lora else None | |
| ) | |
| if result_image: | |
| st.session_state.garment_img = result_image | |
| st.image(result_image, caption="Generated Garment Image") | |
| else: | |
| st.error(status) | |
| with col_right: | |
| st.header("Virtual Try-On Result") | |
| # Try-on button and result display | |
| if st.button("Generate Try-On", type="primary"): | |
| if not hasattr(st.session_state, 'person_img'): | |
| st.error("Please provide a person image first!") | |
| elif not hasattr(st.session_state, 'garment_img'): | |
| st.error("Please provide a garment image first!") | |
| elif not kling_access_key or not kling_secret_key: | |
| st.error("Please provide both Kling Access Key and Secret Key!") | |
| else: | |
| with st.spinner("Generating try-on image..."): | |
| result_image, status = apply_virtual_tryon( | |
| st.session_state.person_img, | |
| st.session_state.garment_img, | |
| kling_access_key, | |
| kling_secret_key | |
| ) | |
| if result_image: | |
| st.session_state.result_image = result_image | |
| st.success("Try-on completed successfully!") | |
| # Add download button for the result | |
| buf = io.BytesIO() | |
| result_image.save(buf, format="PNG") | |
| st.download_button( | |
| label="Download Result", | |
| data=buf.getvalue(), | |
| file_name="tryon_result.png", | |
| mime="image/png" | |
| ) | |
| else: | |
| st.error(status) | |
| # Result display area | |
| if 'result_image' in st.session_state: | |
| st.image( | |
| st.session_state.result_image, | |
| caption="Try-on Result", | |
| use_container_width=True | |
| ) | |
| if __name__ == "__main__": | |
| main() |