| | import torch |
| | import torch.nn as nn |
| | import cv2 |
| | import numpy as np |
| | from dataclasses import dataclass |
| | from skimage.feature import hog,local_binary_pattern |
| | import matplotlib.pyplot as plt |
| | import os |
| | import io |
| | from PIL import Image |
| |
|
| | @dataclass |
| | class Config: |
| | img_size=(256,256) |
| | in_channels=3 |
| | fc_num_layers=3 |
| | conv_hidden_dim=3 |
| | conv_kernel_size=3 |
| | dropout=0.2 |
| | classical_downsample=1 |
| | |
| | hog_orientations = 9 |
| | hog_pixels_per_cell = (16, 16) |
| | hog_cells_per_block = (2, 2) |
| | hog_block_norm = 'L2-Hys' |
| |
|
| | |
| | canny_sigma = 1.0 |
| | canny_low = 100 |
| | canny_high = 200 |
| |
|
| | |
| | gaussian_ksize = (3, 3) |
| | gaussian_sigmaX = 1.0 |
| | gaussian_sigmaY = 1.0 |
| |
|
| | |
| | harris_block_size = 2 |
| | harris_ksize = 3 |
| | harris_k = 0.04 |
| |
|
| | |
| | shi_max_corners = 100 |
| | shi_quality_level = 0.01 |
| | shi_min_distance = 10 |
| |
|
| | |
| | lbp_P = 8 |
| | lbp_R = 1 |
| |
|
| | |
| | gabor_ksize = 21 |
| | gabor_sigma = 5 |
| | gabor_theta = 0 |
| | gabor_lambda = 10 |
| | gabor_gamma = 0.5 |
| |
|
| | class CNNFeatureExtractor(nn.Module): |
| | def __init__(self,config : Config): |
| | super().__init__() |
| | layers = [] |
| | self.in_channels = config.in_channels |
| | in_channel = config.in_channels |
| | self.img_size = config.img_size |
| | out_channel = 32 |
| | for i in range(config.conv_hidden_dim): |
| | layers.append(nn.Conv2d(in_channels=in_channel,out_channels=out_channel,kernel_size=config.conv_kernel_size,stride=1,padding=1)) |
| | layers.append(nn.BatchNorm2d(out_channel)) |
| | layers.append(nn.ReLU()) |
| | layers.append(nn.MaxPool2d(2)) |
| | in_channel=out_channel |
| | out_channel*=2 |
| | self.layers = nn.Sequential(*layers) |
| | def get_device(self): |
| | return next(self.parameters()).device |
| | def forward(self,x): |
| | if isinstance(x, list): |
| | if isinstance(x[0], np.ndarray): |
| | x = np.stack(x, axis=0) |
| | if isinstance(x,np.ndarray): |
| | if len(x.shape) == 2: |
| | x = x[:, :, None] |
| | x = np.expand_dims(x, 0) |
| | x = x.transpose(2, 0, 1) |
| | elif len(x.shape) == 3: |
| | x = x.transpose(2, 0, 1) |
| | x = np.expand_dims(x, 0) |
| | elif x.ndim == 4: |
| | x = x.transpose(0, 3, 1, 2) |
| | x = torch.from_numpy(x).float() |
| | elif isinstance(x, torch.Tensor): |
| | if x.ndim == 3: |
| | x = x.unsqueeze(0) |
| | x=x.to(self.get_device()) |
| | return self.layers(x) |
| | def output(self): |
| | self.eval() |
| |
|
| | with torch.no_grad(): |
| | x = torch.zeros( |
| | (1, self.in_channels, self.img_size[1], self.img_size[0]), |
| | device=self.get_device() |
| | ) |
| |
|
| | out = self(x) |
| |
|
| | return out |
| | def visualize(self, input_image, max_channels=8,show=True): |
| | self.eval() |
| | device = self.get_device() |
| |
|
| | if isinstance(input_image, np.ndarray): |
| | x = torch.from_numpy(input_image).permute(2, 0, 1).float().unsqueeze(0).to(device) |
| | elif isinstance(input_image, torch.Tensor): |
| | x = input_image.unsqueeze(0).to(device) if input_image.ndim == 3 else input_image.to(device) |
| | else: |
| | raise TypeError("input_image must be np.ndarray or torch.Tensor") |
| |
|
| | conv_layers = [(name, module) for name, module in self.named_modules() if isinstance(module, nn.Conv2d)] |
| | all_layer_images = [] |
| |
|
| | for name, layer in conv_layers: |
| | activations = [] |
| |
|
| | def hook_fn(module, input, output): |
| | activations.append(output.cpu().detach()) |
| |
|
| | handle = layer.register_forward_hook(hook_fn) |
| | _ = self(x) |
| | handle.remove() |
| |
|
| | act = activations[0][0] |
| | num_channels = min(act.shape[0], max_channels) |
| |
|
| | fig, axes = plt.subplots(1, num_channels, figsize=(3*num_channels, 3)) |
| | if num_channels == 1: |
| | axes = [axes] |
| |
|
| | for i in range(num_channels): |
| | axes[i].imshow(act[i], cmap='gray') |
| | axes[i].axis('off') |
| |
|
| | fig.suptitle(f'Layer: {name}', fontsize=14) |
| | if show: |
| | plt.show() |
| |
|
| | buf = io.BytesIO() |
| | fig.savefig(buf, format='png') |
| | buf.seek(0) |
| | img = Image.open(buf).convert("RGB") |
| | all_layer_images.append(np.array(img)) |
| | plt.close(fig) |
| | return all_layer_images |
| | |
| | class ClassicalFeatureExtractor(nn.Module): |
| | def __init__(self, config : Config): |
| | super().__init__() |
| | self.img_size = config.img_size |
| | self.hog_orientations = config.hog_orientations |
| | self.num_downsample = config.classical_downsample |
| | self.config = config |
| | self.feature_names = ['HoG','Canny Edge','Harris Corner','Shi-Tomasi corners','LBP','Gabor Filters'] |
| | self.device = 'cpu' |
| |
|
| | def get_device(self): |
| | return next(self.parameters()).device if len(list(self.parameters())) > 0 else self.device |
| |
|
| |
|
| | def extract_features(self, img): |
| | cfg = self.config |
| |
|
| | |
| | min_h = cfg.hog_pixels_per_cell[0] * cfg.hog_cells_per_block[0] |
| | min_w = cfg.hog_pixels_per_cell[1] * cfg.hog_cells_per_block[1] |
| | gray = cv2.cvtColor((img*255).astype(np.uint8), cv2.COLOR_RGB2GRAY) |
| |
|
| | for _ in range(self.num_downsample): |
| | h, w = gray.shape |
| | if h <= min_h or w <= min_w: |
| | break |
| | gray = cv2.pyrDown(gray) |
| |
|
| | gray = cv2.GaussianBlur(gray, cfg.gaussian_ksize, sigmaX=cfg.gaussian_sigmaX, sigmaY=cfg.gaussian_sigmaY) |
| |
|
| | feature_list = [] |
| |
|
| | |
| | _, hog_image = hog( |
| | gray, |
| | orientations=cfg.hog_orientations, |
| | pixels_per_cell=cfg.hog_pixels_per_cell, |
| | cells_per_block=cfg.hog_cells_per_block, |
| | block_norm=cfg.hog_block_norm, |
| | visualize=True |
| | ) |
| | feature_list.append(hog_image) |
| |
|
| | |
| | edges = cv2.Canny(gray, cfg.canny_low, cfg.canny_high) / 255.0 |
| | feature_list.append(edges) |
| |
|
| | |
| | harris = cv2.cornerHarris(gray, blockSize=cfg.harris_block_size, ksize=cfg.harris_ksize, k=cfg.harris_k) |
| | harris = cv2.dilate(harris, None) |
| | harris = np.clip(harris, 0, 1) |
| | feature_list.append(harris) |
| |
|
| | |
| | shi_corners = np.zeros_like(gray, dtype=np.float32) |
| | keypoints = cv2.goodFeaturesToTrack(gray, maxCorners=cfg.shi_max_corners, qualityLevel=cfg.shi_quality_level, minDistance=cfg.shi_min_distance) |
| | if keypoints is not None: |
| | for kp in keypoints: |
| | x, y = kp.ravel() |
| | shi_corners[int(y), int(x)] = 1.0 |
| | feature_list.append(shi_corners) |
| |
|
| | |
| | lbp = local_binary_pattern(gray, P=cfg.lbp_P, R=cfg.lbp_R, method='uniform') |
| | lbp = lbp / lbp.max() if lbp.max() != 0 else lbp |
| | feature_list.append(lbp) |
| |
|
| | |
| | g_kernel = cv2.getGaborKernel((cfg.gabor_ksize, cfg.gabor_ksize), cfg.gabor_sigma, cfg.gabor_theta, cfg.gabor_lambda, cfg.gabor_gamma) |
| | gabor_feat = cv2.filter2D(gray, cv2.CV_32F, g_kernel) |
| | gabor_feat = (gabor_feat - gabor_feat.min()) / (gabor_feat.max() - gabor_feat.min() + 1e-8) |
| | feature_list.append(gabor_feat) |
| |
|
| | |
| | features = np.stack(feature_list, axis=2) |
| | return features.astype(np.float32) |
| |
|
| |
|
| | def forward(self, x): |
| | if isinstance(x, torch.Tensor): |
| | x = x.cpu().numpy() |
| | if isinstance(x, np.ndarray): |
| | if x.ndim == 3: |
| | x = np.expand_dims(x, 0) |
| | elif x.ndim != 4: |
| | raise ValueError(f"Expected input of shape HWC or BHWC, got {x.shape}") |
| | elif isinstance(x, list): |
| | x = np.stack(x, axis=0) |
| |
|
| | batch_features = [] |
| | for img in x: |
| | if img.ndim != 3 or img.shape[2] != 3: |
| | img = np.repeat(img[:, :, None], 3, axis=2) |
| | feat = self.extract_features(img) |
| | batch_features.append(feat) |
| | batch_features = np.stack(batch_features, axis=0) |
| | return torch.from_numpy(batch_features).float().to(self.get_device()) |
| | |
| | def visualize(self, img, show_original=True,show=True): |
| | if img.ndim != 3 or img.shape[2] != 3: |
| | img = np.repeat(img[:, :, None], 3, axis=2) |
| |
|
| | feature_stack = self.extract_features(img) |
| | num_channels = feature_stack.shape[2] |
| |
|
| | outputs = [] |
| |
|
| | def fig_to_pil(fig): |
| | buf = io.BytesIO() |
| | fig.savefig(buf, format="png", dpi=150, bbox_inches="tight") |
| | buf.seek(0) |
| |
|
| | pil_img = Image.open(buf).copy() |
| |
|
| | buf.close() |
| | plt.close(fig) |
| |
|
| | return pil_img |
| |
|
| | if show_original: |
| | fig = plt.figure(figsize=(4, 4)) |
| | plt.imshow(img) |
| | plt.title("Original") |
| | plt.axis("off") |
| | if show: |
| | plt.show() |
| | outputs.append(fig_to_pil(fig)) |
| |
|
| | for c in range(num_channels): |
| | fig = plt.figure(figsize=(4, 4)) |
| |
|
| | plt.imshow(feature_stack[:, :, c], cmap="gray") |
| | plt.title(f"Feature {self.feature_names[c]}") |
| | plt.axis("off") |
| | if show: |
| | plt.show() |
| | outputs.append(fig_to_pil(fig)) |
| |
|
| | return outputs |
| |
|
| |
|
| | def output(self): |
| | """Return dummy output to compute in_features for FC head""" |
| | dummy_img = np.zeros((1, self.img_size[1],self.img_size[0], 3), dtype=np.float32) |
| | feat = self.forward(dummy_img) |
| | return feat |
| |
|
| |
|
| |
|
| | class FullyConnectedHead(nn.Module): |
| | def __init__(self,in_features,classes,config:Config): |
| | super().__init__() |
| | num_classes = len(classes) |
| | self.classes = classes |
| | layers = [] |
| | out_features=256 |
| | for i in range(config.fc_num_layers): |
| | layers.append(nn.Linear(in_features,out_features)) |
| | layers.append(nn.BatchNorm1d(out_features)) |
| | layers.append(nn.ReLU()) |
| | layers.append(nn.Dropout(config.dropout)) |
| | in_features=out_features |
| | out_features=out_features // 2 |
| | if out_features <= num_classes: |
| | break |
| | layers.append(nn.Linear(in_features,num_classes)) |
| | self.layers = nn.Sequential(*layers) |
| | def get_device(self): |
| | return next(self.parameters()).device |
| | def forward(self,x : torch.Tensor): |
| | x=x.to(self.get_device()) |
| | return self.layers(x) |
| | |
| | class Classifier(nn.Module): |
| | def __init__(self,backbone,classes,config : Config): |
| | super().__init__() |
| | self.config=config |
| | self.classes=classes |
| | self.backbone = backbone |
| | self.flatten = nn.Flatten() |
| | feat = backbone.output() |
| | flat = self.flatten(feat) |
| | in_features = flat.shape[1] |
| | self.head = FullyConnectedHead(in_features,classes,config) |
| | def get_device(self): |
| | return next(self.parameters()).device |
| | |
| | @torch.no_grad() |
| | def predict(self, x): |
| | self.eval() |
| | target_size = self.config.img_size |
| | x = cv2.resize(x, target_size) |
| | logits = self.forward(x) |
| | probs = torch.softmax(logits, dim=1) |
| | pred_idx = torch.argmax(probs, dim=1).item() |
| |
|
| | return self.classes[pred_idx] |
| |
|
| | def forward(self,x): |
| | feat = self.backbone(x) |
| | feat = self.flatten(feat) |
| | return self.head(feat) |
| | def visualize_feature(self,img,return_img=True,**kwargs): |
| | target_size = self.config.img_size |
| | img = cv2.resize(img, target_size) |
| | if return_img: |
| | return self.backbone.visualize(img,**kwargs) |
| | else: |
| | self.backbone.visualize(img,**kwargs) |
| | def save(self, path: str): |
| | os.makedirs(os.path.dirname(path), exist_ok=True) |
| | torch.save({ |
| | 'model_state_dict': self.state_dict(), |
| | 'classes': self.classes, |
| | 'config': self.config |
| | }, path) |
| | print(f"Model saved to {path}") |
| |
|
| | @staticmethod |
| | def load(path: str, backbone_class, device='cpu'): |
| | checkpoint = torch.load(path, map_location=device,weights_only=False) |
| | config = checkpoint['config'] |
| | classes = checkpoint['classes'] |
| | backbone = backbone_class(config).to(device) |
| | model = Classifier(backbone, classes, config).to(device) |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | model.eval() |
| | print(f"Model loaded from {path}") |
| | return model |