Spaces:
Sleeping
Sleeping
Pre-load all models in RAM
Browse files
app.py
CHANGED
|
@@ -13,6 +13,8 @@ class Model:
|
|
| 13 |
self.name = name
|
| 14 |
self.path = path
|
| 15 |
self.prefix = prefix
|
|
|
|
|
|
|
| 16 |
|
| 17 |
models = [
|
| 18 |
Model("Custom model", "", ""),
|
|
@@ -27,17 +29,24 @@ models = [
|
|
| 27 |
Model("Pony Diffusion", "AstraliteHeart/pony-diffusion", ""),
|
| 28 |
Model("Robo Diffusion", "nousr/robo-diffusion", ""),
|
| 29 |
Model("Cyberpunk Anime", "DGSpitzer/Cyberpunk-Anime-Diffusion", "dgs illustration style "),
|
| 30 |
-
Model("Tron Legacy", "dallinmackay/Tron-Legacy-diffusion", "trnlgcy
|
| 31 |
]
|
| 32 |
|
| 33 |
last_mode = "txt2img"
|
| 34 |
current_model = models[1]
|
| 35 |
current_model_path = current_model.path
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
|
| 43 |
|
|
@@ -69,7 +78,12 @@ def txt_to_img(model_path, prompt, neg_prompt, guidance, steps, width, height, g
|
|
| 69 |
if model_path != current_model_path or last_mode != "txt2img":
|
| 70 |
current_model_path = model_path
|
| 71 |
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
if torch.cuda.is_available():
|
| 74 |
pipe = pipe.to("cuda")
|
| 75 |
last_mode = "txt2img"
|
|
@@ -95,7 +109,11 @@ def img_to_img(model_path, prompt, neg_prompt, img, strength, guidance, steps, w
|
|
| 95 |
if model_path != current_model_path or last_mode != "img2img":
|
| 96 |
current_model_path = model_path
|
| 97 |
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
if torch.cuda.is_available():
|
| 101 |
pipe = pipe.to("cuda")
|
|
|
|
| 13 |
self.name = name
|
| 14 |
self.path = path
|
| 15 |
self.prefix = prefix
|
| 16 |
+
self.pipe_t2i = None
|
| 17 |
+
self.pipe_i2i = None
|
| 18 |
|
| 19 |
models = [
|
| 20 |
Model("Custom model", "", ""),
|
|
|
|
| 29 |
Model("Pony Diffusion", "AstraliteHeart/pony-diffusion", ""),
|
| 30 |
Model("Robo Diffusion", "nousr/robo-diffusion", ""),
|
| 31 |
Model("Cyberpunk Anime", "DGSpitzer/Cyberpunk-Anime-Diffusion", "dgs illustration style "),
|
| 32 |
+
Model("Tron Legacy", "dallinmackay/Tron-Legacy-diffusion", "trnlgcy")
|
| 33 |
]
|
| 34 |
|
| 35 |
last_mode = "txt2img"
|
| 36 |
current_model = models[1]
|
| 37 |
current_model_path = current_model.path
|
| 38 |
+
|
| 39 |
+
if is_colab:
|
| 40 |
+
pipe = StableDiffusionPipeline.from_pretrained(current_model.path, torch_dtype=torch.float16)
|
| 41 |
+
if torch.cuda.is_available():
|
| 42 |
+
pipe = pipe.to("cuda")
|
| 43 |
+
|
| 44 |
+
else: # download all models
|
| 45 |
+
vae = AutoencoderKL.from_pretrained(current_model, subfolder="vae", torch_dtype=torch.float16)
|
| 46 |
+
for model in models[1:]:
|
| 47 |
+
unet = UNet2DConditionModel.from_pretrained(model, subfolder="unet", torch_dtype=torch.float16)
|
| 48 |
+
model.pipe_t2i = StableDiffusionPipeline.from_pretrained(model, unet=unet, vae=vae, torch_dtype=torch.float16)
|
| 49 |
+
model.pipe_i2i = StableDiffusionImg2ImgPipeline.from_pretrained(model, unet=unet, vae=vae, torch_dtype=torch.float16)
|
| 50 |
|
| 51 |
device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
|
| 52 |
|
|
|
|
| 78 |
if model_path != current_model_path or last_mode != "txt2img":
|
| 79 |
current_model_path = model_path
|
| 80 |
|
| 81 |
+
if is_colab:
|
| 82 |
+
pipe = StableDiffusionPipeline.from_pretrained(current_model_path, torch_dtype=torch.float16)
|
| 83 |
+
else:
|
| 84 |
+
pipe = pipe.to("cpu")
|
| 85 |
+
pipe = current_model.pipe_t2i
|
| 86 |
+
|
| 87 |
if torch.cuda.is_available():
|
| 88 |
pipe = pipe.to("cuda")
|
| 89 |
last_mode = "txt2img"
|
|
|
|
| 109 |
if model_path != current_model_path or last_mode != "img2img":
|
| 110 |
current_model_path = model_path
|
| 111 |
|
| 112 |
+
if is_colab:
|
| 113 |
+
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(current_model_path, torch_dtype=torch.float16)
|
| 114 |
+
else:
|
| 115 |
+
pipe = pipe.to("cpu")
|
| 116 |
+
pipe = current_model.pipe_t2i
|
| 117 |
|
| 118 |
if torch.cuda.is_available():
|
| 119 |
pipe = pipe.to("cuda")
|