Spaces:
Runtime error
Runtime error
File size: 2,121 Bytes
1601b51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import gradio as gr
from PIL import Image
from diffusers import StableDiffusionImg2ImgPipeline
import torch
import uuid
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# Load model
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=dtype,
use_safetensors=True
).to(device)
# Resize to match selected aspect ratio
def resize_to_aspect(image, aspect_ratio):
width, height = image.size
aspect_map = {
"1:1": (min(width, height), min(width, height)),
"16:9": (width, int(width * 9 / 16)),
"4:5": (width, int(width * 5 / 4)),
"9:16": (int(height * 9 / 16), height)
}
target_w, target_h = aspect_map.get(aspect_ratio, (width, height))
return image.resize((target_w, target_h))
def resize_to_512(image):
return image.resize((512, 512))
# Generate the new image
def generate_img(product_img, prompt, aspect_ratio):
resized_img = resize_to_aspect(product_img, aspect_ratio).convert("RGB")
resized_img = resize_to_512(resized_img)
output = pipe(prompt=prompt, image=resized_img, strength=0.75, guidance_scale=7.5)
image = output.images[0]
save_path = f"/tmp/generated_{uuid.uuid4().hex}.png"
image.save(save_path)
return image, save_path
# Launch interface
demo = gr.Interface(
fn=generate_img,
inputs=[
gr.Image(type="pil", label="Upload Product Image", image_mode='RGB'),
gr.Textbox(label="Prompt", placeholder="Describe what you want to generate"),
gr.Dropdown(["1:1", "16:9", "4:5", "9:16"], label="Aspect Ratio", value="1:1")
],
outputs=[
gr.Image(label='Preview'),
gr.File(label='Download Image')
],
title="Image-to-Image Product Generator",
description="Upload a product image, describe your idea, and select the output aspect ratio."
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True, share=True, ssr_mode=False)
|