Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
9d38503
1
Parent(s):
6138235
switch model
Browse files- app.py +5 -14
- pipeline/util.py +1 -1
app.py
CHANGED
|
@@ -24,27 +24,18 @@ MODELS = {"RealVisXL 5 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
|
|
| 24 |
def load_model(model_id):
|
| 25 |
global pipe, last_loaded_model
|
| 26 |
|
| 27 |
-
if model_id != last_loaded_model:
|
| 28 |
-
|
| 29 |
# Initialize the models and pipeline
|
| 30 |
controlnet = ControlNetUnionModel.from_pretrained(
|
| 31 |
"brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
|
| 32 |
-
)
|
| 33 |
-
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
|
| 34 |
-
if pipe is not None:
|
| 35 |
-
optionally_disable_offloading(pipe)
|
| 36 |
-
torch_gc()
|
| 37 |
pipe = StableDiffusionXLControlNetTileSRPipeline.from_pretrained(
|
| 38 |
MODELS[model_id], controlnet=controlnet, vae=vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
|
| 39 |
-
)
|
| 40 |
-
pipe.enable_model_cpu_offload() # << Enable this if you have limited VRAM
|
| 41 |
pipe.enable_vae_tiling() # << Enable this if you have limited VRAM
|
| 42 |
pipe.enable_vae_slicing() # << Enable this if you have limited VRAM
|
| 43 |
-
|
| 44 |
-
unet = UNet2DConditionModel.from_pretrained(MODELS[model_id], subfolder="unet", variant="fp16", use_safetensors=True)
|
| 45 |
-
quantize_8bit(unet) # << Enable this if you have limited VRAM
|
| 46 |
-
pipe.unet = unet
|
| 47 |
-
|
| 48 |
last_loaded_model = model_id
|
| 49 |
|
| 50 |
load_model("RealVisXL 5 Lightning")
|
|
|
|
| 24 |
def load_model(model_id):
|
| 25 |
global pipe, last_loaded_model
|
| 26 |
|
| 27 |
+
if model_id != last_loaded_model:
|
|
|
|
| 28 |
# Initialize the models and pipeline
|
| 29 |
controlnet = ControlNetUnionModel.from_pretrained(
|
| 30 |
"brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
|
| 31 |
+
).to(device)
|
| 32 |
+
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
|
|
|
|
|
|
|
|
|
|
| 33 |
pipe = StableDiffusionXLControlNetTileSRPipeline.from_pretrained(
|
| 34 |
MODELS[model_id], controlnet=controlnet, vae=vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
|
| 35 |
+
).to(device)
|
| 36 |
+
#pipe.enable_model_cpu_offload() # << Enable this if you have limited VRAM
|
| 37 |
pipe.enable_vae_tiling() # << Enable this if you have limited VRAM
|
| 38 |
pipe.enable_vae_slicing() # << Enable this if you have limited VRAM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
last_loaded_model = model_id
|
| 40 |
|
| 41 |
load_model("RealVisXL 5 Lightning")
|
pipeline/util.py
CHANGED
|
@@ -213,7 +213,7 @@ def torch_gc():
|
|
| 213 |
if torch.cuda.is_available():
|
| 214 |
with torch.cuda.device("cuda"):
|
| 215 |
torch.cuda.empty_cache()
|
| 216 |
-
|
| 217 |
|
| 218 |
gc.collect()
|
| 219 |
|
|
|
|
| 213 |
if torch.cuda.is_available():
|
| 214 |
with torch.cuda.device("cuda"):
|
| 215 |
torch.cuda.empty_cache()
|
| 216 |
+
torch.cuda.ipc_collect()
|
| 217 |
|
| 218 |
gc.collect()
|
| 219 |
|