Spaces:
Running on Zero
Running on Zero
Commit ·
0bb8ff5
1
Parent(s): 3a104e9
change device
Browse files
app.py
CHANGED
|
@@ -12,6 +12,7 @@ from src.utils import (
|
|
| 12 |
)
|
| 13 |
from diffusers import StableDiffusionXLPipeline
|
| 14 |
|
|
|
|
| 15 |
|
| 16 |
def get_model_param_summary(model, verbose=False):
|
| 17 |
params_dict = dict()
|
|
@@ -29,7 +30,6 @@ def get_model_param_summary(model, verbose=False):
|
|
| 29 |
@dataclass
|
| 30 |
class GradioArgs:
|
| 31 |
ckpt: str = "./mask/ff.pt"
|
| 32 |
-
device: str = "cuda:0"
|
| 33 |
seed: list = None
|
| 34 |
prompt: str = None
|
| 35 |
mix_precision: str = "bf16"
|
|
@@ -95,9 +95,8 @@ def binary_mask_eval(args):
|
|
| 95 |
# load sdxl model
|
| 96 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
| 97 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
| 98 |
-
).to(
|
| 99 |
|
| 100 |
-
device = args.device
|
| 101 |
torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
|
| 102 |
mask_pipe, hookers = create_pipeline(
|
| 103 |
pipe,
|
|
@@ -132,7 +131,7 @@ def binary_mask_eval(args):
|
|
| 132 |
# reload the original model
|
| 133 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
| 134 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
| 135 |
-
).to(
|
| 136 |
|
| 137 |
# get model param summary
|
| 138 |
print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")
|
|
@@ -143,9 +142,9 @@ def binary_mask_eval(args):
|
|
| 143 |
@spaces.GPU
|
| 144 |
def generate_images(prompt, seed, steps, pipe, pruned_pipe):
|
| 145 |
# Run the model and return images directly
|
| 146 |
-
g_cpu = torch.Generator(
|
| 147 |
original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
|
| 148 |
-
g_cpu = torch.Generator(
|
| 149 |
ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
|
| 150 |
return original_image, ecodiff_image
|
| 151 |
|
|
|
|
| 12 |
)
|
| 13 |
from diffusers import StableDiffusionXLPipeline
|
| 14 |
|
| 15 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
|
| 17 |
def get_model_param_summary(model, verbose=False):
|
| 18 |
params_dict = dict()
|
|
|
|
| 30 |
@dataclass
|
| 31 |
class GradioArgs:
|
| 32 |
ckpt: str = "./mask/ff.pt"
|
|
|
|
| 33 |
seed: list = None
|
| 34 |
prompt: str = None
|
| 35 |
mix_precision: str = "bf16"
|
|
|
|
| 95 |
# load sdxl model
|
| 96 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
| 97 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
| 98 |
+
).to(device)
|
| 99 |
|
|
|
|
| 100 |
torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
|
| 101 |
mask_pipe, hookers = create_pipeline(
|
| 102 |
pipe,
|
|
|
|
| 131 |
# reload the original model
|
| 132 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
| 133 |
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
| 134 |
+
).to(device)
|
| 135 |
|
| 136 |
# get model param summary
|
| 137 |
print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")
|
|
|
|
| 142 |
@spaces.GPU
|
| 143 |
def generate_images(prompt, seed, steps, pipe, pruned_pipe):
|
| 144 |
# Run the model and return images directly
|
| 145 |
+
g_cpu = torch.Generator(device).manual_seed(seed)
|
| 146 |
original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
|
| 147 |
+
g_cpu = torch.Generator(device).manual_seed(seed)
|
| 148 |
ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
|
| 149 |
return original_image, ecodiff_image
|
| 150 |
|