allow inference of hyperpcm on cpu
Browse files
app.py
CHANGED
|
@@ -143,7 +143,7 @@ def retrieval():
|
|
| 143 |
memory = dataset
|
| 144 |
model = HyperPCM(memory=memory).to(device)
|
| 145 |
model = torch.nn.DataParallel(model)
|
| 146 |
-
model.load_state_dict(torch.load(checkpoint_path))
|
| 147 |
model.eval()
|
| 148 |
|
| 149 |
with torch.set_grad_enabled(False):
|
|
|
|
| 143 |
memory = dataset
|
| 144 |
model = HyperPCM(memory=memory).to(device)
|
| 145 |
model = torch.nn.DataParallel(model)
|
| 146 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
|
| 147 |
model.eval()
|
| 148 |
|
| 149 |
with torch.set_grad_enabled(False):
|