Spaces:
Runtime error
Runtime error
Update predict.py
Browse files- predict.py +1 -1
predict.py
CHANGED
|
@@ -10,7 +10,7 @@ def predict_one_image(path) :
|
|
| 10 |
image = read_image(path)
|
| 11 |
image = get_valid_augs()(image=image)['image']
|
| 12 |
image = torch.tensor(image,dtype=torch.float)
|
| 13 |
-
image = image.reshape((1,3,
|
| 14 |
model = CustomModel()
|
| 15 |
#loading ckpt
|
| 16 |
model.load_state_dict(torch.load(CKPT,map_location=torch.device('cpu')))
|
|
|
|
| 10 |
image = read_image(path)
|
| 11 |
image = get_valid_augs()(image=image)['image']
|
| 12 |
image = torch.tensor(image,dtype=torch.float)
|
| 13 |
+
image = image.reshape((1,3,512,512))
|
| 14 |
model = CustomModel()
|
| 15 |
#loading ckpt
|
| 16 |
model.load_state_dict(torch.load(CKPT,map_location=torch.device('cpu')))
|