File size: 5,720 Bytes
9dd3461
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <bitset>

#include <ATen/ArrayRef.h>
#include <ATen/SmallVector.h>
#include <ATen/Tensor.h>

namespace at {
namespace functorch {

using Tensor = at::Tensor;

// We assume this in a few other places in the codebase,
// but there isn't a centralized definition.
constexpr int64_t kVmapMaxTensorDims = 64;

// The valid vmap levels range from [0, 64). This effectively means that we
// support a maximum of 64 nested vmaps.
constexpr int64_t kVmapNumLevels = 64;

// Store this number of elements of BatchDims on the stack. Most people will
// probably use <= 5 nested vmaps, but adjust this number as necessary.
constexpr int64_t kBatchDimsStackSize = 5;

// A BatchedTensorImpl holds an underlying Tensor and a single batch dim
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
// BatchedTensorImpl.
//
// The batch dimensions are treated as being "private"; they are not user-visible.
// For example, in the following Tensor,
//    bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0)
// dimension 0 is batch dimension.
//
// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) tensor.
struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
  explicit BatchedTensorImpl(at::DispatchKeySet key_set, Tensor value, int64_t dim, int64_t level);

  // Returns batch dimension of this tensor
  int64_t bdim() const { return bdim_; }

  // Returns batch dimension of this tensor
  int64_t level() const { return level_; }

  // BatchedTensorImpl wraps a Tensor
  const Tensor& value() const { return value_; }

  // Given a public dimension index, return the dimension index in the underlying
  // value() tensor.
  // For example, if we have
  //    bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0)
  // bt.actualDim(0) -> 1
  // bt.actualDim(1) -> 2
  // bt.actualDim(2) -> 3
  // bt.actualDim(3) -> Error
  int64_t actualDim(int64_t dim, bool wrap_dim = true) const;

  // We have to override this because we opted into CustomStrides
  IntArrayRef strides_custom() const override;
  SymIntArrayRef sym_strides_custom() const override;
  // Override a bunch of methods inherited from TensorImpl to return error messages.
  bool is_contiguous_custom(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override;
  void set_size(int64_t dim, int64_t new_size) override;
  void set_stride(int64_t dim, int64_t new_stride) override;
  void set_storage_offset(int64_t storage_offset) override;
#ifdef DEBUG
  bool has_storage() const override;
#endif

  void refreshTensorMetadata();

  // Used in torchdim. torchdim uses non-lexical BatchedTensor; the way it
  // accomplishes this is a hack where it is able to modify the levels of
  // BatchedTensor to match the level of the current vmap transform.
  void _unsafe_set_level(int64_t level) {
    level_ = level;
  }

  // Used in batching rule for in-place view operations that can change
  // the index of the bdim (think squeeze_, unsqueeze_)
  void unsafe_set_bdim(int64_t bdim) {
    // NB: you MUST call refreshTensorMetadata after doing this.
    bdim_ = bdim;
  }
 private:
  // see NOTE: [BatchedTensorImpl levels invariant]
  void checkInvariants() const;
  const char* tensorimpl_type_name() const override;

  Tensor value_;

  int64_t level_;
  int64_t bdim_;
};

// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
// BatchedTensorImpl.
inline bool isBatchedTensor(const Tensor& tensor) {
  return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::FuncTorchBatched);
}

// It is unsafe to call this on a Tensor that is not backed by a
// BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
inline BatchedTensorImpl* unsafeGetBatchedImpl(Tensor tensor) {
  return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
}

inline BatchedTensorImpl* maybeGetBatchedImpl(Tensor tensor) {
  if (!isBatchedTensor(tensor)) {
    return nullptr;
  }
  return unsafeGetBatchedImpl(tensor);
}

// Returns a bitset. If bit i is set, then that means dim i is a batchdim.
inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(int64_t dim) {
  std::bitset<kVmapMaxTensorDims> is_bdim;
  is_bdim.set(dim);
  return is_bdim;
}

// Creates a bitset for the given level
inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(int64_t level) {
  std::bitset<kVmapNumLevels> result;
  result.set(level);
  return result;
}

// Use this to construct a BatchedTensor from a regular Tensor
TORCH_API Tensor makeBatched(const Tensor& tensor, int64_t dim, int64_t level);

// Adds a batch dim to `tensor`, returning a BatchedTensor
TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t dim, int64_t level);

// Certain dispatch keys must be propagated to the BatchedTensor (or, in general,
// any wrapper Tensor subclasses). This is because there are methods on Tensor
// that skip dispatch and check for the presence of a dispatch key (e.g. is_cpu()).
// TODO: should probably contain more (or all?) backend keys
constexpr DispatchKeySet kKeysToPropagateToWrapper({
  DispatchKey::Negative,
  DispatchKey::Conjugate,
  DispatchKey::XLA,
  DispatchKey::CUDA,
  DispatchKey::CPU,
});

inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {
  auto key_set = tensor.unsafeGetTensorImpl()->key_set();
  return key_set & kKeysToPropagateToWrapper;
}

}
}