thread optimizatin
Browse files
app.py
CHANGED
|
@@ -252,7 +252,6 @@
|
|
| 252 |
|
| 253 |
|
| 254 |
|
| 255 |
-
|
| 256 |
import gradio as gr
|
| 257 |
import cv2
|
| 258 |
import numpy as np
|
|
@@ -274,11 +273,11 @@ logger = logging.getLogger(__name__)
|
|
| 274 |
|
| 275 |
class FeatureExtractor:
|
| 276 |
def __init__(self):
|
| 277 |
-
|
| 278 |
-
backbone = models.resnet50(weights="IMAGENET1K_V1")
|
| 279 |
-
self.
|
| 280 |
-
|
| 281 |
-
|
| 282 |
self.transform = transforms.Compose([
|
| 283 |
transforms.Resize((224, 224)),
|
| 284 |
transforms.ToTensor(),
|
|
@@ -298,33 +297,24 @@ class FeatureExtractor:
|
|
| 298 |
if len(rgb.shape) == 2:
|
| 299 |
rgb = cv2.cvtColor(rgb, cv2.COLOR_GRAY2RGB)
|
| 300 |
|
| 301 |
-
# We want the layer BEFORE the global pooling to get spatial info
|
| 302 |
-
# resnet.layer4 is the last block
|
| 303 |
-
# self.model is nn.Sequential(*list(backbone.children())[:-1])
|
| 304 |
-
# children()[:-1] = [conv1, bn1, relu, maxpool, layer1, layer2, layer3, layer4]
|
| 305 |
-
|
| 306 |
input_tensor = self.transform(Image.fromarray(rgb)).unsqueeze(0)
|
| 307 |
|
| 308 |
-
#
|
| 309 |
with torch.no_grad():
|
| 310 |
-
#
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
backbone.
|
| 314 |
-
|
| 315 |
-
x = backbone.
|
| 316 |
-
x = backbone.
|
| 317 |
-
x = backbone.
|
| 318 |
-
|
| 319 |
-
x = backbone.layer1(x)
|
| 320 |
-
x = backbone.layer2(x)
|
| 321 |
-
x = backbone.layer3(x)
|
| 322 |
-
features_spatial = backbone.layer4(x) # [1, 2048, 7, 7]
|
| 323 |
|
| 324 |
-
# Global Average Pooling
|
| 325 |
feat = torch.mean(features_spatial, dim=[2, 3]).squeeze().cpu().numpy()
|
| 326 |
|
| 327 |
-
#
|
| 328 |
amap = torch.sum(features_spatial, dim=1).squeeze().cpu().numpy()
|
| 329 |
amap = np.maximum(amap, 0)
|
| 330 |
amap /= (np.max(amap) + 1e-8)
|
|
@@ -332,11 +322,11 @@ class FeatureExtractor:
|
|
| 332 |
amap = np.uint8(255 * amap)
|
| 333 |
heatmap = cv2.applyColorMap(amap, cv2.COLORMAP_JET)
|
| 334 |
|
| 335 |
-
# Overlay heatmap on
|
| 336 |
-
# Convert BGR heatmap to RGB
|
| 337 |
heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
| 338 |
overlay = cv2.addWeighted(rgb, 0.6, heatmap_rgb, 0.4, 0)
|
| 339 |
|
|
|
|
| 340 |
norm = np.linalg.norm(feat)
|
| 341 |
return (feat / norm if norm > 1e-8 else feat), overlay
|
| 342 |
|
|
|
|
| 252 |
|
| 253 |
|
| 254 |
|
|
|
|
| 255 |
import gradio as gr
|
| 256 |
import cv2
|
| 257 |
import numpy as np
|
|
|
|
| 273 |
|
| 274 |
class FeatureExtractor:
|
| 275 |
def __init__(self):
|
| 276 |
+
|
| 277 |
+
self.backbone = models.resnet50(weights="IMAGENET1K_V1")
|
| 278 |
+
self.backbone.eval()
|
| 279 |
+
|
| 280 |
+
# We transform to standard ImageNet resolution
|
| 281 |
self.transform = transforms.Compose([
|
| 282 |
transforms.Resize((224, 224)),
|
| 283 |
transforms.ToTensor(),
|
|
|
|
| 297 |
if len(rgb.shape) == 2:
|
| 298 |
rgb = cv2.cvtColor(rgb, cv2.COLOR_GRAY2RGB)
|
| 299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
input_tensor = self.transform(Image.fromarray(rgb)).unsqueeze(0)
|
| 301 |
|
| 302 |
+
# Optimized inference: Use the pre-loaded backbone
|
| 303 |
with torch.no_grad():
|
| 304 |
+
# Walk through layers to capture spatial activations before global pooling
|
| 305 |
+
x = self.backbone.conv1(input_tensor)
|
| 306 |
+
x = self.backbone.bn1(x)
|
| 307 |
+
x = self.backbone.relu(x)
|
| 308 |
+
x = self.backbone.maxpool(x)
|
| 309 |
+
x = self.backbone.layer1(x)
|
| 310 |
+
x = self.backbone.layer2(x)
|
| 311 |
+
x = self.backbone.layer3(x)
|
| 312 |
+
features_spatial = self.backbone.layer4(x) # [1, 2048, 7, 7]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
+
# Global Average Pooling (L2 distance is more effective on normalized vectors)
|
| 315 |
feat = torch.mean(features_spatial, dim=[2, 3]).squeeze().cpu().numpy()
|
| 316 |
|
| 317 |
+
# Heatmap generation: Sum across channels to highlight activated regions
|
| 318 |
amap = torch.sum(features_spatial, dim=1).squeeze().cpu().numpy()
|
| 319 |
amap = np.maximum(amap, 0)
|
| 320 |
amap /= (np.max(amap) + 1e-8)
|
|
|
|
| 322 |
amap = np.uint8(255 * amap)
|
| 323 |
heatmap = cv2.applyColorMap(amap, cv2.COLORMAP_JET)
|
| 324 |
|
| 325 |
+
# Overlay BGR heatmap on RGB image properly
|
|
|
|
| 326 |
heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
| 327 |
overlay = cv2.addWeighted(rgb, 0.6, heatmap_rgb, 0.4, 0)
|
| 328 |
|
| 329 |
+
# Vector normalization for Cosine Similarity
|
| 330 |
norm = np.linalg.norm(feat)
|
| 331 |
return (feat / norm if norm > 1e-8 else feat), overlay
|
| 332 |
|