Update openshape/demo/caption.py
Browse files
openshape/demo/caption.py
CHANGED
|
@@ -158,7 +158,7 @@ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
|
| 158 |
prefix_length = 10
|
| 159 |
model = ClipCaptionModel(prefix_length)
|
| 160 |
# print(model.gpt_embedding_size)
|
| 161 |
-
model.load_state_dict(torch.load(hf_hub_download('OpenShape/clipcap-cc', 'conceptual_weights.pt'
|
| 162 |
model.eval()
|
| 163 |
if torch.cuda.is_available():
|
| 164 |
model = model.cuda()
|
|
|
|
| 158 |
prefix_length = 10
|
| 159 |
model = ClipCaptionModel(prefix_length)
|
| 160 |
# print(model.gpt_embedding_size)
|
| 161 |
+
model.load_state_dict(torch.load(hf_hub_download('OpenShape/clipcap-cc', 'conceptual_weights.pt'), map_location='cpu'))
|
| 162 |
model.eval()
|
| 163 |
if torch.cuda.is_available():
|
| 164 |
model = model.cuda()
|