koichi12 commited on
Commit
db5dd97
·
verified ·
1 Parent(s): 9521f07

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .venv/lib/python3.11/site-packages/grpc/__pycache__/_channel.cpython-311.pyc +3 -0
  3. .venv/lib/python3.11/site-packages/vllm/core/block_manager.py +520 -0
  4. .venv/lib/python3.11/site-packages/vllm/core/evictor.py +156 -0
  5. .venv/lib/python3.11/site-packages/vllm/device_allocator/__init__.py +0 -0
  6. .venv/lib/python3.11/site-packages/vllm/device_allocator/__pycache__/__init__.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/vllm/device_allocator/cumem.py +256 -0
  8. .venv/lib/python3.11/site-packages/vllm/distributed/__init__.py +5 -0
  9. .venv/lib/python3.11/site-packages/vllm/distributed/__pycache__/__init__.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/vllm/distributed/__pycache__/communication_op.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/vllm/distributed/__pycache__/parallel_state.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/vllm/distributed/__pycache__/utils.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/vllm/distributed/communication_op.py +34 -0
  14. .venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/__init__.py +0 -0
  15. .venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/__pycache__/custom_all_reduce.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/__pycache__/pynccl.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/cuda_wrapper.py +173 -0
  18. .venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/custom_all_reduce.py +305 -0
  19. .venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/custom_all_reduce_utils.py +257 -0
  20. .venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/hpu_communicator.py +50 -0
  21. .venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/pynccl.py +217 -0
  22. .venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/shm_broadcast.py +530 -0
  23. .venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/xpu_communicator.py +49 -0
  24. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/__init__.py +0 -0
  25. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/__pycache__/__init__.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/__pycache__/kv_transfer_agent.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/__init__.py +0 -0
  28. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/__pycache__/__init__.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/__pycache__/base.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/__pycache__/factory.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/__pycache__/simple_connector.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/base.py +123 -0
  33. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/factory.py +50 -0
  34. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/simple_connector.py +314 -0
  35. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py +0 -0
  36. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/__pycache__/__init__.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/__pycache__/base.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/__pycache__/simple_buffer.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py +109 -0
  40. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +243 -0
  41. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/__init__.py +0 -0
  42. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/__pycache__/__init__.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/__pycache__/base.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/__pycache__/mooncake_pipe.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/__pycache__/pynccl_pipe.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/base.py +66 -0
  47. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py +274 -0
  48. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +277 -0
  49. .venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_transfer_agent.py +76 -0
  50. .venv/lib/python3.11/site-packages/vllm/distributed/parallel_state.py +1285 -0
.gitattributes CHANGED
@@ -199,3 +199,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
199
  .venv/lib/python3.11/site-packages/google/_upb/_message.abi3.so filter=lfs diff=lfs merge=lfs -text
200
  .venv/lib/python3.11/site-packages/google/protobuf/__pycache__/descriptor_pb2.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
201
  .venv/lib/python3.11/site-packages/jinja2/__pycache__/compiler.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
 
 
199
  .venv/lib/python3.11/site-packages/google/_upb/_message.abi3.so filter=lfs diff=lfs merge=lfs -text
200
  .venv/lib/python3.11/site-packages/google/protobuf/__pycache__/descriptor_pb2.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
201
  .venv/lib/python3.11/site-packages/jinja2/__pycache__/compiler.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
202
+ .venv/lib/python3.11/site-packages/grpc/__pycache__/_channel.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/grpc/__pycache__/_channel.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4a55c7e16388e486f345b7c775a758b4a05a398d378d7491610665c89805e0f
3
+ size 103674
.venv/lib/python3.11/site-packages/vllm/core/block_manager.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """A block manager that manages token blocks."""
3
+ from typing import Dict, List, Optional
4
+ from typing import Sequence as GenericSequence
5
+ from typing import Tuple
6
+
7
+ from vllm.core.block.block_table import BlockTable
8
+ from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
9
+ from vllm.core.block.interfaces import Block
10
+ from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
11
+ LastAccessBlocksTracker)
12
+ from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
13
+ from vllm.core.interfaces import AllocStatus, BlockSpaceManager
14
+ from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
15
+ from vllm.utils import Device
16
+
17
+ SeqId = int
18
+ EncoderSeqId = str
19
+
20
+
21
+ class SelfAttnBlockSpaceManager(BlockSpaceManager):
22
+ """BlockSpaceManager which manages the allocation of KV cache.
23
+
24
+ It owns responsibility for allocation, swapping, allocating memory for
25
+ autoregressively-generated tokens, and other advanced features such as
26
+ prefix caching, forking/copy-on-write, and sliding-window memory allocation.
27
+
28
+ This class implements the design described in
29
+ https://github.com/vllm-project/vllm/pull/3492.
30
+
31
+ Lookahead slots
32
+ The block manager has the notion of a "lookahead slot". These are slots
33
+ in the KV cache that are allocated for a sequence. Unlike the other
34
+ allocated slots, the content of these slots is undefined -- the worker
35
+ may use the memory allocations in any way.
36
+
37
+ In practice, a worker could use these lookahead slots to run multiple
38
+ forward passes for a single scheduler invocation. Each successive
39
+ forward pass would write KV activations to the corresponding lookahead
40
+ slot. This allows low inter-token latency use-cases, where the overhead
41
+ of continuous batching scheduling is amortized over >1 generated tokens.
42
+
43
+ Speculative decoding uses lookahead slots to store KV activations of
44
+ proposal tokens.
45
+
46
+ See https://github.com/vllm-project/vllm/pull/3250 for more information
47
+ on lookahead scheduling.
48
+
49
+ Args:
50
+ block_size (int): The size of each memory block.
51
+ num_gpu_blocks (int): The number of memory blocks allocated on GPU.
52
+ num_cpu_blocks (int): The number of memory blocks allocated on CPU.
53
+ watermark (float, optional): The threshold used for memory swapping.
54
+ Defaults to 0.01.
55
+ sliding_window (Optional[int], optional): The size of the sliding
56
+ window. Defaults to None.
57
+ enable_caching (bool, optional): Flag indicating whether caching is
58
+ enabled. Defaults to False.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ block_size: int,
64
+ num_gpu_blocks: int,
65
+ num_cpu_blocks: int,
66
+ watermark: float = 0.01,
67
+ sliding_window: Optional[int] = None,
68
+ enable_caching: bool = False,
69
+ ) -> None:
70
+ self.block_size = block_size
71
+ self.num_total_gpu_blocks = num_gpu_blocks
72
+ self.num_total_cpu_blocks = num_cpu_blocks
73
+
74
+ self.sliding_window = sliding_window
75
+ # max_block_sliding_window is the max number of blocks that need to be
76
+ # allocated
77
+ self.max_block_sliding_window = None
78
+ if sliding_window is not None:
79
+ # +1 here because // rounds down
80
+ num_blocks = sliding_window // block_size + 1
81
+ # +1 here because the last block may not be full,
82
+ # and so the sequence stretches one more block at the beginning
83
+ # For example, if sliding_window is 3 and block_size is 4,
84
+ # we may need 2 blocks when the second block only holds 1 token.
85
+ self.max_block_sliding_window = num_blocks + 1
86
+
87
+ self.watermark = watermark
88
+ assert watermark >= 0.0
89
+
90
+ self.enable_caching = enable_caching
91
+
92
+ self.watermark_blocks = int(watermark * num_gpu_blocks)
93
+
94
+ self.block_allocator = CpuGpuBlockAllocator.create(
95
+ allocator_type="prefix_caching" if enable_caching else "naive",
96
+ num_gpu_blocks=num_gpu_blocks,
97
+ num_cpu_blocks=num_cpu_blocks,
98
+ block_size=block_size,
99
+ )
100
+
101
+ self.block_tables: Dict[SeqId, BlockTable] = {}
102
+ self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {}
103
+
104
+ self._computed_blocks_tracker = ComputedBlocksTracker(
105
+ self.block_allocator, self.block_size, self.enable_caching)
106
+ self._last_access_blocks_tracker = LastAccessBlocksTracker(
107
+ self.block_allocator)
108
+
109
+ def can_allocate(self,
110
+ seq_group: SequenceGroup,
111
+ num_lookahead_slots: int = 0) -> AllocStatus:
112
+ # FIXME(woosuk): Here we assume that all sequences in the group share
113
+ # the same prompt. This may not be true for preempted sequences.
114
+
115
+ check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
116
+
117
+ seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
118
+ num_required_blocks = BlockTable.get_num_required_blocks(
119
+ seq.get_token_ids(),
120
+ block_size=self.block_size,
121
+ num_lookahead_slots=num_lookahead_slots,
122
+ )
123
+
124
+ if seq_group.is_encoder_decoder():
125
+ encoder_seq = seq_group.get_encoder_seq()
126
+ assert encoder_seq is not None
127
+ num_required_blocks += BlockTable.get_num_required_blocks(
128
+ encoder_seq.get_token_ids(),
129
+ block_size=self.block_size,
130
+ )
131
+
132
+ if self.max_block_sliding_window is not None:
133
+ num_required_blocks = min(num_required_blocks,
134
+ self.max_block_sliding_window)
135
+
136
+ num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
137
+ device=Device.GPU)
138
+
139
+ # Use watermark to avoid frequent cache eviction.
140
+ if (self.num_total_gpu_blocks - num_required_blocks
141
+ < self.watermark_blocks):
142
+ return AllocStatus.NEVER
143
+ if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
144
+ return AllocStatus.OK
145
+ else:
146
+ return AllocStatus.LATER
147
+
148
+ def _allocate_sequence(self, seq: Sequence) -> BlockTable:
149
+ block_table = BlockTable(
150
+ block_size=self.block_size,
151
+ block_allocator=self.block_allocator,
152
+ max_block_sliding_window=self.max_block_sliding_window,
153
+ )
154
+ if seq.get_token_ids():
155
+ # NOTE: If there are any factors affecting the block besides
156
+ # token_ids, they should be added as input to extra_hash.
157
+ extra_hash = seq.extra_hash()
158
+
159
+ # Add blocks to the block table only if the sequence is non empty.
160
+ block_table.allocate(token_ids=seq.get_token_ids(),
161
+ extra_hash=extra_hash)
162
+
163
+ return block_table
164
+
165
+ def allocate(self, seq_group: SequenceGroup) -> None:
166
+
167
+ # Allocate self-attention block tables for decoder sequences
168
+ waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
169
+ assert not (set(seq.seq_id for seq in waiting_seqs)
170
+ & self.block_tables.keys()), "block table already exists"
171
+
172
+ # NOTE: Here we assume that all sequences in the group have the same
173
+ # prompt.
174
+ seq = waiting_seqs[0]
175
+ block_table: BlockTable = self._allocate_sequence(seq)
176
+ self.block_tables[seq.seq_id] = block_table
177
+
178
+ # Track seq
179
+ self._last_access_blocks_tracker.add_seq(seq.seq_id)
180
+
181
+ # Assign the block table for each sequence.
182
+ for seq in waiting_seqs[1:]:
183
+ self.block_tables[seq.seq_id] = block_table.fork()
184
+
185
+ # Track seq
186
+ self._last_access_blocks_tracker.add_seq(seq.seq_id)
187
+
188
+ # Allocate cross-attention block table for encoder sequence
189
+ #
190
+ # NOTE: Here we assume that all sequences in the group have the same
191
+ # encoder prompt.
192
+ request_id = seq_group.request_id
193
+
194
+ assert (request_id
195
+ not in self.cross_block_tables), \
196
+ "block table already exists"
197
+
198
+ check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
199
+
200
+ if seq_group.is_encoder_decoder():
201
+ encoder_seq = seq_group.get_encoder_seq()
202
+ assert encoder_seq is not None
203
+ block_table = self._allocate_sequence(encoder_seq)
204
+ self.cross_block_tables[request_id] = block_table
205
+
206
+ def can_append_slots(self, seq_group: SequenceGroup,
207
+ num_lookahead_slots: int) -> bool:
208
+ """Determine if there is enough space in the GPU KV cache to continue
209
+ generation of the specified sequence group.
210
+
211
+ We use a worst-case heuristic: assume each touched block will require a
212
+ new allocation (either via CoW or new block). We can append slots if the
213
+ number of touched blocks is less than the number of free blocks.
214
+
215
+ "Lookahead slots" are slots that are allocated in addition to the slots
216
+ for known tokens. The contents of the lookahead slots are not defined.
217
+ This is used by speculative decoding when speculating future tokens.
218
+ """
219
+
220
+ num_touched_blocks = 0
221
+ for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
222
+ block_table = self.block_tables[seq.seq_id]
223
+
224
+ num_touched_blocks += (
225
+ block_table.get_num_blocks_touched_by_append_slots(
226
+ token_ids=block_table.get_unseen_token_ids(
227
+ seq.get_token_ids()),
228
+ num_lookahead_slots=num_lookahead_slots,
229
+ ))
230
+
231
+ num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
232
+ Device.GPU)
233
+ return num_touched_blocks <= num_free_gpu_blocks
234
+
235
+ def append_slots(
236
+ self,
237
+ seq: Sequence,
238
+ num_lookahead_slots: int,
239
+ ) -> List[Tuple[int, int]]:
240
+
241
+ block_table = self.block_tables[seq.seq_id]
242
+
243
+ block_table.append_token_ids(
244
+ token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()),
245
+ num_lookahead_slots=num_lookahead_slots,
246
+ num_computed_slots=seq.data.get_num_computed_tokens(),
247
+ extra_hash=seq.extra_hash(),
248
+ )
249
+ # Return any new copy-on-writes.
250
+ new_cows = self.block_allocator.clear_copy_on_writes()
251
+ return new_cows
252
+
253
+ def free(self, seq: Sequence) -> None:
254
+ seq_id = seq.seq_id
255
+
256
+ if seq_id not in self.block_tables:
257
+ # Already freed or haven't been scheduled yet.
258
+ return
259
+
260
+ # Update seq block ids with the latest access time
261
+ self._last_access_blocks_tracker.update_seq_blocks_last_access(
262
+ seq_id, self.block_tables[seq.seq_id].physical_block_ids)
263
+
264
+ # Untrack seq
265
+ self._last_access_blocks_tracker.remove_seq(seq_id)
266
+ self._computed_blocks_tracker.remove_seq(seq_id)
267
+
268
+ # Free table/blocks
269
+ self.block_tables[seq_id].free()
270
+ del self.block_tables[seq_id]
271
+
272
+ def free_cross(self, seq_group: SequenceGroup) -> None:
273
+ request_id = seq_group.request_id
274
+ if request_id not in self.cross_block_tables:
275
+ # Already freed or hasn't been scheduled yet.
276
+ return
277
+ self.cross_block_tables[request_id].free()
278
+ del self.cross_block_tables[request_id]
279
+
280
+ def get_block_table(self, seq: Sequence) -> List[int]:
281
+ block_ids = self.block_tables[seq.seq_id].physical_block_ids
282
+ return block_ids # type: ignore
283
+
284
+ def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
285
+ request_id = seq_group.request_id
286
+ assert request_id in self.cross_block_tables
287
+ block_ids = self.cross_block_tables[request_id].physical_block_ids
288
+ assert all(b is not None for b in block_ids)
289
+ return block_ids # type: ignore
290
+
291
+ def access_all_blocks_in_seq(self, seq: Sequence, now: float):
292
+ if self.enable_caching:
293
+ # Record the latest access time for the sequence. The actual update
294
+ # of the block ids is deferred to the sequence free(..) call, since
295
+ # only during freeing of block ids, the blocks are actually added to
296
+ # the evictor (which is when the most updated time is required)
297
+ # (This avoids expensive calls to mark_blocks_as_accessed(..))
298
+ self._last_access_blocks_tracker.update_last_access(
299
+ seq.seq_id, now)
300
+
301
+ def mark_blocks_as_computed(self, seq_group: SequenceGroup,
302
+ token_chunk_size: int):
303
+ # If prefix caching is enabled, mark immutable blocks as computed
304
+ # right after they have been scheduled (for prefill). This assumes
305
+ # the scheduler is synchronous so blocks are actually computed when
306
+ # scheduling the next batch.
307
+ self.block_allocator.mark_blocks_as_computed([])
308
+
309
+ def get_common_computed_block_ids(
310
+ self, seqs: List[Sequence]) -> GenericSequence[int]:
311
+ """Determine which blocks for which we skip prefill.
312
+
313
+ With prefix caching we can skip prefill for previously-generated blocks.
314
+ Currently, the attention implementation only supports skipping cached
315
+ blocks if they are a contiguous prefix of cached blocks.
316
+
317
+ This method determines which blocks can be safely skipped for all
318
+ sequences in the sequence group.
319
+ """
320
+ computed_seq_block_ids = []
321
+ for seq in seqs:
322
+ all_blocks = self.block_tables[seq.seq_id].physical_block_ids
323
+ num_cached_tokens = (
324
+ self._computed_blocks_tracker.get_num_cached_tokens(seq))
325
+ assert num_cached_tokens % self.block_size == 0
326
+ num_cached_blocks = num_cached_tokens // self.block_size
327
+ computed_block_ids = all_blocks[:num_cached_blocks]
328
+ computed_seq_block_ids.append(computed_block_ids)
329
+
330
+ # NOTE(sang): This assumes seq_block_ids doesn't contain any None.
331
+ return self.block_allocator.get_common_computed_block_ids(
332
+ computed_seq_block_ids) # type: ignore
333
+
334
+ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
335
+ if parent_seq.seq_id not in self.block_tables:
336
+ # Parent sequence has either been freed or never existed.
337
+ return
338
+ src_block_table = self.block_tables[parent_seq.seq_id]
339
+ self.block_tables[child_seq.seq_id] = src_block_table.fork()
340
+
341
+ # Track child seq
342
+ self._last_access_blocks_tracker.add_seq(child_seq.seq_id)
343
+
344
+ def can_swap_in(self, seq_group: SequenceGroup,
345
+ num_lookahead_slots: int) -> AllocStatus:
346
+ """Returns the AllocStatus for the given sequence_group
347
+ with num_lookahead_slots.
348
+
349
+ Args:
350
+ sequence_group (SequenceGroup): The sequence group to swap in.
351
+ num_lookahead_slots (int): Number of lookahead slots used in
352
+ speculative decoding, default to 0.
353
+
354
+ Returns:
355
+ AllocStatus: The AllocStatus for the given sequence group.
356
+ """
357
+ return self._can_swap(seq_group, Device.GPU, SequenceStatus.SWAPPED,
358
+ num_lookahead_slots)
359
+
360
+ def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
361
+ """Returns the block id mapping (from CPU to GPU) generated by
362
+ swapping in the given seq_group with num_lookahead_slots.
363
+
364
+ Args:
365
+ seq_group (SequenceGroup): The sequence group to swap in.
366
+
367
+ Returns:
368
+ List[Tuple[int, int]]: The mapping of swapping block from CPU
369
+ to GPU.
370
+ """
371
+ physical_block_id_mapping = []
372
+ for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
373
+ blocks = self.block_tables[seq.seq_id].blocks
374
+ if len(blocks) == 0:
375
+ continue
376
+
377
+ seq_swap_mapping = self.block_allocator.swap(blocks=blocks,
378
+ src_device=Device.CPU,
379
+ dst_device=Device.GPU)
380
+
381
+ # Refresh the block ids of the table (post-swap)
382
+ self.block_tables[seq.seq_id].update(blocks)
383
+
384
+ seq_physical_block_id_mapping = {
385
+ self.block_allocator.get_physical_block_id(
386
+ Device.CPU, cpu_block_id):
387
+ self.block_allocator.get_physical_block_id(
388
+ Device.GPU, gpu_block_id)
389
+ for cpu_block_id, gpu_block_id in seq_swap_mapping.items()
390
+ }
391
+
392
+ physical_block_id_mapping.extend(
393
+ list(seq_physical_block_id_mapping.items()))
394
+
395
+ return physical_block_id_mapping
396
+
397
+ def can_swap_out(self, seq_group: SequenceGroup) -> bool:
398
+ """Returns whether we can swap out the given sequence_group
399
+ with num_lookahead_slots.
400
+
401
+ Args:
402
+ seq_group (SequenceGroup): The sequence group to swap out.
403
+ num_lookahead_slots (int): Number of lookahead slots used in
404
+ speculative decoding, default to 0.
405
+
406
+ Returns:
407
+ bool: Whether it's possible to swap out current sequence group.
408
+ """
409
+ alloc_status = self._can_swap(seq_group, Device.CPU,
410
+ SequenceStatus.RUNNING)
411
+ return alloc_status == AllocStatus.OK
412
+
413
+ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
414
+ """Returns the block id mapping (from GPU to CPU) generated by
415
+ swapping out the given sequence_group with num_lookahead_slots.
416
+
417
+ Args:
418
+ sequence_group (SequenceGroup): The sequence group to swap out.
419
+
420
+ Returns:
421
+ List[Tuple[int, int]]: The mapping of swapping block from
422
+ GPU to CPU.
423
+ """
424
+ physical_block_id_mapping = []
425
+ for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
426
+ blocks = self.block_tables[seq.seq_id].blocks
427
+ if len(blocks) == 0:
428
+ continue
429
+
430
+ seq_swap_mapping = self.block_allocator.swap(blocks=blocks,
431
+ src_device=Device.GPU,
432
+ dst_device=Device.CPU)
433
+
434
+ # Refresh the block ids of the table (post-swap)
435
+ self.block_tables[seq.seq_id].update(blocks)
436
+
437
+ seq_physical_block_id_mapping = {
438
+ self.block_allocator.get_physical_block_id(
439
+ Device.GPU, gpu_block_id):
440
+ self.block_allocator.get_physical_block_id(
441
+ Device.CPU, cpu_block_id)
442
+ for gpu_block_id, cpu_block_id in seq_swap_mapping.items()
443
+ }
444
+
445
+ physical_block_id_mapping.extend(
446
+ list(seq_physical_block_id_mapping.items()))
447
+
448
+ return physical_block_id_mapping
449
+
450
+ def get_num_free_gpu_blocks(self) -> int:
451
+ return self.block_allocator.get_num_free_blocks(Device.GPU)
452
+
453
+ def get_num_free_cpu_blocks(self) -> int:
454
+ return self.block_allocator.get_num_free_blocks(Device.CPU)
455
+
456
+ def get_prefix_cache_hit_rate(self, device: Device) -> float:
457
+ return self.block_allocator.get_prefix_cache_hit_rate(device)
458
+
459
+ def reset_prefix_cache(self) -> bool:
460
+ return self.block_allocator.reset_prefix_cache()
461
+
462
+ def _can_swap(self,
463
+ seq_group: SequenceGroup,
464
+ device: Device,
465
+ status: SequenceStatus,
466
+ num_lookahead_slots: int = 0) -> AllocStatus:
467
+ """Returns the AllocStatus for swapping in/out the given sequence_group
468
+ on to the 'device'.
469
+
470
+ Args:
471
+ sequence_group (SequenceGroup): The sequence group to swap in/out.
472
+ device (Device): device to swap the 'seq_group' on.
473
+ status (SequenceStatus): The status of sequence which is needed
474
+ for action. RUNNING for swap out and SWAPPED for swap in
475
+ num_lookahead_slots (int): Number of lookahead slots used in
476
+ speculative decoding, default to 0.
477
+
478
+ Returns:
479
+ AllocStatus: The AllocStatus for swapping in/out the given
480
+ sequence_group on to the 'device'.
481
+ """
482
+ # First determine the number of blocks that will be touched by this
483
+ # swap. Then verify if there are available blocks in the device
484
+ # to perform the swap.
485
+ num_blocks_touched = 0
486
+ blocks: List[Block] = []
487
+ for seq in seq_group.get_seqs(status=status):
488
+ block_table = self.block_tables[seq.seq_id]
489
+ if block_table.blocks is not None:
490
+ # Compute the number blocks to touch for the tokens to be
491
+ # appended. This does NOT include the full blocks that need
492
+ # to be touched for the swap.
493
+ num_blocks_touched += \
494
+ block_table.get_num_blocks_touched_by_append_slots(
495
+ block_table.get_unseen_token_ids(seq.get_token_ids()),
496
+ num_lookahead_slots=num_lookahead_slots)
497
+ blocks.extend(block_table.blocks)
498
+ # Compute the number of full blocks to touch and add it to the
499
+ # existing count of blocks to touch.
500
+ num_blocks_touched += self.block_allocator.get_num_full_blocks_touched(
501
+ blocks, device=device)
502
+
503
+ watermark_blocks = 0
504
+ if device == Device.GPU:
505
+ watermark_blocks = self.watermark_blocks
506
+
507
+ if self.block_allocator.get_num_total_blocks(
508
+ device) < num_blocks_touched:
509
+ return AllocStatus.NEVER
510
+ elif self.block_allocator.get_num_free_blocks(
511
+ device) - num_blocks_touched >= watermark_blocks:
512
+ return AllocStatus.OK
513
+ else:
514
+ return AllocStatus.LATER
515
+
516
+ def get_num_cached_tokens(self, seq: Sequence) -> int:
517
+ """Get the number of tokens in blocks that are already computed and
518
+ cached in the block manager for the sequence.
519
+ """
520
+ return self._computed_blocks_tracker.get_num_cached_tokens(seq)
.venv/lib/python3.11/site-packages/vllm/core/evictor.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import enum
4
+ import heapq
5
+ from abc import ABC, abstractmethod
6
+ from typing import Dict, List, Tuple
7
+
8
+
9
+ class EvictionPolicy(enum.Enum):
10
+ """Enum for eviction policy used by make_evictor to instantiate the correct
11
+ Evictor subclass.
12
+ """
13
+ LRU = enum.auto()
14
+
15
+
16
+ class Evictor(ABC):
17
+ """The Evictor subclasses should be used by the BlockAllocator class to
18
+ handle eviction of freed Blocks.
19
+ """
20
+
21
+ @abstractmethod
22
+ def __init__(self):
23
+ pass
24
+
25
+ @abstractmethod
26
+ def __contains__(self, block_id: int) -> bool:
27
+ pass
28
+
29
+ @abstractmethod
30
+ def evict(self) -> Tuple[int, int]:
31
+ """Runs the eviction algorithm and returns the evicted block's
32
+ content hash along with physical block id along with physical block id
33
+ """
34
+ pass
35
+
36
+ @abstractmethod
37
+ def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
38
+ last_accessed: float):
39
+ """Adds block to the evictor, making it a candidate for eviction"""
40
+ pass
41
+
42
+ @abstractmethod
43
+ def update(self, block_id: int, last_accessed: float):
44
+ """Update corresponding block's access time in metadata"""
45
+ pass
46
+
47
+ @abstractmethod
48
+ def remove(self, block_id: int):
49
+ """Remove a given block id from the cache."""
50
+ pass
51
+
52
+ @property
53
+ @abstractmethod
54
+ def num_blocks(self) -> int:
55
+ pass
56
+
57
+
58
+ class BlockMetaData:
59
+ """Data structure for storing key data describe cached block, so that
60
+ evitor could use to make its decision which one to choose for eviction
61
+
62
+ Here we use physical block id as the dict key, as there maybe several
63
+ blocks with the same content hash, but their physical id is unique.
64
+ """
65
+
66
+ def __init__(self, content_hash: int, num_hashed_tokens: int,
67
+ last_accessed: float):
68
+ self.content_hash = content_hash
69
+ self.num_hashed_tokens = num_hashed_tokens
70
+ self.last_accessed = last_accessed
71
+
72
+
73
+ class LRUEvictor(Evictor):
74
+ """Evicts in a least-recently-used order using the last_accessed timestamp
75
+ that's recorded in the Block. If there are multiple blocks with
76
+ the same last_accessed time, then the one with the largest num_hashed_tokens
77
+ will be evicted. If two blocks each have the lowest last_accessed time and
78
+ highest num_hashed_tokens value, then one will be chose arbitrarily
79
+ """
80
+
81
+ # CLEANUP_THRESHOLD determines the maximum allowable size of the priority
82
+ # queue relative to the free table size. When this threshold is exceeded,
83
+ # a cleanup operation is triggered to reduce memory usage.
84
+ CLEANUP_THRESHOLD = 50
85
+
86
+ def __init__(self):
87
+ self.free_table: Dict[int, BlockMetaData] = {}
88
+ self.priority_queue = []
89
+
90
+ def __contains__(self, block_id: int) -> bool:
91
+ return block_id in self.free_table
92
+
93
+ def evict(self) -> Tuple[int, int]:
94
+ if len(self.free_table) == 0:
95
+ raise ValueError("No usable cache memory left")
96
+
97
+ while self.priority_queue:
98
+ # We do not remove outdated entries from the priority queue at the
99
+ # time of updating the last_accessed timestamp. Instead, outdated
100
+ # entries are filtered out here during eviction. Outdated entries
101
+ # would either not in the free table, or have older last accessed
102
+ # time.
103
+ last_accessed, _, block_id, content_hash = heapq.heappop(
104
+ self.priority_queue)
105
+ if (block_id in self.free_table and
106
+ self.free_table[block_id].last_accessed == last_accessed):
107
+ self.free_table.pop(block_id)
108
+ return block_id, content_hash
109
+
110
+ raise ValueError("No usable cache memory left")
111
+
112
+ def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
113
+ last_accessed: float):
114
+ self.free_table[block_id] = BlockMetaData(content_hash,
115
+ num_hashed_tokens,
116
+ last_accessed)
117
+ heapq.heappush(
118
+ self.priority_queue,
119
+ (last_accessed, -num_hashed_tokens, block_id, content_hash))
120
+ self._cleanup_if_necessary()
121
+
122
+ def update(self, block_id: int, last_accessed: float):
123
+ self.free_table[block_id].last_accessed = last_accessed
124
+
125
+ def _cleanup_if_necessary(self):
126
+ if len(self.priority_queue) > LRUEvictor.CLEANUP_THRESHOLD * len(
127
+ self.free_table):
128
+ self._cleanup()
129
+
130
+ def _cleanup(self):
131
+ new_priority_queue: List[Tuple[float, int, int, int]] = []
132
+
133
+ for block_id, block in self.free_table.items():
134
+ new_priority_queue.append(
135
+ (block.last_accessed, -block.num_hashed_tokens, block_id,
136
+ block.content_hash))
137
+ heapq.heapify(new_priority_queue)
138
+
139
+ self.priority_queue = new_priority_queue
140
+
141
+ def remove(self, block_id: int):
142
+ if block_id not in self.free_table:
143
+ raise ValueError(
144
+ "Attempting to remove block that's not in the evictor")
145
+ self.free_table.pop(block_id)
146
+
147
+ @property
148
+ def num_blocks(self) -> int:
149
+ return len(self.free_table)
150
+
151
+
152
+ def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
153
+ if eviction_policy == EvictionPolicy.LRU:
154
+ return LRUEvictor()
155
+ else:
156
+ raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")
.venv/lib/python3.11/site-packages/vllm/device_allocator/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/device_allocator/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (194 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/device_allocator/cumem.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # cumem-based pytorch pluggable allocator to implement sleep mode.
4
+ # other approaches tried but failed:
5
+ # - cuda-python package binding
6
+ # - custom libcuda driver ctypes wrapper
7
+ # both of them failed because of cuda context mismatch.
8
+ # not sure why, they are created from a different context.
9
+ # the only successful approach is to call cuda driver API in C.
10
+ import dataclasses
11
+ from contextlib import contextmanager
12
+ from typing import Callable, Dict, Optional, Tuple, Union
13
+
14
+ import torch
15
+
16
+ from vllm.utils import is_pin_memory_available
17
+
18
+
19
+ def find_loaded_library(lib_name) -> Optional[str]:
20
+ """
21
+ According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
22
+ the file `/proc/self/maps` contains the memory maps of the process, which includes the
23
+ shared libraries loaded by the process. We can use this file to find the path of the
24
+ a loaded library.
25
+ """ # noqa
26
+ found_line = None
27
+ with open("/proc/self/maps") as f:
28
+ for line in f:
29
+ if lib_name in line:
30
+ found_line = line
31
+ break
32
+ if found_line is None:
33
+ # the library is not loaded in the current process
34
+ return None
35
+ # if lib_name is libcudart, we need to match a line with:
36
+ # address /path/to/libcudart-hash.so.11.0
37
+ start = found_line.index("/")
38
+ path = found_line[start:].strip()
39
+ filename = path.split("/")[-1]
40
+ assert filename.rpartition(".so")[0].startswith(lib_name), \
41
+ f"Unexpected filename: {filename} for library {lib_name}"
42
+ return path
43
+
44
+
45
+ cumem_available = False
46
+ try:
47
+ from vllm.cumem_allocator import (init_module, python_create_and_map,
48
+ python_unmap_and_release)
49
+ from vllm.distributed.device_communicators.cuda_wrapper import (
50
+ CudaRTLibrary)
51
+ lib_name = find_loaded_library("cumem_allocator")
52
+ libcudart = CudaRTLibrary()
53
+ cumem_available = True
54
+ except ModuleNotFoundError:
55
+ # rocm platform does not support cumem allocator
56
+ init_module = None
57
+ python_create_and_map = None
58
+ python_unmap_and_release = None
59
+ CudaRTLibrary = None
60
+ lib_name = None
61
+ libcudart = None
62
+
63
+ # py_device, py_alignedSize, py_d_mem, py_p_memHandle
64
+ HandleType = Tuple[int, int, int, int]
65
+
66
+
67
+ @dataclasses.dataclass
68
+ class AllocationData:
69
+ handle: HandleType
70
+ tag: str
71
+ cpu_backup_tensor: Optional[torch.Tensor] = None
72
+
73
+
74
+ def create_and_map(allocation_handle: HandleType) -> None:
75
+ python_create_and_map(*allocation_handle)
76
+
77
+
78
+ def unmap_and_release(allocation_handle: HandleType) -> None:
79
+ python_unmap_and_release(*allocation_handle)
80
+
81
+
82
+ def get_pluggable_allocator(
83
+ python_malloc_fn: Callable[[int],
84
+ int], python_free_func: Callable[[int, int],
85
+ None]
86
+ ) -> torch.cuda.memory.CUDAPluggableAllocator:
87
+ init_module(python_malloc_fn, python_free_func)
88
+ new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
89
+ lib_name, 'my_malloc', 'my_free')
90
+ return new_alloc
91
+
92
+
93
+ @contextmanager
94
+ def use_memory_pool_with_allocator(
95
+ python_malloc_fn: Callable[[int], int],
96
+ python_free_func: Callable[[int, int], None]) -> None:
97
+ new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
98
+ mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
99
+ with torch.cuda.memory.use_mem_pool(mem_pool):
100
+ yield mem_pool
101
+
102
+
103
+ class CuMemAllocator:
104
+ """
105
+ A singleton class that manages a memory pool for CUDA tensors.
106
+ The memory in this pool can be offloaded or discarded when the
107
+ allocator sleeps.
108
+
109
+ Inside the `use_memory_pool(tag)` context, all tensors created will
110
+ be allocated in the memory pool, and has the same tag as the
111
+ tag passed to the context.
112
+
113
+ When we call `sleep`, all tensors with the specified tag will be
114
+ offloaded to CPU memory, and the rest of the tensors will be discarded.
115
+ When we call `wake_up`, all tensors that are previously offloaded
116
+ will be loaded back to GPU memory, and the rest of the tensors will
117
+ have empty memory.
118
+
119
+ Why it needs to be a singleton?
120
+ When allocated tensors are garbage collected, PyTorch will call
121
+ the free callback, which will call the `python_free_callback` method.
122
+ The C-extension uses a global variable to store the function of an
123
+ instance of this class. If we create multiple instances of this class,
124
+ the global variable will be overwritten and the free callback will
125
+ not work as expected.
126
+ """
127
+ instance: "CuMemAllocator" = None
128
+ default_tag: str = "default"
129
+
130
+ @staticmethod
131
+ def get_instance() -> "CuMemAllocator":
132
+ """
133
+ CuMemAllocator is a singleton class.
134
+ We cannot call the constructor directly.
135
+ Call this method to get the instance.
136
+ """
137
+ assert cumem_available, "cumem allocator is not available"
138
+ if CuMemAllocator.instance is None:
139
+ CuMemAllocator.instance = CuMemAllocator()
140
+ return CuMemAllocator.instance
141
+
142
+ def __init__(self):
143
+ self.pointer_to_data: Dict[int, AllocationData] = {}
144
+ self.current_tag: str = CuMemAllocator.default_tag
145
+
146
+ def python_malloc_callback(self, allocation_handle: HandleType) -> None:
147
+ """
148
+ Internal method to store the allocation data
149
+ when memory is allocated in the memory pool."""
150
+ py_d_mem = allocation_handle[2]
151
+ self.pointer_to_data[py_d_mem] = AllocationData(
152
+ allocation_handle, self.current_tag)
153
+ return
154
+
155
+ def python_free_callback(self, ptr: int) -> HandleType:
156
+ """
157
+ Internal method to look up the allocation data
158
+ when memory is freed in the memory pool."""
159
+ data = self.pointer_to_data.pop(ptr)
160
+ if data.cpu_backup_tensor is not None:
161
+ data.cpu_backup_tensor = None
162
+ return data.handle
163
+
164
+ def sleep(
165
+ self,
166
+ offload_tags: Optional[Union[Tuple[str, ...],
167
+ str]] = None) -> None:
168
+ """
169
+ Put the allocator in sleep mode.
170
+ All data in the memory allocation with the specified tag will be
171
+ offloaded to CPU memory, and others will be discarded.
172
+
173
+ :param offload_tags: The tags of the memory allocation that will be
174
+ offloaded. The rest of the memory allocation will be discarded.
175
+ """
176
+ if offload_tags is None:
177
+ # by default, allocated tensors are offloaded
178
+ # when the allocator sleeps
179
+ offload_tags = (CuMemAllocator.default_tag, )
180
+ elif isinstance(offload_tags, str):
181
+ offload_tags = (offload_tags, )
182
+
183
+ assert isinstance(offload_tags, tuple)
184
+
185
+ for ptr, data in self.pointer_to_data.items():
186
+ handle = data.handle
187
+ if data.tag in offload_tags:
188
+ size_in_bytes = handle[1]
189
+ cpu_backup_tensor = torch.empty(
190
+ size_in_bytes,
191
+ dtype=torch.uint8,
192
+ device='cpu',
193
+ pin_memory=is_pin_memory_available())
194
+ cpu_ptr = cpu_backup_tensor.data_ptr()
195
+ libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes)
196
+ data.cpu_backup_tensor = cpu_backup_tensor
197
+ unmap_and_release(handle)
198
+
199
+ def wake_up(self):
200
+ """
201
+ Wake up the allocator from sleep mode.
202
+ All data that is previously offloaded will be loaded back to GPU
203
+ memory, and the rest of the data will have empty memory."""
204
+ for ptr, data in self.pointer_to_data.items():
205
+ handle = data.handle
206
+ create_and_map(handle)
207
+ if data.cpu_backup_tensor is not None:
208
+ cpu_backup_tensor = data.cpu_backup_tensor
209
+ if cpu_backup_tensor is not None:
210
+ size_in_bytes = cpu_backup_tensor.numel(
211
+ ) * cpu_backup_tensor.element_size()
212
+ cpu_ptr = cpu_backup_tensor.data_ptr()
213
+ libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes)
214
+ data.cpu_backup_tensor = None
215
+
216
+ @contextmanager
217
+ def use_memory_pool(self, tag: Optional[str] = None):
218
+ """
219
+ A context manager to use the memory pool.
220
+ All memory allocation created inside the context will be allocated
221
+ in the memory pool, and has the specified tag.
222
+
223
+ :param tag: The tag of the memory allocation. If None, the default tag
224
+ will be used.
225
+ """
226
+ if tag is None:
227
+ tag = CuMemAllocator.default_tag
228
+
229
+ assert isinstance(tag, str)
230
+
231
+ old_tag = self.current_tag
232
+ self.current_tag = tag
233
+ with use_memory_pool_with_allocator(self.python_malloc_callback,
234
+ self.python_free_callback):
235
+ yield
236
+ # PyTorch's bug, calling torch.cuda.empty_cache() will error
237
+ # when using pluggable allocator, see
238
+ # https://github.com/pytorch/pytorch/issues/145168 .
239
+ # if we have some memory allocated and then freed,
240
+ # the memory will not be released.
241
+ # right now it is fine, because we only use this allocator
242
+ # during weight loading and kv cache creation, where we only
243
+ # allocate memory.
244
+ # TODO: we need to find a way to release the memory,
245
+ # i.e. calling torch.cuda.empty_cache()
246
+ self.current_tag = old_tag
247
+
248
+ def get_current_usage(self) -> int:
249
+ """
250
+ Get the total number of bytes allocated in the memory pool.
251
+ """
252
+ sum_bytes: int = 0
253
+ for ptr, data in self.pointer_to_data.items():
254
+ handle = data.handle
255
+ sum_bytes += handle[1]
256
+ return sum_bytes
.venv/lib/python3.11/site-packages/vllm/distributed/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from .communication_op import *
4
+ from .parallel_state import *
5
+ from .utils import *
.venv/lib/python3.11/site-packages/vllm/distributed/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (295 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/distributed/__pycache__/communication_op.cpython-311.pyc ADDED
Binary file (2.22 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/distributed/__pycache__/parallel_state.cpython-311.pyc ADDED
Binary file (55.5 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/distributed/__pycache__/utils.cpython-311.pyc ADDED
Binary file (13.1 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/distributed/communication_op.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import Any, Dict, Optional, Union
4
+
5
+ import torch
6
+ import torch.distributed
7
+
8
+ from .parallel_state import get_tp_group
9
+
10
+
11
+ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
12
+ """All-reduce the input tensor across model parallel group."""
13
+ return get_tp_group().all_reduce(input_)
14
+
15
+
16
+ def tensor_model_parallel_all_gather(input_: torch.Tensor,
17
+ dim: int = -1) -> torch.Tensor:
18
+ """All-gather the input tensor across model parallel group."""
19
+ return get_tp_group().all_gather(input_, dim)
20
+
21
+
22
+ def tensor_model_parallel_gather(input_: torch.Tensor,
23
+ dst: int = 0,
24
+ dim: int = -1) -> Optional[torch.Tensor]:
25
+ """Gather the input tensor across model parallel group."""
26
+ return get_tp_group().gather(input_, dst, dim)
27
+
28
+
29
+ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
30
+ Any]]] = None,
31
+ src: int = 0):
32
+ if not torch.distributed.is_initialized():
33
+ return tensor_dict
34
+ return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
.venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/__pycache__/custom_all_reduce.cpython-311.pyc ADDED
Binary file (16 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/__pycache__/pynccl.cpython-311.pyc ADDED
Binary file (11.6 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/cuda_wrapper.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """This file is a pure Python wrapper for the cudart library.
3
+ It avoids the need to compile a separate shared library, and is
4
+ convenient for use when we just need to call a few functions.
5
+ """
6
+
7
+ import ctypes
8
+ from dataclasses import dataclass
9
+ from typing import Any, Dict, List, Optional
10
+
11
+ # this line makes it possible to directly load `libcudart.so` using `ctypes`
12
+ import torch # noqa
13
+
14
+ from vllm.logger import init_logger
15
+
16
+ logger = init_logger(__name__)
17
+
18
+ # === export types and functions from cudart to Python ===
19
+ # for the original cudart definition, please check
20
+ # https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
21
+
22
+ cudaError_t = ctypes.c_int
23
+ cudaMemcpyKind = ctypes.c_int
24
+
25
+
26
+ class cudaIpcMemHandle_t(ctypes.Structure):
27
+ _fields_ = [("internal", ctypes.c_byte * 128)]
28
+
29
+
30
+ @dataclass
31
+ class Function:
32
+ name: str
33
+ restype: Any
34
+ argtypes: List[Any]
35
+
36
+
37
+ def find_loaded_library(lib_name) -> Optional[str]:
38
+ """
39
+ According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
40
+ the file `/proc/self/maps` contains the memory maps of the process, which includes the
41
+ shared libraries loaded by the process. We can use this file to find the path of the
42
+ a loaded library.
43
+ """ # noqa
44
+ found = False
45
+ with open("/proc/self/maps") as f:
46
+ for line in f:
47
+ if lib_name in line:
48
+ found = True
49
+ break
50
+ if not found:
51
+ # the library is not loaded in the current process
52
+ return None
53
+ # if lib_name is libcudart, we need to match a line with:
54
+ # address /path/to/libcudart-hash.so.11.0
55
+ start = line.index("/")
56
+ path = line[start:].strip()
57
+ filename = path.split("/")[-1]
58
+ assert filename.rpartition(".so")[0].startswith(lib_name), \
59
+ f"Unexpected filename: {filename} for library {lib_name}"
60
+ return path
61
+
62
+
63
+ class CudaRTLibrary:
64
+ exported_functions = [
65
+ # ​cudaError_t cudaSetDevice ( int device )
66
+ Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
67
+ # cudaError_t cudaDeviceSynchronize ( void )
68
+ Function("cudaDeviceSynchronize", cudaError_t, []),
69
+ # ​cudaError_t cudaDeviceReset ( void )
70
+ Function("cudaDeviceReset", cudaError_t, []),
71
+
72
+ # const char* cudaGetErrorString ( cudaError_t error )
73
+ Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
74
+
75
+ # ​cudaError_t cudaMalloc ( void** devPtr, size_t size )
76
+ Function("cudaMalloc", cudaError_t,
77
+ [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),
78
+ # ​cudaError_t cudaFree ( void* devPtr )
79
+ Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
80
+ # ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
81
+ Function("cudaMemset", cudaError_t,
82
+ [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
83
+ # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
84
+ Function("cudaMemcpy", cudaError_t, [
85
+ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind
86
+ ]),
87
+
88
+ # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
89
+ Function("cudaIpcGetMemHandle", cudaError_t,
90
+ [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
91
+ # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
92
+ Function("cudaIpcOpenMemHandle", cudaError_t, [
93
+ ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint
94
+ ]),
95
+ ]
96
+
97
+ # class attribute to store the mapping from the path to the library
98
+ # to avoid loading the same library multiple times
99
+ path_to_library_cache: Dict[str, Any] = {}
100
+
101
+ # class attribute to store the mapping from library path
102
+ # to the corresponding dictionary
103
+ path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
104
+
105
+ def __init__(self, so_file: Optional[str] = None):
106
+ if so_file is None:
107
+ so_file = find_loaded_library("libcudart")
108
+ assert so_file is not None, \
109
+ "libcudart is not loaded in the current process"
110
+ if so_file not in CudaRTLibrary.path_to_library_cache:
111
+ lib = ctypes.CDLL(so_file)
112
+ CudaRTLibrary.path_to_library_cache[so_file] = lib
113
+ self.lib = CudaRTLibrary.path_to_library_cache[so_file]
114
+
115
+ if so_file not in CudaRTLibrary.path_to_dict_mapping:
116
+ _funcs = {}
117
+ for func in CudaRTLibrary.exported_functions:
118
+ f = getattr(self.lib, func.name)
119
+ f.restype = func.restype
120
+ f.argtypes = func.argtypes
121
+ _funcs[func.name] = f
122
+ CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
123
+ self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
124
+
125
+ def CUDART_CHECK(self, result: cudaError_t) -> None:
126
+ if result != 0:
127
+ error_str = self.cudaGetErrorString(result)
128
+ raise RuntimeError(f"CUDART error: {error_str}")
129
+
130
+ def cudaGetErrorString(self, error: cudaError_t) -> str:
131
+ return self.funcs["cudaGetErrorString"](error).decode("utf-8")
132
+
133
+ def cudaSetDevice(self, device: int) -> None:
134
+ self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
135
+
136
+ def cudaDeviceSynchronize(self) -> None:
137
+ self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
138
+
139
+ def cudaDeviceReset(self) -> None:
140
+ self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
141
+
142
+ def cudaMalloc(self, size: int) -> ctypes.c_void_p:
143
+ devPtr = ctypes.c_void_p()
144
+ self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
145
+ return devPtr
146
+
147
+ def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
148
+ self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
149
+
150
+ def cudaMemset(self, devPtr: ctypes.c_void_p, value: int,
151
+ count: int) -> None:
152
+ self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
153
+
154
+ def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p,
155
+ count: int) -> None:
156
+ cudaMemcpyDefault = 4
157
+ kind = cudaMemcpyDefault
158
+ self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
159
+
160
+ def cudaIpcGetMemHandle(self,
161
+ devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
162
+ handle = cudaIpcMemHandle_t()
163
+ self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"](
164
+ ctypes.byref(handle), devPtr))
165
+ return handle
166
+
167
+ def cudaIpcOpenMemHandle(self,
168
+ handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
169
+ cudaIpcMemLazyEnablePeerAccess = 1
170
+ devPtr = ctypes.c_void_p()
171
+ self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"](
172
+ ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess))
173
+ return devPtr
.venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/custom_all_reduce.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import ctypes
4
+ from contextlib import contextmanager
5
+ from typing import List, Optional, Union
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from torch.distributed import ProcessGroup
10
+
11
+ import vllm.envs as envs
12
+ from vllm import _custom_ops as ops
13
+ from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
14
+ from vllm.distributed.device_communicators.custom_all_reduce_utils import (
15
+ gpu_p2p_access_check)
16
+ from vllm.distributed.parallel_state import in_the_same_node_as
17
+ from vllm.logger import init_logger
18
+ from vllm.platforms import current_platform
19
+ from vllm.utils import cuda_device_count_stateless
20
+
21
+ try:
22
+ ops.meta_size()
23
+ custom_ar = True
24
+ except Exception:
25
+ # For AMD GPUs and CPUs
26
+ custom_ar = False
27
+
28
+ logger = init_logger(__name__)
29
+
30
+
31
+ def _can_p2p(rank: int, world_size: int) -> bool:
32
+ for i in range(world_size):
33
+ if i == rank:
34
+ continue
35
+ if envs.VLLM_SKIP_P2P_CHECK:
36
+ logger.info(
37
+ "Skipping P2P check and trusting the driver's P2P report.")
38
+ return torch.cuda.can_device_access_peer(rank, i)
39
+ if not gpu_p2p_access_check(rank, i):
40
+ return False
41
+ return True
42
+
43
+
44
+ def is_weak_contiguous(inp: torch.Tensor):
45
+ return inp.is_contiguous() or (inp.storage().nbytes() -
46
+ inp.storage_offset() * inp.element_size()
47
+ == inp.numel() * inp.element_size())
48
+
49
+
50
+ class CustomAllreduce:
51
+
52
+ _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
53
+
54
+ # max_size: max supported allreduce size
55
+ def __init__(self,
56
+ group: ProcessGroup,
57
+ device: Union[int, str, torch.device],
58
+ max_size=8192 * 1024) -> None:
59
+ """
60
+ Args:
61
+ group: the process group to work on. If None, it will use the
62
+ default process group.
63
+ device: the device to bind the CustomAllreduce to. If None,
64
+ it will be bind to f"cuda:{local_rank}".
65
+ It is the caller's responsibility to make sure each communicator
66
+ is bind to a unique device, and all communicators in this group
67
+ are in the same node.
68
+ """
69
+ self._IS_CAPTURING = False
70
+ self.disabled = True
71
+
72
+ if not custom_ar:
73
+ # disable because of missing custom allreduce library
74
+ # e.g. in a non-cuda environment
75
+ return
76
+
77
+ self.group = group
78
+
79
+ assert dist.get_backend(group) != dist.Backend.NCCL, (
80
+ "CustomAllreduce should be attached to a non-NCCL group.")
81
+
82
+ if not all(in_the_same_node_as(group, source_rank=0)):
83
+ # No need to initialize custom allreduce for multi-node case.
84
+ logger.warning(
85
+ "Custom allreduce is disabled because this process group"
86
+ " spans across nodes.")
87
+ return
88
+
89
+ rank = dist.get_rank(group=self.group)
90
+ world_size = dist.get_world_size(group=self.group)
91
+ if world_size == 1:
92
+ # No need to initialize custom allreduce for single GPU case.
93
+ return
94
+
95
+ if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
96
+ logger.warning(
97
+ "Custom allreduce is disabled due to an unsupported world"
98
+ " size: %d. Supported world sizes: %s. To silence this "
99
+ "warning, specify disable_custom_all_reduce=True explicitly.",
100
+ world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
101
+ return
102
+
103
+ if isinstance(device, int):
104
+ device = torch.device(f"cuda:{device}")
105
+ elif isinstance(device, str):
106
+ device = torch.device(device)
107
+ # now `device` is a `torch.device` object
108
+ assert isinstance(device, torch.device)
109
+ self.device = device
110
+
111
+ cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
112
+ if cuda_visible_devices:
113
+ device_ids = list(map(int, cuda_visible_devices.split(",")))
114
+ else:
115
+ device_ids = list(range(cuda_device_count_stateless()))
116
+
117
+ physical_device_id = device_ids[device.index]
118
+ tensor = torch.tensor([physical_device_id],
119
+ dtype=torch.int,
120
+ device="cpu")
121
+ gather_list = [
122
+ torch.tensor([0], dtype=torch.int, device="cpu")
123
+ for _ in range(world_size)
124
+ ]
125
+ dist.all_gather(gather_list, tensor, group=self.group)
126
+ physical_device_ids = [t.item() for t in gather_list]
127
+
128
+ # test nvlink first, this will filter out most of the cases
129
+ # where custom allreduce is not supported
130
+ # this checks hardware and driver support for NVLink
131
+ assert current_platform.is_cuda()
132
+ from vllm.platforms.cuda import CudaPlatform
133
+ cuda_platform: CudaPlatform = current_platform
134
+ full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids)
135
+ if world_size > 2 and not full_nvlink:
136
+ logger.warning(
137
+ "Custom allreduce is disabled because it's not supported on"
138
+ " more than two PCIe-only GPUs. To silence this warning, "
139
+ "specify disable_custom_all_reduce=True explicitly.")
140
+ return
141
+ # test P2P capability, this checks software/cudaruntime support
142
+ # this is expensive to compute at the first time
143
+ # then we cache the result
144
+ if not _can_p2p(rank, world_size):
145
+ logger.warning(
146
+ "Custom allreduce is disabled because your platform lacks "
147
+ "GPU P2P capability or P2P test failed. To silence this "
148
+ "warning, specify disable_custom_all_reduce=True explicitly.")
149
+ return
150
+
151
+ self.disabled = False
152
+ # Buffers memory are owned by this Python class and passed to C++.
153
+ # Meta data composes of two parts: meta data for synchronization and a
154
+ # temporary buffer for storing intermediate allreduce results.
155
+ self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size,
156
+ group=group)
157
+ # This is a pre-registered IPC buffer. In eager mode, input tensors
158
+ # are first copied into this buffer before allreduce is performed
159
+ self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
160
+ # This is a buffer for storing the tuples of pointers pointing to
161
+ # IPC buffers from all ranks. Each registered tuple has size of
162
+ # 8*world_size bytes where world_size is at most 8. Allocating 8MB
163
+ # is enough for 131072 such tuples. The largest model I've seen only
164
+ # needs less than 10000 of registered tuples.
165
+ self.rank_data = torch.empty(8 * 1024 * 1024,
166
+ dtype=torch.uint8,
167
+ device=self.device)
168
+ self.max_size = max_size
169
+ self.rank = rank
170
+ self.world_size = world_size
171
+ self.full_nvlink = full_nvlink
172
+ self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank,
173
+ self.full_nvlink)
174
+ ops.register_buffer(self._ptr, self.buffer_ptrs)
175
+
176
+ @staticmethod
177
+ def create_shared_buffer(
178
+ size_in_bytes: int,
179
+ group: Optional[ProcessGroup] = None) -> List[int]:
180
+ """
181
+ Creates a shared buffer and returns a list of pointers
182
+ representing the buffer on all processes in the group.
183
+ """
184
+ lib = CudaRTLibrary()
185
+ pointer = lib.cudaMalloc(size_in_bytes)
186
+ handle = lib.cudaIpcGetMemHandle(pointer)
187
+ world_size = dist.get_world_size(group=group)
188
+ rank = dist.get_rank(group=group)
189
+ handles = [None] * world_size
190
+ dist.all_gather_object(handles, handle, group=group)
191
+
192
+ pointers: List[int] = []
193
+ for i, h in enumerate(handles):
194
+ if i == rank:
195
+ pointers.append(pointer.value) # type: ignore
196
+ else:
197
+ pointers.append(
198
+ lib.cudaIpcOpenMemHandle(h).value) # type: ignore
199
+
200
+ return pointers
201
+
202
+ @staticmethod
203
+ def free_shared_buffer(pointers: List[int],
204
+ group: Optional[ProcessGroup] = None) -> None:
205
+ rank = dist.get_rank(group=group)
206
+ lib = CudaRTLibrary()
207
+ lib.cudaFree(ctypes.c_void_p(pointers[rank]))
208
+
209
+ @contextmanager
210
+ def capture(self):
211
+ """
212
+ The main responsibility of this context manager is the
213
+ `register_graph_buffers` call at the end of the context.
214
+ It records all the buffer addresses used in the CUDA graph.
215
+ """
216
+ try:
217
+ self._IS_CAPTURING = True
218
+ yield
219
+ finally:
220
+ self._IS_CAPTURING = False
221
+ if not self.disabled:
222
+ self.register_graph_buffers()
223
+
224
+ def register_graph_buffers(self):
225
+ handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
226
+ logger.info("Registering %d cuda graph addresses", len(offset))
227
+ # We cannot directly use `dist.all_gather_object` here
228
+ # because it is incompatible with `gloo` backend under inference mode.
229
+ # see https://github.com/pytorch/pytorch/issues/126032 for details.
230
+ all_data = [[None, None]
231
+ for _ in range(dist.get_world_size(group=self.group))]
232
+ all_data[self.rank] = [handle, offset]
233
+ ranks = sorted(dist.get_process_group_ranks(group=self.group))
234
+ for i, rank in enumerate(ranks):
235
+ dist.broadcast_object_list(all_data[i],
236
+ src=rank,
237
+ group=self.group,
238
+ device="cpu")
239
+ # Unpack list of tuples to tuple of lists.
240
+ handles = [d[0] for d in all_data] # type: ignore
241
+ offsets = [d[1] for d in all_data] # type: ignore
242
+ ops.register_graph_buffers(self._ptr, handles, offsets)
243
+
244
+ def should_custom_ar(self, inp: torch.Tensor):
245
+ if self.disabled:
246
+ return False
247
+ inp_size = inp.numel() * inp.element_size()
248
+ # custom allreduce requires input byte size to be multiples of 16
249
+ if inp_size % 16 != 0:
250
+ return False
251
+ if not is_weak_contiguous(inp):
252
+ return False
253
+ # for 4 or more non NVLink-capable GPUs, custom allreduce provides
254
+ # little performance improvement over NCCL.
255
+ if self.world_size == 2 or self.full_nvlink:
256
+ return inp_size < self.max_size
257
+ return False
258
+
259
+ def all_reduce(self,
260
+ inp: torch.Tensor,
261
+ *,
262
+ out: torch.Tensor = None,
263
+ registered: bool = False):
264
+ """Performs an out-of-place all reduce.
265
+
266
+ If registered is True, this assumes inp's pointer is already
267
+ IPC-registered. Otherwise, inp is first copied into a pre-registered
268
+ buffer.
269
+ """
270
+ if out is None:
271
+ out = torch.empty_like(inp)
272
+ if registered:
273
+ ops.all_reduce(self._ptr, inp, out, 0, 0)
274
+ else:
275
+ ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank],
276
+ self.max_size)
277
+ return out
278
+
279
+ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
280
+ """The main allreduce API that provides support for cuda graph."""
281
+ # When custom allreduce is disabled, this will be None.
282
+ if self.disabled or not self.should_custom_ar(input):
283
+ return None
284
+ if self._IS_CAPTURING:
285
+ if torch.cuda.is_current_stream_capturing():
286
+ return self.all_reduce(input, registered=True)
287
+ else:
288
+ # If warm up, mimic the allocation pattern since custom
289
+ # allreduce is out-of-place.
290
+ return torch.empty_like(input)
291
+ else:
292
+ # Note: outside of cuda graph context, custom allreduce incurs a
293
+ # cost of cudaMemcpy, which should be small (<=1% of overall
294
+ # latency) compared to the performance gain of using custom kernels
295
+ return self.all_reduce(input, registered=False)
296
+
297
+ def close(self):
298
+ if not self.disabled and self._ptr:
299
+ ops.dispose(self._ptr)
300
+ self._ptr = 0
301
+ self.free_shared_buffer(self.meta_ptrs)
302
+ self.free_shared_buffer(self.buffer_ptrs)
303
+
304
+ def __del__(self):
305
+ self.close()
.venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/custom_all_reduce_utils.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import ctypes
4
+ import json
5
+ import os
6
+ import pickle
7
+ import subprocess
8
+ import sys
9
+ import tempfile
10
+ from itertools import product
11
+ from typing import Dict, List, Optional, Sequence
12
+
13
+ import torch.distributed as dist
14
+ import torch.multiprocessing as mp
15
+
16
+ import vllm.envs as envs
17
+ from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
18
+ from vllm.logger import init_logger
19
+ from vllm.utils import (cuda_device_count_stateless,
20
+ update_environment_variables)
21
+
22
+ logger = init_logger(__name__)
23
+
24
+
25
+ def producer(batch_src: Sequence[int],
26
+ producer_queue,
27
+ consumer_queue,
28
+ result_queue,
29
+ cuda_visible_devices: Optional[str] = None):
30
+ if cuda_visible_devices is not None:
31
+ update_environment_variables(
32
+ {"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
33
+
34
+ lib = CudaRTLibrary()
35
+ for i in batch_src:
36
+ lib.cudaSetDevice(i)
37
+ pointer = lib.cudaMalloc(1024)
38
+ lib.cudaMemset(pointer, 1, 1024)
39
+ lib.cudaDeviceSynchronize()
40
+ handle = lib.cudaIpcGetMemHandle(pointer)
41
+ producer_queue.put(handle)
42
+ open_success = consumer_queue.get()
43
+ if open_success:
44
+ # use two queues to simulate barrier
45
+ producer_queue.put(0)
46
+ consumer_queue.get()
47
+ # check if the memory is modified
48
+ host_data = (ctypes.c_char * 1024)()
49
+ lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
50
+ for i in range(1024):
51
+ if ord(host_data[i]) != 2:
52
+ open_success = False
53
+ break
54
+ result_queue.put(open_success)
55
+ lib.cudaDeviceReset()
56
+
57
+
58
+ def consumer(batch_tgt: Sequence[int],
59
+ producer_queue,
60
+ consumer_queue,
61
+ result_queue,
62
+ cuda_visible_devices: Optional[str] = None):
63
+ if cuda_visible_devices is not None:
64
+ update_environment_variables(
65
+ {"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
66
+
67
+ lib = CudaRTLibrary()
68
+ for j in batch_tgt:
69
+ lib.cudaSetDevice(j)
70
+ handle = producer_queue.get()
71
+ open_success = False
72
+ try:
73
+ pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
74
+ open_success = True
75
+ except RuntimeError:
76
+ # cannot error out here, because the producer process
77
+ # is still waiting for the response.
78
+ pass
79
+ consumer_queue.put(open_success)
80
+ if open_success:
81
+ # modify the memory
82
+ lib.cudaMemset(pointer, 2, 1024)
83
+ lib.cudaDeviceSynchronize()
84
+ # use two queues to simulate barrier
85
+ producer_queue.get()
86
+ consumer_queue.put(0)
87
+ # check if the memory is modified
88
+ host_data = (ctypes.c_char * 1024)()
89
+ lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
90
+ for i in range(1024):
91
+ if ord(host_data[i]) != 2:
92
+ open_success = False
93
+ break
94
+ result_queue.put(open_success)
95
+ lib.cudaDeviceReset()
96
+
97
+
98
+ def can_actually_p2p(
99
+ batch_src: Sequence[int],
100
+ batch_tgt: Sequence[int],
101
+ ) -> Sequence[bool]:
102
+ """
103
+ Usually, checking if P2P access is enabled can be done by
104
+ `torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
105
+ the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
106
+ returns `True` even if P2P access is not actually possible.
107
+ See https://github.com/vllm-project/vllm/issues/2728 and
108
+ https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
109
+ Therefore, we have to perform a real P2P access to check if it is actually
110
+ possible.
111
+
112
+ Note on p2p and cuda IPC:
113
+ Usually, one process uses one GPU:
114
+ GPU src --> cuda context src --> tensor src --> process src
115
+
116
+ We need to combine p2p and cuda IPC, so that:
117
+ GPU src --> cuda context src --> tensor src --> process src
118
+ |shared|
119
+ GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
120
+ That is to say, process src creates a tensor in GPU src, passes IPC handle to
121
+ process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
122
+ tensor in process tgt will be reflected in the tensor in process src, because
123
+ they are the same memory segment.
124
+ It is important to note that process tgt accesses the tensor in GPU tgt, not
125
+ GPU src. That's why we need p2p access.
126
+
127
+ The most time-consuming part is the process creation. To avoid creating
128
+ processes for every pair of GPUs, we use batched testing. We create two
129
+ processes for testing all pairs of GPUs in batch. The trick is to reset
130
+ the device after each test (which is not available in PyTorch).
131
+ """ # noqa
132
+ cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
133
+ # pass the CUDA_VISIBLE_DEVICES to the child process
134
+ # to make sure they see the same set of GPUs
135
+
136
+ # make sure the processes are spawned
137
+ smp = mp.get_context("spawn")
138
+ producer_queue = smp.Queue()
139
+ consumer_queue = smp.Queue()
140
+ result_queue = smp.Queue()
141
+ p_src = smp.Process(target=producer,
142
+ args=(batch_src, producer_queue, consumer_queue,
143
+ result_queue, cuda_visible_devices))
144
+ p_tgt = smp.Process(target=consumer,
145
+ args=(batch_tgt, producer_queue, consumer_queue,
146
+ result_queue, cuda_visible_devices))
147
+ p_src.start()
148
+ p_tgt.start()
149
+ p_src.join()
150
+ p_tgt.join()
151
+ assert p_src.exitcode == 0 and p_tgt.exitcode == 0
152
+ result: List[bool] = []
153
+ for src, tgt in zip(batch_src, batch_tgt):
154
+ a = result_queue.get()
155
+ b = result_queue.get()
156
+ if a != b:
157
+ logger.warning(
158
+ "Two processes do not agree on the P2P access"
159
+ " status on %d -> %d, treat as disabled.", src, tgt)
160
+ result.append(False)
161
+ else:
162
+ result.append(a)
163
+ return result
164
+
165
+
166
+ # why do we need this cache?
167
+ # we are testing peer-to-peer (p2p) access between GPUs,across processes.
168
+ # if we test it every time, it will be very slow, because we need to create
169
+ # N * N * 2 processes, where N is the world size. This is very slow.
170
+ # to reduce the time, we use a cache file to store the p2p access status.
171
+ # the cache file is generated by the master process if it does not exist.
172
+ # then all the processes can read the cache file to check the p2p access status.
173
+ # Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
174
+ # can have different cache files for different CUDA_VISIBLE_DEVICES settings,
175
+ # e.g. used by different vllm engines. The device id in the cache file is a
176
+ # **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
177
+ # of visible devices in the vllm engine.
178
+ _gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
179
+
180
+
181
+ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
182
+ """Check if GPU src can access GPU tgt."""
183
+
184
+ # if the cache variable is already calculated,
185
+ # read from the cache instead of checking it again
186
+ global _gpu_p2p_access_cache
187
+ if _gpu_p2p_access_cache is not None:
188
+ return _gpu_p2p_access_cache[f"{src}->{tgt}"]
189
+
190
+ is_distributed = dist.is_initialized()
191
+
192
+ num_dev = cuda_device_count_stateless()
193
+ cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
194
+ if cuda_visible_devices is None:
195
+ cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
196
+
197
+ path = os.path.join(
198
+ envs.VLLM_CACHE_ROOT,
199
+ f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json")
200
+ os.makedirs(os.path.dirname(path), exist_ok=True)
201
+ from vllm.distributed.parallel_state import get_world_group
202
+ if ((not is_distributed or get_world_group().local_rank == 0)
203
+ and (not os.path.exists(path))):
204
+ # only the local master process (with local_rank == 0) can
205
+ # enter this block to calculate the cache
206
+ logger.info("generating GPU P2P access cache in %s", path)
207
+ cache: Dict[str, bool] = {}
208
+ ids = list(range(num_dev))
209
+ # batch of all pairs of GPUs
210
+ batch_src, batch_tgt = zip(*list(product(ids, ids)))
211
+ # NOTE: we use `subprocess` rather than `multiprocessing` here
212
+ # because the caller might not have `if __name__ == "__main__":`,
213
+ # in that case we cannot use spawn method in multiprocessing.
214
+ # However, `can_actually_p2p` requires spawn method.
215
+ # The fix is, we use `subprocess` to call the function,
216
+ # where we have `if __name__ == "__main__":` in this file.
217
+
218
+ # use a temporary file to store the result
219
+ # we don't use the output of the subprocess directly,
220
+ # because the subprocess might produce logging output
221
+ with tempfile.NamedTemporaryFile() as output_file:
222
+ input_bytes = pickle.dumps(
223
+ (batch_src, batch_tgt, output_file.name))
224
+ returned = subprocess.run([sys.executable, __file__],
225
+ input=input_bytes,
226
+ capture_output=True)
227
+ # check if the subprocess is successful
228
+ try:
229
+ returned.check_returncode()
230
+ except Exception as e:
231
+ # wrap raised exception to provide more information
232
+ raise RuntimeError(
233
+ f"Error happened when batch testing "
234
+ f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
235
+ f"{returned.stderr.decode()}") from e
236
+ with open(output_file.name, "rb") as f:
237
+ result = pickle.load(f)
238
+ for _i, _j, r in zip(batch_src, batch_tgt, result):
239
+ cache[f"{_i}->{_j}"] = r
240
+ with open(path, "w") as f:
241
+ json.dump(cache, f, indent=4)
242
+ if is_distributed:
243
+ get_world_group().barrier()
244
+ logger.info("reading GPU P2P access cache from %s", path)
245
+ with open(path) as f:
246
+ cache = json.load(f)
247
+ _gpu_p2p_access_cache = cache
248
+ return _gpu_p2p_access_cache[f"{src}->{tgt}"]
249
+
250
+
251
+ __all__ = ["gpu_p2p_access_check"]
252
+
253
+ if __name__ == "__main__":
254
+ batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())
255
+ result = can_actually_p2p(batch_src, batch_tgt)
256
+ with open(output_file, "wb") as f:
257
+ f.write(pickle.dumps(result))
.venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/hpu_communicator.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+ from torch.distributed import ProcessGroup
6
+
7
+ from vllm.platforms import current_platform
8
+
9
+ if current_platform.is_hpu():
10
+ import habana_frameworks.torch as htorch # noqa: F401
11
+
12
+
13
+ class HpuCommunicator:
14
+
15
+ def __init__(self, group: ProcessGroup):
16
+ if not current_platform.is_hpu():
17
+ self.disabled = True
18
+ return
19
+ self.disabled = False
20
+ self.group = group
21
+ self.world_size = dist.get_world_size(self.group)
22
+
23
+ def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
24
+ # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
25
+ # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
26
+ # (which is required for tensor parallel HPUGraph inference)
27
+ htorch.core.mark_step()
28
+ dist.all_reduce(x, group=self.group)
29
+ return x
30
+
31
+ def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
32
+ world_size = self.world_size
33
+ if dim < 0:
34
+ # Convert negative dim to positive.
35
+ dim += x.dim()
36
+ input_size = x.size()
37
+ # Allocate output tensor.
38
+ output_tensor = torch.empty((world_size, ) + input_size,
39
+ dtype=x.dtype,
40
+ device=x.device)
41
+ # All-gather.
42
+ htorch.core.mark_step()
43
+ dist.all_gather_into_tensor(output_tensor, x, group=self.group)
44
+ # Reshape
45
+ output_tensor = output_tensor.movedim(0, dim)
46
+ output_tensor = output_tensor.reshape(input_size[:dim] +
47
+ (world_size *
48
+ input_size[dim], ) +
49
+ input_size[dim + 1:])
50
+ return output_tensor
.venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/pynccl.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import Optional, Union
4
+
5
+ # ===================== import region =====================
6
+ import torch
7
+ import torch.distributed as dist
8
+ from torch.distributed import ProcessGroup, ReduceOp
9
+
10
+ from vllm.distributed.device_communicators.pynccl_wrapper import (
11
+ NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
12
+ ncclRedOpTypeEnum, ncclUniqueId)
13
+ from vllm.distributed.utils import StatelessProcessGroup
14
+ from vllm.logger import init_logger
15
+ from vllm.utils import current_stream
16
+
17
+ logger = init_logger(__name__)
18
+
19
+
20
+ class PyNcclCommunicator:
21
+
22
+ def __init__(
23
+ self,
24
+ group: Union[ProcessGroup, StatelessProcessGroup],
25
+ device: Union[int, str, torch.device],
26
+ library_path: Optional[str] = None,
27
+ ):
28
+ """
29
+ Args:
30
+ group: the process group to work on. If None, it will use the
31
+ default process group.
32
+ device: the device to bind the PyNcclCommunicator to. If None,
33
+ it will be bind to f"cuda:{local_rank}".
34
+ library_path: the path to the NCCL library. If None, it will
35
+ use the default library path.
36
+ It is the caller's responsibility to make sure each communicator
37
+ is bind to a unique device.
38
+ """
39
+ if not isinstance(group, StatelessProcessGroup):
40
+ assert dist.is_initialized()
41
+ assert dist.get_backend(group) != dist.Backend.NCCL, (
42
+ "PyNcclCommunicator should be attached to a non-NCCL group.")
43
+ # note: this rank is the rank in the group
44
+ self.rank = dist.get_rank(group)
45
+ self.world_size = dist.get_world_size(group)
46
+ else:
47
+ self.rank = group.rank
48
+ self.world_size = group.world_size
49
+
50
+ self.group = group
51
+
52
+ # if world_size == 1, no need to create communicator
53
+ if self.world_size == 1:
54
+ self.available = False
55
+ self.disabled = True
56
+ return
57
+ try:
58
+ self.nccl = NCCLLibrary(library_path)
59
+ except Exception:
60
+ # disable because of missing NCCL library
61
+ # e.g. in a non-GPU environment
62
+ self.available = False
63
+ self.disabled = True
64
+ return
65
+
66
+ self.available = True
67
+ self.disabled = False
68
+
69
+ logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())
70
+
71
+ if self.rank == 0:
72
+ # get the unique id from NCCL
73
+ self.unique_id = self.nccl.ncclGetUniqueId()
74
+ else:
75
+ # construct an empty unique id
76
+ self.unique_id = ncclUniqueId()
77
+
78
+ if not isinstance(group, StatelessProcessGroup):
79
+ tensor = torch.ByteTensor(list(self.unique_id.internal))
80
+ ranks = dist.get_process_group_ranks(group)
81
+ # arg `src` in `broadcast` is the global rank
82
+ dist.broadcast(tensor, src=ranks[0], group=group)
83
+ byte_list = tensor.tolist()
84
+ for i, byte in enumerate(byte_list):
85
+ self.unique_id.internal[i] = byte
86
+ else:
87
+ self.unique_id = group.broadcast_obj(self.unique_id, src=0)
88
+ if isinstance(device, int):
89
+ device = torch.device(f"cuda:{device}")
90
+ elif isinstance(device, str):
91
+ device = torch.device(device)
92
+ # now `device` is a `torch.device` object
93
+ assert isinstance(device, torch.device)
94
+ self.device = device
95
+ # nccl communicator and stream will use this device
96
+ # `torch.cuda.device` is a context manager that changes the
97
+ # current cuda device to the specified one
98
+ with torch.cuda.device(device):
99
+ self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
100
+ self.world_size, self.unique_id, self.rank)
101
+
102
+ stream = current_stream()
103
+ # A small all_reduce for warmup.
104
+ data = torch.zeros(1, device=device)
105
+ self.all_reduce(data)
106
+ stream.synchronize()
107
+ del data
108
+
109
+ def all_reduce(self,
110
+ in_tensor: torch.Tensor,
111
+ op: ReduceOp = ReduceOp.SUM,
112
+ stream=None) -> torch.Tensor:
113
+ if self.disabled:
114
+ return None
115
+ # nccl communicator created on a specific device
116
+ # will only work on tensors on the same device
117
+ # otherwise it will cause "illegal memory access"
118
+ assert in_tensor.device == self.device, (
119
+ f"this nccl communicator is created to work on {self.device}, "
120
+ f"but the input tensor is on {in_tensor.device}")
121
+
122
+ out_tensor = torch.empty_like(in_tensor)
123
+
124
+ if stream is None:
125
+ stream = current_stream()
126
+ self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
127
+ buffer_type(out_tensor.data_ptr()),
128
+ in_tensor.numel(),
129
+ ncclDataTypeEnum.from_torch(in_tensor.dtype),
130
+ ncclRedOpTypeEnum.from_torch(op), self.comm,
131
+ cudaStream_t(stream.cuda_stream))
132
+ return out_tensor
133
+
134
+ def all_gather(self,
135
+ output_tensor: torch.Tensor,
136
+ input_tensor: torch.Tensor,
137
+ stream=None):
138
+ if self.disabled:
139
+ return
140
+ # nccl communicator created on a specific device
141
+ # will only work on tensors on the same device
142
+ # otherwise it will cause "illegal memory access"
143
+ assert input_tensor.device == self.device, (
144
+ f"this nccl communicator is created to work on {self.device}, "
145
+ f"but the input tensor is on {input_tensor.device}")
146
+ if stream is None:
147
+ stream = current_stream()
148
+ self.nccl.ncclAllGather(
149
+ buffer_type(input_tensor.data_ptr()),
150
+ buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
151
+ ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
152
+ cudaStream_t(stream.cuda_stream))
153
+
154
+ def reduce_scatter(self,
155
+ output_tensor: torch.Tensor,
156
+ input_tensor: torch.Tensor,
157
+ op: ReduceOp = ReduceOp.SUM,
158
+ stream=None):
159
+ if self.disabled:
160
+ return
161
+ # nccl communicator created on a specific device
162
+ # will only work on tensors on the same device
163
+ # otherwise it will cause "illegal memory access"
164
+ assert input_tensor.device == self.device, (
165
+ f"this nccl communicator is created to work on {self.device}, "
166
+ f"but the input tensor is on {input_tensor.device}")
167
+ if stream is None:
168
+ stream = current_stream()
169
+ self.nccl.ncclReduceScatter(
170
+ buffer_type(input_tensor.data_ptr()),
171
+ buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
172
+ ncclDataTypeEnum.from_torch(input_tensor.dtype),
173
+ ncclRedOpTypeEnum.from_torch(op), self.comm,
174
+ cudaStream_t(stream.cuda_stream))
175
+
176
+ def send(self, tensor: torch.Tensor, dst: int, stream=None):
177
+ if self.disabled:
178
+ return
179
+ assert tensor.device == self.device, (
180
+ f"this nccl communicator is created to work on {self.device}, "
181
+ f"but the input tensor is on {tensor.device}")
182
+ if stream is None:
183
+ stream = current_stream()
184
+ self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
185
+ ncclDataTypeEnum.from_torch(tensor.dtype), dst,
186
+ self.comm, cudaStream_t(stream.cuda_stream))
187
+
188
+ def recv(self, tensor: torch.Tensor, src: int, stream=None):
189
+ if self.disabled:
190
+ return
191
+ assert tensor.device == self.device, (
192
+ f"this nccl communicator is created to work on {self.device}, "
193
+ f"but the input tensor is on {tensor.device}")
194
+ if stream is None:
195
+ stream = current_stream()
196
+ self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
197
+ ncclDataTypeEnum.from_torch(tensor.dtype), src,
198
+ self.comm, cudaStream_t(stream.cuda_stream))
199
+
200
+ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
201
+ if self.disabled:
202
+ return
203
+ assert tensor.device == self.device, (
204
+ f"this nccl communicator is created to work on {self.device}, "
205
+ f"but the input tensor is on {tensor.device}")
206
+ if stream is None:
207
+ stream = current_stream()
208
+ if src == self.rank:
209
+ sendbuff = buffer_type(tensor.data_ptr())
210
+ # NCCL requires the sender also to have a receive buffer
211
+ recvbuff = buffer_type(tensor.data_ptr())
212
+ else:
213
+ sendbuff = buffer_type()
214
+ recvbuff = buffer_type(tensor.data_ptr())
215
+ self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
216
+ ncclDataTypeEnum.from_torch(tensor.dtype), src,
217
+ self.comm, cudaStream_t(stream.cuda_stream))
.venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/shm_broadcast.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import os
4
+ import pickle
5
+ import sys
6
+ import time
7
+ from contextlib import contextmanager
8
+ from dataclasses import dataclass, field
9
+ from multiprocessing import shared_memory
10
+ from typing import List, Optional, Tuple, Union
11
+ from unittest.mock import patch
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+ from torch.distributed import ProcessGroup
16
+ from zmq import IPV6 # type: ignore
17
+ from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
18
+
19
+ import vllm.envs as envs
20
+ from vllm.distributed.utils import StatelessProcessGroup
21
+ from vllm.logger import init_logger
22
+ from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address
23
+
24
+ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
25
+
26
+ logger = init_logger(__name__)
27
+
28
+ # We prefer to use os.sched_yield as it results in tighter polling loops,
29
+ # measured to be around 3e-7 seconds. However on earlier versions of Python
30
+ # os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
31
+ USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
32
+ or (sys.version_info[:2] == (3, 10)
33
+ and sys.version_info[2] >= 8))
34
+
35
+
36
+ def sched_yield():
37
+ if USE_SCHED_YIELD:
38
+ os.sched_yield()
39
+ else:
40
+ time.sleep(0)
41
+
42
+
43
+ class ShmRingBuffer:
44
+
45
+ def __init__(self,
46
+ n_reader: int,
47
+ max_chunk_bytes: int,
48
+ max_chunks: int,
49
+ name: Optional[str] = None):
50
+ """
51
+ A shared memory ring buffer implementation for broadcast communication.
52
+ Essentially, it is a queue where only one will `enqueue` and multiple
53
+ will `dequeue`. The max size of each item, together with the max number
54
+ of items that can be stored in the buffer are known in advance.
55
+ In this case, we don't need to synchronize the access to
56
+ the buffer.
57
+
58
+ Buffer memory layout:
59
+ data metadata
60
+ | |
61
+ | (current_idx) | (current_idx)
62
+ v v
63
+ +-------------------------------+----------------------------------------+
64
+ | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
65
+ +-------------------------------+----------------------------------------+
66
+ | max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes |
67
+
68
+ metadata memory layout: each byte is a flag, the first byte is the written
69
+ flag, and the rest are reader flags. The flags are set to 0 by default.
70
+ +--------------+--------------+--------------+-----+--------------+
71
+ | written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
72
+ +--------------+--------------+--------------+-----+--------------+
73
+
74
+ The state of metadata is as follows:
75
+
76
+ (case 1) 0???...???: the block is not written yet, cannot read, can write
77
+ (case 2) 1000...000: the block is just written, can read, cannot write
78
+ (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
79
+ (case 4) 1111...111: the block is written and read by all readers, cannot read, can write
80
+
81
+ State transition for readers:
82
+
83
+ When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
84
+ Only after the caller finishes reading the block, the reader can mark the block as read.
85
+ Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).
86
+
87
+ State transition for writer:
88
+
89
+ When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
90
+ to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
91
+ can reset the reader flags to 0, and mark the block as written (from 0 to 1).
92
+ NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.
93
+
94
+ During creation, `name` is None and the buffer is created. We can pass the
95
+ created object to other processes by pickling it. The other processes will
96
+ get the name of the shared memory and open it, so that they can access the
97
+ same shared memory buffer.
98
+ """# noqa
99
+ self.n_reader = n_reader
100
+ self.metadata_size = 1 + n_reader
101
+ self.max_chunk_bytes = max_chunk_bytes
102
+ self.max_chunks = max_chunks
103
+ self.total_bytes_of_buffer = (self.max_chunk_bytes +
104
+ self.metadata_size) * self.max_chunks
105
+ self.data_offset = 0
106
+ self.metadata_offset = self.max_chunk_bytes * self.max_chunks
107
+
108
+ if name is None:
109
+ # we are creating a buffer
110
+ self.is_creator = True
111
+ self.shared_memory = shared_memory.SharedMemory(
112
+ create=True, size=self.total_bytes_of_buffer)
113
+ # initialize the metadata section to 0
114
+ with memoryview(self.shared_memory.buf[self.metadata_offset:]
115
+ ) as metadata_buffer:
116
+ torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
117
+ else:
118
+ # we are opening an existing buffer
119
+ self.is_creator = False
120
+ # fix to https://stackoverflow.com/q/62748654/9191338
121
+ # Python incorrectly tracks shared memory even if it is not
122
+ # created by the process. The following patch is a workaround.
123
+ with patch("multiprocessing.resource_tracker.register",
124
+ lambda *args, **kwargs: None):
125
+ try:
126
+ self.shared_memory = shared_memory.SharedMemory(name=name)
127
+ assert (
128
+ self.shared_memory.size == self.total_bytes_of_buffer)
129
+ except FileNotFoundError:
130
+ # we might deserialize the object in a different node
131
+ # in this case, this object is not used,
132
+ # and we should suppress the error
133
+ pass
134
+
135
+ def handle(self):
136
+ return (self.n_reader, self.max_chunk_bytes, self.max_chunks,
137
+ self.shared_memory.name)
138
+
139
+ def __reduce__(self):
140
+ return (
141
+ self.__class__,
142
+ self.handle(),
143
+ )
144
+
145
+ def __del__(self):
146
+ if hasattr(self, "shared_memory"):
147
+ self.shared_memory.close()
148
+ if self.is_creator:
149
+ self.shared_memory.unlink()
150
+
151
+ @contextmanager
152
+ def get_data(self, current_idx: int):
153
+ start = self.data_offset + current_idx * self.max_chunk_bytes
154
+ end = start + self.max_chunk_bytes
155
+ with memoryview(self.shared_memory.buf[start:end]) as buf:
156
+ yield buf
157
+
158
+ @contextmanager
159
+ def get_metadata(self, current_idx: int):
160
+ start = self.metadata_offset + current_idx * self.metadata_size
161
+ end = start + self.metadata_size
162
+ with memoryview(self.shared_memory.buf[start:end]) as buf:
163
+ yield buf
164
+
165
+
166
+ @dataclass
167
+ class Handle:
168
+ connect_ip: str
169
+ local_reader_ranks: List[int] = field(default_factory=list)
170
+
171
+ buffer_handle: Optional[Tuple[int, int, int, str]] = None
172
+ local_subscribe_port: Optional[int] = None
173
+ remote_subscribe_port: Optional[int] = None
174
+
175
+
176
+ class MessageQueue:
177
+
178
+ def __init__(
179
+ self,
180
+ n_reader, # number of all readers
181
+ n_local_reader, # number of local readers through shared memory
182
+ local_reader_ranks: Optional[List[int]] = None,
183
+ max_chunk_bytes: int = 1024 * 1024 * 10,
184
+ max_chunks: int = 10,
185
+ connect_ip: Optional[str] = None,
186
+ ):
187
+ if local_reader_ranks is None:
188
+ local_reader_ranks = list(range(n_local_reader))
189
+ else:
190
+ assert len(local_reader_ranks) == n_local_reader
191
+ self.n_local_reader = n_local_reader
192
+ n_remote_reader = n_reader - n_local_reader
193
+ self.n_remote_reader = n_remote_reader
194
+
195
+ if connect_ip is None:
196
+ connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1"
197
+
198
+ context = Context()
199
+
200
+ if n_local_reader > 0:
201
+ # for local readers, we will:
202
+ # 1. create a shared memory ring buffer to communicate small data
203
+ # 2. create a publish-subscribe socket to communicate large data
204
+ self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
205
+ max_chunks)
206
+
207
+ # XPUB is very similar to PUB,
208
+ # except that it can receive subscription messages
209
+ # to confirm the number of subscribers
210
+ self.local_socket = context.socket(XPUB)
211
+ # set the verbose option so that we can receive every subscription
212
+ # message. otherwise, we will only receive the first subscription
213
+ # see http://api.zeromq.org/3-3:zmq-setsockopt for more details
214
+ self.local_socket.setsockopt(XPUB_VERBOSE, True)
215
+ local_subscribe_port = get_open_port()
216
+ socket_addr = f"tcp://127.0.0.1:{local_subscribe_port}"
217
+ logger.debug("Binding to %s", socket_addr)
218
+ self.local_socket.bind(socket_addr)
219
+
220
+ self.current_idx = 0
221
+
222
+ else:
223
+ self.buffer = None # type: ignore
224
+ local_subscribe_port = None
225
+ self.local_socket = None
226
+ self.current_idx = -1
227
+
228
+ if n_remote_reader > 0:
229
+ # for remote readers, we will:
230
+ # create a publish-subscribe socket to communicate large data
231
+ self.remote_socket = context.socket(XPUB)
232
+ self.remote_socket.setsockopt(XPUB_VERBOSE, True)
233
+ remote_subscribe_port = get_open_port()
234
+ if is_valid_ipv6_address(connect_ip):
235
+ self.remote_socket.setsockopt(IPV6, 1)
236
+ socket_addr = f"tcp://*:{remote_subscribe_port}"
237
+ self.remote_socket.bind(socket_addr)
238
+
239
+ else:
240
+ remote_subscribe_port = None
241
+ self.remote_socket = None
242
+
243
+ self._is_writer = True
244
+ self._is_local_reader = False
245
+ self.local_reader_rank = -1
246
+ # rank does not matter for remote readers
247
+ self._is_remote_reader = False
248
+
249
+ self.handle = Handle(
250
+ connect_ip=connect_ip,
251
+ local_reader_ranks=local_reader_ranks,
252
+ buffer_handle=self.buffer.handle()
253
+ if self.buffer is not None else None,
254
+ local_subscribe_port=local_subscribe_port,
255
+ remote_subscribe_port=remote_subscribe_port,
256
+ )
257
+
258
+ logger.info("vLLM message queue communication handle: %s", self.handle)
259
+
260
+ def export_handle(self) -> Handle:
261
+ return self.handle
262
+
263
+ @staticmethod
264
+ def create_from_handle(handle: Handle, rank) -> "MessageQueue":
265
+ self = MessageQueue.__new__(MessageQueue)
266
+ self.handle = handle
267
+ self._is_writer = False
268
+
269
+ context = Context()
270
+
271
+ if rank in handle.local_reader_ranks:
272
+ assert handle.buffer_handle is not None
273
+ self.buffer = ShmRingBuffer(*handle.buffer_handle)
274
+ self.current_idx = 0
275
+ self.local_reader_rank = handle.local_reader_ranks.index(rank)
276
+ self._is_local_reader = True
277
+ self._is_remote_reader = False
278
+
279
+ self.local_socket = context.socket(SUB)
280
+ self.local_socket.setsockopt_string(SUBSCRIBE, "")
281
+ socket_addr = f"tcp://127.0.0.1:{handle.local_subscribe_port}"
282
+ logger.debug("Connecting to %s", socket_addr)
283
+ self.local_socket.connect(socket_addr)
284
+
285
+ self.remote_socket = None
286
+ else:
287
+ self.buffer = None # type: ignore
288
+ self.current_idx = -1
289
+ self.local_reader_rank = -1
290
+ self._is_local_reader = False
291
+ self._is_remote_reader = True
292
+
293
+ self.local_socket = None
294
+
295
+ self.remote_socket = context.socket(SUB)
296
+ self.remote_socket.setsockopt_string(SUBSCRIBE, "")
297
+ if is_valid_ipv6_address(handle.connect_ip):
298
+ self.remote_socket.setsockopt(IPV6, 1)
299
+ socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}"
300
+ logger.debug("Connecting to %s", socket_addr)
301
+ self.remote_socket.connect(socket_addr)
302
+
303
+ return self
304
+
305
+ def wait_until_ready(self):
306
+ """This is a collective operation. All processes (including the
307
+ readers and the writer) should call this function.
308
+ """
309
+ if self._is_writer:
310
+ # wait for all readers to connect
311
+
312
+ # local readers
313
+ for i in range(self.n_local_reader):
314
+ # wait for subscription messages from all local readers
315
+ self.local_socket.recv()
316
+ if self.n_local_reader > 0:
317
+ # send a message to all local readers
318
+ # to make sure the publish channel is working
319
+ self.local_socket.send(b"READY")
320
+
321
+ # remote readers
322
+ for i in range(self.n_remote_reader):
323
+ # wait for subscription messages from all remote readers
324
+ self.remote_socket.recv()
325
+ if self.n_remote_reader > 0:
326
+ # send a message to all remote readers
327
+ # to make sure the publish channel is working
328
+ self.remote_socket.send(b"READY")
329
+ elif self._is_local_reader:
330
+ # wait for the writer to send a message
331
+ recv = self.local_socket.recv()
332
+ assert recv == b"READY"
333
+ elif self._is_remote_reader:
334
+ # wait for the writer to send a message
335
+ recv = self.remote_socket.recv()
336
+ assert recv == b"READY"
337
+
338
+ @contextmanager
339
+ def acquire_write(self, timeout: Optional[float] = None):
340
+ assert self._is_writer, "Only writers can acquire write"
341
+ start_time = time.monotonic()
342
+ n_warning = 1
343
+ while True:
344
+ with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
345
+ read_count = sum(metadata_buffer[1:])
346
+ written_flag = metadata_buffer[0]
347
+ if written_flag and read_count != self.buffer.n_reader:
348
+ # this block is written and not read by all readers
349
+ # for writers, `self.current_idx` is the next block to write
350
+ # if this block is not ready to write,
351
+ # we need to wait until it is read by all readers
352
+
353
+ # Release the processor to other threads
354
+ sched_yield()
355
+
356
+ # if we wait for a long time, log a message
357
+ if (time.monotonic() - start_time
358
+ > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
359
+ logger.debug("No available block found in %s second. ",
360
+ VLLM_RINGBUFFER_WARNING_INTERVAL)
361
+ n_warning += 1
362
+
363
+ # if we time out, raise an exception
364
+ if (timeout is not None
365
+ and time.monotonic() - start_time > timeout):
366
+ raise TimeoutError
367
+
368
+ continue
369
+ # found a block that is either
370
+ # (1) not written
371
+ # (2) read by all readers
372
+
373
+ # mark the block as not written
374
+ metadata_buffer[0] = 0
375
+ # let caller write to the buffer
376
+ with self.buffer.get_data(self.current_idx) as buf:
377
+ yield buf
378
+
379
+ # caller has written to the buffer
380
+ # NOTE: order is important here
381
+ # first set the read flags to 0
382
+ # then set the written flag to 1
383
+ # otherwise, the readers may think they already read the block
384
+ for i in range(1, self.buffer.n_reader + 1):
385
+ # set read flag to 0, meaning it is not read yet
386
+ metadata_buffer[i] = 0
387
+ # mark the block as written
388
+ metadata_buffer[0] = 1
389
+ self.current_idx = (self.current_idx +
390
+ 1) % self.buffer.max_chunks
391
+ break
392
+
393
+ @contextmanager
394
+ def acquire_read(self, timeout: Optional[float] = None):
395
+ assert self._is_local_reader, "Only readers can acquire read"
396
+ start_time = time.monotonic()
397
+ n_warning = 1
398
+ while True:
399
+ with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
400
+ read_flag = metadata_buffer[self.local_reader_rank + 1]
401
+ written_flag = metadata_buffer[0]
402
+ if not written_flag or read_flag:
403
+ # this block is either
404
+ # (1) not written
405
+ # (2) already read by this reader
406
+
407
+ # for readers, `self.current_idx` is the next block to read
408
+ # if this block is not ready,
409
+ # we need to wait until it is written
410
+
411
+ # Release the processor to other threads
412
+ sched_yield()
413
+
414
+ # if we wait for a long time, log a message
415
+ if (time.monotonic() - start_time
416
+ > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
417
+ logger.debug("No available block found in %s second. ",
418
+ VLLM_RINGBUFFER_WARNING_INTERVAL)
419
+ n_warning += 1
420
+
421
+ # if we time out, raise an exception
422
+ if (timeout is not None
423
+ and time.monotonic() - start_time > timeout):
424
+ raise TimeoutError
425
+
426
+ continue
427
+ # found a block that is not read by this reader
428
+ # let caller read from the buffer
429
+ with self.buffer.get_data(self.current_idx) as buf:
430
+ yield buf
431
+
432
+ # caller has read from the buffer
433
+ # set the read flag
434
+ metadata_buffer[self.local_reader_rank + 1] = 1
435
+ self.current_idx = (self.current_idx +
436
+ 1) % self.buffer.max_chunks
437
+ break
438
+
439
+ def enqueue(self, obj, timeout: Optional[float] = None):
440
+ """ Write to message queue with optional timeout (in seconds) """
441
+ assert self._is_writer, "Only writers can enqueue"
442
+ serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
443
+ if self.n_local_reader > 0:
444
+ if len(serialized_obj) >= self.buffer.max_chunk_bytes:
445
+ with self.acquire_write(timeout) as buf:
446
+ buf[0] = 1 # overflow
447
+ self.local_socket.send(serialized_obj)
448
+ else:
449
+ with self.acquire_write(timeout) as buf:
450
+ buf[0] = 0 # not overflow
451
+ buf[1:len(serialized_obj) + 1] = serialized_obj
452
+ if self.n_remote_reader > 0:
453
+ self.remote_socket.send(serialized_obj)
454
+
455
+ def dequeue(self, timeout: Optional[float] = None):
456
+ """ Read from message queue with optional timeout (in seconds) """
457
+ if self._is_local_reader:
458
+ with self.acquire_read(timeout) as buf:
459
+ overflow = buf[0] == 1
460
+ if not overflow:
461
+ # no need to know the size of serialized object
462
+ # pickle format contains the size information internally
463
+ # see https://docs.python.org/3/library/pickle.html
464
+ obj = pickle.loads(buf[1:])
465
+ if overflow:
466
+ recv = self.local_socket.recv()
467
+ obj = pickle.loads(recv)
468
+ elif self._is_remote_reader:
469
+ recv = self.remote_socket.recv()
470
+ obj = pickle.loads(recv)
471
+ else:
472
+ raise RuntimeError("Only readers can dequeue")
473
+ return obj
474
+
475
+ def broadcast_object(self, obj=None):
476
+ if self._is_writer:
477
+ self.enqueue(obj)
478
+ return obj
479
+ else:
480
+ return self.dequeue()
481
+
482
+ @staticmethod
483
+ def create_from_process_group(pg: Union[ProcessGroup,
484
+ StatelessProcessGroup],
485
+ max_chunk_bytes,
486
+ max_chunks,
487
+ writer_rank=0) -> "MessageQueue":
488
+ if isinstance(pg, ProcessGroup):
489
+ group_rank = dist.get_rank(pg)
490
+ group_world_size = dist.get_world_size(pg)
491
+ global_ranks = dist.get_process_group_ranks(pg)
492
+ else:
493
+ group_rank = pg.rank
494
+ group_world_size = pg.world_size
495
+ global_ranks = list(range(pg.world_size))
496
+
497
+ from vllm.distributed.parallel_state import in_the_same_node_as
498
+ status = in_the_same_node_as(pg, source_rank=writer_rank)
499
+ same_node_ranks = [i for i, s in enumerate(status) if s]
500
+ n_reader = group_world_size - 1
501
+ n_local_reader = len(same_node_ranks) - 1
502
+ local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
503
+ buffer_io: MessageQueue
504
+ if group_rank == writer_rank:
505
+ buffer_io = MessageQueue(
506
+ n_reader=n_reader,
507
+ n_local_reader=n_local_reader,
508
+ local_reader_ranks=local_reader_ranks,
509
+ max_chunk_bytes=max_chunk_bytes,
510
+ max_chunks=max_chunks,
511
+ )
512
+ handle = buffer_io.export_handle()
513
+ if isinstance(pg, ProcessGroup):
514
+ dist.broadcast_object_list([handle],
515
+ src=global_ranks[writer_rank],
516
+ group=pg)
517
+ else:
518
+ pg.broadcast_obj(handle, writer_rank)
519
+ else:
520
+ if isinstance(pg, ProcessGroup):
521
+ recv = [None]
522
+ dist.broadcast_object_list(recv,
523
+ src=global_ranks[writer_rank],
524
+ group=pg)
525
+ handle = recv[0] # type: ignore
526
+ else:
527
+ handle = pg.broadcast_obj(None, writer_rank)
528
+ buffer_io = MessageQueue.create_from_handle(handle, group_rank)
529
+ buffer_io.wait_until_ready()
530
+ return buffer_io
.venv/lib/python3.11/site-packages/vllm/distributed/device_communicators/xpu_communicator.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+ from torch.distributed import ProcessGroup
6
+
7
+ from vllm.platforms import current_platform
8
+
9
+
10
+ class XpuCommunicator:
11
+
12
+ def __init__(self, group: ProcessGroup):
13
+ if not current_platform.is_xpu():
14
+ self.disabled = True
15
+ return
16
+ self.disabled = False
17
+ self.group = group
18
+ self.world_size = dist.get_world_size(self.group)
19
+
20
+ def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
21
+ dist.all_reduce(x, group=self.group)
22
+ return x
23
+
24
+ def gather(self,
25
+ input_: torch.Tensor,
26
+ rank_in_group: int,
27
+ dst: int = 0,
28
+ dim: int = -1):
29
+ # For xpu path, gather doesn't work properly together with ray
30
+ # cluster so we use all_gather instead for now.
31
+ input_size = input_.size()
32
+ # Allocate output tensor.
33
+ output_tensor = torch.empty((self.world_size, ) + input_size,
34
+ dtype=input_.dtype,
35
+ device=input_.device)
36
+ # All-gather.
37
+ torch.distributed.all_gather_into_tensor(output_tensor,
38
+ input_,
39
+ group=self.group)
40
+ if rank_in_group == dst:
41
+ # Reshape
42
+ output_tensor = output_tensor.movedim(0, dim)
43
+ output_tensor = output_tensor.reshape(input_size[:dim] +
44
+ (self.world_size *
45
+ input_size[dim], ) +
46
+ input_size[dim + 1:])
47
+ else:
48
+ output_tensor = None
49
+ return output_tensor
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (201 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/__pycache__/kv_transfer_agent.cpython-311.pyc ADDED
Binary file (3.6 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (214 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/__pycache__/base.cpython-311.pyc ADDED
Binary file (5.36 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/__pycache__/factory.cpython-311.pyc ADDED
Binary file (2.78 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/__pycache__/simple_connector.cpython-311.pyc ADDED
Binary file (12.6 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/base.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """
3
+ KVConnectorBase Class for Distributed KV Cache & Hidden State communication
4
+
5
+ The class provides two primary abstract methods:
6
+ 1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
7
+ 2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
8
+ """
9
+
10
+ from abc import ABC, abstractmethod
11
+ from typing import TYPE_CHECKING, List, Tuple, Union
12
+
13
+ import torch
14
+
15
+ from vllm.sequence import IntermediateTensors
16
+
17
+ if TYPE_CHECKING:
18
+ from vllm.config import VllmConfig
19
+ from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
20
+
21
+
22
+ class KVConnectorBase(ABC):
23
+ """
24
+ Abstract base class for a KV connector.
25
+
26
+ The class provides two primary abstract methods:
27
+ 1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
28
+ 2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
29
+ """
30
+
31
+ @abstractmethod
32
+ def __init__(
33
+ self,
34
+ rank: int,
35
+ local_rank: int,
36
+ config: "VllmConfig",
37
+ ):
38
+ raise NotImplementedError
39
+
40
+ @abstractmethod
41
+ def close(self) -> None:
42
+ """Close the buffer and release resources.
43
+
44
+ This method is responsible for cleaning up resources related to the
45
+ connector when it is no longer needed.
46
+
47
+ Raises:
48
+ NotImplementedError: This method must be implemented in subclasses.
49
+ """
50
+ raise NotImplementedError
51
+
52
+ @abstractmethod
53
+ def send_kv_caches_and_hidden_states(
54
+ self,
55
+ model_executable: torch.nn.Module,
56
+ model_input: "ModelInputForGPUWithSamplingMetadata",
57
+ kv_caches: List[torch.Tensor],
58
+ hidden_or_intermediate_states: Union[torch.Tensor,
59
+ IntermediateTensors],
60
+ ) -> None:
61
+ """
62
+ Send KV caches and hidden states to the connector.
63
+
64
+ This method processes the input tokens, KV caches, and
65
+ hidden/intermediate states for a given model and sends the data to the
66
+ decode instance.
67
+
68
+ Args:
69
+ model_executable (torch.nn.Module): The model executable containing
70
+ start and end layer information.
71
+ model_input (ModelInputForGPUWithSamplingMetadata): The input
72
+ metadata from vLLM.
73
+ kv_caches (List[torch.Tensor]): List of KV caches (keys and values)
74
+ for each layer.
75
+ hidden_or_intermediate_states (Union[torch.Tensor,
76
+ IntermediateTensors]):
77
+ The hidden or intermediate states associated with the tokens.
78
+
79
+ Returns:
80
+ None
81
+
82
+ """
83
+
84
+ raise NotImplementedError
85
+
86
+ @abstractmethod
87
+ def recv_kv_caches_and_hidden_states(
88
+ self, model_executable: torch.nn.Module,
89
+ model_input: "ModelInputForGPUWithSamplingMetadata",
90
+ kv_caches: List[torch.Tensor]
91
+ ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
92
+ "ModelInputForGPUWithSamplingMetadata"]:
93
+ """
94
+ Receive KV caches and hidden states from the connector.
95
+
96
+ This method attempts to retrieve KV caches and hidden states for input
97
+ tokens. If all required KV caches and hidden states are received, it
98
+ will bypass model input, else it will fall back to normal vLLM model
99
+ forwarding.
100
+
101
+ Args:
102
+ model_executable (torch.nn.Module):
103
+ The model executable from vLLM modelrunner.
104
+ model_input (ModelInputForGPUWithSamplingMetadata):
105
+ The model input from vLLM modelrunner.
106
+ kv_caches (List[torch.Tensor]):
107
+ List of KV caches for each layer.
108
+
109
+ Returns:
110
+ - hidden_or_intermediate_states (torch.Tensor or
111
+ IntermediateTensors):
112
+ Concatenated hidden states if all required data is retrieved,
113
+ otherwise `None`.
114
+ - bypass_model_exec (bool):
115
+ Indicates whether the model execution can be skipped (True) or
116
+ needs to be redone (False).
117
+ - model_input (ModelInputForGPUWithSamplingMetadata):
118
+ Optionally adjusted input metadata for re-execution when
119
+ `bypass_model_exec=False`.
120
+
121
+ """
122
+
123
+ raise NotImplementedError
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/factory.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import importlib
4
+ from typing import TYPE_CHECKING, Callable, Dict, Type
5
+
6
+ from .base import KVConnectorBase
7
+
8
+ if TYPE_CHECKING:
9
+ from vllm.config import VllmConfig
10
+
11
+
12
+ class KVConnectorFactory:
13
+ _registry: Dict[str, Callable[[], Type[KVConnectorBase]]] = {}
14
+
15
+ @classmethod
16
+ def register_connector(cls, name: str, module_path: str,
17
+ class_name: str) -> None:
18
+ """Register a connector with a lazy-loading module and class name."""
19
+ if name in cls._registry:
20
+ raise ValueError(f"Connector '{name}' is already registered.")
21
+
22
+ def loader() -> Type[KVConnectorBase]:
23
+ module = importlib.import_module(module_path)
24
+ return getattr(module, class_name)
25
+
26
+ cls._registry[name] = loader
27
+
28
+ @classmethod
29
+ def create_connector(cls, rank: int, local_rank: int,
30
+ config: "VllmConfig") -> KVConnectorBase:
31
+ connector_name = config.kv_transfer_config.kv_connector
32
+ if connector_name not in cls._registry:
33
+ raise ValueError(f"Unsupported connector type: {connector_name}")
34
+
35
+ connector_cls = cls._registry[connector_name]()
36
+ return connector_cls(rank, local_rank, config)
37
+
38
+
39
+ # Register various connectors here.
40
+ # The registration should not be done in each individual file, as we want to
41
+ # only load the files corresponding to the current connector.
42
+ KVConnectorFactory.register_connector(
43
+ "PyNcclConnector",
44
+ "vllm.distributed.kv_transfer.kv_connector.simple_connector",
45
+ "SimpleConnector")
46
+
47
+ KVConnectorFactory.register_connector(
48
+ "MooncakeConnector",
49
+ "vllm.distributed.kv_transfer.kv_connector.simple_connector",
50
+ "SimpleConnector")
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_connector/simple_connector.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """
3
+ Simple KV Cache Connector for Distributed Machine Learning Inference
4
+
5
+ The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
6
+ producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
7
+ MooncakePipe.
8
+
9
+ But the logic can be extended to support other pipe and lookup buffer.
10
+ """
11
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
12
+
13
+ import torch
14
+
15
+ from vllm import _custom_ops as ops
16
+ from vllm.config import VllmConfig
17
+ from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
18
+ from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
19
+ SimpleBuffer)
20
+ from vllm.logger import init_logger
21
+ from vllm.sequence import IntermediateTensors
22
+
23
+ if TYPE_CHECKING:
24
+ from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
25
+
26
+ logger = init_logger(__name__)
27
+
28
+
29
+ class SimpleConnector(KVConnectorBase):
30
+
31
+ def __init__(
32
+ self,
33
+ rank: int,
34
+ local_rank: int,
35
+ config: VllmConfig,
36
+ ):
37
+
38
+ self.config = config.kv_transfer_config
39
+ self.tp_size = config.parallel_config.tensor_parallel_size
40
+
41
+ if self.config.kv_connector == "PyNcclConnector":
42
+ from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
43
+ PyNcclPipe)
44
+ logger.info(
45
+ "Initializing PyNcclConfig under kv_transfer_config %s",
46
+ self.config)
47
+ elif self.config.kv_connector == "MooncakeConnector":
48
+ # Check if MOONCAKE_CONFIG_PATH is set
49
+ import os
50
+ use_mooncake_distributed_pipe = os.getenv(
51
+ 'MOONCAKE_CONFIG_PATH') is not None
52
+
53
+ if not use_mooncake_distributed_pipe:
54
+ raise ValueError(
55
+ "To use MooncakeConnector, you need to pass the ENV: "
56
+ "'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.")
57
+ else:
58
+ from vllm.distributed.kv_transfer.kv_pipe.mooncake_pipe import ( # noqa: E501
59
+ MooncakePipe)
60
+ logger.info(
61
+ "Initializing MooncakeConfig under kv_transfer_config %s",
62
+ self.config)
63
+
64
+ self.lookup_buffer_size = self.config.kv_buffer_size
65
+
66
+ self.producer_buffer: Optional[SimpleBuffer] = None
67
+ self.consumer_buffer: Optional[SimpleBuffer] = None
68
+
69
+ self.producer_data_pipe: Union[PyNcclPipe, MooncakePipe]
70
+ self.consumer_data_pipe: Union[PyNcclPipe, MooncakePipe]
71
+ self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
72
+ self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
73
+
74
+ # 2 pipes for every rank in the world
75
+ port_offset_base = 2 * rank
76
+
77
+ # In disaggregated prefill, the prefill vLLM only uses send pipe
78
+ # and the decode vLLM only uses recv pipe
79
+ if self.config.is_kv_producer:
80
+
81
+ if self.config.kv_connector == "PyNcclConnector":
82
+ self.producer_data_pipe = PyNcclPipe(
83
+ local_rank=local_rank,
84
+ config=self.config,
85
+ port_offset=port_offset_base,
86
+ )
87
+ self.producer_signal_pipe = PyNcclPipe(
88
+ local_rank=local_rank,
89
+ config=self.config,
90
+ port_offset=port_offset_base + 1,
91
+ device="cpu",
92
+ )
93
+ elif self.config.kv_connector == "MooncakeConnector":
94
+ self.producer_data_pipe = MooncakePipe(
95
+ local_rank=local_rank,
96
+ config=self.config,
97
+ )
98
+ # We only need to initialize MooncakePipe once
99
+ self.producer_signal_pipe = self.producer_data_pipe
100
+
101
+ self.producer_buffer = SimpleBuffer(self.producer_signal_pipe,
102
+ self.producer_data_pipe,
103
+ self.config.kv_buffer_size)
104
+
105
+ else:
106
+
107
+ # the current vLLM instance is KV consumer, so it needs to connect
108
+ # its recv pipe to the send pipe of KV producder
109
+ if self.config.kv_connector == "PyNcclConnector":
110
+ self.consumer_data_pipe = PyNcclPipe(
111
+ local_rank=local_rank,
112
+ config=self.config,
113
+ port_offset=port_offset_base,
114
+ )
115
+ self.consumer_signal_pipe = PyNcclPipe(
116
+ local_rank=local_rank,
117
+ config=self.config,
118
+ port_offset=port_offset_base + 1,
119
+ device="cpu",
120
+ )
121
+ elif self.config.kv_connector == "MooncakeConnector":
122
+ self.consumer_data_pipe = MooncakePipe(
123
+ local_rank=local_rank,
124
+ config=self.config,
125
+ )
126
+ self.consumer_signal_pipe = self.consumer_data_pipe
127
+
128
+ self.consumer_buffer = SimpleBuffer(
129
+ self.consumer_signal_pipe,
130
+ self.consumer_data_pipe,
131
+ self.config.kv_buffer_size,
132
+ )
133
+
134
+ def select(self, input_tokens: Optional[torch.Tensor],
135
+ roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
136
+
137
+ assert self.consumer_buffer is not None, "Please initialize the "\
138
+ "consumer buffer before calling select."
139
+ return self.consumer_buffer.drop_select(input_tokens, roi)
140
+
141
+ def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
142
+ key: torch.Tensor, value: torch.Tensor,
143
+ hidden: torch.Tensor) -> None:
144
+
145
+ assert self.producer_buffer is not None, "Please initialize the "\
146
+ "producer buffer before calling insert."
147
+
148
+ self.producer_buffer.insert(input_tokens, roi, key, value, hidden)
149
+
150
+ def send_kv_caches_and_hidden_states(
151
+ self,
152
+ model_executable: torch.nn.Module,
153
+ model_input: "ModelInputForGPUWithSamplingMetadata",
154
+ kv_caches: List[torch.Tensor],
155
+ hidden_or_intermediate_states: Union[torch.Tensor,
156
+ IntermediateTensors],
157
+ ) -> None:
158
+
159
+ input_tokens_tensor = model_input.input_tokens
160
+ seq_lens = model_input.attn_metadata.seq_lens
161
+ slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
162
+ start_layer = model_executable.model.start_layer
163
+ end_layer = model_executable.model.end_layer
164
+
165
+ model_config = model_executable.model.config
166
+ num_heads = int(model_config.num_key_value_heads / self.tp_size)
167
+ hidden_size = model_config.hidden_size
168
+ num_attention_heads = model_config.num_attention_heads
169
+ head_size = int(hidden_size / num_attention_heads)
170
+
171
+ # query_lens contains new KV caches that are added to vLLM.
172
+ # so we will send them to decode instance
173
+ # FIXME(Kuntai): This assume that all requests are prefill.
174
+ for idx, slen in enumerate(seq_lens):
175
+ start_pos = sum(seq_lens[:idx])
176
+ end_pos = start_pos + slen
177
+ current_tokens = input_tokens_tensor[start_pos:end_pos]
178
+
179
+ keys, values = [], []
180
+
181
+ for layer_id in range(start_layer, end_layer):
182
+ kv_cache = kv_caches[layer_id - start_layer]
183
+
184
+ key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
185
+ value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
186
+
187
+ current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
188
+
189
+ keys.append(key_cache[current_slot_mapping].unsqueeze(0))
190
+ values.append(value_cache[current_slot_mapping].unsqueeze(0))
191
+
192
+ keys = torch.cat(keys, dim=0)
193
+ values = torch.cat(values, dim=0)
194
+
195
+ self.insert(current_tokens,
196
+ torch.ones_like(current_tokens,
197
+ dtype=bool), keys, values,
198
+ hidden_or_intermediate_states[start_pos:end_pos])
199
+
200
+ logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
201
+
202
+ def recv_kv_caches_and_hidden_states(
203
+ self, model_executable: torch.nn.Module,
204
+ model_input: "ModelInputForGPUWithSamplingMetadata",
205
+ kv_caches: List[torch.Tensor]
206
+ ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
207
+ "ModelInputForGPUWithSamplingMetadata"]:
208
+
209
+ # When bypass_model_exec is set to False, it means that at least for one
210
+ # request its corresponding KV cache or hidden state is missing.
211
+ # In this case we need to do prefilling to recompute missing KV cache
212
+ # and hidden states.
213
+ bypass_model_exec = True
214
+
215
+ input_tokens_tensor = model_input.input_tokens
216
+ seq_lens = model_input.attn_metadata.seq_lens
217
+ slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
218
+
219
+ hidden_or_intermediate_states_for_one_req = []
220
+
221
+ input_tokens_list = []
222
+ num_computed_tokens_list = []
223
+ start_pos_list = []
224
+
225
+ # enumerate different requests
226
+ # FIXME(Kuntai): This impl assumes that all requests are prefill.
227
+ for idx, slen in enumerate(seq_lens):
228
+
229
+ start_pos = sum(seq_lens[:idx])
230
+ end_pos = start_pos + slen
231
+ current_tokens = input_tokens_tensor[start_pos:end_pos]
232
+ num_tokens = slen
233
+
234
+ # collecting data for rebuilding the input
235
+ input_tokens_list.append(current_tokens)
236
+ start_pos_list.append(start_pos)
237
+
238
+ ret = self.select(current_tokens,
239
+ torch.ones_like(current_tokens, dtype=bool))
240
+ if ret[0] is None:
241
+ # didn't find any match.
242
+ bypass_model_exec = False
243
+ num_computed_tokens_list.append(0)
244
+ continue
245
+
246
+ roi: torch.Tensor = ret[1]
247
+ keys: torch.Tensor = ret[2]
248
+ values: torch.Tensor = ret[3]
249
+ hidden: torch.Tensor = ret[4]
250
+
251
+ num_computed_tokens = roi.shape[0]
252
+ num_computed_tokens_list.append(num_computed_tokens)
253
+
254
+ # check if both KV cache and the hidden states are received
255
+ # If not, need to redo the forwarding to compute missing states
256
+ if not all([(num_computed_tokens == num_tokens), hidden is not None
257
+ ]):
258
+ bypass_model_exec = False
259
+
260
+ # update the end position based on how many tokens are cached.
261
+ end_pos = start_pos + num_computed_tokens
262
+
263
+ # put received KV caches into paged memory
264
+ for i in range(model_executable.model.start_layer,
265
+ model_executable.model.end_layer):
266
+
267
+ kv_cache = kv_caches[i - model_executable.model.start_layer]
268
+ layer = model_executable.model.layers[i]
269
+
270
+ key_cache, value_cache = kv_cache[0], kv_cache[1]
271
+ ops.reshape_and_cache_flash(
272
+ keys[i - model_executable.model.start_layer].to(
273
+ key_cache.device),
274
+ values[i - model_executable.model.start_layer].to(
275
+ value_cache.device),
276
+ key_cache,
277
+ value_cache,
278
+ slot_mapping[start_pos:end_pos],
279
+ layer.self_attn.attn.kv_cache_dtype,
280
+ layer.self_attn.attn._k_scale,
281
+ layer.self_attn.attn._v_scale,
282
+ )
283
+
284
+ hidden_or_intermediate_states_for_one_req.append(hidden)
285
+
286
+ if not bypass_model_exec:
287
+ # Some of the KV cache is not retrieved
288
+ # Here we will fall back to normal model forwarding
289
+ # But optionally you can adjust model_input so that you only do
290
+ # prefilling on those tokens that are missing KV caches.
291
+ logger.debug(
292
+ "[rank%d]: Failed to receive all KVs and hidden "
293
+ "states, redo model forwarding.", torch.distributed.get_rank())
294
+ hidden_or_intermediate_states = None
295
+
296
+ else:
297
+ logger.debug(
298
+ "[rank%d]: Successfully received all KVs and hidden "
299
+ "states, skip model forwarding.", torch.distributed.get_rank())
300
+ hidden_or_intermediate_states = torch.cat(
301
+ hidden_or_intermediate_states_for_one_req, dim=0)
302
+
303
+ return hidden_or_intermediate_states, bypass_model_exec, model_input
304
+
305
+ def close(self):
306
+ self.producer_data_pipe.close()
307
+ self.consumer_data_pipe.close()
308
+ if self.config.kv_connector == "PyNcclConnector":
309
+ self.producer_signal_pipe.close()
310
+ self.consumer_signal_pipe.close()
311
+ elif self.config.kv_connector == "MooncakeConnector":
312
+ # MooncakePipe reuses data_pipe for signal_pipe, so we only have to
313
+ # close the data_pipe.
314
+ pass
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (218 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/__pycache__/base.cpython-311.pyc ADDED
Binary file (5.13 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/__pycache__/simple_buffer.cpython-311.pyc ADDED
Binary file (11.6 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """
3
+ This file contains a new class `KVLookupBufferBase` that allows developers to
4
+ think of KV cache operations as inserting new KV cache entries (`insert`)
5
+ into the lookup buffer and querying existing KV caches (`drop_select`)
6
+ from the lookup buffer.
7
+
8
+ All distributed communications are abstracted behind this class.
9
+ """
10
+
11
+ from abc import ABC, abstractmethod
12
+ from typing import List, Optional
13
+
14
+ import torch
15
+
16
+
17
+ class KVLookupBufferBase(ABC):
18
+ """
19
+ Abstract base class for a lookup buffer.
20
+
21
+ This class provides an abstraction for a key-value (KV) cache lookup buffer.
22
+
23
+ The key of the lookup buffer:
24
+ - input_tokens: token IDs of the request
25
+ - roi: a binary mask on top of input_tokens.
26
+ - Purpose of roi: Since KV cache may only be available for a subset of
27
+ tokens in the input (for example, when vLLM is connected to an external
28
+ KV cache service), roi specifies the subset of tokens that the KV cache
29
+ is associated with.
30
+ - NOTE: roi can be further extended to describe which part of KV the
31
+ current process is holding (each process may only hold a part of KV
32
+ due to TP and PP). This is not implemented for now.
33
+
34
+ The value of the lookup buffer:
35
+ - key: the key tensor in the KV cache
36
+ - value: the value tensor in the KV cache
37
+ - hidden: the final hidden state generated by model forwarding. This allows
38
+ vLLM to bypass further model forwarding by transmitting the hidden state.
39
+ """
40
+
41
+ @abstractmethod
42
+ def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
43
+ key: torch.Tensor, value: torch.Tensor,
44
+ hidden: torch.Tensor) -> None:
45
+ """Insert into the lookup buffer.
46
+
47
+ The functionality is similar to the following python statement
48
+ ```
49
+ buffer[input_tokens, roi] = [key, value, hidden]
50
+ ```
51
+
52
+ FIXME: in the future, we should only have two arguments, key and value,
53
+ where key is a tensor dict and value is a tensor dict.
54
+
55
+ FIXME: we should transmit both sampler outputs and the hidden states.
56
+
57
+ Args:
58
+ input_tokens (torch.Tensor): token IDs.
59
+ roi (torch.Tensor): A binary mask on top of the input tokens
60
+ key (torch.Tensor): The key tensor in the KV cache.
61
+ value (torch.Tensor): The value tensor in the KV cache.
62
+ hidden (torch.Tensor): The final hidden state tensor generated
63
+ during model forwarding to bypass model
64
+ forwarding.
65
+
66
+ Raises:
67
+ NotImplementedError: This method must be implemented in subclasses.
68
+ """
69
+ raise NotImplementedError
70
+
71
+ @abstractmethod
72
+ def drop_select(
73
+ self, input_tokens: Optional[torch.Tensor],
74
+ roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
75
+ """Select and *drop* KV cache entries from the lookup buffer.
76
+
77
+ The functionality is similar to the following python statements
78
+ ```
79
+ ret = buffer.pop(input_tokens, roi)
80
+ return ret
81
+ ```
82
+
83
+ If `input_tokens` and `roi` is `None`, it means selecting any of the
84
+ KV caches in the buffer, return, and remove it from the buffer, useful
85
+ when offloading KV cache to KV cache storage service.
86
+
87
+ Args:
88
+ input_tokens (torch.Tensor): token IDs.
89
+ roi (torch.Tensor): A binary mask on top of the input tokens
90
+
91
+ Returns:
92
+ List[Optional[torch.Tensor]]: A list of tensors. Can be None.
93
+
94
+ Raises:
95
+ NotImplementedError: This method must be implemented in subclasses.
96
+ """
97
+ raise NotImplementedError
98
+
99
+ @abstractmethod
100
+ def close(self) -> None:
101
+ """Close the buffer and release resources.
102
+
103
+ This method is responsible for cleaning up resources related to the
104
+ lookup buffer when it is no longer needed.
105
+
106
+ Raises:
107
+ NotImplementedError: This method must be implemented in subclasses.
108
+ """
109
+ raise NotImplementedError
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """
3
+ Implements a distributed key-value (KV) cache transfer mechanism.
4
+
5
+ Key Features:
6
+ - Distributed KV cache transmission using PyNccl pipes.
7
+ - Non-blocking `insert`, blocking `drop_select`.
8
+ - Use CPU signal pipe to avoid racing condition
9
+ - Handles buffer size constraints and provide backpressure mechanism to
10
+ stop the prefill instance when the decode instance is slow.
11
+ """
12
+ import threading
13
+ import time
14
+ from collections import deque
15
+ from typing import Deque, List, Optional, Union
16
+
17
+ import torch
18
+
19
+ from vllm.distributed.kv_transfer.kv_lookup_buffer.base import (
20
+ KVLookupBufferBase)
21
+ from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
22
+ from vllm.logger import init_logger
23
+
24
+ logger = init_logger(__name__)
25
+
26
+
27
+ class SimpleBuffer(KVLookupBufferBase):
28
+
29
+ def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase,
30
+ buffer_size_thresh: float):
31
+ """
32
+ signal_pipe: on CPU
33
+
34
+ NOTE: on-device recv will block all threads in the process, making the
35
+ KV cache producer unable to listen to new request while transmitting
36
+ KV cache. Luckily CPU recv only blocks the current thread so we use
37
+ CPU recv to listen to new request.
38
+
39
+ data_pipe: on device (e.g. GPU)
40
+ """
41
+
42
+ self.buffer: Deque[List[torch.Tensor]] = deque()
43
+
44
+ self.buffer_size = 0
45
+ self.buffer_size_threshold = buffer_size_thresh
46
+ self.buffer_lock = threading.Lock()
47
+ self.signal_pipe = signal_pipe
48
+ self.data_pipe = data_pipe
49
+ self.request_handling_thread: Optional[threading.Thread] = None
50
+
51
+ self.normal_signal = torch.tensor([0], device="cpu")
52
+ self.end_signal = None
53
+
54
+ def _matches(self, tokens_roi_sender: List[torch.Tensor],
55
+ tokens_roi_recver: List[torch.Tensor]):
56
+
57
+ # tokens_roi_sender: tokens and roi of the producer (in the buffer)
58
+ # tokens_roi_recver: tokens and roi of the consumer (query)
59
+
60
+ tokens_sender = tokens_roi_sender[0]
61
+ tokens_recver = tokens_roi_recver[0]
62
+ roi_sender = tokens_roi_sender[1]
63
+ roi_recver = tokens_roi_recver[1]
64
+
65
+ if tokens_recver is None:
66
+ # consumer sends an empty request
67
+ # semantics: DROP SELECT * LIMIT 1
68
+ # so any of the data in the buffer can be drop-selected
69
+ return True
70
+
71
+ # Assuming that roi is a binary mask on tokens
72
+ tokens_sender = tokens_sender[roi_sender]
73
+ tokens_recver = tokens_recver[roi_recver]
74
+
75
+ # simple common prefix matching
76
+ min_length = min(len(tokens_sender), len(tokens_recver))
77
+ if torch.allclose(tokens_sender[:min_length],
78
+ tokens_recver[:min_length]):
79
+ return min_length
80
+
81
+ return 0
82
+
83
+ def _send_tensor_and_dec_size(self,
84
+ tensor: Optional[torch.Tensor]) -> None:
85
+
86
+ assert tensor is not None, "Use self.data_pipe.send(None) instead"
87
+ self.buffer_size -= tensor.element_size() * tensor.numel()
88
+ if tensor.dtype == torch.bool:
89
+ tensor = tensor.float()
90
+ self.data_pipe.send_tensor(tensor)
91
+
92
+ def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]):
93
+
94
+ if isinstance(data, torch.Tensor):
95
+ return data.element_size() * data.numel()
96
+ if not data:
97
+ # cannot perform `not data` on a tensor
98
+ # so this check needs to go after the check above
99
+ return 0
100
+
101
+ raise AssertionError(f"Unknown data type {type(data)}")
102
+
103
+ def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor,
104
+ key: torch.Tensor, value: torch.Tensor,
105
+ hidden: torch.Tensor):
106
+
107
+ if isinstance(input_tokens, torch.Tensor):
108
+ input_tokens = input_tokens.clone()
109
+ if isinstance(roi, torch.Tensor):
110
+ roi = roi.clone()
111
+ if isinstance(key, torch.Tensor):
112
+ key = key.clone()
113
+ if isinstance(value, torch.Tensor):
114
+ value = value.clone()
115
+ if isinstance(hidden, torch.Tensor):
116
+ hidden = hidden.clone()
117
+
118
+ buffer_item = [input_tokens, roi, key, value, hidden]
119
+
120
+ with self.buffer_lock:
121
+ for data in buffer_item:
122
+ self.buffer_size += self._get_element_size(data)
123
+ self.buffer.append(buffer_item)
124
+
125
+ def _is_end_signal(self, signal):
126
+ return signal is None
127
+
128
+ def drop_select_handler(self):
129
+
130
+ try:
131
+
132
+ while True:
133
+ signal = self.signal_pipe.recv_tensor()
134
+ if self._is_end_signal(signal):
135
+ logger.info("Received end signal!")
136
+ break
137
+
138
+ input_tokens = self.data_pipe.recv_tensor()
139
+
140
+ roi = self.data_pipe.recv_tensor()
141
+ assert roi is not None, "Please provide the roi when sending "\
142
+ "drop-select request"
143
+ roi = (roi > 0.5)
144
+ tokens_roi_recver = [input_tokens, roi]
145
+
146
+ matched_length = 0
147
+
148
+ # perform input tokens and roi matching
149
+ # FIXME: this matching is O(n), ideally it should be O(1)
150
+ # but this buffer size won't (and shouldn't) be too large so
151
+ # the fix is not urgent.
152
+ with self.buffer_lock:
153
+
154
+ for _ in range(len(self.buffer)):
155
+
156
+ temp_length = self._matches(self.buffer[0],
157
+ tokens_roi_recver)
158
+ if temp_length > 0:
159
+ matched_length = temp_length
160
+ break
161
+ # rotate the element we just accessed to the end
162
+ self.buffer.rotate(-1)
163
+
164
+ if matched_length > 0:
165
+ # need to clone the tensor
166
+ # in case the tensor is freed before sending finishes
167
+ matched_item = self.buffer.popleft()
168
+ for tensor in matched_item:
169
+ self._send_tensor_and_dec_size(tensor)
170
+
171
+ else:
172
+ # no match, just send None
173
+ for _ in range(5):
174
+ self.data_pipe.send_tensor(None)
175
+
176
+ except RuntimeError as e:
177
+ if 'Connection closed by peer' not in str(e):
178
+ raise e
179
+
180
+ logger.debug("Closing drop_select_handler")
181
+
182
+ def drop_select(
183
+ self, input_tokens: Optional[torch.Tensor],
184
+ roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
185
+
186
+ assert self.request_handling_thread is None, \
187
+ "drop_select should be called by the KV cache consumer "\
188
+ "(e.g. the decode vLLM instance)"
189
+
190
+ if isinstance(input_tokens, torch.Tensor):
191
+ input_tokens = input_tokens.clone()
192
+ if isinstance(roi, torch.Tensor):
193
+ roi = roi.clone().float()
194
+
195
+ self.signal_pipe.send_tensor(self.normal_signal)
196
+ self.data_pipe.send_tensor(input_tokens)
197
+ self.data_pipe.send_tensor(roi)
198
+
199
+ input_tokens = self.data_pipe.recv_tensor()
200
+ roi = self.data_pipe.recv_tensor()
201
+ if roi is not None:
202
+ # convert from float tensor to bool tensor
203
+ # as PyNccl does not support sending bool tensor
204
+ roi = (roi > 0.5)
205
+ key = self.data_pipe.recv_tensor()
206
+ value = self.data_pipe.recv_tensor()
207
+ hidden = self.data_pipe.recv_tensor()
208
+
209
+ return [input_tokens, roi, key, value, hidden]
210
+
211
+ def full_handler(self):
212
+ time.sleep(0.001)
213
+
214
+ def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
215
+ key: torch.Tensor, value: torch.Tensor,
216
+ hidden: torch.Tensor) -> None:
217
+
218
+ if self.buffer_size > self.buffer_size_threshold:
219
+ # log outside the while loop to avoid this message being logged
220
+ # repeatedly.
221
+ logger.debug("KV transfer buffer is full. Handling...")
222
+ while self.buffer_size > self.buffer_size_threshold:
223
+ self.full_handler()
224
+
225
+ self._add_to_buffer(input_tokens, roi, key, value, hidden)
226
+
227
+ # when calling the insert, the current process is a sender
228
+ # need to launch the request handler and start listening to request.
229
+ if self.request_handling_thread is None:
230
+ self.request_handling_thread = threading.Thread(
231
+ target=self.drop_select_handler)
232
+ self.request_handling_thread.start()
233
+
234
+ def close(self):
235
+
236
+ if hasattr(self, "request_handling_thread"
237
+ ) and self.request_handling_thread is not None:
238
+ self.request_handling_thread.join()
239
+
240
+ else:
241
+ # TODO: have a explicit close signal and have a explicit way to
242
+ # check if it's requester
243
+ self.signal_pipe.send_tensor(self.end_signal)
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (209 Bytes). View file
 
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/__pycache__/base.cpython-311.pyc ADDED
Binary file (2.95 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/__pycache__/mooncake_pipe.cpython-311.pyc ADDED
Binary file (18.3 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/__pycache__/pynccl_pipe.cpython-311.pyc ADDED
Binary file (13.8 kB). View file
 
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/base.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """
3
+ This file defines an interface `KVPipeBase`
4
+ that provides an abstraction for sending and receiving tensors, or None, via
5
+ distributed communications.
6
+
7
+ All classes instantiated from this interface are assumed to be a FIFO pipe.
8
+
9
+ If your distributed communication platform already supports key-value lookup,
10
+ you can bypass this interface and directly start from `kv_lookup_buffer`.
11
+ """
12
+
13
+ from abc import ABC, abstractmethod
14
+ from typing import Optional
15
+
16
+ import torch
17
+
18
+
19
+ class KVPipeBase(ABC):
20
+ """
21
+ This class provides an interface for sending and receiving tensors, or
22
+ None, by distributed communications.
23
+ """
24
+
25
+ @abstractmethod
26
+ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
27
+ """Send a tensor, or None, via the pipe.
28
+
29
+ Need to support sending None -- important for error handling.
30
+
31
+ TODO: add a `key` argument so that we can use traditional
32
+ key-value database as the distributed communication mechanism behind
33
+ the pipe.
34
+
35
+ Args:
36
+ tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None.
37
+
38
+ Raises:
39
+ NotImplementedError: This method must be implemented in subclasses.
40
+ """
41
+ raise NotImplementedError
42
+
43
+ @abstractmethod
44
+ def recv_tensor(self) -> Optional[torch.Tensor]:
45
+ """Receive a tensor (can be None) from the pipeline.
46
+
47
+ Returns:
48
+ Optional[torch.Tensor]: The tensor received from the pipeline. Can
49
+ be None.
50
+
51
+ Raises:
52
+ NotImplementedError: This method must be implemented in subclasses.
53
+ """
54
+ raise NotImplementedError
55
+
56
+ @abstractmethod
57
+ def close(self) -> None:
58
+ """Close the pipeline and release resources.
59
+
60
+ This method is responsible for closing the communication pipeline
61
+ and releasing any resources associated with it.
62
+
63
+ Raises:
64
+ NotImplementedError: This method must be implemented in subclasses.
65
+ """
66
+ raise NotImplementedError
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import json
4
+ import os
5
+ import pickle
6
+ from concurrent.futures import ThreadPoolExecutor
7
+ from dataclasses import dataclass
8
+ from typing import Optional, Union
9
+
10
+ import torch
11
+ import zmq
12
+
13
+ from vllm.config import KVTransferConfig
14
+ from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
15
+ from vllm.logger import init_logger
16
+
17
+ logger = init_logger(__name__)
18
+ NONE_INT = -150886311
19
+
20
+
21
+ @dataclass
22
+ class MooncakeTransferEngineConfig:
23
+ prefill_url: str
24
+ decode_url: str
25
+ metadata_backend: Union[str, None]
26
+ metadata_server: str
27
+ protocol: str
28
+ device_name: str
29
+
30
+ @staticmethod
31
+ def from_file(file_path: str) -> 'MooncakeTransferEngineConfig':
32
+ """Load the config from a JSON file."""
33
+ with open(file_path) as fin:
34
+ config = json.load(fin)
35
+ return MooncakeTransferEngineConfig(
36
+ prefill_url=config.get("prefill_url"),
37
+ decode_url=config.get("decode_url"),
38
+ metadata_backend=config.get("metadata_backend", None),
39
+ metadata_server=config.get("metadata_server"),
40
+ protocol=config.get("protocol", "tcp"),
41
+ device_name=config.get("device_name", ""),
42
+ )
43
+
44
+ @staticmethod
45
+ def load_from_env() -> 'MooncakeTransferEngineConfig':
46
+ """Load config from a file specified in the environment variable."""
47
+ config_file_path = os.getenv('MOONCAKE_CONFIG_PATH')
48
+ if config_file_path is None:
49
+ raise ValueError(
50
+ "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
51
+ return MooncakeTransferEngineConfig.from_file(config_file_path)
52
+
53
+
54
+ class MooncakeTransferEngine:
55
+ """Handles the transfer of data using mooncake_vllm_adaptor and ZeroMQ."""
56
+
57
+ def __init__(self, kv_rank: int, local_rank: int):
58
+ try:
59
+ import mooncake_vllm_adaptor as mva
60
+ except ImportError as e:
61
+ raise ImportError(
62
+ "Please install mooncake by following the instructions at "
63
+ "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
64
+ "to run vLLM with MooncakeConnector.") from e
65
+
66
+ self.engine = mva.mooncake_vllm_adaptor()
67
+ self.local_rank = local_rank
68
+
69
+ try:
70
+ self.config = MooncakeTransferEngineConfig.load_from_env()
71
+ logger.info("Mooncake Configuration loaded successfully.")
72
+ except ValueError as e:
73
+ logger.error(e)
74
+ raise
75
+ except Exception as exc:
76
+ logger.error(
77
+ "An error occurred while loading the configuration: %s", exc)
78
+ raise
79
+ prefill_host, base_prefill_port = self.config.prefill_url.split(':')
80
+ decode_host, base_decode_port = self.config.decode_url.split(':')
81
+
82
+ # Avoid ports conflict when running prefill and decode on the same node
83
+ if prefill_host == decode_host and \
84
+ base_prefill_port == base_decode_port:
85
+ base_decode_port = str(int(base_decode_port) + 100)
86
+
87
+ prefill_port = int(base_prefill_port) + self.local_rank
88
+ decode_port = int(base_decode_port) + self.local_rank
89
+ self.prefill_url = ':'.join([prefill_host, str(prefill_port)])
90
+ self.decode_url = ':'.join([decode_host, str(decode_port)])
91
+
92
+ self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url,
93
+ self.config.metadata_server, self.config.protocol,
94
+ self.config.device_name, self.config.metadata_backend)
95
+
96
+ self.remote_url = (self.decode_url
97
+ if kv_rank == 0 else self.prefill_url)
98
+
99
+ # Initialize ZeroMQ context and sockets
100
+ self.context = zmq.Context() # type: ignore[attr-defined]
101
+ self.sender_socket = self.context.socket(zmq.constants.PUSH)
102
+ self.receiver_socket = self.context.socket(zmq.constants.PULL)
103
+ self.sender_ack = self.context.socket(zmq.constants.PULL)
104
+ self.receiver_ack = self.context.socket(zmq.constants.PUSH)
105
+
106
+ self.buffer_cleaner = ThreadPoolExecutor(max_workers=1)
107
+ self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port,
108
+ decode_host, base_decode_port)
109
+
110
+ def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: str,
111
+ d_host: str, d_port: str) -> None:
112
+ """Set up ZeroMQ sockets for sending and receiving data."""
113
+ # Offsets < 8 are left for initialization in case tp and pp are enabled
114
+ p_rank_offset = int(p_port) + 8 + self.local_rank * 2
115
+ d_rank_offset = int(d_port) + 8 + self.local_rank * 2
116
+ if kv_rank == 0:
117
+ self.sender_socket.bind(f"tcp://*:{p_rank_offset + 1}")
118
+ self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
119
+ self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
120
+ self.receiver_ack.bind(f"tcp://*:{p_rank_offset + 2}")
121
+ else:
122
+ self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
123
+ self.sender_socket.bind(f"tcp://*:{d_rank_offset + 1}")
124
+ self.receiver_ack.bind(f"tcp://*:{d_rank_offset + 2}")
125
+ self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
126
+
127
+ def initialize(self, local_hostname: str, metadata_server: str,
128
+ protocol: str, device_name: str,
129
+ metadata_backend: Union[str, None]) -> None:
130
+ """Initialize the mooncake instance."""
131
+ if metadata_backend is None:
132
+ self.engine.initialize(local_hostname, metadata_server, protocol,
133
+ device_name)
134
+ else:
135
+ supported_backend = ["etcd", "redis"]
136
+ metadata_backend = metadata_backend.lower()
137
+ if metadata_backend not in supported_backend:
138
+ raise ValueError(
139
+ "Mooncake Configuration error. `metadata_backend`"
140
+ f"should be one of {supported_backend}.")
141
+
142
+ self.engine.initializeExt(local_hostname, metadata_server,
143
+ protocol, device_name, metadata_backend)
144
+
145
+ def allocate_managed_buffer(self, length: int) -> int:
146
+ """Allocate a managed buffer of the specified length."""
147
+ ret = self.engine.allocateManagedBuffer(length)
148
+ if ret <= 0:
149
+ logger.error("Allocation Return Error")
150
+ raise Exception("Allocation Return Error")
151
+ return ret
152
+
153
+ def free_managed_buffer(self, buffer: int, length: int) -> int:
154
+ """Free a previously allocated managed buffer."""
155
+ return self.engine.freeManagedBuffer(buffer, length)
156
+
157
+ def transfer_sync(self, buffer: int, peer_buffer_address: int,
158
+ length: int) -> int:
159
+ """Synchronously transfer data to the specified address."""
160
+ ret = self.engine.transferSync(self.remote_url, buffer,
161
+ peer_buffer_address, length)
162
+ if ret < 0:
163
+ logger.error("Transfer Return Error")
164
+ raise Exception("Transfer Return Error")
165
+ return ret
166
+
167
+ def write_bytes_to_buffer(self, buffer: int, user_data: bytes,
168
+ length: int) -> int:
169
+ """Write bytes to the allocated buffer."""
170
+ return self.engine.writeBytesToBuffer(buffer, user_data, length)
171
+
172
+ def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes:
173
+ """Read bytes from the allocated buffer."""
174
+ return self.engine.readBytesFromBuffer(buffer, length)
175
+
176
+ def wait_for_ack(self, src_ptr: int, length: int) -> None:
177
+ """Asynchronously wait for ACK from the receiver."""
178
+ ack = self.sender_ack.recv_pyobj()
179
+ if ack != b'ACK':
180
+ logger.error("Failed to receive ACK from the receiver")
181
+
182
+ self.free_managed_buffer(src_ptr, length)
183
+
184
+ def send_bytes(self, user_data: bytes) -> None:
185
+ """Send bytes to the remote process."""
186
+ length = len(user_data)
187
+ src_ptr = self.allocate_managed_buffer(length)
188
+ self.write_bytes_to_buffer(src_ptr, user_data, length)
189
+ self.sender_socket.send_pyobj((src_ptr, length))
190
+ self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length)
191
+
192
+ def recv_bytes(self) -> bytes:
193
+ """Receive bytes from the remote process."""
194
+ src_ptr, length = self.receiver_socket.recv_pyobj()
195
+ dst_ptr = self.allocate_managed_buffer(length)
196
+ self.transfer_sync(dst_ptr, src_ptr, length)
197
+ ret = self.read_bytes_from_buffer(dst_ptr, length)
198
+
199
+ # Buffer cleanup
200
+ self.receiver_ack.send_pyobj(b'ACK')
201
+ self.free_managed_buffer(dst_ptr, length)
202
+
203
+ return ret
204
+
205
+
206
+ class MooncakePipe(KVPipeBase):
207
+ """MooncakeTransferEngine based Pipe implementation."""
208
+
209
+ def __init__(self,
210
+ local_rank: int,
211
+ config: KVTransferConfig,
212
+ device: Optional[str] = None):
213
+ """Initialize the mooncake pipe and set related parameters."""
214
+ self.config = config
215
+ self.local_rank = local_rank
216
+ self.kv_rank = self.config.kv_rank
217
+ if device is None:
218
+ self.device = self._select_device(self.config.kv_buffer_device)
219
+ else:
220
+ self.device = self._select_device(device)
221
+
222
+ self.transfer_engine = MooncakeTransferEngine(self.kv_rank,
223
+ self.local_rank)
224
+ self.transport_thread: Optional[ThreadPoolExecutor] = None
225
+ self.none_tensor = torch.tensor([NONE_INT], device=self.device)
226
+
227
+ def _select_device(self, device: str) -> torch.device:
228
+ """Select available device (CUDA or CPU)."""
229
+ logger.info("Selecting device: %s", device)
230
+ if device == "cuda":
231
+ return torch.device(f"cuda:{self.local_rank}")
232
+ else:
233
+ return torch.device("cpu")
234
+
235
+ def tensor_hash(self, tensor: torch.Tensor) -> int:
236
+ """Calculate the hash value of the tensor."""
237
+ return hash(tensor.data_ptr())
238
+
239
+ def _send_impl(self, tensor: torch.Tensor) -> None:
240
+ """Implement the tensor sending logic."""
241
+ value_bytes = pickle.dumps(tensor)
242
+ self.transfer_engine.send_bytes(value_bytes)
243
+
244
+ def _recv_impl(self) -> torch.Tensor:
245
+ """Implement the tensor receiving logic."""
246
+ data = self.transfer_engine.recv_bytes()
247
+ return pickle.loads(data)
248
+
249
+ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
250
+ """Send tensor to the target process."""
251
+ if self.transport_thread is None:
252
+ self.transport_thread = ThreadPoolExecutor(max_workers=1)
253
+ tensor = tensor if tensor is not None else self.none_tensor
254
+ assert (len(tensor.shape) > 0)
255
+ self.transport_thread.submit(self._send_impl, tensor)
256
+
257
+ def recv_tensor(self) -> Optional[torch.Tensor]:
258
+ """Receive tensor from other processes."""
259
+ if self.transport_thread is None:
260
+ self.transport_thread = ThreadPoolExecutor(max_workers=1)
261
+ tensor = self.transport_thread.submit(self._recv_impl).result()
262
+ if tensor.numel() == 1 and tensor.item() == NONE_INT:
263
+ return None
264
+ else:
265
+ return tensor
266
+
267
+ def close(self) -> None:
268
+ """Cleanup logic when closing the pipe."""
269
+ self.transfer_engine.sender_socket.close()
270
+ self.transfer_engine.receiver_socket.close()
271
+ self.transfer_engine.sender_ack.close()
272
+ self.transfer_engine.receiver_ack.close()
273
+ self.transfer_engine.context.term() # Terminate the ZMQ context
274
+ logger.info("Closed the transfer engine and cleaned up resources.")
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """
3
+ This module implements a PyNccl pipe for sending and receiving
4
+ Optional[torch.Tensor] between distributed ranks with advanced
5
+ communication features.
6
+
7
+ Key Features:
8
+ - Supports sending and receiving tensors with metadata
9
+ - Handles both CUDA and CPU device communications
10
+ - Implements a non-blocking tensor transfer mechanism
11
+ - Manages buffer size and provides backpressure control
12
+ - Supports distributed process groups with configurable parameters
13
+ """
14
+
15
+ import threading
16
+ import time
17
+ from concurrent.futures import ThreadPoolExecutor
18
+ from typing import Callable, Dict, Optional, Tuple
19
+
20
+ import torch
21
+
22
+ from vllm.config import KVTransferConfig
23
+ from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
24
+ from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
25
+ from vllm.distributed.utils import StatelessProcessGroup
26
+ from vllm.logger import init_logger
27
+
28
+ logger = init_logger(__name__)
29
+
30
+
31
+ class BrokenPipeException(Exception):
32
+
33
+ def __init__(self, message):
34
+ self.message = message
35
+ super().__init__(self.message)
36
+
37
+
38
+ Metadata = Dict[str, Optional[torch.Tensor]]
39
+
40
+
41
+ class PyNcclPipe(KVPipeBase):
42
+
43
+ METADATA_LENGTH = 16
44
+ MAX_TENSOR_DIMENSIONS = 14
45
+ METADATA_DTYPE = torch.int64
46
+
47
+ def __init__(self,
48
+ local_rank: int,
49
+ config: KVTransferConfig,
50
+ device: Optional[str] = None,
51
+ port_offset: int = 0):
52
+ self.config = config
53
+ self.local_rank = local_rank
54
+ self.kv_rank = self.config.kv_rank
55
+ self.kv_parallel_size = self.config.kv_parallel_size
56
+ if device is None:
57
+ self.device = self._select_device(self.config.kv_buffer_device)
58
+ else:
59
+ self.device = self._select_device(device)
60
+
61
+ # build distributed connection and send/recv implementation
62
+ self.group = StatelessProcessGroup.create(
63
+ host=self.config.kv_ip,
64
+ port=self.config.kv_port + port_offset,
65
+ rank=self.kv_rank,
66
+ world_size=self.kv_parallel_size,
67
+ )
68
+ # add a barrier to make sure the connection is initiated properly
69
+ self.group.barrier()
70
+ impl = self._get_device_send_recv_impl(self.group)
71
+ self.device_send_func, self.device_recv_func = impl
72
+ # set target rank
73
+ self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size
74
+ self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size
75
+
76
+ # transportation-related variables
77
+ self.transport_thread: Optional[ThreadPoolExecutor] = None
78
+ self.buffer_size = 0
79
+ self.buffer_size_lock = threading.Lock()
80
+ self.buffer_size_thresh = self.config.kv_buffer_size
81
+
82
+ def _get_device_send_recv_impl(
83
+ self, group: StatelessProcessGroup
84
+ ) -> Tuple[Callable[[torch.Tensor, int], None], Callable[
85
+ [torch.Tensor, int], None]]:
86
+
87
+ send: Callable[[torch.Tensor, int], None]
88
+ recv: Callable[[torch.Tensor, int], None]
89
+ if self.device.type == "cuda":
90
+ # use PyNCCL for send / recv
91
+ comm = PyNcclCommunicator(group, device=self.local_rank)
92
+ comm.disabled = False
93
+ send, recv = comm.send, comm.recv # type: ignore
94
+ else:
95
+ # This send / recv implementation here is NOT intended to transfer
96
+ # KV caches (and should NOT be repurposed to transfer KV caches).
97
+ # Currently it is only used to transmit control-plane messages
98
+ # for PyNcclBuffer.
99
+ send = group.send_obj
100
+
101
+ def my_recv(x, src):
102
+ x[...] = group.recv_obj(src)
103
+
104
+ recv = my_recv
105
+
106
+ return send, recv
107
+
108
+ def _select_device(self, device: str):
109
+ logger.info("Selecting device: %s", device)
110
+ if device == "cuda":
111
+ return torch.device(f"cuda:{self.local_rank}")
112
+ else:
113
+ return torch.device("cpu")
114
+
115
+ def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata:
116
+ """
117
+ Create the metadata as a dictionary based on the input tensor.
118
+
119
+ Parameters:
120
+ - tensor: The input tensor or None if no tensor is provided.
121
+
122
+ Returns:
123
+ - metadata: A dictionary with the following keys:
124
+ - "dtype": The data type of the tensor or None.
125
+ - "shape": The shape of the tensor or None.
126
+ """
127
+ if tensor is None:
128
+ return {"dtype": None, "shape": None}
129
+ else:
130
+ return {"dtype": tensor.dtype, "shape": tensor.shape}
131
+
132
+ def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor:
133
+ """
134
+ Create a buffer to receive the tensor based on the provided metadata.
135
+
136
+ Parameters:
137
+ - metadata: A dictionary with keys "dtype" and "shape", describing
138
+ the tensor's data type and shape.
139
+
140
+ Returns:
141
+ - buffer: A tensor of the specified type and shape, allocated on
142
+ self.device.
143
+ """
144
+ return torch.empty(metadata["shape"],
145
+ dtype=metadata["dtype"],
146
+ device=self.device)
147
+
148
+ def _send_metadata(self, metadata: Metadata):
149
+ """
150
+ Send the metadata dictionary to the target rank.
151
+
152
+ Parameters:
153
+ - metadata: A dictionary with keys "dtype" and "shape".
154
+ """
155
+ self.group.send_obj(metadata, self.target_rank_for_send)
156
+
157
+ def _recv_metadata(self) -> Metadata:
158
+ """
159
+ Receive the metadata dictionary from the target rank.
160
+
161
+ Returns:
162
+ - metadata: A dictionary with keys "dtype" and "shape" describing
163
+ the tensor.
164
+ """
165
+ return self.group.recv_obj(self.target_rank_for_recv)
166
+
167
+ def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
168
+ """
169
+ The actual implementation of sending the tensor and its metadata to the
170
+ target rank.
171
+
172
+ Parameters:
173
+ - tensor: The input tensor to be sent, or None if no tensor is
174
+ being sent.
175
+ """
176
+ metadata = self._make_metadata(tensor)
177
+ self._send_metadata(metadata)
178
+ if tensor is not None:
179
+ self.device_send_func(tensor.to(self.device),
180
+ self.target_rank_for_send)
181
+
182
+ def _recv_impl(self) -> Optional[torch.Tensor]:
183
+ """
184
+ The actual implementation of receiving a tensor and its metadata from
185
+ the target rank.
186
+
187
+ Returns:
188
+ - buffer: The received tensor, or None if no tensor is received.
189
+ """
190
+ metadata = self._recv_metadata()
191
+ if metadata["dtype"] is None:
192
+ return None
193
+ buffer = self._prepare_recv_buffer(metadata)
194
+ self.device_recv_func(buffer, self.target_rank_for_recv)
195
+
196
+ return buffer
197
+
198
+ def send_tensor_wrapper(self, tensor: Optional[torch.Tensor],
199
+ tensor_size: int) -> None:
200
+ """
201
+ Wrapper for _send_impl to handle exceptions and update buffer size.
202
+ """
203
+ try:
204
+ self._send_impl(tensor)
205
+
206
+ with self.buffer_size_lock:
207
+ self.buffer_size -= tensor_size
208
+ except Exception as e:
209
+ logger.error("[rank%d]: Exception when trying to send %s, msg: %s",
210
+ torch.distributed.get_rank(), str(tensor), str(e))
211
+ import traceback
212
+ traceback.print_exc()
213
+
214
+ def block_if_full(self):
215
+ """
216
+ Block the current thread if the buffer size is larger than the
217
+ threshold.
218
+ """
219
+ while self.buffer_size > self.buffer_size_thresh:
220
+ logger.debug("KV cache transfer pipe is full. Waiting...")
221
+ time.sleep(0.05)
222
+
223
+ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
224
+ """
225
+ Sends a tensor and its metadata to the destination rank in a
226
+ non-blocking way.
227
+
228
+ Parameters:
229
+ - tensor: The tensor to send, or None if no tensor is being sent.
230
+ """
231
+ if self.transport_thread is None:
232
+ self.transport_thread = ThreadPoolExecutor(max_workers=1)
233
+
234
+ if tensor is not None:
235
+ tensor_size = tensor.element_size() * tensor.numel()
236
+ else:
237
+ tensor_size = 0
238
+
239
+ self.block_if_full()
240
+
241
+ with self.buffer_size_lock:
242
+ self.buffer_size += tensor_size
243
+
244
+ self.transport_thread.submit(self.send_tensor_wrapper, tensor,
245
+ tensor_size)
246
+
247
+ def recv_tensor(self) -> Optional[torch.Tensor]:
248
+ """
249
+ Receives a tensor and its metadata from the source rank. Blocking call.
250
+
251
+ Returns:
252
+ - tensor: The received tensor, or None if no tensor is received.
253
+ """
254
+ if self.transport_thread is None:
255
+ self.transport_thread = ThreadPoolExecutor(max_workers=1)
256
+
257
+ future = self.transport_thread.submit(self._recv_impl)
258
+
259
+ try:
260
+ tensor = future.result()
261
+ except Exception as e:
262
+ logger.error("Encountering exception in KV receiving thread")
263
+ logger.error("%s", e)
264
+ logger.error("My device: %s", self.device)
265
+ import traceback
266
+ traceback.print_exc()
267
+ raise e
268
+
269
+ return tensor
270
+
271
+ def close(self):
272
+ """
273
+ Close the pipe and release associated resources.
274
+ """
275
+ if hasattr(self,
276
+ "transport_thread") and self.transport_thread is not None:
277
+ self.transport_thread.shutdown()
.venv/lib/python3.11/site-packages/vllm/distributed/kv_transfer/kv_transfer_agent.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """A centralized entrypoint to perform distributed KV cache transfer.
3
+
4
+ This implementation is a shim wrapper on two APIs exposed by `kv_connector`:
5
+ 1. `send_kv_caches_and_hidden_states`
6
+ 2. `recv_kv_caches_and_hidden_states
7
+ """
8
+ from typing import TYPE_CHECKING, List, Tuple, Union
9
+
10
+ if TYPE_CHECKING:
11
+ from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
12
+ from vllm.config import VllmConfig
13
+
14
+ import torch
15
+
16
+ from vllm.distributed.kv_transfer.kv_connector.factory import (
17
+ KVConnectorFactory)
18
+ from vllm.logger import init_logger
19
+ from vllm.sequence import IntermediateTensors
20
+
21
+ logger = init_logger(__name__)
22
+
23
+
24
+ class KVTransferAgent:
25
+ """
26
+ A class designated for distributed KV transfer
27
+
28
+ Target use cases:
29
+ 1. Disaggregated prefill
30
+ 2. Remote KV cache storage
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ rank: int,
36
+ local_rank: int,
37
+ config: "VllmConfig",
38
+ ):
39
+
40
+ self.config = config
41
+
42
+ if config.kv_transfer_config is None:
43
+ raise ValueError("KVTransferConfig is not set in the VllmConfig,"
44
+ " cannot initialize KVConnector.")
45
+
46
+ assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\
47
+ "TransferAgent should only be used when kv_connector is set."
48
+
49
+ self.connector = KVConnectorFactory.create_connector(
50
+ rank, local_rank, config)
51
+
52
+ def send_kv_caches_and_hidden_states(
53
+ self,
54
+ model_executable: torch.nn.Module,
55
+ model_input: "ModelInputForGPUWithSamplingMetadata",
56
+ kv_caches: List[torch.Tensor],
57
+ hidden_or_intermediate_states: Union[torch.Tensor,
58
+ IntermediateTensors],
59
+ ) -> None:
60
+
61
+ self.connector.send_kv_caches_and_hidden_states(
62
+ model_executable, model_input, kv_caches,
63
+ hidden_or_intermediate_states)
64
+
65
+ def close(self) -> None:
66
+ self.connector.close()
67
+
68
+ def recv_kv_caches_and_hidden_states(
69
+ self, model_executable: torch.nn.Module,
70
+ model_input: "ModelInputForGPUWithSamplingMetadata",
71
+ kv_caches: List[torch.Tensor]
72
+ ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
73
+ "ModelInputForGPUWithSamplingMetadata"]:
74
+
75
+ return self.connector.recv_kv_caches_and_hidden_states(
76
+ model_executable, model_input, kv_caches)
.venv/lib/python3.11/site-packages/vllm/distributed/parallel_state.py ADDED
@@ -0,0 +1,1285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Copyright 2023 The vLLM team.
4
+ # Adapted from
5
+ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
6
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
7
+ """vLLM distributed state.
8
+ It takes over the control of the distributed environment from PyTorch.
9
+ The typical workflow is:
10
+
11
+ - call `init_distributed_environment` to initialize the distributed environment.
12
+ - call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
13
+ initialize the model parallel groups.
14
+
15
+ - any code dealing with the distributed stuff
16
+
17
+ - call `destroy_model_parallel` to destroy the model parallel groups.
18
+ - call `destroy_distributed_environment` to destroy the distributed environment.
19
+
20
+ If you only need to use the distributed environment without model/pipeline
21
+ parallelism, you can skip the model parallel initialization and destruction
22
+ steps.
23
+ """
24
+ import contextlib
25
+ import gc
26
+ import pickle
27
+ import weakref
28
+ from collections import namedtuple
29
+ from contextlib import contextmanager, nullcontext
30
+ from dataclasses import dataclass
31
+ from multiprocessing import shared_memory
32
+ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
33
+ Union)
34
+ from unittest.mock import patch
35
+
36
+ import torch
37
+ import torch.distributed
38
+ from torch.distributed import Backend, ProcessGroup
39
+
40
+ import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
41
+ import vllm.envs as envs
42
+ from vllm.distributed.utils import StatelessProcessGroup
43
+ from vllm.logger import init_logger
44
+ from vllm.utils import direct_register_custom_op, supports_custom_op
45
+
46
+ if TYPE_CHECKING:
47
+ from vllm.config import VllmConfig
48
+
49
+
50
+ @dataclass
51
+ class GraphCaptureContext:
52
+ stream: torch.cuda.Stream
53
+
54
+
55
+ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
56
+
57
+
58
+ def _split_tensor_dict(
59
+ tensor_dict: Dict[str, Union[torch.Tensor, Any]]
60
+ ) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
61
+ """Split the tensor dictionary into two parts:
62
+ 1. A list of (key, value) pairs. If the value is a tensor, it is replaced
63
+ by its metadata.
64
+ 2. A list of tensors.
65
+ """
66
+ metadata_list: List[Tuple[str, Any]] = []
67
+ tensor_list: List[torch.Tensor] = []
68
+ for key, value in tensor_dict.items():
69
+ if isinstance(value, torch.Tensor):
70
+ # Note: we cannot use `value.device` here,
71
+ # because it contains not only the device type but also the device
72
+ # index (e.g. "cuda:0"). We only need the device type.
73
+ # receiving side will set the device index.
74
+ device = value.device.type
75
+ metadata_list.append(
76
+ (key, TensorMetadata(device, value.dtype, value.size())))
77
+ tensor_list.append(value)
78
+ else:
79
+ metadata_list.append((key, value))
80
+ return metadata_list, tensor_list
81
+
82
+
83
+ _group_name_counter: Dict[str, int] = {}
84
+
85
+
86
+ def _get_unique_name(name: str) -> str:
87
+ """Get a unique name for the group.
88
+ Example:
89
+ _get_unique_name("tp") -> "tp:0"
90
+ _get_unique_name("tp") -> "tp:1"
91
+ """
92
+ if name not in _group_name_counter:
93
+ _group_name_counter[name] = 0
94
+ newname = f"{name}:{_group_name_counter[name]}"
95
+ _group_name_counter[name] += 1
96
+ return newname
97
+
98
+
99
+ _groups: Dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
100
+
101
+
102
+ def _register_group(group: "GroupCoordinator") -> None:
103
+ _groups[group.unique_name] = weakref.ref(group)
104
+
105
+
106
+ def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
107
+ assert group_name in _groups, f"Group {group_name} is not found."
108
+ group = _groups[group_name]()
109
+ if group is None:
110
+ raise ValueError(f"Group {group_name} is destroyed.")
111
+ return group._all_reduce_out_place(tensor)
112
+
113
+
114
+ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
115
+ return torch.empty_like(tensor)
116
+
117
+
118
+ if supports_custom_op():
119
+ direct_register_custom_op(
120
+ op_name="all_reduce",
121
+ op_func=all_reduce,
122
+ mutates_args=[],
123
+ fake_impl=all_reduce_fake,
124
+ )
125
+
126
+
127
+ class GroupCoordinator:
128
+ """
129
+ PyTorch ProcessGroup wrapper for a group of processes.
130
+ PyTorch ProcessGroup is bound to one specific communication backend,
131
+ e.g. NCCL, Gloo, MPI, etc.
132
+ GroupCoordinator takes charge of all the communication operations among
133
+ the processes in the group. It can route the communication to
134
+ a specific implementation (e.g. switch allreduce implementation
135
+ based on the tensor size and cuda graph mode).
136
+ """
137
+
138
+ # available attributes:
139
+ rank: int # global rank
140
+ ranks: List[int] # global ranks in the group
141
+ world_size: int # size of the group
142
+ # difference between `local_rank` and `rank_in_group`:
143
+ # if we have a group of size 4 across two nodes:
144
+ # Process | Node | Rank | Local Rank | Rank in Group
145
+ # 0 | 0 | 0 | 0 | 0
146
+ # 1 | 0 | 1 | 1 | 1
147
+ # 2 | 1 | 2 | 0 | 2
148
+ # 3 | 1 | 3 | 1 | 3
149
+ local_rank: int # local rank used to assign devices
150
+ rank_in_group: int # rank inside the group
151
+ cpu_group: ProcessGroup # group for CPU communication
152
+ device_group: ProcessGroup # group for device communication
153
+ use_pynccl: bool # a hint of whether to use PyNccl
154
+ use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
155
+ # communicators are only created for world size > 1
156
+ pynccl_comm: Optional[Any] # PyNccl communicator
157
+ ca_comm: Optional[Any] # Custom allreduce communicator
158
+ mq_broadcaster: Optional[Any] # shared memory broadcaster
159
+
160
+ def __init__(
161
+ self,
162
+ group_ranks: List[List[int]],
163
+ local_rank: int,
164
+ torch_distributed_backend: Union[str, Backend],
165
+ use_pynccl: bool,
166
+ use_custom_allreduce: bool,
167
+ use_tpu_communicator: bool,
168
+ use_hpu_communicator: bool,
169
+ use_xpu_communicator: bool,
170
+ use_message_queue_broadcaster: bool = False,
171
+ group_name: Optional[str] = None,
172
+ ):
173
+ group_name = group_name or "anonymous"
174
+ self.unique_name = _get_unique_name(group_name)
175
+ _register_group(self)
176
+
177
+ self.rank = torch.distributed.get_rank()
178
+ self.local_rank = local_rank
179
+ self.device_group = None
180
+ self.cpu_group = None
181
+
182
+ for ranks in group_ranks:
183
+ device_group = torch.distributed.new_group(
184
+ ranks, backend=torch_distributed_backend)
185
+ # a group with `gloo` backend, to allow direct coordination between
186
+ # processes through the CPU.
187
+ cpu_group = torch.distributed.new_group(ranks, backend="gloo")
188
+ if self.rank in ranks:
189
+ self.ranks = ranks
190
+ self.world_size = len(ranks)
191
+ self.rank_in_group = ranks.index(self.rank)
192
+ self.device_group = device_group
193
+ self.cpu_group = cpu_group
194
+
195
+ assert self.cpu_group is not None
196
+ assert self.device_group is not None
197
+
198
+ from vllm.platforms import current_platform
199
+ if current_platform.is_cuda_alike():
200
+ self.device = torch.device(f"cuda:{local_rank}")
201
+ else:
202
+ self.device = torch.device("cpu")
203
+
204
+ self.use_pynccl = use_pynccl
205
+ self.use_custom_allreduce = use_custom_allreduce
206
+ self.use_tpu_communicator = use_tpu_communicator
207
+ self.use_hpu_communicator = use_hpu_communicator
208
+ self.use_xpu_communicator = use_xpu_communicator
209
+
210
+ # lazy import to avoid documentation build error
211
+ from vllm.distributed.device_communicators.custom_all_reduce import (
212
+ CustomAllreduce)
213
+ from vllm.distributed.device_communicators.pynccl import (
214
+ PyNcclCommunicator)
215
+
216
+ self.pynccl_comm: Optional[PyNcclCommunicator] = None
217
+ if use_pynccl and self.world_size > 1:
218
+ self.pynccl_comm = PyNcclCommunicator(
219
+ group=self.cpu_group,
220
+ device=self.device,
221
+ )
222
+
223
+ self.ca_comm: Optional[CustomAllreduce] = None
224
+ if use_custom_allreduce and self.world_size > 1:
225
+ # Initialize a custom fast all-reduce implementation.
226
+ self.ca_comm = CustomAllreduce(
227
+ group=self.cpu_group,
228
+ device=self.device,
229
+ )
230
+
231
+ from vllm.distributed.device_communicators.tpu_communicator import (
232
+ TpuCommunicator)
233
+ self.tpu_communicator: Optional[TpuCommunicator] = None
234
+ if use_tpu_communicator and self.world_size > 1:
235
+ self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
236
+
237
+ from vllm.distributed.device_communicators.hpu_communicator import (
238
+ HpuCommunicator)
239
+ self.hpu_communicator: Optional[HpuCommunicator]
240
+ if use_hpu_communicator and self.world_size > 1:
241
+ self.hpu_communicator = HpuCommunicator(group=self.device_group)
242
+
243
+ from vllm.distributed.device_communicators.xpu_communicator import (
244
+ XpuCommunicator)
245
+ self.xpu_communicator: Optional[XpuCommunicator]
246
+ if use_xpu_communicator and self.world_size > 1:
247
+ self.xpu_communicator = XpuCommunicator(group=self.device_group)
248
+
249
+ from vllm.distributed.device_communicators.shm_broadcast import (
250
+ MessageQueue)
251
+ self.mq_broadcaster: Optional[MessageQueue] = None
252
+ if use_message_queue_broadcaster and self.world_size > 1:
253
+ self.mq_broadcaster = MessageQueue.create_from_process_group(
254
+ self.cpu_group, 1 << 22, 6)
255
+
256
+ @property
257
+ def first_rank(self):
258
+ """Return the global rank of the first process in the group"""
259
+ return self.ranks[0]
260
+
261
+ @property
262
+ def last_rank(self):
263
+ """Return the global rank of the last process in the group"""
264
+ return self.ranks[-1]
265
+
266
+ @property
267
+ def is_first_rank(self):
268
+ """Return whether the caller is the first process in the group"""
269
+ return self.rank == self.first_rank
270
+
271
+ @property
272
+ def is_last_rank(self):
273
+ """Return whether the caller is the last process in the group"""
274
+ return self.rank == self.last_rank
275
+
276
+ @property
277
+ def next_rank(self):
278
+ """Return the global rank of the process that follows the caller"""
279
+ rank_in_group = self.rank_in_group
280
+ world_size = self.world_size
281
+ return self.ranks[(rank_in_group + 1) % world_size]
282
+
283
+ @property
284
+ def prev_rank(self):
285
+ """Return the global rank of the process that precedes the caller"""
286
+ rank_in_group = self.rank_in_group
287
+ world_size = self.world_size
288
+ return self.ranks[(rank_in_group - 1) % world_size]
289
+
290
+ @contextmanager
291
+ def graph_capture(
292
+ self, graph_capture_context: Optional[GraphCaptureContext] = None):
293
+ if graph_capture_context is None:
294
+ stream = torch.cuda.Stream()
295
+ graph_capture_context = GraphCaptureContext(stream)
296
+ else:
297
+ stream = graph_capture_context.stream
298
+
299
+ ca_comm = self.ca_comm
300
+ maybe_ca_context = nullcontext(
301
+ ) if ca_comm is None else ca_comm.capture()
302
+
303
+ # ensure all initialization operations complete before attempting to
304
+ # capture the graph on another stream
305
+ curr_stream = torch.cuda.current_stream()
306
+ if curr_stream != stream:
307
+ stream.wait_stream(curr_stream)
308
+
309
+ with torch.cuda.stream(stream), maybe_ca_context:
310
+ yield graph_capture_context
311
+
312
+ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
313
+ """
314
+ User-facing all-reduce function before we actually call the
315
+ all-reduce operation.
316
+
317
+ We need this because Dynamo does not support passing an arbitrary
318
+ object (`self` in this case) to a custom op. We need to pass the
319
+ group name as a string, and then look up the group coordinator from
320
+ the group name, dispatch the all-reduce operation to the group
321
+ coordinator.
322
+
323
+ In addition, PyTorch custom ops do not support mutation or returning
324
+ a new tensor in the same op. So we always make the all-reduce operation
325
+ out-of-place.
326
+ """
327
+ # Bypass the function if we are using only 1 GPU.
328
+ if self.world_size == 1:
329
+ return input_
330
+
331
+ if input_.is_cpu:
332
+ try:
333
+ import intel_extension_for_pytorch as ipex
334
+ ipex.distributed.all_reduce(input_, group=self.device_group)
335
+ return input_
336
+ except ImportError:
337
+ """
338
+ Intel IPEX not found. Falling back to PyTorch native
339
+ all_reduce for CPU
340
+ """
341
+ torch.distributed.all_reduce(input_, group=self.device_group)
342
+ return input_
343
+
344
+ if self.tpu_communicator is not None and \
345
+ not self.tpu_communicator.disabled:
346
+ # TPU handles Dynamo with its own logic.
347
+ return self.tpu_communicator.all_reduce(input_)
348
+
349
+ if self.hpu_communicator is not None and \
350
+ not self.hpu_communicator.disabled:
351
+ return self.hpu_communicator.all_reduce(input_)
352
+
353
+ if self.xpu_communicator is not None and \
354
+ not self.xpu_communicator.disabled:
355
+ return self.xpu_communicator.all_reduce(input_)
356
+
357
+ return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
358
+
359
+ def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
360
+ # always try custom allreduce first,
361
+ # and then pynccl.
362
+ ca_comm = self.ca_comm
363
+ if ca_comm is not None and not ca_comm.disabled and \
364
+ ca_comm.should_custom_ar(input_):
365
+ out = ca_comm.custom_all_reduce(input_)
366
+ assert out is not None
367
+ return out
368
+ pynccl_comm = self.pynccl_comm
369
+ assert pynccl_comm is not None
370
+ out = pynccl_comm.all_reduce(input_)
371
+ if out is None:
372
+ # fall back to the default all-reduce using PyTorch.
373
+ # this usually happens during testing.
374
+ # when we run the model, allreduce only happens for the TP
375
+ # group, where we always have either custom allreduce or pynccl.
376
+ out = input_.clone()
377
+ torch.distributed.all_reduce(out, group=self.device_group)
378
+ return out
379
+
380
+ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
381
+ world_size = self.world_size
382
+ # Bypass the function if we are using only 1 GPU.
383
+ if world_size == 1:
384
+ return input_
385
+ assert -input_.dim() <= dim < input_.dim(), (
386
+ f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
387
+
388
+ # For TPUs, use TPU communicator.
389
+ tpu_comm = self.tpu_communicator
390
+ if tpu_comm is not None and not tpu_comm.disabled:
391
+ return tpu_comm.all_gather(input_, dim)
392
+
393
+ # For HPUs, use HPU communicator.
394
+ hpu_comm = self.hpu_communicator
395
+ if hpu_comm is not None and not hpu_comm.disabled:
396
+ return hpu_comm.all_gather(input_, dim)
397
+
398
+ if dim < 0:
399
+ # Convert negative dim to positive.
400
+ dim += input_.dim()
401
+ input_size = input_.size()
402
+ # NOTE: we have to use concat-style all-gather here,
403
+ # stack-style all-gather has compatibility issues with
404
+ # torch.compile . see https://github.com/pytorch/pytorch/issues/138795
405
+ output_size = (input_size[0] * world_size, ) + input_size[1:]
406
+ # Allocate output tensor.
407
+ output_tensor = torch.empty(output_size,
408
+ dtype=input_.dtype,
409
+ device=input_.device)
410
+ # All-gather.
411
+ torch.distributed.all_gather_into_tensor(output_tensor,
412
+ input_,
413
+ group=self.device_group)
414
+ # Reshape
415
+ output_tensor = output_tensor.reshape((world_size, ) + input_size)
416
+ output_tensor = output_tensor.movedim(0, dim)
417
+ output_tensor = output_tensor.reshape(input_size[:dim] +
418
+ (world_size *
419
+ input_size[dim], ) +
420
+ input_size[dim + 1:])
421
+ return output_tensor
422
+
423
+ def gather(self,
424
+ input_: torch.Tensor,
425
+ dst: int = 0,
426
+ dim: int = -1) -> Optional[torch.Tensor]:
427
+ """
428
+ NOTE: We assume that the input tensor is on the same device across
429
+ all the ranks.
430
+ NOTE: `dst` is the local rank of the destination rank.
431
+ """
432
+ world_size = self.world_size
433
+ # Bypass the function if we are using only 1 GPU.
434
+ if world_size == 1:
435
+ return input_
436
+ assert -input_.dim() <= dim < input_.dim(), (
437
+ f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
438
+ if dim < 0:
439
+ # Convert negative dim to positive.
440
+ dim += input_.dim()
441
+ if self.xpu_communicator is not None and \
442
+ not self.xpu_communicator.disabled:
443
+ return self.xpu_communicator.gather(input_, self.rank_in_group,
444
+ dst, dim)
445
+ # Allocate output tensor.
446
+ if self.rank_in_group == dst:
447
+ gather_list = [torch.empty_like(input_) for _ in range(world_size)]
448
+ else:
449
+ gather_list = None
450
+ # Gather.
451
+ torch.distributed.gather(input_,
452
+ gather_list,
453
+ dst=self.ranks[dst],
454
+ group=self.device_group)
455
+ if self.rank_in_group == dst:
456
+ output_tensor = torch.cat(gather_list, dim=dim)
457
+ else:
458
+ output_tensor = None
459
+ return output_tensor
460
+
461
+ def broadcast(self, input_: torch.Tensor, src: int = 0):
462
+ """Broadcast the input tensor.
463
+ NOTE: `src` is the local rank of the source rank.
464
+ """
465
+ assert src < self.world_size, f"Invalid src rank ({src})"
466
+
467
+ # Bypass the function if we are using only 1 GPU.
468
+ if self.world_size == 1:
469
+ return input_
470
+ # Broadcast.
471
+ torch.distributed.broadcast(input_,
472
+ src=self.ranks[src],
473
+ group=self.device_group)
474
+ return input_
475
+
476
+ def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
477
+ """Broadcast the input object.
478
+ NOTE: `src` is the local rank of the source rank.
479
+ """
480
+ assert src < self.world_size, f"Invalid src rank ({src})"
481
+
482
+ # Bypass the function if we are using only 1 GPU.
483
+ if self.world_size == 1:
484
+ return obj
485
+ if self.mq_broadcaster is not None:
486
+ assert src == 0, "Message queue broadcaster only supports src=0"
487
+ return self.mq_broadcaster.broadcast_object(obj)
488
+ if self.rank_in_group == src:
489
+ torch.distributed.broadcast_object_list([obj],
490
+ src=self.ranks[src],
491
+ group=self.cpu_group)
492
+ return obj
493
+ else:
494
+ recv = [None]
495
+ torch.distributed.broadcast_object_list(recv,
496
+ src=self.ranks[src],
497
+ group=self.cpu_group)
498
+ return recv[0]
499
+
500
+ def broadcast_object_list(self,
501
+ obj_list: List[Any],
502
+ src: int = 0,
503
+ group: Optional[ProcessGroup] = None):
504
+ """Broadcast the input object list.
505
+ NOTE: `src` is the local rank of the source rank.
506
+ """
507
+ assert src < self.world_size, f"Invalid src rank ({src})"
508
+
509
+ # Bypass the function if we are using only 1 GPU.
510
+ if self.world_size == 1:
511
+ return obj_list
512
+ # Broadcast.
513
+ torch.distributed.broadcast_object_list(obj_list,
514
+ src=self.ranks[src],
515
+ group=self.device_group)
516
+ return obj_list
517
+
518
+ def send_object(self, obj: Any, dst: int) -> None:
519
+ """Send the input object list to the destination rank."""
520
+ """NOTE: `dst` is the local rank of the destination rank."""
521
+
522
+ assert dst < self.world_size, f"Invalid dst rank ({dst})"
523
+
524
+ assert dst != self.rank_in_group, (
525
+ "Invalid destination rank. Destination rank is the same "
526
+ "as the current rank.")
527
+
528
+ # Serialize object to tensor and get the size as well
529
+ object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
530
+
531
+ size_tensor = torch.tensor([object_tensor.numel()],
532
+ dtype=torch.long,
533
+ device="cpu")
534
+
535
+ # Send object size
536
+
537
+ torch.distributed.send(size_tensor,
538
+ dst=self.ranks[dst],
539
+ group=self.cpu_group)
540
+
541
+ # Send object
542
+ torch.distributed.send(object_tensor,
543
+ dst=self.ranks[dst],
544
+ group=self.cpu_group)
545
+
546
+ return None
547
+
548
+ def recv_object(self, src: int) -> Any:
549
+ """Receive the input object list from the source rank."""
550
+ """NOTE: `src` is the local rank of the source rank."""
551
+
552
+ assert src < self.world_size, f"Invalid src rank ({src})"
553
+
554
+ assert src != self.rank_in_group, (
555
+ "Invalid source rank. Source rank is the same as the current rank."
556
+ )
557
+
558
+ size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
559
+
560
+ # Receive object size
561
+ rank_size = torch.distributed.recv(size_tensor,
562
+ src=self.ranks[src],
563
+ group=self.cpu_group)
564
+
565
+ # Tensor to receive serialized objects into.
566
+ object_tensor = torch.empty( # type: ignore[call-overload]
567
+ size_tensor.item(), # type: ignore[arg-type]
568
+ dtype=torch.uint8,
569
+ device="cpu")
570
+
571
+ rank_object = torch.distributed.recv(object_tensor,
572
+ src=self.ranks[src],
573
+ group=self.cpu_group)
574
+
575
+ assert rank_object == rank_size, (
576
+ "Received object sender rank does not match the size sender rank.")
577
+
578
+ obj = pickle.loads(object_tensor.numpy().tobytes())
579
+
580
+ return obj
581
+
582
+ def broadcast_tensor_dict(
583
+ self,
584
+ tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
585
+ src: int = 0,
586
+ group: Optional[ProcessGroup] = None,
587
+ metadata_group: Optional[ProcessGroup] = None
588
+ ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
589
+ """Broadcast the input tensor dictionary.
590
+ NOTE: `src` is the local rank of the source rank.
591
+ """
592
+ # Bypass the function if we are using only 1 GPU.
593
+ if (not torch.distributed.is_initialized() or self.world_size == 1):
594
+ return tensor_dict
595
+
596
+ group = self.device_group
597
+ metadata_group = self.cpu_group
598
+ assert src < self.world_size, f"Invalid src rank ({src})"
599
+
600
+ rank_in_group = self.rank_in_group
601
+ if rank_in_group == src:
602
+ metadata_list: List[Tuple[Any, Any]] = []
603
+ assert isinstance(
604
+ tensor_dict,
605
+ dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
606
+ metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
607
+ # `metadata_list` lives in CPU memory.
608
+ # `broadcast_object_list` has serialization & deserialization,
609
+ # all happening on CPU. Therefore, we can use the CPU group.
610
+ self.broadcast_object(metadata_list, src=src)
611
+ async_handles = []
612
+ for tensor in tensor_list:
613
+ if tensor.numel() == 0:
614
+ # Skip broadcasting empty tensors.
615
+ continue
616
+ if tensor.is_cpu:
617
+ # use metadata_group for CPU tensors
618
+ handle = torch.distributed.broadcast(tensor,
619
+ src=self.ranks[src],
620
+ group=metadata_group,
621
+ async_op=True)
622
+ else:
623
+ # use group for GPU tensors
624
+ handle = torch.distributed.broadcast(tensor,
625
+ src=self.ranks[src],
626
+ group=group,
627
+ async_op=True)
628
+ async_handles.append(handle)
629
+ for async_handle in async_handles:
630
+ async_handle.wait()
631
+
632
+ else:
633
+ metadata_list = self.broadcast_object(None, src=src)
634
+ tensor_dict = {}
635
+ async_handles = []
636
+ for key, value in metadata_list:
637
+ if isinstance(value, TensorMetadata):
638
+ tensor = torch.empty(value.size,
639
+ dtype=value.dtype,
640
+ device=value.device)
641
+ if tensor.numel() == 0:
642
+ # Skip broadcasting empty tensors.
643
+ tensor_dict[key] = tensor
644
+ continue
645
+ if tensor.is_cpu:
646
+ # use metadata_group for CPU tensors
647
+ handle = torch.distributed.broadcast(
648
+ tensor,
649
+ src=self.ranks[src],
650
+ group=metadata_group,
651
+ async_op=True)
652
+ else:
653
+ # use group for GPU tensors
654
+ handle = torch.distributed.broadcast(
655
+ tensor,
656
+ src=self.ranks[src],
657
+ group=group,
658
+ async_op=True)
659
+ async_handles.append(handle)
660
+ tensor_dict[key] = tensor
661
+ else:
662
+ tensor_dict[key] = value
663
+ for async_handle in async_handles:
664
+ async_handle.wait()
665
+ return tensor_dict
666
+
667
+ def send_tensor_dict(
668
+ self,
669
+ tensor_dict: Dict[str, Union[torch.Tensor, Any]],
670
+ dst: Optional[int] = None,
671
+ all_gather_group: Optional["GroupCoordinator"] = None,
672
+ ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
673
+ """Send the input tensor dictionary.
674
+ NOTE: `dst` is the local rank of the source rank.
675
+ """
676
+ # Bypass the function if we are using only 1 GPU.
677
+ if not torch.distributed.is_initialized() or self.world_size == 1:
678
+ return tensor_dict
679
+
680
+ all_gather_size = (1 if all_gather_group is None else
681
+ all_gather_group.world_size)
682
+ all_gather_rank = (0 if all_gather_group is None else
683
+ all_gather_group.rank_in_group)
684
+
685
+ group = self.device_group
686
+ metadata_group = self.cpu_group
687
+
688
+ if dst is None:
689
+ dst = (self.rank_in_group + 1) % self.world_size
690
+ assert dst < self.world_size, f"Invalid dst rank ({dst})"
691
+
692
+ metadata_list: List[Tuple[Any, Any]] = []
693
+ assert isinstance(
694
+ tensor_dict,
695
+ dict), f"Expecting a dictionary, got {type(tensor_dict)}"
696
+ metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
697
+ # `metadata_list` lives in CPU memory.
698
+ # `send_object_list` has serialization & deserialization,
699
+ # all happening on CPU. Therefore, we can use the CPU group.
700
+ self.send_object(metadata_list, dst=dst)
701
+ for tensor in tensor_list:
702
+ if tensor.numel() == 0:
703
+ # Skip sending empty tensors.
704
+ continue
705
+
706
+ # send-allgather: send only a slice, then do allgather.
707
+ if (all_gather_group is not None
708
+ and tensor.numel() % all_gather_size == 0):
709
+ tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
710
+
711
+ if tensor.is_cpu:
712
+ # use metadata_group for CPU tensors
713
+ torch.distributed.send(tensor,
714
+ dst=self.ranks[dst],
715
+ group=metadata_group)
716
+ else:
717
+ # use group for GPU tensors
718
+ torch.distributed.send(tensor,
719
+ dst=self.ranks[dst],
720
+ group=group)
721
+ return None
722
+
723
+ def recv_tensor_dict(
724
+ self,
725
+ src: Optional[int] = None,
726
+ all_gather_group: Optional["GroupCoordinator"] = None,
727
+ ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
728
+ """Recv the input tensor dictionary.
729
+ NOTE: `src` is the local rank of the source rank.
730
+ """
731
+ # Bypass the function if we are using only 1 GPU.
732
+ if not torch.distributed.is_initialized() or self.world_size == 1:
733
+ return None
734
+
735
+ all_gather_size = (1 if all_gather_group is None else
736
+ all_gather_group.world_size)
737
+ all_gather_rank = (0 if all_gather_group is None else
738
+ all_gather_group.rank_in_group)
739
+
740
+ group = self.device_group
741
+ metadata_group = self.cpu_group
742
+
743
+ if src is None:
744
+ src = (self.rank_in_group - 1) % self.world_size
745
+ assert src < self.world_size, f"Invalid src rank ({src})"
746
+
747
+ recv_metadata_list = self.recv_object(src=src)
748
+ tensor_dict: Dict[str, Any] = {}
749
+ for key, value in recv_metadata_list:
750
+ if isinstance(value, TensorMetadata):
751
+ tensor = torch.empty(value.size,
752
+ dtype=value.dtype,
753
+ device=value.device)
754
+ if tensor.numel() == 0:
755
+ # Skip broadcasting empty tensors.
756
+ tensor_dict[key] = tensor
757
+ continue
758
+
759
+ # send-allgather: send only a slice, then do allgather.
760
+ use_all_gather = (all_gather_group is not None
761
+ and tensor.numel() % all_gather_size == 0)
762
+
763
+ if use_all_gather:
764
+ orig_shape = tensor.shape
765
+ tensor = tensor.reshape(all_gather_size,
766
+ -1)[all_gather_rank]
767
+
768
+ if tensor.is_cpu:
769
+ # use metadata_group for CPU tensors
770
+ torch.distributed.recv(tensor,
771
+ src=self.ranks[src],
772
+ group=metadata_group)
773
+ else:
774
+ # use group for GPU tensors
775
+ torch.distributed.recv(tensor,
776
+ src=self.ranks[src],
777
+ group=group)
778
+ if use_all_gather:
779
+ # do the allgather
780
+ tensor = all_gather_group.all_gather( # type: ignore
781
+ tensor, dim=0)
782
+ tensor = tensor.reshape(orig_shape)
783
+
784
+ tensor_dict[key] = tensor
785
+ else:
786
+ tensor_dict[key] = value
787
+ return tensor_dict
788
+
789
+ def barrier(self):
790
+ """Barrier synchronization among the group.
791
+ NOTE: don't use `device_group` here! `barrier` in NCCL is
792
+ terrible because it is internally a broadcast operation with
793
+ secretly created GPU tensors. It is easy to mess up the current
794
+ device. Use the CPU group instead.
795
+ """
796
+ torch.distributed.barrier(group=self.cpu_group)
797
+
798
+ def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
799
+ """Sends a tensor to the destination rank in a non-blocking way"""
800
+ """NOTE: `dst` is the local rank of the destination rank."""
801
+ if dst is None:
802
+ dst = (self.rank_in_group + 1) % self.world_size
803
+
804
+ pynccl_comm = self.pynccl_comm
805
+ if pynccl_comm is not None and not pynccl_comm.disabled:
806
+ pynccl_comm.send(tensor, dst)
807
+ else:
808
+ torch.distributed.send(tensor, self.ranks[dst], self.device_group)
809
+
810
+ def recv(self,
811
+ size: torch.Size,
812
+ dtype: torch.dtype,
813
+ src: Optional[int] = None) -> torch.Tensor:
814
+ """Receives a tensor from the source rank."""
815
+ """NOTE: `src` is the local rank of the source rank."""
816
+ if src is None:
817
+ src = (self.rank_in_group - 1) % self.world_size
818
+
819
+ tensor = torch.empty(size, dtype=dtype, device=self.device)
820
+ pynccl_comm = self.pynccl_comm
821
+ if pynccl_comm is not None and not pynccl_comm.disabled:
822
+ pynccl_comm.recv(tensor, src)
823
+ else:
824
+ torch.distributed.recv(tensor, self.ranks[src], self.device_group)
825
+ return tensor
826
+
827
+ def destroy(self):
828
+ if self.device_group is not None:
829
+ torch.distributed.destroy_process_group(self.device_group)
830
+ self.device_group = None
831
+ if self.cpu_group is not None:
832
+ torch.distributed.destroy_process_group(self.cpu_group)
833
+ self.cpu_group = None
834
+ if self.pynccl_comm is not None:
835
+ self.pynccl_comm = None
836
+ if self.ca_comm is not None:
837
+ self.ca_comm = None
838
+ if self.mq_broadcaster is not None:
839
+ self.mq_broadcaster = None
840
+
841
+
842
+ _WORLD: Optional[GroupCoordinator] = None
843
+
844
+
845
+ def get_world_group() -> GroupCoordinator:
846
+ assert _WORLD is not None, ("world group is not initialized")
847
+ return _WORLD
848
+
849
+
850
+ def init_world_group(ranks: List[int], local_rank: int,
851
+ backend: str) -> GroupCoordinator:
852
+ return GroupCoordinator(
853
+ group_ranks=[ranks],
854
+ local_rank=local_rank,
855
+ torch_distributed_backend=backend,
856
+ use_pynccl=False,
857
+ use_custom_allreduce=False,
858
+ use_tpu_communicator=False,
859
+ use_hpu_communicator=False,
860
+ use_xpu_communicator=False,
861
+ group_name="world",
862
+ )
863
+
864
+
865
+ def init_model_parallel_group(
866
+ group_ranks: List[List[int]],
867
+ local_rank: int,
868
+ backend: str,
869
+ use_custom_allreduce: Optional[bool] = None,
870
+ use_message_queue_broadcaster: bool = False,
871
+ group_name: Optional[str] = None,
872
+ ) -> GroupCoordinator:
873
+ if use_custom_allreduce is None:
874
+ use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
875
+ from vllm.platforms import current_platform
876
+ return GroupCoordinator(
877
+ group_ranks=group_ranks,
878
+ local_rank=local_rank,
879
+ torch_distributed_backend=backend,
880
+ use_pynccl=current_platform.is_cuda_alike(),
881
+ use_custom_allreduce=current_platform.is_cuda_alike()
882
+ and use_custom_allreduce,
883
+ use_tpu_communicator=True,
884
+ use_hpu_communicator=True,
885
+ use_xpu_communicator=True,
886
+ use_message_queue_broadcaster=use_message_queue_broadcaster,
887
+ group_name=group_name,
888
+ )
889
+
890
+
891
+ _TP: Optional[GroupCoordinator] = None
892
+
893
+
894
+ def get_tp_group() -> GroupCoordinator:
895
+ assert _TP is not None, ("tensor model parallel group is not initialized")
896
+ return _TP
897
+
898
+
899
+ # kept for backward compatibility
900
+ get_tensor_model_parallel_group = get_tp_group
901
+
902
+ _PP: Optional[GroupCoordinator] = None
903
+
904
+
905
+ def get_pp_group() -> GroupCoordinator:
906
+ assert _PP is not None, (
907
+ "pipeline model parallel group is not initialized")
908
+ return _PP
909
+
910
+
911
+ # kept for backward compatibility
912
+ get_pipeline_model_parallel_group = get_pp_group
913
+
914
+ _KV_TRANSFER: Optional[kv_transfer.KVTransferAgent] = None
915
+
916
+
917
+ def get_kv_transfer_group() -> kv_transfer.KVTransferAgent:
918
+ assert _KV_TRANSFER is not None, (
919
+ "disaggregated KV cache transfer parallel group is not initialized")
920
+ return _KV_TRANSFER
921
+
922
+
923
+ @contextmanager
924
+ def graph_capture(device: torch.device):
925
+ """
926
+ `graph_capture` is a context manager which should surround the code that
927
+ is capturing the CUDA graph. Its main purpose is to ensure that the
928
+ some operations will be run after the graph is captured, before the graph
929
+ is replayed. It returns a `GraphCaptureContext` object which contains the
930
+ necessary data for the graph capture. Currently, it only contains the
931
+ stream that the graph capture is running on. This stream is set to the
932
+ current CUDA stream when the context manager is entered and reset to the
933
+ default stream when the context manager is exited. This is to ensure that
934
+ the graph capture is running on a separate stream from the default stream,
935
+ in order to explicitly distinguish the kernels to capture
936
+ from other kernels possibly launched on background in the default stream.
937
+ """
938
+ context = GraphCaptureContext(torch.cuda.Stream(device=device))
939
+ with get_tp_group().graph_capture(context), get_pp_group().graph_capture(
940
+ context):
941
+ yield context
942
+
943
+
944
+ logger = init_logger(__name__)
945
+
946
+ _ENABLE_CUSTOM_ALL_REDUCE = True
947
+
948
+
949
+ def set_custom_all_reduce(enable: bool):
950
+ global _ENABLE_CUSTOM_ALL_REDUCE
951
+ _ENABLE_CUSTOM_ALL_REDUCE = enable
952
+
953
+
954
+ def init_distributed_environment(
955
+ world_size: int = -1,
956
+ rank: int = -1,
957
+ distributed_init_method: str = "env://",
958
+ local_rank: int = -1,
959
+ backend: str = "nccl",
960
+ ):
961
+ logger.debug(
962
+ "world_size=%d rank=%d local_rank=%d "
963
+ "distributed_init_method=%s backend=%s", world_size, rank, local_rank,
964
+ distributed_init_method, backend)
965
+ if not torch.distributed.is_initialized():
966
+ assert distributed_init_method is not None, (
967
+ "distributed_init_method must be provided when initializing "
968
+ "distributed environment")
969
+ # this backend is used for WORLD
970
+ torch.distributed.init_process_group(
971
+ backend=backend,
972
+ init_method=distributed_init_method,
973
+ world_size=world_size,
974
+ rank=rank)
975
+ # set the local rank
976
+ # local_rank is not available in torch ProcessGroup,
977
+ # see https://github.com/pytorch/pytorch/issues/122816
978
+ if local_rank == -1:
979
+ # local rank not set, this usually happens in single-node
980
+ # setting, where we can use rank as local rank
981
+ if distributed_init_method == "env://":
982
+ local_rank = envs.LOCAL_RANK
983
+ else:
984
+ local_rank = rank
985
+ global _WORLD
986
+ if _WORLD is None:
987
+ ranks = list(range(torch.distributed.get_world_size()))
988
+ _WORLD = init_world_group(ranks, local_rank, backend)
989
+ else:
990
+ assert _WORLD.world_size == torch.distributed.get_world_size(), (
991
+ "world group already initialized with a different world size")
992
+
993
+
994
+ def initialize_model_parallel(
995
+ tensor_model_parallel_size: int = 1,
996
+ pipeline_model_parallel_size: int = 1,
997
+ backend: Optional[str] = None,
998
+ ) -> None:
999
+ """
1000
+ Initialize model parallel groups.
1001
+
1002
+ Arguments:
1003
+ tensor_model_parallel_size: number of GPUs used for tensor model
1004
+ parallelism.
1005
+ pipeline_model_parallel_size: number of GPUs used for pipeline model
1006
+ parallelism.
1007
+
1008
+ Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
1009
+ use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
1010
+ the model pipeline. The present function will
1011
+ create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
1012
+ 4 tensor model-parallel groups:
1013
+ [g0, g1], [g2, g3], [g4, g5], [g6, g7]
1014
+ 2 pipeline model-parallel groups:
1015
+ [g0, g2, g4, g6], [g1, g3, g5, g7]
1016
+ Note that for efficiency, the caller should make sure adjacent ranks
1017
+ are on the same DGX box. For example if we are using 2 DGX-1 boxes
1018
+ with a total of 16 GPUs, rank 0 to 7 belong to the first box and
1019
+ ranks 8 to 15 belong to the second box.
1020
+ """
1021
+ # Get world size and rank. Ensure some consistencies.
1022
+ assert torch.distributed.is_initialized()
1023
+ world_size: int = torch.distributed.get_world_size()
1024
+ backend = backend or torch.distributed.get_backend(
1025
+ get_world_group().device_group)
1026
+
1027
+ if (world_size
1028
+ != tensor_model_parallel_size * pipeline_model_parallel_size):
1029
+ raise RuntimeError(
1030
+ f"world_size ({world_size}) is not equal to "
1031
+ f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
1032
+ f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
1033
+
1034
+ # Build the tensor model-parallel groups.
1035
+ num_tensor_model_parallel_groups: int = (world_size //
1036
+ tensor_model_parallel_size)
1037
+ global _TP
1038
+ assert _TP is None, ("tensor model parallel group is already initialized")
1039
+ group_ranks = []
1040
+ for i in range(num_tensor_model_parallel_groups):
1041
+ ranks = list(
1042
+ range(i * tensor_model_parallel_size,
1043
+ (i + 1) * tensor_model_parallel_size))
1044
+ group_ranks.append(ranks)
1045
+
1046
+ # message queue broadcaster is only used in tensor model parallel group
1047
+ _TP = init_model_parallel_group(group_ranks,
1048
+ get_world_group().local_rank,
1049
+ backend,
1050
+ use_message_queue_broadcaster=True,
1051
+ group_name="tp")
1052
+
1053
+ # Build the pipeline model-parallel groups.
1054
+ num_pipeline_model_parallel_groups: int = (world_size //
1055
+ pipeline_model_parallel_size)
1056
+ global _PP
1057
+ assert _PP is None, (
1058
+ "pipeline model parallel group is already initialized")
1059
+ group_ranks = []
1060
+ for i in range(num_pipeline_model_parallel_groups):
1061
+ ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
1062
+ group_ranks.append(ranks)
1063
+ # pipeline parallel does not need custom allreduce
1064
+ _PP = init_model_parallel_group(group_ranks,
1065
+ get_world_group().local_rank,
1066
+ backend,
1067
+ use_custom_allreduce=False,
1068
+ group_name="pp")
1069
+
1070
+
1071
+ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
1072
+ """
1073
+ Initialize KV cache transfer parallel group.
1074
+ """
1075
+
1076
+ global _KV_TRANSFER
1077
+
1078
+ if vllm_config.kv_transfer_config is None:
1079
+ return
1080
+
1081
+ if all([
1082
+ vllm_config.kv_transfer_config.need_kv_parallel_group, _KV_TRANSFER
1083
+ is None
1084
+ ]):
1085
+ _KV_TRANSFER = kv_transfer.KVTransferAgent(
1086
+ rank=get_world_group().rank,
1087
+ local_rank=get_world_group().local_rank,
1088
+ config=vllm_config)
1089
+
1090
+
1091
+ def ensure_model_parallel_initialized(
1092
+ tensor_model_parallel_size: int,
1093
+ pipeline_model_parallel_size: int,
1094
+ backend: Optional[str] = None,
1095
+ ) -> None:
1096
+ """Helper to initialize model parallel groups if they are not initialized,
1097
+ or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
1098
+ values if the model parallel groups are initialized.
1099
+ """
1100
+ backend = backend or torch.distributed.get_backend(
1101
+ get_world_group().device_group)
1102
+ if not model_parallel_is_initialized():
1103
+ initialize_model_parallel(tensor_model_parallel_size,
1104
+ pipeline_model_parallel_size, backend)
1105
+ return
1106
+
1107
+ assert (
1108
+ get_tensor_model_parallel_world_size() == tensor_model_parallel_size
1109
+ ), ("tensor parallel group already initialized, but of unexpected size: "
1110
+ f"{get_tensor_model_parallel_world_size()=} vs. "
1111
+ f"{tensor_model_parallel_size=}")
1112
+ pp_world_size = get_pp_group().world_size
1113
+ assert (pp_world_size == pipeline_model_parallel_size), (
1114
+ "pipeline parallel group already initialized, but of unexpected size: "
1115
+ f"{pp_world_size=} vs. "
1116
+ f"{pipeline_model_parallel_size=}")
1117
+
1118
+
1119
+ def model_parallel_is_initialized():
1120
+ """Check if tensor and pipeline parallel groups are initialized."""
1121
+ return (_TP is not None and _PP is not None)
1122
+
1123
+
1124
+ _TP_STATE_PATCHED = False
1125
+
1126
+
1127
+ @contextmanager
1128
+ def patch_tensor_parallel_group(tp_group: GroupCoordinator):
1129
+ """Patch the tp group temporarily until this function ends.
1130
+
1131
+ This method is for draft workers of speculative decoding to run draft model
1132
+ with different tp degree from that of target model workers.
1133
+
1134
+ Args:
1135
+ tp_group (GroupCoordinator): the tp group coordinator
1136
+ """
1137
+ global _TP_STATE_PATCHED
1138
+ assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
1139
+
1140
+ _TP_STATE_PATCHED = True
1141
+ old_tp_group = get_tp_group()
1142
+ global _TP
1143
+ _TP = tp_group
1144
+ try:
1145
+ yield
1146
+ finally:
1147
+ # restore the original state
1148
+ _TP_STATE_PATCHED = False
1149
+ _TP = old_tp_group
1150
+
1151
+
1152
+ def get_tensor_model_parallel_world_size():
1153
+ """Return world size for the tensor model parallel group."""
1154
+ return get_tp_group().world_size
1155
+
1156
+
1157
+ def get_tensor_model_parallel_rank():
1158
+ """Return my rank for the tensor model parallel group."""
1159
+ return get_tp_group().rank_in_group
1160
+
1161
+
1162
+ def destroy_model_parallel():
1163
+ """Set the groups to none and destroy them."""
1164
+ global _TP
1165
+ if _TP:
1166
+ _TP.destroy()
1167
+ _TP = None
1168
+
1169
+ global _PP
1170
+ if _PP:
1171
+ _PP.destroy()
1172
+ _PP = None
1173
+
1174
+
1175
+ def destroy_distributed_environment():
1176
+ global _WORLD
1177
+ if _WORLD:
1178
+ _WORLD.destroy()
1179
+ _WORLD = None
1180
+ if torch.distributed.is_initialized():
1181
+ torch.distributed.destroy_process_group()
1182
+
1183
+
1184
+ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
1185
+ destroy_model_parallel()
1186
+ destroy_distributed_environment()
1187
+ with contextlib.suppress(AssertionError):
1188
+ torch.distributed.destroy_process_group()
1189
+ if shutdown_ray:
1190
+ import ray # Lazy import Ray
1191
+ ray.shutdown()
1192
+ gc.collect()
1193
+ from vllm.platforms import current_platform
1194
+ if not current_platform.is_cpu():
1195
+ torch.cuda.empty_cache()
1196
+ try:
1197
+ torch._C._host_emptyCache()
1198
+ except AttributeError:
1199
+ logger.warning(
1200
+ "torch._C._host_emptyCache() only available in Pytorch >=2.5")
1201
+
1202
+
1203
+ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
1204
+ source_rank: int = 0) -> List[bool]:
1205
+ """
1206
+ This is a collective operation that returns if each rank is in the same node
1207
+ as the source rank. It tests if processes are attached to the same
1208
+ memory system (shared access to shared memory).
1209
+ """
1210
+ if isinstance(pg, ProcessGroup):
1211
+ assert torch.distributed.get_backend(
1212
+ pg) != torch.distributed.Backend.NCCL, (
1213
+ "in_the_same_node_as should be tested with a non-NCCL group.")
1214
+ # local rank inside the group
1215
+ rank = torch.distributed.get_rank(group=pg)
1216
+ world_size = torch.distributed.get_world_size(group=pg)
1217
+
1218
+ # global ranks of the processes in the group
1219
+ ranks = torch.distributed.get_process_group_ranks(pg)
1220
+ else:
1221
+ rank = pg.rank
1222
+ world_size = pg.world_size
1223
+ ranks = list(range(world_size))
1224
+
1225
+ # local tensor in each process to store the result
1226
+ is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
1227
+
1228
+ magic_message = b"magic_message"
1229
+ shm = None
1230
+
1231
+ try:
1232
+ with contextlib.suppress(OSError):
1233
+ if rank == source_rank:
1234
+ # create a shared memory segment
1235
+ shm = shared_memory.SharedMemory(create=True, size=128)
1236
+ shm.buf[:len(magic_message)] = magic_message
1237
+ if isinstance(pg, ProcessGroup):
1238
+ torch.distributed.broadcast_object_list(
1239
+ [shm.name], src=ranks[source_rank], group=pg)
1240
+ else:
1241
+ pg.broadcast_obj(shm.name, src=source_rank)
1242
+ is_in_the_same_node[rank] = 1
1243
+ else:
1244
+ # try to open the shared memory segment
1245
+ if isinstance(pg, ProcessGroup):
1246
+ recv = [None]
1247
+ torch.distributed.broadcast_object_list(
1248
+ recv, src=ranks[source_rank], group=pg)
1249
+ name = recv[0]
1250
+ else:
1251
+ name = pg.broadcast_obj(None, src=source_rank)
1252
+ # fix to https://stackoverflow.com/q/62748654/9191338
1253
+ # Python incorrectly tracks shared memory even if it is not
1254
+ # created by the process. The following patch is a workaround.
1255
+ with patch("multiprocessing.resource_tracker.register",
1256
+ lambda *args, **kwargs: None):
1257
+ shm = shared_memory.SharedMemory(name=name)
1258
+ if shm.buf[:len(magic_message)] == magic_message:
1259
+ is_in_the_same_node[rank] = 1
1260
+ except Exception as e:
1261
+ logger.error("Error ignored in is_in_the_same_node: %s", e)
1262
+ finally:
1263
+ if shm:
1264
+ shm.close()
1265
+
1266
+ if isinstance(pg, ProcessGroup):
1267
+ torch.distributed.barrier(group=pg)
1268
+ else:
1269
+ pg.barrier()
1270
+
1271
+ # clean up the shared memory segment
1272
+ with contextlib.suppress(OSError):
1273
+ if rank == source_rank and shm:
1274
+ shm.unlink()
1275
+
1276
+ if isinstance(pg, ProcessGroup):
1277
+ torch.distributed.all_reduce(is_in_the_same_node, group=pg)
1278
+ aggregated_data = is_in_the_same_node
1279
+ else:
1280
+ aggregated_data = torch.zeros_like(is_in_the_same_node)
1281
+ for i in range(world_size):
1282
+ rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
1283
+ aggregated_data += rank_data
1284
+
1285
+ return [x == 1 for x in aggregated_data.tolist()]