File size: 2,269 Bytes
565d784
747a27b
 
 
 
565d784
747a27b
 
 
 
 
 
 
565d784
747a27b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
import torch
from diffusers import FluxPipeline
from PIL import Image
import spaces

@spaces.GPU
def load_flux_model():
    pipe = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev", 
        torch_dtype=torch.bfloat16
    ).to("cuda")
    return pipe

flux_pipe = load_flux_model()

def generate_design(images, apparel_type, guidance_scale=7.0, steps=30, num_outputs=1):
    if not images:
        raise gr.Error("Please upload at least one image.")

    # Create a reference embedding by averaging CLIP image embeddings
    # This gives the model a style/content reference
    ref_embeds = []
    for img in images:
        img = img.convert("RGB")
        embed = flux_pipe.encode_image(img)
        ref_embeds.append(embed)
    combined_embed = torch.mean(torch.stack(ref_embeds), dim=0)

    prompt = f"Design a {apparel_type} inspired by the given references, fashionable, photorealistic, detailed fabric textures."

    outputs = flux_pipe(
        prompt=prompt,
        image_embeds=combined_embed,  # use custom conditioning
        guidance_scale=guidance_scale,
        num_inference_steps=steps,
        num_images_per_prompt=num_outputs
    ).images

    return outputs

with gr.Blocks() as demo:
    gr.Markdown("# Fashion Design Generator 👗✨")

    with gr.Row():
        with gr.Column():
            images = gr.Image(type="pil", label="Upload reference images", tool="editor", image_mode="RGB", sources=["upload"], elem_id="multi_upload", interactive=True, value=None, every=1, container=True, show_label=True, mirror_webcam=False, height=300, allow_multiple=True)
            apparel_type = gr.Textbox(label="Type of Apparel", placeholder="e.g., summer dress, jacket, kurta")
            guidance = gr.Slider(1, 15, 7.0, 0.1, label="Guidance Scale")
            steps = gr.Slider(10, 50, 30, 1, label="Steps")
            num_outputs = gr.Slider(1, 4, 1, 1, label="Number of Designs")
            btn = gr.Button("Generate Designs", variant="primary")
        with gr.Column():
            gallery = gr.Gallery(label="Generated Designs", height=300)

    btn.click(generate_design, inputs=[images, apparel_type, guidance, steps, num_outputs], outputs=gallery)

if __name__ == "__main__":
    demo.launch()