Spaces:
Build error
Build error
Commit ·
56fa2d2
1
Parent(s): 8f4bc3f
update
Browse files
app.py
CHANGED
|
@@ -43,8 +43,12 @@ if device != "cpu":
|
|
| 43 |
decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)#.to(device)
|
| 44 |
|
| 45 |
if ENABLE_CPU_OFFLOAD:
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
else:
|
| 49 |
prior_pipeline.to(device)
|
| 50 |
decoder_pipeline.to(device)
|
|
@@ -101,10 +105,10 @@ def generate(
|
|
| 101 |
print("")
|
| 102 |
|
| 103 |
#previewer.eval().requires_grad_(False).to(device).to(dtype)
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
generator = torch.Generator().manual_seed(seed)
|
| 109 |
print("prior_num_inference_steps: ", prior_num_inference_steps)
|
| 110 |
prior_output = prior_pipeline(
|
|
|
|
| 43 |
decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)#.to(device)
|
| 44 |
|
| 45 |
if ENABLE_CPU_OFFLOAD:
|
| 46 |
+
if device == "mps":
|
| 47 |
+
prior_pipeline.enable_attention_slicing()
|
| 48 |
+
decoder_pipeline.enable_attention_slicing()
|
| 49 |
+
else:
|
| 50 |
+
prior_pipeline.enable_model_cpu_offload()
|
| 51 |
+
decoder_pipeline.enable_model_cpu_offload()
|
| 52 |
else:
|
| 53 |
prior_pipeline.to(device)
|
| 54 |
decoder_pipeline.to(device)
|
|
|
|
| 105 |
print("")
|
| 106 |
|
| 107 |
#previewer.eval().requires_grad_(False).to(device).to(dtype)
|
| 108 |
+
if device != "cpu":
|
| 109 |
+
prior_pipeline.to(device)
|
| 110 |
+
decoder_pipeline.to(device)
|
| 111 |
+
|
| 112 |
generator = torch.Generator().manual_seed(seed)
|
| 113 |
print("prior_num_inference_steps: ", prior_num_inference_steps)
|
| 114 |
prior_output = prior_pipeline(
|