TransformSuite / transforms.py
Ritabrata04's picture
fixes to high pass and low pass,laplacian and sobel (#2)
08f3349 verified
import cv2
import numpy as np
from matplotlib import pyplot as plt
def load_image(image):
if len(image.shape) == 2: # Grayscale image
return image, 'grayscale'
elif len(image.shape) == 3: # RGB image
return image, 'rgb'
else:
raise ValueError("Unsupported image format")
# 1. Histogram Equalization
def histogram_equalization(image):
if len(image.shape) == 3: # Convert to grayscale if RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
equalized_img = cv2.equalizeHist(image)
return equalized_img
# 2. Sobel Edge Detection
def sobel_edge_detection(image):
if len(image.shape) == 3: # Convert to grayscale if RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
sobel_x = cv2.Sobel(image, cv2.CV_64F, 1, 0, ksize=3)
sobel_y = cv2.Sobel(image, cv2.CV_64F, 0, 1, ksize=3)
sobel_combined = np.sqrt(sobel_x**2 + sobel_y**2)
sobel_combined = cv2.normalize(sobel_combined, None, 0, 255, cv2.NORM_MINMAX) # Added normalization
return sobel_combined
# 3. Gaussian Blur
def gaussian_blur(image, kernel_size=5):
blurred_img = cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
return blurred_img
# 4. Laplacian of Gaussian (LoG)
def laplacian_of_gaussian(image):
if len(image.shape) == 3: # Convert to grayscale if RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
blurred_img = cv2.GaussianBlur(image, (3, 3), 0)
log_img = cv2.Laplacian(blurred_img, cv2.CV_64F)
log_img = cv2.normalize(log_img, None, 0, 255, cv2.NORM_MINMAX) # Added normalization
return log_img
# 5. Median Filtering
def median_filter(image, kernel_size=5):
median_img = cv2.medianBlur(image, kernel_size)
return median_img
# Frequency Domain Transforms
def fourier_transform(image):
if len(image.shape) == 3: # Convert to grayscale if RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
dft = np.fft.fft2(image)
dft_shift = np.fft.fftshift(dft)
magnitude_spectrum = 20 * np.log(np.abs(dft_shift))
return magnitude_spectrum
def discrete_cosine_transform(image):
if len(image.shape) == 3: # Convert to grayscale if RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
dct = cv2.dct(np.float32(image) / 255.0)
return dct
def high_pass_filter(image):
if len(image.shape) == 3: # Convert to grayscale if RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
dft = np.fft.fft2(image)
dft_shift = np.fft.fftshift(dft)
rows, cols = image.shape
crow, ccol = rows // 2, cols // 2
mask = np.ones((rows, cols), np.float64) # Changed to float64
mask[crow-30:crow+30, ccol-30:ccol+30] = 0
fshift = dft_shift * mask
f_ishift = np.fft.ifftshift(fshift)
img_back = np.fft.ifft2(f_ishift)
img_back = np.abs(img_back)
return img_back
def low_pass_filter(image):
if len(image.shape) == 3: # Convert to grayscale if RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
dft = np.fft.fft2(image)
dft_shift = np.fft.fftshift(dft)
rows, cols = image.shape
crow, ccol = rows // 2, cols // 2
mask = np.zeros((rows, cols), np.float64) # Changed to float64
mask[crow-30:crow+30, ccol-30:ccol+30] = 1
fshift = dft_shift * mask
f_ishift = np.fft.ifftshift(fshift)
img_back = np.fft.ifft2(f_ishift)
img_back = np.abs(img_back)
return img_back
def wavelet_transform(image):
import pywt
if len(image.shape) == 3: # Convert to grayscale if RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
coeffs2 = pywt.dwt2(image, 'haar')
LL, (LH, HL, HH) = coeffs2
return LL