File size: 17,230 Bytes
bca57e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1953a23
bca57e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1953a23
bca57e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
import streamlit as st
import os
from PIL import Image
import numpy as np
from image_generation import generate_image
from virtual_tryon import (
    apply_virtual_tryon,
    VALID_CLOTH_TYPES,
    VALID_IMAGE_SIZES,
    DEFAULT_IMAGE_SIZE,
    DEFAULT_NUM_STEPS,
    DEFAULT_GUIDANCE_SCALE,
    DEFAULT_SEED
)
import io
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

# Set page config
st.set_page_config(
    page_title="Virtual Try-On",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Constants
MAX_SEED = 2147483647
DEFAULT_LORA_KEYWORD = ""
EXAMPLE_PROMPTS = [
    "fashion product showcase closeup of a luxury red evening dress, studio lighting",
    "wide shot fashion model wearing designer denim jacket, urban street background",
    "closeup product photography of handcrafted leather handbag, neutral background",
    "full body fashion showcase of summer collection dress, beach setting",
    "detailed closeup of haute couture embroidered blouse, professional studio setup",
    "wide angle fashion editorial of winter coat ensemble, city backdrop",
    "product closeup of designer sneakers, minimalist background",
    "editorial wide shot of evening wear collection, elegant interior setting"
]

def main():
    st.title("Virtual Try-On System")
    st.markdown("Upload images or generate new ones to try on different clothing items!")

    # Sidebar for API keys and LoRA settings
    with st.sidebar:
        st.header("API Configuration")
        replicate_api = st.text_input("Replicate API Key", type="password")
        
        # Kling API Configuration
        st.subheader("Kling API Configuration")
        kling_access_key = st.text_input("Kling Access Key", type="password")
        kling_secret_key = st.text_input("Kling Secret Key", type="password")
        
        st.header("Generation Settings")
        use_lora = st.checkbox("Use LoRA", value=True)
        
        if use_lora:
            custom_lora = st.text_input(
                "Custom LoRA URL",
                value=""
            )
            use_keyword = st.checkbox(
                "Add style keyword to prompts",
                value=True,
                help="Automatically add a style keyword to your prompts"
            )
            if use_keyword:
                custom_keyword = st.text_input(
                    "Custom Style Keyword",
                    value=DEFAULT_LORA_KEYWORD,
                    help="This keyword will be added to your prompts"
                )

    # Add a tabs section for different functionalities
    tab1, tab2 = st.tabs(["Generate Images", "Virtual Try-On"])

    with tab1:
        st.header("Image Generation")
        
        col1, col2 = st.columns(2)
        
        with col1:
            st.subheader("Generation Settings")
            prompt = st.text_area(
                "Generation Prompt",
                value=EXAMPLE_PROMPTS[0],
                height=100,
                help="Describe the image you want to generate"
            )
            
            # Example prompts
            with st.expander("Example Prompts"):
                for example in EXAMPLE_PROMPTS:
                    if st.button(example[:50] + "...", key=example):
                        prompt = example
                        st.session_state.current_prompt = example
            
            # Advanced settings
            with st.expander("Advanced Settings", expanded=True):
                num_steps = st.slider(
                    "Number of Steps",
                    min_value=20,
                    max_value=50,
                    value=DEFAULT_NUM_STEPS,
                    help="More steps = better quality but slower"
                )
                
                guidance = st.slider(
                    "Guidance Scale",
                    min_value=1.0,
                    max_value=20.0,
                    value=DEFAULT_GUIDANCE_SCALE,
                    help="How closely to follow the prompt"
                )
                
                aspect = st.radio(
                    "Aspect Ratio",
                    ["1:1", "16:9", "3:2", "2:3", "4:5", "5:4"],
                    help="Select the aspect ratio for the generated image"
                )
                
                negative_prompt = st.text_area(
                    "Negative Prompt",
                    value="ugly, blurry, low quality, distorted, deformed",
                    help="What to avoid in the generation"
                )
        
        with col2:
            st.subheader("Generated Image")
            if st.button("Generate Image", type="primary"):
                if not replicate_api:
                    st.error("Please provide your Replicate API key in the sidebar!")
                else:
                    with st.spinner("Generating image..."):
                        # Add LoRA and keyword if enabled
                        final_prompt = prompt
                        if use_lora and use_keyword and 'custom_keyword' in locals():
                            final_prompt = f"{custom_keyword} {prompt}"
                        
                        result_image, status = generate_image(
                            final_prompt,
                            num_steps=num_steps,
                            guidance_scale=guidance,
                            aspect_ratio=aspect,
                            replicate_api_key=replicate_api,
                            lora_url=custom_lora if use_lora else None,
                            negative_prompt=negative_prompt
                        )
                        
                        if result_image:
                            st.session_state.generated_image = result_image
                            st.image(result_image, caption="Generated Image", use_column_width=True)
                            
                            # Add buttons to use the generated image
                            col1, col2 = st.columns(2)
                            with col1:
                                if st.button("Use as Person Image"):
                                    st.session_state.person_img = result_image
                                    st.success("Set as person image!")
                            with col2:
                                if st.button("Use as Garment Image"):
                                    st.session_state.garment_img = result_image
                                    st.success("Set as garment image!")
                                    
                            # Add download button
                            buf = io.BytesIO()
                            result_image.save(buf, format="PNG")
                            st.download_button(
                                label="Download Image",
                                data=buf.getvalue(),
                                file_name="generated_image.png",
                                mime="image/png"
                            )
                        else:
                            st.error(status)
                    
                    # Display history of generated images
                    if 'generated_images' not in st.session_state:
                        st.session_state.generated_images = []
                        
                    if 'generated_image' in st.session_state:
                        with st.expander("Generation History"):
                            st.session_state.generated_images.append(st.session_state.generated_image)
                            for idx, img in enumerate(reversed(st.session_state.generated_images[-5:])):
                                st.image(img, caption=f"Generated Image {len(st.session_state.generated_images)-idx}", 
                                        use_column_width=True)

    with tab2:
        # Move the existing three columns layout here
        col_left, col_mid, col_right = st.columns(3)

        with col_left:
            st.header("Person Image")
            upload_method = st.radio(
                "Choose input method",
                ["Upload Image", "Generate Image"],
                key="person_method"
            )
            
            if upload_method == "Upload Image":
                preview_container = st.empty()
                
                person_image = st.file_uploader(
                    "Upload person image",
                    type=['png', 'jpg', 'jpeg'],
                    key="person_upload",
                    help="Upload a clear full-body photo of the person"
                )
                
                if person_image is not None:
                    img = Image.open(person_image)
                    
                    with st.expander("Image Settings"):
                        resize = st.checkbox("Resize Image", value=True)
                        if resize:
                            max_size = st.slider("Max Image Size", 256, 1024, 512)
                            aspect = float(img.size[0]) / float(img.size[1])
                            if aspect > 1:
                                new_size = (max_size, int(max_size / aspect))
                            else:
                                new_size = (int(max_size * aspect), max_size)
                            img = img.resize(new_size, Image.Resampling.LANCZOS)
                        
                    preview_container.image(img, caption="Uploaded Person Image", use_column_width=True)
                    st.session_state.person_img = img
            else:
                prompt = st.text_area(
                    "Generation Prompt",
                    value=EXAMPLE_PROMPTS[0],
                    height=100
                )
                
                # Example prompts section
                if st.button("Show Example Prompts"):
                    selected_prompt = st.selectbox(
                        "Select an example prompt",
                        EXAMPLE_PROMPTS
                    )
                    if selected_prompt:
                        prompt = st.text_area(
                            "Generation Prompt",
                            value=selected_prompt,
                            height=100,
                            key="selected_prompt"
                        )
                
                with st.expander("Generation Settings"):
                    num_steps = st.slider("Steps", 20, 50, DEFAULT_NUM_STEPS)
                    guidance = st.slider("Guidance Scale", 1.0, 20.0, DEFAULT_GUIDANCE_SCALE)
                    aspect = st.radio(
                        "Aspect Ratio",
                        ["1:1", "16:9", "3:2", "2:3", "4:5", "5:4"],
                        help="Select the aspect ratio for the generated image"
                    )
                
                if st.button("Generate Person Image"):
                    with st.spinner("Generating image..."):
                        final_prompt = prompt
                        if use_lora and use_keyword and 'custom_keyword' in locals():
                            final_prompt = f"{custom_keyword} {prompt}"
                            
                        result_image, status = generate_image(
                            final_prompt,
                            num_steps=num_steps,
                            guidance_scale=guidance,
                            aspect_ratio=aspect,
                            replicate_api_key=replicate_api,
                            lora_url=custom_lora if use_lora else None
                        )
                        
                        if result_image:
                            st.session_state.person_img = result_image
                            st.image(result_image, caption="Generated Person Image")
                        else:
                            st.error(status)

        with col_mid:
            st.header("Garment Image")
            upload_method = st.radio(
                "Choose input method",
                ["Upload Image", "Generate Image"],
                key="garment_method"
            )
            
            if upload_method == "Upload Image":
                preview_container = st.empty()
                
                garment_image = st.file_uploader(
                    "Upload garment image",
                    type=['png', 'jpg', 'jpeg'],
                    key="garment_upload",
                    help="Upload a clear photo of the garment on plain background"
                )
                
                if garment_image is not None:
                    img = Image.open(garment_image)
                    
                    with st.expander("Image Settings"):
                        resize = st.checkbox("Resize Image", value=True, key="garment_resize")
                        if resize:
                            max_size = st.slider("Max Image Size", 256, 1024, 512, key="garment_size")
                            aspect = float(img.size[0]) / float(img.size[1])
                            if aspect > 1:
                                new_size = (max_size, int(max_size / aspect))
                            else:
                                new_size = (int(max_size * aspect), max_size)
                            img = img.resize(new_size, Image.Resampling.LANCZOS)
                        
                    preview_container.image(img, caption="Uploaded Garment Image", use_column_width=True)
                    st.session_state.garment_img = img
            else:
                prompt = st.text_area(
                    "Generation Prompt",
                    value=EXAMPLE_PROMPTS[0],
                    height=100
                )
                
                with st.expander("Generation Settings"):
                    num_steps = st.slider("Steps", 20, 50, DEFAULT_NUM_STEPS, key="garment_steps")
                    guidance = st.slider("Guidance Scale", 1.0, 20.0, DEFAULT_GUIDANCE_SCALE, key="garment_guidance")
                    aspect = st.radio(
                        "Aspect Ratio",
                        ["1:1", "16:9", "3:2"],
                        key="garment_aspect"
                    )
                
                if st.button("Generate Garment Image"):
                    with st.spinner("Generating image..."):
                        result_image, status = generate_image(
                            prompt,
                            num_steps=num_steps,
                            guidance_scale=guidance,
                            aspect_ratio=aspect,
                            replicate_api_key=replicate_api,
                            lora_url=custom_lora if use_lora else None
                        )
                        
                        if result_image:
                            st.session_state.garment_img = result_image
                            st.image(result_image, caption="Generated Garment Image")
                        else:
                            st.error(status)

        with col_right:
            st.header("Virtual Try-On Result")
            
            # Try-on button and result display
            if st.button("Generate Try-On", type="primary"):
                if not hasattr(st.session_state, 'person_img'):
                    st.error("Please provide a person image first!")
                elif not hasattr(st.session_state, 'garment_img'):
                    st.error("Please provide a garment image first!")
                elif not kling_access_key or not kling_secret_key:
                    st.error("Please provide both Kling Access Key and Secret Key!")
                else:
                    with st.spinner("Generating try-on image..."):
                        result_image, status = apply_virtual_tryon(
                            st.session_state.person_img,
                            st.session_state.garment_img,
                            kling_access_key,
                            kling_secret_key
                        )
                        
                        if result_image:
                            st.session_state.result_image = result_image
                            st.success("Try-on completed successfully!")
                            
                            # Add download button for the result
                            buf = io.BytesIO()
                            result_image.save(buf, format="PNG")
                            st.download_button(
                                label="Download Result",
                                data=buf.getvalue(),
                                file_name="tryon_result.png",
                                mime="image/png"
                            )
                        else:
                            st.error(status)

            # Result display area
            if 'result_image' in st.session_state:
                st.image(
                    st.session_state.result_image,
                    caption="Try-on Result",
                    use_container_width=True
                )

if __name__ == "__main__":
    main()