rain1024's picture
update
6f3ebfa
package org.maltparser.ml.lib;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.util.LinkedHashMap;
import org.maltparser.core.config.Configuration;
import org.maltparser.core.exception.MaltChainedException;
import org.maltparser.core.helper.NoPrintStream;
import org.maltparser.ml.lib.FeatureList;
import org.maltparser.ml.lib.MaltLibModel;
import org.maltparser.ml.lib.MaltLibsvmModel;
import org.maltparser.ml.lib.MaltFeatureNode;
import org.maltparser.ml.lib.LibException;
import org.maltparser.parser.guide.instance.InstanceModel;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;
public class LibSvm extends Lib {
public LibSvm(InstanceModel owner, Integer learnerMode) throws MaltChainedException {
super(owner, learnerMode, "libsvm");
if (learnerMode == CLASSIFY) {
model = (MaltLibModel)getConfigFileEntryObject(".moo");
}
}
protected void trainInternal(LinkedHashMap<String, String> libOptions) throws MaltChainedException {
try {
final svm_problem prob = readProblem(getInstanceInputStreamReader(".ins"), libOptions);
final svm_parameter param = getLibSvmParameters(libOptions);
if(svm.svm_check_parameter(prob, param) != null) {
throw new LibException(svm.svm_check_parameter(prob, param));
}
Configuration config = getConfiguration();
if (config.isLoggerInfoEnabled()) {
config.logInfoMessage("Creating LIBSVM model "+getFile(".moo").getName()+"\n");
}
final PrintStream out = System.out;
final PrintStream err = System.err;
System.setOut(NoPrintStream.NO_PRINTSTREAM);
System.setErr(NoPrintStream.NO_PRINTSTREAM);
svm_model model = svm.svm_train(prob, param);
System.setOut(err);
System.setOut(out);
ObjectOutputStream output = new ObjectOutputStream (new BufferedOutputStream(new FileOutputStream(getFile(".moo").getAbsolutePath())));
try{
output.writeObject(new MaltLibsvmModel(model, prob));
} finally {
output.close();
}
boolean saveInstanceFiles = ((Boolean)getConfiguration().getOptionValue("lib", "save_instance_files")).booleanValue();
if (!saveInstanceFiles) {
getFile(".ins").delete();
}
} catch (OutOfMemoryError e) {
throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e);
} catch (IllegalArgumentException e) {
throw new LibException("The LIBSVM learner was not able to redirect Standard Error stream. ", e);
} catch (SecurityException e) {
throw new LibException("The LIBSVM learner cannot remove the instance file. ", e);
} catch (IOException e) {
throw new LibException("The LIBSVM learner cannot save the model file '"+getFile(".mod").getAbsolutePath()+"'. ", e);
}
}
protected void trainExternal(String pathExternalTrain, LinkedHashMap<String, String> libOptions) throws MaltChainedException {
try {
binariesInstances2SVMFileFormat(getInstanceInputStreamReader(".ins"), getInstanceOutputStreamWriter(".ins.tmp"));
Configuration config = getConfiguration();
if (config.isLoggerInfoEnabled()) {
config.logInfoMessage("Creating learner model (external) "+getFile(".mod").getName());
}
final svm_problem prob = readProblem(getInstanceInputStreamReader(".ins"), libOptions);
final String[] params = getLibParamStringArray(libOptions);
String[] arrayCommands = new String[params.length+3];
int i = 0;
arrayCommands[i++] = pathExternalTrain;
for (; i <= params.length; i++) {
arrayCommands[i] = params[i-1];
}
arrayCommands[i++] = getFile(".ins.tmp").getAbsolutePath();
arrayCommands[i++] = getFile(".mod").getAbsolutePath();
if (verbosity == Verbostity.ALL) {
config.logInfoMessage('\n');
}
final Process child = Runtime.getRuntime().exec(arrayCommands);
final InputStream in = child.getInputStream();
final InputStream err = child.getErrorStream();
int c;
while ((c = in.read()) != -1){
if (verbosity == Verbostity.ALL) {
config.logInfoMessage((char)c);
}
}
while ((c = err.read()) != -1){
if (verbosity == Verbostity.ALL || verbosity == Verbostity.ERROR) {
config.logInfoMessage((char)c);
}
}
if (child.waitFor() != 0) {
config.logErrorMessage(" FAILED ("+child.exitValue()+")");
}
in.close();
err.close();
svm_model model = svm.svm_load_model(getFile(".mod").getAbsolutePath());
MaltLibsvmModel xmodel = new MaltLibsvmModel(model, prob);
ObjectOutputStream output = new ObjectOutputStream (new BufferedOutputStream(new FileOutputStream(getFile(".moo").getAbsolutePath())));
try {
output.writeObject(xmodel);
} finally {
output.close();
}
boolean saveInstanceFiles = ((Boolean)getConfiguration().getOptionValue("lib", "save_instance_files")).booleanValue();
if (!saveInstanceFiles) {
getFile(".ins").delete();
getFile(".mod").delete();
getFile(".ins.tmp").delete();
}
if (config.isLoggerInfoEnabled()) {
config.logInfoMessage('\n');
}
} catch (InterruptedException e) {
throw new LibException("Learner is interrupted. ", e);
} catch (IllegalArgumentException e) {
throw new LibException("The learner was not able to redirect Standard Error stream. ", e);
} catch (SecurityException e) {
throw new LibException("The learner cannot remove the instance file. ", e);
} catch (IOException e) {
throw new LibException("The learner cannot save the model file '"+getFile(".mod").getAbsolutePath()+"'. ", e);
} catch (OutOfMemoryError e) {
throw new LibException("Out of memory. Please increase the Java heap size (-Xmx<size>). ", e);
}
}
public void terminate() throws MaltChainedException {
super.terminate();
}
public LinkedHashMap<String, String> getDefaultLibOptions() {
LinkedHashMap<String, String> libOptions = new LinkedHashMap<String, String>();
libOptions.put("s", Integer.toString(svm_parameter.C_SVC));
libOptions.put("t", Integer.toString(svm_parameter.POLY));
libOptions.put("d", Integer.toString(2));
libOptions.put("g", Double.toString(0.2));
libOptions.put("r", Double.toString(0));
libOptions.put("n", Double.toString(0.5));
libOptions.put("m", Integer.toString(100));
libOptions.put("c", Double.toString(1));
libOptions.put("e", Double.toString(1.0));
libOptions.put("p", Double.toString(0.1));
libOptions.put("h", Integer.toString(1));
libOptions.put("b", Integer.toString(0));
return libOptions;
}
public String getAllowedLibOptionFlags() {
return "stdgrnmcepb";
}
private svm_parameter getLibSvmParameters(LinkedHashMap<String, String> libOptions) throws MaltChainedException {
svm_parameter param = new svm_parameter();
param.svm_type = Integer.parseInt(libOptions.get("s"));
param.kernel_type = Integer.parseInt(libOptions.get("t"));
param.degree = Integer.parseInt(libOptions.get("d"));
param.gamma = Double.valueOf(libOptions.get("g")).doubleValue();
param.coef0 = Double.valueOf(libOptions.get("r")).doubleValue();
param.nu = Double.valueOf(libOptions.get("n")).doubleValue();
param.cache_size = Double.valueOf(libOptions.get("m")).doubleValue();
param.C = Double.valueOf(libOptions.get("c")).doubleValue();
param.eps = Double.valueOf(libOptions.get("e")).doubleValue();
param.p = Double.valueOf(libOptions.get("p")).doubleValue();
param.shrinking = Integer.parseInt(libOptions.get("h"));
param.probability = Integer.parseInt(libOptions.get("b"));
param.nr_weight = 0;
param.weight_label = new int[0];
param.weight = new double[0];
return param;
}
private svm_problem readProblem(InputStreamReader isr, LinkedHashMap<String, String> libOptions) throws MaltChainedException {
final svm_problem problem = new svm_problem();
final svm_parameter param = getLibSvmParameters(libOptions);
final FeatureList featureList = new FeatureList();
try {
final BufferedReader fp = new BufferedReader(isr);
problem.l = getNumberOfInstances();
problem.x = new svm_node[problem.l][];
problem.y = new double[problem.l];
int i = 0;
while(true) {
String line = fp.readLine();
if(line == null) break;
int y = binariesInstance(line, featureList);
if (y == -1) {
continue;
}
try {
problem.y[i] = y;
problem.x[i] = new svm_node[featureList.size()];
int p = 0;
for (int k=0; k < featureList.size(); k++) {
MaltFeatureNode x = featureList.get(k);
problem.x[i][p] = new svm_node();
problem.x[i][p].value = x.getValue();
problem.x[i][p].index = x.getIndex();
p++;
}
i++;
} catch (ArrayIndexOutOfBoundsException e) {
throw new LibException("Couldn't read libsvm problem from the instance file. ", e);
}
}
fp.close();
if (param.gamma == 0) {
param.gamma = 1.0/featureMap.getFeatureCounter();
}
} catch (IOException e) {
throw new LibException("Couldn't read libsvm problem from the instance file. ", e);
}
return problem;
}
}