Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import asyncio | |
| import fal_client | |
| from dotenv import load_dotenv | |
| import os | |
| from pathlib import Path | |
| import time | |
| import json | |
| load_dotenv() | |
| os.environ["FAL_KEY"] = os.getenv("FAL_API_KEY") | |
| async def generate_paris_images(product_name: str, image1_path: str, image2_path: str, woman_prompt: str, man_prompt: str, girl_name: str, girl_hair_length: str, girl_hair_style: str, girl_hair_color: str, boy_name: str, boy_hair_length: str, boy_hair_style: str, boy_hair_color: str, batch_size: int, progress=gr.Progress()): | |
| start_time = time.time() | |
| print("Progress: 5% - Starting Paris image generation...") | |
| progress(0.05, desc="Starting Paris image generation...") | |
| # Upload all images in parallel | |
| upload_tasks = [ | |
| fal_client.upload_file_async(str(image1_path)), | |
| fal_client.upload_file_async(str(image2_path)), | |
| fal_client.upload_file_async("template/man_pose.png"), | |
| fal_client.upload_file_async("template/woman_pose.png"), | |
| fal_client.upload_file_async("template/woman_clip_mask.png"), | |
| fal_client.upload_file_async("template/man_clip_mask.png") | |
| ] | |
| [image1_url, image2_url, man_pose_img, woman_pose_img, woman_clip_mask, man_clip_mask] = await asyncio.gather(*upload_tasks) | |
| print("Progress: 40% - Uploaded all images") | |
| progress(0.4, desc="Uploaded all images") | |
| # Replace {hair_feature} placeholders with user hair descriptions | |
| woman_hair_desc = f"{girl_hair_length} {girl_hair_style} {girl_hair_color} hair," | |
| print(f"Final woman hair description: {woman_hair_desc}") | |
| # Handle bald case for man's hair description | |
| if boy_hair_length == "Bald": | |
| man_hair_desc = "bald," | |
| else: | |
| man_hair_desc = f"{boy_hair_length} {boy_hair_style} {boy_hair_color} hair," | |
| print(f"Final man hair description: {man_hair_desc}") | |
| woman_prompt = woman_prompt.replace("{hair_feature}", woman_hair_desc) | |
| man_prompt = man_prompt.replace("{hair_feature}", man_hair_desc) | |
| print(f"Final woman prompt: {woman_prompt}") | |
| print(f"Final man prompt: {man_prompt}") | |
| handler = await fal_client.submit_async( | |
| "comfy/LVE/paris-couple", | |
| arguments={ | |
| "loadimage_1": image1_url, | |
| "loadimage_2": image2_url, | |
| "loadimage_3": woman_pose_img, | |
| "loadimage_4": woman_clip_mask, | |
| "loadimage_5": man_clip_mask, | |
| "loadimage_6": man_pose_img, | |
| "woman_prompt": woman_prompt, | |
| "man_prompt": man_prompt, | |
| "girl_name": girl_name, | |
| "boy_name": boy_name, | |
| "batch_size": batch_size | |
| } | |
| ) | |
| print("Progress: 60% - Processing images...") | |
| progress(0.6, desc="Processing images...") | |
| result = await handler.get() | |
| print(result) | |
| end_time = time.time() | |
| processing_time = end_time - start_time | |
| print(f"Progress: 100% - Generation completed in {processing_time:.2f} seconds") | |
| progress(1.0, desc=f"Generation completed in {processing_time:.2f} seconds") | |
| # Fix the URL extraction logic | |
| image_215 = [] | |
| image_818 = [] | |
| if "outputs" in result: | |
| if "215" in result["outputs"]: | |
| image_215 = [img["url"] for img in result["outputs"]["215"]["images"]] | |
| if "818" in result["outputs"]: | |
| image_818 = [img["url"] for img in result["outputs"]["818"]["images"]] | |
| print(f"Image 215: {image_215}") | |
| print(f"Image 818: {image_818}") | |
| # Return all generated image URLs and processing time | |
| # Get the first key from outputs dynamically | |
| return ( | |
| image_215, | |
| image_818, | |
| f"Processing time: {processing_time:.2f} seconds" | |
| ) | |
| def change_product_preview(product_name): | |
| # Load prompts from JSON file | |
| with open('prompt.json', 'r') as f: | |
| prompts = json.load(f) | |
| # Find the matching prompt data | |
| prompt_data = next((item for item in prompts if item['title'] == product_name), None) | |
| if prompt_data: | |
| return ( | |
| f"thumbnail/{product_name}.png", | |
| prompt_data['woman'], | |
| prompt_data['man'] | |
| ) | |
| return None, "", "" | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| product_name = gr.Dropdown(label="Product Name", choices=["Winter", "Classy", "Night Out", "Romantic"], value="Winter") | |
| product_preview = gr.Image(label="Product Preview", type="filepath", value="thumbnail/Winter.png", height=500, width=500) | |
| with gr.Row(): | |
| image1_input = gr.Image(label="Upload Woman Image", type="filepath", value="user3-f.jpg") | |
| image2_input = gr.Image(label="Upload Man Image", type="filepath", value="user3-m.jpg") | |
| with gr.Row(): | |
| with gr.Column(): | |
| woman_prompt = gr.Textbox( | |
| label="Woman Prompt", | |
| value="Close-up, portrait photo, a woman, {hair_feature} wearing a cream-colored wool coat, chunky knit scarf, and matching earmuffs, standing on the same snow-dusted cobblestone street, illuminated Eiffel Tower glowing golden in the background, snowflakes sparkling in the warm streetlight glow, looking at camera with gentle smile." | |
| ) | |
| girl_name = gr.Textbox( | |
| label="Girl Name", | |
| value="julie delpy" | |
| ) | |
| girl_hair_length = gr.Dropdown(label="Girl Hair Length", choices=["Short", "Medium", "Long"], value="Long") | |
| girl_hair_style = gr.Dropdown(label="Girl Hair Style", choices=["Straight", "Wavy", "Curly"], value="Straight") | |
| girl_hair_color = gr.Dropdown(label="Girl Hair Color", choices=["Blonde", "Brown", "Black", "Brunette", "Redhead", "Bronde"], value="Bronde") | |
| with gr.Column(): | |
| man_prompt = gr.Textbox( | |
| label="Man Prompt", | |
| value="Close-up, portrait photo, a man, {hair_feature} wearing a dark navy wool peacoat, cashmere scarf, and leather gloves, standing on a snow-dusted cobblestone street, illuminated Eiffel Tower in the background glowing golden against the night sky, gentle snowflakes catching the warm glow of vintage streetlamps, looking at camera with confident expression." | |
| ) | |
| boy_name = gr.Textbox( | |
| label="Boy Name", | |
| value="ethan hawke" | |
| ) | |
| boy_hair_length = gr.Dropdown(label="Boy Hair Length", choices=["Short", "Medium", "Long", "Bald"], value="Short") | |
| boy_hair_style = gr.Dropdown(label="Boy Hair Style", choices=["None", "Undercut", "Mullet", "French Crop", "Slicked Back", "Fade", "Buzz Cut"], value="Undercut") | |
| boy_hair_color = gr.Dropdown(label="Boy Hair Color", choices=["None", "Blonde", "Brown", "Black", "Brunette", "Redhead"], value="Black") | |
| batch_size = gr.Slider(minimum=1, maximum=8, value=4, step=1, label="Batch Size") | |
| generate_btn = gr.Button("Generate") | |
| with gr.Row(): | |
| image_output = gr.Gallery(label="Generated Image Raw") | |
| image_output_processed = gr.Gallery(label="Generated Image Final") | |
| time_output = gr.Textbox(label="Processing Time") | |
| generate_btn.click( | |
| fn=generate_paris_images, | |
| inputs=[product_name, image1_input, image2_input, woman_prompt, man_prompt, girl_name, girl_hair_length, girl_hair_style, girl_hair_color, boy_name, boy_hair_length, boy_hair_style, boy_hair_color, batch_size], | |
| outputs=[image_output, image_output_processed, time_output] | |
| ) | |
| product_name.change( | |
| fn=change_product_preview, | |
| inputs=[product_name], | |
| outputs=[product_preview, woman_prompt, man_prompt] | |
| ) | |
| if __name__ == "__main__": | |
| print("Starting Gradio interface...") | |
| demo.launch() | |