Spaces:
Runtime error
Runtime error
Update utils/common_viz.py
Browse files- utils/common_viz.py +6 -6
utils/common_viz.py
CHANGED
|
@@ -106,23 +106,23 @@ def get_batch(
|
|
| 106 |
|
| 107 |
def init(
|
| 108 |
config_name: str,
|
| 109 |
-
) -> Tuple[Diffuser, clip.model.CLIP, MultimodalDataset
|
| 110 |
with initialize(version_base="1.3", config_path="../configs"):
|
| 111 |
config = compose(config_name=config_name)
|
| 112 |
|
| 113 |
OmegaConf.register_new_resolver("eval", eval)
|
| 114 |
|
| 115 |
# Initialize model
|
| 116 |
-
device = torch.device(config.compnode.device)
|
| 117 |
diffuser = instantiate(config.diffuser)
|
| 118 |
-
state_dict = torch.load(config.checkpoint_path, map_location=
|
| 119 |
state_dict["ema.initted"] = diffuser.ema.initted
|
| 120 |
state_dict["ema.step"] = diffuser.ema.step
|
| 121 |
diffuser.load_state_dict(state_dict, strict=False)
|
| 122 |
-
diffuser.to(
|
| 123 |
|
| 124 |
# Initialize CLIP model
|
| 125 |
-
clip_model = load_clip_model("ViT-B/32",
|
| 126 |
|
| 127 |
# Initialize dataset
|
| 128 |
config.dataset.char.load_vertices = True
|
|
@@ -133,4 +133,4 @@ def init(
|
|
| 133 |
diffuser.get_matrix = dataset.get_matrix
|
| 134 |
diffuser.v_get_matrix = dataset.get_matrix
|
| 135 |
|
| 136 |
-
return diffuser, clip_model, dataset, device
|
|
|
|
| 106 |
|
| 107 |
def init(
|
| 108 |
config_name: str,
|
| 109 |
+
) -> Tuple[Diffuser, clip.model.CLIP, MultimodalDataset]:
|
| 110 |
with initialize(version_base="1.3", config_path="../configs"):
|
| 111 |
config = compose(config_name=config_name)
|
| 112 |
|
| 113 |
OmegaConf.register_new_resolver("eval", eval)
|
| 114 |
|
| 115 |
# Initialize model
|
| 116 |
+
# device = torch.device(config.compnode.device)
|
| 117 |
diffuser = instantiate(config.diffuser)
|
| 118 |
+
state_dict = torch.load(config.checkpoint_path, map_location="cpu")["state_dict"]
|
| 119 |
state_dict["ema.initted"] = diffuser.ema.initted
|
| 120 |
state_dict["ema.step"] = diffuser.ema.step
|
| 121 |
diffuser.load_state_dict(state_dict, strict=False)
|
| 122 |
+
diffuser.to("cpu").eval()
|
| 123 |
|
| 124 |
# Initialize CLIP model
|
| 125 |
+
clip_model = load_clip_model("ViT-B/32", "cpu")
|
| 126 |
|
| 127 |
# Initialize dataset
|
| 128 |
config.dataset.char.load_vertices = True
|
|
|
|
| 133 |
diffuser.get_matrix = dataset.get_matrix
|
| 134 |
diffuser.v_get_matrix = dataset.get_matrix
|
| 135 |
|
| 136 |
+
return diffuser, clip_model, dataset, config.compnode.device
|