Update codebase/inference/inference.py
Browse files
codebase/inference/inference.py
CHANGED
|
@@ -24,12 +24,12 @@ def build_model(model_name, ckpt_path, device):
|
|
| 24 |
if model_name == "ViT-B-32":
|
| 25 |
model, _, _ = open_clip.create_model_and_transforms("ViT-B/32", pretrained="openai")
|
| 26 |
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
| 27 |
-
msg = model.load_state_dict(checkpoint
|
| 28 |
|
| 29 |
elif model_name == "ViT-H-14":
|
| 30 |
model, _, _ = open_clip.create_model_and_transforms("ViT-H/14", pretrained="laion2b_s32b_b79k")
|
| 31 |
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
| 32 |
-
msg = model.load_state_dict(checkpoint
|
| 33 |
|
| 34 |
print(msg)
|
| 35 |
model = model.to(device)
|
|
|
|
| 24 |
if model_name == "ViT-B-32":
|
| 25 |
model, _, _ = open_clip.create_model_and_transforms("ViT-B/32", pretrained="openai")
|
| 26 |
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
| 27 |
+
msg = model.load_state_dict(checkpoint)
|
| 28 |
|
| 29 |
elif model_name == "ViT-H-14":
|
| 30 |
model, _, _ = open_clip.create_model_and_transforms("ViT-H/14", pretrained="laion2b_s32b_b79k")
|
| 31 |
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
| 32 |
+
msg = model.load_state_dict(checkpoint)
|
| 33 |
|
| 34 |
print(msg)
|
| 35 |
model = model.to(device)
|