| package org.maltparser.ml.lib; |
|
|
| import java.io.BufferedReader; |
| import java.io.EOFException; |
| import java.io.File; |
| import java.io.FileInputStream; |
| import java.io.IOException; |
| import java.io.InputStreamReader; |
| import java.io.ObjectInputStream; |
| import java.io.ObjectOutputStream; |
| import java.io.Reader; |
| import java.io.Serializable; |
| import java.nio.charset.Charset; |
| import java.util.Arrays; |
| import java.util.regex.Pattern; |
|
|
| import org.maltparser.core.helper.Util; |
|
|
| import de.bwaldvogel.liblinear.SolverType; |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| public class MaltLiblinearModel implements Serializable, MaltLibModel { |
| private static final long serialVersionUID = 7526471155622776147L; |
| private static final Charset FILE_CHARSET = Charset.forName("ISO-8859-1"); |
| private double bias; |
| |
| private int[] labels; |
| private int nr_class; |
| private int nr_feature; |
| private SolverType solverType; |
| |
| private double[][] w; |
|
|
| public MaltLiblinearModel(int[] labels, int nr_class, int nr_feature, double[][] w, SolverType solverType) { |
| this.labels = labels; |
| this.nr_class = nr_class; |
| this.nr_feature = nr_feature; |
| this.w = w; |
| this.solverType = solverType; |
| } |
| |
| public MaltLiblinearModel(Reader inputReader) throws IOException { |
| loadModel(inputReader); |
| } |
| |
| public MaltLiblinearModel(File modelFile) throws IOException { |
| BufferedReader inputReader = new BufferedReader(new InputStreamReader(new FileInputStream(modelFile), FILE_CHARSET)); |
| loadModel(inputReader); |
| } |
| |
| |
| |
| |
| public int getNrClass() { |
| return nr_class; |
| } |
|
|
| |
| |
| |
| public int getNrFeature() { |
| return nr_feature; |
| } |
|
|
| public int[] getLabels() { |
| return Util.copyOf(labels, nr_class); |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| public boolean isProbabilityModel() { |
| return (solverType == SolverType.L2R_LR || solverType == SolverType.L2R_LR_DUAL || solverType == SolverType.L1R_LR); |
| } |
| |
| public double getBias() { |
| return bias; |
| } |
| |
| public int[] predict(MaltFeatureNode[] x) { |
| final double[] dec_values = new double[nr_class]; |
| final int n = (bias >= 0)?nr_feature + 1:nr_feature; |
| final int xlen = x.length; |
| |
| for (int i=0; i < xlen; i++) { |
| if (x[i].index <= n) { |
| final int t = (x[i].index - 1); |
| if (w[t] != null) { |
| for (int j = 0; j < w[t].length; j++) { |
| dec_values[j] += w[t][j] * x[i].value; |
| } |
| } |
| } |
| } |
|
|
| |
| double tmpDec; |
| int tmpObj; |
| int iMax; |
| final int[] predictionList = new int[nr_class]; |
| System.arraycopy(labels, 0, predictionList, 0, nr_class); |
| final int nc = nr_class-1; |
| for (int i=0; i < nc; i++) { |
| iMax = i; |
| for (int j=i+1; j < nr_class; j++) { |
| if (dec_values[j] > dec_values[iMax]) { |
| iMax = j; |
| } |
| } |
| if (iMax != i) { |
| tmpDec = dec_values[iMax]; |
| dec_values[iMax] = dec_values[i]; |
| dec_values[i] = tmpDec; |
| tmpObj = predictionList[iMax]; |
| predictionList[iMax] = predictionList[i]; |
| predictionList[i] = tmpObj; |
| } |
| } |
| return predictionList; |
| } |
| |
| public int predict_one(MaltFeatureNode[] x) { |
| final double[] dec_values = new double[nr_class]; |
| final int n = (bias >= 0)?nr_feature + 1:nr_feature; |
| final int xlen = x.length; |
| |
| for (int i=0; i < xlen; i++) { |
| if (x[i].index <= n) { |
| final int t = (x[i].index - 1); |
| if (w[t] != null) { |
| for (int j = 0; j < w[t].length; j++) { |
| dec_values[j] += w[t][j] * x[i].value; |
| } |
| } |
| } |
| } |
| |
| double max = dec_values[0]; |
| int max_index = 0; |
| for (int i = 1; i < dec_values.length; i++) { |
| if (dec_values[i] > max) { |
| max = dec_values[i]; |
| max_index = i; |
| } |
| } |
|
|
| return labels[max_index]; |
| } |
| |
| private void readObject(ObjectInputStream is) throws ClassNotFoundException, IOException { |
| is.defaultReadObject(); |
| } |
|
|
| private void writeObject(ObjectOutputStream os) throws IOException { |
| os.defaultWriteObject(); |
| } |
| |
| private void loadModel(Reader inputReader) throws IOException { |
| labels = null; |
| Pattern whitespace = Pattern.compile("\\s+"); |
| BufferedReader reader = null; |
| if (inputReader instanceof BufferedReader) { |
| reader = (BufferedReader)inputReader; |
| } else { |
| reader = new BufferedReader(inputReader); |
| } |
|
|
| try { |
| String line = null; |
| while ((line = reader.readLine()) != null) { |
| String[] split = whitespace.split(line); |
| if (split[0].equals("solver_type")) { |
| SolverType solver = SolverType.valueOf(split[1]); |
| if (solver == null) { |
| throw new RuntimeException("unknown solver type"); |
| } |
| solverType = solver; |
| } else if (split[0].equals("nr_class")) { |
| nr_class = Util.atoi(split[1]); |
| Integer.parseInt(split[1]); |
| } else if (split[0].equals("nr_feature")) { |
| nr_feature = Util.atoi(split[1]); |
| } else if (split[0].equals("bias")) { |
| bias = Util.atof(split[1]); |
| } else if (split[0].equals("w")) { |
| break; |
| } else if (split[0].equals("label")) { |
| labels = new int[nr_class]; |
| for (int i = 0; i < nr_class; i++) { |
| labels[i] = Util.atoi(split[i + 1]); |
| } |
| } else { |
| throw new RuntimeException("unknown text in model file: [" + line + "]"); |
| } |
| } |
|
|
| int w_size = nr_feature; |
| if (bias >= 0) w_size++; |
|
|
| int nr_w = nr_class; |
| if (nr_class == 2 && solverType != SolverType.MCSVM_CS) nr_w = 1; |
| w = new double[w_size][nr_w]; |
| int[] buffer = new int[128]; |
|
|
| for (int i = 0; i < w_size; i++) { |
| for (int j = 0; j < nr_w; j++) { |
| int b = 0; |
| while (true) { |
| int ch = reader.read(); |
| if (ch == -1) { |
| throw new EOFException("unexpected EOF"); |
| } |
| if (ch == ' ') { |
| w[i][j] = Util.atof(new String(buffer, 0, b)); |
| break; |
| } else { |
| buffer[b++] = ch; |
| } |
| } |
| } |
| } |
| } |
| finally { |
| Util.closeQuietly(reader); |
| } |
| } |
|
|
| public int hashCode() { |
| final int prime = 31; |
| long temp = Double.doubleToLongBits(bias); |
| int result = prime * 1 + (int)(temp ^ (temp >>> 32)); |
| result = prime * result + Arrays.hashCode(labels); |
| result = prime * result + nr_class; |
| result = prime * result + nr_feature; |
| result = prime * result + ((solverType == null) ? 0 : solverType.hashCode()); |
| for (int i = 0; i < w.length; i++) { |
| result = prime * result + Arrays.hashCode(w[i]); |
| } |
| return result; |
| } |
|
|
| public boolean equals(Object obj) { |
| if (this == obj) return true; |
| if (obj == null) return false; |
| if (getClass() != obj.getClass()) return false; |
| MaltLiblinearModel other = (MaltLiblinearModel)obj; |
| if (Double.doubleToLongBits(bias) != Double.doubleToLongBits(other.bias)) return false; |
| if (!Arrays.equals(labels, other.labels)) return false; |
| if (nr_class != other.nr_class) return false; |
| if (nr_feature != other.nr_feature) return false; |
| if (solverType == null) { |
| if (other.solverType != null) return false; |
| } else if (!solverType.equals(other.solverType)) return false; |
| for (int i = 0; i < w.length; i++) { |
| if (other.w.length <= i) return false; |
| if (!Util.equals(w[i], other.w[i])) return false; |
| } |
| return true; |
| } |
| |
| public String toString() { |
| final StringBuilder sb = new StringBuilder("Model"); |
| sb.append(" bias=").append(bias); |
| sb.append(" nr_class=").append(nr_class); |
| sb.append(" nr_feature=").append(nr_feature); |
| sb.append(" solverType=").append(solverType); |
| return sb.toString(); |
| } |
| } |
|
|