| // The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt | |
| /* | |
| This is an example illustrating the use of the dlib machine learning tools for | |
| learning to solve the assignment problem. | |
| Many tasks in computer vision or natural language processing can be thought of | |
| as assignment problems. For example, in a computer vision application where | |
| you are trying to track objects moving around in video, you likely need to solve | |
| an association problem every time you get a new video frame. That is, each new | |
| frame will contain objects (e.g. people, cars, etc.) and you will want to | |
| determine which of these objects are actually things you have seen in previous | |
| frames. | |
| The assignment problem can be optimally solved using the well known Hungarian | |
| algorithm. However, this algorithm requires the user to supply some function | |
| which measures the "goodness" of an individual association. In many cases the | |
| best way to measure this goodness isn't obvious and therefore machine learning | |
| methods are used. | |
| The remainder of this example will show you how to learn a goodness function | |
| which is optimal, in a certain sense, for use with the Hungarian algorithm. To | |
| do this, we will make a simple dataset of example associations and use them to | |
| train a supervised machine learning method. | |
| Finally, note that there is a whole example program dedicated to assignment | |
| learning problems where you are trying to make an object tracker. So if that is | |
| what you are interested in then take a look at the learning_to_track_ex.cpp | |
| example program. | |
| */ | |
| using namespace std; | |
| using namespace dlib; | |
| // ---------------------------------------------------------------------------------------- | |
| /* | |
| In an association problem, we will talk about the "Left Hand Set" (LHS) and the | |
| "Right Hand Set" (RHS). The task will be to learn to map all elements of LHS to | |
| unique elements of RHS. If an element of LHS can't be mapped to a unique element of | |
| RHS for some reason (e.g. LHS is bigger than RHS) then it can also be mapped to the | |
| special -1 output, indicating no mapping to RHS. | |
| So the first step is to define the type of elements in each of these sets. In the | |
| code below we will use column vectors in both LHS and RHS. However, in general, | |
| they can each contain any type you like. LHS can even contain a different type | |
| than RHS. | |
| */ | |
| typedef dlib::matrix<double,0,1> column_vector; | |
| // This type represents a pair of LHS and RHS. That is, sample_type::first | |
| // contains a left hand set and sample_type::second contains a right hand set. | |
| typedef std::pair<std::vector<column_vector>, std::vector<column_vector> > sample_type; | |
| // This type will contain the association information between LHS and RHS. That is, | |
| // it will determine which elements of LHS map to which elements of RHS. | |
| typedef std::vector<long> label_type; | |
| // In this example, all our LHS and RHS elements will be 3-dimensional vectors. | |
| const unsigned long num_dims = 3; | |
| void make_data ( | |
| std::vector<sample_type>& samples, | |
| std::vector<label_type>& labels | |
| ); | |
| /*! | |
| ensures | |
| - This function creates a training dataset of 5 example associations. | |
| - #samples.size() == 5 | |
| - #labels.size() == 5 | |
| - for all valid i: | |
| - #samples[i].first == a left hand set | |
| - #samples[i].second == a right hand set | |
| - #labels[i] == a set of integers indicating how to map LHS to RHS. To be | |
| precise: | |
| - #samples[i].first.size() == #labels[i].size() | |
| - for all valid j: | |
| -1 <= #labels[i][j] < #samples[i].second.size() | |
| (A value of -1 indicates that #samples[i].first[j] isn't associated with anything. | |
| All other values indicate the associating element of #samples[i].second) | |
| - All elements of #labels[i] which are not equal to -1 are unique. That is, | |
| multiple elements of #samples[i].first can't associate to the same element | |
| in #samples[i].second. | |
| !*/ | |
| // ---------------------------------------------------------------------------------------- | |
| struct feature_extractor | |
| { | |
| /*! | |
| Recall that our task is to learn the "goodness of assignment" function for | |
| use with the Hungarian algorithm. The dlib tools assume this function | |
| can be written as: | |
| match_score(l,r) == dot(w, PSI(l,r)) + bias | |
| where l is an element of LHS, r is an element of RHS, w is a parameter vector, | |
| bias is a scalar value, and PSI() is a user supplied feature extractor. | |
| This feature_extractor is where we implement PSI(). How you implement this | |
| is highly problem dependent. | |
| !*/ | |
| // The type of feature vector returned from get_features(). This must be either | |
| // a dlib::matrix or a sparse vector. | |
| typedef column_vector feature_vector_type; | |
| // The types of elements in the LHS and RHS sets | |
| typedef column_vector lhs_element; | |
| typedef column_vector rhs_element; | |
| unsigned long num_features() const | |
| { | |
| // Return the dimensionality of feature vectors produced by get_features() | |
| return num_dims; | |
| } | |
| void get_features ( | |
| const lhs_element& left, | |
| const rhs_element& right, | |
| feature_vector_type& feats | |
| ) const | |
| /*! | |
| ensures | |
| - #feats == PSI(left,right) | |
| (i.e. This function computes a feature vector which, in some sense, | |
| captures information useful for deciding if matching left to right | |
| is "good"). | |
| !*/ | |
| { | |
| // Let's just use the squared difference between each vector as our features. | |
| // However, it should be emphasized that how to compute the features here is very | |
| // problem dependent. | |
| feats = squared(left - right); | |
| } | |
| }; | |
| // We need to define serialize() and deserialize() for our feature extractor if we want | |
| // to be able to serialize and deserialize our learned models. In this case the | |
| // implementation is empty since our feature_extractor doesn't have any state. But you | |
| // might define more complex feature extractors which have state that needs to be saved. | |
| void serialize (const feature_extractor& , std::ostream& ) {} | |
| void deserialize (feature_extractor& , std::istream& ) {} | |
| // ---------------------------------------------------------------------------------------- | |
| int main() | |
| { | |
| try | |
| { | |
| // Get a small bit of training data. | |
| std::vector<sample_type> samples; | |
| std::vector<label_type> labels; | |
| make_data(samples, labels); | |
| structural_assignment_trainer<feature_extractor> trainer; | |
| // This is the common SVM C parameter. Larger values encourage the | |
| // trainer to attempt to fit the data exactly but might overfit. | |
| // In general, you determine this parameter by cross-validation. | |
| trainer.set_c(10); | |
| // This trainer can use multiple CPU cores to speed up the training. | |
| // So set this to the number of available CPU cores. | |
| trainer.set_num_threads(4); | |
| // Do the training and save the results in assigner. | |
| assignment_function<feature_extractor> assigner = trainer.train(samples, labels); | |
| // Test the assigner on our data. The output will indicate that it makes the | |
| // correct associations on all samples. | |
| cout << "Test the learned assignment function: " << endl; | |
| for (unsigned long i = 0; i < samples.size(); ++i) | |
| { | |
| // Predict the assignments for the LHS and RHS in samples[i]. | |
| std::vector<long> predicted_assignments = assigner(samples[i]); | |
| cout << "true labels: " << trans(mat(labels[i])); | |
| cout << "predicted labels: " << trans(mat(predicted_assignments)) << endl; | |
| } | |
| // We can also use this tool to compute the percentage of assignments predicted correctly. | |
| cout << "training accuracy: " << test_assignment_function(assigner, samples, labels) << endl; | |
| // Since testing on your training data is a really bad idea, we can also do 5-fold cross validation. | |
| // Happily, this also indicates that all associations were made correctly. | |
| randomize_samples(samples, labels); | |
| cout << "cv accuracy: " << cross_validate_assignment_trainer(trainer, samples, labels, 5) << endl; | |
| // Finally, the assigner can be serialized to disk just like most dlib objects. | |
| serialize("assigner.dat") << assigner; | |
| // recall from disk | |
| deserialize("assigner.dat") >> assigner; | |
| } | |
| catch (std::exception& e) | |
| { | |
| cout << "EXCEPTION THROWN" << endl; | |
| cout << e.what() << endl; | |
| } | |
| } | |
| // ---------------------------------------------------------------------------------------- | |
| void make_data ( | |
| std::vector<sample_type>& samples, | |
| std::vector<label_type>& labels | |
| ) | |
| { | |
| // Make four different vectors. We will use them to make example assignments. | |
| column_vector A(num_dims), B(num_dims), C(num_dims), D(num_dims); | |
| A = 1,0,0; | |
| B = 0,1,0; | |
| C = 0,0,1; | |
| D = 0,1,1; | |
| std::vector<column_vector> lhs; | |
| std::vector<column_vector> rhs; | |
| label_type mapping; | |
| // In all the assignments to follow, we will only say an element of the LHS | |
| // matches an element of the RHS if the two are equal. So A matches with A, | |
| // B with B, etc. But never A with C, for example. | |
| // ------------------------ | |
| lhs.resize(3); | |
| lhs[0] = A; | |
| lhs[1] = B; | |
| lhs[2] = C; | |
| rhs.resize(3); | |
| rhs[0] = B; | |
| rhs[1] = A; | |
| rhs[2] = C; | |
| mapping.resize(3); | |
| mapping[0] = 1; // lhs[0] matches rhs[1] | |
| mapping[1] = 0; // lhs[1] matches rhs[0] | |
| mapping[2] = 2; // lhs[2] matches rhs[2] | |
| samples.push_back(make_pair(lhs,rhs)); | |
| labels.push_back(mapping); | |
| // ------------------------ | |
| lhs[0] = C; | |
| lhs[1] = A; | |
| lhs[2] = B; | |
| rhs[0] = A; | |
| rhs[1] = B; | |
| rhs[2] = D; | |
| mapping[0] = -1; // The -1 indicates that lhs[0] doesn't match anything in rhs. | |
| mapping[1] = 0; // lhs[1] matches rhs[0] | |
| mapping[2] = 1; // lhs[2] matches rhs[1] | |
| samples.push_back(make_pair(lhs,rhs)); | |
| labels.push_back(mapping); | |
| // ------------------------ | |
| lhs[0] = A; | |
| lhs[1] = B; | |
| lhs[2] = C; | |
| rhs.resize(4); | |
| rhs[0] = C; | |
| rhs[1] = B; | |
| rhs[2] = A; | |
| rhs[3] = D; | |
| mapping[0] = 2; | |
| mapping[1] = 1; | |
| mapping[2] = 0; | |
| samples.push_back(make_pair(lhs,rhs)); | |
| labels.push_back(mapping); | |
| // ------------------------ | |
| lhs.resize(2); | |
| lhs[0] = B; | |
| lhs[1] = C; | |
| rhs.resize(3); | |
| rhs[0] = C; | |
| rhs[1] = A; | |
| rhs[2] = D; | |
| mapping.resize(2); | |
| mapping[0] = -1; | |
| mapping[1] = 0; | |
| samples.push_back(make_pair(lhs,rhs)); | |
| labels.push_back(mapping); | |
| // ------------------------ | |
| lhs.resize(3); | |
| lhs[0] = D; | |
| lhs[1] = B; | |
| lhs[2] = C; | |
| // rhs will be empty. So none of the items in lhs can match anything. | |
| rhs.resize(0); | |
| mapping.resize(3); | |
| mapping[0] = -1; | |
| mapping[1] = -1; | |
| mapping[2] = -1; | |
| samples.push_back(make_pair(lhs,rhs)); | |
| labels.push_back(mapping); | |
| } | |