Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,10 +3,14 @@ import torch
|
|
| 3 |
from annotator.util import resize_image, HWC3
|
| 4 |
from cldm.model import create_model, load_state_dict
|
| 5 |
from cldm.ddim_hacked import DDIMSampler
|
|
|
|
| 6 |
|
| 7 |
# Initialize the model and other components
|
|
|
|
| 8 |
model = create_model('./models/cldm_v21_512_latctrl_coltrans.yaml').cpu()
|
| 9 |
-
|
|
|
|
|
|
|
| 10 |
model = model.cuda()
|
| 11 |
ddim_sampler = DDIMSampler(model)
|
| 12 |
|
|
|
|
| 3 |
from annotator.util import resize_image, HWC3
|
| 4 |
from cldm.model import create_model, load_state_dict
|
| 5 |
from cldm.ddim_hacked import DDIMSampler
|
| 6 |
+
from huggingface_hub import hf_hub_download
|
| 7 |
|
| 8 |
# Initialize the model and other components
|
| 9 |
+
# config = "./models/cldm_v21_512_latctrl_coltrans.yaml'"
|
| 10 |
model = create_model('./models/cldm_v21_512_latctrl_coltrans.yaml').cpu()
|
| 11 |
+
ckpt = hf_hub_download(repo_id="xywwww/scene_diffusion", filename="checkpoints/epoch=25-step=112553.ckpt")
|
| 12 |
+
# model.load_state_dict(load_state_dict('xywwww/scene_diffusion/checkpoints/epoch=25-step=112553.ckpt', location='cuda'), strict=False)
|
| 13 |
+
model = load_model_checkpoint(model, ckpt)
|
| 14 |
model = model.cuda()
|
| 15 |
ddim_sampler = DDIMSampler(model)
|
| 16 |
|