rain1024's picture
update
6f3ebfa
package org.maltparser.parser.guide.instance;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Formatter;
import org.maltparser.core.exception.MaltChainedException;
import org.maltparser.core.feature.FeatureModel;
import org.maltparser.core.feature.FeatureVector;
import org.maltparser.core.feature.function.FeatureFunction;
import org.maltparser.core.feature.function.Modifiable;
import org.maltparser.core.syntaxgraph.DependencyStructure;
import org.maltparser.ml.LearningMethod;
import org.maltparser.ml.lib.LibLinear;
import org.maltparser.ml.lib.LibSvm;
import org.maltparser.parser.guide.ClassifierGuide;
import org.maltparser.parser.guide.GuideException;
import org.maltparser.parser.guide.Model;
import org.maltparser.parser.history.action.SingleDecision;
/**
@author Johan Hall
*/
public class AtomicModel implements InstanceModel {
public static final Class<?>[] argTypes = { org.maltparser.parser.guide.instance.InstanceModel.class, java.lang.Integer.class };
private final Model parent;
private final String modelName;
// private final FeatureVector featureVector;
private final int index;
private final LearningMethod method;
private int frequency = 0;
/**
* Constructs an atomic model.
*
* @param index the index of the atomic model (-1..n), where -1 is special value (used by a single model
* or the master divide model) and n is number of divide models.
* @param parent the parent guide model.
* @throws MaltChainedException
*/
public AtomicModel(int index, Model parent) throws MaltChainedException {
this.parent = parent;
this.index = index;
if (index == -1) {
this.modelName = parent.getModelName()+".";
} else {
this.modelName = parent.getModelName()+"."+new Formatter().format("%03d", index)+".";
}
// this.featureVector = featureVector;
this.frequency = 0;
Integer learnerMode = null;
if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
learnerMode = LearningMethod.CLASSIFY;
} else if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
learnerMode = LearningMethod.BATCH;
}
// start init learning method
Class<?> clazz = (Class<?>)getGuide().getConfiguration().getOptionValue("guide", "learner");
if (clazz == org.maltparser.ml.lib.LibSvm.class) {
this.method = new LibSvm(this, learnerMode);
} else if (clazz == org.maltparser.ml.lib.LibLinear.class) {
this.method = new LibLinear(this, learnerMode);
} else {
Object[] arguments = {this, learnerMode};
try {
Constructor<?> constructor = clazz.getConstructor(argTypes);
this.method = (LearningMethod)constructor.newInstance(arguments);
} catch (NoSuchMethodException e) {
throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
} catch (InstantiationException e) {
throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
} catch (IllegalAccessException e) {
throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
} catch (InvocationTargetException e) {
throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
}
}
// end init learning method
if (learnerMode == LearningMethod.BATCH && index == -1 && getGuide().getConfiguration() != null) {
getGuide().getConfiguration().writeInfoToConfigFile(method.toString());
}
}
public void addInstance(FeatureVector featureVector, SingleDecision decision) throws MaltChainedException {
try {
method.addInstance(decision, featureVector);
} catch (NullPointerException e) {
throw new GuideException("The learner cannot be found. ", e);
}
}
public void noMoreInstances(FeatureModel featureModel) throws MaltChainedException {
try {
method.noMoreInstances();
} catch (NullPointerException e) {
throw new GuideException("The learner cannot be found. ", e);
}
}
public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException {
try {
method.finalizeSentence(dependencyGraph);
} catch (NullPointerException e) {
throw new GuideException("The learner cannot be found. ", e);
}
}
public boolean predict(FeatureVector featureVector, SingleDecision decision) throws MaltChainedException {
try {
return method.predict(featureVector, decision);
} catch (NullPointerException e) {
throw new GuideException("The learner cannot be found. ", e);
}
}
public FeatureVector predictExtract(FeatureVector featureVector, SingleDecision decision) throws MaltChainedException {
try {
if (method.predict(featureVector, decision)) {
return featureVector;
}
return null;
} catch (NullPointerException e) {
throw new GuideException("The learner cannot be found. ", e);
}
}
public FeatureVector extract(FeatureVector featureVector) throws MaltChainedException {
return featureVector;
}
public void terminate() throws MaltChainedException {
if (method != null) {
method.terminate();
}
}
/**
* Moves all instance from this atomic model into the destination atomic model and add the divide feature.
* This method is used by the feature divide model to sum up all model below a certain threshold.
*
* @param model the destination atomic model
* @param divideFeature the divide feature
* @param divideFeatureIndexVector the divide feature index vector
* @throws MaltChainedException
*/
public void moveAllInstances(AtomicModel model, FeatureFunction divideFeature, ArrayList<Integer> divideFeatureIndexVector) throws MaltChainedException {
if (method == null) {
throw new GuideException("The learner cannot be found. ");
} else if (model == null) {
throw new GuideException("The guide model cannot be found. ");
} else if (divideFeature == null) {
throw new GuideException("The divide feature cannot be found. ");
} else if (divideFeatureIndexVector == null) {
throw new GuideException("The divide feature index vector cannot be found. ");
}
((Modifiable)divideFeature).setFeatureValue(index);
method.moveAllInstances(model.getMethod(), divideFeature, divideFeatureIndexVector);
method.terminate();
}
/**
* Invokes the train() of the learning method
*
* @throws MaltChainedException
*/
public void train() throws MaltChainedException {
try {
method.train();
method.terminate();
} catch (NullPointerException e) {
throw new GuideException("The learner cannot be found. ", e);
}
}
/**
* Returns the parent guide model
*
* @return the parent guide model
*/
public Model getParent() throws MaltChainedException {
if (parent == null) {
throw new GuideException("The atomic model can only be used by a parent model. ");
}
return parent;
}
public String getModelName() {
return modelName;
}
/**
* Returns the feature vector used by this atomic model
*
* @return a feature vector object
*/
// public FeatureVector getFeatures() {
// return featureVector;
// }
public ClassifierGuide getGuide() {
return parent.getGuide();
}
/**
* Returns the index of the atomic model
*
* @return the index of the atomic model
*/
public int getIndex() {
return index;
}
/**
* Returns the frequency (number of instances)
*
* @return the frequency (number of instances)
*/
public int getFrequency() {
return frequency;
}
/**
* Increase the frequency by 1
*/
public void increaseFrequency() {
if (parent instanceof InstanceModel) {
((InstanceModel)parent).increaseFrequency();
}
frequency++;
}
public void decreaseFrequency() {
if (parent instanceof InstanceModel) {
((InstanceModel)parent).decreaseFrequency();
}
frequency--;
}
/**
* Sets the frequency (number of instances)
*
* @param frequency (number of instances)
*/
protected void setFrequency(int frequency) {
this.frequency = frequency;
}
/**
* Returns a learner object
*
* @return a learner object
*/
public LearningMethod getMethod() {
return method;
}
/* (non-Javadoc)
* @see java.lang.Object#toString()
*/
public String toString() {
final StringBuilder sb = new StringBuilder();
sb.append(method.toString());
return sb.toString();
}
}