| namespace c10 { | |
| /// RAII guard that sets a certain default device in its constructor, and | |
| /// changes it back to the device that was originally active upon destruction. | |
| /// | |
| /// The device is always reset to the one that was active at the time of | |
| /// construction of the guard. Even if you `set_device` after construction, the | |
| /// destructor will still reset the device to the one that was active at | |
| /// construction time. | |
| /// | |
| /// This device guard does NOT have an uninitialized state; it is guaranteed | |
| /// to reset a device on exit. If you are in a situation where you *might* | |
| /// want to setup a guard (i.e., are looking for the moral equivalent | |
| /// of optional<DeviceGuard>), see OptionalDeviceGuard. | |
| class DeviceGuard { | |
| public: | |
| /// No default constructor; see Note [Omitted default constructor from RAII] | |
| explicit DeviceGuard() = delete; | |
| /// Set the current device to the passed Device. | |
| explicit DeviceGuard(Device device) : guard_(device) {} | |
| /// This constructor is for testing only. | |
| explicit DeviceGuard( | |
| Device device, | |
| const impl::DeviceGuardImplInterface* impl) | |
| : guard_(device, impl) {} | |
| /// Copy is disallowed | |
| DeviceGuard(const DeviceGuard&) = delete; | |
| DeviceGuard& operator=(const DeviceGuard&) = delete; | |
| /// Move is disallowed, as DeviceGuard does not have an uninitialized state, | |
| /// which is required for moves on types with nontrivial destructors. | |
| DeviceGuard(DeviceGuard&& other) = delete; | |
| DeviceGuard& operator=(DeviceGuard&& other) = delete; | |
| /// Sets the device to the given one. The specified device must be consistent | |
| /// with the device type originally specified during guard construction. | |
| /// | |
| /// TODO: The consistency check here is inconsistent with StreamGuard's | |
| /// behavior with set_stream, where a stream on a different device than | |
| /// the original one isn't an error; we just reset the stream and then | |
| /// switch devices. | |
| void reset_device(at::Device device) { | |
| guard_.reset_device(device); | |
| } | |
| /// This method is for testing only. | |
| void reset_device( | |
| at::Device device, | |
| const impl::DeviceGuardImplInterface* impl) { | |
| guard_.reset_device(device, impl); | |
| } | |
| /// Sets the device index to the given one. The device type is inferred | |
| /// from the original device type the guard was constructed with. | |
| void set_index(DeviceIndex index) { | |
| guard_.set_index(index); | |
| } | |
| /// Returns the device that was set at the time the guard was constructed. | |
| Device original_device() const { | |
| return guard_.original_device(); | |
| } | |
| /// Returns the most recent device that was set using this device guard, | |
| /// either from construction, or via set_device. | |
| Device current_device() const { | |
| return guard_.current_device(); | |
| } | |
| private: | |
| impl::InlineDeviceGuard<impl::VirtualGuardImpl> guard_; | |
| }; | |
| /** | |
| * A OptionalDeviceGuard is an RAII class that sets a device to some value on | |
| * initialization, and resets the device to its original value on destruction. | |
| * Morally, a OptionalDeviceGuard is equivalent to optional<DeviceGuard>, but | |
| * with extra constructors and methods as appropriate. | |
| * | |
| * Besides its obvious use (optionally applying a DeviceGuard), | |
| * OptionalDeviceGuard is often also used for the following idiom: | |
| * | |
| * OptionalDeviceGuard g; | |
| * for (const auto& t : tensors) { | |
| * g.set_device(t.device()); | |
| * do_something_with(t); | |
| * } | |
| * | |
| * This usage is marginally more efficient than constructing a DeviceGuard every | |
| * iteration of the for loop, as it avoids an unnecessary device reset. | |
| * | |
| * Unlike DeviceGuard, a OptionalDeviceGuard may be uninitialized. This occurs | |
| * when you use the nullary constructor, or pass a nullopt to the constructor. | |
| * Uninitialized OptionalDeviceGuards do *nothing*; they do not know what the | |
| * original device was and they do not reset on destruction. This is why | |
| * original_device() and current_device() return optional<Device> rather than | |
| * Device (as they do in DeviceGuard), and also is why we didn't just | |
| * provide OptionalDeviceGuard by default and hide DeviceGuard from users. | |
| * | |
| * The semantics of an OptionalDeviceGuard are exactly explained by thinking | |
| * of it as an optional<DeviceGuard>. In particular, an initialized | |
| * OptionalDeviceGuard doesn't restore device to its value at construction; it | |
| * restores device to its value *at initialization*. So if you have the | |
| * program: | |
| * | |
| * setDevice(1); | |
| * OptionalDeviceGuard g; | |
| * setDevice(2); | |
| * g.reset_device(Device(DeviceType::CUDA, 3)); // initializes! | |
| * | |
| * On destruction, g will reset device to 2, rather than 1. | |
| * | |
| * An uninitialized OptionalDeviceGuard is distinct from a (initialized) | |
| * DeviceGuard whose original_device_ and current_device_ match, since the | |
| * DeviceGuard will still reset the device to original_device_. | |
| */ | |
| class OptionalDeviceGuard { | |
| public: | |
| /// Create an uninitialized guard. Set the guard later using reset_device. | |
| explicit OptionalDeviceGuard() : guard_() {} | |
| /// Initialize the guard, setting the current device to the passed Device. | |
| explicit OptionalDeviceGuard(Device device) : guard_(device) {} | |
| /// Initialize the guard if a Device is passed; otherwise leave the | |
| /// guard uninitialized. | |
| explicit OptionalDeviceGuard(optional<Device> device) : guard_(device) {} | |
| /// Constructor for testing only. | |
| explicit OptionalDeviceGuard( | |
| Device device, | |
| const impl::DeviceGuardImplInterface* impl) | |
| : guard_(device, impl) {} | |
| /// Copy is disallowed | |
| OptionalDeviceGuard(const OptionalDeviceGuard&) = delete; | |
| OptionalDeviceGuard& operator=(const OptionalDeviceGuard&) = delete; | |
| /// Move is disallowed | |
| /// See Note [Explicit initialization of optional fields] | |
| /// and // Note [Move construction for RAII guards is tricky] | |
| /// for rationale. | |
| OptionalDeviceGuard(OptionalDeviceGuard&& other) = delete; | |
| OptionalDeviceGuard& operator=(OptionalDeviceGuard&& other) = delete; | |
| /// Sets the device to the given one. The specified device must be consistent | |
| /// with the device type originally specified during guard construction. | |
| void reset_device(at::Device device) { | |
| guard_.reset_device(device); | |
| } | |
| /// For testing only | |
| void reset_device( | |
| at::Device device, | |
| const impl::DeviceGuardImplInterface* impl) { | |
| guard_.reset_device(device, impl); | |
| } | |
| /// Returns the device that was set at the time the guard was constructed. | |
| optional<Device> original_device() const { | |
| return guard_.original_device(); | |
| } | |
| /// Returns the most recent device that was set using this device guard, | |
| /// either from construction, or via reset_device. | |
| optional<Device> current_device() const { | |
| return guard_.current_device(); | |
| } | |
| private: | |
| impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl> guard_; | |
| }; | |
| // Note [Whither the DeviceGuard boilerplate] | |
| // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| // Design note: in principle, we could avoid these wrappers using: | |
| // | |
| // using DeviceGuard = impl::InlineDeviceGuard<impl::VirtualGuardImpl>; | |
| // using OptionalDeviceGuard = | |
| // impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl>; | |
| // | |
| // But the error messages are worse, and our users can't just look at the | |
| // header file to find out what's going on. Furthermore, for specializations | |
| // like CUDAStreamGuard, it can be profitable to replace some interfaces with | |
| // refined types (e.g., return CUDAStream instead of Stream). So, we eat | |
| // the boilerplate and write out the API explicitly. | |
| } // namespace c10 | |