eho69 commited on
Commit
7c3f78b
·
verified ·
1 Parent(s): cfd55eb

thread optimizatin

Browse files
Files changed (1) hide show
  1. app.py +19 -29
app.py CHANGED
@@ -252,7 +252,6 @@
252
 
253
 
254
 
255
-
256
  import gradio as gr
257
  import cv2
258
  import numpy as np
@@ -274,11 +273,11 @@ logger = logging.getLogger(__name__)
274
 
275
  class FeatureExtractor:
276
  def __init__(self):
277
- # Using ResNet50 for 2048-D feature vectors
278
- backbone = models.resnet50(weights="IMAGENET1K_V1")
279
- self.model = nn.Sequential(*list(backbone.children())[:-1])
280
- self.model.eval()
281
-
282
  self.transform = transforms.Compose([
283
  transforms.Resize((224, 224)),
284
  transforms.ToTensor(),
@@ -298,33 +297,24 @@ class FeatureExtractor:
298
  if len(rgb.shape) == 2:
299
  rgb = cv2.cvtColor(rgb, cv2.COLOR_GRAY2RGB)
300
 
301
- # We want the layer BEFORE the global pooling to get spatial info
302
- # resnet.layer4 is the last block
303
- # self.model is nn.Sequential(*list(backbone.children())[:-1])
304
- # children()[:-1] = [conv1, bn1, relu, maxpool, layer1, layer2, layer3, layer4]
305
-
306
  input_tensor = self.transform(Image.fromarray(rgb)).unsqueeze(0)
307
 
308
- # Get activations from the last conv layer (Layer 4)
309
  with torch.no_grad():
310
- # Run through the layers up to global pooling
311
- # Using the original backbone for Easier Access to sub-layers
312
- backbone = models.resnet50(weights="IMAGENET1K_V1")
313
- backbone.eval()
314
-
315
- x = backbone.conv1(input_tensor)
316
- x = backbone.bn1(x)
317
- x = backbone.relu(x)
318
- x = backbone.maxpool(x)
319
- x = backbone.layer1(x)
320
- x = backbone.layer2(x)
321
- x = backbone.layer3(x)
322
- features_spatial = backbone.layer4(x) # [1, 2048, 7, 7]
323
 
324
- # Global Average Pooling to get the vector
325
  feat = torch.mean(features_spatial, dim=[2, 3]).squeeze().cpu().numpy()
326
 
327
- # Create Heatmap: sum across channels to see "hot" regions
328
  amap = torch.sum(features_spatial, dim=1).squeeze().cpu().numpy()
329
  amap = np.maximum(amap, 0)
330
  amap /= (np.max(amap) + 1e-8)
@@ -332,11 +322,11 @@ class FeatureExtractor:
332
  amap = np.uint8(255 * amap)
333
  heatmap = cv2.applyColorMap(amap, cv2.COLORMAP_JET)
334
 
335
- # Overlay heatmap on original image
336
- # Convert BGR heatmap to RGB
337
  heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
338
  overlay = cv2.addWeighted(rgb, 0.6, heatmap_rgb, 0.4, 0)
339
 
 
340
  norm = np.linalg.norm(feat)
341
  return (feat / norm if norm > 1e-8 else feat), overlay
342
 
 
252
 
253
 
254
 
 
255
  import gradio as gr
256
  import cv2
257
  import numpy as np
 
273
 
274
  class FeatureExtractor:
275
  def __init__(self):
276
+
277
+ self.backbone = models.resnet50(weights="IMAGENET1K_V1")
278
+ self.backbone.eval()
279
+
280
+ # We transform to standard ImageNet resolution
281
  self.transform = transforms.Compose([
282
  transforms.Resize((224, 224)),
283
  transforms.ToTensor(),
 
297
  if len(rgb.shape) == 2:
298
  rgb = cv2.cvtColor(rgb, cv2.COLOR_GRAY2RGB)
299
 
 
 
 
 
 
300
  input_tensor = self.transform(Image.fromarray(rgb)).unsqueeze(0)
301
 
302
+ # Optimized inference: Use the pre-loaded backbone
303
  with torch.no_grad():
304
+ # Walk through layers to capture spatial activations before global pooling
305
+ x = self.backbone.conv1(input_tensor)
306
+ x = self.backbone.bn1(x)
307
+ x = self.backbone.relu(x)
308
+ x = self.backbone.maxpool(x)
309
+ x = self.backbone.layer1(x)
310
+ x = self.backbone.layer2(x)
311
+ x = self.backbone.layer3(x)
312
+ features_spatial = self.backbone.layer4(x) # [1, 2048, 7, 7]
 
 
 
 
313
 
314
+ # Global Average Pooling (L2 distance is more effective on normalized vectors)
315
  feat = torch.mean(features_spatial, dim=[2, 3]).squeeze().cpu().numpy()
316
 
317
+ # Heatmap generation: Sum across channels to highlight activated regions
318
  amap = torch.sum(features_spatial, dim=1).squeeze().cpu().numpy()
319
  amap = np.maximum(amap, 0)
320
  amap /= (np.max(amap) + 1e-8)
 
322
  amap = np.uint8(255 * amap)
323
  heatmap = cv2.applyColorMap(amap, cv2.COLORMAP_JET)
324
 
325
+ # Overlay BGR heatmap on RGB image properly
 
326
  heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
327
  overlay = cv2.addWeighted(rgb, 0.6, heatmap_rgb, 0.4, 0)
328
 
329
+ # Vector normalization for Cosine Similarity
330
  norm = np.linalg.norm(feat)
331
  return (feat / norm if norm > 1e-8 else feat), overlay
332