hkayabilisim commited on
Commit
08df26a
·
1 Parent(s): ba447ac

Added custom model trained on Eurosat

Browse files

Model performans on eurosat is not very good.
It will be improved in the next commit.

Files changed (3) hide show
  1. EUROSAT_CUSTOM_MODEL.pth +3 -0
  2. app.py +54 -12
  3. requirements.txt +2 -1
EUROSAT_CUSTOM_MODEL.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68981cd12c767b20d0a98970d41d823bbb0a91ea23debf85a70dc2643ab5d619
3
+ size 33657519
app.py CHANGED
@@ -14,6 +14,16 @@ from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
14
  import wget
15
  import cv2
16
  matplotlib.use('agg')
 
 
 
 
 
 
 
 
 
 
17
 
18
  # Vanilla Legendre between [0,1]
19
  def Pn(m, x):
@@ -28,6 +38,32 @@ def Pn(m, x):
28
  def L(a,b,m,x):
29
  return np.sqrt((2*m+1)/(b-a))*Pn(m, 2*(x-b)/(b-a)+1)
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def run_lime(input_image,
32
  model_name: str,
33
  top_labels: int,
@@ -43,13 +79,11 @@ def run_lime(input_image,
43
  print('batch_size', batch_size)
44
  print('input image', type(input_image), input_image.shape)
45
 
46
- model, weights = fetch_model(model_name)
47
- preprocess = weights.transforms(antialias=True)
48
 
49
  input_image_processed = preprocess(torch.from_numpy(input_image.transpose(2,0,1))).unsqueeze(0)
50
  logits = model(input_image_processed)
51
  probs = F.softmax(logits, dim=1)
52
- names = weights.meta['categories']
53
 
54
  top_10_classes = []
55
  print('probs', type(probs), probs.shape)
@@ -149,9 +183,7 @@ def run_hdmr(input_image,
149
  print('num_features_hdmr', num_features_hdmr)
150
  print('input image', type(input_image), input_image.shape)
151
 
152
- model, weights = fetch_model(model_name)
153
- preprocess = weights.transforms(antialias=True)
154
-
155
  sam_model = fetch_sam_model(sam_model_name)
156
  mask_generator = SamAutomaticMaskGenerator(sam_model)
157
  masks = mask_generator.generate(input_image)
@@ -192,7 +224,7 @@ def run_hdmr(input_image,
192
  l2_distance = np.linalg.norm(logits_normalized - logits_sample_normalized)
193
  class_id = probs.argmax().item()
194
  score = probs[class_id].item()
195
- category_name = weights.meta["categories"][class_id]
196
  print(f"sample:{sample:2d} cosine: {cosine_distance:.5f} l1: {l1_distance:.5f} l2: {l2_distance:.5f} {category_name}: {100 * score:.1f}%")
197
 
198
  y[:,sample] = [cosine_distance, l1_distance, l2_distance]
@@ -225,18 +257,29 @@ def fetch_sam_model(sam_model_name_checkpoint):
225
  return sam
226
 
227
  def fetch_model_names():
228
- return models.list_models(module=torchvision.models)
 
229
 
230
  def fetch_model(model_name):
231
  print('Retrieving model ', model_name)
 
 
 
 
 
 
 
 
232
  weights_enum = models.get_model_weights(model_name)
233
  for w in weights_enum:
234
  if "IMAGENET1K" in w.name:
235
  weights = w
236
  model = models.get_model(model_name, weights=weights)
237
  print('Model weights loaded', w.name)
238
- return model, weights
239
- return None, None
 
 
240
 
241
  with gd.Blocks() as demo:
242
  with gd.Column():
@@ -255,7 +298,7 @@ with gd.Blocks() as demo:
255
  Select the image classification model to use for LIME.
256
  The list is automatically populated by using torchvision library.
257
  ''',
258
- value='convnext_tiny',
259
  choices=fetch_model_names())
260
  sam_model_name = gd.Dropdown(label="SAM model",
261
  info='Select the SAM model',
@@ -325,4 +368,3 @@ with gd.Blocks() as demo:
325
 
326
  if __name__ == "__main__":
327
  demo.launch()
328
-
 
14
  import wget
15
  import cv2
16
  matplotlib.use('agg')
17
+ import os
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.optim as optim
21
+ from torch.utils.data import DataLoader, random_split
22
+ from torchvision import datasets, transforms, models
23
+ import torch.optim as optim
24
+ import torch.nn.functional as F
25
+ import matplotlib.pyplot as plt
26
+ from torchvision.transforms import v2
27
 
28
  # Vanilla Legendre between [0,1]
29
  def Pn(m, x):
 
38
  def L(a,b,m,x):
39
  return np.sqrt((2*m+1)/(b-a))*Pn(m, 2*(x-b)/(b-a)+1)
40
 
41
+ eurosat_transform = v2.Compose([
42
+ v2.Resize((64, 64)), # Resize images to 64x64
43
+ #transforms.ToTensor(), # Convert images to PyTorch tensors,
44
+ v2.ToDtype(torch.float32),
45
+ v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize images
46
+ ])
47
+
48
+ class CNN(nn.Module):
49
+ def __init__(self, num_classes=10): # Modify num_classes based on the number of your classes
50
+ super(CNN, self).__init__()
51
+ self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
52
+ self.pool = nn.MaxPool2d(2, 2)
53
+ self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
54
+ self.fc1 = nn.Linear(64 * 16 * 16, 512)
55
+ self.fc2 = nn.Linear(512, num_classes)
56
+ self.dropout = nn.Dropout(0.25)
57
+
58
+ def forward(self, x):
59
+ x = self.pool(F.relu(self.conv1(x)))
60
+ x = self.pool(F.relu(self.conv2(x)))
61
+ x = x.view(-1, 64 * 16 * 16)
62
+ x = self.dropout(x)
63
+ x = F.relu(self.fc1(x))
64
+ x = self.fc2(x)
65
+ return x
66
+
67
  def run_lime(input_image,
68
  model_name: str,
69
  top_labels: int,
 
79
  print('batch_size', batch_size)
80
  print('input image', type(input_image), input_image.shape)
81
 
82
+ model, weights, preprocess, names = fetch_model(model_name)
 
83
 
84
  input_image_processed = preprocess(torch.from_numpy(input_image.transpose(2,0,1))).unsqueeze(0)
85
  logits = model(input_image_processed)
86
  probs = F.softmax(logits, dim=1)
 
87
 
88
  top_10_classes = []
89
  print('probs', type(probs), probs.shape)
 
183
  print('num_features_hdmr', num_features_hdmr)
184
  print('input image', type(input_image), input_image.shape)
185
 
186
+ model, weights, preprocess, names = fetch_model(model_name)
 
 
187
  sam_model = fetch_sam_model(sam_model_name)
188
  mask_generator = SamAutomaticMaskGenerator(sam_model)
189
  masks = mask_generator.generate(input_image)
 
224
  l2_distance = np.linalg.norm(logits_normalized - logits_sample_normalized)
225
  class_id = probs.argmax().item()
226
  score = probs[class_id].item()
227
+ category_name = names[class_id]
228
  print(f"sample:{sample:2d} cosine: {cosine_distance:.5f} l1: {l1_distance:.5f} l2: {l2_distance:.5f} {category_name}: {100 * score:.1f}%")
229
 
230
  y[:,sample] = [cosine_distance, l1_distance, l2_distance]
 
257
  return sam
258
 
259
  def fetch_model_names():
260
+ model_names = models.list_models(module=torchvision.models)
261
+ return ['EUROSAT_CUSTOM_MODEL'] + model_names
262
 
263
  def fetch_model(model_name):
264
  print('Retrieving model ', model_name)
265
+ if model_name == "EUROSAT_CUSTOM_MODEL":
266
+ model = CNN()
267
+ weights = torch.load('EUROSAT_CUSTOM_MODEL.pth')
268
+ model.load_state_dict(weights)
269
+ return (model, weights, eurosat_transform,
270
+ ['AnnualCrop','Forest','HerbaceousVegetation','Highway',
271
+ 'Industrial','Pasture','PermanentCrop','Residential',
272
+ 'River','SeaLake'])
273
  weights_enum = models.get_model_weights(model_name)
274
  for w in weights_enum:
275
  if "IMAGENET1K" in w.name:
276
  weights = w
277
  model = models.get_model(model_name, weights=weights)
278
  print('Model weights loaded', w.name)
279
+ return (model, weights,
280
+ weights.transforms(antialias=True),
281
+ weights.meta['categories'])
282
+ return None, None, None, None
283
 
284
  with gd.Blocks() as demo:
285
  with gd.Column():
 
298
  Select the image classification model to use for LIME.
299
  The list is automatically populated by using torchvision library.
300
  ''',
301
+ value='EUROSAT_CUSTOM_MODEL',
302
  choices=fetch_model_names())
303
  sam_model_name = gd.Dropdown(label="SAM model",
304
  info='Select the SAM model',
 
368
 
369
  if __name__ == "__main__":
370
  demo.launch()
 
requirements.txt CHANGED
@@ -4,4 +4,5 @@ opencv-python==4.7.0.72
4
  lime==0.2.0.1
5
  scikit-image==0.20.0
6
  torch==2.0.0
7
- wget==3.2
 
 
4
  lime==0.2.0.1
5
  scikit-image==0.20.0
6
  torch==2.0.0
7
+ wget==3.2
8
+ torchaudio==2.0.1