stable-cascade / app.py
allyourtech's picture
Upload app.py
1fffadf verified
import os
import random
import gradio as gr
import numpy as np
import PIL.Image
import torch
import argparse
from typing import List
from diffusers.utils import numpy_to_pil
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
from previewer.modules import Previewer
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
DESCRIPTION = "# Stable Cascade"
DESCRIPTION += "\n<p style=\"text-align: center\">Demo by <a href='https://www.youtube.com/@allyourtechai/' target='_blank'>All Your Tech AI</a> using <a href='https://huggingface.co/stabilityai/stable-cascade' target='_blank'>Stable Cascade</a>, a new high resolution text-to-image model by Stability AI <a href='https://huggingface.co/stabilityai/stable-cascade/blob/main/LICENSE' target='_blank'>non-commercial research license</a></p>"
MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = False
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
USE_TORCH_COMPILE = False
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
PREVIEW_IMAGES = True
parser = argparse.ArgumentParser(description='Gradio App Control')
parser.add_argument('--share', action='store_true', help='Create a public shareable URL')
parser.add_argument('--inbrowser', action='store_true', help='Automatically launch the application in a browser')
parser.add_argument('--server_port', type=int, default=7860, help='Server port')
args = parser.parse_args()
dtype = torch.bfloat16
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
dtype = torch.float32
else:
device = "cpu"
print(f"device={device}")
if device != "cpu":
prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype)#.to(device)
decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)#.to(device)
if ENABLE_CPU_OFFLOAD:
prior_pipeline.enable_model_cpu_offload()
decoder_pipeline.enable_model_cpu_offload()
else:
prior_pipeline.to(device)
decoder_pipeline.to(device)
if USE_TORCH_COMPILE:
prior_pipeline.prior = torch.compile(prior_pipeline.prior, mode="reduce-overhead", fullgraph=True)
decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="max-autotune", fullgraph=True)
if PREVIEW_IMAGES:
previewer = Previewer()
previewer_state_dict = torch.load("previewer/previewer_v1_100k.pt", map_location=torch.device('cpu'))["state_dict"]
previewer.load_state_dict(previewer_state_dict)
def callback_prior(i, t, latents):
output = previewer(latents)
output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())
return output
callback_steps = 1
else:
previewer = None
callback_prior = None
callback_steps = None
else:
prior_pipeline = None
decoder_pipeline = None
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
def generate(
prompt: str,
negative_prompt: str = "",
seed: int = 0,
width: int = 1024,
height: int = 1024,
prior_num_inference_steps: int = 30,
# prior_timesteps: List[float] = None,
prior_guidance_scale: float = 4.0,
decoder_num_inference_steps: int = 12,
# decoder_timesteps: List[float] = None,
decoder_guidance_scale: float = 0.0,
num_images_per_prompt: int = 2,
# profile: gr.OAuthProfile | None = None,
) -> PIL.Image.Image:
previewer.eval().requires_grad_(False).to(device).to(dtype)
prior_pipeline.to(device)
decoder_pipeline.to(device)
generator = torch.Generator().manual_seed(seed)
prior_output = prior_pipeline(
prompt=prompt,
height=height,
width=width,
num_inference_steps=prior_num_inference_steps,
timesteps=DEFAULT_STAGE_C_TIMESTEPS,
negative_prompt=negative_prompt,
guidance_scale=prior_guidance_scale,
num_images_per_prompt=num_images_per_prompt,
generator=generator,
callback=callback_prior,
callback_steps=callback_steps
)
if PREVIEW_IMAGES:
for _ in range(len(DEFAULT_STAGE_C_TIMESTEPS)):
r = next(prior_output)
if isinstance(r, list):
yield r[0]
prior_output = r
decoder_output = decoder_pipeline(
image_embeddings=prior_output.image_embeddings,
prompt=prompt,
num_inference_steps=decoder_num_inference_steps,
# timesteps=decoder_timesteps,
guidance_scale=decoder_guidance_scale,
negative_prompt=negative_prompt,
generator=generator,
output_type="pil",
).images
yield decoder_output[0]
examples = [
"Food photography of a delicious steak dish",
"Futuristic soldier with hi-tech armor and weaponry",
"complex 3d render ultra detailed of a beautiful porcelain human brain",
"a red heart in the clouds over water"
]
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_id="duplicate-button",
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
)
with gr.Group():
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced options", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a Negative Prompt",
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=1024,
maximum=MAX_IMAGE_SIZE,
step=512,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=1024,
maximum=MAX_IMAGE_SIZE,
step=512,
value=1024,
)
num_images_per_prompt = gr.Slider(
label="Number of Images",
minimum=1,
maximum=2,
step=1,
value=1,
)
with gr.Row():
prior_guidance_scale = gr.Slider(
label="Prior Guidance Scale",
minimum=0,
maximum=20,
step=0.1,
value=4.0,
)
prior_num_inference_steps = gr.Slider(
label="Prior Inference Steps",
minimum=10,
maximum=30,
step=1,
value=20,
)
decoder_guidance_scale = gr.Slider(
label="Decoder Guidance Scale",
minimum=0,
maximum=0,
step=0.1,
value=0.0,
)
decoder_num_inference_steps = gr.Slider(
label="Decoder Inference Steps",
minimum=4,
maximum=12,
step=1,
value=10,
)
gr.Examples(
examples=examples,
inputs=prompt,
outputs=result,
fn=generate,
cache_examples=CACHE_EXAMPLES,
)
inputs = [
prompt,
negative_prompt,
seed,
width,
height,
prior_num_inference_steps,
# prior_timesteps,
prior_guidance_scale,
decoder_num_inference_steps,
# decoder_timesteps,
decoder_guidance_scale,
num_images_per_prompt,
]
gr.on(
triggers=[prompt.submit, negative_prompt.submit, run_button.click],
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=generate,
inputs=inputs,
outputs=result,
api_name="run",
)
with gr.Blocks(css="style.css") as demo_with_history:
with gr.Tab("App"):
demo.render()
if __name__ == "__main__":
launch_args = {
'inbrowser': args.inbrowser,
'share': args.share,
'server_port' : args.server_port,
}
demo_with_history.launch(**launch_args)