fnew / app.py
Revrse's picture
Update app.py
747a27b verified
import gradio as gr
import torch
from diffusers import FluxPipeline
from PIL import Image
import spaces
@spaces.GPU
def load_flux_model():
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16
).to("cuda")
return pipe
flux_pipe = load_flux_model()
def generate_design(images, apparel_type, guidance_scale=7.0, steps=30, num_outputs=1):
if not images:
raise gr.Error("Please upload at least one image.")
# Create a reference embedding by averaging CLIP image embeddings
# This gives the model a style/content reference
ref_embeds = []
for img in images:
img = img.convert("RGB")
embed = flux_pipe.encode_image(img)
ref_embeds.append(embed)
combined_embed = torch.mean(torch.stack(ref_embeds), dim=0)
prompt = f"Design a {apparel_type} inspired by the given references, fashionable, photorealistic, detailed fabric textures."
outputs = flux_pipe(
prompt=prompt,
image_embeds=combined_embed, # use custom conditioning
guidance_scale=guidance_scale,
num_inference_steps=steps,
num_images_per_prompt=num_outputs
).images
return outputs
with gr.Blocks() as demo:
gr.Markdown("# Fashion Design Generator 👗✨")
with gr.Row():
with gr.Column():
images = gr.Image(type="pil", label="Upload reference images", tool="editor", image_mode="RGB", sources=["upload"], elem_id="multi_upload", interactive=True, value=None, every=1, container=True, show_label=True, mirror_webcam=False, height=300, allow_multiple=True)
apparel_type = gr.Textbox(label="Type of Apparel", placeholder="e.g., summer dress, jacket, kurta")
guidance = gr.Slider(1, 15, 7.0, 0.1, label="Guidance Scale")
steps = gr.Slider(10, 50, 30, 1, label="Steps")
num_outputs = gr.Slider(1, 4, 1, 1, label="Number of Designs")
btn = gr.Button("Generate Designs", variant="primary")
with gr.Column():
gallery = gr.Gallery(label="Generated Designs", height=300)
btn.click(generate_design, inputs=[images, apparel_type, guidance, steps, num_outputs], outputs=gallery)
if __name__ == "__main__":
demo.launch()