File size: 8,169 Bytes
9dd3461 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <ATen/functorch/Macros.h>
#include <ATen/functorch/BatchedTensorImpl.h>
namespace at {
namespace functorch {
// This files contains the legacy (now-deprecated) batching rule API.
// Please try to use the new-style batching rule API (see writing_batch_rules.md)
// This file contains abstractions used for transforming *logical* vmap arguments
// into *physical* arguments. (Keep reading for definitions of these terms).
// NOTE: [Logical vs physical args]
// Consider the following vmap.
// vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4))
// This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4],
// with batch dims 0 and 2:
// BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)])
//
// We say the *logical* view of the tensor has size [3] -- tensors inside
// `func` appear to have size [3].
// However, the *physical* underlying tensor (the one passed to vmap) has size
// [2, 3, 4].
//
// This notion of logical vs physical also extends to non-tensor arguments.
// Consider the previous tensor; let's assume the user called
// `torch.sum(tensor, dim=0)` inside of `func`. Then the logical
// dimension they are reducing over is dim 0 but the physical dim is dim 1
// (the first non-batch dimension)
// Forward declared; see NOTE: [What is a VmapPhysicalView?]
struct VmapPhysicalView;
// Most PyTorch operators take 4 or fewer inputs.
constexpr int64_t kVmapTransformStaticInputSize = 4;
using VmapPhysicalViewVec = SmallVector<VmapPhysicalView, kVmapTransformStaticInputSize>;
// Pytorch generally advertises good performance for <= 5 dims.
// (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap
// dimensions to get 8. Adjust this number as necessary
constexpr int64_t kVmapStaticDimVecSize = 8;
using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>;
// NOTE: [What is an VmapTransform?]
// An *VmapTransform* converts logical views of tensors to physical views.
//
// Batching rules use VmapTransforms to convert logical arguments to
// physical arguments, then call one or more at:: operator that handles the
// physical arguments, and then converts the physical result back to a logical
// argument.
// VmapTransform for operators that take tensors with multiple batch dims.
// Given one or more logical views on Tensors, `logicalToPhysical`
// permutes all of the batch dims to the front of the tensor, aligns
// and expands the batch dims to match each other (according to their `level`),
// and returns a VmapPhysicalView on the tensor(s).
struct TORCH_API MultiBatchVmapTransform {
static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor);
static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors);
};
// VmapTransform for operators that broadcast all inputs.
// Given some logical views on Tensors, `logicalToPhysical`:
// - permutes all of the batch dims to the front of the tensors
// - aligns all the batch dims to the collective levels of all of the tensors.
// If a tensor does not have a batch dim for a vmap level, then it receives
// a size-one dimension for said level.
// - aligns the non-batch dims to have the same dimensionality, adding extra
// size-1 dimensions in between the batch dimensions and the non-batch dimensions
// so that the batch dimensions are lined up from the right.
//
// For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch
// dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap tensors
// of size (B, 1, 2) and (B, 3, 2).
//
// Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns
// VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't
// actually *need* to return a tensor of size (1, 2) for the second tensor
// because the broadcasting operation takes care of that for us, but we do
// it anyways to keep things simple.
struct TORCH_API BroadcastingVmapTransform {
static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
};
// Forward declared, if you're reading this file head to toe, don't worry about
// it yet.
struct VmapPhysicalToLogicalMap;
// NOTE: [What is a VmapPhysicalView?]
// VmapPhysicalView represents a physical view on a Tensor.
//
// One can use it to further convert logical dimension indices, logical shapes,
// and more to their physical variants, or convert a new (physical) tensor into
// a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented).
//
// VmapPhysicalView stores a physical tensor with all of its batch dimensions at
// the front and some levels that correspond to said batch dimensions.
//
// The levels bitset specifies which vmap levels correspond to the batch
// dimensions at the front of the tensor. In particular, the number of set bits
// corresponds to the number of batch dimensions on `tensor` and the rightmost
// bit of `levels` specifies the maximum number of nested vmaps we are in at
// this point in time.
// For example, given:
// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
//
// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
// than or equal to 3.
// bitset: 010100
// ^
// |
// levels: 012345
struct TORCH_API VmapPhysicalView {
VmapPhysicalView(Tensor&& tensor, std::bitset<kVmapNumLevels> levels)
: levels_(levels), tensor_(tensor) {
// TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor));
}
Tensor& tensor() { return tensor_; }
const Tensor& tensor() const { return tensor_; }
// Maps logical dim indices to physical dim indices. Also does dim wrapping.
//
// For example, given:
// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3})
//
// Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}.
// This is because the size of levels tell us that the first two dimensions
// of `tensor_` are batch dimensions, so a logical dim of `n` is actually
// a physical dim of `n + 2`.
VmapDimVector getPhysicalDims(IntArrayRef logical_dims) const;
int64_t getPhysicalDim(int64_t logical_dim) const;
// Returns a VmapPhysicalToLogicalMap object. This can be used for
// mapping a physical tensor to a new logical tensor (BatchedTensor)
VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const;
// Maps a logical shape to a physical shape by pre-pending the batch
// sizes to the logical shape.
VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;
SymDimVector getPhysicalShape(c10::SymIntArrayRef logical_shape) const;
int64_t numBatchDims() const;
private:
int64_t numLogicalDims() const;
std::bitset<kVmapNumLevels> levels_;
Tensor tensor_;
};
// Convenience struct used for mapping a physical tensor (a non-BatchedTensor)
// to a logical one (BatchedTensor). It holds some levels that are used to do the
// mapping and assumes that the batch dimensions in the physical tensor all
// occur at the front of the tensor.
struct TORCH_API VmapPhysicalToLogicalMap {
VmapPhysicalToLogicalMap(std::bitset<kVmapNumLevels> levels): levels_(levels) {}
// Maps a physical tensor to a new logical tensor (BatchedTensor).
// Assumes that all of the "batch dimensions" are at the front
// of the physical tensor. For example, given:
// - x = rank-4 Tensor with size 2, 3, 5, 7
// - levels = (2, 4)
// Returns:
// - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)])
Tensor apply(const Tensor& physical_tensor) const;
// Given a vector of physical tensors,
// 1. maps each tensor to a new logical tensor. Assumes that all of the
// "batch dimensions" are at the front of the physical tensors.
// 2. stores the new logical tensors back into the passed-in vector. This is
// to avoid additional dynamic allocations.
void applyInplace(std::vector<Tensor>& physical_tensors) const;
std::bitset<kVmapNumLevels> levels_;
};
}
} // namespace at
|