/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import amten.ml.NNParams;
import amten.ml.matrix.Matrix;
import amten.ml.matrix.MatrixElement;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Utils;

public class NeuralNetwork
extends AbstractClassifier
implements Serializable {
    private NNParams myParams = new NNParams();
    private amten.ml.NeuralNetwork myNN = null;

    public NeuralNetwork() {
        this.setDebug(false);
    }

    public void buildClassifier(Instances instances) throws Exception {
        int numExamples = instances.numInstances();
        int numInputAttributes = instances.numAttributes() - 1;
        int classIndex = instances.classIndex();
        int numClasses = instances.numClasses();
        double[] classValues = instances.attributeToDoubleArray(classIndex);
        Matrix y = new Matrix(numExamples, 1);
        for (MatrixElement me : y) {
            me.set(classValues[me.row()]);
        }
        Matrix x = new Matrix(numExamples, numInputAttributes);
        int[] numCategories = new int[numInputAttributes];
        int col = 0;
        int attrIndex = 0;
        while (attrIndex < instances.numAttributes()) {
            Attribute attr = instances.attribute(attrIndex);
            if (attrIndex != classIndex) {
                int row = 0;
                while (row < numExamples) {
                    double value = instances.get(row).value(attrIndex);
                    boolean missing = instances.get(row).isMissing(attrIndex);
                    if (missing) {
                        value = attr.isNominal() ? -1.0 : 0.0;
                    }
                    x.set(row, col, value);
                    ++row;
                }
                numCategories[col] = attr.isNominal() ? attr.numValues() : 1;
                ++col;
            }
            ++attrIndex;
        }
        this.myParams.numClasses = numClasses;
        this.myParams.numCategories = numCategories;
        this.myNN = new amten.ml.NeuralNetwork(this.myParams);
        this.myNN.train(x, y);
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        Matrix x = new Matrix(1, instance.numAttributes() - 1);
        int classIndex = instance.classIndex();
        int col = 0;
        int attrIndex = 0;
        while (attrIndex < instance.numAttributes()) {
            Attribute attr = instance.attribute(attrIndex);
            if (attrIndex != classIndex) {
                double value = instance.value(attrIndex);
                boolean missing = instance.isMissing(attrIndex);
                if (missing) {
                    value = attr.isNominal() ? -1.0 : 0.0;
                }
                x.set(0, col, value);
                ++col;
            }
            ++attrIndex;
        }
        return this.myNN.getPredictions(x).getRow(0);
    }

    public Enumeration listOptions() {
        ArrayList<Option> options = new ArrayList<Option>();
        options.add(new Option("\tNumber of examples in each mini-batch.", "BatchSize", 1, "-bs"));
        options.add(new Option("\tWeight penalty", "WeightPenalty", 1, "-wp"));
        options.add(new Option("\tLearning rate", "LearningRate", 1, "-lr"));
        options.add(new Option("\tMaximum number of training iterations over the entire data set. (epochs)", "MaxIterations", 1, "-mi"));
        options.add(new Option("\tNumber of threads to use for training the network.", "Threads", 1, "-th"));
        options.add(new Option("\tNumber of Units in the hidden layers. (comma-separated list)\ne.g. \"100,100\" for two layers with 100 units each.\nFor convolutional layers: <num feature maps>-<patch-width>-<patch-height>-<pool-width>-<pool-height> \ne.g. \"20-5-5-2-2,100-5-5-2-2\" for two convolutional layers, both with patch size 5x5 and pool size 2x2, each with 20 and 100 feature maps respectively.", "HiddenLayers", 1, "-hl"));
        options.add(new Option("\tFraction of units to dropout in the input layer during training.", "InputLayerDropoutRate", 1, "-di"));
        options.add(new Option("\tFraction of units to dropout in the hidden layers during training.", "HiddenLayersDropoutRate", 1, "-dh"));
        options.add(new Option("\tWidth of input image (only used for convolution) (0=Square image).", "InputWidth", 1, "-iw"));
        return Collections.enumeration(options);
    }

    public void setOptions(String[] options) throws Exception {
        String weightPenaltyString = Utils.getOption((String)"wp", (String[])options);
        this.myParams.weightPenalty = weightPenaltyString.equals("") ? this.myParams.weightPenalty : Double.parseDouble(weightPenaltyString);
        String lrString = Utils.getOption((String)"lr", (String[])options);
        this.myParams.learningRate = lrString.equals("") ? this.myParams.learningRate : Double.parseDouble(lrString);
        String maxIterationsString = Utils.getOption((String)"mi", (String[])options);
        this.myParams.maxIterations = maxIterationsString.equals("") ? this.myParams.maxIterations : Integer.parseInt(maxIterationsString);
        String threadsString = Utils.getOption((String)"th", (String[])options);
        this.myParams.numThreads = threadsString.equals("") ? this.myParams.numThreads : Integer.parseInt(threadsString);
        String batchSizeString = Utils.getOption((String)"bs", (String[])options);
        this.myParams.batchSize = batchSizeString.equals("") ? this.myParams.batchSize : Integer.parseInt(batchSizeString);
        String hiddenLayersString = Utils.getOption((String)"hl", (String[])options);
        this.myParams.hiddenLayerParams = hiddenLayersString.equals("") ? this.myParams.hiddenLayerParams : this.getHiddenLayers(hiddenLayersString);
        String inputLayerDropoutRateString = Utils.getOption((String)"di", (String[])options);
        this.myParams.inputLayerDropoutRate = inputLayerDropoutRateString.equals("") ? this.myParams.inputLayerDropoutRate : Double.parseDouble(inputLayerDropoutRateString);
        String hiddenLayersDropoutRateString = Utils.getOption((String)"dh", (String[])options);
        this.myParams.hiddenLayersDropoutRate = hiddenLayersDropoutRateString.equals("") ? this.myParams.hiddenLayersDropoutRate : Double.parseDouble(hiddenLayersDropoutRateString);
        String inputWidthString = Utils.getOption((String)"iw", (String[])options);
        this.myParams.inputWidth = inputWidthString.equals("") ? this.myParams.inputWidth : Integer.parseInt(inputWidthString);
    }

    public String[] getOptions() {
        ArrayList<String> options = new ArrayList<String>();
        options.add("-lr");
        options.add(Double.toString(this.myParams.learningRate));
        options.add("-wp");
        options.add(Double.toString(this.myParams.weightPenalty));
        options.add("-mi");
        options.add(Integer.toString(this.myParams.maxIterations));
        options.add("-bs");
        options.add(Integer.toString(this.myParams.batchSize));
        options.add("-th");
        options.add(Integer.toString(this.myParams.numThreads));
        options.add("-hl");
        options.add(this.getString(this.myParams.hiddenLayerParams));
        options.add("-di");
        options.add(Double.toString(this.myParams.inputLayerDropoutRate));
        options.add("-dh");
        options.add(Double.toString(this.myParams.hiddenLayersDropoutRate));
        options.add("-iw");
        options.add(Integer.toString(this.myParams.inputWidth));
        return options.toArray(new String[options.size()]);
    }

    public double getWeightPenalty() {
        return this.myParams.weightPenalty;
    }

    public void setWeightPenalty(double weightPenalty) {
        this.myParams.weightPenalty = weightPenalty;
    }

    public String weightPenaltyTipText() {
        return "Weight penalty parameter.";
    }

    public String getHiddenLayers() {
        return this.getString(this.myParams.hiddenLayerParams);
    }

    public void setHiddenLayers(String hiddenLayers) {
        this.myParams.hiddenLayerParams = this.getHiddenLayers(hiddenLayers);
    }

    public String hiddenLayersTipText() {
        return "Number of units in each hidden layer (comma-separated) (For convolutional layers: <num feature maps>-<patch-width>-<patch-height>-<pool-width>-<pool-height>).";
    }

    public int getMaxIterations() {
        return this.myParams.maxIterations;
    }

    public void setMaxIterations(int iterations) {
        this.myParams.maxIterations = iterations;
    }

    public String maxIterationsTipText() {
        return "Maximum number of training iterations over the entire data set (epochs)";
    }

    public double getInputLayerDropoutRate() {
        return this.myParams.inputLayerDropoutRate;
    }

    public void setInputLayerDropoutRate(double inputLayerDropoutRate) {
        this.myParams.inputLayerDropoutRate = inputLayerDropoutRate;
    }

    public String inputLayerDropoutRateTipText() {
        return "Fraction of units to dropout in the input layer during training.";
    }

    public double getHiddenLayersDropoutRate() {
        return this.myParams.hiddenLayersDropoutRate;
    }

    public void setHiddenLayersDropoutRate(double hiddenLayersDropoutRate) {
        this.myParams.hiddenLayersDropoutRate = hiddenLayersDropoutRate;
    }

    public String hiddenLayersDropoutRateTipText() {
        return "Fraction of units to dropout in the hidden layers during training.";
    }

    public int getBatchSize() {
        return this.myParams.batchSize;
    }

    public void setBatchSize(int batchSize) {
        this.myParams.batchSize = batchSize;
    }

    public String batchSizeTipText() {
        return "Number of training examples in each mini-batch (0=Auto-choose) .";
    }

    public int getThreads() {
        return this.myParams.numThreads;
    }

    public void setThreads(int threads) {
        this.myParams.numThreads = threads;
    }

    public String threadsTipText() {
        return "The number of threads to use for training the network (0=Auto-detect)";
    }

    public double getLearningRate() {
        return this.myParams.learningRate;
    }

    public void setLearningRate(double learningRate) {
        this.myParams.learningRate = learningRate;
    }

    public String learningRateTipText() {
        return "Learning rate (0=Auto-detect).";
    }

    public int getInputWidth() {
        return this.myParams.inputWidth;
    }

    public void setInputWidth(int width) {
        this.myParams.inputWidth = width;
    }

    public String inputWidthTipText() {
        return "Width of input image (only used for convolution) (0=Square image).";
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.NUMERIC_CLASS);
        return result;
    }

    private NNParams.NNLayerParams[] getHiddenLayers(String s) {
        String[] stringList = s.split(",");
        ArrayList<NNParams.NNLayerParams> layerList = new ArrayList<NNParams.NNLayerParams>();
        String[] stringArray = stringList;
        int n = stringList.length;
        int n2 = 0;
        while (n2 < n) {
            String layerString = stringArray[n2];
            if (layerString.contains("-")) {
                String[] convStringList = layerString.split("-");
                if (convStringList.length >= 3) {
                    int numFeatureMaps = Integer.parseInt(convStringList[0]);
                    int patchWidth = Integer.parseInt(convStringList[1]);
                    int patchHeight = Integer.parseInt(convStringList[2]);
                    int poolWidth = convStringList.length > 4 ? Integer.parseInt(convStringList[3]) : 0;
                    int poolHeight = convStringList.length > 4 ? Integer.parseInt(convStringList[4]) : 0;
                    layerList.add(new NNParams.NNLayerParams(numFeatureMaps, patchWidth, patchHeight, poolWidth, poolHeight));
                }
            } else if (!layerString.equals("")) {
                layerList.add(new NNParams.NNLayerParams(Integer.parseInt(layerString)));
            }
            ++n2;
        }
        return layerList.toArray(new NNParams.NNLayerParams[layerList.size()]);
    }

    private String getString(NNParams.NNLayerParams[] layerList) {
        String s = "";
        NNParams.NNLayerParams[] nNLayerParamsArray = layerList;
        int n = layerList.length;
        int n2 = 0;
        while (n2 < n) {
            NNParams.NNLayerParams layer = nNLayerParamsArray[n2];
            if (!s.equals("")) {
                s = String.valueOf(s) + ",";
            }
            if (layer.isConvolutional()) {
                s = String.valueOf(s) + layer.numFeatures + "-" + layer.patchWidth + "-" + layer.patchHeight;
                if (layer.isPooled()) {
                    s = String.valueOf(s) + "-" + layer.poolWidth + "-" + layer.poolHeight;
                }
            } else {
                s = String.valueOf(s) + layer.numFeatures;
            }
            ++n2;
        }
        return s;
    }

    public String globalInfo() {
        return "(Convolutional) Neural Network implementation with dropout regularization and Rectified Linear Units.\n\nTraining is done with multithreaded mini-batch gradient descent.\n\nRunning Weka with console window and with debug flag for this classifier on, you can monitor training cost in console window and halt training anytime by pressing enter.\n\nHidden layers are specified as comma-separated lists.\ne.g. \"100,100\" for two layers with 100 units each.\nFor convolutional layers: <num feature maps>-<patch-width>-<patch-height>-<pool-width>-<pool-height> \ne.g. \"20-5-5-2-2,100-5-5-2-2\" for two convolutional layers, both with patch size 5x5 and pool size 2x2, each with 20 and 100 feature maps respectively.";
    }

    public static void main(String[] argv) {
        NeuralNetwork.runClassifier((Classifier)new NeuralNetwork(), (String[])argv);
    }
}

