|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
#include <ATen/cuda/tunable/Tunable.h>
|
|
|
#include <ATen/cuda/tunable/StreamTimer.h>
|
|
|
#include <ATen/cuda/Sleep.h>
|
|
|
#include <c10/cuda/CUDACachingAllocator.h>
|
|
|
|
|
|
#ifndef _WIN32
|
|
|
#include <cxxabi.h>
|
|
|
#endif
|
|
|
|
|
|
#include <string>
|
|
|
#include <unordered_map>
|
|
|
#include <vector>
|
|
|
#include <deque>
|
|
|
|
|
|
namespace at::cuda::tunable {
|
|
|
|
|
|
template <typename ParamsT>
|
|
|
class Callable {
|
|
|
public:
|
|
|
virtual ~Callable() = default;
|
|
|
virtual TuningStatus Call(const ParamsT*) {
|
|
|
return FAIL;
|
|
|
}
|
|
|
virtual TuningStatus IsSupported(const ParamsT* params) {
|
|
|
return Call(params);
|
|
|
}
|
|
|
};
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
|
|
|
|
class Stats {
|
|
|
public:
|
|
|
Stats() {
|
|
|
_n = 0UL;
|
|
|
_mean = 0.0;
|
|
|
_M2 = 0.0;
|
|
|
_sum = 0.0;
|
|
|
_min = 0.0;
|
|
|
_max = 0.0;
|
|
|
}
|
|
|
|
|
|
void sample_value(const double x) {
|
|
|
double delta = 0;
|
|
|
_sum = _sum + x;
|
|
|
if (0UL == _n) {
|
|
|
_min = x;
|
|
|
_max = x;
|
|
|
}
|
|
|
else {
|
|
|
_min = _min < x ? _min : x;
|
|
|
_max = _max > x ? _max : x;
|
|
|
}
|
|
|
_n = _n + 1UL;
|
|
|
delta = x - _mean;
|
|
|
_mean = _mean + delta/_n;
|
|
|
_M2 = _M2 + delta * (x - _mean);
|
|
|
}
|
|
|
|
|
|
double variance() const {
|
|
|
return _M2/(_n-1);
|
|
|
}
|
|
|
|
|
|
double stddev() const {
|
|
|
return std::sqrt(variance());
|
|
|
}
|
|
|
|
|
|
unsigned long _n;
|
|
|
double _mean;
|
|
|
double _M2;
|
|
|
double _sum;
|
|
|
double _min;
|
|
|
double _max;
|
|
|
};
|
|
|
|
|
|
class FixedSizeStack {
|
|
|
private:
|
|
|
std::deque<std::string> stack;
|
|
|
const size_t max_size;
|
|
|
|
|
|
public:
|
|
|
FixedSizeStack(size_t size) : max_size(size) {}
|
|
|
|
|
|
void push(const std::string& value) {
|
|
|
if (stack.size() >= max_size) {
|
|
|
stack.pop_front();
|
|
|
}
|
|
|
stack.push_back(value);
|
|
|
}
|
|
|
|
|
|
auto rbegin() { return stack.rbegin(); }
|
|
|
auto rend() { return stack.rend(); }
|
|
|
};
|
|
|
|
|
|
}
|
|
|
|
|
|
template <typename ParamsT>
|
|
|
class TunableOp {
|
|
|
public:
|
|
|
virtual ~TunableOp() = default;
|
|
|
|
|
|
TuningStatus operator()(const ParamsT* params) {
|
|
|
ResultEntry result = ResultEntry::Null();
|
|
|
TuningContext* ctx = getTuningContext();
|
|
|
if (ctx->IsTunableOpEnabled()) {
|
|
|
auto& mgr = ctx->GetTuningResultsManager();
|
|
|
auto op_sig = Signature();
|
|
|
auto params_sig = params->Signature();
|
|
|
auto blas_sig = params->BLASSignature();
|
|
|
result = mgr.Lookup(op_sig, params_sig);
|
|
|
|
|
|
if (result == ResultEntry::Null()) {
|
|
|
if (ctx->IsTuningEnabled()) {
|
|
|
result = FindFastest(params);
|
|
|
mgr.Add(op_sig, params_sig, result);
|
|
|
}
|
|
|
else if (ctx->IsRecordUntunedEnabled()) {
|
|
|
|
|
|
mgr.RecordUntuned(ctx->GetUntunedFile(), op_sig, params_sig, blas_sig);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
else {
|
|
|
result = ResultEntry::Default();
|
|
|
}
|
|
|
if (result == ResultEntry::Null()) {
|
|
|
TUNABLE_LOG2("no result, using default");
|
|
|
result = ResultEntry::Default();
|
|
|
}
|
|
|
auto iter = ops_.find(result);
|
|
|
TORCH_CHECK(iter != ops_.end());
|
|
|
return iter->second->Call(params);
|
|
|
}
|
|
|
|
|
|
virtual std::string Signature() {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
c10::call_once(signature_init_once_, [this]() { signature_ = CreateSignature(); });
|
|
|
return signature_;
|
|
|
}
|
|
|
|
|
|
protected:
|
|
|
void RegisterOp(const std::string& name, std::unique_ptr<Callable<ParamsT>> op) {
|
|
|
this->op_names_.emplace_back(name);
|
|
|
this->ops_.emplace(name, std::move(op));
|
|
|
}
|
|
|
|
|
|
private:
|
|
|
static void WarmUp(Callable<ParamsT> *op, const std::vector<ParamsT*> ¶m, size_t num_iter, size_t &offset) {
|
|
|
TuningContext* ctx = getTuningContext();
|
|
|
bool do_flush = ctx->IsICacheFlushEnabled();
|
|
|
for (size_t i = 0; i < num_iter; i++) {
|
|
|
if (do_flush) {
|
|
|
at::cuda::flush_icache();
|
|
|
}
|
|
|
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
static double ProfileSimple(Callable<ParamsT> *op, const std::vector<ParamsT*> ¶m, size_t num_iter, size_t &offset) {
|
|
|
TuningContext* ctx = getTuningContext();
|
|
|
bool do_flush = ctx->IsICacheFlushEnabled();
|
|
|
StreamTimerNoSync timer{};
|
|
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < 2; i++) {
|
|
|
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
|
|
|
}
|
|
|
|
|
|
timer.Start();
|
|
|
for (size_t i = 0; i < num_iter; i++) {
|
|
|
if (do_flush) {
|
|
|
at::cuda::flush_icache();
|
|
|
}
|
|
|
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
|
|
|
}
|
|
|
timer.End();
|
|
|
return timer.Duration() / num_iter;
|
|
|
}
|
|
|
|
|
|
static Stats ProfileStats(Callable<ParamsT> *op, const std::vector<ParamsT*> ¶m, size_t num_iter, size_t &offset) {
|
|
|
TuningContext* ctx = getTuningContext();
|
|
|
bool do_flush = ctx->IsICacheFlushEnabled();
|
|
|
std::vector<StreamTimerNoSync> timer(num_iter);
|
|
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < 2; i++) {
|
|
|
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
|
|
|
}
|
|
|
|
|
|
for (size_t i = 0; i < num_iter; i++) {
|
|
|
timer[i].Start();
|
|
|
TORCH_CHECK(op->Call(param[(i+offset++)%param.size()]) == OK);
|
|
|
timer[i].End();
|
|
|
if (do_flush) {
|
|
|
at::cuda::flush_icache();
|
|
|
}
|
|
|
}
|
|
|
Stats s;
|
|
|
for (size_t i = 0; i < num_iter; i++) {
|
|
|
s.sample_value(timer[i].Duration());
|
|
|
}
|
|
|
return s;
|
|
|
}
|
|
|
|
|
|
protected:
|
|
|
virtual ResultEntry FindFastest(const ParamsT* params) {
|
|
|
TuningContext* ctx = getTuningContext();
|
|
|
auto op_sig = Signature();
|
|
|
auto params_sig = params->Signature();
|
|
|
auto blas_sig = params->BLASSignature();
|
|
|
TUNABLE_LOG2("finding fastest for ", op_sig, '(', params_sig, ')', " out of ", op_names_.size(), " candidates");
|
|
|
auto min_duration_ms = std::numeric_limits<double>::infinity();
|
|
|
std::string id_name = "Default";
|
|
|
ParamsT* reference_params = nullptr;
|
|
|
auto top_solns = FixedSizeStack(5);
|
|
|
|
|
|
|
|
|
bool do_numerics_check = ctx->IsNumericsCheckEnabled();
|
|
|
|
|
|
|
|
|
if (do_numerics_check) {
|
|
|
reference_params = params->DeepCopy(false);
|
|
|
TORCH_CHECK(ops_[ResultEntry::Default()]->Call(reference_params) == OK);
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
size_t rotating_size = ctx->GetRotatingBufferSize();
|
|
|
bool use_buffer_rotation = (rotating_size > 0);
|
|
|
size_t param_size = params->GetSize(use_buffer_rotation);
|
|
|
size_t param_count = (rotating_size / param_size) + 1;
|
|
|
constexpr size_t MB = 1024ull*1024;
|
|
|
if (use_buffer_rotation) {
|
|
|
TUNABLE_LOG2("Rotating buffer ", rotating_size/MB, " MiB. ",
|
|
|
"Needed Size: ", param_size/MB, " MiB. ",
|
|
|
"Needed number of param copies: ", param_count);
|
|
|
}
|
|
|
TORCH_CHECK(param_count > 0);
|
|
|
|
|
|
std::vector<ParamsT*> reusable_params(param_count);
|
|
|
for (size_t i = 0; i < param_count; i++) {
|
|
|
reusable_params[i] = params->DeepCopy(use_buffer_rotation);
|
|
|
}
|
|
|
|
|
|
|
|
|
size_t offset = 0;
|
|
|
|
|
|
for (size_t i = 0; i < op_names_.size(); i++) {
|
|
|
auto* candidate = ops_[op_names_[i]].get();
|
|
|
|
|
|
if (do_numerics_check) {
|
|
|
ParamsT* numerical_params = params->DeepCopy(false);
|
|
|
auto status = candidate->Call(numerical_params);
|
|
|
if (status != OK) {
|
|
|
numerical_params->Delete();
|
|
|
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
|
|
continue;
|
|
|
}
|
|
|
status = reference_params->NumericalCheck(numerical_params);
|
|
|
numerical_params->Delete();
|
|
|
if (status != OK) {
|
|
|
TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
|
|
continue;
|
|
|
}
|
|
|
}
|
|
|
else {
|
|
|
auto status = candidate->Call(reusable_params[0]);
|
|
|
if (status != OK) {
|
|
|
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
|
|
continue;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
int approx_num_iter = 3;
|
|
|
auto s = ProfileStats(candidate, reusable_params, approx_num_iter, offset);
|
|
|
double approx_duration = s._mean;
|
|
|
|
|
|
if (approx_duration > 1.5 * min_duration_ms) {
|
|
|
TUNABLE_LOG3("├──skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
|
|
continue;
|
|
|
}
|
|
|
|
|
|
|
|
|
approx_num_iter = 10;
|
|
|
s = ProfileStats(candidate, reusable_params, approx_num_iter, offset);
|
|
|
approx_duration = s._mean;
|
|
|
|
|
|
if (approx_duration > 1.15 * min_duration_ms) {
|
|
|
TUNABLE_LOG3("├──2nd skip slow instance id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
|
|
|
continue;
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
double max_warmup_duration = ctx->GetMaxWarmupDurationMs();
|
|
|
int max_warmup_iter = ctx->GetMaxWarmupIterations();
|
|
|
int warmup_iter = 0;
|
|
|
if (max_warmup_duration > 0) {
|
|
|
int duration_iters = max_warmup_duration / approx_duration;
|
|
|
if (max_warmup_iter > 0) {
|
|
|
warmup_iter = std::min(max_warmup_iter, duration_iters);
|
|
|
}
|
|
|
else {
|
|
|
warmup_iter = duration_iters;
|
|
|
}
|
|
|
}
|
|
|
else if (max_warmup_iter > 0) {
|
|
|
warmup_iter = max_warmup_iter;
|
|
|
}
|
|
|
|
|
|
|
|
|
double max_tuning_duration = ctx->GetMaxTuningDurationMs();
|
|
|
int max_tuning_iter = ctx->GetMaxTuningIterations();
|
|
|
int tuning_iter = 100;
|
|
|
if (max_tuning_duration > 0) {
|
|
|
int duration_iters = max_tuning_duration / approx_duration;
|
|
|
if (max_tuning_iter > 0) {
|
|
|
tuning_iter = std::min(max_tuning_iter, duration_iters);
|
|
|
}
|
|
|
else {
|
|
|
tuning_iter = duration_iters;
|
|
|
}
|
|
|
}
|
|
|
else if (max_tuning_iter > 0) {
|
|
|
tuning_iter = max_tuning_iter;
|
|
|
}
|
|
|
|
|
|
tuning_iter = std::max(1, tuning_iter);
|
|
|
|
|
|
|
|
|
double warmup_ms = warmup_iter * approx_duration;
|
|
|
double tuning_ms = tuning_iter * approx_duration;
|
|
|
TUNABLE_LOG3("├──tuning using "
|
|
|
"warmup iters ", warmup_iter, " [", warmup_ms, " ms] "
|
|
|
"and tuning iters ", tuning_iter, " [", tuning_ms, " ms] ",
|
|
|
"instance id=", i, ", ", op_sig, "(", params_sig, ") ", op_names_[i]);
|
|
|
TUNABLE_LOG3("├──offset at ", offset);
|
|
|
WarmUp(candidate, reusable_params, warmup_iter, offset);
|
|
|
s = ProfileStats(candidate, reusable_params, tuning_iter, offset);
|
|
|
auto s_stddev = s.stddev();
|
|
|
|
|
|
|
|
|
|
|
|
if (s._mean < min_duration_ms) {
|
|
|
TUNABLE_LOG3("├──found better instance id=", i, ". " , s._mean, "ms. ", op_names_[i],
|
|
|
" min ", s._min,
|
|
|
" max ", s._max,
|
|
|
" mean ", s._mean,
|
|
|
" std ", s_stddev);
|
|
|
min_duration_ms = s._mean;
|
|
|
id_name = op_names_[i];
|
|
|
std::string current_soln = std::to_string(s._mean) + " " + op_names_[i];
|
|
|
top_solns.push(current_soln);
|
|
|
}
|
|
|
else {
|
|
|
TUNABLE_LOG3("├──found slower instance id=", i, ". " , s._mean, "ms. ", op_names_[i],
|
|
|
" min ", s._min,
|
|
|
" max ", s._max,
|
|
|
" mean ", s._mean,
|
|
|
" std ", s_stddev);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
for (size_t i = 0; i < reusable_params.size(); i++) {
|
|
|
reusable_params[i]->Delete();
|
|
|
}
|
|
|
if (reference_params) {
|
|
|
reference_params->Delete();
|
|
|
}
|
|
|
|
|
|
TUNABLE_LOG2("└──found fastest for ", op_sig, '(', params_sig, ") ", id_name);
|
|
|
TUNABLE_LOG2("└──top five solutions for ", op_sig, '(', params_sig, ") ");
|
|
|
for (auto it = top_solns.rbegin(); it != top_solns.rend(); ++it) {
|
|
|
TUNABLE_LOG2(" ", *it);
|
|
|
}
|
|
|
return ResultEntry(id_name, min_duration_ms, blas_sig);
|
|
|
}
|
|
|
|
|
|
private:
|
|
|
std::string CreateSignature() {
|
|
|
#ifndef _WIN32
|
|
|
const auto* name = typeid(*this).name();
|
|
|
|
|
|
char buf[256];
|
|
|
size_t buf_len = 256;
|
|
|
abi::__cxa_demangle(name, buf, &buf_len, nullptr);
|
|
|
buf[255] = '\0';
|
|
|
return buf;
|
|
|
#else
|
|
|
return typeid(*this).name();
|
|
|
#endif
|
|
|
}
|
|
|
|
|
|
mutable c10::once_flag signature_init_once_;
|
|
|
std::string signature_;
|
|
|
|
|
|
std::unordered_map<std::string, std::unique_ptr<Callable<ParamsT>>> ops_;
|
|
|
std::vector<std::string> op_names_;
|
|
|
};
|
|
|
|
|
|
struct OpParams {
|
|
|
virtual ~OpParams() = default;
|
|
|
virtual std::string Signature() const = 0;
|
|
|
virtual std::string BLASSignature() const = 0;
|
|
|
};
|
|
|
|
|
|
}
|
|
|
|