import gradio as gr import numpy as np import random from diffusers import StableDiffusion3Pipeline, DiffusionPipeline import torch from transformers import T5EncoderModel from huggingface_hub import login import os import gc import psutil def flush(): gc.collect() torch.cuda.empty_cache() def bytes_to_giga_bytes(bytes): return bytes / 1024 / 1024 / 1024 def get_memory_usage(): process = psutil.Process(os.getpid()) mem_info = process.memory_info() return f"{mem_info.rss / (1024 ** 2):.2f} MB" def log_memory(step): memory_log.append(f"{step}: {get_memory_usage()}") device = "cuda" if torch.cuda.is_available() else "cpu" # Set your Hugging Face token HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") login(token=HUGGINGFACE_TOKEN) # Path to your model repository and safetensors weights base_model_repo = "stabilityai/stable-diffusion-3-medium-diffusers" lora_weights_path = "./pytorch_lora_weights.safetensors" memory_log = [] log_memory("Before loading the model") # Load text encoder in 8-bit text_encoder = T5EncoderModel.from_pretrained( base_model_repo, subfolder="text_encoder_3", load_in_8bit=True, device_map="auto" ) # Load the pipeline with 8-bit text encoder pipeline = StableDiffusion3Pipeline.from_pretrained( base_model_repo, text_encoder_3=text_encoder, transformer=None, vae=None, device_map="balanced", ) log_memory("After loading the pipeline") # Load and apply the LoRA weights pipeline.load_lora_weights(lora_weights_path) log_memory("After loading LoRA weights") with torch.no_grad(): for _ in range(3): prompt = "a photo of a cat" ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, prompt_3=None) start = time.time() for _ in range(10): ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, prompt_3=None) end = time.time() avg_prompt_encoding_time = (end - start) / 10 del text_encoder del pipeline flush() pipeline = StableDiffusion3Pipeline.from_pretrained( base_model_repo, text_encoder=None, text_encoder_2=None, text_encoder_3=None, tokenizer=None, tokenizer_2=None, tokenizer_3=None, torch_dtype=torch.float16 ).to("cuda") pipeline.set_progress_bar_config(disable=True) log_memory("After reloading the pipeline without text encoder") # Load and apply the LoRA weights again for the reloaded pipeline pipeline.load_lora_weights(lora_weights_path) log_memory("After reloading LoRA weights for inference") for _ in range(3): _ = pipeline( prompt_embeds=prompt_embeds.half(), negative_prompt_embeds=negative_prompt_embeds.half(), pooled_prompt_embeds=pooled_prompt_embeds.half(), negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.half(), ) start = time.time() for _ in range(10): _ = pipeline( prompt_embeds=prompt_embeds.half(), negative_prompt_embeds=negative_prompt_embeds.half(), pooled_prompt_embeds=pooled_prompt_embeds.half(), negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.half(), ) end = time.time() avg_inference_time = (end - start) / 10 log_memory("After inference") print(f"Average prompt encoding time: {avg_prompt_encoding_time:.3f} seconds.") print(f"Average inference time: {avg_inference_time:.3f} seconds.") print(f"Total time: {(avg_prompt_encoding_time + avg_inference_time):.3f} seconds.") print( f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB" ) image = pipeline( prompt_embeds=prompt_embeds.half(), negative_prompt_embeds=negative_prompt_embeds.half(), pooled_prompt_embeds=pooled_prompt_embeds.half(), negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.half(), ).images[0] image.save("output_8bit.png") log_memory("After saving the image") MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 768 # Reduce max image size to fit within memory constraints def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator(device=device).manual_seed(seed) image = pipeline( prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator ).images[0] log_memory("After inference") return image, "\n".join(memory_log) examples = [ ["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"], ["An astronaut riding a green horse"], ["A delicious ceviche cheesecake slice"], ] css = """ #col-container { margin: 0 auto; max-width: 520px; } #memory-log { white-space: pre-wrap; background: #f8f9fa; padding: 10px; border-radius: 5px; } """ if torch.cuda.is_available(): power_device = "GPU" else: power_device = "CPU" with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(f""" # Text-to-Image Gradio Template Currently running on {power_device}. """) with gr.Row(): prompt = gr.Textbox( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, ) run_button = gr.Button("Run", scale=0) result = gr.Image(label="Result", show_label=False) memory_log_output = gr.Textbox(label="Memory Log", elem_id="memory-log", lines=10, interactive=False) with gr.Accordion("Advanced Settings", open=False): negative_prompt = gr.Textbox( label="Negative prompt", max_lines=1, placeholder="Enter a negative prompt", visible=True, ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512, ) height = gr.Slider( label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512, ) with gr.Row(): guidance_scale = gr.Slider( label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=7.5, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=50, step=1, value=30, ) gr.Examples( examples=examples, inputs=[prompt] ) run_button.click( fn=infer, inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps], outputs=[result, memory_log_output] ) demo.queue().launch()