Spaces:
Runtime error
Runtime error
Update Experiments/Resnet50_classification.py
Browse files
Experiments/Resnet50_classification.py
CHANGED
|
@@ -51,11 +51,11 @@ class HiddenLayer(nn.Module):
|
|
| 51 |
|
| 52 |
def predict(features_path,image):
|
| 53 |
batch1 = unpickle(r"Model/data/data_batch_1")
|
| 54 |
-
batch2 = unpickle(r"Model
|
| 55 |
-
batch3 = unpickle(r"Model
|
| 56 |
-
batch4 = unpickle(r"Model
|
| 57 |
-
batch5 = unpickle(r"Model
|
| 58 |
-
test_batch = unpickle(r"Model
|
| 59 |
train_batch = [batch1,batch2,batch3,batch4,batch5]
|
| 60 |
train_y = []
|
| 61 |
train_x = []
|
|
@@ -102,7 +102,7 @@ def predict(features_path,image):
|
|
| 102 |
|
| 103 |
|
| 104 |
|
| 105 |
-
def retrieve(image,k,feature_path=r"Model
|
| 106 |
print(image.shape)
|
| 107 |
test_label,z,features,class_images_dict,train_x = predict(feature_path,image)
|
| 108 |
class_indices = class_images_dict[test_label.item()]
|
|
|
|
| 51 |
|
| 52 |
def predict(features_path,image):
|
| 53 |
batch1 = unpickle(r"Model/data/data_batch_1")
|
| 54 |
+
batch2 = unpickle(r"Model/data/data_batch_2")
|
| 55 |
+
batch3 = unpickle(r"Model/data/data_batch_3")
|
| 56 |
+
batch4 = unpickle(r"Model/data/data_batch_4")
|
| 57 |
+
batch5 = unpickle(r"Model/data/data_batch_5")
|
| 58 |
+
test_batch = unpickle(r"Model/data/test_batch")
|
| 59 |
train_batch = [batch1,batch2,batch3,batch4,batch5]
|
| 60 |
train_y = []
|
| 61 |
train_x = []
|
|
|
|
| 102 |
|
| 103 |
|
| 104 |
|
| 105 |
+
def retrieve(image,k,feature_path=r"Model/Resnet50_train_features.pt"):
|
| 106 |
print(image.shape)
|
| 107 |
test_label,z,features,class_images_dict,train_x = predict(feature_path,image)
|
| 108 |
class_indices = class_images_dict[test_label.item()]
|