Spaces:
Runtime error
Runtime error
cuda
Browse files
app.py
CHANGED
|
@@ -10,12 +10,12 @@ from decomp_diffusion.diffusion.respace import SpacedDiffusion
|
|
| 10 |
from decomp_diffusion.gen_image import *
|
| 11 |
from download import download_model
|
| 12 |
|
|
|
|
|
|
|
| 13 |
# fix randomness
|
| 14 |
th.manual_seed(0)
|
| 15 |
np.random.seed(0)
|
| 16 |
|
| 17 |
-
import gradio as gr
|
| 18 |
-
|
| 19 |
|
| 20 |
def get_pil_im(im, resolution=64):
|
| 21 |
im = imresize(im, (resolution, resolution))[:, :, :3]
|
|
@@ -112,7 +112,7 @@ model_kwargs.update(dict(
|
|
| 112 |
clevr_model = create_diffusion_model(**model_kwargs)
|
| 113 |
clevr_model.eval()
|
| 114 |
|
| 115 |
-
device = 'cuda'
|
| 116 |
clevr_model.to(device)
|
| 117 |
|
| 118 |
print(f'loading from {ckpt_path}')
|
|
|
|
| 10 |
from decomp_diffusion.gen_image import *
|
| 11 |
from download import download_model
|
| 12 |
|
| 13 |
+
import gradio as gr
|
| 14 |
+
|
| 15 |
# fix randomness
|
| 16 |
th.manual_seed(0)
|
| 17 |
np.random.seed(0)
|
| 18 |
|
|
|
|
|
|
|
| 19 |
|
| 20 |
def get_pil_im(im, resolution=64):
|
| 21 |
im = imresize(im, (resolution, resolution))[:, :, :3]
|
|
|
|
| 112 |
clevr_model = create_diffusion_model(**model_kwargs)
|
| 113 |
clevr_model.eval()
|
| 114 |
|
| 115 |
+
device = 'cuda' if th.cuda.is_available() else 'cpu'
|
| 116 |
clevr_model.to(device)
|
| 117 |
|
| 118 |
print(f'loading from {ckpt_path}')
|