/*
 * Decompiled with CFR 0.152.
 */
package amten.ml;

import amten.ml.matrix.Matrix;

public class Convolutions {
    public static PoolingResult maxPool(Matrix inputs, int inputWidth, int inputHeight, int poolWidth, int poolHeight) {
        int inputSize = inputHeight * inputWidth;
        int numExamples = inputs.numRows() / inputSize;
        int numChannels = inputs.numColumns();
        int outputHeight = inputHeight % poolHeight == 0 ? inputHeight / poolHeight : inputHeight / poolHeight + 1;
        int outputWidth = inputWidth % poolWidth == 0 ? inputWidth / poolWidth : inputWidth / poolWidth + 1;
        int outputSize = outputHeight * outputWidth;
        Matrix outputs = new Matrix(numExamples * outputSize, numChannels);
        Matrix prePoolRowIndexes = new Matrix(outputs.numRows(), outputs.numColumns());
        int example = 0;
        while (example < numExamples) {
            int outputY = 0;
            while (outputY < outputHeight) {
                int outputX = 0;
                while (outputX < outputWidth) {
                    int channel = 0;
                    while (channel < numChannels) {
                        double maxValue = Double.NEGATIVE_INFINITY;
                        int maxInputsRowIndex = 0;
                        int inputY = outputY * poolHeight;
                        while (inputY < outputY * poolHeight + poolHeight && inputY < inputHeight) {
                            int inputX = outputX * poolWidth;
                            while (inputX < outputX * poolWidth + poolWidth && inputX < inputWidth) {
                                int inputsRowIndex = example * inputSize + inputY * inputWidth + inputX;
                                double value = inputs.get(inputsRowIndex, channel);
                                if (value > maxValue) {
                                    maxValue = value;
                                    maxInputsRowIndex = inputsRowIndex;
                                }
                                ++inputX;
                            }
                            ++inputY;
                        }
                        int outputsRowIndex = example * outputSize + outputY * outputWidth + outputX;
                        outputs.set(outputsRowIndex, channel, maxValue);
                        if (prePoolRowIndexes != null) {
                            prePoolRowIndexes.set(outputsRowIndex, channel, maxInputsRowIndex);
                        }
                        ++channel;
                    }
                    ++outputX;
                }
                ++outputY;
            }
            ++example;
        }
        return new PoolingResult(outputs, prePoolRowIndexes);
    }

    public static Matrix antiPoolDelta(Matrix delta, Matrix prePoolRowIndexes, int numRowsPrePool) {
        Matrix result = new Matrix(numRowsPrePool, delta.numColumns());
        int row = 0;
        while (row < delta.numRows()) {
            int col = 0;
            while (col < delta.numColumns()) {
                result.set((int)prePoolRowIndexes.get(row, col), col, delta.get(row, col));
                ++col;
            }
            ++row;
        }
        return result;
    }

    public static Matrix generatePatchesFromInputLayer(Matrix inputs, int inputWidth, int inputHeight, int patchWidth, int patchHeight) {
        int numChannels = inputs.numColumns() / (inputWidth * inputHeight);
        int numPatchesPerExample = (inputWidth - patchWidth + 1) * (inputHeight - patchHeight + 1);
        int numExamples = inputs.numRows();
        Matrix output = new Matrix(numExamples * numPatchesPerExample, numChannels * patchWidth * patchHeight);
        int example = 0;
        while (example < numExamples) {
            int patchNum = 0;
            int inputStartY = 0;
            while (inputStartY < inputHeight - patchHeight + 1) {
                int inputStartX = 0;
                while (inputStartX < inputWidth - patchWidth + 1) {
                    int channel = 0;
                    while (channel < numChannels) {
                        int patchPixelY = 0;
                        while (patchPixelY < patchHeight) {
                            int patchPixelX = 0;
                            while (patchPixelX < patchWidth) {
                                int inputY = inputStartY + patchPixelY;
                                int inputX = inputStartX + patchPixelX;
                                double value = inputs.get(example, channel * inputHeight * inputWidth + inputY * inputWidth + inputX);
                                output.set(example * numPatchesPerExample + patchNum, channel * patchHeight * patchWidth + patchPixelY * patchWidth + patchPixelX, value);
                                ++patchPixelX;
                            }
                            ++patchPixelY;
                        }
                        ++channel;
                    }
                    ++patchNum;
                    ++inputStartX;
                }
                ++inputStartY;
            }
            ++example;
        }
        return output;
    }

    public static Matrix generatePatchesFromHiddenLayer(Matrix inputs, int inputWidth, int inputHeight, int patchWidth, int patchHeight) {
        int numPatchesPerExample = (inputWidth - patchWidth + 1) * (inputHeight - patchHeight + 1);
        int inputSize = inputHeight * inputWidth;
        int numExamples = inputs.numRows() / inputSize;
        int numChannels = inputs.numColumns();
        int patchSize = patchHeight * patchWidth;
        Matrix output = new Matrix(numExamples * numPatchesPerExample, numChannels * patchSize);
        int example = 0;
        while (example < numExamples) {
            int patchNum = 0;
            int inputStartY = 0;
            while (inputStartY < inputHeight - patchHeight + 1) {
                int inputStartX = 0;
                while (inputStartX < inputWidth - patchWidth + 1) {
                    int channel = 0;
                    while (channel < numChannels) {
                        int patchPixelY = 0;
                        while (patchPixelY < patchHeight) {
                            int patchPixelX = 0;
                            while (patchPixelX < patchWidth) {
                                int inputY = inputStartY + patchPixelY;
                                int inputX = inputStartX + patchPixelX;
                                double value = inputs.get(example * inputSize + inputY * inputWidth + inputX, channel);
                                output.set(example * numPatchesPerExample + patchNum, channel * patchSize + patchPixelY * patchWidth + patchPixelX, value);
                                ++patchPixelX;
                            }
                            ++patchPixelY;
                        }
                        ++channel;
                    }
                    ++patchNum;
                    ++inputStartX;
                }
                ++inputStartY;
            }
            ++example;
        }
        return output;
    }

    public static Matrix antiPatchDeltas(Matrix output, int inputWidth, int inputHeight, int patchWidth, int patchHeight) {
        int numPatchesPerExample = (inputWidth - patchWidth + 1) * (inputHeight - patchHeight + 1);
        int inputSize = inputHeight * inputWidth;
        int numExamples = output.numRows() / numPatchesPerExample;
        int patchSize = patchHeight * patchWidth;
        int numChannels = output.numColumns() / patchSize;
        Matrix inputs = new Matrix(numExamples * inputSize, numChannels);
        int example = 0;
        while (example < numExamples) {
            int patchNum = 0;
            int inputStartY = 0;
            while (inputStartY < inputHeight - patchHeight + 1) {
                int inputStartX = 0;
                while (inputStartX < inputWidth - patchWidth + 1) {
                    int channel = 0;
                    while (channel < numChannels) {
                        int patchPixelY = 0;
                        while (patchPixelY < patchHeight) {
                            int patchPixelX = 0;
                            while (patchPixelX < patchWidth) {
                                int inputY = inputStartY + patchPixelY;
                                int inputX = inputStartX + patchPixelX;
                                double value = output.get(example * numPatchesPerExample + patchNum, channel * patchSize + patchPixelY * patchWidth + patchPixelX);
                                inputs.set(example * inputSize + inputY * inputWidth + inputX, channel, inputs.get(example * inputSize + inputY * inputWidth + inputX, channel) + value);
                                ++patchPixelX;
                            }
                            ++patchPixelY;
                        }
                        ++channel;
                    }
                    ++patchNum;
                    ++inputStartX;
                }
                ++inputStartY;
            }
            ++example;
        }
        return inputs;
    }

    public static Matrix movePatchesToColumns(Matrix inputs, int numExamples, int numFeatureMaps, int numPatches) {
        Matrix output = new Matrix(numExamples, numFeatureMaps * numPatches);
        int example = 0;
        while (example < numExamples) {
            int featureMap = 0;
            while (featureMap < numFeatureMaps) {
                int patch = 0;
                while (patch < numPatches) {
                    double value = inputs.get(example * numPatches + patch, featureMap);
                    output.set(example, featureMap * numPatches + patch, value);
                    ++patch;
                }
                ++featureMap;
            }
            ++example;
        }
        return output;
    }

    public static Matrix movePatchesToRows(Matrix x, int numExamples, int numFeatureMaps, int numPatches) {
        Matrix output = new Matrix(numExamples * numPatches, numFeatureMaps);
        int example = 0;
        while (example < numExamples) {
            int featureMap = 0;
            while (featureMap < numFeatureMaps) {
                int patch = 0;
                while (patch < numPatches) {
                    double value = x.get(example, featureMap * numPatches + patch);
                    output.set(example * numPatches + patch, featureMap, value);
                    ++patch;
                }
                ++featureMap;
            }
            ++example;
        }
        return output;
    }

    public static class PoolingResult {
        public Matrix pooledActivations = null;
        public Matrix prePoolRowIndexes = null;

        public PoolingResult(Matrix pooledActivations, Matrix prePoolRowIndexes) {
            this.pooledActivations = pooledActivations;
            this.prePoolRowIndexes = prePoolRowIndexes;
        }
    }
}

