File size: 3,085 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
// workspace_pool.h — reusable aclnn workspace buffer pool.
//
// Problem: every aclnn op does `aclrtMalloc(workspace)` + `aclrtFree`. For decode at 94 layers
// × ~30 ops = 2820 mallocs/frees per token, this is significant overhead.
//
// Solution: pool of DeviceBuffers, grow-only. Pool returns a pointer >= requested size.
// Most ops reuse the SAME buffer since they don't overlap on-stream (serial execution).
//
// Thread safety: not thread-safe. One pool per Runner (one thread).
#pragma once
#include "acl_common.h"
#include <algorithm>
#include <vector>

class WorkspacePool {
public:
    WorkspacePool() = default;
    ~WorkspacePool() = default;
    WorkspacePool(const WorkspacePool&) = delete;
    WorkspacePool& operator=(const WorkspacePool&) = delete;

    // Return a device pointer of at least `bytes`. Reuses the current buffer
    // if it's big enough; otherwise grows by allocating a new one and
    // **retaining old buffers** (async kernels may still be reading them —
    // freeing too early would corrupt in-flight workspaces).
    //
    // Periodically call `reset_after_sync()` when the stream is idle to
    // reclaim all-but-largest buffers and reset grow count.
    void* alloc(size_t bytes) {
        if (bytes == 0) return nullptr;
        if (current_size_ < bytes) {
            // Keep old buffer alive (don't free!) — aclnn kernels may still use it.
            old_bufs_.push_back(std::move(buf_));
            buf_.alloc(bytes);
            current_size_ = bytes;
            grow_count_++;
        }
        return buf_.get();
    }

    size_t current_size() const { return current_size_; }
    size_t grow_count() const { return grow_count_; }
    size_t retained_count() const { return old_bufs_.size(); }

    // Call only when the stream is guaranteed idle (e.g., after aclrtSynchronizeStream).
    // Drops all retained older buffers, freeing device memory. Current active buffer kept.
    void reset_after_sync() {
        old_bufs_.clear();
    }

    void clear() {
        old_bufs_.clear();
        buf_ = DeviceBuffer();
        current_size_ = 0;
        grow_count_ = 0;
    }

private:
    DeviceBuffer buf_;                       // current active (largest so far)
    std::vector<DeviceBuffer> old_bufs_;     // older, smaller — still live until stream sync
    size_t current_size_ = 0;
    size_t grow_count_ = 0;
};

// Convenience: per-stream RAII guard that acts like a `DeviceBuffer` but draws from pool.
// Used in aclnn_ops.h wrappers as a drop-in replacement for the local DeviceBuffer.
class PoolBuffer {
public:
    // Fallback mode: if pool is nullptr, allocate own buffer (current behavior).
    // Pool mode: return pool's shared pointer.
    PoolBuffer(WorkspacePool* pool, size_t bytes) {
        if (pool) {
            ptr_ = pool->alloc(bytes);
        } else if (bytes > 0) {
            local_.alloc(bytes);
            ptr_ = local_.get();
        }
    }
    void* get() { return ptr_; }

private:
    DeviceBuffer local_;   // only used when pool is null
    void*        ptr_ = nullptr;
};