sd3-shecodes / app.py
Aditibaheti's picture
Update app.py
ac34d60 verified
raw
history blame
7.95 kB
import gradio as gr
import numpy as np
import random
from diffusers import StableDiffusion3Pipeline, DiffusionPipeline
import torch
from transformers import T5EncoderModel
from huggingface_hub import login
import os
import gc
import psutil
def flush():
gc.collect()
torch.cuda.empty_cache()
def bytes_to_giga_bytes(bytes):
return bytes / 1024 / 1024 / 1024
def get_memory_usage():
process = psutil.Process(os.getpid())
mem_info = process.memory_info()
return f"{mem_info.rss / (1024 ** 2):.2f} MB"
def log_memory(step):
memory_log.append(f"{step}: {get_memory_usage()}")
device = "cuda" if torch.cuda.is_available() else "cpu"
# Set your Hugging Face token
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
login(token=HUGGINGFACE_TOKEN)
# Path to your model repository and safetensors weights
base_model_repo = "stabilityai/stable-diffusion-3-medium-diffusers"
lora_weights_path = "./pytorch_lora_weights.safetensors"
memory_log = []
log_memory("Before loading the model")
# Load text encoder in 8-bit
text_encoder = T5EncoderModel.from_pretrained(
base_model_repo,
subfolder="text_encoder_3",
load_in_8bit=True,
device_map="auto"
)
# Load the pipeline with 8-bit text encoder
pipeline = StableDiffusion3Pipeline.from_pretrained(
base_model_repo,
text_encoder_3=text_encoder,
transformer=None,
vae=None,
device_map="balanced",
)
log_memory("After loading the pipeline")
# Load and apply the LoRA weights
pipeline.load_lora_weights(lora_weights_path)
log_memory("After loading LoRA weights")
with torch.no_grad():
for _ in range(3):
prompt = "a photo of a cat"
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, prompt_3=None)
start = time.time()
for _ in range(10):
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, prompt_3=None)
end = time.time()
avg_prompt_encoding_time = (end - start) / 10
del text_encoder
del pipeline
flush()
pipeline = StableDiffusion3Pipeline.from_pretrained(
base_model_repo,
text_encoder=None,
text_encoder_2=None,
text_encoder_3=None,
tokenizer=None,
tokenizer_2=None,
tokenizer_3=None,
torch_dtype=torch.float16
).to("cuda")
pipeline.set_progress_bar_config(disable=True)
log_memory("After reloading the pipeline without text encoder")
# Load and apply the LoRA weights again for the reloaded pipeline
pipeline.load_lora_weights(lora_weights_path)
log_memory("After reloading LoRA weights for inference")
for _ in range(3):
_ = pipeline(
prompt_embeds=prompt_embeds.half(),
negative_prompt_embeds=negative_prompt_embeds.half(),
pooled_prompt_embeds=pooled_prompt_embeds.half(),
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.half(),
)
start = time.time()
for _ in range(10):
_ = pipeline(
prompt_embeds=prompt_embeds.half(),
negative_prompt_embeds=negative_prompt_embeds.half(),
pooled_prompt_embeds=pooled_prompt_embeds.half(),
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.half(),
)
end = time.time()
avg_inference_time = (end - start) / 10
log_memory("After inference")
print(f"Average prompt encoding time: {avg_prompt_encoding_time:.3f} seconds.")
print(f"Average inference time: {avg_inference_time:.3f} seconds.")
print(f"Total time: {(avg_prompt_encoding_time + avg_inference_time):.3f} seconds.")
print(
f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)
image = pipeline(
prompt_embeds=prompt_embeds.half(),
negative_prompt_embeds=negative_prompt_embeds.half(),
pooled_prompt_embeds=pooled_prompt_embeds.half(),
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.half(),
).images[0]
image.save("output_8bit.png")
log_memory("After saving the image")
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 768 # Reduce max image size to fit within memory constraints
def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
image = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator
).images[0]
log_memory("After inference")
return image, "\n".join(memory_log)
examples = [
["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"],
["An astronaut riding a green horse"],
["A delicious ceviche cheesecake slice"],
]
css = """
#col-container {
margin: 0 auto;
max-width: 520px;
}
#memory-log {
white-space: pre-wrap;
background: #f8f9fa;
padding: 10px;
border-radius: 5px;
}
"""
if torch.cuda.is_available():
power_device = "GPU"
else:
power_device = "CPU"
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# Text-to-Image Gradio Template
Currently running on {power_device}.
""")
with gr.Row():
prompt = gr.Textbox(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Image(label="Result", show_label=False)
memory_log_output = gr.Textbox(label="Memory Log", elem_id="memory-log", lines=10, interactive=False)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Textbox(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
visible=True,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=7.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=30,
)
gr.Examples(
examples=examples,
inputs=[prompt]
)
run_button.click(
fn=infer,
inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
outputs=[result, memory_log_output]
)
demo.queue().launch()