File size: 3,568 Bytes
6f3ebfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
package org.maltparser.parser;

import org.maltparser.core.exception.MaltChainedException;
import org.maltparser.core.feature.FeatureModel;
import org.maltparser.core.symbol.SymbolTableHandler;
import org.maltparser.core.syntaxgraph.DependencyStructure;
import org.maltparser.parser.guide.ClassifierGuide;
import org.maltparser.parser.guide.OracleGuide;
import org.maltparser.parser.guide.SingleGuide;
import org.maltparser.parser.history.action.GuideDecision;
import org.maltparser.parser.history.action.GuideUserAction;

public class BatchTrainerWithDiagnostics extends Trainer {
	private final Diagnostics diagnostics;
	private final OracleGuide oracleGuide;
	private int parseCount;
	private final FeatureModel featureModel;
	
	public BatchTrainerWithDiagnostics(DependencyParserConfig manager, SymbolTableHandler symbolTableHandler) throws MaltChainedException {
		super(manager,symbolTableHandler);
		this.diagnostics = new Diagnostics(manager.getOptionValue("singlemalt", "diafile").toString());
		registry.setAlgorithm(this);
		setGuide(new SingleGuide(this,  ClassifierGuide.GuideMode.BATCH));
		String featureModelFileName = manager.getOptionValue("guide", "features").toString().trim();
		if (manager.isLoggerInfoEnabled()) {
			manager.logDebugMessage("  Feature model        : " + featureModelFileName+"\n");
			manager.logDebugMessage("  Learner              : " + manager.getOptionValueString("guide", "learner").toString()+"\n");
		}
		String dataSplitColumn = manager.getOptionValue("guide", "data_split_column").toString().trim();
		String dataSplitStructure = manager.getOptionValue("guide", "data_split_structure").toString().trim();
		this.featureModel = manager.getFeatureModelManager().getFeatureModel(SingleGuide.findURL(featureModelFileName, manager), 0, getParserRegistry(), dataSplitColumn, dataSplitStructure);

		manager.writeInfoToConfigFile("\nFEATURE MODEL\n");
		manager.writeInfoToConfigFile(featureModel.toString());
		oracleGuide = parserState.getFactory().makeOracleGuide(parserState.getHistory());
	}
	
	public DependencyStructure parse(DependencyStructure goldDependencyGraph, DependencyStructure parseDependencyGraph) throws MaltChainedException {
		parserState.clear();
		parserState.initialize(parseDependencyGraph);
		currentParserConfiguration = parserState.getConfiguration();
		parseCount++;

		diagnostics.writeToDiaFile(parseCount + "");

		TransitionSystem transitionSystem = parserState.getTransitionSystem();
		while (!parserState.isTerminalState()) {
			GuideUserAction action = transitionSystem.getDeterministicAction(parserState.getHistory(), currentParserConfiguration);
			if (action == null) {
				action = oracleGuide.predict(goldDependencyGraph, currentParserConfiguration);
				try {
					classifierGuide.addInstance(featureModel,(GuideDecision)action);
				} catch (NullPointerException e) {
					throw new MaltChainedException("The guide cannot be found. ", e);
				}
			} else {
				diagnostics.writeToDiaFile(" *");
			}

			diagnostics.writeToDiaFile(" " + transitionSystem.getActionString(action));

			parserState.apply(action);
		}
		copyEdges(currentParserConfiguration.getDependencyGraph(), parseDependencyGraph);
		parseDependencyGraph.linkAllTreesToRoot();
		oracleGuide.finalizeSentence(parseDependencyGraph);

		diagnostics.writeToDiaFile("\n");

		return parseDependencyGraph;
	}
	
	public OracleGuide getOracleGuide() {
		return oracleGuide;
	}
	
	public void train() throws MaltChainedException { }
	public void terminate() throws MaltChainedException {
		diagnostics.closeDiaWriter();
	}
}