|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
#include <array>
|
|
|
#include <stdexcept>
|
|
|
#include <string>
|
|
|
#include <type_traits>
|
|
|
|
|
|
namespace fbgemm {
|
|
|
|
|
|
template <int N, int... Vals>
|
|
|
constexpr
|
|
|
typename std::enable_if<N == sizeof...(Vals), std::array<int, N>>::type
|
|
|
array_of_ones() {
|
|
|
return std::array<int, N>{{Vals...}};
|
|
|
}
|
|
|
|
|
|
template <int N, int... Vals>
|
|
|
constexpr
|
|
|
typename std::enable_if<N != sizeof...(Vals), std::array<int, N>>::type
|
|
|
array_of_ones() {
|
|
|
return array_of_ones<N, Vals..., 1>();
|
|
|
}
|
|
|
|
|
|
template <int N, int... Vals>
|
|
|
constexpr
|
|
|
typename std::enable_if<N == sizeof...(Vals), std::array<int, N>>::type
|
|
|
array_of_zeroes() {
|
|
|
return std::array<int, N>{{Vals...}};
|
|
|
}
|
|
|
|
|
|
template <int N, int... Vals>
|
|
|
constexpr
|
|
|
typename std::enable_if<N != sizeof...(Vals), std::array<int, N>>::type
|
|
|
array_of_zeroes() {
|
|
|
return array_of_zeroes<N, Vals..., 0>();
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <int SPATIAL_DIM = 2>
|
|
|
struct conv_param_t {
|
|
|
int MB;
|
|
|
int IC;
|
|
|
int OC;
|
|
|
std::array<int, SPATIAL_DIM> IN_DIM;
|
|
|
int G;
|
|
|
std::array<int, SPATIAL_DIM> K;
|
|
|
std::array<int, SPATIAL_DIM> stride;
|
|
|
std::array<int, SPATIAL_DIM * 2>
|
|
|
pad;
|
|
|
|
|
|
std::array<int, SPATIAL_DIM> dilation;
|
|
|
|
|
|
|
|
|
std::array<int, SPATIAL_DIM> OUT_DIM;
|
|
|
std::array<int, SPATIAL_DIM> IN_DIMP;
|
|
|
|
|
|
|
|
|
std::array<int, SPATIAL_DIM>
|
|
|
output_pad;
|
|
|
bool transposed;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conv_param_t(
|
|
|
int mb,
|
|
|
int ic,
|
|
|
int oc,
|
|
|
std::array<int, SPATIAL_DIM> in_dim,
|
|
|
int g,
|
|
|
std::array<int, SPATIAL_DIM> k,
|
|
|
std::array<int, SPATIAL_DIM> strd,
|
|
|
std::array<int, SPATIAL_DIM * 2> pd,
|
|
|
std::array<int, SPATIAL_DIM> dilations = array_of_ones<SPATIAL_DIM>(),
|
|
|
std::array<int, SPATIAL_DIM> otpt_pd = array_of_zeroes<SPATIAL_DIM>(),
|
|
|
bool transposed = false)
|
|
|
: MB(mb),
|
|
|
IC(ic),
|
|
|
OC(oc),
|
|
|
IN_DIM(in_dim),
|
|
|
G(g),
|
|
|
K(k),
|
|
|
stride(strd),
|
|
|
pad(pd),
|
|
|
dilation(dilations),
|
|
|
output_pad(otpt_pd),
|
|
|
transposed(transposed) {
|
|
|
if (ic % g != 0) {
|
|
|
throw std::runtime_error(
|
|
|
"groups = " + std::to_string(g) +
|
|
|
" does not divide number of input channels = " + std::to_string(ic));
|
|
|
}
|
|
|
if (oc % g != 0) {
|
|
|
throw std::runtime_error(
|
|
|
"groups = " + std::to_string(g) +
|
|
|
" does not divide number of output channels = " + std::to_string(oc));
|
|
|
}
|
|
|
|
|
|
for (int d = 0; d < SPATIAL_DIM; ++d) {
|
|
|
if (transposed) {
|
|
|
this->IN_DIMP[d] = this->IN_DIM[d] +
|
|
|
(this->dilation[d] * (this->K[d] - 1) - this->pad[d]) +
|
|
|
(this->dilation[d] * (this->K[d] - 1) - this->pad[SPATIAL_DIM + d]);
|
|
|
this->OUT_DIM[d] = (this->IN_DIM[d] - 1) * this->stride[d] -
|
|
|
this->pad[d] - this->pad[SPATIAL_DIM + d] +
|
|
|
this->dilation[d] * (this->K[d] - 1) + output_pad[d] + 1;
|
|
|
} else {
|
|
|
IN_DIMP[d] = IN_DIM[d] + pad[d] + pad[SPATIAL_DIM + d];
|
|
|
OUT_DIM[d] =
|
|
|
(IN_DIMP[d] - dilation[d] * (K[d] - 1) - 1) / stride[d] + 1;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::string toString() const {
|
|
|
std::string dim_string[3] = {"T", "H", "W"};
|
|
|
|
|
|
std::string out = "";
|
|
|
out += "MB:" + std::to_string(MB) + ", ";
|
|
|
out += "IC:" + std::to_string(IC) + ", ";
|
|
|
out += "OC:" + std::to_string(OC) + ", ";
|
|
|
if (SPATIAL_DIM <= 3) {
|
|
|
for (int d = 0; d < SPATIAL_DIM; ++d) {
|
|
|
out += "I" + dim_string[3 - SPATIAL_DIM + d] + ":" +
|
|
|
std::to_string(IN_DIM[d]) + ", ";
|
|
|
}
|
|
|
} else {
|
|
|
for (int d = 0; d < SPATIAL_DIM; ++d) {
|
|
|
out += "I" + std::to_string(d) + ":" + std::to_string(IN_DIM[d]) + ", ";
|
|
|
}
|
|
|
}
|
|
|
out += "G:" + std::to_string(G) + ", ";
|
|
|
if (SPATIAL_DIM <= 3) {
|
|
|
for (int d = 0; d < SPATIAL_DIM; ++d) {
|
|
|
out += "K" + dim_string[3 - SPATIAL_DIM + d] + ":" +
|
|
|
std::to_string(K[d]) + ", ";
|
|
|
}
|
|
|
for (int d = 0; d < SPATIAL_DIM; ++d) {
|
|
|
out += "stride_" + dim_string[3 - SPATIAL_DIM + d] + ":" +
|
|
|
std::to_string(stride[d]) + ", ";
|
|
|
}
|
|
|
for (int d = 0; d < SPATIAL_DIM * 2; ++d) {
|
|
|
out += "pad_" + dim_string[3 - SPATIAL_DIM + (d % SPATIAL_DIM)] + ":" +
|
|
|
std::to_string(pad[d]) + ", ";
|
|
|
}
|
|
|
for (int d = 0; d < SPATIAL_DIM; ++d) {
|
|
|
out += "dilation_" + dim_string[3 - SPATIAL_DIM + d] + ":" +
|
|
|
std::to_string(dilation[d]);
|
|
|
if (d < SPATIAL_DIM - 1) {
|
|
|
out += ", ";
|
|
|
}
|
|
|
}
|
|
|
} else {
|
|
|
for (int d = 0; d < SPATIAL_DIM; ++d) {
|
|
|
out += "K" + std::to_string(d) + ":" + std::to_string(K[d]) + ", ";
|
|
|
}
|
|
|
for (int d = 0; d < SPATIAL_DIM; ++d) {
|
|
|
out += "stride_" + std::to_string(d) + ":" + std::to_string(stride[d]) +
|
|
|
", ";
|
|
|
}
|
|
|
for (int d = 0; d < SPATIAL_DIM; ++d) {
|
|
|
out += "pad_" + std::to_string(d) + ":" + std::to_string(pad[d]);
|
|
|
if (d < SPATIAL_DIM * 2 - 1) {
|
|
|
out += ", ";
|
|
|
}
|
|
|
}
|
|
|
for (int d = 0; d < SPATIAL_DIM; ++d) {
|
|
|
out += "dilation_" + std::to_string(d) + ":" +
|
|
|
std::to_string(dilation[d]) + ", ";
|
|
|
}
|
|
|
}
|
|
|
if (transposed) {
|
|
|
for (int d = 0; d < SPATIAL_DIM; ++d) {
|
|
|
out += "output_padding_" + std::to_string(d) + ":" +
|
|
|
std::to_string(output_pad[d]) + ", ";
|
|
|
}
|
|
|
}
|
|
|
return out;
|
|
|
}
|
|
|
};
|
|
|
}
|
|
|
|