Spaces:
Build error
Build error
Commit ·
45c81bf
1
Parent(s): 08df26a
Fixed normalization issue for EUROSAT_CUSTOM_MODEL
Browse files
app.py
CHANGED
|
@@ -38,11 +38,9 @@ def Pn(m, x):
|
|
| 38 |
def L(a,b,m,x):
|
| 39 |
return np.sqrt((2*m+1)/(b-a))*Pn(m, 2*(x-b)/(b-a)+1)
|
| 40 |
|
| 41 |
-
eurosat_transform =
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
v2.ToDtype(torch.float32),
|
| 45 |
-
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize images
|
| 46 |
])
|
| 47 |
|
| 48 |
class CNN(nn.Module):
|
|
@@ -81,7 +79,7 @@ def run_lime(input_image,
|
|
| 81 |
|
| 82 |
model, weights, preprocess, names = fetch_model(model_name)
|
| 83 |
|
| 84 |
-
input_image_processed = preprocess(torch.from_numpy(input_image.transpose(2,0,1))).unsqueeze(0)
|
| 85 |
logits = model(input_image_processed)
|
| 86 |
probs = F.softmax(logits, dim=1)
|
| 87 |
|
|
@@ -94,12 +92,12 @@ def run_lime(input_image,
|
|
| 94 |
def classifier_fn(images):
|
| 95 |
print('classifier_fn', type(images), images.shape)
|
| 96 |
|
| 97 |
-
zz = preprocess(torch.from_numpy(images[0].transpose(2,0,1)))
|
| 98 |
c, w, h = zz.shape
|
| 99 |
batch = torch.zeros(batch_size, c, w, h)
|
| 100 |
print('len(images)', len(images))
|
| 101 |
for i in range(batch_size):
|
| 102 |
-
batch[i] = preprocess(torch.from_numpy(images[i].transpose(2,0,1)))
|
| 103 |
print('batch', type(batch), batch.shape)
|
| 104 |
|
| 105 |
logits = model(batch)
|
|
@@ -190,7 +188,7 @@ def run_hdmr(input_image,
|
|
| 190 |
|
| 191 |
sam_segmented_image = segmented_image(input_image, masks, alpha=0.9)
|
| 192 |
|
| 193 |
-
batch = preprocess(torch.from_numpy(input_image.transpose(2,0,1))).unsqueeze(0)
|
| 194 |
# Unit normalize the logits
|
| 195 |
logits = model(batch)
|
| 196 |
logits = logits[0].detach().numpy()
|
|
@@ -211,7 +209,7 @@ def run_hdmr(input_image,
|
|
| 211 |
x_seg = mask['segmentation']
|
| 212 |
x_input[x_seg == 1] = x_input[x_seg == 1] * np.power(x[sample,i],2)
|
| 213 |
|
| 214 |
-
batch = preprocess(torch.from_numpy(x_input.transpose(2,0,1))).unsqueeze(0)
|
| 215 |
# Unit normalize the logits
|
| 216 |
logits_sample = model(batch)
|
| 217 |
probs = logits_sample.squeeze(0).softmax(0)
|
|
|
|
| 38 |
def L(a,b,m,x):
|
| 39 |
return np.sqrt((2*m+1)/(b-a))*Pn(m, 2*(x-b)/(b-a)+1)
|
| 40 |
|
| 41 |
+
eurosat_transform = transforms.Compose([
|
| 42 |
+
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[255.0, 255.0, 255.0]), # normalize to [0,1] first
|
| 43 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize images
|
|
|
|
|
|
|
| 44 |
])
|
| 45 |
|
| 46 |
class CNN(nn.Module):
|
|
|
|
| 79 |
|
| 80 |
model, weights, preprocess, names = fetch_model(model_name)
|
| 81 |
|
| 82 |
+
input_image_processed = preprocess(torch.from_numpy(input_image.astype(np.float32).transpose(2,0,1))).unsqueeze(0)
|
| 83 |
logits = model(input_image_processed)
|
| 84 |
probs = F.softmax(logits, dim=1)
|
| 85 |
|
|
|
|
| 92 |
def classifier_fn(images):
|
| 93 |
print('classifier_fn', type(images), images.shape)
|
| 94 |
|
| 95 |
+
zz = preprocess(torch.from_numpy(images[0].transpose(2,0,1).astype(np.float32)))
|
| 96 |
c, w, h = zz.shape
|
| 97 |
batch = torch.zeros(batch_size, c, w, h)
|
| 98 |
print('len(images)', len(images))
|
| 99 |
for i in range(batch_size):
|
| 100 |
+
batch[i] = preprocess(torch.from_numpy(images[i].transpose(2,0,1).astype(np.float32)))
|
| 101 |
print('batch', type(batch), batch.shape)
|
| 102 |
|
| 103 |
logits = model(batch)
|
|
|
|
| 188 |
|
| 189 |
sam_segmented_image = segmented_image(input_image, masks, alpha=0.9)
|
| 190 |
|
| 191 |
+
batch = preprocess(torch.from_numpy(input_image.astype(np.float32).transpose(2,0,1))).unsqueeze(0)
|
| 192 |
# Unit normalize the logits
|
| 193 |
logits = model(batch)
|
| 194 |
logits = logits[0].detach().numpy()
|
|
|
|
| 209 |
x_seg = mask['segmentation']
|
| 210 |
x_input[x_seg == 1] = x_input[x_seg == 1] * np.power(x[sample,i],2)
|
| 211 |
|
| 212 |
+
batch = preprocess(torch.from_numpy(x_input.astype(np.float32).transpose(2,0,1))).unsqueeze(0)
|
| 213 |
# Unit normalize the logits
|
| 214 |
logits_sample = model(batch)
|
| 215 |
probs = logits_sample.squeeze(0).softmax(0)
|