File size: 4,368 Bytes
f5bb0c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
//
// This script converts the MNIST dataset to the leveldb format used
// by caffe to train siamese network.
// Usage:
// convert_mnist_data input_image_file input_label_file output_db_file
// The MNIST dataset could be downloaded at
// http://yann.lecun.com/exdb/mnist/
#include <fstream> // NOLINT(readability/streams)
#include <string>
#include "glog/logging.h"
#include "google/protobuf/text_format.h"
#include "stdint.h"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/format.hpp"
#include "caffe/util/math_functions.hpp"
#ifdef USE_LEVELDB
#include "leveldb/db.h"
uint32_t swap_endian(uint32_t val) {
val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
return (val << 16) | (val >> 16);
}
void read_image(std::ifstream* image_file, std::ifstream* label_file,
uint32_t index, uint32_t rows, uint32_t cols,
char* pixels, char* label) {
image_file->seekg(index * rows * cols + 16);
image_file->read(pixels, rows * cols);
label_file->seekg(index + 8);
label_file->read(label, 1);
}
void convert_dataset(const char* image_filename, const char* label_filename,
const char* db_filename) {
// Open files
std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);
CHECK(image_file) << "Unable to open file " << image_filename;
CHECK(label_file) << "Unable to open file " << label_filename;
// Read the magic and the meta data
uint32_t magic;
uint32_t num_items;
uint32_t num_labels;
uint32_t rows;
uint32_t cols;
image_file.read(reinterpret_cast<char*>(&magic), 4);
magic = swap_endian(magic);
CHECK_EQ(magic, 2051) << "Incorrect image file magic.";
label_file.read(reinterpret_cast<char*>(&magic), 4);
magic = swap_endian(magic);
CHECK_EQ(magic, 2049) << "Incorrect label file magic.";
image_file.read(reinterpret_cast<char*>(&num_items), 4);
num_items = swap_endian(num_items);
label_file.read(reinterpret_cast<char*>(&num_labels), 4);
num_labels = swap_endian(num_labels);
CHECK_EQ(num_items, num_labels);
image_file.read(reinterpret_cast<char*>(&rows), 4);
rows = swap_endian(rows);
image_file.read(reinterpret_cast<char*>(&cols), 4);
cols = swap_endian(cols);
// Open leveldb
leveldb::DB* db;
leveldb::Options options;
options.create_if_missing = true;
options.error_if_exists = true;
leveldb::Status status = leveldb::DB::Open(
options, db_filename, &db);
CHECK(status.ok()) << "Failed to open leveldb " << db_filename
<< ". Is it already existing?";
char label_i;
char label_j;
char* pixels = new char[2 * rows * cols];
std::string value;
caffe::Datum datum;
datum.set_channels(2); // one channel for each image in the pair
datum.set_height(rows);
datum.set_width(cols);
LOG(INFO) << "A total of " << num_items << " items.";
LOG(INFO) << "Rows: " << rows << " Cols: " << cols;
for (int itemid = 0; itemid < num_items; ++itemid) {
int i = caffe::caffe_rng_rand() % num_items; // pick a random pair
int j = caffe::caffe_rng_rand() % num_items;
read_image(&image_file, &label_file, i, rows, cols,
pixels, &label_i);
read_image(&image_file, &label_file, j, rows, cols,
pixels + (rows * cols), &label_j);
datum.set_data(pixels, 2*rows*cols);
if (label_i == label_j) {
datum.set_label(1);
} else {
datum.set_label(0);
}
datum.SerializeToString(&value);
std::string key_str = caffe::format_int(itemid, 8);
db->Put(leveldb::WriteOptions(), key_str, value);
}
delete db;
delete [] pixels;
}
int main(int argc, char** argv) {
if (argc != 4) {
printf("This script converts the MNIST dataset to the leveldb format used\n"
"by caffe to train a siamese network.\n"
"Usage:\n"
" convert_mnist_data input_image_file input_label_file "
"output_db_file\n"
"The MNIST dataset could be downloaded at\n"
" http://yann.lecun.com/exdb/mnist/\n"
"You should gunzip them after downloading.\n");
} else {
google::InitGoogleLogging(argv[0]);
convert_dataset(argv[1], argv[2], argv[3]);
}
return 0;
}
#else
int main(int argc, char** argv) {
LOG(FATAL) << "This example requires LevelDB; compile with USE_LEVELDB.";
}
#endif // USE_LEVELDB
|