|
|
import gradio as gr |
|
|
import torch |
|
|
from diffusers import FluxPipeline |
|
|
from PIL import Image |
|
|
import spaces |
|
|
|
|
|
@spaces.GPU |
|
|
def load_flux_model(): |
|
|
pipe = FluxPipeline.from_pretrained( |
|
|
"black-forest-labs/FLUX.1-dev", |
|
|
torch_dtype=torch.bfloat16 |
|
|
).to("cuda") |
|
|
return pipe |
|
|
|
|
|
flux_pipe = load_flux_model() |
|
|
|
|
|
def generate_design(images, apparel_type, guidance_scale=7.0, steps=30, num_outputs=1): |
|
|
if not images: |
|
|
raise gr.Error("Please upload at least one image.") |
|
|
|
|
|
|
|
|
|
|
|
ref_embeds = [] |
|
|
for img in images: |
|
|
img = img.convert("RGB") |
|
|
embed = flux_pipe.encode_image(img) |
|
|
ref_embeds.append(embed) |
|
|
combined_embed = torch.mean(torch.stack(ref_embeds), dim=0) |
|
|
|
|
|
prompt = f"Design a {apparel_type} inspired by the given references, fashionable, photorealistic, detailed fabric textures." |
|
|
|
|
|
outputs = flux_pipe( |
|
|
prompt=prompt, |
|
|
image_embeds=combined_embed, |
|
|
guidance_scale=guidance_scale, |
|
|
num_inference_steps=steps, |
|
|
num_images_per_prompt=num_outputs |
|
|
).images |
|
|
|
|
|
return outputs |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Fashion Design Generator 👗✨") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
images = gr.Image(type="pil", label="Upload reference images", tool="editor", image_mode="RGB", sources=["upload"], elem_id="multi_upload", interactive=True, value=None, every=1, container=True, show_label=True, mirror_webcam=False, height=300, allow_multiple=True) |
|
|
apparel_type = gr.Textbox(label="Type of Apparel", placeholder="e.g., summer dress, jacket, kurta") |
|
|
guidance = gr.Slider(1, 15, 7.0, 0.1, label="Guidance Scale") |
|
|
steps = gr.Slider(10, 50, 30, 1, label="Steps") |
|
|
num_outputs = gr.Slider(1, 4, 1, 1, label="Number of Designs") |
|
|
btn = gr.Button("Generate Designs", variant="primary") |
|
|
with gr.Column(): |
|
|
gallery = gr.Gallery(label="Generated Designs", height=300) |
|
|
|
|
|
btn.click(generate_design, inputs=[images, apparel_type, guidance, steps, num_outputs], outputs=gallery) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|