File size: 7,168 Bytes
c1af2fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
// Some stateful GPU libraries, such as cuDNN, cuBLAS, use handles to store states.
// These handles are tied to device, and these libraries requires/recommends not to
// share handles across host threads.
//
// These libraries recommend using one handle per host thread. We may not want to do
// this because threads are relatively light-weight, but creating and destroying
// handles is expensive (destroying the handle causes synchronizations). DataParallel,
// for example, creates new threads for each forward pass.
//
// This file implements a handle pool mechanism. The handle pool returns handles on
// demand as threads request them. If all existing handles in the pool are in use,
// it creates a new one. As threads terminate, they release handles back into the pool.
// In this way, the handle pool never creates more handles than the high-water mark of
// active threads, so it's efficient with DataParallel.

#pragma once

#include <unordered_map>
#include <vector>
#include <utility>
#include <mutex>
#include <memory>

#include <c10/util/Exception.h>

namespace at::cuda { namespace {

template <typename Handle_t, void Create(Handle_t *), void Destroy(Handle_t)>
struct DeviceThreadHandlePool : public std::enable_shared_from_this<DeviceThreadHandlePool<Handle_t, Create, Destroy>> {

    struct Handle {
    Handle_t handle;
    Handle(bool create = false) : handle(nullptr)
    {
        if(create) Create(&handle);
    }
    // std::vector.emplace() and push_back() may route through temporaries and call
    // copy/move constructors along the way.  If this is the case, we don't want
    // the destructors of temporaries to call cudnnDestroy on the handle.
    // We can achieve safety (for the narrow case of stashing within std::vectors)
    // by making Handle moveable but not copyable, and transferring handle ownership
    // to the latest constructed object.  This is not a substitute for full-blown
    // reference counting, but reference counting may be overkill here.
    // Another alternative is to wrap the saved Handles in unique_ptrs, i.e.,
    // unordered_map<int, vector<unique_ptr<Handle>>> created_handles;
    Handle(const Handle& rhs) = delete;
    // Following https://stackoverflow.com/questions/3279543/what-is-the-copy-and-swap-idiom
    Handle(Handle&& rhs) noexcept : Handle() { std::swap(handle, rhs.handle); }
    // operator= takes argument by value
    Handle& operator=(Handle rhs) { std::swap(handle, rhs.handle); return *this; }
    ~Handle() {
        if(handle) Destroy(handle);
    }
    };

    std::mutex mutex;

    // Handles are lazily created as different threads request them,
    // but are never destroyed until the end of the process.
    // The maximum number of handles this process will create for each device is equal
    // to the high-water mark of the number of concurrently active threads that request
    // handles for that device.
    // When threads terminate, they release their handles back into the pool for reuse.
    // Otherwise, new handles would be created every time new threads were spawned,
    // resulting in poor performance for Python modules that repeatedly or frequently
    // spawned new sets of threads (like DataParallel, which creates a new set of threads
    // for each forward pass).
    //
    // To prevent potential deadlocks, we explicitly choose not to cap the number
    // of handles that are created per device.
    // Example of danger: If we cap the max handles at 4, and 5 threads are sharing a device,
    // only 4 can make forward progress at any time. The other 4 will not release their
    // handles until they exit, so the fifth cannot make progress until then.  This is
    // not a problem...UNLESS all 5 threads attempt some sort of synchronization at an
    // intermediate point (ie, before any of them have exited).  We have no way to anticipate
    // or enforce that user threads will not attempt such intermediate synchronization.
    // The only way to ensure safety is to avoid imposing a cap on the number of handles.
    std::unordered_map<int, std::vector<Handle>> created_handles;
    std::unordered_map<int, std::vector<Handle_t>> available_handles;

    // PoolWindow lazily creates and caches the handles that a particular thread is using,
    // so in the common case handle access doesn't incur either handle creation or a mutex lock.
    class PoolWindow

    {
    public:
    PoolWindow(std::shared_ptr<DeviceThreadHandlePool> parent): weak_parent(std::move(parent)) {}
    ~PoolWindow(){ release(); }

    Handle_t reserve(int device)
    {
        // If this thread already has a handle for this device, return it
        if(my_handles.find(device) != my_handles.end())
        return my_handles[device];

        // otherwise, either grab a handle from the pool if one is available,
        // or if not, create a new one.
        auto parent = weak_parent.lock();
        TORCH_CHECK(parent, "Cannot create handle during program termination");
        std::lock_guard<std::mutex> guard(parent->mutex);

        if(parent->available_handles[device].size() > 0)
        {
        my_handles[device] = parent->available_handles[device].back();
        parent->available_handles[device].pop_back();
        }
        else
        {
        // In local testing, I do observe that emplace_back sometimes routes through temporaries
        // that incur move-constructor and destructor calls.  See comments in Handle above.
        parent->created_handles[device].emplace_back(true /*create*/);
        my_handles[device] = parent->created_handles[device].back().handle;
        }

        return my_handles[device];
    }

    private:
    // Stores the per-device handles currently owned by this thread
    std::unordered_map<int, Handle_t> my_handles;

    std::weak_ptr<DeviceThreadHandlePool> weak_parent;

    // Called by the destructor.  Releases this thread's handles back into the pool.
    void release() {
        if(my_handles.size() > 0) {
            auto parent = weak_parent.lock();
            if (!parent) {
                // If this thread exits after atexit handlers have completed, the
                // cuda context itself may be invalid, so we must leak the handles.
                return;
            }

            std::lock_guard<std::mutex> guard(parent->mutex);
            for(auto d_h : my_handles)
                parent->available_handles[d_h.first].push_back(d_h.second);
        }
    }
    };

    // Warning:
    // If you want to change this function, be aware that this function will be called
    // by multiple threads and there is no mutex guarding the call of this function, so
    // make sure your implementation is thread-safe.
    PoolWindow *newPoolWindow() {
        // The returned pointer will be owned by a thread local variable
        // so that different threads does not share the same PoolWindow.
        return new PoolWindow(this->shared_from_this());
    }
};

}}  // namespace at::cuda::detail::<anonymous>