Spaces:
Build error
Build error
modify device to cpu
Browse files- stable_diffusion.py +3 -1
stable_diffusion.py
CHANGED
|
@@ -18,7 +18,7 @@ from torch import Tensor
|
|
| 18 |
from torch.nn import functional as F
|
| 19 |
from torch.nn.parameter import Parameter
|
| 20 |
|
| 21 |
-
device = "
|
| 22 |
|
| 23 |
def apply_seq(seqs, x):
|
| 24 |
for seq in seqs:
|
|
@@ -830,6 +830,7 @@ class Args(object):
|
|
| 830 |
self.input_mask = input_mask
|
| 831 |
self.input_image_strength = input_image_strength
|
| 832 |
self.unphrase = unphrase
|
|
|
|
| 833 |
|
| 834 |
from PIL import Image
|
| 835 |
|
|
@@ -999,6 +1000,7 @@ class Generate2img(Module):
|
|
| 999 |
@lru_cache()
|
| 1000 |
def generate2img(phrase, steps, model_file, guidance_scale, img_width, img_height, seed, device, input_image, input_mask, input_image_strength=0.5, unphrase=""):
|
| 1001 |
try:
|
|
|
|
| 1002 |
args = Args(phrase, steps, None, guidance_scale, img_width, img_height, seed, device, model_file, input_image, input_mask, input_image_strength, unphrase)
|
| 1003 |
im = Generate2img.instance(args).forward(args.phrase)
|
| 1004 |
im = Generate2img.instance(args).decode_latent2img(im)
|
|
|
|
| 18 |
from torch.nn import functional as F
|
| 19 |
from torch.nn.parameter import Parameter
|
| 20 |
|
| 21 |
+
device = "cpu"
|
| 22 |
|
| 23 |
def apply_seq(seqs, x):
|
| 24 |
for seq in seqs:
|
|
|
|
| 830 |
self.input_mask = input_mask
|
| 831 |
self.input_image_strength = input_image_strength
|
| 832 |
self.unphrase = unphrase
|
| 833 |
+
device = self.device
|
| 834 |
|
| 835 |
from PIL import Image
|
| 836 |
|
|
|
|
| 1000 |
@lru_cache()
|
| 1001 |
def generate2img(phrase, steps, model_file, guidance_scale, img_width, img_height, seed, device, input_image, input_mask, input_image_strength=0.5, unphrase=""):
|
| 1002 |
try:
|
| 1003 |
+
|
| 1004 |
args = Args(phrase, steps, None, guidance_scale, img_width, img_height, seed, device, model_file, input_image, input_mask, input_image_strength, unphrase)
|
| 1005 |
im = Generate2img.instance(args).forward(args.phrase)
|
| 1006 |
im = Generate2img.instance(args).decode_latent2img(im)
|