File size: 8,625 Bytes
6f3ebfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
package org.maltparser.parser.guide.instance;

import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Formatter;

import org.maltparser.core.exception.MaltChainedException;
import org.maltparser.core.feature.FeatureModel;
import org.maltparser.core.feature.FeatureVector;
import org.maltparser.core.feature.function.FeatureFunction;
import org.maltparser.core.feature.function.Modifiable;
import org.maltparser.core.syntaxgraph.DependencyStructure;
import org.maltparser.ml.LearningMethod;
import org.maltparser.ml.lib.LibLinear;
import org.maltparser.ml.lib.LibSvm;
import org.maltparser.parser.guide.ClassifierGuide;
import org.maltparser.parser.guide.GuideException;
import org.maltparser.parser.guide.Model;
import org.maltparser.parser.history.action.SingleDecision;


/**



@author Johan Hall

*/
public class AtomicModel implements InstanceModel {
	public static final Class<?>[] argTypes = { org.maltparser.parser.guide.instance.InstanceModel.class, java.lang.Integer.class };
	private final Model parent;
	private final String modelName;
//	private final FeatureVector featureVector;
	private final int index;
	private final LearningMethod method;
	private int frequency = 0;

	
	/**

	 * Constructs an atomic model.

	 * 

	 * @param index the index of the atomic model (-1..n), where -1 is special value (used by a single model 

	 * or the master divide model) and n is number of divide models.

	 * @param parent the parent guide model.

	 * @throws MaltChainedException

	 */
	public AtomicModel(int index, Model parent) throws MaltChainedException {
		this.parent = parent;
		this.index = index;
		if (index == -1) {
			this.modelName = parent.getModelName()+".";
		} else {
			this.modelName = parent.getModelName()+"."+new Formatter().format("%03d", index)+".";
		}
//		this.featureVector = featureVector;
		this.frequency = 0;
		Integer learnerMode = null;
		if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
			learnerMode = LearningMethod.CLASSIFY;
		} else if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
			learnerMode = LearningMethod.BATCH;
		}
		
		// start init learning method
		Class<?> clazz = (Class<?>)getGuide().getConfiguration().getOptionValue("guide", "learner");
		if (clazz == org.maltparser.ml.lib.LibSvm.class) {
			this.method = new LibSvm(this, learnerMode);
		} else if (clazz == org.maltparser.ml.lib.LibLinear.class) {
			this.method = new LibLinear(this, learnerMode);
		} else {
			Object[] arguments = {this, learnerMode};
			try {	
				Constructor<?> constructor = clazz.getConstructor(argTypes);
				this.method = (LearningMethod)constructor.newInstance(arguments);
			} catch (NoSuchMethodException e) {
				throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
			} catch (InstantiationException e) {
				throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
			} catch (IllegalAccessException e) {
				throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
			} catch (InvocationTargetException e) {
				throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
			}
		}
		// end init learning method
		
		if (learnerMode == LearningMethod.BATCH && index == -1 && getGuide().getConfiguration() != null) {
			getGuide().getConfiguration().writeInfoToConfigFile(method.toString());
		}
	}
	
	public void addInstance(FeatureVector featureVector, SingleDecision decision) throws MaltChainedException {
		try {
			method.addInstance(decision, featureVector);
		} catch (NullPointerException e) {
			throw new GuideException("The learner cannot be found. ", e);
		}
	}

	
	public void noMoreInstances(FeatureModel featureModel) throws MaltChainedException {
		try {
			method.noMoreInstances();
		} catch (NullPointerException e) {
			throw new GuideException("The learner cannot be found. ", e);
		}
	}
	
	public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException {
		try {
			method.finalizeSentence(dependencyGraph);
		} catch (NullPointerException e) {
			throw new GuideException("The learner cannot be found. ", e);
		}
	}

	public boolean predict(FeatureVector featureVector, SingleDecision decision) throws MaltChainedException {
		try {
			return method.predict(featureVector, decision);
		} catch (NullPointerException e) {
			throw new GuideException("The learner cannot be found. ", e);
		}
	}

	public FeatureVector predictExtract(FeatureVector featureVector, SingleDecision decision) throws MaltChainedException {
		try {

			if (method.predict(featureVector, decision)) {
				return featureVector;
			}
			return null;
		} catch (NullPointerException e) {
			throw new GuideException("The learner cannot be found. ", e);
		}
	}
	
	public FeatureVector extract(FeatureVector featureVector) throws MaltChainedException {
		return featureVector;
	}
	
	public void terminate() throws MaltChainedException {
		if (method != null) {
			method.terminate();
		}
	}
	
	/**

	 * Moves all instance from this atomic model into the destination atomic model and add the divide feature.

	 * This method is used by the feature divide model to sum up all model below a certain threshold.

	 * 

	 * @param model the destination atomic model 

	 * @param divideFeature the divide feature

	 * @param divideFeatureIndexVector the divide feature index vector

	 * @throws MaltChainedException

	 */
	public void moveAllInstances(AtomicModel model, FeatureFunction divideFeature, ArrayList<Integer> divideFeatureIndexVector) throws MaltChainedException {
		if (method == null) {
			throw new GuideException("The learner cannot be found. ");
		} else if (model == null) {
			throw new GuideException("The guide model cannot be found. ");
		} else if (divideFeature == null) {
			throw new GuideException("The divide feature cannot be found. ");
		} else if (divideFeatureIndexVector == null) {
			throw new GuideException("The divide feature index vector cannot be found. ");
		}
		((Modifiable)divideFeature).setFeatureValue(index);
		method.moveAllInstances(model.getMethod(), divideFeature, divideFeatureIndexVector);
		method.terminate();
	}
	
	/**

	 * Invokes the train() of the learning method 

	 * 

	 * @throws MaltChainedException

	 */
	public void train() throws MaltChainedException {
		try {
			method.train();
			method.terminate();
		} catch (NullPointerException e) {	
			throw new GuideException("The learner cannot be found. ", e);
		}
		

	}
	
	/**

	 * Returns the parent guide model

	 * 

	 * @return the parent guide model

	 */
	public Model getParent() throws MaltChainedException {
		if (parent == null) {
			throw new GuideException("The atomic model can only be used by a parent model. ");
		}
		return parent;
	}


	public String getModelName() {
		return modelName;
	}

	/**

	 * Returns the feature vector used by this atomic model

	 * 

	 * @return a feature vector object

	 */
//	public FeatureVector getFeatures() {
//		return featureVector;
//	}

	public ClassifierGuide getGuide() {
		return parent.getGuide();
	}
	
	/**

	 * Returns the index of the atomic model

	 * 

	 * @return the index of the atomic model

	 */
	public int getIndex() {
		return index;
	}

	/**

	 * Returns the frequency (number of instances)

	 * 

	 * @return the frequency (number of instances)

	 */
	public int getFrequency() {
		return frequency;
	}
	
	/**

	 * Increase the frequency by 1

	 */
	public void increaseFrequency() {
		if (parent instanceof InstanceModel) {
			((InstanceModel)parent).increaseFrequency();
		}
		frequency++;
	}
	
	public void decreaseFrequency() {
		if (parent instanceof InstanceModel) {
			((InstanceModel)parent).decreaseFrequency();
		}
		frequency--;
	}
	/**

	 * Sets the frequency (number of instances)

	 * 

	 * @param frequency (number of instances)

	 */
	protected void setFrequency(int frequency) {
		this.frequency = frequency;
	} 
	
	/**

	 * Returns a learner object

	 * 

	 * @return a learner object

	 */
	public LearningMethod getMethod() {
		return method;
	}
	
	
	/* (non-Javadoc)

	 * @see java.lang.Object#toString()

	 */
	public String toString() {
		final StringBuilder sb = new StringBuilder();
		sb.append(method.toString());
		return sb.toString();
	}
}