nroggendorff commited on
Commit
5312c2f
·
verified ·
1 Parent(s): 87f084d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -6,29 +6,29 @@ from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
6
  prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.float16)
7
  decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16)
8
 
9
- prior_pipeline.enable_model_cpu_offload()
10
- decoder_pipeline.enable_model_cpu_offload()
11
-
12
  @spaces.GPU
13
  def generate(prompt, negative_prompt, width, height, steps):
 
14
  prior_output = prior_pipeline(
15
  prompt=prompt,
 
 
16
  height=height,
17
  guidance_scale=4.0,
18
  num_images_per_prompt=1,
19
- width=width,
20
- num_inference_steps=steps,
21
- negative_prompt=negative_prompt
22
  )
 
 
23
  decoder_output = decoder_pipeline(
24
  image_embeddings=prior_output.image_embeddings.to(torch.float16),
25
  prompt=prompt,
26
  guidance_scale=0.0,
27
  output_type="pil",
28
- num_inference_steps=steps,
29
  negative_prompt=negative_prompt
30
  ).images[0]
31
- return torch.clamp(decoder_output, 0, 1)
32
 
33
  with gr.Blocks() as demo:
34
  with gr.Row():
 
6
  prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.float16)
7
  decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16)
8
 
 
 
 
9
  @spaces.GPU
10
  def generate(prompt, negative_prompt, width, height, steps):
11
+ prior_pipeline.enable_model_cpu_offload()
12
  prior_output = prior_pipeline(
13
  prompt=prompt,
14
+ negative_prompt=negative_prompt,
15
+ width=width,
16
  height=height,
17
  guidance_scale=4.0,
18
  num_images_per_prompt=1,
19
+ num_inference_steps=steps
 
 
20
  )
21
+
22
+ decoder_pipeline.enable_model_cpu_offload()
23
  decoder_output = decoder_pipeline(
24
  image_embeddings=prior_output.image_embeddings.to(torch.float16),
25
  prompt=prompt,
26
  guidance_scale=0.0,
27
  output_type="pil",
28
+ num_inference_steps=10,
29
  negative_prompt=negative_prompt
30
  ).images[0]
31
+ return decoder_output
32
 
33
  with gr.Blocks() as demo:
34
  with gr.Row():