Spaces:
Runtime error
Runtime error
| /** | |
| * Copyright (c) 2016-present, Facebook, Inc. | |
| * All rights reserved. | |
| * | |
| * This source code is licensed under the MIT license found in the | |
| * LICENSE file in the root directory of this source tree. | |
| */ | |
| namespace fasttext { | |
| Args::Args() { | |
| lr = 0.05; | |
| dim = 100; | |
| ws = 5; | |
| epoch = 5; | |
| minCount = 5; | |
| minCountLabel = 0; | |
| neg = 5; | |
| wordNgrams = 1; | |
| loss = loss_name::ns; | |
| model = model_name::sg; | |
| bucket = 2000000; | |
| minn = 3; | |
| maxn = 6; | |
| thread = 12; | |
| lrUpdateRate = 100; | |
| t = 1e-4; | |
| label = "__label__"; | |
| verbose = 2; | |
| pretrainedVectors = ""; | |
| saveOutput = false; | |
| seed = 0; | |
| qout = false; | |
| retrain = false; | |
| qnorm = false; | |
| cutoff = 0; | |
| dsub = 2; | |
| autotuneValidationFile = ""; | |
| autotuneMetric = "f1"; | |
| autotunePredictions = 1; | |
| autotuneDuration = 60 * 5; // 5 minutes | |
| autotuneModelSize = ""; | |
| } | |
| std::string Args::lossToString(loss_name ln) const { | |
| switch (ln) { | |
| case loss_name::hs: | |
| return "hs"; | |
| case loss_name::ns: | |
| return "ns"; | |
| case loss_name::softmax: | |
| return "softmax"; | |
| case loss_name::ova: | |
| return "one-vs-all"; | |
| } | |
| return "Unknown loss!"; // should never happen | |
| } | |
| std::string Args::boolToString(bool b) const { | |
| if (b) { | |
| return "true"; | |
| } else { | |
| return "false"; | |
| } | |
| } | |
| std::string Args::modelToString(model_name mn) const { | |
| switch (mn) { | |
| case model_name::cbow: | |
| return "cbow"; | |
| case model_name::sg: | |
| return "sg"; | |
| case model_name::sup: | |
| return "sup"; | |
| } | |
| return "Unknown model name!"; // should never happen | |
| } | |
| std::string Args::metricToString(metric_name mn) const { | |
| switch (mn) { | |
| case metric_name::f1score: | |
| return "f1score"; | |
| case metric_name::f1scoreLabel: | |
| return "f1scoreLabel"; | |
| case metric_name::precisionAtRecall: | |
| return "precisionAtRecall"; | |
| case metric_name::precisionAtRecallLabel: | |
| return "precisionAtRecallLabel"; | |
| case metric_name::recallAtPrecision: | |
| return "recallAtPrecision"; | |
| case metric_name::recallAtPrecisionLabel: | |
| return "recallAtPrecisionLabel"; | |
| } | |
| return "Unknown metric name!"; // should never happen | |
| } | |
| void Args::parseArgs(const std::vector<std::string>& args) { | |
| const std::string& command(args[1]); | |
| if (command == "supervised") { | |
| model = model_name::sup; | |
| loss = loss_name::softmax; | |
| minCount = 1; | |
| minn = 0; | |
| maxn = 0; | |
| lr = 0.1; | |
| } else if (command == "cbow") { | |
| model = model_name::cbow; | |
| } | |
| for (int ai = 2; ai < args.size(); ai += 2) { | |
| if (args[ai][0] != '-') { | |
| std::cerr << "Provided argument without a dash! Usage:" << std::endl; | |
| printHelp(); | |
| exit(EXIT_FAILURE); | |
| } | |
| try { | |
| setManual(args[ai].substr(1)); | |
| if (args[ai] == "-h") { | |
| std::cerr << "Here is the help! Usage:" << std::endl; | |
| printHelp(); | |
| exit(EXIT_FAILURE); | |
| } else if (args[ai] == "-input") { | |
| input = std::string(args.at(ai + 1)); | |
| } else if (args[ai] == "-output") { | |
| output = std::string(args.at(ai + 1)); | |
| } else if (args[ai] == "-lr") { | |
| lr = std::stof(args.at(ai + 1)); | |
| } else if (args[ai] == "-lrUpdateRate") { | |
| lrUpdateRate = std::stoi(args.at(ai + 1)); | |
| } else if (args[ai] == "-dim") { | |
| dim = std::stoi(args.at(ai + 1)); | |
| } else if (args[ai] == "-ws") { | |
| ws = std::stoi(args.at(ai + 1)); | |
| } else if (args[ai] == "-epoch") { | |
| epoch = std::stoi(args.at(ai + 1)); | |
| } else if (args[ai] == "-minCount") { | |
| minCount = std::stoi(args.at(ai + 1)); | |
| } else if (args[ai] == "-minCountLabel") { | |
| minCountLabel = std::stoi(args.at(ai + 1)); | |
| } else if (args[ai] == "-neg") { | |
| neg = std::stoi(args.at(ai + 1)); | |
| } else if (args[ai] == "-wordNgrams") { | |
| wordNgrams = std::stoi(args.at(ai + 1)); | |
| } else if (args[ai] == "-loss") { | |
| if (args.at(ai + 1) == "hs") { | |
| loss = loss_name::hs; | |
| } else if (args.at(ai + 1) == "ns") { | |
| loss = loss_name::ns; | |
| } else if (args.at(ai + 1) == "softmax") { | |
| loss = loss_name::softmax; | |
| } else if ( | |
| args.at(ai + 1) == "one-vs-all" || args.at(ai + 1) == "ova") { | |
| loss = loss_name::ova; | |
| } else { | |
| std::cerr << "Unknown loss: " << args.at(ai + 1) << std::endl; | |
| printHelp(); | |
| exit(EXIT_FAILURE); | |
| } | |
| } else if (args[ai] == "-bucket") { | |
| bucket = std::stoi(args.at(ai + 1)); | |
| } else if (args[ai] == "-minn") { | |
| minn = std::stoi(args.at(ai + 1)); | |
| } else if (args[ai] == "-maxn") { | |
| maxn = std::stoi(args.at(ai + 1)); | |
| } else if (args[ai] == "-thread") { | |
| thread = std::stoi(args.at(ai + 1)); | |
| } else if (args[ai] == "-t") { | |
| t = std::stof(args.at(ai + 1)); | |
| } else if (args[ai] == "-label") { | |
| label = std::string(args.at(ai + 1)); | |
| } else if (args[ai] == "-verbose") { | |
| verbose = std::stoi(args.at(ai + 1)); | |
| } else if (args[ai] == "-pretrainedVectors") { | |
| pretrainedVectors = std::string(args.at(ai + 1)); | |
| } else if (args[ai] == "-saveOutput") { | |
| saveOutput = true; | |
| ai--; | |
| } else if (args[ai] == "-seed") { | |
| seed = std::stoi(args.at(ai + 1)); | |
| } else if (args[ai] == "-qnorm") { | |
| qnorm = true; | |
| ai--; | |
| } else if (args[ai] == "-retrain") { | |
| retrain = true; | |
| ai--; | |
| } else if (args[ai] == "-qout") { | |
| qout = true; | |
| ai--; | |
| } else if (args[ai] == "-cutoff") { | |
| cutoff = std::stoi(args.at(ai + 1)); | |
| } else if (args[ai] == "-dsub") { | |
| dsub = std::stoi(args.at(ai + 1)); | |
| } else if (args[ai] == "-autotune-validation") { | |
| autotuneValidationFile = std::string(args.at(ai + 1)); | |
| } else if (args[ai] == "-autotune-metric") { | |
| autotuneMetric = std::string(args.at(ai + 1)); | |
| getAutotuneMetric(); // throws exception if not able to parse | |
| getAutotuneMetricLabel(); // throws exception if not able to parse | |
| } else if (args[ai] == "-autotune-predictions") { | |
| autotunePredictions = std::stoi(args.at(ai + 1)); | |
| } else if (args[ai] == "-autotune-duration") { | |
| autotuneDuration = std::stoi(args.at(ai + 1)); | |
| } else if (args[ai] == "-autotune-modelsize") { | |
| autotuneModelSize = std::string(args.at(ai + 1)); | |
| } else { | |
| std::cerr << "Unknown argument: " << args[ai] << std::endl; | |
| printHelp(); | |
| exit(EXIT_FAILURE); | |
| } | |
| } catch (std::out_of_range) { | |
| std::cerr << args[ai] << " is missing an argument" << std::endl; | |
| printHelp(); | |
| exit(EXIT_FAILURE); | |
| } | |
| } | |
| if (input.empty() || output.empty()) { | |
| std::cerr << "Empty input or output path." << std::endl; | |
| printHelp(); | |
| exit(EXIT_FAILURE); | |
| } | |
| if (wordNgrams <= 1 && maxn == 0 && !hasAutotune()) { | |
| bucket = 0; | |
| } | |
| } | |
| void Args::printHelp() { | |
| printBasicHelp(); | |
| printDictionaryHelp(); | |
| printTrainingHelp(); | |
| printAutotuneHelp(); | |
| printQuantizationHelp(); | |
| } | |
| void Args::printBasicHelp() { | |
| std::cerr << "\nThe following arguments are mandatory:\n" | |
| << " -input training file path\n" | |
| << " -output output file path\n" | |
| << "\nThe following arguments are optional:\n" | |
| << " -verbose verbosity level [" << verbose << "]\n"; | |
| } | |
| void Args::printDictionaryHelp() { | |
| std::cerr << "\nThe following arguments for the dictionary are optional:\n" | |
| << " -minCount minimal number of word occurences [" | |
| << minCount << "]\n" | |
| << " -minCountLabel minimal number of label occurences [" | |
| << minCountLabel << "]\n" | |
| << " -wordNgrams max length of word ngram [" << wordNgrams | |
| << "]\n" | |
| << " -bucket number of buckets [" << bucket << "]\n" | |
| << " -minn min length of char ngram [" << minn | |
| << "]\n" | |
| << " -maxn max length of char ngram [" << maxn | |
| << "]\n" | |
| << " -t sampling threshold [" << t << "]\n" | |
| << " -label labels prefix [" << label << "]\n"; | |
| } | |
| void Args::printTrainingHelp() { | |
| std::cerr | |
| << "\nThe following arguments for training are optional:\n" | |
| << " -lr learning rate [" << lr << "]\n" | |
| << " -lrUpdateRate change the rate of updates for the learning " | |
| "rate [" | |
| << lrUpdateRate << "]\n" | |
| << " -dim size of word vectors [" << dim << "]\n" | |
| << " -ws size of the context window [" << ws << "]\n" | |
| << " -epoch number of epochs [" << epoch << "]\n" | |
| << " -neg number of negatives sampled [" << neg << "]\n" | |
| << " -loss loss function {ns, hs, softmax, one-vs-all} [" | |
| << lossToString(loss) << "]\n" | |
| << " -thread number of threads (set to 1 to ensure " | |
| "reproducible results) [" | |
| << thread << "]\n" | |
| << " -pretrainedVectors pretrained word vectors for supervised " | |
| "learning [" | |
| << pretrainedVectors << "]\n" | |
| << " -saveOutput whether output params should be saved [" | |
| << boolToString(saveOutput) << "]\n" | |
| << " -seed random generator seed [" << seed << "]\n"; | |
| } | |
| void Args::printAutotuneHelp() { | |
| std::cerr << "\nThe following arguments are for autotune:\n" | |
| << " -autotune-validation validation file to be used " | |
| "for evaluation\n" | |
| << " -autotune-metric metric objective {f1, " | |
| "f1:labelname} [" | |
| << autotuneMetric << "]\n" | |
| << " -autotune-predictions number of predictions used " | |
| "for evaluation [" | |
| << autotunePredictions << "]\n" | |
| << " -autotune-duration maximum duration in seconds [" | |
| << autotuneDuration << "]\n" | |
| << " -autotune-modelsize constraint model file size [" | |
| << autotuneModelSize << "] (empty = do not quantize)\n"; | |
| } | |
| void Args::printQuantizationHelp() { | |
| std::cerr | |
| << "\nThe following arguments for quantization are optional:\n" | |
| << " -cutoff number of words and ngrams to retain [" | |
| << cutoff << "]\n" | |
| << " -retrain whether embeddings are finetuned if a cutoff " | |
| "is applied [" | |
| << boolToString(retrain) << "]\n" | |
| << " -qnorm whether the norm is quantized separately [" | |
| << boolToString(qnorm) << "]\n" | |
| << " -qout whether the classifier is quantized [" | |
| << boolToString(qout) << "]\n" | |
| << " -dsub size of each sub-vector [" << dsub << "]\n"; | |
| } | |
| void Args::save(std::ostream& out) { | |
| out.write((char*)&(dim), sizeof(int)); | |
| out.write((char*)&(ws), sizeof(int)); | |
| out.write((char*)&(epoch), sizeof(int)); | |
| out.write((char*)&(minCount), sizeof(int)); | |
| out.write((char*)&(neg), sizeof(int)); | |
| out.write((char*)&(wordNgrams), sizeof(int)); | |
| out.write((char*)&(loss), sizeof(loss_name)); | |
| out.write((char*)&(model), sizeof(model_name)); | |
| out.write((char*)&(bucket), sizeof(int)); | |
| out.write((char*)&(minn), sizeof(int)); | |
| out.write((char*)&(maxn), sizeof(int)); | |
| out.write((char*)&(lrUpdateRate), sizeof(int)); | |
| out.write((char*)&(t), sizeof(double)); | |
| } | |
| void Args::load(std::istream& in) { | |
| in.read((char*)&(dim), sizeof(int)); | |
| in.read((char*)&(ws), sizeof(int)); | |
| in.read((char*)&(epoch), sizeof(int)); | |
| in.read((char*)&(minCount), sizeof(int)); | |
| in.read((char*)&(neg), sizeof(int)); | |
| in.read((char*)&(wordNgrams), sizeof(int)); | |
| in.read((char*)&(loss), sizeof(loss_name)); | |
| in.read((char*)&(model), sizeof(model_name)); | |
| in.read((char*)&(bucket), sizeof(int)); | |
| in.read((char*)&(minn), sizeof(int)); | |
| in.read((char*)&(maxn), sizeof(int)); | |
| in.read((char*)&(lrUpdateRate), sizeof(int)); | |
| in.read((char*)&(t), sizeof(double)); | |
| } | |
| void Args::dump(std::ostream& out) const { | |
| out << "dim" | |
| << " " << dim << std::endl; | |
| out << "ws" | |
| << " " << ws << std::endl; | |
| out << "epoch" | |
| << " " << epoch << std::endl; | |
| out << "minCount" | |
| << " " << minCount << std::endl; | |
| out << "neg" | |
| << " " << neg << std::endl; | |
| out << "wordNgrams" | |
| << " " << wordNgrams << std::endl; | |
| out << "loss" | |
| << " " << lossToString(loss) << std::endl; | |
| out << "model" | |
| << " " << modelToString(model) << std::endl; | |
| out << "bucket" | |
| << " " << bucket << std::endl; | |
| out << "minn" | |
| << " " << minn << std::endl; | |
| out << "maxn" | |
| << " " << maxn << std::endl; | |
| out << "lrUpdateRate" | |
| << " " << lrUpdateRate << std::endl; | |
| out << "t" | |
| << " " << t << std::endl; | |
| } | |
| bool Args::hasAutotune() const { | |
| return !autotuneValidationFile.empty(); | |
| } | |
| bool Args::isManual(const std::string& argName) const { | |
| return (manualArgs_.count(argName) != 0); | |
| } | |
| void Args::setManual(const std::string& argName) { | |
| manualArgs_.emplace(argName); | |
| } | |
| metric_name Args::getAutotuneMetric() const { | |
| if (autotuneMetric.substr(0, 3) == "f1:") { | |
| return metric_name::f1scoreLabel; | |
| } else if (autotuneMetric == "f1") { | |
| return metric_name::f1score; | |
| } else if (autotuneMetric.substr(0, 18) == "precisionAtRecall:") { | |
| size_t semicolon = autotuneMetric.find(':', 18); | |
| if (semicolon != std::string::npos) { | |
| return metric_name::precisionAtRecallLabel; | |
| } | |
| return metric_name::precisionAtRecall; | |
| } else if (autotuneMetric.substr(0, 18) == "recallAtPrecision:") { | |
| size_t semicolon = autotuneMetric.find(':', 18); | |
| if (semicolon != std::string::npos) { | |
| return metric_name::recallAtPrecisionLabel; | |
| } | |
| return metric_name::recallAtPrecision; | |
| } | |
| throw std::runtime_error("Unknown metric : " + autotuneMetric); | |
| } | |
| std::string Args::getAutotuneMetricLabel() const { | |
| metric_name metric = getAutotuneMetric(); | |
| std::string label; | |
| if (metric == metric_name::f1scoreLabel) { | |
| label = autotuneMetric.substr(3); | |
| } else if ( | |
| metric == metric_name::precisionAtRecallLabel || | |
| metric == metric_name::recallAtPrecisionLabel) { | |
| size_t semicolon = autotuneMetric.find(':', 18); | |
| label = autotuneMetric.substr(semicolon + 1); | |
| } else { | |
| return label; | |
| } | |
| if (label.empty()) { | |
| throw std::runtime_error("Empty metric label : " + autotuneMetric); | |
| } | |
| return label; | |
| } | |
| double Args::getAutotuneMetricValue() const { | |
| metric_name metric = getAutotuneMetric(); | |
| double value = 0.0; | |
| if (metric == metric_name::precisionAtRecallLabel || | |
| metric == metric_name::precisionAtRecall || | |
| metric == metric_name::recallAtPrecisionLabel || | |
| metric == metric_name::recallAtPrecision) { | |
| size_t firstSemicolon = 18; // semicolon position in "precisionAtRecall:" | |
| size_t secondSemicolon = autotuneMetric.find(':', firstSemicolon); | |
| const std::string valueStr = | |
| autotuneMetric.substr(firstSemicolon, secondSemicolon - firstSemicolon); | |
| value = std::stof(valueStr) / 100.0; | |
| } | |
| return value; | |
| } | |
| int64_t Args::getAutotuneModelSize() const { | |
| std::string modelSize = autotuneModelSize; | |
| if (modelSize.empty()) { | |
| return Args::kUnlimitedModelSize; | |
| } | |
| std::unordered_map<char, int> units = { | |
| {'k', 1000}, | |
| {'K', 1000}, | |
| {'m', 1000000}, | |
| {'M', 1000000}, | |
| {'g', 1000000000}, | |
| {'G', 1000000000}, | |
| }; | |
| uint64_t multiplier = 1; | |
| char lastCharacter = modelSize.back(); | |
| if (units.count(lastCharacter)) { | |
| multiplier = units[lastCharacter]; | |
| modelSize = modelSize.substr(0, modelSize.size() - 1); | |
| } | |
| uint64_t size = 0; | |
| size_t nonNumericCharacter = 0; | |
| bool parseError = false; | |
| try { | |
| size = std::stol(modelSize, &nonNumericCharacter); | |
| } catch (std::invalid_argument&) { | |
| parseError = true; | |
| } | |
| if (!parseError && nonNumericCharacter != modelSize.size()) { | |
| parseError = true; | |
| } | |
| if (parseError) { | |
| throw std::invalid_argument( | |
| "Unable to parse model size " + autotuneModelSize); | |
| } | |
| return size * multiplier; | |
| } | |
| } // namespace fasttext | |