| // The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt | |
| /* | |
| This example program shows you how to create your own custom binary classification | |
| trainer object and use it with the multiclass classification tools in the dlib C++ | |
| library. This example assumes you have already become familiar with the concepts | |
| introduced in the multiclass_classification_ex.cpp example program. | |
| In this example we will create a very simple trainer object that takes a binary | |
| classification problem and produces a decision rule which says a test point has the | |
| same class as whichever centroid it is closest to. | |
| The multiclass training dataset will consist of four classes. Each class will be a blob | |
| of points in one of the quadrants of the cartesian plane. For fun, we will use | |
| std::string labels and therefore the labels of these classes will be the following: | |
| "upper_left", | |
| "upper_right", | |
| "lower_left", | |
| "lower_right" | |
| */ | |
| using namespace std; | |
| using namespace dlib; | |
| // Our data will be 2-dimensional data. So declare an appropriate type to contain these points. | |
| typedef matrix<double,2,1> sample_type; | |
| // ---------------------------------------------------------------------------------------- | |
| struct custom_decision_function | |
| { | |
| /*! | |
| WHAT THIS OBJECT REPRESENTS | |
| This object is the representation of our binary decision rule. | |
| !*/ | |
| // centers of the two classes | |
| sample_type positive_center, negative_center; | |
| double operator() ( | |
| const sample_type& x | |
| ) const | |
| { | |
| // if x is closer to the positive class then return +1 | |
| if (length(positive_center - x) < length(negative_center - x)) | |
| return +1; | |
| else | |
| return -1; | |
| } | |
| }; | |
| // Later on in this example we will save our decision functions to disk. This | |
| // pair of routines is needed for this functionality. | |
| void serialize (const custom_decision_function& item, std::ostream& out) | |
| { | |
| // write the state of item to the output stream | |
| serialize(item.positive_center, out); | |
| serialize(item.negative_center, out); | |
| } | |
| void deserialize (custom_decision_function& item, std::istream& in) | |
| { | |
| // read the data from the input stream and store it in item | |
| deserialize(item.positive_center, in); | |
| deserialize(item.negative_center, in); | |
| } | |
| // ---------------------------------------------------------------------------------------- | |
| class simple_custom_trainer | |
| { | |
| /*! | |
| WHAT THIS OBJECT REPRESENTS | |
| This is our example custom binary classifier trainer object. It simply | |
| computes the means of the +1 and -1 classes, puts them into our | |
| custom_decision_function, and returns the results. | |
| Below we define the train() function. I have also included the | |
| requires/ensures definition for a generic binary classifier's train() | |
| !*/ | |
| public: | |
| custom_decision_function train ( | |
| const std::vector<sample_type>& samples, | |
| const std::vector<double>& labels | |
| ) const | |
| /*! | |
| requires | |
| - is_binary_classification_problem(samples, labels) == true | |
| (e.g. labels consists of only +1 and -1 values, samples.size() == labels.size()) | |
| ensures | |
| - returns a decision function F with the following properties: | |
| - if (new_x is a sample predicted have +1 label) then | |
| - F(new_x) >= 0 | |
| - else | |
| - F(new_x) < 0 | |
| !*/ | |
| { | |
| sample_type positive_center, negative_center; | |
| // compute sums of each class | |
| positive_center = 0; | |
| negative_center = 0; | |
| for (unsigned long i = 0; i < samples.size(); ++i) | |
| { | |
| if (labels[i] == +1) | |
| positive_center += samples[i]; | |
| else // this is a -1 sample | |
| negative_center += samples[i]; | |
| } | |
| // divide by number of +1 samples | |
| positive_center /= sum(mat(labels) == +1); | |
| // divide by number of -1 samples | |
| negative_center /= sum(mat(labels) == -1); | |
| custom_decision_function df; | |
| df.positive_center = positive_center; | |
| df.negative_center = negative_center; | |
| return df; | |
| } | |
| }; | |
| // ---------------------------------------------------------------------------------------- | |
| void generate_data ( | |
| std::vector<sample_type>& samples, | |
| std::vector<string>& labels | |
| ); | |
| /*! | |
| ensures | |
| - make some four class data as described above. | |
| - each class will have 50 samples in it | |
| !*/ | |
| // ---------------------------------------------------------------------------------------- | |
| int main() | |
| { | |
| std::vector<sample_type> samples; | |
| std::vector<string> labels; | |
| // First, get our labeled set of training data | |
| generate_data(samples, labels); | |
| cout << "samples.size(): "<< samples.size() << endl; | |
| // Define the trainer we will use. The second template argument specifies the type | |
| // of label used, which is string in this case. | |
| typedef one_vs_one_trainer<any_trainer<sample_type>, string> ovo_trainer; | |
| ovo_trainer trainer; | |
| // Now tell the one_vs_one_trainer that, by default, it should use the simple_custom_trainer | |
| // to solve the individual binary classification subproblems. | |
| trainer.set_trainer(simple_custom_trainer()); | |
| // Next, to make things a little more interesting, we will setup the one_vs_one_trainer | |
| // to use kernel ridge regression to solve the upper_left vs lower_right binary classification | |
| // subproblem. | |
| typedef radial_basis_kernel<sample_type> rbf_kernel; | |
| krr_trainer<rbf_kernel> rbf_trainer; | |
| rbf_trainer.set_kernel(rbf_kernel(0.1)); | |
| trainer.set_trainer(rbf_trainer, "upper_left", "lower_right"); | |
| // Now let's do 5-fold cross-validation using the one_vs_one_trainer we just setup. | |
| // As an aside, always shuffle the order of the samples before doing cross validation. | |
| // For a discussion of why this is a good idea see the svm_ex.cpp example. | |
| randomize_samples(samples, labels); | |
| cout << "cross validation: \n" << cross_validate_multiclass_trainer(trainer, samples, labels, 5) << endl; | |
| // This dataset is very easy and everything is correctly classified. Therefore, the output of | |
| // cross validation is the following confusion matrix. | |
| /* | |
| 50 0 0 0 | |
| 0 50 0 0 | |
| 0 0 50 0 | |
| 0 0 0 50 | |
| */ | |
| // We can also obtain the decision rule as always. | |
| one_vs_one_decision_function<ovo_trainer> df = trainer.train(samples, labels); | |
| cout << "predicted label: "<< df(samples[0]) << ", true label: "<< labels[0] << endl; | |
| cout << "predicted label: "<< df(samples[90]) << ", true label: "<< labels[90] << endl; | |
| // The output is: | |
| /* | |
| predicted label: upper_right, true label: upper_right | |
| predicted label: lower_left, true label: lower_left | |
| */ | |
| // Finally, let's save our multiclass decision rule to disk. Remember that we have | |
| // to specify the types of binary decision function used inside the one_vs_one_decision_function. | |
| one_vs_one_decision_function<ovo_trainer, | |
| custom_decision_function, // This is the output of the simple_custom_trainer | |
| decision_function<radial_basis_kernel<sample_type> > // This is the output of the rbf_trainer | |
| > df2, df3; | |
| df2 = df; | |
| // save to a file called df.dat | |
| serialize("df.dat") << df2; | |
| // load the function back in from disk and store it in df3. | |
| deserialize("df.dat") >> df3; | |
| // Test df3 to see that this worked. | |
| cout << endl; | |
| cout << "predicted label: "<< df3(samples[0]) << ", true label: "<< labels[0] << endl; | |
| cout << "predicted label: "<< df3(samples[90]) << ", true label: "<< labels[90] << endl; | |
| // Test df3 on the samples and labels and print the confusion matrix. | |
| cout << "test deserialized function: \n" << test_multiclass_decision_function(df3, samples, labels) << endl; | |
| } | |
| // ---------------------------------------------------------------------------------------- | |
| void generate_data ( | |
| std::vector<sample_type>& samples, | |
| std::vector<string>& labels | |
| ) | |
| { | |
| const long num = 50; | |
| sample_type m; | |
| dlib::rand rnd; | |
| // add some points in the upper right quadrant | |
| m = 10, 10; | |
| for (long i = 0; i < num; ++i) | |
| { | |
| samples.push_back(m + randm(2,1,rnd)); | |
| labels.push_back("upper_right"); | |
| } | |
| // add some points in the upper left quadrant | |
| m = -10, 10; | |
| for (long i = 0; i < num; ++i) | |
| { | |
| samples.push_back(m + randm(2,1,rnd)); | |
| labels.push_back("upper_left"); | |
| } | |
| // add some points in the lower right quadrant | |
| m = 10, -10; | |
| for (long i = 0; i < num; ++i) | |
| { | |
| samples.push_back(m + randm(2,1,rnd)); | |
| labels.push_back("lower_right"); | |
| } | |
| // add some points in the lower left quadrant | |
| m = -10, -10; | |
| for (long i = 0; i < num; ++i) | |
| { | |
| samples.push_back(m + randm(2,1,rnd)); | |
| labels.push_back("lower_left"); | |
| } | |
| } | |
| // ---------------------------------------------------------------------------------------- | |