/*
 * Decompiled with CFR 0.152.
 */
package amten.ml.matrix;

import amten.ml.matrix.Matrix;
import amten.ml.matrix.MatrixElement;
import au.com.bytecode.opencsv.CSVReader;
import au.com.bytecode.opencsv.CSVWriter;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Reader;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

public class MatrixUtils {
    public static Matrix readCSV(String filename, char separator, int headerLines) throws IOException {
        BufferedReader br = new BufferedReader(new FileReader(filename));
        CSVReader cr = new CSVReader((Reader)br, separator, '\"', '\\', headerLines);
        List values = cr.readAll();
        cr.close();
        br.close();
        int numRows = values.size();
        int numCols = ((String[])values.get(0)).length;
        Matrix m = new Matrix(numRows, numCols);
        int row = 0;
        while (row < numRows) {
            String[] rowValues = (String[])values.get(row);
            int col = 0;
            while (col < numCols) {
                Double v = Double.parseDouble(rowValues[col]);
                m.set(row, col, v);
                ++col;
            }
            ++row;
        }
        return m;
    }

    public static void writeCSV(Matrix m, String filename) throws IOException {
        BufferedWriter bw = new BufferedWriter(new FileWriter(filename));
        CSVWriter cw = new CSVWriter((Writer)bw, ',', '\u0000');
        ArrayList<String[]> rows = new ArrayList<String[]>();
        int row = 0;
        while (row < m.numRows()) {
            String[] rowValues = new String[m.numColumns()];
            int col = 0;
            while (col < m.numColumns()) {
                rowValues[col] = Double.toString(m.get(row, col));
                ++col;
            }
            rows.add(rowValues);
            ++row;
        }
        cw.writeAll(rows);
        cw.close();
        bw.close();
    }

    public static Matrix random(int rows, int cols) {
        Random rnd = new Random();
        Matrix m = new Matrix(rows, cols);
        for (MatrixElement me : m) {
            me.set(rnd.nextDouble());
        }
        return m;
    }

    public static Matrix addBiasColumn(Matrix m) {
        Matrix bias = new Matrix(m.numRows(), 1);
        bias.fill(1.0);
        return bias.addColumns(m);
    }

    public static Matrix expandNominalAttributes(Matrix mCompressed, int[] numCategories) {
        if (numCategories == null) {
            numCategories = new int[mCompressed.numColumns()];
            Arrays.fill(numCategories, 0);
        }
        int numExamples = mCompressed.numRows();
        int numColumnsExpanded = 0;
        int[] nArray = numCategories;
        int n = numCategories.length;
        int n2 = 0;
        while (n2 < n) {
            int numCat = nArray[n2];
            numColumnsExpanded += numCat > 0 ? numCat : 1;
            ++n2;
        }
        Matrix mExpanded = new Matrix(numExamples, numColumnsExpanded);
        int expandedCol = 0;
        int compressedCol = 0;
        while (compressedCol < mCompressed.numColumns()) {
            if (numCategories[compressedCol] <= 1) {
                int row = 0;
                while (row < numExamples) {
                    mExpanded.set(row, expandedCol, mCompressed.get(row, compressedCol));
                    ++row;
                }
                ++expandedCol;
            } else {
                int cat = 0;
                while (cat < numCategories[compressedCol]) {
                    int row = 0;
                    while (row < numExamples) {
                        double value = (double)cat == mCompressed.get(row, compressedCol) ? 1.0 : 0.0;
                        mExpanded.set(row, expandedCol, value);
                        ++row;
                    }
                    ++expandedCol;
                    ++cat;
                }
            }
            ++compressedCol;
        }
        return mExpanded;
    }

    public static Matrix compressNominalAttributes(Matrix mExpanded, int[] numCategories) {
        if (numCategories == null) {
            numCategories = new int[mExpanded.numColumns()];
            Arrays.fill(numCategories, 0);
        }
        int numExamples = mExpanded.numRows();
        int numColumnsCompressed = numCategories.length;
        Matrix mCompressed = new Matrix(numExamples, numColumnsCompressed);
        int expandedCol = 0;
        int compressedCol = 0;
        while (compressedCol < mCompressed.numColumns()) {
            int row;
            if (numCategories[compressedCol] < 1) {
                row = 0;
                while (row < numExamples) {
                    mCompressed.set(row, compressedCol, mExpanded.get(row, expandedCol));
                    ++row;
                }
                ++expandedCol;
            } else {
                row = 0;
                while (row < numExamples) {
                    mCompressed.set(row, compressedCol, -1.0);
                    int cat = 0;
                    while (cat < numCategories[compressedCol]) {
                        if (mExpanded.get(row, expandedCol) == 1.0) {
                            mCompressed.set(row, compressedCol, cat);
                        }
                        ++cat;
                    }
                    ++expandedCol;
                    ++row;
                }
            }
            ++compressedCol;
        }
        return mCompressed;
    }

    public static Matrix[] split(Matrix m, float crossValidationPercent, float testPercent) {
        double value;
        int col;
        int mRow;
        ArrayList<Integer> rowIndexes = new ArrayList<Integer>();
        int ri = 0;
        while (ri < m.numRows()) {
            rowIndexes.add(ri);
            ++ri;
        }
        Collections.shuffle(rowIndexes);
        int numCVRows = Math.round((float)m.numRows() * crossValidationPercent / 100.0f);
        int numTestRows = Math.round((float)m.numRows() * testPercent / 100.0f);
        int numTrainRows = m.numRows() - numCVRows - numTestRows;
        Matrix trainMatrix = new Matrix(numTrainRows, m.numColumns());
        Matrix cvMatrix = new Matrix(numCVRows, m.numColumns());
        Matrix testMatrix = new Matrix(numTestRows, m.numColumns());
        Iterator mRowsIter = rowIndexes.iterator();
        int row = 0;
        while (row < trainMatrix.numRows()) {
            mRow = (Integer)mRowsIter.next();
            col = 0;
            while (col < trainMatrix.numColumns()) {
                value = m.get(mRow, col);
                trainMatrix.set(row, col, value);
                ++col;
            }
            ++row;
        }
        row = 0;
        while (row < cvMatrix.numRows()) {
            mRow = (Integer)mRowsIter.next();
            col = 0;
            while (col < cvMatrix.numColumns()) {
                value = m.get(mRow, col);
                cvMatrix.set(row, col, value);
                ++col;
            }
            ++row;
        }
        row = 0;
        while (row < testMatrix.numRows()) {
            mRow = (Integer)mRowsIter.next();
            col = 0;
            while (col < testMatrix.numColumns()) {
                value = m.get(mRow, col);
                testMatrix.set(row, col, value);
                ++col;
            }
            ++row;
        }
        return new Matrix[]{trainMatrix, cvMatrix, testMatrix};
    }

    public static void split(Matrix x, Matrix y, int batchSize, List<Matrix> batchesX, List<Matrix> batchesY) {
        boolean createMatrices = batchesX.size() == 0;
        ArrayList<Integer> rowIndexes = new ArrayList<Integer>();
        int ri = 0;
        while (ri < x.numRows()) {
            rowIndexes.add(ri);
            ++ri;
        }
        Collections.shuffle(rowIndexes);
        int batchNr = 0;
        Matrix batchX = createMatrices ? new Matrix(Math.min(batchSize, x.numRows()), x.numColumns()) : batchesX.get(batchNr);
        Matrix batchY = createMatrices ? new Matrix(Math.min(batchSize, y.numRows()), y.numColumns()) : batchesY.get(batchNr);
        int ri2 = 0;
        while (ri2 < rowIndexes.size()) {
            double value;
            int row = (Integer)rowIndexes.get(ri2);
            int col = 0;
            while (col < x.numColumns()) {
                value = x.get(row, col);
                batchX.set(ri2 % batchSize, col, value);
                ++col;
            }
            col = 0;
            while (col < y.numColumns()) {
                value = y.get(row, col);
                batchY.set(ri2 % batchSize, col, value);
                ++col;
            }
            int rowsLeft = rowIndexes.size() - ri2 - 1;
            if ((ri2 + 1) % batchSize == 0 || rowsLeft == 0) {
                if (createMatrices) {
                    batchesX.add(batchX);
                    batchesY.add(batchY);
                    if (rowsLeft > 0) {
                        batchX = new Matrix(Math.min(batchSize, rowsLeft), x.numColumns());
                        batchY = new Matrix(Math.min(batchSize, rowsLeft), y.numColumns());
                    }
                } else if (rowsLeft > 0) {
                    batchX = batchesX.get(++batchNr);
                    batchY = batchesY.get(batchNr);
                }
            }
            ++ri2;
        }
    }

    public static double sigmoid(double x) {
        return 1.0 / (1.0 + Math.exp(-x));
    }

    public static Matrix sigmoid(Matrix m) {
        for (MatrixElement me : m) {
            me.set(MatrixUtils.sigmoid(me.value()));
        }
        return m;
    }

    public static Matrix softmax(Matrix m) {
        int row = 0;
        while (row < m.numRows()) {
            double max = 0.0;
            int col = 0;
            while (col < m.numColumns()) {
                double value = m.get(row, col);
                if (value > max) {
                    max = value;
                }
                ++col;
            }
            double sum = 0.0;
            int col2 = 0;
            while (col2 < m.numColumns()) {
                double value = m.get(row, col2);
                value -= max;
                value = Math.exp(value);
                m.set(row, col2, value);
                sum += value;
                ++col2;
            }
            col2 = 0;
            while (col2 < m.numColumns()) {
                m.set(row, col2, m.get(row, col2) / sum);
                ++col2;
            }
            ++row;
        }
        return m;
    }

    public static Matrix[] softMaxJacobians(Matrix m) {
        Matrix h = MatrixUtils.softmax(m.copy());
        Matrix[] results = new Matrix[m.numRows()];
        int example = 0;
        while (example < m.numRows()) {
            Matrix res = new Matrix(m.numColumns(), m.numColumns());
            for (MatrixElement me : res) {
                int row = me.row();
                int col = me.col();
                double delta = col == row ? 1.0 : 0.0;
                me.set(h.get(example, row) * (delta - h.get(example, col)));
            }
            results[example] = res;
            ++example;
        }
        return results;
    }

    public static Matrix sigmoidGradient(Matrix m) {
        Matrix t1 = MatrixUtils.sigmoid(m.copy());
        Matrix t2 = t1.copy();
        t2.scale(-1.0);
        t2.add(1.0);
        return t1.multElements(t2);
    }

    public static Matrix rectify(Matrix m) {
        for (MatrixElement me : m) {
            double value = me.value();
            value = Math.max(0.0, value);
            me.set(value);
        }
        return m;
    }

    public static Matrix rectifyGradient(Matrix m) {
        Matrix gradient = new Matrix(m.numRows(), m.numColumns());
        for (MatrixElement me : m) {
            double g = me.value() >= 0.0 ? 1 : 0;
            gradient.set(me.row(), me.col(), g);
        }
        return gradient;
    }

    public static Matrix log(Matrix m) {
        for (MatrixElement me : m) {
            me.set(Math.log(me.value()));
        }
        return m;
    }

    public static double getAverage(Matrix m, int col) {
        double sum = 0.0;
        int row = 0;
        while (row < m.numRows()) {
            sum += m.get(row, col);
            ++row;
        }
        return sum / (double)m.numRows();
    }

    public static double getStandardDeviation(Matrix m, int col) {
        double largestValue = Double.NEGATIVE_INFINITY;
        double smallestValue = Double.POSITIVE_INFINITY;
        int row = 0;
        while (row < m.numRows()) {
            double value = m.get(row, col);
            if (value > largestValue) {
                largestValue = value;
            }
            if (value < smallestValue) {
                smallestValue = value;
            }
            ++row;
        }
        return largestValue - smallestValue;
    }

    public static double normalizeData(double x, double average, double standardDeviation) {
        if (standardDeviation == 0.0) {
            standardDeviation = 1.0;
        }
        return (x - average) / standardDeviation;
    }

    public static void normalizeData(Matrix x, int col, double average, double standardDeviation) {
        if (standardDeviation == 0.0) {
            standardDeviation = 1.0;
        }
        int row = 0;
        while (row < x.numRows()) {
            x.set(row, col, (x.get(row, col) - average) / standardDeviation);
            ++row;
        }
    }
}

