sd3-shecodes / app.py
Aditibaheti's picture
Update app.py
f53ed44 verified
raw
history blame
5.49 kB
import spaces
import os
import random
import gradio as gr
import numpy as np
from diffusers import DiffusionPipeline
import torch
from huggingface_hub import login
device = "cuda" if torch.cuda.is_available() else "cpu"
# Set your Hugging Face token
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
login(token=HUGGINGFACE_TOKEN)
# Path to your model repository and safetensors weights
base_model_repo = "stabilityai/stable-diffusion-3-medium-diffusers"
lora_weights_path = "./pytorch_lora_weights.safetensors"
# Load the base model
pipeline = DiffusionPipeline.from_pretrained(
base_model_repo,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_auth_token=HUGGINGFACE_TOKEN
)
pipeline.load_lora_weights(lora_weights_path)
pipeline = pipeline.to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024 # Reduce max image size to fit within memory constraints
@spaces.GPU
def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
image = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator
).images[0]
return image
examples = [
["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"],
["An astronaut riding a green horse"],
["A delicious ceviche cheesecake slice"],
]
css = """
body {
background-color: #ffffff; /* Myntra's white background */
color: #282c3f; /* Myntra's primary text color */
font-family: 'Arial', sans-serif;
margin: 0;
padding: 0;
}
#header {
background-color: #ff3f6c; /* Myntra's pink color */
color: white;
text-align: center;
padding: 20px;
font-size: 24px;
font-weight: bold;
}
#col-container {
margin: 0 auto;
max-width: 720px;
padding: 20px;
border: 1px solid #ebebeb;
border-radius: 8px;
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
}
.gr-button {
background-color: #ff3f6c; /* Myntra's pink color */
color: white;
border: none;
padding: 10px 20px;
font-size: 16px;
border-radius: 5px;
cursor: pointer;
margin-top: 10px;
}
.gr-button:hover {
background-color: #e62e5c; /* Darker shade for hover effect */
}
.gr-textbox, .gr-slider, .gr-checkbox, .gr-accordion {
margin-bottom: 20px;
}
.gr-markdown {
text-align: center;
font-size: 24px;
margin-bottom: 20px;
}
.gr-image {
border: 1px solid #ebebeb;
border-radius: 8px;
margin-top: 20px;
}
"""
if torch.cuda.is_available():
power_device = "GPU"
else:
power_device = "CPU"
with gr.Blocks(css=css) as demo:
gr.HTML("<div id='header'>Myntra Text-to-Image Generation</div>")
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
Currently running on {power_device}.
""")
with gr.Row():
prompt = gr.Textbox(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Generate", scale=0)
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Textbox(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
visible=True,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=512,
step=32,
value=512,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=2048,
step=32,
value=512,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=7.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=30,
)
gr.Examples(
examples=examples,
inputs=[prompt],
outputs=[result],
fn=infer,
cache_examples=True,
)
run_button.click(
fn=infer,
inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
outputs=[result]
)
demo.queue().launch()