cnn_visualizer / src /useConvolutionProcessing.ts
Joel Woodfield
Resize large images to prevent convolution wait time
00fddb0
import { useEffect, useMemo, useState } from "react";
const MAX_SIZE: number = 500;
function getBoundedDimensions(width: number, height: number): [number, number] {
if (width <= MAX_SIZE && height <= MAX_SIZE) {
return [width, height];
}
const scale = Math.min(MAX_SIZE / width, MAX_SIZE / height);
const boundedWidth = Math.max(1, Math.round(width * scale));
const boundedHeight = Math.max(1, Math.round(height * scale));
return [boundedWidth, boundedHeight];
}
async function getImageData(imageUrl: string): Promise<ImageData> {
const image = await new Promise<HTMLImageElement>((resolve, reject) => {
const img = new Image();
img.crossOrigin = "anonymous";
img.onload = () => resolve(img);
img.onerror = reject;
img.src = imageUrl;
})
const sourceWidth = image.naturalWidth || image.width;
const sourceHeight = image.naturalHeight || image.height;
const [targetWidth, targetHeight] = getBoundedDimensions(sourceWidth, sourceHeight);
const canvas = document.createElement("canvas")
canvas.width = targetWidth;
canvas.height = targetHeight;
const ctx = canvas.getContext("2d");
if (!ctx) {
throw new Error("Failed to get canvas context");
}
ctx.drawImage(image, 0, 0, targetWidth, targetHeight);
return ctx.getImageData(0, 0, canvas.width, canvas.height);
}
function getImageUrl(imageData: ImageData): string {
const canvas = document.createElement("canvas");
canvas.width = imageData.width;
canvas.height = imageData.height;
const ctx = canvas.getContext("2d")
if (!ctx) {
throw new Error("Failed to get canvas context");
}
ctx.putImageData(imageData, 0, 0);
return canvas.toDataURL("image/png");
}
function convertToGrayscale(imageData: ImageData): ImageData {
const output = new ImageData(
new Uint8ClampedArray(imageData.data),
imageData.width,
imageData.height,
);
const data = output.data;
for (let i = 0; i < data.length; i += 4) {
const r = data[i];
const g = data[i + 1];
const b = data[i + 2];
const gray = 0.299 * r + 0.587 * g + 0.114 * b;
data[i] = data[i + 1] = data[i + 2] = gray;
}
return output;
}
function convolve(imageData: ImageData, kernel: number[][] | number[][][]): ImageData {
if (Array.isArray(kernel[0][0])) {
// 3D kernel (color)
return convolveColor(imageData, kernel as number[][][]);
} else {
// 2D kernel (grayscale)
return convolveGray(imageData, kernel as number[][]);
}
}
function convolveGray(image: ImageData, kernel: number[][]): ImageData {
const kernelWidth = kernel[0].length;
const kernelHeight = kernel.length;
const width = image.width;
const height = image.height;
const inputData = image.data;
const outputWidth = width - kernelWidth + 1;
const outputHeight = height - kernelHeight + 1;
const outputData = new Uint8ClampedArray(outputWidth * outputHeight * 4);
for (let y = 0; y < outputHeight; ++y) {
for (let x = 0; x < outputWidth; ++x) {
// dot product
let sum = 0;
for (let ky = 0; ky < kernelHeight; ++ky) {
for (let kx = 0; kx < kernelWidth; ++kx) {
const pixelIndex = ((y + ky) * width + (x + kx)) * 4;
const pixelValue = inputData[pixelIndex];
const kernelValue = kernel[ky][kx];
sum += pixelValue * kernelValue;
}
}
const outputIndex = (y * outputWidth + x) * 4;
const clampedValue = Math.min(Math.max(sum, 0), 255);
outputData[outputIndex] = clampedValue; // R
outputData[outputIndex + 1] = clampedValue; // G
outputData[outputIndex + 2] = clampedValue; // B
outputData[outputIndex + 3] = 255; // A
}
}
return new ImageData(outputData, outputWidth, outputHeight);
}
function convolveColor(image: ImageData, kernel: number[][][]): ImageData {
const kernelWidth = kernel[0][0].length;
const kernelHeight = kernel[0].length;
const width = image.width;
const height = image.height;
const inputData = image.data;
const outputWidth = width - kernelWidth + 1;
const outputHeight = height - kernelHeight + 1;
const outputData = new Uint8ClampedArray(outputWidth * outputHeight * 4);
for (let y = 0; y < outputHeight; ++y) {
for (let x = 0; x < outputWidth; ++x) {
// dot product over 3 channels
let sum = 0;
for (let ky = 0; ky < kernelHeight; ++ky) {
for (let kx = 0; kx < kernelWidth; ++kx) {
const pixelIndex = ((y + ky) * width + (x + kx)) * 4;
const r = inputData[pixelIndex];
const g = inputData[pixelIndex + 1];
const b = inputData[pixelIndex + 2];
const kernelR = kernel[0][ky][kx];
const kernelG = kernel[1][ky][kx];
const kernelB = kernel[2][ky][kx];
sum += r * kernelR + g * kernelG + b * kernelB;
}
}
const outputIndex = (y * outputWidth + x) * 4;
const clampedValue = Math.min(Math.max(sum, 0), 255);
outputData[outputIndex] = clampedValue; // R
outputData[outputIndex + 1] = clampedValue; // G
outputData[outputIndex + 2] = clampedValue; // B
outputData[outputIndex + 3] = 255; // A
}
}
return new ImageData(outputData, outputWidth, outputHeight);
}
export default function useConvolutionProcessing(
rawInputImage: string,
kernel: number[][][] | number[][],
): [string | null, string | null] {
const useColor = Array.isArray(kernel[0][0]); // true if 3D kernel, false if 2D kernel
const [rawImageData, setRawImageData] = useState<ImageData | null>(null);
// extract input image data (array)
useEffect(() => {
let cancelled = false;
async function processImage() {
const imageData = await getImageData(rawInputImage);
if (!cancelled) {
setRawImageData(imageData);
}
}
processImage();
return () => {
cancelled = true;
}
}, [rawInputImage]);
const processedImageData = useMemo(() => {
if (!rawImageData) return null;
return useColor ? rawImageData : convertToGrayscale(rawImageData);
}, [rawImageData, useColor]);
const outputImageData = useMemo(() => {
if (!processedImageData) return null;
return convolve(processedImageData, kernel);
}, [processedImageData, kernel]);
const inputImage = useMemo(() => {
if (!processedImageData) return null;
return getImageUrl(processedImageData);
}, [processedImageData]);
const outputImage = useMemo(() => {
if (!outputImageData) return null;
return getImageUrl(outputImageData);
}, [outputImageData]);
return [inputImage, outputImage];
}