Update src/pipeline.py
Browse files- src/pipeline.py +3 -1
src/pipeline.py
CHANGED
|
@@ -19,11 +19,13 @@ def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
|
|
| 19 |
return callback_kwargs
|
| 20 |
|
| 21 |
def load_pipeline(pipeline=None) -> StableDiffusionXLPipeline:
|
|
|
|
| 22 |
if not pipeline:
|
|
|
|
| 23 |
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
| 24 |
"stablediffusionapi/newdream-sdxl-20",
|
| 25 |
torch_dtype=torch.float16,
|
| 26 |
-
).to(
|
| 27 |
|
| 28 |
pipeline.scheduler = SchedulerWrapper(DDIMScheduler.from_config(pipeline.scheduler.config))
|
| 29 |
|
|
|
|
| 19 |
return callback_kwargs
|
| 20 |
|
| 21 |
def load_pipeline(pipeline=None) -> StableDiffusionXLPipeline:
|
| 22 |
+
device = 'cpu'
|
| 23 |
if not pipeline:
|
| 24 |
+
device = "cuda"
|
| 25 |
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
| 26 |
"stablediffusionapi/newdream-sdxl-20",
|
| 27 |
torch_dtype=torch.float16,
|
| 28 |
+
).to(device)
|
| 29 |
|
| 30 |
pipeline.scheduler = SchedulerWrapper(DDIMScheduler.from_config(pipeline.scheduler.config))
|
| 31 |
|