| // workspace_pool.h — reusable aclnn workspace buffer pool. | |
| // | |
| // Problem: every aclnn op does `aclrtMalloc(workspace)` + `aclrtFree`. For decode at 94 layers | |
| // × ~30 ops = 2820 mallocs/frees per token, this is significant overhead. | |
| // | |
| // Solution: pool of DeviceBuffers, grow-only. Pool returns a pointer >= requested size. | |
| // Most ops reuse the SAME buffer since they don't overlap on-stream (serial execution). | |
| // | |
| // Thread safety: not thread-safe. One pool per Runner (one thread). | |
| class WorkspacePool { | |
| public: | |
| WorkspacePool() = default; | |
| ~WorkspacePool() = default; | |
| WorkspacePool(const WorkspacePool&) = delete; | |
| WorkspacePool& operator=(const WorkspacePool&) = delete; | |
| // Return a device pointer of at least `bytes`. Reuses the current buffer | |
| // if it's big enough; otherwise grows by allocating a new one and | |
| // **retaining old buffers** (async kernels may still be reading them — | |
| // freeing too early would corrupt in-flight workspaces). | |
| // | |
| // Periodically call `reset_after_sync()` when the stream is idle to | |
| // reclaim all-but-largest buffers and reset grow count. | |
| void* alloc(size_t bytes) { | |
| if (bytes == 0) return nullptr; | |
| if (current_size_ < bytes) { | |
| // Keep old buffer alive (don't free!) — aclnn kernels may still use it. | |
| old_bufs_.push_back(std::move(buf_)); | |
| buf_.alloc(bytes); | |
| current_size_ = bytes; | |
| grow_count_++; | |
| } | |
| return buf_.get(); | |
| } | |
| size_t current_size() const { return current_size_; } | |
| size_t grow_count() const { return grow_count_; } | |
| size_t retained_count() const { return old_bufs_.size(); } | |
| // Call only when the stream is guaranteed idle (e.g., after aclrtSynchronizeStream). | |
| // Drops all retained older buffers, freeing device memory. Current active buffer kept. | |
| void reset_after_sync() { | |
| old_bufs_.clear(); | |
| } | |
| void clear() { | |
| old_bufs_.clear(); | |
| buf_ = DeviceBuffer(); | |
| current_size_ = 0; | |
| grow_count_ = 0; | |
| } | |
| private: | |
| DeviceBuffer buf_; // current active (largest so far) | |
| std::vector<DeviceBuffer> old_bufs_; // older, smaller — still live until stream sync | |
| size_t current_size_ = 0; | |
| size_t grow_count_ = 0; | |
| }; | |
| // Convenience: per-stream RAII guard that acts like a `DeviceBuffer` but draws from pool. | |
| // Used in aclnn_ops.h wrappers as a drop-in replacement for the local DeviceBuffer. | |
| class PoolBuffer { | |
| public: | |
| // Fallback mode: if pool is nullptr, allocate own buffer (current behavior). | |
| // Pool mode: return pool's shared pointer. | |
| PoolBuffer(WorkspacePool* pool, size_t bytes) { | |
| if (pool) { | |
| ptr_ = pool->alloc(bytes); | |
| } else if (bytes > 0) { | |
| local_.alloc(bytes); | |
| ptr_ = local_.get(); | |
| } | |
| } | |
| void* get() { return ptr_; } | |
| private: | |
| DeviceBuffer local_; // only used when pool is null | |
| void* ptr_ = nullptr; | |
| }; | |