|
|
#pragma once |
|
|
|
|
|
#include <c10/core/AutogradState.h> |
|
|
#include <c10/macros/Macros.h> |
|
|
|
|
|
namespace c10 { |
|
|
|
|
|
struct TORCH_API GradMode { |
|
|
static bool is_enabled(); |
|
|
static void set_enabled(bool enabled); |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
struct TORCH_API AutoGradMode { |
|
|
AutoGradMode(bool enabled) : prev_mode(GradMode::is_enabled()) { |
|
|
GradMode::set_enabled(enabled); |
|
|
} |
|
|
~AutoGradMode() { |
|
|
GradMode::set_enabled(prev_mode); |
|
|
} |
|
|
bool prev_mode; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
struct TORCH_API NoGradGuard : public AutoGradMode { |
|
|
NoGradGuard() : AutoGradMode(false) {} |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
struct TORCH_API AutoFwGradMode { |
|
|
AutoFwGradMode(bool enabled) |
|
|
: prev_mode(AutogradState::get_tls_state().get_fw_grad_mode()) { |
|
|
AutogradState::get_tls_state().set_fw_grad_mode(enabled); |
|
|
} |
|
|
~AutoFwGradMode() { |
|
|
AutogradState::get_tls_state().set_fw_grad_mode(prev_mode); |
|
|
} |
|
|
bool prev_mode; |
|
|
}; |
|
|
|
|
|
} |
|
|
|