VJyzCELERY's picture
Update
5e96bc9
import torch
import torch.nn as nn
import cv2
import numpy as np
from dataclasses import dataclass
from skimage.feature import hog,local_binary_pattern
import itertools
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os
import io
from PIL import Image
@dataclass
class Config:
img_size=(256,256)
in_channels=3
fc_num_layers=3
conv_hidden_dim=2
conv_kernel_size=3
dropout=0.2
classical_downsample=1
# HOG
hog_orientations = 9
hog_pixels_per_cell = (16, 16)
hog_cells_per_block = (2, 2)
hog_block_norm = 'L2-Hys'
# Canny
canny_sigma = 1.0
canny_low = 100
canny_high = 200
# Gaussian
gaussian_ksize = (3, 3)
gaussian_sigmaX = 1.0
gaussian_sigmaY = 1.0
# Harris corners
harris_block_size = 2
harris_ksize = 3
harris_k = 0.04
# LBP
lbp_P = 8
lbp_R = 1
# Gabor filters
gabor_ksize = 21
gabor_sigma = 5
gabor_theta = 0
gabor_lambda = 10
gabor_gamma = 0.5
# Sobel
sobel_ksize=3
class CNNFeatureExtractor(nn.Module):
def __init__(self,config : Config):
super().__init__()
layers = []
self.in_channels = config.in_channels
in_channel = config.in_channels
self.img_size = config.img_size
out_channel = 32
for i in range(config.conv_hidden_dim):
layers.append(nn.Conv2d(in_channels=in_channel,out_channels=out_channel,kernel_size=config.conv_kernel_size,stride=1,padding=config.conv_kernel_size // 2))
layers.append(nn.BatchNorm2d(out_channel))
layers.append(nn.ReLU())
layers.append(nn.MaxPool2d((2,2)))
in_channel=out_channel
out_channel*=2
self.layers = nn.Sequential(*layers)
def get_device(self):
return next(self.parameters()).device
def forward(self,x,**kwargs):
if isinstance(x, list):
if isinstance(x[0], np.ndarray):
x = np.stack(x, axis=0)
if isinstance(x,np.ndarray):
if len(x.shape) == 2:
x = x[:, :, None]
x = np.expand_dims(x, 0)
x = x.transpose(2, 0, 1)
elif len(x.shape) == 3:
x = x.transpose(2, 0, 1)
x = np.expand_dims(x, 0)
elif x.ndim == 4:
x = x.transpose(0, 3, 1, 2) # Change to (B,C,H,W)
x = torch.from_numpy(x).float()
elif isinstance(x, torch.Tensor):
if x.ndim == 3:
x = x.unsqueeze(0)
x=x.to(self.get_device())
return self.layers(x) # Always expects (B,C,H,W)
def output(self):
self.eval()
with torch.no_grad():
x = torch.zeros(
(1, self.in_channels, self.img_size[1], self.img_size[0]),
device=self.get_device()
)
out = self(x)
return out
def visualize(
self,
input_image,
max_channels=8,
couple=False,
show=True,
**kwargs
):
self.eval()
device = self.get_device()
if isinstance(input_image, np.ndarray):
x = torch.from_numpy(input_image).permute(2, 0, 1).float().unsqueeze(0).to(device)
elif isinstance(input_image, torch.Tensor):
x = input_image.unsqueeze(0).to(device) if input_image.ndim == 3 else input_image.to(device)
else:
raise TypeError("input_image must be np.ndarray or torch.Tensor")
conv_layers = [
(name, module)
for name, module in self.named_modules()
if isinstance(module, nn.ReLU)
]
all_layer_images = []
for name, layer in conv_layers:
activations = []
def hook_fn(module, input, output):
activations.append(output.detach().cpu())
handle = layer.register_forward_hook(hook_fn)
_ = self(x)
handle.remove()
act = activations[0][0] # (C, H, W)
C, H, W = act.shape
# --------------------------------------------------
# COUPLED RGB VISUALIZATION
# --------------------------------------------------
if couple:
max_rgb = max_channels // 3
num_rgb = min(C // 3, max_rgb)
rem = min(C - num_rgb * 3, max_channels - num_rgb * 3)
total_tiles = num_rgb + rem
cols = min(4, total_tiles)
rows = int(np.ceil(total_tiles / cols))
fig, axes = plt.subplots(
rows, cols,
figsize=(3 * cols, 3 * rows)
)
axes = np.atleast_2d(axes)
tile_idx = 0
# ---------------------------
# RGB COUPLED CHANNELS
# ---------------------------
for i in range(num_rgb):
r = tile_idx // cols
c = tile_idx % cols
rgb = act[i*3:(i+1)*3].clone()
for ch in range(3):
v = rgb[ch]
rgb[ch] = (v - v.min()) / (v.max() - v.min() + 1e-8)
rgb = rgb.permute(1, 2, 0).numpy()
axes[r, c].imshow(rgb)
axes[r, c].axis("off")
axes[r, c].set_title(f"RGB {i*3}-{i*3+2}", fontsize=9)
tile_idx += 1
start = num_rgb * 3
for j in range(rem):
r = tile_idx // cols
c = tile_idx % cols
ch = act[start + j]
ch = (ch - ch.min()) / (ch.max() - ch.min() + 1e-8)
axes[r, c].imshow(ch, cmap="gray")
axes[r, c].axis("off")
axes[r, c].set_title(f"Ch {start + j}", fontsize=9)
tile_idx += 1
for idx in range(tile_idx, rows * cols):
r = idx // cols
c = idx % cols
axes[r, c].axis("off")
fig.suptitle(f"Layer: {name} (Coupled RGB + Grayscale)", fontsize=14)
plt.tight_layout()
# --------------------------------------------------
# STANDARD GRAYSCALE VISUALIZATION
# --------------------------------------------------
else:
num_channels = min(C, max_channels)
cols = min(8, num_channels)
rows = int(np.ceil(num_channels / cols))
fig, axes = plt.subplots(
rows, cols,
figsize=(3 * cols, 3 * rows)
)
axes = np.atleast_2d(axes)
for idx in range(num_channels):
r = idx // cols
c = idx % cols
axes[r, c].imshow(act[idx], cmap="gray")
axes[r, c].axis("off")
for idx in range(num_channels, rows * cols):
r = idx // cols
c = idx % cols
axes[r, c].axis("off")
fig.suptitle(f"Layer: {name}", fontsize=14)
plt.tight_layout()
if show:
plt.show()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=150, bbox_inches="tight")
buf.seek(0)
img = Image.open(buf).convert("RGB")
all_layer_images.append(np.array(img))
plt.close(fig)
return all_layer_images
class ClassicalFeatureExtractor(nn.Module):
def __init__(self, config : Config):
super().__init__()
self.img_size = config.img_size # (H, W)
self.hog_orientations = config.hog_orientations
self.num_downsample = config.classical_downsample
self.config = config
self.device = 'cpu'
self.convolution=None
def get_device(self):
return next(self.parameters()).device if len(list(self.parameters())) > 0 else self.device
def render_subplots(self,items, max_cols=8, figsize_per_cell=3):
n = len(items)
cols = min(max_cols, n)
rows = int(np.ceil(n / cols))
fig, axes = plt.subplots(
rows, cols,
figsize=(cols * figsize_per_cell, rows * figsize_per_cell)
)
axes = np.atleast_2d(axes)
for idx, (img, title, cmap) in enumerate(items):
r = idx // cols
c = idx % cols
ax = axes[r, c]
ax.imshow(img, cmap=cmap)
ax.set_title(title, fontsize=9)
ax.axis("off")
for idx in range(n, rows * cols):
r = idx // cols
c = idx % cols
axes[r, c].axis("off")
plt.tight_layout()
return fig
def extract_features(self, img,visualize=False,**kwargs):
cfg = self.config
# Convert to grayscale
gray = cv2.cvtColor((img*255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
for _ in range(self.num_downsample):
gray = cv2.pyrDown(gray)
gray = cv2.GaussianBlur(gray, cfg.gaussian_ksize, sigmaX=cfg.gaussian_sigmaX, sigmaY=cfg.gaussian_sigmaY)
valid_H, valid_W = gray.shape[:2]
feature_list = []
vis_items=[]
# DEPRECATED
# H, W = gray.shape
# cell_h, cell_w = cfg.hog_pixels_per_cell
# block_h, block_w = cfg.hog_cells_per_block
# min_h = cell_h * block_h
# min_w = cell_w * block_w
# use_hog = False
# # 1. HOG
# if use_hog:
# hog_descriptors, hog_image = hog(
# gray,
# orientations=cfg.hog_orientations,
# pixels_per_cell=cfg.hog_pixels_per_cell,
# cells_per_block=cfg.hog_cells_per_block,
# block_norm=cfg.hog_block_norm,
# visualize=True,
# feature_vector=False
# )
# hog_cells = hog_descriptors.mean(axis=(2, 3))
# cell_h, cell_w = cfg.hog_pixels_per_cell
# hog_pixel = np.repeat(
# np.repeat(hog_cells, cell_h, axis=0),
# cell_w, axis=1
# )
# hog_pixel = hog_pixel[:gray.shape[0], :gray.shape[1]]
# hog_energy = np.sum(hog_pixel, axis=2)
# dominant_bin = np.argmax(hog_pixel, axis=2)
# dominant_strength = np.max(hog_pixel, axis=2)
# dominant_weighted = dominant_bin * dominant_strength
# valid_H, valid_W = hog_pixel.shape[:2]
# if visualize:
# vis_items.append((hog_energy, "HOG Energy",'gray'))
# vis_items.append((dominant_bin, "HOG Dominant Bin",'hsv'))
# vis_items.append((dominant_weighted, "HOG Weighted Dominant Bin",'gray'))
# vis_items.append((hog_image[:valid_H, :valid_W], f"HoG",'gray'))
# for b in range(hog_pixel.shape[2]):
# feature_list.append(hog_pixel[:, :, b])
# 2. Canny edges
edges = cv2.Canny(gray, cfg.canny_low, cfg.canny_high) / 255.0
feature_list.append(edges[:valid_H, :valid_W])
if visualize:
vis_items.append((edges[:valid_H, :valid_W], "Canny Edge", "gray"))
# 3. Harris corners
harris = cv2.cornerHarris(gray, blockSize=cfg.harris_block_size, ksize=cfg.harris_ksize, k=cfg.harris_k)
harris = cv2.dilate(harris, None)
harris = np.clip(harris, 0, 1)
feature_list.append(harris[:valid_H, :valid_W])
if visualize:
vis_items.append((harris[:valid_H, :valid_W], "Harris Corner", "gray"))
# 4. LBP
lbp = local_binary_pattern(gray, P=cfg.lbp_P, R=cfg.lbp_R, method='uniform')
lbp = lbp / lbp.max() if lbp.max() != 0 else lbp
# feature_list.append(lbp.ravel())
feature_list.append(lbp[:valid_H, :valid_W])
if visualize:
# figs.append(plot_feature(lbp[:valid_H, :valid_W], "LBP"))
vis_items.append((lbp[:valid_H, :valid_W], "LBP", "gray"))
# 5. Gabor filter
for theta in [0, np.pi/4, np.pi/2]:
kernel = cv2.getGaborKernel(
(cfg.gabor_ksize, cfg.gabor_ksize),
cfg.gabor_sigma, theta,
cfg.gabor_lambda, cfg.gabor_gamma
)
g = cv2.filter2D(gray, cv2.CV_32F, kernel)
g = np.abs(g)
g /= g.max() + 1e-8
feature_list.append(g[:valid_H, :valid_W])
if visualize:
vis_items.append((g[:valid_H, :valid_W], f"Gabor θ={theta:.2f}", "gray"))
# 6. Sobel
sobelx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=cfg.sobel_ksize)
sobely = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=cfg.sobel_ksize)
sobelx = np.abs(sobelx)
sobely = np.abs(sobely)
sobelx /= sobelx.max() + 1e-8
sobely /= sobely.max() + 1e-8
feature_list.append(sobelx[:valid_H, :valid_W])
feature_list.append(sobely[:valid_H, :valid_W])
if visualize:
vis_items.append((sobelx[:valid_H, :valid_W], "Sobel X",'gray'))
vis_items.append((sobely[:valid_H, :valid_W], "Sobel Y",'gray'))
# 7. Laplacian
lap = cv2.Laplacian(gray, cv2.CV_32F)
lap = np.abs(lap)
lap /= lap.max() + 1e-8
feature_list.append(lap[:valid_H, :valid_W])
if visualize:
vis_items.append((lap[:valid_H, :valid_W], "Laplacian",'gray'))
# 8. Gradient Magnitude
gx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=cfg.sobel_ksize)
gy = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=cfg.sobel_ksize)
grad_mag = np.sqrt(gx**2 + gy**2)
grad_mag /= grad_mag.max() + 1e-8
feature_list.append(grad_mag[:valid_H, :valid_W])
if visualize:
vis_items.append((grad_mag[:valid_H, :valid_W], "Gradient Magnitude",'gray'))
# Stack all features along channel axis
features = np.stack(feature_list, axis=0)
if visualize:
return features.astype(np.float32),[self.render_subplots(vis_items, max_cols=8)]
return features.astype(np.float32)
def forward(self, x, **kwargs):
if isinstance(x, list):
x = np.stack(x, axis=0)
if isinstance(x, torch.Tensor):
x = x.cpu().numpy()
if isinstance(x, np.ndarray):
if x.ndim == 3:
x = x[None]
elif x.ndim != 4:
raise ValueError(
f"Expected input of shape HWC or BHWC, got {x.shape}"
)
feats = []
for img in x:
if img.shape[2] != 3:
img = np.repeat(img[:, :, None], 3, axis=2)
feats.append(self.extract_features(img))
feats = np.stack(feats, axis=0)
feats = torch.from_numpy(feats).float().to(self.get_device())
return feats
def visualize(self, img, show_original=True,show=True):
if img.ndim != 3 or img.shape[2] != 3:
img = np.repeat(img[:, :, None], 3, axis=2)
outputs = []
def fig_to_pil(fig):
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=150, bbox_inches="tight")
buf.seek(0)
pil_img = Image.open(buf).copy()
buf.close()
plt.close(fig)
return pil_img
if show_original:
fig = plt.figure(figsize=(4, 4))
plt.imshow(img)
plt.title("Original")
plt.axis("off")
if show:
plt.show()
outputs.append(fig_to_pil(fig))
feature_stack,figs = self.extract_features(img,visualize=True)
if show:
plt.show()
for fig in figs:
outputs.append(fig_to_pil(fig))
return outputs
def output(self):
dummy = np.zeros(
(self.img_size[1], self.img_size[0], 3),
dtype=np.float32
)
feats = self.forward(dummy)
return feats
class FullyConnectedHead(nn.Module):
def __init__(self,in_features,classes,config:Config):
super().__init__()
num_classes = len(classes)
self.classes = classes
layers = []
hidden_dim =1024
for _ in range(config.fc_num_layers):
layers.append(nn.Linear(in_features, hidden_dim))
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.ReLU())
layers.append(nn.Dropout(config.dropout))
in_features = hidden_dim
hidden_dim = max(hidden_dim // 2, num_classes * 2)
layers.append(nn.Linear(in_features,num_classes))
self.layers = nn.Sequential(*layers)
def get_device(self):
return next(self.parameters()).device
def forward(self,x : torch.Tensor,**kwargs):
x=x.to(self.get_device())
return self.layers(x)
class Classifier(nn.Module):
def __init__(self,backbone,classes,config : Config):
super().__init__()
self.config=config
self.classes=classes
self.backbone = backbone
self.flatten = nn.Flatten()
feat = backbone.output()
flat = self.flatten(feat)
in_features = flat.shape[1]
self.head = FullyConnectedHead(in_features,classes,config)
def get_device(self):
return next(self.parameters()).device
@torch.no_grad()
def predict(self, x):
self.eval()
target_size = self.config.img_size
x = cv2.resize(x, target_size)
logits = self.forward(x)
probs = torch.softmax(logits,dim=1)
pred_idx = torch.argmax(probs, dim=1).item()
return self.classes[pred_idx]
def forward(self,x,**kwargs):
feat = self.backbone(x,**kwargs)
feat = self.flatten(feat,**kwargs)
return self.head(feat,**kwargs)
def visualize_feature(self,img,return_img=True,**kwargs):
target_size = self.config.img_size
img = cv2.resize(img, target_size)
if return_img:
return self.backbone.visualize(img,**kwargs)
else:
self.backbone.visualize(img,**kwargs)
def save(self, path: str):
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save({
'model_state_dict': self.state_dict(),
'classes': self.classes,
'config': self.config
}, path)
print(f"Model saved to {path}")
@staticmethod
def load(path: str, backbone_class, device='cpu'):
checkpoint = torch.load(path, map_location=device,weights_only=False)
config = checkpoint['config']
classes = checkpoint['classes']
backbone = backbone_class(config).to(device)
model = Classifier(backbone, classes, config).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"Model loaded from {path}")
return model