Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/aiohappyeyeballs/__init__.py +13 -0
- .venv/lib/python3.11/site-packages/aiohappyeyeballs/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/aiohappyeyeballs/__pycache__/_staggered.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/aiohappyeyeballs/__pycache__/impl.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/aiohappyeyeballs/__pycache__/types.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/aiohappyeyeballs/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/aiohappyeyeballs/_staggered.py +202 -0
- .venv/lib/python3.11/site-packages/aiohappyeyeballs/impl.py +210 -0
- .venv/lib/python3.11/site-packages/aiohappyeyeballs/py.typed +0 -0
- .venv/lib/python3.11/site-packages/aiohappyeyeballs/types.py +12 -0
- .venv/lib/python3.11/site-packages/aiohappyeyeballs/utils.py +97 -0
- .venv/lib/python3.11/site-packages/attrs/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/attrs/__pycache__/validators.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/__init__.py +62 -0
- .venv/lib/python3.11/site-packages/vllm/_custom_ops.py +1098 -0
- .venv/lib/python3.11/site-packages/vllm/_ipex_ops.py +228 -0
- .venv/lib/python3.11/site-packages/vllm/_version.py +16 -0
- .venv/lib/python3.11/site-packages/vllm/beam_search.py +73 -0
- .venv/lib/python3.11/site-packages/vllm/config.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/connections.py +169 -0
- .venv/lib/python3.11/site-packages/vllm/cumem_allocator.abi3.so +0 -0
- .venv/lib/python3.11/site-packages/vllm/envs.py +588 -0
- .venv/lib/python3.11/site-packages/vllm/executor/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/vllm/executor/__pycache__/ray_distributed_executor.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/executor/executor_base.py +376 -0
- .venv/lib/python3.11/site-packages/vllm/executor/mp_distributed_executor.py +243 -0
- .venv/lib/python3.11/site-packages/vllm/executor/msgspec_utils.py +29 -0
- .venv/lib/python3.11/site-packages/vllm/executor/ray_distributed_executor.py +638 -0
- .venv/lib/python3.11/site-packages/vllm/executor/ray_utils.py +378 -0
- .venv/lib/python3.11/site-packages/vllm/executor/uniproc_executor.py +134 -0
- .venv/lib/python3.11/site-packages/vllm/forward_context.py +101 -0
- .venv/lib/python3.11/site-packages/vllm/logger.py +210 -0
- .venv/lib/python3.11/site-packages/vllm/logits_process.py +121 -0
- .venv/lib/python3.11/site-packages/vllm/outputs.py +529 -0
- .venv/lib/python3.11/site-packages/vllm/platforms/__init__.py +237 -0
- .venv/lib/python3.11/site-packages/vllm/platforms/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/platforms/__pycache__/cpu.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/platforms/__pycache__/cuda.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/platforms/__pycache__/hpu.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/platforms/__pycache__/interface.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/platforms/__pycache__/neuron.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/platforms/__pycache__/openvino.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/platforms/__pycache__/rocm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/platforms/__pycache__/tpu.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/platforms/__pycache__/xpu.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/vllm/platforms/cpu.py +145 -0
- .venv/lib/python3.11/site-packages/vllm/platforms/cuda.py +390 -0
- .venv/lib/python3.11/site-packages/vllm/platforms/hpu.py +90 -0
- .venv/lib/python3.11/site-packages/vllm/platforms/interface.py +308 -0
- .venv/lib/python3.11/site-packages/vllm/platforms/neuron.py +57 -0
.venv/lib/python3.11/site-packages/aiohappyeyeballs/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__version__ = "2.4.6"
|
| 2 |
+
|
| 3 |
+
from .impl import start_connection
|
| 4 |
+
from .types import AddrInfoType
|
| 5 |
+
from .utils import addr_to_addr_infos, pop_addr_infos_interleave, remove_addr_infos
|
| 6 |
+
|
| 7 |
+
__all__ = (
|
| 8 |
+
"AddrInfoType",
|
| 9 |
+
"addr_to_addr_infos",
|
| 10 |
+
"pop_addr_infos_interleave",
|
| 11 |
+
"remove_addr_infos",
|
| 12 |
+
"start_connection",
|
| 13 |
+
)
|
.venv/lib/python3.11/site-packages/aiohappyeyeballs/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (520 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/aiohappyeyeballs/__pycache__/_staggered.cpython-311.pyc
ADDED
|
Binary file (8.87 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/aiohappyeyeballs/__pycache__/impl.cpython-311.pyc
ADDED
|
Binary file (9.75 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/aiohappyeyeballs/__pycache__/types.cpython-311.pyc
ADDED
|
Binary file (514 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/aiohappyeyeballs/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (4.16 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/aiohappyeyeballs/_staggered.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import contextlib
|
| 3 |
+
from typing import (
|
| 4 |
+
TYPE_CHECKING,
|
| 5 |
+
Any,
|
| 6 |
+
Awaitable,
|
| 7 |
+
Callable,
|
| 8 |
+
Iterable,
|
| 9 |
+
List,
|
| 10 |
+
Optional,
|
| 11 |
+
Set,
|
| 12 |
+
Tuple,
|
| 13 |
+
TypeVar,
|
| 14 |
+
Union,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
_T = TypeVar("_T")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _set_result(wait_next: "asyncio.Future[None]") -> None:
|
| 21 |
+
"""Set the result of a future if it is not already done."""
|
| 22 |
+
if not wait_next.done():
|
| 23 |
+
wait_next.set_result(None)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
async def _wait_one(
|
| 27 |
+
futures: "Iterable[asyncio.Future[Any]]",
|
| 28 |
+
loop: asyncio.AbstractEventLoop,
|
| 29 |
+
) -> _T:
|
| 30 |
+
"""Wait for the first future to complete."""
|
| 31 |
+
wait_next = loop.create_future()
|
| 32 |
+
|
| 33 |
+
def _on_completion(fut: "asyncio.Future[Any]") -> None:
|
| 34 |
+
if not wait_next.done():
|
| 35 |
+
wait_next.set_result(fut)
|
| 36 |
+
|
| 37 |
+
for f in futures:
|
| 38 |
+
f.add_done_callback(_on_completion)
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
return await wait_next
|
| 42 |
+
finally:
|
| 43 |
+
for f in futures:
|
| 44 |
+
f.remove_done_callback(_on_completion)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
async def staggered_race(
|
| 48 |
+
coro_fns: Iterable[Callable[[], Awaitable[_T]]],
|
| 49 |
+
delay: Optional[float],
|
| 50 |
+
*,
|
| 51 |
+
loop: Optional[asyncio.AbstractEventLoop] = None,
|
| 52 |
+
) -> Tuple[Optional[_T], Optional[int], List[Optional[BaseException]]]:
|
| 53 |
+
"""
|
| 54 |
+
Run coroutines with staggered start times and take the first to finish.
|
| 55 |
+
|
| 56 |
+
This method takes an iterable of coroutine functions. The first one is
|
| 57 |
+
started immediately. From then on, whenever the immediately preceding one
|
| 58 |
+
fails (raises an exception), or when *delay* seconds has passed, the next
|
| 59 |
+
coroutine is started. This continues until one of the coroutines complete
|
| 60 |
+
successfully, in which case all others are cancelled, or until all
|
| 61 |
+
coroutines fail.
|
| 62 |
+
|
| 63 |
+
The coroutines provided should be well-behaved in the following way:
|
| 64 |
+
|
| 65 |
+
* They should only ``return`` if completed successfully.
|
| 66 |
+
|
| 67 |
+
* They should always raise an exception if they did not complete
|
| 68 |
+
successfully. In particular, if they handle cancellation, they should
|
| 69 |
+
probably reraise, like this::
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
# do work
|
| 73 |
+
except asyncio.CancelledError:
|
| 74 |
+
# undo partially completed work
|
| 75 |
+
raise
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
----
|
| 79 |
+
coro_fns: an iterable of coroutine functions, i.e. callables that
|
| 80 |
+
return a coroutine object when called. Use ``functools.partial`` or
|
| 81 |
+
lambdas to pass arguments.
|
| 82 |
+
|
| 83 |
+
delay: amount of time, in seconds, between starting coroutines. If
|
| 84 |
+
``None``, the coroutines will run sequentially.
|
| 85 |
+
|
| 86 |
+
loop: the event loop to use. If ``None``, the running loop is used.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
-------
|
| 90 |
+
tuple *(winner_result, winner_index, exceptions)* where
|
| 91 |
+
|
| 92 |
+
- *winner_result*: the result of the winning coroutine, or ``None``
|
| 93 |
+
if no coroutines won.
|
| 94 |
+
|
| 95 |
+
- *winner_index*: the index of the winning coroutine in
|
| 96 |
+
``coro_fns``, or ``None`` if no coroutines won. If the winning
|
| 97 |
+
coroutine may return None on success, *winner_index* can be used
|
| 98 |
+
to definitively determine whether any coroutine won.
|
| 99 |
+
|
| 100 |
+
- *exceptions*: list of exceptions returned by the coroutines.
|
| 101 |
+
``len(exceptions)`` is equal to the number of coroutines actually
|
| 102 |
+
started, and the order is the same as in ``coro_fns``. The winning
|
| 103 |
+
coroutine's entry is ``None``.
|
| 104 |
+
|
| 105 |
+
"""
|
| 106 |
+
loop = loop or asyncio.get_running_loop()
|
| 107 |
+
exceptions: List[Optional[BaseException]] = []
|
| 108 |
+
tasks: Set[asyncio.Task[Optional[Tuple[_T, int]]]] = set()
|
| 109 |
+
|
| 110 |
+
async def run_one_coro(
|
| 111 |
+
coro_fn: Callable[[], Awaitable[_T]],
|
| 112 |
+
this_index: int,
|
| 113 |
+
start_next: "asyncio.Future[None]",
|
| 114 |
+
) -> Optional[Tuple[_T, int]]:
|
| 115 |
+
"""
|
| 116 |
+
Run a single coroutine.
|
| 117 |
+
|
| 118 |
+
If the coroutine fails, set the exception in the exceptions list and
|
| 119 |
+
start the next coroutine by setting the result of the start_next.
|
| 120 |
+
|
| 121 |
+
If the coroutine succeeds, return the result and the index of the
|
| 122 |
+
coroutine in the coro_fns list.
|
| 123 |
+
|
| 124 |
+
If SystemExit or KeyboardInterrupt is raised, re-raise it.
|
| 125 |
+
"""
|
| 126 |
+
try:
|
| 127 |
+
result = await coro_fn()
|
| 128 |
+
except (SystemExit, KeyboardInterrupt):
|
| 129 |
+
raise
|
| 130 |
+
except BaseException as e:
|
| 131 |
+
exceptions[this_index] = e
|
| 132 |
+
_set_result(start_next) # Kickstart the next coroutine
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
return result, this_index
|
| 136 |
+
|
| 137 |
+
start_next_timer: Optional[asyncio.TimerHandle] = None
|
| 138 |
+
start_next: Optional[asyncio.Future[None]]
|
| 139 |
+
task: asyncio.Task[Optional[Tuple[_T, int]]]
|
| 140 |
+
done: Union[asyncio.Future[None], asyncio.Task[Optional[Tuple[_T, int]]]]
|
| 141 |
+
coro_iter = iter(coro_fns)
|
| 142 |
+
this_index = -1
|
| 143 |
+
try:
|
| 144 |
+
while True:
|
| 145 |
+
if coro_fn := next(coro_iter, None):
|
| 146 |
+
this_index += 1
|
| 147 |
+
exceptions.append(None)
|
| 148 |
+
start_next = loop.create_future()
|
| 149 |
+
task = loop.create_task(run_one_coro(coro_fn, this_index, start_next))
|
| 150 |
+
tasks.add(task)
|
| 151 |
+
start_next_timer = (
|
| 152 |
+
loop.call_later(delay, _set_result, start_next) if delay else None
|
| 153 |
+
)
|
| 154 |
+
elif not tasks:
|
| 155 |
+
# We exhausted the coro_fns list and no tasks are running
|
| 156 |
+
# so we have no winner and all coroutines failed.
|
| 157 |
+
break
|
| 158 |
+
|
| 159 |
+
while tasks or start_next:
|
| 160 |
+
done = await _wait_one(
|
| 161 |
+
(*tasks, start_next) if start_next else tasks, loop
|
| 162 |
+
)
|
| 163 |
+
if done is start_next:
|
| 164 |
+
# The current task has failed or the timer has expired
|
| 165 |
+
# so we need to start the next task.
|
| 166 |
+
start_next = None
|
| 167 |
+
if start_next_timer:
|
| 168 |
+
start_next_timer.cancel()
|
| 169 |
+
start_next_timer = None
|
| 170 |
+
|
| 171 |
+
# Break out of the task waiting loop to start the next
|
| 172 |
+
# task.
|
| 173 |
+
break
|
| 174 |
+
|
| 175 |
+
if TYPE_CHECKING:
|
| 176 |
+
assert isinstance(done, asyncio.Task)
|
| 177 |
+
|
| 178 |
+
tasks.remove(done)
|
| 179 |
+
if winner := done.result():
|
| 180 |
+
return *winner, exceptions
|
| 181 |
+
finally:
|
| 182 |
+
# We either have:
|
| 183 |
+
# - a winner
|
| 184 |
+
# - all tasks failed
|
| 185 |
+
# - a KeyboardInterrupt or SystemExit.
|
| 186 |
+
|
| 187 |
+
#
|
| 188 |
+
# If the timer is still running, cancel it.
|
| 189 |
+
#
|
| 190 |
+
if start_next_timer:
|
| 191 |
+
start_next_timer.cancel()
|
| 192 |
+
|
| 193 |
+
#
|
| 194 |
+
# If there are any tasks left, cancel them and than
|
| 195 |
+
# wait them so they fill the exceptions list.
|
| 196 |
+
#
|
| 197 |
+
for task in tasks:
|
| 198 |
+
task.cancel()
|
| 199 |
+
with contextlib.suppress(asyncio.CancelledError):
|
| 200 |
+
await task
|
| 201 |
+
|
| 202 |
+
return None, None, exceptions
|
.venv/lib/python3.11/site-packages/aiohappyeyeballs/impl.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Base implementation."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import collections
|
| 5 |
+
import functools
|
| 6 |
+
import itertools
|
| 7 |
+
import socket
|
| 8 |
+
from typing import List, Optional, Sequence, Union
|
| 9 |
+
|
| 10 |
+
from . import _staggered
|
| 11 |
+
from .types import AddrInfoType
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
async def start_connection(
|
| 15 |
+
addr_infos: Sequence[AddrInfoType],
|
| 16 |
+
*,
|
| 17 |
+
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
|
| 18 |
+
happy_eyeballs_delay: Optional[float] = None,
|
| 19 |
+
interleave: Optional[int] = None,
|
| 20 |
+
loop: Optional[asyncio.AbstractEventLoop] = None,
|
| 21 |
+
) -> socket.socket:
|
| 22 |
+
"""
|
| 23 |
+
Connect to a TCP server.
|
| 24 |
+
|
| 25 |
+
Create a socket connection to a specified destination. The
|
| 26 |
+
destination is specified as a list of AddrInfoType tuples as
|
| 27 |
+
returned from getaddrinfo().
|
| 28 |
+
|
| 29 |
+
The arguments are, in order:
|
| 30 |
+
|
| 31 |
+
* ``family``: the address family, e.g. ``socket.AF_INET`` or
|
| 32 |
+
``socket.AF_INET6``.
|
| 33 |
+
* ``type``: the socket type, e.g. ``socket.SOCK_STREAM`` or
|
| 34 |
+
``socket.SOCK_DGRAM``.
|
| 35 |
+
* ``proto``: the protocol, e.g. ``socket.IPPROTO_TCP`` or
|
| 36 |
+
``socket.IPPROTO_UDP``.
|
| 37 |
+
* ``canonname``: the canonical name of the address, e.g.
|
| 38 |
+
``"www.python.org"``.
|
| 39 |
+
* ``sockaddr``: the socket address
|
| 40 |
+
|
| 41 |
+
This method is a coroutine which will try to establish the connection
|
| 42 |
+
in the background. When successful, the coroutine returns a
|
| 43 |
+
socket.
|
| 44 |
+
|
| 45 |
+
The expected use case is to use this method in conjunction with
|
| 46 |
+
loop.create_connection() to establish a connection to a server::
|
| 47 |
+
|
| 48 |
+
socket = await start_connection(addr_infos)
|
| 49 |
+
transport, protocol = await loop.create_connection(
|
| 50 |
+
MyProtocol, sock=socket, ...)
|
| 51 |
+
"""
|
| 52 |
+
if not (current_loop := loop):
|
| 53 |
+
current_loop = asyncio.get_running_loop()
|
| 54 |
+
|
| 55 |
+
single_addr_info = len(addr_infos) == 1
|
| 56 |
+
|
| 57 |
+
if happy_eyeballs_delay is not None and interleave is None:
|
| 58 |
+
# If using happy eyeballs, default to interleave addresses by family
|
| 59 |
+
interleave = 1
|
| 60 |
+
|
| 61 |
+
if interleave and not single_addr_info:
|
| 62 |
+
addr_infos = _interleave_addrinfos(addr_infos, interleave)
|
| 63 |
+
|
| 64 |
+
sock: Optional[socket.socket] = None
|
| 65 |
+
# uvloop can raise RuntimeError instead of OSError
|
| 66 |
+
exceptions: List[List[Union[OSError, RuntimeError]]] = []
|
| 67 |
+
if happy_eyeballs_delay is None or single_addr_info:
|
| 68 |
+
# not using happy eyeballs
|
| 69 |
+
for addrinfo in addr_infos:
|
| 70 |
+
try:
|
| 71 |
+
sock = await _connect_sock(
|
| 72 |
+
current_loop, exceptions, addrinfo, local_addr_infos
|
| 73 |
+
)
|
| 74 |
+
break
|
| 75 |
+
except (RuntimeError, OSError):
|
| 76 |
+
continue
|
| 77 |
+
else: # using happy eyeballs
|
| 78 |
+
sock, _, _ = await _staggered.staggered_race(
|
| 79 |
+
(
|
| 80 |
+
functools.partial(
|
| 81 |
+
_connect_sock, current_loop, exceptions, addrinfo, local_addr_infos
|
| 82 |
+
)
|
| 83 |
+
for addrinfo in addr_infos
|
| 84 |
+
),
|
| 85 |
+
happy_eyeballs_delay,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
if sock is None:
|
| 89 |
+
all_exceptions = [exc for sub in exceptions for exc in sub]
|
| 90 |
+
try:
|
| 91 |
+
first_exception = all_exceptions[0]
|
| 92 |
+
if len(all_exceptions) == 1:
|
| 93 |
+
raise first_exception
|
| 94 |
+
else:
|
| 95 |
+
# If they all have the same str(), raise one.
|
| 96 |
+
model = str(first_exception)
|
| 97 |
+
if all(str(exc) == model for exc in all_exceptions):
|
| 98 |
+
raise first_exception
|
| 99 |
+
# Raise a combined exception so the user can see all
|
| 100 |
+
# the various error messages.
|
| 101 |
+
msg = "Multiple exceptions: {}".format(
|
| 102 |
+
", ".join(str(exc) for exc in all_exceptions)
|
| 103 |
+
)
|
| 104 |
+
# If the errno is the same for all exceptions, raise
|
| 105 |
+
# an OSError with that errno.
|
| 106 |
+
if isinstance(first_exception, OSError):
|
| 107 |
+
first_errno = first_exception.errno
|
| 108 |
+
if all(
|
| 109 |
+
isinstance(exc, OSError) and exc.errno == first_errno
|
| 110 |
+
for exc in all_exceptions
|
| 111 |
+
):
|
| 112 |
+
raise OSError(first_errno, msg)
|
| 113 |
+
elif isinstance(first_exception, RuntimeError) and all(
|
| 114 |
+
isinstance(exc, RuntimeError) for exc in all_exceptions
|
| 115 |
+
):
|
| 116 |
+
raise RuntimeError(msg)
|
| 117 |
+
# We have a mix of OSError and RuntimeError
|
| 118 |
+
# so we have to pick which one to raise.
|
| 119 |
+
# and we raise OSError for compatibility
|
| 120 |
+
raise OSError(msg)
|
| 121 |
+
finally:
|
| 122 |
+
all_exceptions = None # type: ignore[assignment]
|
| 123 |
+
exceptions = None # type: ignore[assignment]
|
| 124 |
+
|
| 125 |
+
return sock
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
async def _connect_sock(
|
| 129 |
+
loop: asyncio.AbstractEventLoop,
|
| 130 |
+
exceptions: List[List[Union[OSError, RuntimeError]]],
|
| 131 |
+
addr_info: AddrInfoType,
|
| 132 |
+
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
|
| 133 |
+
) -> socket.socket:
|
| 134 |
+
"""Create, bind and connect one socket."""
|
| 135 |
+
my_exceptions: List[Union[OSError, RuntimeError]] = []
|
| 136 |
+
exceptions.append(my_exceptions)
|
| 137 |
+
family, type_, proto, _, address = addr_info
|
| 138 |
+
sock = None
|
| 139 |
+
try:
|
| 140 |
+
sock = socket.socket(family=family, type=type_, proto=proto)
|
| 141 |
+
sock.setblocking(False)
|
| 142 |
+
if local_addr_infos is not None:
|
| 143 |
+
for lfamily, _, _, _, laddr in local_addr_infos:
|
| 144 |
+
# skip local addresses of different family
|
| 145 |
+
if lfamily != family:
|
| 146 |
+
continue
|
| 147 |
+
try:
|
| 148 |
+
sock.bind(laddr)
|
| 149 |
+
break
|
| 150 |
+
except OSError as exc:
|
| 151 |
+
msg = (
|
| 152 |
+
f"error while attempting to bind on "
|
| 153 |
+
f"address {laddr!r}: "
|
| 154 |
+
f"{(exc.strerror or '').lower()}"
|
| 155 |
+
)
|
| 156 |
+
exc = OSError(exc.errno, msg)
|
| 157 |
+
my_exceptions.append(exc)
|
| 158 |
+
else: # all bind attempts failed
|
| 159 |
+
if my_exceptions:
|
| 160 |
+
raise my_exceptions.pop()
|
| 161 |
+
else:
|
| 162 |
+
raise OSError(f"no matching local address with {family=} found")
|
| 163 |
+
await loop.sock_connect(sock, address)
|
| 164 |
+
return sock
|
| 165 |
+
except (RuntimeError, OSError) as exc:
|
| 166 |
+
my_exceptions.append(exc)
|
| 167 |
+
if sock is not None:
|
| 168 |
+
try:
|
| 169 |
+
sock.close()
|
| 170 |
+
except OSError as e:
|
| 171 |
+
my_exceptions.append(e)
|
| 172 |
+
raise
|
| 173 |
+
raise
|
| 174 |
+
except:
|
| 175 |
+
if sock is not None:
|
| 176 |
+
try:
|
| 177 |
+
sock.close()
|
| 178 |
+
except OSError as e:
|
| 179 |
+
my_exceptions.append(e)
|
| 180 |
+
raise
|
| 181 |
+
raise
|
| 182 |
+
finally:
|
| 183 |
+
exceptions = my_exceptions = None # type: ignore[assignment]
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def _interleave_addrinfos(
|
| 187 |
+
addrinfos: Sequence[AddrInfoType], first_address_family_count: int = 1
|
| 188 |
+
) -> List[AddrInfoType]:
|
| 189 |
+
"""Interleave list of addrinfo tuples by family."""
|
| 190 |
+
# Group addresses by family
|
| 191 |
+
addrinfos_by_family: collections.OrderedDict[int, List[AddrInfoType]] = (
|
| 192 |
+
collections.OrderedDict()
|
| 193 |
+
)
|
| 194 |
+
for addr in addrinfos:
|
| 195 |
+
family = addr[0]
|
| 196 |
+
if family not in addrinfos_by_family:
|
| 197 |
+
addrinfos_by_family[family] = []
|
| 198 |
+
addrinfos_by_family[family].append(addr)
|
| 199 |
+
addrinfos_lists = list(addrinfos_by_family.values())
|
| 200 |
+
|
| 201 |
+
reordered: List[AddrInfoType] = []
|
| 202 |
+
if first_address_family_count > 1:
|
| 203 |
+
reordered.extend(addrinfos_lists[0][: first_address_family_count - 1])
|
| 204 |
+
del addrinfos_lists[0][: first_address_family_count - 1]
|
| 205 |
+
reordered.extend(
|
| 206 |
+
a
|
| 207 |
+
for a in itertools.chain.from_iterable(itertools.zip_longest(*addrinfos_lists))
|
| 208 |
+
if a is not None
|
| 209 |
+
)
|
| 210 |
+
return reordered
|
.venv/lib/python3.11/site-packages/aiohappyeyeballs/py.typed
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/aiohappyeyeballs/types.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Types for aiohappyeyeballs."""
|
| 2 |
+
|
| 3 |
+
import socket
|
| 4 |
+
from typing import Tuple, Union
|
| 5 |
+
|
| 6 |
+
AddrInfoType = Tuple[
|
| 7 |
+
Union[int, socket.AddressFamily],
|
| 8 |
+
Union[int, socket.SocketKind],
|
| 9 |
+
int,
|
| 10 |
+
str,
|
| 11 |
+
Tuple, # type: ignore[type-arg]
|
| 12 |
+
]
|
.venv/lib/python3.11/site-packages/aiohappyeyeballs/utils.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility functions for aiohappyeyeballs."""
|
| 2 |
+
|
| 3 |
+
import ipaddress
|
| 4 |
+
import socket
|
| 5 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
from .types import AddrInfoType
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def addr_to_addr_infos(
|
| 11 |
+
addr: Optional[
|
| 12 |
+
Union[Tuple[str, int, int, int], Tuple[str, int, int], Tuple[str, int]]
|
| 13 |
+
],
|
| 14 |
+
) -> Optional[List[AddrInfoType]]:
|
| 15 |
+
"""Convert an address tuple to a list of addr_info tuples."""
|
| 16 |
+
if addr is None:
|
| 17 |
+
return None
|
| 18 |
+
host = addr[0]
|
| 19 |
+
port = addr[1]
|
| 20 |
+
is_ipv6 = ":" in host
|
| 21 |
+
if is_ipv6:
|
| 22 |
+
flowinfo = 0
|
| 23 |
+
scopeid = 0
|
| 24 |
+
addr_len = len(addr)
|
| 25 |
+
if addr_len >= 4:
|
| 26 |
+
scopeid = addr[3] # type: ignore[misc]
|
| 27 |
+
if addr_len >= 3:
|
| 28 |
+
flowinfo = addr[2] # type: ignore[misc]
|
| 29 |
+
addr = (host, port, flowinfo, scopeid)
|
| 30 |
+
family = socket.AF_INET6
|
| 31 |
+
else:
|
| 32 |
+
addr = (host, port)
|
| 33 |
+
family = socket.AF_INET
|
| 34 |
+
return [(family, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", addr)]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def pop_addr_infos_interleave(
|
| 38 |
+
addr_infos: List[AddrInfoType], interleave: Optional[int] = None
|
| 39 |
+
) -> None:
|
| 40 |
+
"""
|
| 41 |
+
Pop addr_info from the list of addr_infos by family up to interleave times.
|
| 42 |
+
|
| 43 |
+
The interleave parameter is used to know how many addr_infos for
|
| 44 |
+
each family should be popped of the top of the list.
|
| 45 |
+
"""
|
| 46 |
+
seen: Dict[int, int] = {}
|
| 47 |
+
if interleave is None:
|
| 48 |
+
interleave = 1
|
| 49 |
+
to_remove: List[AddrInfoType] = []
|
| 50 |
+
for addr_info in addr_infos:
|
| 51 |
+
family = addr_info[0]
|
| 52 |
+
if family not in seen:
|
| 53 |
+
seen[family] = 0
|
| 54 |
+
if seen[family] < interleave:
|
| 55 |
+
to_remove.append(addr_info)
|
| 56 |
+
seen[family] += 1
|
| 57 |
+
for addr_info in to_remove:
|
| 58 |
+
addr_infos.remove(addr_info)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _addr_tuple_to_ip_address(
|
| 62 |
+
addr: Union[Tuple[str, int], Tuple[str, int, int, int]],
|
| 63 |
+
) -> Union[
|
| 64 |
+
Tuple[ipaddress.IPv4Address, int], Tuple[ipaddress.IPv6Address, int, int, int]
|
| 65 |
+
]:
|
| 66 |
+
"""Convert an address tuple to an IPv4Address."""
|
| 67 |
+
return (ipaddress.ip_address(addr[0]), *addr[1:])
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def remove_addr_infos(
|
| 71 |
+
addr_infos: List[AddrInfoType],
|
| 72 |
+
addr: Union[Tuple[str, int], Tuple[str, int, int, int]],
|
| 73 |
+
) -> None:
|
| 74 |
+
"""
|
| 75 |
+
Remove an address from the list of addr_infos.
|
| 76 |
+
|
| 77 |
+
The addr value is typically the return value of
|
| 78 |
+
sock.getpeername().
|
| 79 |
+
"""
|
| 80 |
+
bad_addrs_infos: List[AddrInfoType] = []
|
| 81 |
+
for addr_info in addr_infos:
|
| 82 |
+
if addr_info[-1] == addr:
|
| 83 |
+
bad_addrs_infos.append(addr_info)
|
| 84 |
+
if bad_addrs_infos:
|
| 85 |
+
for bad_addr_info in bad_addrs_infos:
|
| 86 |
+
addr_infos.remove(bad_addr_info)
|
| 87 |
+
return
|
| 88 |
+
# Slow path in case addr is formatted differently
|
| 89 |
+
match_addr = _addr_tuple_to_ip_address(addr)
|
| 90 |
+
for addr_info in addr_infos:
|
| 91 |
+
if match_addr == _addr_tuple_to_ip_address(addr_info[-1]):
|
| 92 |
+
bad_addrs_infos.append(addr_info)
|
| 93 |
+
if bad_addrs_infos:
|
| 94 |
+
for bad_addr_info in bad_addrs_infos:
|
| 95 |
+
addr_infos.remove(bad_addr_info)
|
| 96 |
+
return
|
| 97 |
+
raise ValueError(f"Address {addr} not found in addr_infos")
|
.venv/lib/python3.11/site-packages/attrs/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/attrs/__pycache__/validators.cpython-311.pyc
ADDED
|
Binary file (222 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/__init__.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
| 8 |
+
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
| 9 |
+
from vllm.engine.llm_engine import LLMEngine
|
| 10 |
+
from vllm.entrypoints.llm import LLM
|
| 11 |
+
from vllm.executor.ray_utils import initialize_ray_cluster
|
| 12 |
+
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
| 13 |
+
from vllm.model_executor.models import ModelRegistry
|
| 14 |
+
from vllm.outputs import (ClassificationOutput, ClassificationRequestOutput,
|
| 15 |
+
CompletionOutput, EmbeddingOutput,
|
| 16 |
+
EmbeddingRequestOutput, PoolingOutput,
|
| 17 |
+
PoolingRequestOutput, RequestOutput, ScoringOutput,
|
| 18 |
+
ScoringRequestOutput)
|
| 19 |
+
from vllm.pooling_params import PoolingParams
|
| 20 |
+
from vllm.sampling_params import SamplingParams
|
| 21 |
+
|
| 22 |
+
from .version import __version__, __version_tuple__
|
| 23 |
+
|
| 24 |
+
# set some common config/environment variables that should be set
|
| 25 |
+
# for all processes created by vllm and all processes
|
| 26 |
+
# that interact with vllm workers.
|
| 27 |
+
# they are executed whenever `import vllm` is called.
|
| 28 |
+
|
| 29 |
+
# see https://github.com/NVIDIA/nccl/issues/1234
|
| 30 |
+
os.environ['NCCL_CUMEM_ENABLE'] = '0'
|
| 31 |
+
|
| 32 |
+
# see https://github.com/vllm-project/vllm/issues/10480
|
| 33 |
+
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
|
| 34 |
+
# see https://github.com/vllm-project/vllm/issues/10619
|
| 35 |
+
torch._inductor.config.compile_threads = 1
|
| 36 |
+
|
| 37 |
+
__all__ = [
|
| 38 |
+
"__version__",
|
| 39 |
+
"__version_tuple__",
|
| 40 |
+
"LLM",
|
| 41 |
+
"ModelRegistry",
|
| 42 |
+
"PromptType",
|
| 43 |
+
"TextPrompt",
|
| 44 |
+
"TokensPrompt",
|
| 45 |
+
"SamplingParams",
|
| 46 |
+
"RequestOutput",
|
| 47 |
+
"CompletionOutput",
|
| 48 |
+
"PoolingOutput",
|
| 49 |
+
"PoolingRequestOutput",
|
| 50 |
+
"EmbeddingOutput",
|
| 51 |
+
"EmbeddingRequestOutput",
|
| 52 |
+
"ClassificationOutput",
|
| 53 |
+
"ClassificationRequestOutput",
|
| 54 |
+
"ScoringOutput",
|
| 55 |
+
"ScoringRequestOutput",
|
| 56 |
+
"LLMEngine",
|
| 57 |
+
"EngineArgs",
|
| 58 |
+
"AsyncLLMEngine",
|
| 59 |
+
"AsyncEngineArgs",
|
| 60 |
+
"initialize_ray_cluster",
|
| 61 |
+
"PoolingParams",
|
| 62 |
+
]
|
.venv/lib/python3.11/site-packages/vllm/_custom_ops.py
ADDED
|
@@ -0,0 +1,1098 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
import importlib
|
| 5 |
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.library
|
| 9 |
+
|
| 10 |
+
import vllm.envs as envs
|
| 11 |
+
from vllm.logger import init_logger
|
| 12 |
+
from vllm.platforms import current_platform
|
| 13 |
+
from vllm.scalar_type import ScalarType
|
| 14 |
+
|
| 15 |
+
logger = init_logger(__name__)
|
| 16 |
+
|
| 17 |
+
if not current_platform.is_tpu() and not current_platform.is_hpu():
|
| 18 |
+
try:
|
| 19 |
+
import vllm._C
|
| 20 |
+
except ImportError as e:
|
| 21 |
+
logger.warning("Failed to import from vllm._C with %r", e)
|
| 22 |
+
|
| 23 |
+
supports_moe_ops = False
|
| 24 |
+
with contextlib.suppress(ImportError):
|
| 25 |
+
import vllm._moe_C # noqa: F401
|
| 26 |
+
supports_moe_ops = True
|
| 27 |
+
|
| 28 |
+
if TYPE_CHECKING:
|
| 29 |
+
|
| 30 |
+
def register_fake(fn):
|
| 31 |
+
return lambda name: fn
|
| 32 |
+
else:
|
| 33 |
+
try:
|
| 34 |
+
from torch.library import register_fake
|
| 35 |
+
except ImportError:
|
| 36 |
+
from torch.library import impl_abstract as register_fake
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# page attention ops
|
| 40 |
+
def paged_attention_v1(
|
| 41 |
+
out: torch.Tensor,
|
| 42 |
+
query: torch.Tensor,
|
| 43 |
+
key_cache: torch.Tensor,
|
| 44 |
+
value_cache: torch.Tensor,
|
| 45 |
+
num_kv_heads: int,
|
| 46 |
+
scale: float,
|
| 47 |
+
block_tables: torch.Tensor,
|
| 48 |
+
seq_lens: torch.Tensor,
|
| 49 |
+
block_size: int,
|
| 50 |
+
max_seq_len: int,
|
| 51 |
+
alibi_slopes: Optional[torch.Tensor],
|
| 52 |
+
kv_cache_dtype: str,
|
| 53 |
+
k_scale: torch.Tensor,
|
| 54 |
+
v_scale: torch.Tensor,
|
| 55 |
+
tp_rank: int = 0,
|
| 56 |
+
blocksparse_local_blocks: int = 0,
|
| 57 |
+
blocksparse_vert_stride: int = 0,
|
| 58 |
+
blocksparse_block_size: int = 64,
|
| 59 |
+
blocksparse_head_sliding_step: int = 0,
|
| 60 |
+
) -> None:
|
| 61 |
+
torch.ops._C.paged_attention_v1(
|
| 62 |
+
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
|
| 63 |
+
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
|
| 64 |
+
k_scale, v_scale, tp_rank, blocksparse_local_blocks,
|
| 65 |
+
blocksparse_vert_stride, blocksparse_block_size,
|
| 66 |
+
blocksparse_head_sliding_step)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def paged_attention_v2(
|
| 70 |
+
out: torch.Tensor,
|
| 71 |
+
exp_sum: torch.Tensor,
|
| 72 |
+
max_logits: torch.Tensor,
|
| 73 |
+
tmp_out: torch.Tensor,
|
| 74 |
+
query: torch.Tensor,
|
| 75 |
+
key_cache: torch.Tensor,
|
| 76 |
+
value_cache: torch.Tensor,
|
| 77 |
+
num_kv_heads: int,
|
| 78 |
+
scale: float,
|
| 79 |
+
block_tables: torch.Tensor,
|
| 80 |
+
seq_lens: torch.Tensor,
|
| 81 |
+
block_size: int,
|
| 82 |
+
max_seq_len: int,
|
| 83 |
+
alibi_slopes: Optional[torch.Tensor],
|
| 84 |
+
kv_cache_dtype: str,
|
| 85 |
+
k_scale: torch.Tensor,
|
| 86 |
+
v_scale: torch.Tensor,
|
| 87 |
+
tp_rank: int = 0,
|
| 88 |
+
blocksparse_local_blocks: int = 0,
|
| 89 |
+
blocksparse_vert_stride: int = 0,
|
| 90 |
+
blocksparse_block_size: int = 64,
|
| 91 |
+
blocksparse_head_sliding_step: int = 0,
|
| 92 |
+
) -> None:
|
| 93 |
+
torch.ops._C.paged_attention_v2(
|
| 94 |
+
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
|
| 95 |
+
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
|
| 96 |
+
alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
|
| 97 |
+
blocksparse_local_blocks, blocksparse_vert_stride,
|
| 98 |
+
blocksparse_block_size, blocksparse_head_sliding_step)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def paged_attention_rocm(
|
| 102 |
+
out: torch.Tensor,
|
| 103 |
+
exp_sum: torch.Tensor,
|
| 104 |
+
max_logits: torch.Tensor,
|
| 105 |
+
tmp_out: torch.Tensor,
|
| 106 |
+
query: torch.Tensor,
|
| 107 |
+
key_cache: torch.Tensor,
|
| 108 |
+
value_cache: torch.Tensor,
|
| 109 |
+
num_kv_heads: int,
|
| 110 |
+
scale: float,
|
| 111 |
+
block_tables: torch.Tensor,
|
| 112 |
+
seq_lens: torch.Tensor,
|
| 113 |
+
block_size: int,
|
| 114 |
+
max_seq_len: int,
|
| 115 |
+
alibi_slopes: Optional[torch.Tensor],
|
| 116 |
+
kv_cache_dtype: str,
|
| 117 |
+
k_scale: torch.Tensor,
|
| 118 |
+
v_scale: torch.Tensor,
|
| 119 |
+
) -> None:
|
| 120 |
+
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
|
| 121 |
+
key_cache, value_cache, num_kv_heads,
|
| 122 |
+
scale, block_tables, seq_lens,
|
| 123 |
+
block_size, max_seq_len, alibi_slopes,
|
| 124 |
+
kv_cache_dtype, k_scale, v_scale)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# pos encoding ops
|
| 128 |
+
def rotary_embedding(
|
| 129 |
+
positions: torch.Tensor,
|
| 130 |
+
query: torch.Tensor,
|
| 131 |
+
key: torch.Tensor,
|
| 132 |
+
head_size: int,
|
| 133 |
+
cos_sin_cache: torch.Tensor,
|
| 134 |
+
is_neox: bool,
|
| 135 |
+
) -> None:
|
| 136 |
+
torch.ops._C.rotary_embedding(positions, query, key, head_size,
|
| 137 |
+
cos_sin_cache, is_neox)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
| 141 |
+
key: torch.Tensor, head_size: int,
|
| 142 |
+
cos_sin_cache: torch.Tensor, is_neox: bool,
|
| 143 |
+
rot_dim: int,
|
| 144 |
+
cos_sin_cache_offsets: torch.Tensor) -> None:
|
| 145 |
+
torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
|
| 146 |
+
cos_sin_cache, is_neox, rot_dim,
|
| 147 |
+
cos_sin_cache_offsets)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# layer norm ops
|
| 151 |
+
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
| 152 |
+
epsilon: float) -> None:
|
| 153 |
+
torch.ops._C.rms_norm(out, input, weight, epsilon)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
|
| 157 |
+
weight: torch.Tensor, epsilon: float) -> None:
|
| 158 |
+
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int,
|
| 162 |
+
input_tokens: torch.Tensor,
|
| 163 |
+
sampled_token_ids: torch.Tensor,
|
| 164 |
+
input_positions: torch.Tensor,
|
| 165 |
+
seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
|
| 166 |
+
block_tables: torch.Tensor) -> None:
|
| 167 |
+
"""Advance a step on GPU for existing inputs for a multi-step runner"""
|
| 168 |
+
return torch.ops._C.advance_step_flashattn(num_seqs, num_queries,
|
| 169 |
+
block_size, input_tokens,
|
| 170 |
+
sampled_token_ids,
|
| 171 |
+
input_positions, seq_lens,
|
| 172 |
+
slot_mapping, block_tables)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
|
| 176 |
+
input_tokens: torch.Tensor,
|
| 177 |
+
sampled_token_ids: torch.Tensor,
|
| 178 |
+
input_positions: torch.Tensor,
|
| 179 |
+
seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
|
| 180 |
+
block_tables: torch.Tensor,
|
| 181 |
+
paged_kv_indices: torch.Tensor,
|
| 182 |
+
paged_kv_indptr: torch.Tensor,
|
| 183 |
+
paged_kv_last_page_len: torch.Tensor,
|
| 184 |
+
block_table_bound: torch.Tensor) -> None:
|
| 185 |
+
|
| 186 |
+
return torch.ops._C.advance_step_flashinfer(
|
| 187 |
+
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
|
| 188 |
+
input_positions, seq_lens, slot_mapping, block_tables,
|
| 189 |
+
paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len,
|
| 190 |
+
block_table_bound)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# fused quant layer norm ops
|
| 194 |
+
def rms_norm_dynamic_per_token_quant(
|
| 195 |
+
input: torch.Tensor,
|
| 196 |
+
weight: torch.Tensor,
|
| 197 |
+
epsilon: float,
|
| 198 |
+
quant_dtype: torch.dtype,
|
| 199 |
+
scale_ub: Optional[torch.Tensor] = None,
|
| 200 |
+
residual: Optional[torch.Tensor] = None
|
| 201 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 202 |
+
output = torch.empty_like(input, dtype=quant_dtype)
|
| 203 |
+
scales = torch.empty((input.numel() // input.shape[-1], 1),
|
| 204 |
+
device=input.device,
|
| 205 |
+
dtype=torch.float32)
|
| 206 |
+
|
| 207 |
+
torch.ops._C.rms_norm_dynamic_per_token_quant(output, input, weight,
|
| 208 |
+
scales, epsilon, scale_ub,
|
| 209 |
+
residual)
|
| 210 |
+
return output, scales
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# quantization ops
|
| 214 |
+
# awq
|
| 215 |
+
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
|
| 216 |
+
zeros: torch.Tensor, split_k_iters: int, thx: int,
|
| 217 |
+
thy: int) -> torch.Tensor:
|
| 218 |
+
if envs.VLLM_USE_TRITON_AWQ:
|
| 219 |
+
from vllm.model_executor.layers.quantization.awq_triton import (
|
| 220 |
+
awq_dequantize_triton)
|
| 221 |
+
return awq_dequantize_triton(qweight, scales, zeros)
|
| 222 |
+
return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
|
| 223 |
+
thx, thy)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
|
| 227 |
+
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
|
| 228 |
+
if envs.VLLM_USE_TRITON_AWQ:
|
| 229 |
+
from vllm.model_executor.layers.quantization.awq_triton import (
|
| 230 |
+
awq_gemm_triton)
|
| 231 |
+
return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters)
|
| 232 |
+
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# gptq
|
| 236 |
+
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
| 237 |
+
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
|
| 238 |
+
b_g_idx: torch.Tensor, use_exllama: bool,
|
| 239 |
+
bit: int) -> torch.Tensor:
|
| 240 |
+
return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
|
| 241 |
+
b_g_idx, use_exllama, bit)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
if hasattr(torch.ops._C, "gptq_gemm"):
|
| 245 |
+
|
| 246 |
+
@register_fake("_C::gptq_gemm")
|
| 247 |
+
def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
| 248 |
+
b_gptq_qzeros: torch.Tensor,
|
| 249 |
+
b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
|
| 250 |
+
use_exllama: bool, bit: int) -> torch.Tensor:
|
| 251 |
+
return torch.empty((a.size(0), b_q_weight.size(1)),
|
| 252 |
+
dtype=a.dtype,
|
| 253 |
+
device=a.device)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
|
| 257 |
+
bit: int) -> None:
|
| 258 |
+
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# marlin
|
| 262 |
+
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
| 263 |
+
b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
|
| 264 |
+
size_n: int, size_k: int) -> torch.Tensor:
|
| 265 |
+
return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
|
| 266 |
+
size_n, size_k)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
# marlin_24
|
| 270 |
+
def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
| 271 |
+
b_meta: torch.Tensor, b_scales: torch.Tensor,
|
| 272 |
+
workspace: torch.Tensor, b_q_type: ScalarType,
|
| 273 |
+
size_m: int, size_n: int, size_k: int) -> torch.Tensor:
|
| 274 |
+
return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
|
| 275 |
+
workspace, b_q_type.id, size_m,
|
| 276 |
+
size_n, size_k)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
|
| 280 |
+
|
| 281 |
+
@register_fake("_C::gptq_marlin_24_gemm")
|
| 282 |
+
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
| 283 |
+
b_meta: torch.Tensor, b_scales: torch.Tensor,
|
| 284 |
+
workspace: torch.Tensor,
|
| 285 |
+
b_q_type: ScalarType, size_m: torch.SymInt,
|
| 286 |
+
size_n: torch.SymInt,
|
| 287 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
| 288 |
+
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
| 289 |
+
|
| 290 |
+
@register_fake("_C::gptq_marlin_gemm")
|
| 291 |
+
def _gptq_marlin_gemm_fake(a: torch.Tensor,
|
| 292 |
+
b_q_weight: torch.Tensor,
|
| 293 |
+
b_scales: torch.Tensor,
|
| 294 |
+
b_zeros: torch.Tensor,
|
| 295 |
+
g_idx: torch.Tensor,
|
| 296 |
+
perm: torch.Tensor,
|
| 297 |
+
workspace: torch.Tensor,
|
| 298 |
+
b_q_type: ScalarType,
|
| 299 |
+
size_m: torch.SymInt,
|
| 300 |
+
size_n: torch.SymInt,
|
| 301 |
+
size_k: torch.SymInt,
|
| 302 |
+
is_k_full: bool,
|
| 303 |
+
has_zp: bool = False,
|
| 304 |
+
use_fp32_reduce: bool = False,
|
| 305 |
+
is_zp_float: bool = False) -> torch.Tensor:
|
| 306 |
+
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
|
| 307 |
+
|
| 308 |
+
@register_fake("_C::marlin_qqq_gemm")
|
| 309 |
+
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
| 310 |
+
s_tok: torch.Tensor, s_ch: torch.Tensor,
|
| 311 |
+
s_group: torch.Tensor, workspace: torch.Tensor,
|
| 312 |
+
size_m: torch.SymInt, size_n: torch.SymInt,
|
| 313 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
| 314 |
+
return torch.empty((size_m, size_n),
|
| 315 |
+
dtype=torch.float16,
|
| 316 |
+
device=a.device)
|
| 317 |
+
|
| 318 |
+
@register_fake("_C::marlin_gemm")
|
| 319 |
+
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
| 320 |
+
b_scales: torch.Tensor, workspace: torch.Tensor,
|
| 321 |
+
size_m: torch.SymInt, size_n: torch.SymInt,
|
| 322 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
| 323 |
+
return torch.empty((size_m, size_n),
|
| 324 |
+
dtype=torch.float16,
|
| 325 |
+
device=a.device)
|
| 326 |
+
|
| 327 |
+
@register_fake("_C::awq_dequantize")
|
| 328 |
+
def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
|
| 329 |
+
zeros: torch.Tensor, split_k_iters: torch.SymInt,
|
| 330 |
+
thx: int, thy: int) -> torch.Tensor:
|
| 331 |
+
in_c = qweight.size(0)
|
| 332 |
+
qout_c = qweight.size(1)
|
| 333 |
+
out_c = qout_c * 8
|
| 334 |
+
return torch.empty((in_c, out_c),
|
| 335 |
+
dtype=scales.dtype,
|
| 336 |
+
device=scales.device)
|
| 337 |
+
|
| 338 |
+
@register_fake("_C::awq_gemm")
|
| 339 |
+
def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
|
| 340 |
+
qzeros: torch.Tensor, scales: torch.Tensor,
|
| 341 |
+
split_k_iters: torch.SymInt) -> torch.Tensor:
|
| 342 |
+
num_in_feats = input.size(0)
|
| 343 |
+
return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8),
|
| 344 |
+
dtype=input.dtype,
|
| 345 |
+
device=input.device).sum(0)
|
| 346 |
+
|
| 347 |
+
@register_fake("_C::aqlm_gemm")
|
| 348 |
+
def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
|
| 349 |
+
codebooks: torch.Tensor, scales: torch.Tensor,
|
| 350 |
+
codebook_partition_sizes: List[int],
|
| 351 |
+
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
| 352 |
+
out_features = codes.size(0) * codebooks.size(2)
|
| 353 |
+
flat_input = input.reshape((-1, input.size(-1)))
|
| 354 |
+
flat_output = torch.empty((flat_input.size(0), out_features),
|
| 355 |
+
dtype=input.dtype,
|
| 356 |
+
device=input.device)
|
| 357 |
+
|
| 358 |
+
output_sizes = list(input.shape)
|
| 359 |
+
output_sizes.pop()
|
| 360 |
+
output_sizes.append(-1)
|
| 361 |
+
return flat_output.reshape(tuple(output_sizes))
|
| 362 |
+
|
| 363 |
+
@register_fake("_C::aqlm_dequant")
|
| 364 |
+
def _aqlm_dequant_fake(
|
| 365 |
+
codes: torch.Tensor, codebooks: torch.Tensor,
|
| 366 |
+
codebook_partition_sizes: List[int]) -> torch.Tensor:
|
| 367 |
+
in_features = codes.size(1) * 8
|
| 368 |
+
out_features = codes.size(0)
|
| 369 |
+
return torch.empty((out_features, in_features),
|
| 370 |
+
dtype=codebooks.dtype,
|
| 371 |
+
device=codebooks.device)
|
| 372 |
+
|
| 373 |
+
@register_fake("_C::fp8_marlin_gemm")
|
| 374 |
+
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
|
| 375 |
+
b_scales: torch.Tensor, workspace: torch.Tensor,
|
| 376 |
+
num_bits: int, size_m: torch.SymInt,
|
| 377 |
+
size_n: torch.SymInt,
|
| 378 |
+
size_k: torch.SymInt) -> torch.Tensor:
|
| 379 |
+
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
|
| 380 |
+
|
| 381 |
+
@register_fake("_C::machete_mm")
|
| 382 |
+
def machete_mm_fake(
|
| 383 |
+
a: torch.Tensor,
|
| 384 |
+
# b_q Should be the tensor returned by machete_prepack_B
|
| 385 |
+
b_q: torch.Tensor,
|
| 386 |
+
b_type: ScalarType,
|
| 387 |
+
out_type: Optional[torch.dtype] = None,
|
| 388 |
+
b_group_scales: Optional[torch.Tensor] = None,
|
| 389 |
+
b_group_zeros: Optional[torch.Tensor] = None,
|
| 390 |
+
b_group_size: Optional[int] = None,
|
| 391 |
+
b_channel_scales: Optional[torch.Tensor] = None,
|
| 392 |
+
a_token_scales: Optional[torch.Tensor] = None,
|
| 393 |
+
schedule: Optional[str] = None,
|
| 394 |
+
) -> torch.Tensor:
|
| 395 |
+
m = a.size(0)
|
| 396 |
+
n = b_q.size(1)
|
| 397 |
+
return torch.empty((m, n), device=a.device, dtype=a.dtype)
|
| 398 |
+
|
| 399 |
+
@register_fake("_C::machete_prepack_B")
|
| 400 |
+
def machete_prepack_B_fake(
|
| 401 |
+
b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType,
|
| 402 |
+
group_scales_type: Optional[torch.dtype]) -> torch.Tensor:
|
| 403 |
+
return torch.empty_like(b_q_weight,
|
| 404 |
+
memory_format=torch.contiguous_format)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
if hasattr(torch.ops._C, "ggml_dequantize"):
|
| 408 |
+
|
| 409 |
+
@register_fake("_C::ggml_dequantize")
|
| 410 |
+
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int,
|
| 411 |
+
m: torch.SymInt,
|
| 412 |
+
n: torch.SymInt) -> torch.Tensor:
|
| 413 |
+
return torch.empty((m, n), dtype=torch.float16, device=W.device)
|
| 414 |
+
|
| 415 |
+
@register_fake("_C::ggml_mul_mat_vec_a8")
|
| 416 |
+
def _ggml_mul_mat_vec_a8_fake(
|
| 417 |
+
W: torch.Tensor,
|
| 418 |
+
X: torch.Tensor,
|
| 419 |
+
quant_type: int,
|
| 420 |
+
row: torch.SymInt,
|
| 421 |
+
) -> torch.Tensor:
|
| 422 |
+
return torch.empty((1, row), dtype=torch.float16, device=W.device)
|
| 423 |
+
|
| 424 |
+
@register_fake("_C::ggml_mul_mat_a8")
|
| 425 |
+
def _ggml_mul_mat_a8_fake(
|
| 426 |
+
W: torch.Tensor,
|
| 427 |
+
X: torch.Tensor,
|
| 428 |
+
quant_type: int,
|
| 429 |
+
row: torch.SymInt,
|
| 430 |
+
) -> torch.Tensor:
|
| 431 |
+
batch = X.size(0)
|
| 432 |
+
return torch.empty((batch, row), dtype=torch.float16, device=W.device)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
# cutlass
|
| 436 |
+
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
|
| 437 |
+
return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool:
|
| 441 |
+
return torch.ops._C.cutlass_scaled_mm_supports_block_fp8(
|
| 442 |
+
cuda_device_capability)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def cutlass_scaled_mm(a: torch.Tensor,
|
| 446 |
+
b: torch.Tensor,
|
| 447 |
+
scale_a: torch.Tensor,
|
| 448 |
+
scale_b: torch.Tensor,
|
| 449 |
+
out_dtype: torch.dtype,
|
| 450 |
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 451 |
+
"""
|
| 452 |
+
`cutlass_scaled_mm` implements a fused version of
|
| 453 |
+
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
|
| 454 |
+
where scale_a * a and scale_b * b are implemented using numpy-style
|
| 455 |
+
broadcasting.
|
| 456 |
+
|
| 457 |
+
In order to support blockwise scaling like found in DeepSeek V3 we also
|
| 458 |
+
support extended "group" broadcast rules. We extend the numpy-style
|
| 459 |
+
broadcasting rules with the following rule:
|
| 460 |
+
"if the extent of a dimension in the source shape is between 1 and
|
| 461 |
+
corresponding extent in the target shape we repeat each element along
|
| 462 |
+
that dimension src_shape[dim] // target_shape[dim] times consecutively"
|
| 463 |
+
example if we have:
|
| 464 |
+
a = [[1, 2], and target_shape = (2, 4)
|
| 465 |
+
[3, 4]]
|
| 466 |
+
then we would expand a to:
|
| 467 |
+
a = [[1, 1, 2, 2],
|
| 468 |
+
[3, 3, 4, 4]]
|
| 469 |
+
currently we only support the case:
|
| 470 |
+
scale_a.shape * [1, 128] == a.shape
|
| 471 |
+
scale_b.shape * [128, 128] == b.shape
|
| 472 |
+
"""
|
| 473 |
+
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
| 474 |
+
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
| 475 |
+
assert bias is None or bias.shape[0] == b.shape[
|
| 476 |
+
1] and bias.dtype == out_dtype
|
| 477 |
+
|
| 478 |
+
m = a.shape[0]
|
| 479 |
+
n = b.shape[1]
|
| 480 |
+
|
| 481 |
+
if current_platform.is_rocm():
|
| 482 |
+
triton_scaled_mm_module = importlib.import_module(
|
| 483 |
+
"vllm.model_executor.layers.quantization.compressed_tensors."
|
| 484 |
+
"triton_scaled_mm")
|
| 485 |
+
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
| 486 |
+
return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
| 487 |
+
|
| 488 |
+
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
| 489 |
+
|
| 490 |
+
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
| 491 |
+
|
| 492 |
+
return out
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def cutlass_scaled_mm_azp(a: torch.Tensor,
|
| 496 |
+
b: torch.Tensor,
|
| 497 |
+
scale_a: torch.Tensor,
|
| 498 |
+
scale_b: torch.Tensor,
|
| 499 |
+
out_dtype: torch.dtype,
|
| 500 |
+
azp_adj: torch.Tensor,
|
| 501 |
+
azp: Optional[torch.Tensor] = None,
|
| 502 |
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 503 |
+
"""
|
| 504 |
+
:param azp_adj: In the per-tensor case, this should include the azp.
|
| 505 |
+
Always per-channel.
|
| 506 |
+
:param azp: Only set in the per-token case. Per-token if set.
|
| 507 |
+
"""
|
| 508 |
+
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
| 509 |
+
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
| 510 |
+
assert bias is None or bias.numel(
|
| 511 |
+
) == b.shape[1] and bias.dtype == out_dtype
|
| 512 |
+
assert azp is None or azp.numel() == a.shape[0]
|
| 513 |
+
|
| 514 |
+
m = a.shape[0]
|
| 515 |
+
n = b.shape[1]
|
| 516 |
+
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
| 517 |
+
|
| 518 |
+
torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
|
| 519 |
+
azp, bias)
|
| 520 |
+
return out
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
|
| 524 |
+
return torch.ops._C.cutlass_sparse_scaled_mm_supported(
|
| 525 |
+
cuda_device_capability)
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def cutlass_sparse_compress(a: torch.Tensor) \
|
| 529 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
| 530 |
+
"""
|
| 531 |
+
Compresses a sparse matrix for use with Cutlass sparse operations.
|
| 532 |
+
|
| 533 |
+
This function takes a dense tensor and compresses it into two components:
|
| 534 |
+
non-zero elements and metadata. The compressed representation is compatible
|
| 535 |
+
with Cutlass sparse kernels.
|
| 536 |
+
|
| 537 |
+
Args:
|
| 538 |
+
a (torch.Tensor):
|
| 539 |
+
The input tensor to be compressed. Must have one of the following data types:
|
| 540 |
+
- `torch.int8`
|
| 541 |
+
- `torch.float8_e4m3fn`
|
| 542 |
+
- `torch.bfloat16`
|
| 543 |
+
- `torch.float16`
|
| 544 |
+
|
| 545 |
+
Returns:
|
| 546 |
+
Tuple[torch.Tensor, torch.Tensor]:
|
| 547 |
+
A tuple containing:
|
| 548 |
+
- `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`.
|
| 549 |
+
- `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation.
|
| 550 |
+
|
| 551 |
+
Raises:
|
| 552 |
+
ValueError: If the compression operation fails.
|
| 553 |
+
|
| 554 |
+
Notes:
|
| 555 |
+
- The `a_meta` tensor has a data type of `torch.uint8`.
|
| 556 |
+
- Each metadata element encodes the sparsity of 4 non-zero elements (i.e., `elemsPerMetaElem = 4`).
|
| 557 |
+
- The shape of `a_nzs` is `(m, k // 2)`, where `m` and `k` are the dimensions of the input tensor.
|
| 558 |
+
- The shape of `a_meta` is `(m, k // 2 // elemsPerMetaElem)`.
|
| 559 |
+
"""
|
| 560 |
+
assert (a.dtype in [
|
| 561 |
+
torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16
|
| 562 |
+
])
|
| 563 |
+
assert (a.is_contiguous())
|
| 564 |
+
|
| 565 |
+
# a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4
|
| 566 |
+
elemsPerMetaElem = 4
|
| 567 |
+
|
| 568 |
+
m = a.shape[0]
|
| 569 |
+
k = a.shape[1]
|
| 570 |
+
assert (k % 2 == 0)
|
| 571 |
+
a_nzs = torch.empty((m, k // 2), dtype=a.dtype, device=a.device)
|
| 572 |
+
a_meta = torch.empty((m, k // 2 // elemsPerMetaElem),
|
| 573 |
+
dtype=torch.uint8,
|
| 574 |
+
device=a.device)
|
| 575 |
+
|
| 576 |
+
if not (torch.ops._C.cutlass_sparse_compress_entry(a_nzs, a_meta, a)):
|
| 577 |
+
raise ValueError
|
| 578 |
+
|
| 579 |
+
assert (a_nzs.is_contiguous())
|
| 580 |
+
assert (a_meta.is_contiguous())
|
| 581 |
+
|
| 582 |
+
return a_nzs, a_meta
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def cutlass_scaled_sparse_mm(
|
| 586 |
+
a: torch.Tensor,
|
| 587 |
+
bt_nzs: torch.Tensor,
|
| 588 |
+
bt_meta: torch.Tensor,
|
| 589 |
+
scale_a: torch.Tensor,
|
| 590 |
+
scale_b: torch.Tensor,
|
| 591 |
+
out_dtype: torch.dtype,
|
| 592 |
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 593 |
+
"""
|
| 594 |
+
Performs a scaled sparse matrix multiplication using Cutlass.
|
| 595 |
+
|
| 596 |
+
Steps:
|
| 597 |
+
1. Create a dense matrix `a` of shape (m, k) on the CUDA device:
|
| 598 |
+
`a = torch.randn((m, k), device='cuda')`.
|
| 599 |
+
|
| 600 |
+
2. Create a dense matrix `b` of shape (k, n) on the CUDA device:
|
| 601 |
+
`b = torch.randn((k, n), device='cuda')`.
|
| 602 |
+
|
| 603 |
+
3. Prune matrix `b` to 2:4 sparsity along the specified dimension:
|
| 604 |
+
`b = prune_to_2_4(b, dim=0)`.
|
| 605 |
+
|
| 606 |
+
4. Compress the transposed sparse matrix `b.t()`:
|
| 607 |
+
`bt_nzs, bt_meta = cutlass_sparse_compress(b.t())`.
|
| 608 |
+
|
| 609 |
+
5. Perform sparse matrix multiplication using the compressed matrix,
|
| 610 |
+
applying scaling factors for `a` and `b`, and the output data type:
|
| 611 |
+
`out = cutlass_scaled_sparse_mm(a, bt_nzs, bt_meta, scale_a, scale_b, out_dtype)`.
|
| 612 |
+
|
| 613 |
+
Returns:
|
| 614 |
+
- The result of the scaled sparse matrix multiplication.
|
| 615 |
+
"""
|
| 616 |
+
assert (bt_nzs.shape[0] % 16 == 0 and bt_nzs.shape[1] % 16 == 0)
|
| 617 |
+
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
| 618 |
+
assert bias is None or bias.shape[0] == bt_nzs.shape[0] \
|
| 619 |
+
and bias.dtype == out_dtype
|
| 620 |
+
|
| 621 |
+
m = a.shape[0]
|
| 622 |
+
n = bt_nzs.shape[0]
|
| 623 |
+
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
| 624 |
+
|
| 625 |
+
torch.ops._C.cutlass_scaled_sparse_mm(out, a, bt_nzs, bt_meta, scale_a,
|
| 626 |
+
scale_b, bias)
|
| 627 |
+
|
| 628 |
+
return out
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
# aqlm
|
| 632 |
+
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
|
| 633 |
+
codebooks: torch.Tensor, scales: torch.Tensor,
|
| 634 |
+
codebook_partition_sizes: List[int],
|
| 635 |
+
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
| 636 |
+
return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
|
| 637 |
+
codebook_partition_sizes, bias)
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
|
| 641 |
+
codebook_partition_sizes: List[int]) -> torch.Tensor:
|
| 642 |
+
return torch.ops._C.aqlm_dequant(codes, codebooks,
|
| 643 |
+
codebook_partition_sizes)
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
# gptq_marlin
|
| 647 |
+
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
|
| 648 |
+
size_k: int, size_n: int,
|
| 649 |
+
num_bits: int) -> torch.Tensor:
|
| 650 |
+
return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
|
| 651 |
+
num_bits)
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
# gptq_marlin
|
| 655 |
+
def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
|
| 656 |
+
num_bits: int) -> torch.Tensor:
|
| 657 |
+
return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
|
| 661 |
+
size_k: int, size_n: int,
|
| 662 |
+
num_bits: int) -> torch.Tensor:
|
| 663 |
+
num_experts = b_q_weight.shape[0]
|
| 664 |
+
assert size_k % 16 == 0
|
| 665 |
+
output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
|
| 666 |
+
device=b_q_weight.device,
|
| 667 |
+
dtype=b_q_weight.dtype)
|
| 668 |
+
for e in range(num_experts):
|
| 669 |
+
output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e],
|
| 670 |
+
size_k, size_n, num_bits)
|
| 671 |
+
return output
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
|
| 675 |
+
size_k: int, size_n: int,
|
| 676 |
+
num_bits: int) -> torch.Tensor:
|
| 677 |
+
num_experts = b_q_weight.shape[0]
|
| 678 |
+
assert size_k % 16 == 0
|
| 679 |
+
output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
|
| 680 |
+
device=b_q_weight.device,
|
| 681 |
+
dtype=b_q_weight.dtype)
|
| 682 |
+
for e in range(num_experts):
|
| 683 |
+
output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k,
|
| 684 |
+
size_n, num_bits)
|
| 685 |
+
return output
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
def gptq_marlin_gemm(a: torch.Tensor,
|
| 689 |
+
b_q_weight: torch.Tensor,
|
| 690 |
+
b_scales: torch.Tensor,
|
| 691 |
+
b_zeros: torch.Tensor,
|
| 692 |
+
g_idx: torch.Tensor,
|
| 693 |
+
perm: torch.Tensor,
|
| 694 |
+
workspace: torch.Tensor,
|
| 695 |
+
b_q_type: ScalarType,
|
| 696 |
+
size_m: int,
|
| 697 |
+
size_n: int,
|
| 698 |
+
size_k: int,
|
| 699 |
+
is_k_full: bool,
|
| 700 |
+
has_zp: bool = False,
|
| 701 |
+
use_fp32_reduce: bool = False,
|
| 702 |
+
is_zp_float: bool = False) -> torch.Tensor:
|
| 703 |
+
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
|
| 704 |
+
g_idx, perm, workspace, b_q_type.id,
|
| 705 |
+
size_m, size_n, size_k, is_k_full,
|
| 706 |
+
has_zp, use_fp32_reduce, is_zp_float)
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
# fp8 marlin
|
| 710 |
+
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
| 711 |
+
b_scales: torch.Tensor, workspace: torch.Tensor,
|
| 712 |
+
num_bits: int, size_m: int, size_n: int,
|
| 713 |
+
size_k: int) -> torch.Tensor:
|
| 714 |
+
return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
|
| 715 |
+
num_bits, size_m, size_n, size_k)
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
# machete
|
| 719 |
+
def machete_supported_schedules(
|
| 720 |
+
a_type: torch.dtype,
|
| 721 |
+
b_type: ScalarType,
|
| 722 |
+
group_scales_type: Optional[torch.dtype],
|
| 723 |
+
group_zeros_type: Optional[torch.dtype] = None,
|
| 724 |
+
channel_scales_type: Optional[torch.dtype] = None,
|
| 725 |
+
token_scales_type: Optional[torch.dtype] = None,
|
| 726 |
+
out_type: Optional[torch.dtype] = None) -> List[str]:
|
| 727 |
+
return torch.ops._C.machete_supported_schedules(
|
| 728 |
+
a_type, b_type.id, group_scales_type, group_zeros_type,
|
| 729 |
+
channel_scales_type, token_scales_type, out_type)
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
def machete_mm(
|
| 733 |
+
a: torch.Tensor,
|
| 734 |
+
# b_q Should be the tensor returned by machete_prepack_B
|
| 735 |
+
b_q: torch.Tensor,
|
| 736 |
+
b_type: ScalarType,
|
| 737 |
+
out_type: Optional[torch.dtype] = None,
|
| 738 |
+
b_group_scales: Optional[torch.Tensor] = None,
|
| 739 |
+
b_group_zeros: Optional[torch.Tensor] = None,
|
| 740 |
+
b_group_size: Optional[int] = None,
|
| 741 |
+
b_channel_scales: Optional[torch.Tensor] = None,
|
| 742 |
+
a_token_scales: Optional[torch.Tensor] = None,
|
| 743 |
+
schedule: Optional[str] = None) -> torch.Tensor:
|
| 744 |
+
return torch.ops._C.machete_mm(a, b_q, b_type.id, out_type, b_group_scales,
|
| 745 |
+
b_group_zeros, b_group_size,
|
| 746 |
+
b_channel_scales, a_token_scales, schedule)
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
def machete_prepack_B(
|
| 750 |
+
b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType,
|
| 751 |
+
group_scales_type: Optional[torch.dtype]) -> torch.Tensor:
|
| 752 |
+
return torch.ops._C.machete_prepack_B(b_q_weight, a_type, b_type.id,
|
| 753 |
+
group_scales_type)
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
if hasattr(torch.ops._C, "permute_cols"):
|
| 757 |
+
|
| 758 |
+
@register_fake("_C::permute_cols")
|
| 759 |
+
def _permute_cols_fake(a: torch.Tensor,
|
| 760 |
+
perm: torch.Tensor) -> torch.Tensor:
|
| 761 |
+
return torch.empty_like(a)
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
|
| 765 |
+
return torch.ops._C.permute_cols(a, perm)
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
# fp8
|
| 769 |
+
def scaled_fp8_quant(
|
| 770 |
+
input: torch.Tensor,
|
| 771 |
+
scale: Optional[torch.Tensor] = None,
|
| 772 |
+
num_token_padding: Optional[int] = None,
|
| 773 |
+
scale_ub: Optional[torch.Tensor] = None,
|
| 774 |
+
use_per_token_if_dynamic: bool = False,
|
| 775 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 776 |
+
"""
|
| 777 |
+
Quantize input tensor to FP8 and return quantized tensor and scale.
|
| 778 |
+
|
| 779 |
+
This function supports both static and dynamic quantization: If you
|
| 780 |
+
provide the scale, it will use static scaling and if you omit it,
|
| 781 |
+
the scale will be determined dynamically. The function also allows
|
| 782 |
+
optional padding of the output tensors for downstream kernels that
|
| 783 |
+
will benefit from padding.
|
| 784 |
+
|
| 785 |
+
Args:
|
| 786 |
+
input: The input tensor to be quantized to FP8
|
| 787 |
+
scale: Optional scaling factor for the FP8 quantization
|
| 788 |
+
scale_ub: Optional upper bound for scaling factor in dynamic
|
| 789 |
+
per token case
|
| 790 |
+
num_token_padding: If specified, pad the first dimension
|
| 791 |
+
of the output to at least this value.
|
| 792 |
+
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
| 793 |
+
in the dynamic quantization case.
|
| 794 |
+
|
| 795 |
+
Returns:
|
| 796 |
+
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
| 797 |
+
scaling factor.
|
| 798 |
+
"""
|
| 799 |
+
# This code assumes batch_dim and num_tokens are flattened
|
| 800 |
+
assert (input.ndim == 2)
|
| 801 |
+
shape: Union[Tuple[int, int], torch.Size] = input.shape
|
| 802 |
+
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
|
| 803 |
+
out_dtype: torch.dtype = torch.float8_e4m3fnuz \
|
| 804 |
+
if current_platform.is_rocm() else torch.float8_e4m3fn
|
| 805 |
+
if num_token_padding:
|
| 806 |
+
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
| 807 |
+
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
| 808 |
+
|
| 809 |
+
if scale is None:
|
| 810 |
+
if use_per_token_if_dynamic:
|
| 811 |
+
scale = torch.empty((shape[0], 1),
|
| 812 |
+
device=input.device,
|
| 813 |
+
dtype=torch.float32)
|
| 814 |
+
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
|
| 815 |
+
output, input, scale, scale_ub)
|
| 816 |
+
else:
|
| 817 |
+
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
| 818 |
+
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
|
| 819 |
+
else:
|
| 820 |
+
# num_token_padding not implemented for this case
|
| 821 |
+
assert (scale.numel() == 1 or num_token_padding is None)
|
| 822 |
+
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
|
| 823 |
+
|
| 824 |
+
return output, scale
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
# int8
|
| 828 |
+
def scaled_int8_quant(
|
| 829 |
+
input: torch.Tensor,
|
| 830 |
+
scale: Optional[torch.Tensor] = None,
|
| 831 |
+
azp: Optional[torch.Tensor] = None,
|
| 832 |
+
symmetric: bool = True
|
| 833 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
| 834 |
+
"""
|
| 835 |
+
Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
|
| 836 |
+
|
| 837 |
+
Args:
|
| 838 |
+
input: The input tensor to be quantized to int8.
|
| 839 |
+
scale: Optional scaling factor for the int8 quantization.
|
| 840 |
+
When not provided, we invoke dynamic-per-token quantization.
|
| 841 |
+
azp: Optional zero-point for the int8 quantization.
|
| 842 |
+
Must be provided for asymmetric quantization if `scale` is provided.
|
| 843 |
+
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
|
| 844 |
+
|
| 845 |
+
Returns:
|
| 846 |
+
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
|
| 847 |
+
"""
|
| 848 |
+
output = torch.empty_like(input, dtype=torch.int8)
|
| 849 |
+
if scale is not None:
|
| 850 |
+
# static-per-tensor quantization.
|
| 851 |
+
assert symmetric == (
|
| 852 |
+
azp
|
| 853 |
+
is None), "azp must only be provided for asymmetric quantization."
|
| 854 |
+
torch.ops._C.static_scaled_int8_quant(output, input, scale, azp)
|
| 855 |
+
return output, scale, azp
|
| 856 |
+
|
| 857 |
+
# dynamic-per-token quantization.
|
| 858 |
+
input_scales = torch.empty((input.numel() // input.shape[-1], 1),
|
| 859 |
+
device=input.device,
|
| 860 |
+
dtype=torch.float32)
|
| 861 |
+
input_azp = None if symmetric else torch.empty_like(input_scales,
|
| 862 |
+
dtype=torch.int32)
|
| 863 |
+
torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales,
|
| 864 |
+
input_azp)
|
| 865 |
+
return output, input_scales, input_azp
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
# qqq ops
|
| 869 |
+
def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
| 870 |
+
s_tok: torch.Tensor, s_ch: torch.Tensor,
|
| 871 |
+
s_group: torch.Tensor, workspace: torch.Tensor,
|
| 872 |
+
size_m: int, size_n: int, size_k: int) -> torch.Tensor:
|
| 873 |
+
return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group,
|
| 874 |
+
workspace, size_m, size_n, size_k)
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
# gguf
|
| 878 |
+
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int,
|
| 879 |
+
n: int) -> torch.Tensor:
|
| 880 |
+
return torch.ops._C.ggml_dequantize(W, quant_type, m, n)
|
| 881 |
+
|
| 882 |
+
|
| 883 |
+
def ggml_mul_mat_vec_a8(
|
| 884 |
+
W: torch.Tensor,
|
| 885 |
+
X: torch.Tensor,
|
| 886 |
+
quant_type: int,
|
| 887 |
+
row: int,
|
| 888 |
+
) -> torch.Tensor:
|
| 889 |
+
return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)
|
| 890 |
+
|
| 891 |
+
|
| 892 |
+
def ggml_mul_mat_a8(
|
| 893 |
+
W: torch.Tensor,
|
| 894 |
+
X: torch.Tensor,
|
| 895 |
+
quant_type: int,
|
| 896 |
+
row: int,
|
| 897 |
+
) -> torch.Tensor:
|
| 898 |
+
return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)
|
| 899 |
+
|
| 900 |
+
|
| 901 |
+
# mamba
|
| 902 |
+
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
|
| 903 |
+
bias_: Optional[torch.Tensor],
|
| 904 |
+
conv_states: Optional[torch.Tensor],
|
| 905 |
+
query_start_loc: Optional[torch.Tensor],
|
| 906 |
+
cache_indices: Optional[torch.Tensor],
|
| 907 |
+
has_initial_state: Optional[torch.Tensor],
|
| 908 |
+
silu_activation: bool, pad_slot_id: int):
|
| 909 |
+
torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states,
|
| 910 |
+
query_start_loc, cache_indices,
|
| 911 |
+
has_initial_state, silu_activation,
|
| 912 |
+
pad_slot_id)
|
| 913 |
+
|
| 914 |
+
|
| 915 |
+
def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
|
| 916 |
+
weight: torch.Tensor, bias_: Optional[torch.Tensor],
|
| 917 |
+
silu_activation: bool,
|
| 918 |
+
cache_seqlens: Optional[torch.Tensor],
|
| 919 |
+
conv_state_indices: Optional[torch.Tensor],
|
| 920 |
+
pad_slot_id: int):
|
| 921 |
+
torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
|
| 922 |
+
silu_activation, cache_seqlens,
|
| 923 |
+
conv_state_indices, pad_slot_id)
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
|
| 927 |
+
B: torch.Tensor, C: torch.Tensor,
|
| 928 |
+
D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
|
| 929 |
+
delta_bias_: Optional[torch.Tensor],
|
| 930 |
+
delta_softplus: bool,
|
| 931 |
+
query_start_loc: Optional[torch.Tensor],
|
| 932 |
+
cache_indices: Optional[torch.Tensor],
|
| 933 |
+
has_initial_state: Optional[torch.Tensor],
|
| 934 |
+
ssm_states: torch.Tensor, pad_slot_id: int):
|
| 935 |
+
torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_,
|
| 936 |
+
delta_softplus, query_start_loc,
|
| 937 |
+
cache_indices, has_initial_state,
|
| 938 |
+
ssm_states, pad_slot_id)
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
# moe
|
| 942 |
+
def moe_sum(input: torch.Tensor, output: torch.Tensor):
|
| 943 |
+
torch.ops._moe_C.moe_sum(input, output)
|
| 944 |
+
|
| 945 |
+
|
| 946 |
+
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
|
| 947 |
+
block_size: int, sorted_token_ids: torch.Tensor,
|
| 948 |
+
experts_ids: torch.Tensor,
|
| 949 |
+
num_tokens_post_pad: torch.Tensor) -> None:
|
| 950 |
+
torch.ops._moe_C.moe_align_block_size(topk_ids, num_experts, block_size,
|
| 951 |
+
sorted_token_ids, experts_ids,
|
| 952 |
+
num_tokens_post_pad)
|
| 953 |
+
|
| 954 |
+
|
| 955 |
+
def sgl_moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
|
| 956 |
+
block_size: int, sorted_token_ids: torch.Tensor,
|
| 957 |
+
experts_ids: torch.Tensor,
|
| 958 |
+
num_tokens_post_pad: torch.Tensor) -> None:
|
| 959 |
+
torch.ops._moe_C.sgl_moe_align_block_size(topk_ids, num_experts,
|
| 960 |
+
block_size, sorted_token_ids,
|
| 961 |
+
experts_ids, num_tokens_post_pad)
|
| 962 |
+
|
| 963 |
+
|
| 964 |
+
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
| 965 |
+
token_expert_indicies: torch.Tensor,
|
| 966 |
+
gating_output: float) -> None:
|
| 967 |
+
torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
|
| 968 |
+
token_expert_indicies, gating_output)
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
|
| 972 |
+
|
| 973 |
+
@register_fake("_moe_C::marlin_gemm_moe")
|
| 974 |
+
def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
|
| 975 |
+
sorted_ids: torch.Tensor,
|
| 976 |
+
topk_weights: torch.Tensor,
|
| 977 |
+
topk_ids: torch.Tensor, b_scales: torch.Tensor,
|
| 978 |
+
b_zero_points: torch.Tensor, g_idx: torch.Tensor,
|
| 979 |
+
perm: torch.Tensor, workspace: torch.Tensor,
|
| 980 |
+
b_q_type: ScalarType, size_m: torch.SymInt,
|
| 981 |
+
size_n: torch.SymInt, size_k: torch.SymInt,
|
| 982 |
+
is_k_full: bool, num_experts: int, topk: int,
|
| 983 |
+
moe_block_size: int, replicate_input: bool,
|
| 984 |
+
apply_weights: bool) -> torch.Tensor:
|
| 985 |
+
return torch.empty((size_m, topk, size_n),
|
| 986 |
+
dtype=a.dtype,
|
| 987 |
+
device=a.device)
|
| 988 |
+
|
| 989 |
+
|
| 990 |
+
def reshape_and_cache(
|
| 991 |
+
key: torch.Tensor,
|
| 992 |
+
value: torch.Tensor,
|
| 993 |
+
key_cache: torch.Tensor,
|
| 994 |
+
value_cache: torch.Tensor,
|
| 995 |
+
slot_mapping: torch.Tensor,
|
| 996 |
+
kv_cache_dtype: str,
|
| 997 |
+
k_scale: torch.Tensor,
|
| 998 |
+
v_scale: torch.Tensor,
|
| 999 |
+
) -> None:
|
| 1000 |
+
torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
|
| 1001 |
+
value_cache, slot_mapping,
|
| 1002 |
+
kv_cache_dtype, k_scale, v_scale)
|
| 1003 |
+
|
| 1004 |
+
|
| 1005 |
+
def reshape_and_cache_flash(
|
| 1006 |
+
key: torch.Tensor,
|
| 1007 |
+
value: torch.Tensor,
|
| 1008 |
+
key_cache: torch.Tensor,
|
| 1009 |
+
value_cache: torch.Tensor,
|
| 1010 |
+
slot_mapping: torch.Tensor,
|
| 1011 |
+
kv_cache_dtype: str,
|
| 1012 |
+
k_scale: torch.Tensor,
|
| 1013 |
+
v_scale: torch.Tensor,
|
| 1014 |
+
) -> None:
|
| 1015 |
+
torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
|
| 1016 |
+
value_cache, slot_mapping,
|
| 1017 |
+
kv_cache_dtype, k_scale,
|
| 1018 |
+
v_scale)
|
| 1019 |
+
|
| 1020 |
+
|
| 1021 |
+
def concat_and_cache_mla(
|
| 1022 |
+
kv_c: torch.Tensor,
|
| 1023 |
+
k_pe: torch.Tensor,
|
| 1024 |
+
kv_cache: torch.Tensor,
|
| 1025 |
+
slot_mapping: torch.Tensor,
|
| 1026 |
+
kv_cache_dtype: str,
|
| 1027 |
+
scale: torch.Tensor,
|
| 1028 |
+
) -> None:
|
| 1029 |
+
torch.ops._C_cache_ops.concat_and_cache_mla(kv_c, k_pe, kv_cache,
|
| 1030 |
+
slot_mapping, kv_cache_dtype,
|
| 1031 |
+
scale)
|
| 1032 |
+
|
| 1033 |
+
|
| 1034 |
+
def copy_blocks(key_caches: List[torch.Tensor],
|
| 1035 |
+
value_caches: List[torch.Tensor],
|
| 1036 |
+
block_mapping: torch.Tensor) -> None:
|
| 1037 |
+
torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
| 1038 |
+
|
| 1039 |
+
|
| 1040 |
+
def copy_blocks_mla(kv_caches: List[torch.Tensor],
|
| 1041 |
+
block_mapping: torch.Tensor) -> None:
|
| 1042 |
+
torch.ops._C_cache_ops.copy_blocks_mla(kv_caches, block_mapping)
|
| 1043 |
+
|
| 1044 |
+
|
| 1045 |
+
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
|
| 1046 |
+
block_mapping: torch.Tensor) -> None:
|
| 1047 |
+
torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
|
| 1048 |
+
|
| 1049 |
+
|
| 1050 |
+
def convert_fp8(output: torch.Tensor,
|
| 1051 |
+
input: torch.Tensor,
|
| 1052 |
+
scale: float = 1.0,
|
| 1053 |
+
kv_dtype: str = "fp8") -> None:
|
| 1054 |
+
torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)
|
| 1055 |
+
|
| 1056 |
+
|
| 1057 |
+
def get_device_attribute(attribute: int, device: int) -> int:
|
| 1058 |
+
return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
|
| 1059 |
+
|
| 1060 |
+
|
| 1061 |
+
def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
|
| 1062 |
+
# ruff: noqa: E501
|
| 1063 |
+
return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
|
| 1064 |
+
device)
|
| 1065 |
+
|
| 1066 |
+
|
| 1067 |
+
# custom ar
|
| 1068 |
+
def init_custom_ar(ipc_tensors: List[torch.Tensor], rank_data: torch.Tensor,
|
| 1069 |
+
rank: int, full_nvlink: bool) -> int:
|
| 1070 |
+
return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank,
|
| 1071 |
+
full_nvlink)
|
| 1072 |
+
|
| 1073 |
+
|
| 1074 |
+
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int,
|
| 1075 |
+
reg_buffer_sz_bytes: int) -> None:
|
| 1076 |
+
torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer,
|
| 1077 |
+
reg_buffer_sz_bytes)
|
| 1078 |
+
|
| 1079 |
+
|
| 1080 |
+
def dispose(fa: int) -> None:
|
| 1081 |
+
torch.ops._C_custom_ar.dispose(fa)
|
| 1082 |
+
|
| 1083 |
+
|
| 1084 |
+
def meta_size() -> int:
|
| 1085 |
+
return torch.ops._C_custom_ar.meta_size()
|
| 1086 |
+
|
| 1087 |
+
|
| 1088 |
+
def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
|
| 1089 |
+
return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors)
|
| 1090 |
+
|
| 1091 |
+
|
| 1092 |
+
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
|
| 1093 |
+
return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
|
| 1094 |
+
|
| 1095 |
+
|
| 1096 |
+
def register_graph_buffers(fa: int, handles: List[List[int]],
|
| 1097 |
+
offsets: List[List[int]]) -> None:
|
| 1098 |
+
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
|
.venv/lib/python3.11/site-packages/vllm/_ipex_ops.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from vllm.logger import init_logger
|
| 8 |
+
|
| 9 |
+
logger = init_logger(__name__)
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import intel_extension_for_pytorch as ipex
|
| 13 |
+
except ImportError as e:
|
| 14 |
+
logger.warning("Import error msg: %s", e.msg)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ipex_ops:
|
| 18 |
+
|
| 19 |
+
@staticmethod
|
| 20 |
+
def _reshape_activation_tensor(
|
| 21 |
+
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 22 |
+
num = x.size(0)
|
| 23 |
+
d = x.size(1) // 2
|
| 24 |
+
x = x.reshape(num, 2, d)
|
| 25 |
+
x1, x2 = torch.chunk(x, chunks=2, dim=1)
|
| 26 |
+
x1 = x1.reshape(num, d)
|
| 27 |
+
x2 = x2.reshape(num, d)
|
| 28 |
+
return x1, x2
|
| 29 |
+
|
| 30 |
+
@staticmethod
|
| 31 |
+
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
| 32 |
+
ipex.llm.functional.silu_and_mul(x, out)
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
| 36 |
+
ipex.llm.functional.gelu_and_mul(x, out)
|
| 37 |
+
|
| 38 |
+
@staticmethod
|
| 39 |
+
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
| 40 |
+
ipex.llm.functional.gelu_and_mul(x, out)
|
| 41 |
+
|
| 42 |
+
@staticmethod
|
| 43 |
+
def gelu_fast(x: torch.Tensor) -> torch.Tensor:
|
| 44 |
+
return torch.nn.functional.gelu(x)
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def gelu_new(x: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
return torch.nn.functional.gelu(x)
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
|
| 52 |
+
ipex.llm.functional.gelu_quick(x, out)
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
def paged_attention_v1(
|
| 56 |
+
out: torch.Tensor,
|
| 57 |
+
query: torch.Tensor,
|
| 58 |
+
key_cache: torch.Tensor,
|
| 59 |
+
value_cache: torch.Tensor,
|
| 60 |
+
num_kv_heads: int,
|
| 61 |
+
scale: float,
|
| 62 |
+
block_tables: torch.Tensor,
|
| 63 |
+
context_lens: torch.Tensor,
|
| 64 |
+
block_size: int,
|
| 65 |
+
max_context_len: int,
|
| 66 |
+
alibi_slopes: Optional[torch.Tensor],
|
| 67 |
+
kv_cache_dtype: str,
|
| 68 |
+
k_scale: float,
|
| 69 |
+
v_scale: float,
|
| 70 |
+
tp_rank: int = 0,
|
| 71 |
+
blocksparse_local_blocks: int = 0,
|
| 72 |
+
blocksparse_vert_stride: int = 0,
|
| 73 |
+
blocksparse_block_size: int = 64,
|
| 74 |
+
blocksparse_head_sliding_step: int = 0,
|
| 75 |
+
) -> None:
|
| 76 |
+
assert kv_cache_dtype == "auto"
|
| 77 |
+
num_heads = out.size(1)
|
| 78 |
+
num_queries_per_tokens = num_heads // num_kv_heads
|
| 79 |
+
ipex.llm.modules.PagedAttention.single_query_kv_attention(
|
| 80 |
+
out,
|
| 81 |
+
query.contiguous(),
|
| 82 |
+
key_cache.view_as(value_cache),
|
| 83 |
+
value_cache,
|
| 84 |
+
num_queries_per_tokens,
|
| 85 |
+
scale,
|
| 86 |
+
block_tables,
|
| 87 |
+
context_lens,
|
| 88 |
+
block_size,
|
| 89 |
+
max_context_len,
|
| 90 |
+
alibi_slopes,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def paged_attention_v2(
|
| 95 |
+
out: torch.Tensor,
|
| 96 |
+
exp_sum: torch.Tensor,
|
| 97 |
+
max_logits: torch.Tensor,
|
| 98 |
+
tmp_out: torch.Tensor,
|
| 99 |
+
query: torch.Tensor,
|
| 100 |
+
key_cache: torch.Tensor,
|
| 101 |
+
value_cache: torch.Tensor,
|
| 102 |
+
num_kv_heads: int,
|
| 103 |
+
scale: float,
|
| 104 |
+
block_tables: torch.Tensor,
|
| 105 |
+
context_lens: torch.Tensor,
|
| 106 |
+
block_size: int,
|
| 107 |
+
max_context_len: int,
|
| 108 |
+
alibi_slopes: Optional[torch.Tensor],
|
| 109 |
+
kv_cache_dtype: str,
|
| 110 |
+
k_scale: float,
|
| 111 |
+
v_scale: float,
|
| 112 |
+
tp_rank: int = 0,
|
| 113 |
+
blocksparse_local_blocks: int = 0,
|
| 114 |
+
blocksparse_vert_stride: int = 0,
|
| 115 |
+
blocksparse_block_size: int = 64,
|
| 116 |
+
blocksparse_head_sliding_step: int = 0,
|
| 117 |
+
) -> None:
|
| 118 |
+
assert kv_cache_dtype == "auto"
|
| 119 |
+
num_heads = out.size(1)
|
| 120 |
+
num_queries_per_tokens = num_heads // num_kv_heads
|
| 121 |
+
ipex.llm.modules.PagedAttention.single_query_kv_attention(
|
| 122 |
+
out,
|
| 123 |
+
query.contiguous(),
|
| 124 |
+
key_cache.view_as(value_cache),
|
| 125 |
+
value_cache,
|
| 126 |
+
num_queries_per_tokens,
|
| 127 |
+
scale,
|
| 128 |
+
block_tables,
|
| 129 |
+
context_lens,
|
| 130 |
+
block_size,
|
| 131 |
+
max_context_len,
|
| 132 |
+
alibi_slopes,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
@staticmethod
|
| 136 |
+
def rotary_embedding(
|
| 137 |
+
positions: torch.Tensor, # [batch_size, seq_len]
|
| 138 |
+
query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size]
|
| 139 |
+
key: torch.Tensor, # [batch_size, seq_len, num_kv_heads*head_size]
|
| 140 |
+
head_size: int,
|
| 141 |
+
cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim]
|
| 142 |
+
is_neox: bool,
|
| 143 |
+
) -> None:
|
| 144 |
+
rot_dim = cos_sin_cache.size(1)
|
| 145 |
+
ipex.llm.functional.rotary_embedding_batched(positions, query, key,
|
| 146 |
+
head_size, cos_sin_cache,
|
| 147 |
+
is_neox, rot_dim)
|
| 148 |
+
|
| 149 |
+
@staticmethod
|
| 150 |
+
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
| 151 |
+
key: torch.Tensor, head_size: int,
|
| 152 |
+
cos_sin_cache: torch.Tensor, is_neox: bool,
|
| 153 |
+
rot_dim: int,
|
| 154 |
+
cos_sin_cache_offsets: torch.Tensor) -> None:
|
| 155 |
+
ipex.llm.functional.rotary_embedding_batched(positions, query, key,
|
| 156 |
+
head_size, cos_sin_cache,
|
| 157 |
+
is_neox, rot_dim,
|
| 158 |
+
cos_sin_cache_offsets)
|
| 159 |
+
|
| 160 |
+
@staticmethod
|
| 161 |
+
def rms_norm(input: torch.Tensor, weight: torch.Tensor,
|
| 162 |
+
epsilon: float) -> torch.Tensor:
|
| 163 |
+
return ipex.llm.functional.rms_norm(input, weight, epsilon)
|
| 164 |
+
|
| 165 |
+
@staticmethod
|
| 166 |
+
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
|
| 167 |
+
weight: torch.Tensor, epsilon: float) -> None:
|
| 168 |
+
tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
|
| 169 |
+
epsilon, True)
|
| 170 |
+
input.copy_(tmp)
|
| 171 |
+
|
| 172 |
+
@staticmethod
|
| 173 |
+
def varlen_attention(
|
| 174 |
+
query: torch.Tensor,
|
| 175 |
+
key: torch.Tensor,
|
| 176 |
+
value: torch.Tensor,
|
| 177 |
+
out: torch.Tensor,
|
| 178 |
+
seqlen_q: torch.Tensor,
|
| 179 |
+
seqlen_k: torch.Tensor,
|
| 180 |
+
max_seqlen_q: int,
|
| 181 |
+
max_seqlen_k: int,
|
| 182 |
+
pdropout: float,
|
| 183 |
+
softmax_scale: float,
|
| 184 |
+
zero_tensors: bool,
|
| 185 |
+
is_causal: bool,
|
| 186 |
+
return_softmax: bool,
|
| 187 |
+
gen_: torch.Generator,
|
| 188 |
+
logits_soft_cap: float,
|
| 189 |
+
) -> None:
|
| 190 |
+
ipex.llm.functional.varlen_attention(query.contiguous(),
|
| 191 |
+
key.contiguous(),
|
| 192 |
+
value.contiguous(), out,
|
| 193 |
+
seqlen_q.int(), seqlen_k.int(),
|
| 194 |
+
max_seqlen_q, max_seqlen_k,
|
| 195 |
+
pdropout, softmax_scale,
|
| 196 |
+
zero_tensors, is_causal,
|
| 197 |
+
return_softmax, gen_,
|
| 198 |
+
logits_soft_cap)
|
| 199 |
+
|
| 200 |
+
@staticmethod
|
| 201 |
+
def reshape_and_cache(
|
| 202 |
+
key: torch.Tensor,
|
| 203 |
+
value: torch.Tensor,
|
| 204 |
+
key_cache: torch.Tensor,
|
| 205 |
+
value_cache: torch.Tensor,
|
| 206 |
+
slot_mapping: torch.Tensor,
|
| 207 |
+
kv_cache_dtype: str,
|
| 208 |
+
k_scale: float,
|
| 209 |
+
v_scale: float,
|
| 210 |
+
) -> None:
|
| 211 |
+
assert kv_cache_dtype == "auto"
|
| 212 |
+
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
| 213 |
+
key, value, key_cache, value_cache, slot_mapping)
|
| 214 |
+
|
| 215 |
+
@staticmethod
|
| 216 |
+
def copy_blocks(key_caches: List[torch.Tensor],
|
| 217 |
+
value_caches: List[torch.Tensor],
|
| 218 |
+
block_mapping: torch.Tensor) -> None:
|
| 219 |
+
torch.xpu.copy_blocks( # type: ignore
|
| 220 |
+
key_caches,
|
| 221 |
+
value_caches,
|
| 222 |
+
block_mapping,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
@staticmethod
|
| 226 |
+
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
|
| 227 |
+
block_mapping: torch.Tensor) -> None:
|
| 228 |
+
torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore
|
.venv/lib/python3.11/site-packages/vllm/_version.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# file generated by setuptools_scm
|
| 2 |
+
# don't change, don't track in version control
|
| 3 |
+
TYPE_CHECKING = False
|
| 4 |
+
if TYPE_CHECKING:
|
| 5 |
+
from typing import Tuple, Union
|
| 6 |
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
| 7 |
+
else:
|
| 8 |
+
VERSION_TUPLE = object
|
| 9 |
+
|
| 10 |
+
version: str
|
| 11 |
+
__version__: str
|
| 12 |
+
__version_tuple__: VERSION_TUPLE
|
| 13 |
+
version_tuple: VERSION_TUPLE
|
| 14 |
+
|
| 15 |
+
__version__ = version = '0.7.2'
|
| 16 |
+
__version_tuple__ = version_tuple = (0, 7, 2)
|
.venv/lib/python3.11/site-packages/vllm/beam_search.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
| 5 |
+
|
| 6 |
+
from vllm.sequence import Logprob
|
| 7 |
+
|
| 8 |
+
if TYPE_CHECKING:
|
| 9 |
+
from vllm.multimodal import MultiModalDataDict
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class BeamSearchSequence:
|
| 14 |
+
"""A sequence for beam search.
|
| 15 |
+
It keeps track of the tokens and the log probability of the sequence.
|
| 16 |
+
The text field is optional and will only be filled when the sequence is
|
| 17 |
+
about to be returned to the user.
|
| 18 |
+
"""
|
| 19 |
+
# The tokens includes the prompt.
|
| 20 |
+
tokens: List[int]
|
| 21 |
+
logprobs: List[Dict[int, Logprob]]
|
| 22 |
+
cum_logprob: float = 0.0
|
| 23 |
+
text: Optional[str] = None
|
| 24 |
+
finish_reason: Optional[str] = None
|
| 25 |
+
stop_reason: Union[int, str, None] = None
|
| 26 |
+
multi_modal_data: Optional["MultiModalDataDict"] = None
|
| 27 |
+
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class BeamSearchOutput:
|
| 32 |
+
"""The output of beam search.
|
| 33 |
+
It contains the list of the best beam search sequences.
|
| 34 |
+
The length of the list is equal to the beam width.
|
| 35 |
+
"""
|
| 36 |
+
sequences: List[BeamSearchSequence]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class BeamSearchInstance:
|
| 40 |
+
|
| 41 |
+
def __init__(self, prompt_tokens: List[int]):
|
| 42 |
+
self.beams: List[BeamSearchSequence] = [
|
| 43 |
+
BeamSearchSequence(tokens=prompt_tokens, logprobs=[])
|
| 44 |
+
]
|
| 45 |
+
self.completed: List[BeamSearchSequence] = []
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_beam_search_score(
|
| 49 |
+
tokens: List[int],
|
| 50 |
+
cumulative_logprob: float,
|
| 51 |
+
eos_token_id: int,
|
| 52 |
+
length_penalty: float = 1.0,
|
| 53 |
+
) -> float:
|
| 54 |
+
"""Calculate the beam search score with length penalty.
|
| 55 |
+
|
| 56 |
+
Adapted from
|
| 57 |
+
|
| 58 |
+
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
|
| 59 |
+
"""
|
| 60 |
+
seq_len = len(tokens)
|
| 61 |
+
if tokens[-1] == eos_token_id:
|
| 62 |
+
seq_len -= 1
|
| 63 |
+
|
| 64 |
+
return cumulative_logprob / (seq_len**length_penalty)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def create_sort_beams_key_function(eos_token_id: int, length_penalty: float):
|
| 68 |
+
|
| 69 |
+
def sort_beams_key(x: BeamSearchSequence) -> float:
|
| 70 |
+
return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id,
|
| 71 |
+
length_penalty)
|
| 72 |
+
|
| 73 |
+
return sort_beams_key
|
.venv/lib/python3.11/site-packages/vllm/config.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
.venv/lib/python3.11/site-packages/vllm/connections.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Mapping, MutableMapping, Optional
|
| 5 |
+
from urllib.parse import urlparse
|
| 6 |
+
|
| 7 |
+
import aiohttp
|
| 8 |
+
import requests
|
| 9 |
+
|
| 10 |
+
from vllm.version import __version__ as VLLM_VERSION
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class HTTPConnection:
|
| 14 |
+
"""Helper class to send HTTP requests."""
|
| 15 |
+
|
| 16 |
+
def __init__(self, *, reuse_client: bool = True) -> None:
|
| 17 |
+
super().__init__()
|
| 18 |
+
|
| 19 |
+
self.reuse_client = reuse_client
|
| 20 |
+
|
| 21 |
+
self._sync_client: Optional[requests.Session] = None
|
| 22 |
+
self._async_client: Optional[aiohttp.ClientSession] = None
|
| 23 |
+
|
| 24 |
+
def get_sync_client(self) -> requests.Session:
|
| 25 |
+
if self._sync_client is None or not self.reuse_client:
|
| 26 |
+
self._sync_client = requests.Session()
|
| 27 |
+
|
| 28 |
+
return self._sync_client
|
| 29 |
+
|
| 30 |
+
# NOTE: We intentionally use an async function even though it is not
|
| 31 |
+
# required, so that the client is only accessible inside async event loop
|
| 32 |
+
async def get_async_client(self) -> aiohttp.ClientSession:
|
| 33 |
+
if self._async_client is None or not self.reuse_client:
|
| 34 |
+
self._async_client = aiohttp.ClientSession(trust_env=True)
|
| 35 |
+
|
| 36 |
+
return self._async_client
|
| 37 |
+
|
| 38 |
+
def _validate_http_url(self, url: str):
|
| 39 |
+
parsed_url = urlparse(url)
|
| 40 |
+
|
| 41 |
+
if parsed_url.scheme not in ("http", "https"):
|
| 42 |
+
raise ValueError("Invalid HTTP URL: A valid HTTP URL "
|
| 43 |
+
"must have scheme 'http' or 'https'.")
|
| 44 |
+
|
| 45 |
+
def _headers(self, **extras: str) -> MutableMapping[str, str]:
|
| 46 |
+
return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras}
|
| 47 |
+
|
| 48 |
+
def get_response(
|
| 49 |
+
self,
|
| 50 |
+
url: str,
|
| 51 |
+
*,
|
| 52 |
+
stream: bool = False,
|
| 53 |
+
timeout: Optional[float] = None,
|
| 54 |
+
extra_headers: Optional[Mapping[str, str]] = None,
|
| 55 |
+
):
|
| 56 |
+
self._validate_http_url(url)
|
| 57 |
+
|
| 58 |
+
client = self.get_sync_client()
|
| 59 |
+
extra_headers = extra_headers or {}
|
| 60 |
+
|
| 61 |
+
return client.get(url,
|
| 62 |
+
headers=self._headers(**extra_headers),
|
| 63 |
+
stream=stream,
|
| 64 |
+
timeout=timeout)
|
| 65 |
+
|
| 66 |
+
async def get_async_response(
|
| 67 |
+
self,
|
| 68 |
+
url: str,
|
| 69 |
+
*,
|
| 70 |
+
timeout: Optional[float] = None,
|
| 71 |
+
extra_headers: Optional[Mapping[str, str]] = None,
|
| 72 |
+
):
|
| 73 |
+
self._validate_http_url(url)
|
| 74 |
+
|
| 75 |
+
client = await self.get_async_client()
|
| 76 |
+
extra_headers = extra_headers or {}
|
| 77 |
+
|
| 78 |
+
return client.get(url,
|
| 79 |
+
headers=self._headers(**extra_headers),
|
| 80 |
+
timeout=timeout)
|
| 81 |
+
|
| 82 |
+
def get_bytes(self, url: str, *, timeout: Optional[float] = None) -> bytes:
|
| 83 |
+
with self.get_response(url, timeout=timeout) as r:
|
| 84 |
+
r.raise_for_status()
|
| 85 |
+
|
| 86 |
+
return r.content
|
| 87 |
+
|
| 88 |
+
async def async_get_bytes(
|
| 89 |
+
self,
|
| 90 |
+
url: str,
|
| 91 |
+
*,
|
| 92 |
+
timeout: Optional[float] = None,
|
| 93 |
+
) -> bytes:
|
| 94 |
+
async with await self.get_async_response(url, timeout=timeout) as r:
|
| 95 |
+
r.raise_for_status()
|
| 96 |
+
|
| 97 |
+
return await r.read()
|
| 98 |
+
|
| 99 |
+
def get_text(self, url: str, *, timeout: Optional[float] = None) -> str:
|
| 100 |
+
with self.get_response(url, timeout=timeout) as r:
|
| 101 |
+
r.raise_for_status()
|
| 102 |
+
|
| 103 |
+
return r.text
|
| 104 |
+
|
| 105 |
+
async def async_get_text(
|
| 106 |
+
self,
|
| 107 |
+
url: str,
|
| 108 |
+
*,
|
| 109 |
+
timeout: Optional[float] = None,
|
| 110 |
+
) -> str:
|
| 111 |
+
async with await self.get_async_response(url, timeout=timeout) as r:
|
| 112 |
+
r.raise_for_status()
|
| 113 |
+
|
| 114 |
+
return await r.text()
|
| 115 |
+
|
| 116 |
+
def get_json(self, url: str, *, timeout: Optional[float] = None) -> str:
|
| 117 |
+
with self.get_response(url, timeout=timeout) as r:
|
| 118 |
+
r.raise_for_status()
|
| 119 |
+
|
| 120 |
+
return r.json()
|
| 121 |
+
|
| 122 |
+
async def async_get_json(
|
| 123 |
+
self,
|
| 124 |
+
url: str,
|
| 125 |
+
*,
|
| 126 |
+
timeout: Optional[float] = None,
|
| 127 |
+
) -> str:
|
| 128 |
+
async with await self.get_async_response(url, timeout=timeout) as r:
|
| 129 |
+
r.raise_for_status()
|
| 130 |
+
|
| 131 |
+
return await r.json()
|
| 132 |
+
|
| 133 |
+
def download_file(
|
| 134 |
+
self,
|
| 135 |
+
url: str,
|
| 136 |
+
save_path: Path,
|
| 137 |
+
*,
|
| 138 |
+
timeout: Optional[float] = None,
|
| 139 |
+
chunk_size: int = 128,
|
| 140 |
+
) -> Path:
|
| 141 |
+
with self.get_response(url, timeout=timeout) as r:
|
| 142 |
+
r.raise_for_status()
|
| 143 |
+
|
| 144 |
+
with save_path.open("wb") as f:
|
| 145 |
+
for chunk in r.iter_content(chunk_size):
|
| 146 |
+
f.write(chunk)
|
| 147 |
+
|
| 148 |
+
return save_path
|
| 149 |
+
|
| 150 |
+
async def async_download_file(
|
| 151 |
+
self,
|
| 152 |
+
url: str,
|
| 153 |
+
save_path: Path,
|
| 154 |
+
*,
|
| 155 |
+
timeout: Optional[float] = None,
|
| 156 |
+
chunk_size: int = 128,
|
| 157 |
+
) -> Path:
|
| 158 |
+
async with await self.get_async_response(url, timeout=timeout) as r:
|
| 159 |
+
r.raise_for_status()
|
| 160 |
+
|
| 161 |
+
with save_path.open("wb") as f:
|
| 162 |
+
async for chunk in r.content.iter_chunked(chunk_size):
|
| 163 |
+
f.write(chunk)
|
| 164 |
+
|
| 165 |
+
return save_path
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
global_http_connection = HTTPConnection()
|
| 169 |
+
"""The global :class:`HTTPConnection` instance used by vLLM."""
|
.venv/lib/python3.11/site-packages/vllm/cumem_allocator.abi3.so
ADDED
|
Binary file (27.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/envs.py
ADDED
|
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import tempfile
|
| 5 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
| 6 |
+
|
| 7 |
+
if TYPE_CHECKING:
|
| 8 |
+
VLLM_HOST_IP: str = ""
|
| 9 |
+
VLLM_PORT: Optional[int] = None
|
| 10 |
+
VLLM_RPC_BASE_PATH: str = tempfile.gettempdir()
|
| 11 |
+
VLLM_USE_MODELSCOPE: bool = False
|
| 12 |
+
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
|
| 13 |
+
VLLM_NCCL_SO_PATH: Optional[str] = None
|
| 14 |
+
LD_LIBRARY_PATH: Optional[str] = None
|
| 15 |
+
VLLM_USE_TRITON_FLASH_ATTN: bool = False
|
| 16 |
+
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
|
| 17 |
+
LOCAL_RANK: int = 0
|
| 18 |
+
CUDA_VISIBLE_DEVICES: Optional[str] = None
|
| 19 |
+
VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60
|
| 20 |
+
VLLM_API_KEY: Optional[str] = None
|
| 21 |
+
S3_ACCESS_KEY_ID: Optional[str] = None
|
| 22 |
+
S3_SECRET_ACCESS_KEY: Optional[str] = None
|
| 23 |
+
S3_ENDPOINT_URL: Optional[str] = None
|
| 24 |
+
VLLM_CACHE_ROOT: str = os.path.expanduser("~/.cache/vllm")
|
| 25 |
+
VLLM_CONFIG_ROOT: str = os.path.expanduser("~/.config/vllm")
|
| 26 |
+
VLLM_USAGE_STATS_SERVER: str = "https://stats.vllm.ai"
|
| 27 |
+
VLLM_NO_USAGE_STATS: bool = False
|
| 28 |
+
VLLM_DO_NOT_TRACK: bool = False
|
| 29 |
+
VLLM_USAGE_SOURCE: str = ""
|
| 30 |
+
VLLM_CONFIGURE_LOGGING: int = 1
|
| 31 |
+
VLLM_LOGGING_LEVEL: str = "INFO"
|
| 32 |
+
VLLM_LOGGING_PREFIX: str = ""
|
| 33 |
+
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
|
| 34 |
+
VLLM_LOGITS_PROCESSOR_THREADS: Optional[int] = None
|
| 35 |
+
VLLM_TRACE_FUNCTION: int = 0
|
| 36 |
+
VLLM_ATTENTION_BACKEND: Optional[str] = None
|
| 37 |
+
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
|
| 38 |
+
VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False
|
| 39 |
+
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
|
| 40 |
+
VLLM_PP_LAYER_PARTITION: Optional[str] = None
|
| 41 |
+
VLLM_CPU_KVCACHE_SPACE: int = 0
|
| 42 |
+
VLLM_CPU_OMP_THREADS_BIND: str = ""
|
| 43 |
+
VLLM_OPENVINO_DEVICE: str = "CPU"
|
| 44 |
+
VLLM_OPENVINO_KVCACHE_SPACE: int = 0
|
| 45 |
+
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
|
| 46 |
+
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
|
| 47 |
+
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
|
| 48 |
+
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
|
| 49 |
+
VLLM_USE_RAY_SPMD_WORKER: bool = False
|
| 50 |
+
VLLM_USE_RAY_COMPILED_DAG: bool = False
|
| 51 |
+
VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: bool = True
|
| 52 |
+
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
|
| 53 |
+
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
|
| 54 |
+
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
|
| 55 |
+
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
|
| 56 |
+
VLLM_VIDEO_FETCH_TIMEOUT: int = 30
|
| 57 |
+
VLLM_AUDIO_FETCH_TIMEOUT: int = 10
|
| 58 |
+
VLLM_TARGET_DEVICE: str = "cuda"
|
| 59 |
+
MAX_JOBS: Optional[str] = None
|
| 60 |
+
NVCC_THREADS: Optional[str] = None
|
| 61 |
+
VLLM_USE_PRECOMPILED: bool = False
|
| 62 |
+
VLLM_NO_DEPRECATION_WARNING: bool = False
|
| 63 |
+
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False
|
| 64 |
+
CMAKE_BUILD_TYPE: Optional[str] = None
|
| 65 |
+
VERBOSE: bool = False
|
| 66 |
+
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
|
| 67 |
+
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
|
| 68 |
+
VLLM_RPC_TIMEOUT: int = 10000 # ms
|
| 69 |
+
VLLM_PLUGINS: Optional[List[str]] = None
|
| 70 |
+
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
|
| 71 |
+
VLLM_USE_TRITON_AWQ: bool = False
|
| 72 |
+
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
| 73 |
+
VLLM_SKIP_P2P_CHECK: bool = False
|
| 74 |
+
VLLM_DISABLED_KERNELS: List[str] = []
|
| 75 |
+
VLLM_USE_V1: bool = False
|
| 76 |
+
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
|
| 77 |
+
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
|
| 78 |
+
VLLM_DISABLE_COMPILE_CACHE: bool = False
|
| 79 |
+
K_SCALE_CONSTANT: int = 200
|
| 80 |
+
V_SCALE_CONSTANT: int = 100
|
| 81 |
+
VLLM_SERVER_DEV_MODE: bool = False
|
| 82 |
+
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
|
| 83 |
+
VLLM_MLA_DISABLE: bool = False
|
| 84 |
+
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
|
| 85 |
+
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
|
| 86 |
+
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True
|
| 87 |
+
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
|
| 88 |
+
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
|
| 89 |
+
VLLM_RAY_BUNDLE_INDICES: str = ""
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def get_default_cache_root():
|
| 93 |
+
return os.getenv(
|
| 94 |
+
"XDG_CACHE_HOME",
|
| 95 |
+
os.path.join(os.path.expanduser("~"), ".cache"),
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def get_default_config_root():
|
| 100 |
+
return os.getenv(
|
| 101 |
+
"XDG_CONFIG_HOME",
|
| 102 |
+
os.path.join(os.path.expanduser("~"), ".config"),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def maybe_convert_int(value: Optional[str]) -> Optional[int]:
|
| 107 |
+
if value is None:
|
| 108 |
+
return None
|
| 109 |
+
return int(value)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# The begin-* and end* here are used by the documentation generator
|
| 113 |
+
# to extract the used env vars.
|
| 114 |
+
|
| 115 |
+
# begin-env-vars-definition
|
| 116 |
+
|
| 117 |
+
environment_variables: Dict[str, Callable[[], Any]] = {
|
| 118 |
+
|
| 119 |
+
# ================== Installation Time Env Vars ==================
|
| 120 |
+
|
| 121 |
+
# Target device of vLLM, supporting [cuda (by default),
|
| 122 |
+
# rocm, neuron, cpu, openvino]
|
| 123 |
+
"VLLM_TARGET_DEVICE":
|
| 124 |
+
lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda"),
|
| 125 |
+
|
| 126 |
+
# Maximum number of compilation jobs to run in parallel.
|
| 127 |
+
# By default this is the number of CPUs
|
| 128 |
+
"MAX_JOBS":
|
| 129 |
+
lambda: os.getenv("MAX_JOBS", None),
|
| 130 |
+
|
| 131 |
+
# Number of threads to use for nvcc
|
| 132 |
+
# By default this is 1.
|
| 133 |
+
# If set, `MAX_JOBS` will be reduced to avoid oversubscribing the CPU.
|
| 134 |
+
"NVCC_THREADS":
|
| 135 |
+
lambda: os.getenv("NVCC_THREADS", None),
|
| 136 |
+
|
| 137 |
+
# If set, vllm will use precompiled binaries (*.so)
|
| 138 |
+
"VLLM_USE_PRECOMPILED":
|
| 139 |
+
lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")) or bool(
|
| 140 |
+
os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION")),
|
| 141 |
+
|
| 142 |
+
# CMake build type
|
| 143 |
+
# If not set, defaults to "Debug" or "RelWithDebInfo"
|
| 144 |
+
# Available options: "Debug", "Release", "RelWithDebInfo"
|
| 145 |
+
"CMAKE_BUILD_TYPE":
|
| 146 |
+
lambda: os.getenv("CMAKE_BUILD_TYPE"),
|
| 147 |
+
|
| 148 |
+
# If set, vllm will print verbose logs during installation
|
| 149 |
+
"VERBOSE":
|
| 150 |
+
lambda: bool(int(os.getenv('VERBOSE', '0'))),
|
| 151 |
+
|
| 152 |
+
# Root directory for VLLM configuration files
|
| 153 |
+
# Defaults to `~/.config/vllm` unless `XDG_CONFIG_HOME` is set
|
| 154 |
+
# Note that this not only affects how vllm finds its configuration files
|
| 155 |
+
# during runtime, but also affects how vllm installs its configuration
|
| 156 |
+
# files during **installation**.
|
| 157 |
+
"VLLM_CONFIG_ROOT":
|
| 158 |
+
lambda: os.path.expanduser(
|
| 159 |
+
os.getenv(
|
| 160 |
+
"VLLM_CONFIG_ROOT",
|
| 161 |
+
os.path.join(get_default_config_root(), "vllm"),
|
| 162 |
+
)),
|
| 163 |
+
|
| 164 |
+
# ================== Runtime Env Vars ==================
|
| 165 |
+
|
| 166 |
+
# Root directory for VLLM cache files
|
| 167 |
+
# Defaults to `~/.cache/vllm` unless `XDG_CACHE_HOME` is set
|
| 168 |
+
"VLLM_CACHE_ROOT":
|
| 169 |
+
lambda: os.path.expanduser(
|
| 170 |
+
os.getenv(
|
| 171 |
+
"VLLM_CACHE_ROOT",
|
| 172 |
+
os.path.join(get_default_cache_root(), "vllm"),
|
| 173 |
+
)),
|
| 174 |
+
|
| 175 |
+
# used in distributed environment to determine the ip address
|
| 176 |
+
# of the current node, when the node has multiple network interfaces.
|
| 177 |
+
# If you are using multi-node inference, you should set this differently
|
| 178 |
+
# on each node.
|
| 179 |
+
'VLLM_HOST_IP':
|
| 180 |
+
lambda: os.getenv('VLLM_HOST_IP', ""),
|
| 181 |
+
|
| 182 |
+
# used in distributed environment to manually set the communication port
|
| 183 |
+
# Note: if VLLM_PORT is set, and some code asks for multiple ports, the
|
| 184 |
+
# VLLM_PORT will be used as the first port, and the rest will be generated
|
| 185 |
+
# by incrementing the VLLM_PORT value.
|
| 186 |
+
# '0' is used to make mypy happy
|
| 187 |
+
'VLLM_PORT':
|
| 188 |
+
lambda: int(os.getenv('VLLM_PORT', '0'))
|
| 189 |
+
if 'VLLM_PORT' in os.environ else None,
|
| 190 |
+
|
| 191 |
+
# path used for ipc when the frontend api server is running in
|
| 192 |
+
# multi-processing mode to communicate with the backend engine process.
|
| 193 |
+
'VLLM_RPC_BASE_PATH':
|
| 194 |
+
lambda: os.getenv('VLLM_RPC_BASE_PATH', tempfile.gettempdir()),
|
| 195 |
+
|
| 196 |
+
# If true, will load models from ModelScope instead of Hugging Face Hub.
|
| 197 |
+
# note that the value is true or false, not numbers
|
| 198 |
+
"VLLM_USE_MODELSCOPE":
|
| 199 |
+
lambda: os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true",
|
| 200 |
+
|
| 201 |
+
# Interval in seconds to log a warning message when the ring buffer is full
|
| 202 |
+
"VLLM_RINGBUFFER_WARNING_INTERVAL":
|
| 203 |
+
lambda: int(os.environ.get("VLLM_RINGBUFFER_WARNING_INTERVAL", "60")),
|
| 204 |
+
|
| 205 |
+
# path to cudatoolkit home directory, under which should be bin, include,
|
| 206 |
+
# and lib directories.
|
| 207 |
+
"CUDA_HOME":
|
| 208 |
+
lambda: os.environ.get("CUDA_HOME", None),
|
| 209 |
+
|
| 210 |
+
# Path to the NCCL library file. It is needed because nccl>=2.19 brought
|
| 211 |
+
# by PyTorch contains a bug: https://github.com/NVIDIA/nccl/issues/1234
|
| 212 |
+
"VLLM_NCCL_SO_PATH":
|
| 213 |
+
lambda: os.environ.get("VLLM_NCCL_SO_PATH", None),
|
| 214 |
+
|
| 215 |
+
# when `VLLM_NCCL_SO_PATH` is not set, vllm will try to find the nccl
|
| 216 |
+
# library file in the locations specified by `LD_LIBRARY_PATH`
|
| 217 |
+
"LD_LIBRARY_PATH":
|
| 218 |
+
lambda: os.environ.get("LD_LIBRARY_PATH", None),
|
| 219 |
+
|
| 220 |
+
# flag to control if vllm should use triton flash attention
|
| 221 |
+
"VLLM_USE_TRITON_FLASH_ATTN":
|
| 222 |
+
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
|
| 223 |
+
("true", "1")),
|
| 224 |
+
|
| 225 |
+
# Force vllm to use a specific flash-attention version (2 or 3), only valid
|
| 226 |
+
# when using the flash-attention backend.
|
| 227 |
+
"VLLM_FLASH_ATTN_VERSION":
|
| 228 |
+
lambda: maybe_convert_int(os.environ.get("VLLM_FLASH_ATTN_VERSION", None)),
|
| 229 |
+
|
| 230 |
+
# Internal flag to enable Dynamo fullgraph capture
|
| 231 |
+
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
|
| 232 |
+
lambda: bool(
|
| 233 |
+
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
|
| 234 |
+
|
| 235 |
+
# local rank of the process in the distributed setting, used to determine
|
| 236 |
+
# the GPU device id
|
| 237 |
+
"LOCAL_RANK":
|
| 238 |
+
lambda: int(os.environ.get("LOCAL_RANK", "0")),
|
| 239 |
+
|
| 240 |
+
# used to control the visible devices in the distributed setting
|
| 241 |
+
"CUDA_VISIBLE_DEVICES":
|
| 242 |
+
lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None),
|
| 243 |
+
|
| 244 |
+
# timeout for each iteration in the engine
|
| 245 |
+
"VLLM_ENGINE_ITERATION_TIMEOUT_S":
|
| 246 |
+
lambda: int(os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")),
|
| 247 |
+
|
| 248 |
+
# API key for VLLM API server
|
| 249 |
+
"VLLM_API_KEY":
|
| 250 |
+
lambda: os.environ.get("VLLM_API_KEY", None),
|
| 251 |
+
|
| 252 |
+
# S3 access information, used for tensorizer to load model from S3
|
| 253 |
+
"S3_ACCESS_KEY_ID":
|
| 254 |
+
lambda: os.environ.get("S3_ACCESS_KEY_ID", None),
|
| 255 |
+
"S3_SECRET_ACCESS_KEY":
|
| 256 |
+
lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None),
|
| 257 |
+
"S3_ENDPOINT_URL":
|
| 258 |
+
lambda: os.environ.get("S3_ENDPOINT_URL", None),
|
| 259 |
+
|
| 260 |
+
# Usage stats collection
|
| 261 |
+
"VLLM_USAGE_STATS_SERVER":
|
| 262 |
+
lambda: os.environ.get("VLLM_USAGE_STATS_SERVER", "https://stats.vllm.ai"),
|
| 263 |
+
"VLLM_NO_USAGE_STATS":
|
| 264 |
+
lambda: os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1",
|
| 265 |
+
"VLLM_DO_NOT_TRACK":
|
| 266 |
+
lambda: (os.environ.get("VLLM_DO_NOT_TRACK", None) or os.environ.get(
|
| 267 |
+
"DO_NOT_TRACK", None) or "0") == "1",
|
| 268 |
+
"VLLM_USAGE_SOURCE":
|
| 269 |
+
lambda: os.environ.get("VLLM_USAGE_SOURCE", "production"),
|
| 270 |
+
|
| 271 |
+
# Logging configuration
|
| 272 |
+
# If set to 0, vllm will not configure logging
|
| 273 |
+
# If set to 1, vllm will configure logging using the default configuration
|
| 274 |
+
# or the configuration file specified by VLLM_LOGGING_CONFIG_PATH
|
| 275 |
+
"VLLM_CONFIGURE_LOGGING":
|
| 276 |
+
lambda: int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")),
|
| 277 |
+
"VLLM_LOGGING_CONFIG_PATH":
|
| 278 |
+
lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"),
|
| 279 |
+
|
| 280 |
+
# this is used for configuring the default logging level
|
| 281 |
+
"VLLM_LOGGING_LEVEL":
|
| 282 |
+
lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO"),
|
| 283 |
+
|
| 284 |
+
# if set, VLLM_LOGGING_PREFIX will be prepended to all log messages
|
| 285 |
+
"VLLM_LOGGING_PREFIX":
|
| 286 |
+
lambda: os.getenv("VLLM_LOGGING_PREFIX", ""),
|
| 287 |
+
|
| 288 |
+
# if set, vllm will call logits processors in a thread pool with this many
|
| 289 |
+
# threads. This is useful when using custom logits processors that either
|
| 290 |
+
# (a) launch additional CUDA kernels or (b) do significant CPU-bound work
|
| 291 |
+
# while not holding the python GIL, or both.
|
| 292 |
+
"VLLM_LOGITS_PROCESSOR_THREADS":
|
| 293 |
+
lambda: int(os.getenv("VLLM_LOGITS_PROCESSOR_THREADS", "0"))
|
| 294 |
+
if "VLLM_LOGITS_PROCESSOR_THREADS" in os.environ else None,
|
| 295 |
+
|
| 296 |
+
# Trace function calls
|
| 297 |
+
# If set to 1, vllm will trace function calls
|
| 298 |
+
# Useful for debugging
|
| 299 |
+
"VLLM_TRACE_FUNCTION":
|
| 300 |
+
lambda: int(os.getenv("VLLM_TRACE_FUNCTION", "0")),
|
| 301 |
+
|
| 302 |
+
# Backend for attention computation
|
| 303 |
+
# Available options:
|
| 304 |
+
# - "TORCH_SDPA": use torch.nn.MultiheadAttention
|
| 305 |
+
# - "FLASH_ATTN": use FlashAttention
|
| 306 |
+
# - "XFORMERS": use XFormers
|
| 307 |
+
# - "ROCM_FLASH": use ROCmFlashAttention
|
| 308 |
+
# - "FLASHINFER": use flashinfer
|
| 309 |
+
"VLLM_ATTENTION_BACKEND":
|
| 310 |
+
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),
|
| 311 |
+
|
| 312 |
+
# If set, vllm will use flashinfer sampler
|
| 313 |
+
"VLLM_USE_FLASHINFER_SAMPLER":
|
| 314 |
+
lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]))
|
| 315 |
+
if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None,
|
| 316 |
+
|
| 317 |
+
# If set, vllm will force flashinfer to use tensor cores;
|
| 318 |
+
# otherwise will use heuristic based on model architecture.
|
| 319 |
+
"VLLM_FLASHINFER_FORCE_TENSOR_CORES":
|
| 320 |
+
lambda: bool(int(os.getenv("VLLM_FLASHINFER_FORCE_TENSOR_CORES", "0"))),
|
| 321 |
+
|
| 322 |
+
# Pipeline stage partition strategy
|
| 323 |
+
"VLLM_PP_LAYER_PARTITION":
|
| 324 |
+
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
|
| 325 |
+
|
| 326 |
+
# (CPU backend only) CPU key-value cache space.
|
| 327 |
+
# default is 4GB
|
| 328 |
+
"VLLM_CPU_KVCACHE_SPACE":
|
| 329 |
+
lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")),
|
| 330 |
+
|
| 331 |
+
# (CPU backend only) CPU core ids bound by OpenMP threads, e.g., "0-31",
|
| 332 |
+
# "0,1,2", "0-31,33". CPU cores of different ranks are separated by '|'.
|
| 333 |
+
"VLLM_CPU_OMP_THREADS_BIND":
|
| 334 |
+
lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "all"),
|
| 335 |
+
|
| 336 |
+
# OpenVINO device selection
|
| 337 |
+
# default is CPU
|
| 338 |
+
"VLLM_OPENVINO_DEVICE":
|
| 339 |
+
lambda: os.getenv("VLLM_OPENVINO_DEVICE", "CPU").upper(),
|
| 340 |
+
|
| 341 |
+
# OpenVINO key-value cache space
|
| 342 |
+
# default is 4GB
|
| 343 |
+
"VLLM_OPENVINO_KVCACHE_SPACE":
|
| 344 |
+
lambda: int(os.getenv("VLLM_OPENVINO_KVCACHE_SPACE", "0")),
|
| 345 |
+
|
| 346 |
+
# OpenVINO KV cache precision
|
| 347 |
+
# default is bf16 if natively supported by platform, otherwise f16
|
| 348 |
+
# To enable KV cache compression, please, explicitly specify u8
|
| 349 |
+
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION":
|
| 350 |
+
lambda: os.getenv("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION", None),
|
| 351 |
+
|
| 352 |
+
# Enables weights compression during model export via HF Optimum
|
| 353 |
+
# default is False
|
| 354 |
+
"VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS":
|
| 355 |
+
lambda: bool(os.getenv("VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", False)),
|
| 356 |
+
|
| 357 |
+
# If the env var is set, then all workers will execute as separate
|
| 358 |
+
# processes from the engine, and we use the same mechanism to trigger
|
| 359 |
+
# execution on all workers.
|
| 360 |
+
# Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it.
|
| 361 |
+
"VLLM_USE_RAY_SPMD_WORKER":
|
| 362 |
+
lambda: bool(int(os.getenv("VLLM_USE_RAY_SPMD_WORKER", "0"))),
|
| 363 |
+
|
| 364 |
+
# If the env var is set, it uses the Ray's compiled DAG API
|
| 365 |
+
# which optimizes the control plane overhead.
|
| 366 |
+
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
| 367 |
+
"VLLM_USE_RAY_COMPILED_DAG":
|
| 368 |
+
lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0"))),
|
| 369 |
+
|
| 370 |
+
# If the env var is set, it uses NCCL for communication in
|
| 371 |
+
# Ray's compiled DAG. This flag is ignored if
|
| 372 |
+
# VLLM_USE_RAY_COMPILED_DAG is not set.
|
| 373 |
+
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL":
|
| 374 |
+
lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL", "1"))
|
| 375 |
+
),
|
| 376 |
+
|
| 377 |
+
# If the env var is set, it enables GPU communication overlap
|
| 378 |
+
# (experimental feature) in Ray's compiled DAG. This flag is ignored if
|
| 379 |
+
# VLLM_USE_RAY_COMPILED_DAG is not set.
|
| 380 |
+
"VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM":
|
| 381 |
+
lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM", "0"))
|
| 382 |
+
),
|
| 383 |
+
|
| 384 |
+
# Use dedicated multiprocess context for workers.
|
| 385 |
+
# Both spawn and fork work
|
| 386 |
+
"VLLM_WORKER_MULTIPROC_METHOD":
|
| 387 |
+
lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "fork"),
|
| 388 |
+
|
| 389 |
+
# Path to the cache for storing downloaded assets
|
| 390 |
+
"VLLM_ASSETS_CACHE":
|
| 391 |
+
lambda: os.path.expanduser(
|
| 392 |
+
os.getenv(
|
| 393 |
+
"VLLM_ASSETS_CACHE",
|
| 394 |
+
os.path.join(get_default_cache_root(), "vllm", "assets"),
|
| 395 |
+
)),
|
| 396 |
+
|
| 397 |
+
# Timeout for fetching images when serving multimodal models
|
| 398 |
+
# Default is 5 seconds
|
| 399 |
+
"VLLM_IMAGE_FETCH_TIMEOUT":
|
| 400 |
+
lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")),
|
| 401 |
+
|
| 402 |
+
# Timeout for fetching videos when serving multimodal models
|
| 403 |
+
# Default is 15 seconds
|
| 404 |
+
"VLLM_VIDEO_FETCH_TIMEOUT":
|
| 405 |
+
lambda: int(os.getenv("VLLM_VIDEO_FETCH_TIMEOUT", "15")),
|
| 406 |
+
|
| 407 |
+
# Timeout for fetching audio when serving multimodal models
|
| 408 |
+
# Default is 10 seconds
|
| 409 |
+
"VLLM_AUDIO_FETCH_TIMEOUT":
|
| 410 |
+
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
|
| 411 |
+
|
| 412 |
+
# Path to the XLA persistent cache directory.
|
| 413 |
+
# Only used for XLA devices such as TPUs.
|
| 414 |
+
"VLLM_XLA_CACHE_PATH":
|
| 415 |
+
lambda: os.path.expanduser(
|
| 416 |
+
os.getenv(
|
| 417 |
+
"VLLM_XLA_CACHE_PATH",
|
| 418 |
+
os.path.join(get_default_cache_root(), "vllm", "xla_cache"),
|
| 419 |
+
)),
|
| 420 |
+
"VLLM_FUSED_MOE_CHUNK_SIZE":
|
| 421 |
+
lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")),
|
| 422 |
+
|
| 423 |
+
# If set, vllm will skip the deprecation warnings.
|
| 424 |
+
"VLLM_NO_DEPRECATION_WARNING":
|
| 425 |
+
lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))),
|
| 426 |
+
|
| 427 |
+
# If set, the OpenAI API server will stay alive even after the underlying
|
| 428 |
+
# AsyncLLMEngine errors and stops serving requests
|
| 429 |
+
"VLLM_KEEP_ALIVE_ON_ENGINE_DEATH":
|
| 430 |
+
lambda: bool(os.getenv("VLLM_KEEP_ALIVE_ON_ENGINE_DEATH", 0)),
|
| 431 |
+
|
| 432 |
+
# If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows
|
| 433 |
+
# the user to specify a max sequence length greater than
|
| 434 |
+
# the max length derived from the model's config.json.
|
| 435 |
+
# To enable this, set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1.
|
| 436 |
+
"VLLM_ALLOW_LONG_MAX_MODEL_LEN":
|
| 437 |
+
lambda:
|
| 438 |
+
(os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() in
|
| 439 |
+
("1", "true")),
|
| 440 |
+
|
| 441 |
+
# If set, forces FP8 Marlin to be used for FP8 quantization regardless
|
| 442 |
+
# of the hardware support for FP8 compute.
|
| 443 |
+
"VLLM_TEST_FORCE_FP8_MARLIN":
|
| 444 |
+
lambda:
|
| 445 |
+
(os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in
|
| 446 |
+
("1", "true")),
|
| 447 |
+
"VLLM_TEST_FORCE_LOAD_FORMAT":
|
| 448 |
+
lambda: os.getenv("VLLM_TEST_FORCE_LOAD_FORMAT", "dummy"),
|
| 449 |
+
|
| 450 |
+
# Time in ms for the zmq client to wait for a response from the backend
|
| 451 |
+
# server for simple data operations
|
| 452 |
+
"VLLM_RPC_TIMEOUT":
|
| 453 |
+
lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")),
|
| 454 |
+
|
| 455 |
+
# a list of plugin names to load, separated by commas.
|
| 456 |
+
# if this is not set, it means all plugins will be loaded
|
| 457 |
+
# if this is set to an empty string, no plugins will be loaded
|
| 458 |
+
"VLLM_PLUGINS":
|
| 459 |
+
lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[
|
| 460 |
+
"VLLM_PLUGINS"].split(","),
|
| 461 |
+
|
| 462 |
+
# Enables torch profiler if set. Path to the directory where torch profiler
|
| 463 |
+
# traces are saved. Note that it must be an absolute path.
|
| 464 |
+
"VLLM_TORCH_PROFILER_DIR":
|
| 465 |
+
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os
|
| 466 |
+
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))),
|
| 467 |
+
|
| 468 |
+
# If set, vLLM will use Triton implementations of AWQ.
|
| 469 |
+
"VLLM_USE_TRITON_AWQ":
|
| 470 |
+
lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),
|
| 471 |
+
|
| 472 |
+
# If set, allow loading or unloading lora adapters in runtime,
|
| 473 |
+
"VLLM_ALLOW_RUNTIME_LORA_UPDATING":
|
| 474 |
+
lambda:
|
| 475 |
+
(os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() in
|
| 476 |
+
("1", "true")),
|
| 477 |
+
|
| 478 |
+
# By default, vLLM will check the peer-to-peer capability itself,
|
| 479 |
+
# in case of broken drivers. See https://github.com/vllm-project/vllm/blob/a9b15c606fea67a072416ea0ea115261a2756058/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L101-L108 for details. # noqa
|
| 480 |
+
# If this env var is set to 1, vLLM will skip the peer-to-peer check,
|
| 481 |
+
# and trust the driver's peer-to-peer capability report.
|
| 482 |
+
"VLLM_SKIP_P2P_CHECK":
|
| 483 |
+
lambda: os.getenv("VLLM_SKIP_P2P_CHECK", "0") == "1",
|
| 484 |
+
|
| 485 |
+
# List of quantization kernels that should be disabled, used for testing
|
| 486 |
+
# and performance comparisons. Currently only affects MPLinearKernel
|
| 487 |
+
# selection
|
| 488 |
+
# (kernels: MacheteLinearKernel, MarlinLinearKernel, ExllamaLinearKernel)
|
| 489 |
+
"VLLM_DISABLED_KERNELS":
|
| 490 |
+
lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[
|
| 491 |
+
"VLLM_DISABLED_KERNELS"].split(","),
|
| 492 |
+
|
| 493 |
+
# If set, use the V1 code path.
|
| 494 |
+
"VLLM_USE_V1":
|
| 495 |
+
lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))),
|
| 496 |
+
|
| 497 |
+
# Divisor for dynamic key scale factor calculation for FP8 KV Cache
|
| 498 |
+
"K_SCALE_CONSTANT":
|
| 499 |
+
lambda: int(os.getenv("K_SCALE_CONSTANT", "200")),
|
| 500 |
+
|
| 501 |
+
# Divisor for dynamic value scale factor calculation for FP8 KV Cache
|
| 502 |
+
"V_SCALE_CONSTANT":
|
| 503 |
+
lambda: int(os.getenv("V_SCALE_CONSTANT", "100")),
|
| 504 |
+
# If set, enable multiprocessing in LLM for the V1 code path.
|
| 505 |
+
"VLLM_ENABLE_V1_MULTIPROCESSING":
|
| 506 |
+
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))),
|
| 507 |
+
"VLLM_LOG_BATCHSIZE_INTERVAL":
|
| 508 |
+
lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")),
|
| 509 |
+
"VLLM_DISABLE_COMPILE_CACHE":
|
| 510 |
+
lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))),
|
| 511 |
+
|
| 512 |
+
# If set, vllm will run in development mode, which will enable
|
| 513 |
+
# some additional endpoints for developing and debugging,
|
| 514 |
+
# e.g. `/reset_prefix_cache`
|
| 515 |
+
"VLLM_SERVER_DEV_MODE":
|
| 516 |
+
lambda: bool(int(os.getenv("VLLM_SERVER_DEV_MODE", "0"))),
|
| 517 |
+
|
| 518 |
+
# Controls the maximum number of requests to handle in a
|
| 519 |
+
# single asyncio task when processing per-token outputs in the
|
| 520 |
+
# V1 AsyncLLM interface. It is applicable when handling a high
|
| 521 |
+
# concurrency of streaming requests.
|
| 522 |
+
# Setting this too high can result in a higher variance of
|
| 523 |
+
# inter-message latencies. Setting it too low can negatively impact
|
| 524 |
+
# TTFT and overall throughput.
|
| 525 |
+
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE":
|
| 526 |
+
lambda: int(os.getenv("VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "128")),
|
| 527 |
+
|
| 528 |
+
# If set, vLLM will disable the MLA attention optimizations.
|
| 529 |
+
"VLLM_MLA_DISABLE":
|
| 530 |
+
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))),
|
| 531 |
+
|
| 532 |
+
# Flag that can control whether or not we perform matrix-absorption for MLA
|
| 533 |
+
# decode, i.e. absorb W_UK into W_Q/W_UK and W_UV into W_O, absorbing the
|
| 534 |
+
# matrices reduces the runtime FLOPs needed to compute MLA but requires
|
| 535 |
+
# storing more weights, W_Q_UK and W_UV_O, so can increase memory usage,
|
| 536 |
+
# the is enabled by default
|
| 537 |
+
"VLLM_MLA_PERFORM_MATRIX_ABSORPTION":
|
| 538 |
+
lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1"))),
|
| 539 |
+
|
| 540 |
+
# When running MLA with matrix-absorption enabled and fp8 quantized weights
|
| 541 |
+
# we perform the matrix-absorption in float32 precision, after the matrices
|
| 542 |
+
# are absorbed we requantize the weights back to fp8, this flag can be used
|
| 543 |
+
# to disable the requantization step, and instead convert the absorbed
|
| 544 |
+
# matrices to match the activation type. This can lead to higher memory and
|
| 545 |
+
# compute usage but better preserves the accuracy of the original model.
|
| 546 |
+
"VLLM_MLA_DISABLE_REQUANTIZATION":
|
| 547 |
+
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE_REQUANTIZATION", "0"))),
|
| 548 |
+
|
| 549 |
+
# If set, vLLM will use the Triton implementation of moe_align_block_size,
|
| 550 |
+
# i.e. moe_align_block_size_triton in fused_moe.py.
|
| 551 |
+
"VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON":
|
| 552 |
+
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
| 553 |
+
),
|
| 554 |
+
|
| 555 |
+
# Number of GPUs per worker in Ray, if it is set to be a fraction,
|
| 556 |
+
# it allows ray to schedule multiple actors on a single GPU,
|
| 557 |
+
# so that users can colocate other actors on the same GPUs as vLLM.
|
| 558 |
+
"VLLM_RAY_PER_WORKER_GPUS":
|
| 559 |
+
lambda: float(os.getenv("VLLM_RAY_PER_WORKER_GPUS", "1.0")),
|
| 560 |
+
|
| 561 |
+
# Bundle indices for Ray, if it is set, it can control precisely
|
| 562 |
+
# which indices are used for the Ray bundle, for every worker.
|
| 563 |
+
# Format: comma-separated list of integers, e.g. "0,1,2,3"
|
| 564 |
+
"VLLM_RAY_BUNDLE_INDICES":
|
| 565 |
+
lambda: os.getenv("VLLM_RAY_BUNDLE_INDICES", ""),
|
| 566 |
+
|
| 567 |
+
# When on a Nvidia GPU aligns single entries (within a page) so they are 256
|
| 568 |
+
# byte aligned for better performance, this increases the memory usage of
|
| 569 |
+
# the cache. Currently this only affects MLA that results in non-256
|
| 570 |
+
# byte aligned entries. This matches the alignment the CUDA runtime uses
|
| 571 |
+
# for all allocations. Currently this primarily affects MLA, for most other
|
| 572 |
+
# models the alignment is already naturally aligned to 256 bytes.
|
| 573 |
+
"VLLM_CUDA_MEM_ALIGN_KV_CACHE":
|
| 574 |
+
lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))),
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
# end-env-vars-definition
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
def __getattr__(name: str):
|
| 581 |
+
# lazy evaluation of environment variables
|
| 582 |
+
if name in environment_variables:
|
| 583 |
+
return environment_variables[name]()
|
| 584 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def __dir__():
|
| 588 |
+
return list(environment_variables.keys())
|
.venv/lib/python3.11/site-packages/vllm/executor/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/vllm/executor/__pycache__/ray_distributed_executor.cpython-311.pyc
ADDED
|
Binary file (30 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/executor/executor_base.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
|
| 6 |
+
Union)
|
| 7 |
+
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from typing_extensions import TypeVar
|
| 10 |
+
|
| 11 |
+
from vllm.config import VllmConfig
|
| 12 |
+
from vllm.logger import init_logger
|
| 13 |
+
from vllm.lora.request import LoRARequest
|
| 14 |
+
from vllm.model_executor.layers.sampler import SamplerOutput
|
| 15 |
+
from vllm.platforms import current_platform
|
| 16 |
+
from vllm.prompt_adapter.request import PromptAdapterRequest
|
| 17 |
+
from vllm.sequence import ExecuteModelRequest, PoolerOutput
|
| 18 |
+
from vllm.utils import make_async
|
| 19 |
+
from vllm.worker.worker_base import WorkerBase
|
| 20 |
+
|
| 21 |
+
logger = init_logger(__name__)
|
| 22 |
+
|
| 23 |
+
_R = TypeVar("_R", default=Any)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ExecutorBase(ABC):
|
| 27 |
+
"""Base class for all executors.
|
| 28 |
+
|
| 29 |
+
An executor is responsible for executing the model on one device,
|
| 30 |
+
or it can be a distributed executor
|
| 31 |
+
that can execute the model on multiple devices.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
uses_ray: bool # whether the executor uses Ray for orchestration.
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
vllm_config: VllmConfig,
|
| 39 |
+
) -> None:
|
| 40 |
+
self.vllm_config = vllm_config
|
| 41 |
+
self.model_config = vllm_config.model_config
|
| 42 |
+
self.cache_config = vllm_config.cache_config
|
| 43 |
+
self.lora_config = vllm_config.lora_config
|
| 44 |
+
self.load_config = vllm_config.load_config
|
| 45 |
+
self.parallel_config = vllm_config.parallel_config
|
| 46 |
+
self.scheduler_config = vllm_config.scheduler_config
|
| 47 |
+
self.device_config = vllm_config.device_config
|
| 48 |
+
self.speculative_config = vllm_config.speculative_config
|
| 49 |
+
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
| 50 |
+
self.observability_config = vllm_config.observability_config
|
| 51 |
+
self._init_executor()
|
| 52 |
+
self.is_sleeping = False
|
| 53 |
+
|
| 54 |
+
@abstractmethod
|
| 55 |
+
def _init_executor(self) -> None:
|
| 56 |
+
raise NotImplementedError
|
| 57 |
+
|
| 58 |
+
@abstractmethod
|
| 59 |
+
def collective_rpc(self,
|
| 60 |
+
method: Union[str, Callable[..., _R]],
|
| 61 |
+
timeout: Optional[float] = None,
|
| 62 |
+
args: Tuple = (),
|
| 63 |
+
kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
|
| 64 |
+
"""
|
| 65 |
+
Execute an RPC call on all workers.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
method: Name of the worker method to execute, or a callable that
|
| 69 |
+
is serialized and sent to all workers to execute.
|
| 70 |
+
|
| 71 |
+
If the method is a callable, it should accept an additional
|
| 72 |
+
`self` argument, in addition to the arguments passed in `args`
|
| 73 |
+
and `kwargs`. The `self` argument will be the worker object.
|
| 74 |
+
timeout: Maximum time in seconds to wait for execution. Raises a
|
| 75 |
+
:exc:`TimeoutError` on timeout. `None` means wait indefinitely.
|
| 76 |
+
args: Positional arguments to pass to the worker method.
|
| 77 |
+
kwargs: Keyword arguments to pass to the worker method.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
A list containing the results from each worker.
|
| 81 |
+
|
| 82 |
+
Note:
|
| 83 |
+
It is recommended to use this API to only pass control messages,
|
| 84 |
+
and set up data-plane communication to pass data.
|
| 85 |
+
"""
|
| 86 |
+
raise NotImplementedError
|
| 87 |
+
|
| 88 |
+
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
| 89 |
+
"""Determine the number of available blocks for the GPU KV cache and
|
| 90 |
+
swappable CPU KV cache.
|
| 91 |
+
|
| 92 |
+
Normally, this should simply delegate to the underlying Worker. Some
|
| 93 |
+
ExecutorBase may require modification of the result, e.g. to ensure the
|
| 94 |
+
selected cache sizes are compatible with all workers.
|
| 95 |
+
|
| 96 |
+
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
|
| 97 |
+
are blocks that are "active" on the device and can be appended to.
|
| 98 |
+
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
|
| 99 |
+
appended to.
|
| 100 |
+
"""
|
| 101 |
+
results = self.collective_rpc("determine_num_available_blocks")
|
| 102 |
+
a = min([r[0] for r in results])
|
| 103 |
+
b = min([r[1] for r in results])
|
| 104 |
+
return a, b
|
| 105 |
+
|
| 106 |
+
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
|
| 107 |
+
"""Initialize the KV cache by invoking the underlying worker.
|
| 108 |
+
"""
|
| 109 |
+
# NOTE: This is logged in the executor because there can be >1 workers.
|
| 110 |
+
logger.info("# %s blocks: %d, # CPU blocks: %d",
|
| 111 |
+
current_platform.dispatch_key, num_gpu_blocks,
|
| 112 |
+
num_cpu_blocks)
|
| 113 |
+
max_concurrency = (num_gpu_blocks * self.cache_config.block_size /
|
| 114 |
+
self.model_config.max_model_len)
|
| 115 |
+
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
|
| 116 |
+
self.model_config.max_model_len, max_concurrency)
|
| 117 |
+
|
| 118 |
+
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
| 119 |
+
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
| 120 |
+
|
| 121 |
+
self.collective_rpc("initialize_cache",
|
| 122 |
+
args=(num_gpu_blocks, num_cpu_blocks))
|
| 123 |
+
|
| 124 |
+
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
| 125 |
+
"""
|
| 126 |
+
Run a function directly on the model inside each worker,
|
| 127 |
+
returning the result for each of them.
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
def rpc_func(worker: WorkerBase) -> _R:
|
| 131 |
+
return func(worker.get_model())
|
| 132 |
+
|
| 133 |
+
return self.collective_rpc(rpc_func)
|
| 134 |
+
|
| 135 |
+
def execute_model(
|
| 136 |
+
self, execute_model_req: ExecuteModelRequest
|
| 137 |
+
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
|
| 138 |
+
output = self.collective_rpc("execute_model",
|
| 139 |
+
args=(execute_model_req, ))
|
| 140 |
+
return output[0]
|
| 141 |
+
|
| 142 |
+
def stop_remote_worker_execution_loop(self) -> None:
|
| 143 |
+
"""Releases parallel workers from model loop."""
|
| 144 |
+
return
|
| 145 |
+
|
| 146 |
+
def add_lora(self, lora_request: LoRARequest) -> bool:
|
| 147 |
+
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
| 148 |
+
return all(self.collective_rpc("add_lora", args=(lora_request, )))
|
| 149 |
+
|
| 150 |
+
def remove_lora(self, lora_id: int) -> bool:
|
| 151 |
+
assert lora_id > 0, "lora_id must be greater than 0."
|
| 152 |
+
return all(self.collective_rpc("remove_lora", args=(lora_id, )))
|
| 153 |
+
|
| 154 |
+
def pin_lora(self, lora_id: int) -> bool:
|
| 155 |
+
assert lora_id > 0, "lora_id must be greater than 0."
|
| 156 |
+
return all(self.collective_rpc("pin_lora", args=(lora_id, )))
|
| 157 |
+
|
| 158 |
+
def list_loras(self) -> Set[int]:
|
| 159 |
+
sets = self.collective_rpc("list_loras")
|
| 160 |
+
for s in sets:
|
| 161 |
+
assert s == sets[0], "All workers should have the same LORAs."
|
| 162 |
+
return sets[0]
|
| 163 |
+
|
| 164 |
+
def add_prompt_adapter(
|
| 165 |
+
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
| 166 |
+
assert prompt_adapter_request.prompt_adapter_id > 0, \
|
| 167 |
+
"prompt_adapter_id must be greater than 0."
|
| 168 |
+
return all(
|
| 169 |
+
self.collective_rpc("add_prompt_adapter",
|
| 170 |
+
args=(prompt_adapter_request, )))
|
| 171 |
+
|
| 172 |
+
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
| 173 |
+
assert prompt_adapter_id > 0, \
|
| 174 |
+
"prompt_adapter_id must be greater than 0."
|
| 175 |
+
return all(
|
| 176 |
+
self.collective_rpc("remove_prompt_adapter",
|
| 177 |
+
args=(prompt_adapter_id, )))
|
| 178 |
+
|
| 179 |
+
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
| 180 |
+
assert prompt_adapter_id > 0, \
|
| 181 |
+
"prompt_adapter_id must be greater than 0."
|
| 182 |
+
return all(
|
| 183 |
+
self.collective_rpc("pin_prompt_adapter",
|
| 184 |
+
args=(prompt_adapter_id, )))
|
| 185 |
+
|
| 186 |
+
def list_prompt_adapters(self) -> Set[int]:
|
| 187 |
+
sets = self.collective_rpc("list_prompt_adapters")
|
| 188 |
+
for s in sets:
|
| 189 |
+
assert (s == sets[0]
|
| 190 |
+
), "All workers should have the same prompt adapters."
|
| 191 |
+
return sets[0]
|
| 192 |
+
|
| 193 |
+
def start_profile(self) -> None:
|
| 194 |
+
self.collective_rpc("start_profile")
|
| 195 |
+
|
| 196 |
+
def stop_profile(self) -> None:
|
| 197 |
+
self.collective_rpc("stop_profile")
|
| 198 |
+
|
| 199 |
+
def sleep(self, level: int = 1):
|
| 200 |
+
if self.is_sleeping:
|
| 201 |
+
logger.warning("Executor is already sleeping.")
|
| 202 |
+
return
|
| 203 |
+
self.collective_rpc("sleep", kwargs=dict(level=level))
|
| 204 |
+
self.is_sleeping = True
|
| 205 |
+
|
| 206 |
+
def wake_up(self):
|
| 207 |
+
if not self.is_sleeping:
|
| 208 |
+
logger.warning("Executor is not sleeping.")
|
| 209 |
+
return
|
| 210 |
+
self.collective_rpc("wake_up")
|
| 211 |
+
self.is_sleeping = False
|
| 212 |
+
|
| 213 |
+
def save_sharded_state(
|
| 214 |
+
self,
|
| 215 |
+
path: str,
|
| 216 |
+
pattern: Optional[str] = None,
|
| 217 |
+
max_size: Optional[int] = None,
|
| 218 |
+
) -> None:
|
| 219 |
+
self.collective_rpc("save_sharded_state",
|
| 220 |
+
kwargs=dict(path=path,
|
| 221 |
+
pattern=pattern,
|
| 222 |
+
max_size=max_size))
|
| 223 |
+
|
| 224 |
+
@abstractmethod
|
| 225 |
+
def check_health(self) -> None:
|
| 226 |
+
"""Checks if the executor is healthy. If not, it should raise an
|
| 227 |
+
exception."""
|
| 228 |
+
raise NotImplementedError
|
| 229 |
+
|
| 230 |
+
def shutdown(self) -> None:
|
| 231 |
+
"""Shutdown the executor."""
|
| 232 |
+
return
|
| 233 |
+
|
| 234 |
+
def __del__(self):
|
| 235 |
+
self.shutdown()
|
| 236 |
+
|
| 237 |
+
async def execute_model_async(
|
| 238 |
+
self,
|
| 239 |
+
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
| 240 |
+
"""Executes one model step on the given sequences."""
|
| 241 |
+
output = await make_async(self.execute_model)(execute_model_req)
|
| 242 |
+
return output
|
| 243 |
+
|
| 244 |
+
async def stop_remote_worker_execution_loop_async(self) -> None:
|
| 245 |
+
"""Releases parallel workers from model loop."""
|
| 246 |
+
return
|
| 247 |
+
|
| 248 |
+
async def check_health_async(self) -> None:
|
| 249 |
+
"""Checks if the executor is healthy. If not, it should raise an
|
| 250 |
+
exception."""
|
| 251 |
+
self.check_health()
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class DistributedExecutorBase(ExecutorBase):
|
| 255 |
+
"""Abstract superclass of distributed executor implementations."""
|
| 256 |
+
|
| 257 |
+
def __init__(self, *args, **kwargs):
|
| 258 |
+
# This is non-None when the execute model loop is running
|
| 259 |
+
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
|
| 260 |
+
self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
|
| 261 |
+
|
| 262 |
+
super().__init__(*args, **kwargs)
|
| 263 |
+
|
| 264 |
+
def execute_model(
|
| 265 |
+
self,
|
| 266 |
+
execute_model_req: ExecuteModelRequest,
|
| 267 |
+
) -> List[SamplerOutput]:
|
| 268 |
+
# TODO: unify into collective_rpc
|
| 269 |
+
if self.parallel_worker_tasks is None:
|
| 270 |
+
self.parallel_worker_tasks = self._run_workers(
|
| 271 |
+
"start_worker_execution_loop",
|
| 272 |
+
async_run_tensor_parallel_workers_only=True)
|
| 273 |
+
|
| 274 |
+
# Only the driver worker returns the sampling results.
|
| 275 |
+
driver_outputs = self._driver_execute_model(execute_model_req)
|
| 276 |
+
assert driver_outputs is not None
|
| 277 |
+
return driver_outputs
|
| 278 |
+
|
| 279 |
+
def stop_remote_worker_execution_loop(self) -> None:
|
| 280 |
+
if self.parallel_worker_tasks is None:
|
| 281 |
+
return
|
| 282 |
+
|
| 283 |
+
self._driver_execute_model(execute_model_req=None)
|
| 284 |
+
parallel_worker_tasks = self.parallel_worker_tasks
|
| 285 |
+
self.parallel_worker_tasks = None
|
| 286 |
+
# Ensure that workers exit model loop cleanly
|
| 287 |
+
# (this will raise otherwise)
|
| 288 |
+
self._wait_for_tasks_completion(parallel_worker_tasks)
|
| 289 |
+
|
| 290 |
+
@abstractmethod
|
| 291 |
+
def _driver_execute_model(
|
| 292 |
+
self, execute_model_req: Optional[ExecuteModelRequest]
|
| 293 |
+
) -> Optional[List[SamplerOutput]]:
|
| 294 |
+
"""Run execute_model in the driver worker.
|
| 295 |
+
|
| 296 |
+
Passing None will cause the driver to stop the model execution loop
|
| 297 |
+
running in each of the remote workers. In this case, this method
|
| 298 |
+
returns None. Otherwise, this method returns the model output.
|
| 299 |
+
"""
|
| 300 |
+
raise NotImplementedError
|
| 301 |
+
|
| 302 |
+
def collective_rpc(self,
|
| 303 |
+
method: Union[str, Callable],
|
| 304 |
+
timeout: Optional[float] = None,
|
| 305 |
+
args: Tuple = (),
|
| 306 |
+
kwargs: Optional[Dict] = None) -> List[Any]:
|
| 307 |
+
return self._run_workers(method, *args, **(kwargs or {}))
|
| 308 |
+
|
| 309 |
+
@abstractmethod
|
| 310 |
+
def _run_workers(
|
| 311 |
+
self,
|
| 312 |
+
method: Union[str, Callable],
|
| 313 |
+
*args,
|
| 314 |
+
async_run_tensor_parallel_workers_only: bool = False,
|
| 315 |
+
max_concurrent_workers: Optional[int] = None,
|
| 316 |
+
**kwargs,
|
| 317 |
+
) -> Any:
|
| 318 |
+
"""Runs the given method on all workers.
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
async_run_tensor_parallel_workers_only: If True the method will be
|
| 322 |
+
run only in the remote TP workers, not the driver worker.
|
| 323 |
+
It will also be run asynchronously and return a list of futures
|
| 324 |
+
rather than blocking on the results.
|
| 325 |
+
|
| 326 |
+
# TODO: simplify and merge with collective_rpc
|
| 327 |
+
"""
|
| 328 |
+
raise NotImplementedError
|
| 329 |
+
|
| 330 |
+
@abstractmethod
|
| 331 |
+
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
| 332 |
+
"""Wait for futures returned from _run_workers() with
|
| 333 |
+
async_run_remote_workers_only to complete."""
|
| 334 |
+
raise NotImplementedError
|
| 335 |
+
|
| 336 |
+
async def execute_model_async(
|
| 337 |
+
self,
|
| 338 |
+
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
| 339 |
+
if self.parallel_worker_tasks is None:
|
| 340 |
+
# Start model execution loop running in the parallel workers
|
| 341 |
+
self.parallel_worker_tasks = asyncio.create_task(
|
| 342 |
+
self._start_worker_execution_loop())
|
| 343 |
+
|
| 344 |
+
# Only the driver worker returns the sampling results.
|
| 345 |
+
return await self._driver_execute_model_async(execute_model_req)
|
| 346 |
+
|
| 347 |
+
async def stop_remote_worker_execution_loop_async(self) -> None:
|
| 348 |
+
if self.parallel_worker_tasks is None:
|
| 349 |
+
return
|
| 350 |
+
|
| 351 |
+
await self._driver_execute_model_async()
|
| 352 |
+
parallel_worker_tasks = self.parallel_worker_tasks
|
| 353 |
+
self.parallel_worker_tasks = None
|
| 354 |
+
# Ensure that workers exit model loop cleanly
|
| 355 |
+
# (this will raise otherwise)
|
| 356 |
+
await parallel_worker_tasks
|
| 357 |
+
|
| 358 |
+
@abstractmethod
|
| 359 |
+
async def _driver_execute_model_async(
|
| 360 |
+
self,
|
| 361 |
+
execute_model_req: Optional[ExecuteModelRequest] = None,
|
| 362 |
+
) -> List[SamplerOutput]:
|
| 363 |
+
"""Execute the model asynchronously in the driver worker.
|
| 364 |
+
|
| 365 |
+
Passing None will cause the driver to stop the model execution
|
| 366 |
+
loop running in each of the remote workers.
|
| 367 |
+
"""
|
| 368 |
+
raise NotImplementedError
|
| 369 |
+
|
| 370 |
+
@abstractmethod
|
| 371 |
+
async def _start_worker_execution_loop(self):
|
| 372 |
+
"""Run execution loop on all workers. It guarantees all workers run
|
| 373 |
+
the loop or None of them is running the loop. Loop can be stopped by
|
| 374 |
+
`stop_remote_worker_execution_loop`.
|
| 375 |
+
The API is idempotent (guarantee only 1 loop run at any moment)."""
|
| 376 |
+
raise NotImplementedError
|
.venv/lib/python3.11/site-packages/vllm/executor/mp_distributed_executor.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import os
|
| 5 |
+
from typing import Any, Callable, List, Optional, Union
|
| 6 |
+
|
| 7 |
+
import cloudpickle
|
| 8 |
+
|
| 9 |
+
from vllm.executor.executor_base import DistributedExecutorBase
|
| 10 |
+
from vllm.executor.multiproc_worker_utils import (
|
| 11 |
+
ProcessWorkerWrapper, ResultHandler, WorkerMonitor,
|
| 12 |
+
set_multiprocessing_worker_envs)
|
| 13 |
+
from vllm.logger import init_logger
|
| 14 |
+
from vllm.model_executor.layers.sampler import SamplerOutput
|
| 15 |
+
from vllm.sequence import ExecuteModelRequest
|
| 16 |
+
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
|
| 17 |
+
get_distributed_init_method, get_ip, get_open_port,
|
| 18 |
+
make_async, run_method, update_environment_variables)
|
| 19 |
+
from vllm.worker.worker_base import WorkerWrapperBase
|
| 20 |
+
|
| 21 |
+
logger = init_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MultiprocessingDistributedExecutor(DistributedExecutorBase):
|
| 25 |
+
"""Python multiprocessing-based distributed executor"""
|
| 26 |
+
|
| 27 |
+
uses_ray: bool = False
|
| 28 |
+
|
| 29 |
+
def _check_cuda(self) -> None:
|
| 30 |
+
"""Check that the number of GPUs is sufficient for the parallel
|
| 31 |
+
configuration. Separate from _init_executor to reduce the number of
|
| 32 |
+
indented blocks.
|
| 33 |
+
"""
|
| 34 |
+
parallel_config = self.parallel_config
|
| 35 |
+
world_size = parallel_config.world_size
|
| 36 |
+
tensor_parallel_size = parallel_config.tensor_parallel_size
|
| 37 |
+
|
| 38 |
+
cuda_device_count = cuda_device_count_stateless()
|
| 39 |
+
# Use confusing message for more common TP-only case.
|
| 40 |
+
if tensor_parallel_size > cuda_device_count:
|
| 41 |
+
raise RuntimeError(
|
| 42 |
+
f"please set tensor_parallel_size ({tensor_parallel_size}) "
|
| 43 |
+
f"to less than max local gpu count ({cuda_device_count})")
|
| 44 |
+
|
| 45 |
+
if world_size > cuda_device_count:
|
| 46 |
+
raise RuntimeError(
|
| 47 |
+
f"please ensure that world_size ({world_size}) "
|
| 48 |
+
f"is less than than max local gpu count ({cuda_device_count})")
|
| 49 |
+
|
| 50 |
+
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
|
| 51 |
+
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
| 52 |
+
update_environment_variables({
|
| 53 |
+
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
|
| 54 |
+
})
|
| 55 |
+
|
| 56 |
+
def _init_executor(self) -> None:
|
| 57 |
+
|
| 58 |
+
from vllm.platforms import current_platform
|
| 59 |
+
if current_platform.is_cuda_alike():
|
| 60 |
+
self._check_cuda()
|
| 61 |
+
|
| 62 |
+
# Create the parallel GPU workers.
|
| 63 |
+
world_size = self.parallel_config.world_size
|
| 64 |
+
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
| 65 |
+
|
| 66 |
+
# Set multiprocessing envs that are common to V0 and V1
|
| 67 |
+
set_multiprocessing_worker_envs(self.parallel_config)
|
| 68 |
+
|
| 69 |
+
# Multiprocessing-based executor does not support multi-node setting.
|
| 70 |
+
# Since it only works for single node, we can use the loopback address
|
| 71 |
+
# 127.0.0.1 for communication.
|
| 72 |
+
distributed_init_method = get_distributed_init_method(
|
| 73 |
+
"127.0.0.1", get_open_port())
|
| 74 |
+
|
| 75 |
+
self.workers: List[ProcessWorkerWrapper] = []
|
| 76 |
+
# This is the list of workers that are rank 0 of each TP group EXCEPT
|
| 77 |
+
# global rank 0. These are the workers that will broadcast to the
|
| 78 |
+
# rest of the workers.
|
| 79 |
+
self.tp_driver_workers: List[ProcessWorkerWrapper] = []
|
| 80 |
+
# This is the list of workers that are not drivers and not the first
|
| 81 |
+
# worker in a TP group. These are the workers that will be
|
| 82 |
+
# broadcasted to.
|
| 83 |
+
self.non_driver_workers: List[ProcessWorkerWrapper] = []
|
| 84 |
+
|
| 85 |
+
if world_size == 1:
|
| 86 |
+
self.worker_monitor = None
|
| 87 |
+
else:
|
| 88 |
+
result_handler = ResultHandler()
|
| 89 |
+
for rank in range(1, world_size):
|
| 90 |
+
worker = ProcessWorkerWrapper(result_handler,
|
| 91 |
+
WorkerWrapperBase,
|
| 92 |
+
self.vllm_config, rank)
|
| 93 |
+
self.workers.append(worker)
|
| 94 |
+
if rank % tensor_parallel_size == 0:
|
| 95 |
+
self.tp_driver_workers.append(worker)
|
| 96 |
+
else:
|
| 97 |
+
self.non_driver_workers.append(worker)
|
| 98 |
+
|
| 99 |
+
self.worker_monitor = WorkerMonitor(self.workers, result_handler)
|
| 100 |
+
result_handler.start()
|
| 101 |
+
self.worker_monitor.start()
|
| 102 |
+
|
| 103 |
+
# Set up signal handlers to shutdown the executor cleanly
|
| 104 |
+
# sometimes gc does not work well
|
| 105 |
+
|
| 106 |
+
self.driver_worker = WorkerWrapperBase(self.vllm_config, 0)
|
| 107 |
+
|
| 108 |
+
all_kwargs = []
|
| 109 |
+
distributed_init_method = get_distributed_init_method(
|
| 110 |
+
get_ip(), get_open_port())
|
| 111 |
+
for i in range(world_size):
|
| 112 |
+
local_rank = i
|
| 113 |
+
rank = i
|
| 114 |
+
kwargs = dict(
|
| 115 |
+
vllm_config=self.vllm_config,
|
| 116 |
+
local_rank=local_rank,
|
| 117 |
+
rank=rank,
|
| 118 |
+
distributed_init_method=distributed_init_method,
|
| 119 |
+
is_driver_worker=(not self.parallel_config)
|
| 120 |
+
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
| 121 |
+
)
|
| 122 |
+
all_kwargs.append(kwargs)
|
| 123 |
+
self._run_workers("init_worker", all_kwargs)
|
| 124 |
+
self._run_workers("init_device")
|
| 125 |
+
self._run_workers("load_model",
|
| 126 |
+
max_concurrent_workers=self.parallel_config.
|
| 127 |
+
max_parallel_loading_workers)
|
| 128 |
+
self.driver_exec_model = make_async(self.driver_worker.execute_model)
|
| 129 |
+
self.pp_locks: Optional[List[asyncio.Lock]] = None
|
| 130 |
+
|
| 131 |
+
def shutdown(self):
|
| 132 |
+
if (worker_monitor := getattr(self, "worker_monitor",
|
| 133 |
+
None)) is not None:
|
| 134 |
+
worker_monitor.close()
|
| 135 |
+
|
| 136 |
+
def _driver_execute_model(
|
| 137 |
+
self, execute_model_req: Optional[ExecuteModelRequest]
|
| 138 |
+
) -> Optional[List[SamplerOutput]]:
|
| 139 |
+
"""Run execute_model in the driver worker.
|
| 140 |
+
|
| 141 |
+
Passing None will cause the driver to stop the model execution
|
| 142 |
+
loop running in each of the remote workers.
|
| 143 |
+
"""
|
| 144 |
+
return self.driver_worker.execute_model(execute_model_req)
|
| 145 |
+
|
| 146 |
+
def _run_workers(
|
| 147 |
+
self,
|
| 148 |
+
method: Union[str, Callable],
|
| 149 |
+
*args,
|
| 150 |
+
async_run_tensor_parallel_workers_only: bool = False,
|
| 151 |
+
max_concurrent_workers: Optional[int] = None,
|
| 152 |
+
**kwargs,
|
| 153 |
+
) -> List[Any]:
|
| 154 |
+
"""Runs the given method on all workers.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
async_run_tensor_parallel_workers_only: If True the method will be
|
| 158 |
+
run only in the remote TP workers, not the driver worker.
|
| 159 |
+
It will also be run asynchronously and return a list of futures
|
| 160 |
+
rather than blocking on the results.
|
| 161 |
+
"""
|
| 162 |
+
if isinstance(method, str):
|
| 163 |
+
sent_method = method
|
| 164 |
+
else:
|
| 165 |
+
sent_method = cloudpickle.dumps(method)
|
| 166 |
+
del method
|
| 167 |
+
|
| 168 |
+
if max_concurrent_workers:
|
| 169 |
+
raise NotImplementedError(
|
| 170 |
+
"max_concurrent_workers is not supported yet.")
|
| 171 |
+
|
| 172 |
+
if async_run_tensor_parallel_workers_only:
|
| 173 |
+
# Run only non-driver workers and just return futures.
|
| 174 |
+
return [
|
| 175 |
+
worker.execute_method(sent_method, *args, **kwargs)
|
| 176 |
+
for worker in self.non_driver_workers
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
# Start all remote workers first.
|
| 180 |
+
worker_outputs = [
|
| 181 |
+
worker.execute_method(sent_method, *args, **kwargs)
|
| 182 |
+
for worker in self.workers
|
| 183 |
+
]
|
| 184 |
+
|
| 185 |
+
driver_worker_output = run_method(self.driver_worker, sent_method,
|
| 186 |
+
args, kwargs)
|
| 187 |
+
|
| 188 |
+
# Get the results of the workers.
|
| 189 |
+
return [driver_worker_output
|
| 190 |
+
] + [output.get() for output in worker_outputs]
|
| 191 |
+
|
| 192 |
+
def check_health(self) -> None:
|
| 193 |
+
"""Raises an error if engine is unhealthy."""
|
| 194 |
+
if self.worker_monitor is not None and not self.worker_monitor.is_alive(
|
| 195 |
+
):
|
| 196 |
+
raise RuntimeError("Worker processes are not running")
|
| 197 |
+
|
| 198 |
+
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
| 199 |
+
"""Wait for futures returned from _run_workers() with
|
| 200 |
+
async_run_remote_workers_only to complete."""
|
| 201 |
+
for result in parallel_worker_tasks:
|
| 202 |
+
result.get()
|
| 203 |
+
|
| 204 |
+
async def _driver_execute_model_async(
|
| 205 |
+
self,
|
| 206 |
+
execute_model_req: Optional[ExecuteModelRequest] = None
|
| 207 |
+
) -> List[SamplerOutput]:
|
| 208 |
+
if not self.tp_driver_workers:
|
| 209 |
+
return await self.driver_exec_model(execute_model_req)
|
| 210 |
+
|
| 211 |
+
if self.pp_locks is None:
|
| 212 |
+
# This locks each pipeline parallel stage so multiple virtual
|
| 213 |
+
# engines can't execute on the same stage at the same time
|
| 214 |
+
# We create the locks here to avoid creating them in the constructor
|
| 215 |
+
# which uses a different asyncio loop.
|
| 216 |
+
self.pp_locks = [
|
| 217 |
+
asyncio.Lock()
|
| 218 |
+
for _ in range(self.parallel_config.pipeline_parallel_size)
|
| 219 |
+
]
|
| 220 |
+
|
| 221 |
+
tasks = [
|
| 222 |
+
asyncio.create_task(
|
| 223 |
+
_run_task_with_lock(self.driver_exec_model, self.pp_locks[0],
|
| 224 |
+
execute_model_req))
|
| 225 |
+
]
|
| 226 |
+
for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
|
| 227 |
+
start=1):
|
| 228 |
+
tasks.append(
|
| 229 |
+
asyncio.create_task(
|
| 230 |
+
_run_task_with_lock(driver_worker.execute_method_async,
|
| 231 |
+
self.pp_locks[pp_rank],
|
| 232 |
+
"execute_model", execute_model_req)))
|
| 233 |
+
results = await asyncio.gather(*tasks)
|
| 234 |
+
|
| 235 |
+
# Only the last PP stage has the final results.
|
| 236 |
+
return results[-1]
|
| 237 |
+
|
| 238 |
+
async def _start_worker_execution_loop(self):
|
| 239 |
+
coros = [
|
| 240 |
+
worker.execute_method_async("start_worker_execution_loop")
|
| 241 |
+
for worker in self.non_driver_workers
|
| 242 |
+
]
|
| 243 |
+
return await asyncio.gather(*coros)
|
.venv/lib/python3.11/site-packages/vllm/executor/msgspec_utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from array import array
|
| 4 |
+
from typing import Any, Type
|
| 5 |
+
|
| 6 |
+
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def encode_hook(obj: Any) -> Any:
|
| 10 |
+
"""Custom msgspec enc hook that supports array types.
|
| 11 |
+
|
| 12 |
+
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
|
| 13 |
+
"""
|
| 14 |
+
if isinstance(obj, array):
|
| 15 |
+
assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, (
|
| 16 |
+
f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
|
| 17 |
+
f"Given array has a type code of {obj.typecode}.")
|
| 18 |
+
return obj.tobytes()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def decode_hook(type: Type, obj: Any) -> Any:
|
| 22 |
+
"""Custom msgspec dec hook that supports array types.
|
| 23 |
+
|
| 24 |
+
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
|
| 25 |
+
"""
|
| 26 |
+
if type is array:
|
| 27 |
+
deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE)
|
| 28 |
+
deserialized.frombytes(obj)
|
| 29 |
+
return deserialized
|
.venv/lib/python3.11/site-packages/vllm/executor/ray_distributed_executor.py
ADDED
|
@@ -0,0 +1,638 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import os
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
| 8 |
+
|
| 9 |
+
import cloudpickle
|
| 10 |
+
import msgspec
|
| 11 |
+
|
| 12 |
+
import vllm.envs as envs
|
| 13 |
+
from vllm.executor.executor_base import (
|
| 14 |
+
DistributedExecutorBase) # yapf: disable
|
| 15 |
+
from vllm.executor.msgspec_utils import encode_hook
|
| 16 |
+
from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster,
|
| 17 |
+
ray)
|
| 18 |
+
from vllm.logger import init_logger
|
| 19 |
+
from vllm.model_executor.layers.sampler import SamplerOutput
|
| 20 |
+
from vllm.platforms import current_platform
|
| 21 |
+
from vllm.sequence import ExecuteModelRequest
|
| 22 |
+
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
|
| 23 |
+
get_ip, get_open_port, make_async)
|
| 24 |
+
|
| 25 |
+
if ray is not None:
|
| 26 |
+
from ray.actor import ActorHandle
|
| 27 |
+
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
| 28 |
+
else:
|
| 29 |
+
ActorHandle = None
|
| 30 |
+
|
| 31 |
+
if TYPE_CHECKING:
|
| 32 |
+
from ray.util.placement_group import PlacementGroup
|
| 33 |
+
|
| 34 |
+
logger = init_logger(__name__)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class RayWorkerMetaData:
|
| 39 |
+
"""
|
| 40 |
+
Metadata for a Ray worker.
|
| 41 |
+
The order of ray worker creation can be random,
|
| 42 |
+
and we need to reset the rank after creating all workers.
|
| 43 |
+
"""
|
| 44 |
+
worker: ActorHandle
|
| 45 |
+
created_rank: int
|
| 46 |
+
adjusted_rank: int = -1
|
| 47 |
+
ip: str = ""
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class RayDistributedExecutor(DistributedExecutorBase):
|
| 51 |
+
|
| 52 |
+
uses_ray: bool = True
|
| 53 |
+
|
| 54 |
+
def _init_executor(self) -> None:
|
| 55 |
+
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
|
| 56 |
+
if envs.VLLM_USE_V1:
|
| 57 |
+
# v1 always uses the compiled DAG and SPMD worker.
|
| 58 |
+
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
|
| 59 |
+
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
|
| 60 |
+
# If the env var is set, it uses the Ray's compiled DAG API
|
| 61 |
+
# which optimizes the control plane overhead.
|
| 62 |
+
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
| 63 |
+
# Currently, this requires USE_RAY_SPMD_WORKER=True.
|
| 64 |
+
self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
|
| 65 |
+
# If the env var is set, then we do not distinguish between the
|
| 66 |
+
# "driver worker" vs other workers. Also, the rank 0 worker will
|
| 67 |
+
# be executed in a remote Ray worker. Currently this requires
|
| 68 |
+
# USE_RAY_COMPILED_DAG=True.
|
| 69 |
+
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
|
| 70 |
+
if self.use_ray_compiled_dag:
|
| 71 |
+
assert self.use_ray_spmd_worker, (
|
| 72 |
+
"VLLM_USE_RAY_COMPILED_DAG=1 requires "
|
| 73 |
+
"VLLM_USE_RAY_SPMD_WORKER=1")
|
| 74 |
+
if self.use_ray_spmd_worker:
|
| 75 |
+
# TODO: Support SPMD worker for non-DAG Ray executor.
|
| 76 |
+
assert self.use_ray_compiled_dag, (
|
| 77 |
+
"VLLM_USE_RAY_SPMD_WORKER=1 requires "
|
| 78 |
+
"VLLM_USE_RAY_COMPILED_DAG=1")
|
| 79 |
+
|
| 80 |
+
assert self.uses_ray
|
| 81 |
+
initialize_ray_cluster(self.parallel_config)
|
| 82 |
+
placement_group = self.parallel_config.placement_group
|
| 83 |
+
|
| 84 |
+
# Disable Ray usage stats collection.
|
| 85 |
+
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
|
| 86 |
+
if ray_usage != "1":
|
| 87 |
+
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
|
| 88 |
+
|
| 89 |
+
# Create the parallel GPU workers.
|
| 90 |
+
self._init_workers_ray(placement_group)
|
| 91 |
+
|
| 92 |
+
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
|
| 93 |
+
self.output_decoder = msgspec.msgpack.Decoder(
|
| 94 |
+
Optional[List[SamplerOutput]])
|
| 95 |
+
self.use_v1 = envs.VLLM_USE_V1
|
| 96 |
+
|
| 97 |
+
self.pp_locks: Optional[List[asyncio.Lock]] = None
|
| 98 |
+
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
|
| 99 |
+
if not self.use_ray_compiled_dag:
|
| 100 |
+
self.driver_exec_method = make_async(
|
| 101 |
+
self.driver_worker.execute_method)
|
| 102 |
+
|
| 103 |
+
def shutdown(self) -> None:
|
| 104 |
+
if hasattr(self, "forward_dag") and self.forward_dag is not None:
|
| 105 |
+
self.forward_dag.teardown()
|
| 106 |
+
import ray
|
| 107 |
+
for worker in self.workers:
|
| 108 |
+
ray.kill(worker)
|
| 109 |
+
self.forward_dag = None
|
| 110 |
+
|
| 111 |
+
def _configure_ray_workers_use_nsight(self,
|
| 112 |
+
ray_remote_kwargs) -> Dict[str, Any]:
|
| 113 |
+
# If nsight profiling is enabled, we need to set the profiling
|
| 114 |
+
# configuration for the ray workers as runtime env.
|
| 115 |
+
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
|
| 116 |
+
runtime_env.update({
|
| 117 |
+
"nsight": {
|
| 118 |
+
"t": "cuda,cudnn,cublas",
|
| 119 |
+
"o": "'worker_process_%p'",
|
| 120 |
+
"cuda-graph-trace": "node",
|
| 121 |
+
}
|
| 122 |
+
})
|
| 123 |
+
|
| 124 |
+
return ray_remote_kwargs
|
| 125 |
+
|
| 126 |
+
# child class could overwrite this to return actual env vars.
|
| 127 |
+
def _get_env_vars_to_be_updated(self):
|
| 128 |
+
return self._env_vars_for_all_workers
|
| 129 |
+
|
| 130 |
+
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
| 131 |
+
**ray_remote_kwargs):
|
| 132 |
+
num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
|
| 133 |
+
|
| 134 |
+
# The driver dummy worker does not actually use any resources.
|
| 135 |
+
# It holds the resource for the driver worker.
|
| 136 |
+
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
|
| 137 |
+
# The remaining workers are the actual ray actors.
|
| 138 |
+
self.workers: List[RayWorkerWrapper] = []
|
| 139 |
+
|
| 140 |
+
# Used in ray compiled DAG: indexed first by PP rank,
|
| 141 |
+
# and then TP rank. In other words, the inner list is
|
| 142 |
+
# the TP group of workers for a PP rank.
|
| 143 |
+
self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
|
| 144 |
+
|
| 145 |
+
if self.parallel_config.ray_workers_use_nsight:
|
| 146 |
+
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
|
| 147 |
+
ray_remote_kwargs)
|
| 148 |
+
|
| 149 |
+
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
|
| 150 |
+
|
| 151 |
+
# Create the workers.
|
| 152 |
+
bundle_indices: List[int]
|
| 153 |
+
if envs.VLLM_RAY_BUNDLE_INDICES:
|
| 154 |
+
# Use the bundle indices specified by the user.
|
| 155 |
+
bundle_indices = list(
|
| 156 |
+
map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
|
| 157 |
+
assert len(bundle_indices) == self.parallel_config.world_size, \
|
| 158 |
+
("VLLM_RAY_BUNDLE_INDICES must have the same size"
|
| 159 |
+
f" as the world size, but got {bundle_indices=} "
|
| 160 |
+
f"and {self.parallel_config.world_size=}")
|
| 161 |
+
assert len(set(bundle_indices)) == len(bundle_indices), \
|
| 162 |
+
("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
|
| 163 |
+
f" but got {bundle_indices=}")
|
| 164 |
+
else:
|
| 165 |
+
# use the first N bundles that have GPU resources.
|
| 166 |
+
bundle_indices = []
|
| 167 |
+
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
| 168 |
+
if bundle.get(current_platform.ray_device_key, 0):
|
| 169 |
+
bundle_indices.append(bundle_id)
|
| 170 |
+
bundle_indices = bundle_indices[:self.parallel_config.world_size]
|
| 171 |
+
|
| 172 |
+
worker_metadata: List[RayWorkerMetaData] = []
|
| 173 |
+
driver_ip = get_ip()
|
| 174 |
+
for rank, bundle_id in enumerate(bundle_indices):
|
| 175 |
+
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
| 176 |
+
placement_group=placement_group,
|
| 177 |
+
placement_group_capture_child_tasks=True,
|
| 178 |
+
placement_group_bundle_index=bundle_id,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
if current_platform.ray_device_key == "GPU":
|
| 182 |
+
# NV+AMD GPUs, and Intel XPUs
|
| 183 |
+
worker = ray.remote(
|
| 184 |
+
num_cpus=0,
|
| 185 |
+
num_gpus=num_gpus,
|
| 186 |
+
scheduling_strategy=scheduling_strategy,
|
| 187 |
+
**ray_remote_kwargs,
|
| 188 |
+
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
|
| 189 |
+
rpc_rank=rank)
|
| 190 |
+
else:
|
| 191 |
+
worker = ray.remote(
|
| 192 |
+
num_cpus=0,
|
| 193 |
+
num_gpus=0,
|
| 194 |
+
resources={current_platform.ray_device_key: num_gpus},
|
| 195 |
+
scheduling_strategy=scheduling_strategy,
|
| 196 |
+
**ray_remote_kwargs,
|
| 197 |
+
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
|
| 198 |
+
rpc_rank=rank)
|
| 199 |
+
worker_metadata.append(
|
| 200 |
+
RayWorkerMetaData(worker=worker, created_rank=rank))
|
| 201 |
+
|
| 202 |
+
worker_ips = ray.get([
|
| 203 |
+
each.worker.get_node_ip.remote() # type: ignore[attr-defined]
|
| 204 |
+
for each in worker_metadata
|
| 205 |
+
])
|
| 206 |
+
|
| 207 |
+
for each, ip in zip(worker_metadata, worker_ips):
|
| 208 |
+
each.ip = ip
|
| 209 |
+
|
| 210 |
+
if not self.use_ray_spmd_worker:
|
| 211 |
+
for i, each in enumerate(worker_metadata):
|
| 212 |
+
# find and remove the dummy worker from the list
|
| 213 |
+
worker = each.worker
|
| 214 |
+
worker_ip = each.ip
|
| 215 |
+
if self.driver_dummy_worker is None and worker_ip == driver_ip:
|
| 216 |
+
# If the worker is on the same node as the driver, we use it
|
| 217 |
+
# as the resource holder for the driver process.
|
| 218 |
+
self.driver_dummy_worker = worker
|
| 219 |
+
self.driver_worker = RayWorkerWrapper(
|
| 220 |
+
vllm_config=self.vllm_config, rpc_rank=0)
|
| 221 |
+
worker_metadata.pop(i)
|
| 222 |
+
break
|
| 223 |
+
|
| 224 |
+
logger.debug("workers: %s", worker_metadata)
|
| 225 |
+
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
|
| 226 |
+
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
|
| 227 |
+
raise ValueError(
|
| 228 |
+
"Ray does not allocate any GPUs on the driver node. Consider "
|
| 229 |
+
"adjusting the Ray placement group or running the driver on a "
|
| 230 |
+
"GPU node.")
|
| 231 |
+
|
| 232 |
+
ip_counts: Dict[str, int] = {}
|
| 233 |
+
for ip in worker_ips:
|
| 234 |
+
ip_counts[ip] = ip_counts.get(ip, 0) + 1
|
| 235 |
+
|
| 236 |
+
def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
|
| 237 |
+
"""
|
| 238 |
+
Sort the workers based on 3 properties:
|
| 239 |
+
1. If the worker is on the same node as the driver (vllm engine),
|
| 240 |
+
it should be placed first.
|
| 241 |
+
2. Then, if the worker is on a node with fewer workers, it should
|
| 242 |
+
be placed first.
|
| 243 |
+
3. Finally, if the work is on a node with smaller IP address, it
|
| 244 |
+
should be placed first.
|
| 245 |
+
"""
|
| 246 |
+
ip = item.ip
|
| 247 |
+
return (0 if ip == driver_ip else 1, ip_counts[ip], ip)
|
| 248 |
+
|
| 249 |
+
# After sorting, the workers on the same node will be
|
| 250 |
+
# close to each other, and the workers on the driver
|
| 251 |
+
# node will be placed first.
|
| 252 |
+
sorted_worker_metadata = sorted(worker_metadata,
|
| 253 |
+
key=sort_by_driver_then_worker_ip)
|
| 254 |
+
start_rank = 0 if self.use_ray_spmd_worker else 1
|
| 255 |
+
for i, item in enumerate(sorted_worker_metadata):
|
| 256 |
+
item.adjusted_rank = i + start_rank
|
| 257 |
+
self.workers = [item.worker for item in sorted_worker_metadata]
|
| 258 |
+
rerank_mapping = {
|
| 259 |
+
item.created_rank: item.adjusted_rank
|
| 260 |
+
for item in sorted_worker_metadata
|
| 261 |
+
}
|
| 262 |
+
self._run_workers("adjust_rank", rerank_mapping)
|
| 263 |
+
|
| 264 |
+
# Get the set of GPU IDs used on each node.
|
| 265 |
+
worker_node_and_gpu_ids = []
|
| 266 |
+
for worker in [self.driver_dummy_worker] + self.workers:
|
| 267 |
+
if worker is None:
|
| 268 |
+
# driver_dummy_worker can be None when using ray spmd worker.
|
| 269 |
+
continue
|
| 270 |
+
worker_node_and_gpu_ids.append(
|
| 271 |
+
ray.get(worker.get_node_and_gpu_ids.remote()) \
|
| 272 |
+
) # type: ignore
|
| 273 |
+
|
| 274 |
+
node_workers = defaultdict(list) # node id -> list of worker ranks
|
| 275 |
+
node_gpus = defaultdict(list) # node id -> list of gpu ids
|
| 276 |
+
|
| 277 |
+
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
|
| 278 |
+
node_workers[node_id].append(i)
|
| 279 |
+
# `gpu_ids` can be a list of strings or integers.
|
| 280 |
+
# convert them to integers for consistency.
|
| 281 |
+
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
|
| 282 |
+
# string sorting is not sufficient.
|
| 283 |
+
# see https://github.com/vllm-project/vllm/issues/5590
|
| 284 |
+
gpu_ids = [int(x) for x in gpu_ids]
|
| 285 |
+
node_gpus[node_id].extend(gpu_ids)
|
| 286 |
+
for node_id, gpu_ids in node_gpus.items():
|
| 287 |
+
node_gpus[node_id] = sorted(gpu_ids)
|
| 288 |
+
|
| 289 |
+
all_ips = set(worker_ips + [driver_ip])
|
| 290 |
+
n_ips = len(all_ips)
|
| 291 |
+
n_nodes = len(node_workers)
|
| 292 |
+
|
| 293 |
+
if n_nodes != n_ips:
|
| 294 |
+
raise RuntimeError(
|
| 295 |
+
f"Every node should have a unique IP address. Got {n_nodes}"
|
| 296 |
+
f" nodes with node ids {list(node_workers.keys())} and "
|
| 297 |
+
f"{n_ips} unique IP addresses {all_ips}. Please check your"
|
| 298 |
+
" network configuration. If you set `VLLM_HOST_IP`"
|
| 299 |
+
" environment variable, make sure it is unique for"
|
| 300 |
+
" each node.")
|
| 301 |
+
|
| 302 |
+
# Set environment variables for the driver and workers.
|
| 303 |
+
all_args_to_update_environment_variables = [{
|
| 304 |
+
current_platform.device_control_env_var:
|
| 305 |
+
",".join(map(str, node_gpus[node_id])),
|
| 306 |
+
} for (node_id, _) in worker_node_and_gpu_ids]
|
| 307 |
+
|
| 308 |
+
for args in all_args_to_update_environment_variables:
|
| 309 |
+
# some carry-over env vars from the driver
|
| 310 |
+
# TODO: refactor platform-specific env vars
|
| 311 |
+
for name in [
|
| 312 |
+
"VLLM_ATTENTION_BACKEND",
|
| 313 |
+
"TPU_CHIPS_PER_HOST_BOUNDS",
|
| 314 |
+
"TPU_HOST_BOUNDS",
|
| 315 |
+
"VLLM_USE_V1",
|
| 316 |
+
"VLLM_TRACE_FUNCTION",
|
| 317 |
+
]:
|
| 318 |
+
if name in os.environ:
|
| 319 |
+
args[name] = os.environ[name]
|
| 320 |
+
|
| 321 |
+
self._env_vars_for_all_workers = (
|
| 322 |
+
all_args_to_update_environment_variables)
|
| 323 |
+
|
| 324 |
+
self._run_workers("update_environment_variables",
|
| 325 |
+
self._get_env_vars_to_be_updated())
|
| 326 |
+
|
| 327 |
+
if len(node_gpus) == 1:
|
| 328 |
+
# in single node case, we don't need to get the IP address.
|
| 329 |
+
# the loopback address is sufficient
|
| 330 |
+
# NOTE: a node may have several IP addresses, one for each
|
| 331 |
+
# network interface. `get_ip()` might return any of them,
|
| 332 |
+
# while they might not work for communication inside the node
|
| 333 |
+
# if the network setup is complicated. Using the loopback address
|
| 334 |
+
# solves this issue, as it always works for communication inside
|
| 335 |
+
# the node.
|
| 336 |
+
driver_ip = "127.0.0.1"
|
| 337 |
+
distributed_init_method = get_distributed_init_method(
|
| 338 |
+
driver_ip, get_open_port())
|
| 339 |
+
|
| 340 |
+
# Initialize the actual workers inside worker wrapper.
|
| 341 |
+
all_kwargs = []
|
| 342 |
+
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
|
| 343 |
+
local_rank = node_workers[node_id].index(rank)
|
| 344 |
+
kwargs = dict(
|
| 345 |
+
vllm_config=self.vllm_config,
|
| 346 |
+
local_rank=local_rank,
|
| 347 |
+
rank=rank,
|
| 348 |
+
distributed_init_method=distributed_init_method,
|
| 349 |
+
is_driver_worker=(not self.parallel_config)
|
| 350 |
+
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
| 351 |
+
)
|
| 352 |
+
all_kwargs.append(kwargs)
|
| 353 |
+
self._run_workers("init_worker", all_kwargs)
|
| 354 |
+
|
| 355 |
+
self._run_workers("init_device")
|
| 356 |
+
self._run_workers("load_model",
|
| 357 |
+
max_concurrent_workers=self.parallel_config.
|
| 358 |
+
max_parallel_loading_workers)
|
| 359 |
+
|
| 360 |
+
if self.use_ray_spmd_worker:
|
| 361 |
+
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
|
| 362 |
+
self.pp_tp_workers.append([])
|
| 363 |
+
for tp_rank in range(
|
| 364 |
+
self.parallel_config.tensor_parallel_size):
|
| 365 |
+
# PP=2, TP=4
|
| 366 |
+
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
|
| 367 |
+
rank = (pp_rank * self.parallel_config.tensor_parallel_size
|
| 368 |
+
) + tp_rank
|
| 369 |
+
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
|
| 370 |
+
assert pp_rank < len(self.pp_tp_workers)
|
| 371 |
+
self.pp_tp_workers[pp_rank].append(self.workers[rank])
|
| 372 |
+
|
| 373 |
+
# This is the list of workers that are rank 0 of each TP group EXCEPT
|
| 374 |
+
# global rank 0. These are the workers that will broadcast to the
|
| 375 |
+
# rest of the workers.
|
| 376 |
+
self.tp_driver_workers: List[RayWorkerWrapper] = []
|
| 377 |
+
# This is the list of workers that are not drivers and not the first
|
| 378 |
+
# worker in a TP group. These are the workers that will be
|
| 379 |
+
# broadcasted to.
|
| 380 |
+
self.non_driver_workers: List[RayWorkerWrapper] = []
|
| 381 |
+
|
| 382 |
+
# Enforce rank order for correct rank to return final output.
|
| 383 |
+
for index, worker in enumerate(self.workers):
|
| 384 |
+
# The driver worker is rank 0 and not in self.workers.
|
| 385 |
+
rank = index + 1
|
| 386 |
+
if rank % self.parallel_config.tensor_parallel_size == 0:
|
| 387 |
+
self.tp_driver_workers.append(worker)
|
| 388 |
+
else:
|
| 389 |
+
self.non_driver_workers.append(worker)
|
| 390 |
+
|
| 391 |
+
def _driver_execute_model(
|
| 392 |
+
self, execute_model_req: Optional[ExecuteModelRequest]
|
| 393 |
+
) -> Optional[List[SamplerOutput]]:
|
| 394 |
+
"""Run execute_model in the driver worker.
|
| 395 |
+
|
| 396 |
+
Passing None will cause the driver to stop the model execution
|
| 397 |
+
loop running in each of the remote workers.
|
| 398 |
+
"""
|
| 399 |
+
assert not self.use_ray_spmd_worker, (
|
| 400 |
+
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
|
| 401 |
+
return self.driver_worker.execute_method("execute_model",
|
| 402 |
+
execute_model_req)
|
| 403 |
+
|
| 404 |
+
def execute_model(
|
| 405 |
+
self,
|
| 406 |
+
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
| 407 |
+
if not self.use_ray_spmd_worker:
|
| 408 |
+
return super().execute_model(execute_model_req)
|
| 409 |
+
|
| 410 |
+
if self.forward_dag is None:
|
| 411 |
+
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
|
| 412 |
+
|
| 413 |
+
if self.use_v1:
|
| 414 |
+
serialized_data = execute_model_req
|
| 415 |
+
else:
|
| 416 |
+
serialized_data = self.input_encoder.encode(execute_model_req)
|
| 417 |
+
outputs = ray.get(self.forward_dag.execute(serialized_data))
|
| 418 |
+
if self.use_v1:
|
| 419 |
+
output = outputs[0]
|
| 420 |
+
else:
|
| 421 |
+
output = self.output_decoder.decode(outputs[0])
|
| 422 |
+
return output
|
| 423 |
+
|
| 424 |
+
def _run_workers(
|
| 425 |
+
self,
|
| 426 |
+
method: Union[str, Callable],
|
| 427 |
+
*args,
|
| 428 |
+
async_run_tensor_parallel_workers_only: bool = False,
|
| 429 |
+
max_concurrent_workers: Optional[int] = None,
|
| 430 |
+
**kwargs,
|
| 431 |
+
) -> Any:
|
| 432 |
+
"""Runs the given method on all workers. Can be used in the following
|
| 433 |
+
ways:
|
| 434 |
+
|
| 435 |
+
Args:
|
| 436 |
+
- async_run_tensor_parallel_workers_only: If True the method will be
|
| 437 |
+
run only in the remote TP workers, not the driver worker.
|
| 438 |
+
It will also be run asynchronously and return a list of futures
|
| 439 |
+
rather than blocking on the results.
|
| 440 |
+
- args/kwargs: All workers share the same args/kwargs
|
| 441 |
+
"""
|
| 442 |
+
if isinstance(method, str):
|
| 443 |
+
sent_method = method
|
| 444 |
+
else:
|
| 445 |
+
sent_method = cloudpickle.dumps(method)
|
| 446 |
+
del method
|
| 447 |
+
if self.use_ray_spmd_worker:
|
| 448 |
+
assert not async_run_tensor_parallel_workers_only, (
|
| 449 |
+
"async_run_tensor_parallel_workers_only is not supported for "
|
| 450 |
+
"spmd mode.")
|
| 451 |
+
|
| 452 |
+
if max_concurrent_workers:
|
| 453 |
+
raise NotImplementedError(
|
| 454 |
+
"max_concurrent_workers is not supported yet.")
|
| 455 |
+
|
| 456 |
+
# Start the ray workers first.
|
| 457 |
+
ray_workers = self.workers
|
| 458 |
+
if async_run_tensor_parallel_workers_only:
|
| 459 |
+
ray_workers = self.non_driver_workers
|
| 460 |
+
ray_worker_outputs = [
|
| 461 |
+
worker.execute_method.remote(sent_method, *args, **kwargs)
|
| 462 |
+
for worker in ray_workers
|
| 463 |
+
]
|
| 464 |
+
|
| 465 |
+
if async_run_tensor_parallel_workers_only:
|
| 466 |
+
# Just return futures
|
| 467 |
+
return ray_worker_outputs
|
| 468 |
+
|
| 469 |
+
driver_worker_output = []
|
| 470 |
+
# In SPMD mode, the driver worker is the same as any other worker,
|
| 471 |
+
# so we only explicitly execute on the driver worker if using a
|
| 472 |
+
# non-SPMD worker class.
|
| 473 |
+
if not self.use_ray_spmd_worker:
|
| 474 |
+
# Start the driver worker after all the ray workers.
|
| 475 |
+
driver_worker_output = [
|
| 476 |
+
self.driver_worker.execute_method(sent_method, *args, **kwargs)
|
| 477 |
+
]
|
| 478 |
+
|
| 479 |
+
# Get the results of the ray workers.
|
| 480 |
+
if self.workers:
|
| 481 |
+
ray_worker_outputs = ray.get(ray_worker_outputs)
|
| 482 |
+
|
| 483 |
+
return driver_worker_output + ray_worker_outputs
|
| 484 |
+
|
| 485 |
+
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
| 486 |
+
"""Wait for futures returned from _run_workers() with
|
| 487 |
+
async_run_remote_workers_only to complete."""
|
| 488 |
+
ray.get(parallel_worker_tasks)
|
| 489 |
+
|
| 490 |
+
def _check_ray_adag_installation(self):
|
| 491 |
+
import pkg_resources
|
| 492 |
+
from packaging import version
|
| 493 |
+
|
| 494 |
+
required_version = version.parse("2.40")
|
| 495 |
+
current_version = version.parse(
|
| 496 |
+
pkg_resources.get_distribution("ray").version)
|
| 497 |
+
if current_version < required_version:
|
| 498 |
+
raise ValueError(f"Ray version {required_version} is "
|
| 499 |
+
f"required, but found {current_version}")
|
| 500 |
+
|
| 501 |
+
import importlib.util
|
| 502 |
+
adag_spec = importlib.util.find_spec(
|
| 503 |
+
"ray.experimental.compiled_dag_ref")
|
| 504 |
+
if adag_spec is None:
|
| 505 |
+
raise ValueError("Ray accelerated DAG is not installed. "
|
| 506 |
+
"Run `pip install ray[adag]` to install it.")
|
| 507 |
+
|
| 508 |
+
cupy_spec = importlib.util.find_spec("cupy")
|
| 509 |
+
if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL:
|
| 510 |
+
raise ValueError(
|
| 511 |
+
"cupy is not installed but required since "
|
| 512 |
+
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set."
|
| 513 |
+
"Run `pip install ray[adag]` and check cupy installation.")
|
| 514 |
+
|
| 515 |
+
def _compiled_ray_dag(self, enable_asyncio: bool):
|
| 516 |
+
assert self.parallel_config.use_ray
|
| 517 |
+
self._check_ray_adag_installation()
|
| 518 |
+
from ray.dag import InputNode, MultiOutputNode
|
| 519 |
+
from ray.experimental.channel.torch_tensor_type import TorchTensorType
|
| 520 |
+
|
| 521 |
+
logger.info("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = %s",
|
| 522 |
+
envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL)
|
| 523 |
+
logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
|
| 524 |
+
envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
|
| 525 |
+
with InputNode() as input_data:
|
| 526 |
+
# Example DAG: PP=2, TP=4
|
| 527 |
+
# (ExecuteModelReq, None) -> 0 -> (ExecuteModelReq, IntermediateOutput) -> 4 -> SamplerOutput # noqa: E501
|
| 528 |
+
# -> 1 -> (ExecuteModelReq, IntermediateOutput) -> 5 -> SamplerOutput # noqa: E501
|
| 529 |
+
# -> 2 -> (ExecuteModelReq, IntermediateOutput) -> 6 -> SamplerOutput # noqa: E501
|
| 530 |
+
# -> 3 -> (ExecuteModelReq, IntermediateOutput) -> 7 -> SamplerOutput # noqa: E501
|
| 531 |
+
|
| 532 |
+
# All workers in the first TP group will take in the
|
| 533 |
+
# ExecuteModelRequest as input.
|
| 534 |
+
outputs = [input_data for _ in self.pp_tp_workers[0]]
|
| 535 |
+
for pp_rank, tp_group in enumerate(self.pp_tp_workers):
|
| 536 |
+
# Each PP worker takes in the output of the previous PP worker,
|
| 537 |
+
# and the TP group executes in SPMD fashion.
|
| 538 |
+
if self.use_v1:
|
| 539 |
+
outputs = [
|
| 540 |
+
worker.execute_model.
|
| 541 |
+
bind( # type: ignore[attr-defined]
|
| 542 |
+
outputs[i]) for i, worker in enumerate(tp_group)
|
| 543 |
+
]
|
| 544 |
+
else:
|
| 545 |
+
outputs = [
|
| 546 |
+
worker.execute_model_spmd.
|
| 547 |
+
bind( # type: ignore[attr-defined]
|
| 548 |
+
outputs[i]) for i, worker in enumerate(tp_group)
|
| 549 |
+
]
|
| 550 |
+
|
| 551 |
+
last_pp_rank = len(self.pp_tp_workers) - 1
|
| 552 |
+
if pp_rank < last_pp_rank:
|
| 553 |
+
# Specify how intermediate tensors should be passed
|
| 554 |
+
# between pp stages, no need to specify for the last
|
| 555 |
+
# pp stage.
|
| 556 |
+
transport = "nccl" \
|
| 557 |
+
if envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL \
|
| 558 |
+
else "auto"
|
| 559 |
+
outputs = [
|
| 560 |
+
output.with_type_hint(
|
| 561 |
+
TorchTensorType(transport=transport))
|
| 562 |
+
for output in outputs
|
| 563 |
+
]
|
| 564 |
+
|
| 565 |
+
forward_dag = MultiOutputNode(outputs)
|
| 566 |
+
|
| 567 |
+
return forward_dag.experimental_compile(
|
| 568 |
+
enable_asyncio=enable_asyncio,
|
| 569 |
+
_overlap_gpu_communication=envs.
|
| 570 |
+
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
|
| 571 |
+
|
| 572 |
+
def __del__(self):
|
| 573 |
+
self.shutdown()
|
| 574 |
+
|
| 575 |
+
async def execute_model_async(
|
| 576 |
+
self,
|
| 577 |
+
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
| 578 |
+
if not self.use_ray_spmd_worker:
|
| 579 |
+
return await super().execute_model_async(execute_model_req)
|
| 580 |
+
|
| 581 |
+
if self.forward_dag is None:
|
| 582 |
+
self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
|
| 583 |
+
|
| 584 |
+
serialized_data = self.input_encoder.encode(execute_model_req)
|
| 585 |
+
dag_future = await self.forward_dag.execute_async(serialized_data)
|
| 586 |
+
output = await dag_future[0]
|
| 587 |
+
return self.output_decoder.decode(output)
|
| 588 |
+
|
| 589 |
+
async def _driver_execute_model_async(
|
| 590 |
+
self,
|
| 591 |
+
execute_model_req: Optional[ExecuteModelRequest] = None
|
| 592 |
+
) -> List[SamplerOutput]:
|
| 593 |
+
assert not self.use_ray_spmd_worker, (
|
| 594 |
+
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
|
| 595 |
+
if not self.tp_driver_workers:
|
| 596 |
+
return await self.driver_exec_method("execute_model",
|
| 597 |
+
execute_model_req)
|
| 598 |
+
if self.pp_locks is None:
|
| 599 |
+
# This locks each pipeline parallel stage so multiple virtual
|
| 600 |
+
# engines can't execute on the same stage at the same time
|
| 601 |
+
# We create the locks here to avoid creating them in the constructor
|
| 602 |
+
# which uses a different asyncio loop.
|
| 603 |
+
self.pp_locks = [
|
| 604 |
+
asyncio.Lock()
|
| 605 |
+
for _ in range(self.parallel_config.pipeline_parallel_size)
|
| 606 |
+
]
|
| 607 |
+
|
| 608 |
+
tasks = [
|
| 609 |
+
asyncio.create_task(
|
| 610 |
+
_run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
|
| 611 |
+
"execute_model", execute_model_req))
|
| 612 |
+
]
|
| 613 |
+
for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
|
| 614 |
+
start=1):
|
| 615 |
+
tasks.append(
|
| 616 |
+
asyncio.create_task(
|
| 617 |
+
_run_task_with_lock(driver_worker.execute_method.remote,
|
| 618 |
+
self.pp_locks[pp_rank],
|
| 619 |
+
"execute_model", execute_model_req)))
|
| 620 |
+
|
| 621 |
+
results = await asyncio.gather(*tasks)
|
| 622 |
+
|
| 623 |
+
# Only the last PP stage has the final results.
|
| 624 |
+
return results[-1]
|
| 625 |
+
|
| 626 |
+
async def _start_worker_execution_loop(self):
|
| 627 |
+
assert not self.use_ray_spmd_worker, (
|
| 628 |
+
"worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1")
|
| 629 |
+
coros = [
|
| 630 |
+
worker.execute_method.remote("start_worker_execution_loop")
|
| 631 |
+
for worker in self.non_driver_workers
|
| 632 |
+
]
|
| 633 |
+
return await asyncio.gather(*coros)
|
| 634 |
+
|
| 635 |
+
def check_health(self) -> None:
|
| 636 |
+
# Assume that the Ray workers are healthy.
|
| 637 |
+
# TODO: check the health of the Ray workers
|
| 638 |
+
return
|
.venv/lib/python3.11/site-packages/vllm/executor/ray_utils.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import msgspec
|
| 9 |
+
|
| 10 |
+
from vllm.config import ParallelConfig
|
| 11 |
+
from vllm.executor.msgspec_utils import decode_hook, encode_hook
|
| 12 |
+
from vllm.logger import init_logger
|
| 13 |
+
from vllm.platforms import current_platform
|
| 14 |
+
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
| 15 |
+
from vllm.utils import get_ip
|
| 16 |
+
from vllm.worker.worker_base import WorkerWrapperBase
|
| 17 |
+
|
| 18 |
+
if TYPE_CHECKING:
|
| 19 |
+
from vllm.v1.core.scheduler import SchedulerOutput
|
| 20 |
+
from vllm.v1.outputs import ModelRunnerOutput
|
| 21 |
+
|
| 22 |
+
logger = init_logger(__name__)
|
| 23 |
+
PG_WAIT_TIMEOUT = 1800
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
import ray
|
| 27 |
+
from ray.util import placement_group_table
|
| 28 |
+
from ray.util.placement_group import PlacementGroup
|
| 29 |
+
try:
|
| 30 |
+
from ray._private.state import available_resources_per_node
|
| 31 |
+
except ImportError:
|
| 32 |
+
# Ray 2.9.x doesn't expose `available_resources_per_node`
|
| 33 |
+
from ray._private.state import state as _state
|
| 34 |
+
available_resources_per_node = _state._available_resources_per_node
|
| 35 |
+
|
| 36 |
+
class RayWorkerWrapper(WorkerWrapperBase):
|
| 37 |
+
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
|
| 38 |
+
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
|
| 39 |
+
|
| 40 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 41 |
+
super().__init__(*args, **kwargs)
|
| 42 |
+
# Since the compiled DAG runs a main execution
|
| 43 |
+
# in a different thread that calls cuda.set_device.
|
| 44 |
+
# The flag indicates is set_device is called on
|
| 45 |
+
# that thread.
|
| 46 |
+
self.compiled_dag_cuda_device_set = False
|
| 47 |
+
|
| 48 |
+
self.input_decoder = msgspec.msgpack.Decoder(ExecuteModelRequest,
|
| 49 |
+
dec_hook=decode_hook)
|
| 50 |
+
self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
|
| 51 |
+
|
| 52 |
+
def get_node_ip(self) -> str:
|
| 53 |
+
return get_ip()
|
| 54 |
+
|
| 55 |
+
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
|
| 56 |
+
node_id = ray.get_runtime_context().get_node_id()
|
| 57 |
+
device_key = current_platform.ray_device_key
|
| 58 |
+
if not device_key:
|
| 59 |
+
raise RuntimeError("current platform %s does not support ray.",
|
| 60 |
+
current_platform.device_name)
|
| 61 |
+
gpu_ids = ray.get_runtime_context().get_accelerator_ids(
|
| 62 |
+
)[device_key]
|
| 63 |
+
return node_id, gpu_ids
|
| 64 |
+
|
| 65 |
+
def execute_model_spmd(
|
| 66 |
+
self, req_or_tuple: Union[bytes,
|
| 67 |
+
Tuple[bytes,
|
| 68 |
+
Optional[IntermediateTensors]]]
|
| 69 |
+
) -> bytes:
|
| 70 |
+
"""Execute model in SPMD fashion: used only when SPMD worker and
|
| 71 |
+
compiled DAG are both enabled.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
req_or_tuple: A request or a tuple containing the
|
| 75 |
+
request and intermediate tensors. Intermediate tensors are
|
| 76 |
+
None unless if it is provided because it is > 0 pipeline
|
| 77 |
+
stage. The request is serialized by msgspec.
|
| 78 |
+
"""
|
| 79 |
+
if isinstance(req_or_tuple, bytes):
|
| 80 |
+
serialized_req, intermediate_tensors = req_or_tuple, None
|
| 81 |
+
else:
|
| 82 |
+
serialized_req, intermediate_tensors = req_or_tuple
|
| 83 |
+
|
| 84 |
+
execute_model_req = self.input_decoder.decode(serialized_req)
|
| 85 |
+
|
| 86 |
+
# TODO(swang): This is needed right now because Ray aDAG executes
|
| 87 |
+
# on a background thread, so we need to reset torch's current
|
| 88 |
+
# device.
|
| 89 |
+
import torch
|
| 90 |
+
if not self.compiled_dag_cuda_device_set:
|
| 91 |
+
torch.cuda.set_device(self.worker.device)
|
| 92 |
+
self.compiled_dag_cuda_device_set = True
|
| 93 |
+
|
| 94 |
+
output = self.worker._execute_model_spmd(execute_model_req,
|
| 95 |
+
intermediate_tensors)
|
| 96 |
+
# Pipeline model request and output to the next pipeline stage.
|
| 97 |
+
if isinstance(output, IntermediateTensors):
|
| 98 |
+
output = serialized_req, output
|
| 99 |
+
else:
|
| 100 |
+
output = self.output_encoder.encode(output)
|
| 101 |
+
|
| 102 |
+
return output
|
| 103 |
+
|
| 104 |
+
def setup_device_if_necessary(self):
|
| 105 |
+
# TODO(swang): This is needed right now because Ray CG executes
|
| 106 |
+
# on a background thread, so we need to reset torch's current
|
| 107 |
+
# device.
|
| 108 |
+
# We can remove this API after it is fixed in compiled graph.
|
| 109 |
+
import torch
|
| 110 |
+
assert self.worker is not None, "Worker is not initialized"
|
| 111 |
+
if not self.compiled_dag_cuda_device_set:
|
| 112 |
+
torch.cuda.set_device(self.worker.device)
|
| 113 |
+
self.compiled_dag_cuda_device_set = True
|
| 114 |
+
|
| 115 |
+
def execute_model(
|
| 116 |
+
self,
|
| 117 |
+
scheduler_output: "SchedulerOutput",
|
| 118 |
+
) -> "ModelRunnerOutput":
|
| 119 |
+
self.setup_device_if_necessary()
|
| 120 |
+
assert self.worker is not None, "Worker is not initialized"
|
| 121 |
+
output = self.worker.model_runner.execute_model(scheduler_output)
|
| 122 |
+
return output
|
| 123 |
+
|
| 124 |
+
def override_env_vars(self, vars: Dict[str, str]):
|
| 125 |
+
os.environ.update(vars)
|
| 126 |
+
|
| 127 |
+
ray_import_err = None
|
| 128 |
+
|
| 129 |
+
except ImportError as e:
|
| 130 |
+
ray = None # type: ignore
|
| 131 |
+
ray_import_err = e
|
| 132 |
+
RayWorkerWrapper = None # type: ignore
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def ray_is_available() -> bool:
|
| 136 |
+
"""Returns True if Ray is available."""
|
| 137 |
+
return ray is not None
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def assert_ray_available():
|
| 141 |
+
"""Raise an exception if Ray is not available."""
|
| 142 |
+
if ray is None:
|
| 143 |
+
raise ValueError("Failed to import Ray, please install Ray with "
|
| 144 |
+
"`pip install ray`.") from ray_import_err
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _verify_bundles(placement_group: "PlacementGroup",
|
| 148 |
+
parallel_config: ParallelConfig, device_str: str):
|
| 149 |
+
"""Verify a given placement group has bundles located in the right place.
|
| 150 |
+
|
| 151 |
+
There are 2 rules.
|
| 152 |
+
- Warn if all tensor parallel workers cannot fit in a single node.
|
| 153 |
+
- Fail if driver node is not included in a placement group.
|
| 154 |
+
"""
|
| 155 |
+
assert ray.is_initialized(), (
|
| 156 |
+
"Ray is not initialized although distributed-executor-backend is ray.")
|
| 157 |
+
pg_data = placement_group_table(placement_group)
|
| 158 |
+
# bundle_idx -> node_id
|
| 159 |
+
bundle_to_node_ids = pg_data["bundles_to_node_id"]
|
| 160 |
+
# bundle_idx -> bundle (e.g., {"GPU": 1})
|
| 161 |
+
bundles = pg_data["bundles"]
|
| 162 |
+
# node_id -> List of bundle (e.g., {"GPU": 1})
|
| 163 |
+
node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list)
|
| 164 |
+
|
| 165 |
+
for bundle_idx, node_id in bundle_to_node_ids.items():
|
| 166 |
+
node_id_to_bundle[node_id].append(bundles[bundle_idx])
|
| 167 |
+
driver_node_id = ray.get_runtime_context().get_node_id()
|
| 168 |
+
|
| 169 |
+
if driver_node_id not in node_id_to_bundle:
|
| 170 |
+
raise RuntimeError(
|
| 171 |
+
f"driver node id {driver_node_id} is not included in a placement "
|
| 172 |
+
f"group {placement_group.id}. Node id -> bundles "
|
| 173 |
+
f"{node_id_to_bundle}. "
|
| 174 |
+
"You don't have enough GPUs available in a current node. Check "
|
| 175 |
+
"`ray status` to see if you have available GPUs in a node "
|
| 176 |
+
f"{driver_node_id} before starting an vLLM engine.")
|
| 177 |
+
|
| 178 |
+
for node_id, bundles in node_id_to_bundle.items():
|
| 179 |
+
if len(bundles) < parallel_config.tensor_parallel_size:
|
| 180 |
+
logger.warning(
|
| 181 |
+
"tensor_parallel_size=%d "
|
| 182 |
+
"is bigger than a reserved number of %ss (%d "
|
| 183 |
+
"%ss) in a node %s. Tensor parallel workers can be "
|
| 184 |
+
"spread out to 2+ nodes which can degrade the performance "
|
| 185 |
+
"unless you have fast interconnect across nodes, like "
|
| 186 |
+
"Infiniband. To resolve this issue, make sure you have more "
|
| 187 |
+
"than %d GPUs available at each node.",
|
| 188 |
+
parallel_config.tensor_parallel_size, device_str, len(bundles),
|
| 189 |
+
device_str, node_id, parallel_config.tensor_parallel_size)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _wait_until_pg_ready(current_placement_group: "PlacementGroup"):
|
| 193 |
+
"""Wait until a placement group is ready.
|
| 194 |
+
|
| 195 |
+
It prints the informative log messages if the placement group is
|
| 196 |
+
not created within time.
|
| 197 |
+
|
| 198 |
+
"""
|
| 199 |
+
# Wait until PG is ready - this will block until all
|
| 200 |
+
# requested resources are available, and will timeout
|
| 201 |
+
# if they cannot be provisioned.
|
| 202 |
+
placement_group_specs = current_placement_group.bundle_specs
|
| 203 |
+
|
| 204 |
+
s = time.time()
|
| 205 |
+
pg_ready_ref = current_placement_group.ready()
|
| 206 |
+
wait_interval = 10
|
| 207 |
+
while time.time() - s < PG_WAIT_TIMEOUT:
|
| 208 |
+
ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval)
|
| 209 |
+
if len(ready) > 0:
|
| 210 |
+
break
|
| 211 |
+
|
| 212 |
+
# Exponential backoff for warning print.
|
| 213 |
+
wait_interval *= 2
|
| 214 |
+
logger.info(
|
| 215 |
+
"Waiting for creating a placement group of specs for "
|
| 216 |
+
"%d seconds. specs=%s. Check "
|
| 217 |
+
"`ray status` to see if you have enough resources,"
|
| 218 |
+
" and make sure the IP addresses used by ray cluster"
|
| 219 |
+
" are the same as VLLM_HOST_IP environment variable"
|
| 220 |
+
" specified in each node if you are running on a multi-node.",
|
| 221 |
+
int(time.time() - s), placement_group_specs)
|
| 222 |
+
|
| 223 |
+
try:
|
| 224 |
+
ray.get(pg_ready_ref, timeout=0)
|
| 225 |
+
except ray.exceptions.GetTimeoutError:
|
| 226 |
+
raise ValueError(
|
| 227 |
+
"Cannot provide a placement group of "
|
| 228 |
+
f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See "
|
| 229 |
+
"`ray status` to make sure the cluster has enough resources."
|
| 230 |
+
) from None
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def _wait_until_pg_removed(current_placement_group: "PlacementGroup"):
|
| 234 |
+
ray.util.remove_placement_group(current_placement_group)
|
| 235 |
+
s = time.time()
|
| 236 |
+
wait_interval = 10
|
| 237 |
+
while time.time() - s < PG_WAIT_TIMEOUT:
|
| 238 |
+
pg = ray.util.get_current_placement_group()
|
| 239 |
+
if pg is None:
|
| 240 |
+
break
|
| 241 |
+
|
| 242 |
+
# Exponential backoff for warning print.
|
| 243 |
+
wait_interval *= 2
|
| 244 |
+
logger.info(
|
| 245 |
+
"Waiting for removing a placement group of specs for "
|
| 246 |
+
"%d seconds.", int(time.time() - s))
|
| 247 |
+
time.sleep(wait_interval)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def initialize_ray_cluster(
|
| 251 |
+
parallel_config: ParallelConfig,
|
| 252 |
+
ray_address: Optional[str] = None,
|
| 253 |
+
):
|
| 254 |
+
"""Initialize the distributed cluster with Ray.
|
| 255 |
+
|
| 256 |
+
it will connect to the Ray cluster and create a placement group
|
| 257 |
+
for the workers, which includes the specification of the resources
|
| 258 |
+
for each distributed worker.
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
parallel_config: The configurations for parallel execution.
|
| 262 |
+
ray_address: The address of the Ray cluster. If None, uses
|
| 263 |
+
the default Ray cluster address.
|
| 264 |
+
"""
|
| 265 |
+
assert_ray_available()
|
| 266 |
+
from vllm.platforms import current_platform
|
| 267 |
+
|
| 268 |
+
# Connect to a ray cluster.
|
| 269 |
+
if current_platform.is_rocm() or current_platform.is_xpu():
|
| 270 |
+
# Try to connect existing ray instance and create a new one if not found
|
| 271 |
+
try:
|
| 272 |
+
ray.init("auto", ignore_reinit_error=True)
|
| 273 |
+
except ConnectionError:
|
| 274 |
+
logger.warning(
|
| 275 |
+
"No existing RAY instance detected. "
|
| 276 |
+
"A new instance will be launched with current node resources.")
|
| 277 |
+
ray.init(address=ray_address,
|
| 278 |
+
ignore_reinit_error=True,
|
| 279 |
+
num_gpus=parallel_config.world_size)
|
| 280 |
+
else:
|
| 281 |
+
ray.init(address=ray_address, ignore_reinit_error=True)
|
| 282 |
+
|
| 283 |
+
if parallel_config.placement_group:
|
| 284 |
+
# Placement group is already set.
|
| 285 |
+
return
|
| 286 |
+
|
| 287 |
+
device_str = current_platform.ray_device_key
|
| 288 |
+
if not device_str:
|
| 289 |
+
raise ValueError(
|
| 290 |
+
f"current platform {current_platform.device_name} does not "
|
| 291 |
+
"support ray.")
|
| 292 |
+
|
| 293 |
+
# Create placement group for worker processes
|
| 294 |
+
current_placement_group = ray.util.get_current_placement_group()
|
| 295 |
+
if current_placement_group:
|
| 296 |
+
# We are in a placement group
|
| 297 |
+
bundles = current_placement_group.bundle_specs
|
| 298 |
+
# Verify that we can use the placement group.
|
| 299 |
+
device_bundles = 0
|
| 300 |
+
for bundle in bundles:
|
| 301 |
+
bundle_devices = bundle.get(device_str, 0)
|
| 302 |
+
if bundle_devices > 1:
|
| 303 |
+
raise ValueError(
|
| 304 |
+
"Placement group bundle cannot have more than 1 "
|
| 305 |
+
f"{device_str}.")
|
| 306 |
+
if bundle_devices:
|
| 307 |
+
device_bundles += 1
|
| 308 |
+
if parallel_config.world_size > device_bundles:
|
| 309 |
+
raise ValueError(
|
| 310 |
+
f"The number of required {device_str}s exceeds the total "
|
| 311 |
+
f"number of available {device_str}s in the placement group."
|
| 312 |
+
f"Required number of devices: {parallel_config.world_size}. "
|
| 313 |
+
f"Total number of devices: {device_bundles}.")
|
| 314 |
+
else:
|
| 315 |
+
num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
|
| 316 |
+
# Log a warning message and delay resource allocation failure response.
|
| 317 |
+
# Avoid immediate rejection to allow user-initiated placement group
|
| 318 |
+
# created and wait cluster to be ready
|
| 319 |
+
if parallel_config.world_size > num_devices_in_cluster:
|
| 320 |
+
logger.warning(
|
| 321 |
+
"The number of required %ss exceeds the total "
|
| 322 |
+
"number of available %ss in the placement group.", device_str,
|
| 323 |
+
device_str)
|
| 324 |
+
# Create a new placement group
|
| 325 |
+
placement_group_specs: List[Dict[str, float]] = ([{
|
| 326 |
+
device_str: 1.0
|
| 327 |
+
} for _ in range(parallel_config.world_size)])
|
| 328 |
+
|
| 329 |
+
# vLLM engine is also a worker to execute model with an accelerator,
|
| 330 |
+
# so it requires to have the device in a current node. Check if
|
| 331 |
+
# the current node has at least one device.
|
| 332 |
+
current_ip = get_ip()
|
| 333 |
+
current_node_id = ray.get_runtime_context().get_node_id()
|
| 334 |
+
current_node_resource = available_resources_per_node()[current_node_id]
|
| 335 |
+
if current_node_resource.get(device_str, 0) < 1:
|
| 336 |
+
raise ValueError(
|
| 337 |
+
f"Current node has no {device_str} available. "
|
| 338 |
+
f"{current_node_resource=}. vLLM engine cannot start without "
|
| 339 |
+
f"{device_str}. Make sure you have at least 1 {device_str} "
|
| 340 |
+
f"available in a node {current_node_id=} {current_ip=}.")
|
| 341 |
+
# This way, at least bundle is required to be created in a current
|
| 342 |
+
# node.
|
| 343 |
+
placement_group_specs[0][f"node:{current_ip}"] = 0.001
|
| 344 |
+
|
| 345 |
+
# By default, Ray packs resources as much as possible.
|
| 346 |
+
current_placement_group = ray.util.placement_group(
|
| 347 |
+
placement_group_specs, strategy="PACK")
|
| 348 |
+
_wait_until_pg_ready(current_placement_group)
|
| 349 |
+
|
| 350 |
+
assert current_placement_group is not None
|
| 351 |
+
_verify_bundles(current_placement_group, parallel_config, device_str)
|
| 352 |
+
# Set the placement group in the parallel config
|
| 353 |
+
parallel_config.placement_group = current_placement_group
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def get_num_tpu_nodes() -> int:
|
| 357 |
+
from ray._private.accelerators import TPUAcceleratorManager
|
| 358 |
+
cluster_resources = ray.cluster_resources()
|
| 359 |
+
total_tpus = int(cluster_resources["TPU"])
|
| 360 |
+
tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators()
|
| 361 |
+
assert total_tpus % tpus_per_node == 0
|
| 362 |
+
return total_tpus // tpus_per_node
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def get_num_nodes_in_placement_group() -> int:
|
| 366 |
+
pg_table = ray.util.placement_group_table()
|
| 367 |
+
current_pg = ray.util.get_current_placement_group()
|
| 368 |
+
num_nodes = 0
|
| 369 |
+
|
| 370 |
+
if current_pg:
|
| 371 |
+
nodes_in_pg = set()
|
| 372 |
+
for pg_key, pg in pg_table.items():
|
| 373 |
+
if pg_key == current_pg.id.hex():
|
| 374 |
+
for _, node in pg["bundles_to_node_id"].items():
|
| 375 |
+
nodes_in_pg.add(node)
|
| 376 |
+
num_nodes = len(nodes_in_pg)
|
| 377 |
+
|
| 378 |
+
return num_nodes
|
.venv/lib/python3.11/site-packages/vllm/executor/uniproc_executor.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
|
| 9 |
+
import vllm.envs as envs
|
| 10 |
+
from vllm.executor.executor_base import ExecutorBase
|
| 11 |
+
from vllm.logger import init_logger
|
| 12 |
+
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
| 13 |
+
run_method)
|
| 14 |
+
from vllm.worker.worker_base import WorkerWrapperBase
|
| 15 |
+
|
| 16 |
+
logger = init_logger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class UniProcExecutor(ExecutorBase):
|
| 20 |
+
|
| 21 |
+
uses_ray: bool = False
|
| 22 |
+
|
| 23 |
+
def _init_executor(self) -> None:
|
| 24 |
+
"""Initialize the worker and load the model.
|
| 25 |
+
"""
|
| 26 |
+
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
|
| 27 |
+
rpc_rank=0)
|
| 28 |
+
distributed_init_method = get_distributed_init_method(
|
| 29 |
+
get_ip(), get_open_port())
|
| 30 |
+
local_rank = 0
|
| 31 |
+
rank = 0
|
| 32 |
+
kwargs = dict(
|
| 33 |
+
vllm_config=self.vllm_config,
|
| 34 |
+
local_rank=local_rank,
|
| 35 |
+
rank=rank,
|
| 36 |
+
distributed_init_method=distributed_init_method,
|
| 37 |
+
is_driver_worker=(not self.parallel_config)
|
| 38 |
+
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
| 39 |
+
)
|
| 40 |
+
self.collective_rpc("init_worker", args=([kwargs], ))
|
| 41 |
+
self.collective_rpc("init_device")
|
| 42 |
+
self.collective_rpc("load_model")
|
| 43 |
+
|
| 44 |
+
def collective_rpc(self,
|
| 45 |
+
method: Union[str, Callable],
|
| 46 |
+
timeout: Optional[float] = None,
|
| 47 |
+
args: Tuple = (),
|
| 48 |
+
kwargs: Optional[Dict] = None) -> List[Any]:
|
| 49 |
+
if kwargs is None:
|
| 50 |
+
kwargs = {}
|
| 51 |
+
answer = run_method(self.driver_worker, method, args, kwargs)
|
| 52 |
+
return [answer]
|
| 53 |
+
|
| 54 |
+
def check_health(self) -> None:
|
| 55 |
+
# UniProcExecutor will always be healthy as long as
|
| 56 |
+
# it's running.
|
| 57 |
+
return
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
UniProcExecutorAsync = UniProcExecutor
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class ExecutorWithExternalLauncher(UniProcExecutor):
|
| 64 |
+
"""An executor that uses external launchers to launch engines,
|
| 65 |
+
specially designed for torchrun-compatible launchers, for
|
| 66 |
+
offline inference with tensor parallelism.
|
| 67 |
+
|
| 68 |
+
see https://github.com/vllm-project/vllm/issues/11400 for
|
| 69 |
+
the motivation, and examples/offline_inference/torchrun_example.py
|
| 70 |
+
for the usage example.
|
| 71 |
+
|
| 72 |
+
The key idea: although it is tensor-parallel inference, we only
|
| 73 |
+
create one worker per executor, users will launch multiple
|
| 74 |
+
engines with torchrun-compatible launchers, and all these engines
|
| 75 |
+
work together to process the same prompts. When scheduling is
|
| 76 |
+
deterministic, all the engines will generate the same outputs,
|
| 77 |
+
and they don't need to synchronize the states with each other.
|
| 78 |
+
"""
|
| 79 |
+
uses_ray: bool = False
|
| 80 |
+
|
| 81 |
+
def _init_executor(self) -> None:
|
| 82 |
+
"""Initialize the worker and load the model.
|
| 83 |
+
"""
|
| 84 |
+
assert self.vllm_config.parallel_config.pipeline_parallel_size == 1, \
|
| 85 |
+
("ExecutorWithExternalLauncher does not "
|
| 86 |
+
"support pipeline parallelism.")
|
| 87 |
+
assert self.vllm_config.scheduler_config.delay_factor == 0.0, \
|
| 88 |
+
("ExecutorWithExternalLauncher needs deterministic "
|
| 89 |
+
"execution, so it"
|
| 90 |
+
"does not support delay_factor in scheduling")
|
| 91 |
+
assert not envs.VLLM_USE_V1, \
|
| 92 |
+
("V1 architecture cannot guarantee deterministic execution, "
|
| 93 |
+
"so it is not supported in ExecutorWithExternalLauncher.")
|
| 94 |
+
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
|
| 95 |
+
rpc_rank=0)
|
| 96 |
+
# engines are launched in torchrun-compatible launchers
|
| 97 |
+
# so we can use the env:// method.
|
| 98 |
+
# required env vars:
|
| 99 |
+
# - RANK
|
| 100 |
+
# - MASTER_ADDR
|
| 101 |
+
# - MASTER_PORT
|
| 102 |
+
distributed_init_method = "env://"
|
| 103 |
+
rank = int(os.environ["RANK"])
|
| 104 |
+
local_rank = rank
|
| 105 |
+
is_driver_worker = True
|
| 106 |
+
kwargs = dict(
|
| 107 |
+
vllm_config=self.vllm_config,
|
| 108 |
+
local_rank=local_rank,
|
| 109 |
+
rank=rank,
|
| 110 |
+
distributed_init_method=distributed_init_method,
|
| 111 |
+
is_driver_worker=is_driver_worker,
|
| 112 |
+
)
|
| 113 |
+
self.collective_rpc("init_worker", args=([kwargs], ))
|
| 114 |
+
self.collective_rpc("init_device")
|
| 115 |
+
self.collective_rpc("load_model")
|
| 116 |
+
|
| 117 |
+
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
| 118 |
+
"""
|
| 119 |
+
Determine the number of available KV blocks.
|
| 120 |
+
Add an additional all_reduce to get the min across all ranks.
|
| 121 |
+
Note that even if we have the same `gpu_memory_utilization` and
|
| 122 |
+
`swap_space`, the available memory in every rank might still
|
| 123 |
+
differ because NCCL can take different amounts of memory in
|
| 124 |
+
different ranks. Therefore, it is necessary to test if all ranks
|
| 125 |
+
agree on the same KV cache configuration.
|
| 126 |
+
"""
|
| 127 |
+
a, b = super().determine_num_available_blocks()
|
| 128 |
+
from vllm.distributed.parallel_state import get_world_group
|
| 129 |
+
cpu_group = get_world_group().cpu_group
|
| 130 |
+
a_tensor = torch.tensor([a], device="cpu", dtype=torch.int64)
|
| 131 |
+
b_tensor = torch.tensor([b], device="cpu", dtype=torch.int64)
|
| 132 |
+
dist.all_reduce(a_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
|
| 133 |
+
dist.all_reduce(b_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
|
| 134 |
+
return a_tensor.item(), b_tensor.item()
|
.venv/lib/python3.11/site-packages/vllm/forward_context.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
from contextlib import contextmanager
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import TYPE_CHECKING, Any, Dict, Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
import vllm.envs as envs
|
| 12 |
+
from vllm.config import VllmConfig
|
| 13 |
+
from vllm.logger import init_logger
|
| 14 |
+
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
from vllm.attention.backends.abstract import AttentionMetadata
|
| 17 |
+
|
| 18 |
+
logger = init_logger(__name__)
|
| 19 |
+
|
| 20 |
+
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
|
| 21 |
+
last_logging_time: float = 0
|
| 22 |
+
forward_start_time: float = 0
|
| 23 |
+
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
|
| 24 |
+
batchsize_forward_time: defaultdict = defaultdict(list)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class ForwardContext:
|
| 29 |
+
# copy from vllm_config.compilation_config.static_forward_context
|
| 30 |
+
attn_layers: Dict[str, Any]
|
| 31 |
+
# TODO: extend to support per-layer dynamic forward context
|
| 32 |
+
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
|
| 33 |
+
# TODO: remove after making all virtual_engines share the same kv cache
|
| 34 |
+
virtual_engine: int # set dynamically for each forward pass
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
_forward_context: Optional[ForwardContext] = None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_forward_context() -> ForwardContext:
|
| 41 |
+
"""Get the current forward context."""
|
| 42 |
+
assert _forward_context is not None, (
|
| 43 |
+
"Forward context is not set. "
|
| 44 |
+
"Please use `set_forward_context` to set the forward context.")
|
| 45 |
+
return _forward_context
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@contextmanager
|
| 49 |
+
def set_forward_context(attn_metadata: Any,
|
| 50 |
+
vllm_config: VllmConfig,
|
| 51 |
+
virtual_engine: int = 0):
|
| 52 |
+
"""A context manager that stores the current forward context,
|
| 53 |
+
can be attention metadata, etc.
|
| 54 |
+
Here we can inject common logic for every model forward pass.
|
| 55 |
+
"""
|
| 56 |
+
global forward_start_time
|
| 57 |
+
need_to_track_batchsize = track_batchsize and attn_metadata is not None
|
| 58 |
+
if need_to_track_batchsize:
|
| 59 |
+
forward_start_time = time.perf_counter()
|
| 60 |
+
global _forward_context
|
| 61 |
+
prev_context = _forward_context
|
| 62 |
+
_forward_context = ForwardContext(
|
| 63 |
+
attn_layers=vllm_config.compilation_config.static_forward_context,
|
| 64 |
+
virtual_engine=virtual_engine,
|
| 65 |
+
attn_metadata=attn_metadata)
|
| 66 |
+
try:
|
| 67 |
+
yield
|
| 68 |
+
finally:
|
| 69 |
+
global last_logging_time, batchsize_logging_interval
|
| 70 |
+
if need_to_track_batchsize:
|
| 71 |
+
if hasattr(attn_metadata, "num_prefill_tokens"):
|
| 72 |
+
# for v0 attention backends
|
| 73 |
+
batchsize = attn_metadata.num_prefill_tokens + \
|
| 74 |
+
attn_metadata.num_decode_tokens
|
| 75 |
+
else:
|
| 76 |
+
# for v1 attention backends
|
| 77 |
+
batchsize = attn_metadata.num_input_tokens
|
| 78 |
+
# we use synchronous scheduling right now,
|
| 79 |
+
# adding a sync point here should not affect
|
| 80 |
+
# scheduling of the next batch
|
| 81 |
+
torch.cuda.synchronize()
|
| 82 |
+
now = time.perf_counter()
|
| 83 |
+
# time measurement is in milliseconds
|
| 84 |
+
batchsize_forward_time[batchsize].append(
|
| 85 |
+
(now - forward_start_time) * 1000)
|
| 86 |
+
if now - last_logging_time > batchsize_logging_interval:
|
| 87 |
+
last_logging_time = now
|
| 88 |
+
forward_stats = []
|
| 89 |
+
for bs, times in batchsize_forward_time.items():
|
| 90 |
+
if len(times) <= 1:
|
| 91 |
+
# can be cudagraph / profiling run
|
| 92 |
+
continue
|
| 93 |
+
medium = torch.quantile(torch.tensor(times), q=0.5).item()
|
| 94 |
+
medium = round(medium, 2)
|
| 95 |
+
forward_stats.append((bs, len(times), medium))
|
| 96 |
+
forward_stats.sort(key=lambda x: x[1], reverse=True)
|
| 97 |
+
if forward_stats:
|
| 98 |
+
logger.info(("Batchsize forward time stats "
|
| 99 |
+
"(batchsize, count, median_time(ms)): %s"),
|
| 100 |
+
forward_stats)
|
| 101 |
+
_forward_context = prev_context
|
.venv/lib/python3.11/site-packages/vllm/logger.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Logging configuration for vLLM."""
|
| 3 |
+
import datetime
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
from functools import lru_cache, partial
|
| 9 |
+
from logging import Logger
|
| 10 |
+
from logging.config import dictConfig
|
| 11 |
+
from os import path
|
| 12 |
+
from types import MethodType
|
| 13 |
+
from typing import Any, Optional, cast
|
| 14 |
+
|
| 15 |
+
import vllm.envs as envs
|
| 16 |
+
|
| 17 |
+
VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING
|
| 18 |
+
VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH
|
| 19 |
+
VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL
|
| 20 |
+
VLLM_LOGGING_PREFIX = envs.VLLM_LOGGING_PREFIX
|
| 21 |
+
|
| 22 |
+
_FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s "
|
| 23 |
+
"%(filename)s:%(lineno)d] %(message)s")
|
| 24 |
+
_DATE_FORMAT = "%m-%d %H:%M:%S"
|
| 25 |
+
|
| 26 |
+
DEFAULT_LOGGING_CONFIG = {
|
| 27 |
+
"formatters": {
|
| 28 |
+
"vllm": {
|
| 29 |
+
"class": "vllm.logging_utils.NewLineFormatter",
|
| 30 |
+
"datefmt": _DATE_FORMAT,
|
| 31 |
+
"format": _FORMAT,
|
| 32 |
+
},
|
| 33 |
+
},
|
| 34 |
+
"handlers": {
|
| 35 |
+
"vllm": {
|
| 36 |
+
"class": "logging.StreamHandler",
|
| 37 |
+
"formatter": "vllm",
|
| 38 |
+
"level": VLLM_LOGGING_LEVEL,
|
| 39 |
+
"stream": "ext://sys.stdout",
|
| 40 |
+
},
|
| 41 |
+
},
|
| 42 |
+
"loggers": {
|
| 43 |
+
"vllm": {
|
| 44 |
+
"handlers": ["vllm"],
|
| 45 |
+
"level": "DEBUG",
|
| 46 |
+
"propagate": False,
|
| 47 |
+
},
|
| 48 |
+
},
|
| 49 |
+
"version": 1,
|
| 50 |
+
"disable_existing_loggers": False
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@lru_cache
|
| 55 |
+
def _print_info_once(logger: Logger, msg: str) -> None:
|
| 56 |
+
# Set the stacklevel to 2 to print the original caller's line info
|
| 57 |
+
logger.info(msg, stacklevel=2)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@lru_cache
|
| 61 |
+
def _print_warning_once(logger: Logger, msg: str) -> None:
|
| 62 |
+
# Set the stacklevel to 2 to print the original caller's line info
|
| 63 |
+
logger.warning(msg, stacklevel=2)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class _VllmLogger(Logger):
|
| 67 |
+
"""
|
| 68 |
+
Note:
|
| 69 |
+
This class is just to provide type information.
|
| 70 |
+
We actually patch the methods directly on the :class:`logging.Logger`
|
| 71 |
+
instance to avoid conflicting with other libraries such as
|
| 72 |
+
`intel_extension_for_pytorch.utils._logger`.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def info_once(self, msg: str) -> None:
|
| 76 |
+
"""
|
| 77 |
+
As :meth:`info`, but subsequent calls with the same message
|
| 78 |
+
are silently dropped.
|
| 79 |
+
"""
|
| 80 |
+
_print_info_once(self, msg)
|
| 81 |
+
|
| 82 |
+
def warning_once(self, msg: str) -> None:
|
| 83 |
+
"""
|
| 84 |
+
As :meth:`warning`, but subsequent calls with the same message
|
| 85 |
+
are silently dropped.
|
| 86 |
+
"""
|
| 87 |
+
_print_warning_once(self, msg)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _configure_vllm_root_logger() -> None:
|
| 91 |
+
logging_config = dict[str, Any]()
|
| 92 |
+
|
| 93 |
+
if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH:
|
| 94 |
+
raise RuntimeError(
|
| 95 |
+
"VLLM_CONFIGURE_LOGGING evaluated to false, but "
|
| 96 |
+
"VLLM_LOGGING_CONFIG_PATH was given. VLLM_LOGGING_CONFIG_PATH "
|
| 97 |
+
"implies VLLM_CONFIGURE_LOGGING. Please enable "
|
| 98 |
+
"VLLM_CONFIGURE_LOGGING or unset VLLM_LOGGING_CONFIG_PATH.")
|
| 99 |
+
|
| 100 |
+
if VLLM_CONFIGURE_LOGGING:
|
| 101 |
+
logging_config = DEFAULT_LOGGING_CONFIG
|
| 102 |
+
|
| 103 |
+
if VLLM_LOGGING_CONFIG_PATH:
|
| 104 |
+
if not path.exists(VLLM_LOGGING_CONFIG_PATH):
|
| 105 |
+
raise RuntimeError(
|
| 106 |
+
"Could not load logging config. File does not exist: %s",
|
| 107 |
+
VLLM_LOGGING_CONFIG_PATH)
|
| 108 |
+
with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
|
| 109 |
+
custom_config = json.loads(file.read())
|
| 110 |
+
|
| 111 |
+
if not isinstance(custom_config, dict):
|
| 112 |
+
raise ValueError("Invalid logging config. Expected Dict, got %s.",
|
| 113 |
+
type(custom_config).__name__)
|
| 114 |
+
logging_config = custom_config
|
| 115 |
+
|
| 116 |
+
for formatter in logging_config.get("formatters", {}).values():
|
| 117 |
+
# This provides backwards compatibility after #10134.
|
| 118 |
+
if formatter.get("class") == "vllm.logging.NewLineFormatter":
|
| 119 |
+
formatter["class"] = "vllm.logging_utils.NewLineFormatter"
|
| 120 |
+
|
| 121 |
+
if logging_config:
|
| 122 |
+
dictConfig(logging_config)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def init_logger(name: str) -> _VllmLogger:
|
| 126 |
+
"""The main purpose of this function is to ensure that loggers are
|
| 127 |
+
retrieved in such a way that we can be sure the root vllm logger has
|
| 128 |
+
already been configured."""
|
| 129 |
+
|
| 130 |
+
logger = logging.getLogger(name)
|
| 131 |
+
|
| 132 |
+
methods_to_patch = {
|
| 133 |
+
"info_once": _print_info_once,
|
| 134 |
+
"warning_once": _print_warning_once,
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
for method_name, method in methods_to_patch.items():
|
| 138 |
+
setattr(logger, method_name, MethodType(method, logger))
|
| 139 |
+
|
| 140 |
+
return cast(_VllmLogger, logger)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# The root logger is initialized when the module is imported.
|
| 144 |
+
# This is thread-safe as the module is only imported once,
|
| 145 |
+
# guaranteed by the Python GIL.
|
| 146 |
+
_configure_vllm_root_logger()
|
| 147 |
+
|
| 148 |
+
logger = init_logger(__name__)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _trace_calls(log_path, root_dir, frame, event, arg=None):
|
| 152 |
+
if event in ['call', 'return']:
|
| 153 |
+
# Extract the filename, line number, function name, and the code object
|
| 154 |
+
filename = frame.f_code.co_filename
|
| 155 |
+
lineno = frame.f_lineno
|
| 156 |
+
func_name = frame.f_code.co_name
|
| 157 |
+
if not filename.startswith(root_dir):
|
| 158 |
+
# only log the functions in the vllm root_dir
|
| 159 |
+
return
|
| 160 |
+
# Log every function call or return
|
| 161 |
+
try:
|
| 162 |
+
last_frame = frame.f_back
|
| 163 |
+
if last_frame is not None:
|
| 164 |
+
last_filename = last_frame.f_code.co_filename
|
| 165 |
+
last_lineno = last_frame.f_lineno
|
| 166 |
+
last_func_name = last_frame.f_code.co_name
|
| 167 |
+
else:
|
| 168 |
+
# initial frame
|
| 169 |
+
last_filename = ""
|
| 170 |
+
last_lineno = 0
|
| 171 |
+
last_func_name = ""
|
| 172 |
+
with open(log_path, 'a') as f:
|
| 173 |
+
ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
|
| 174 |
+
if event == 'call':
|
| 175 |
+
f.write(f"{ts} Call to"
|
| 176 |
+
f" {func_name} in {filename}:{lineno}"
|
| 177 |
+
f" from {last_func_name} in {last_filename}:"
|
| 178 |
+
f"{last_lineno}\n")
|
| 179 |
+
else:
|
| 180 |
+
f.write(f"{ts} Return from"
|
| 181 |
+
f" {func_name} in {filename}:{lineno}"
|
| 182 |
+
f" to {last_func_name} in {last_filename}:"
|
| 183 |
+
f"{last_lineno}\n")
|
| 184 |
+
except NameError:
|
| 185 |
+
# modules are deleted during shutdown
|
| 186 |
+
pass
|
| 187 |
+
return partial(_trace_calls, log_path, root_dir)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def enable_trace_function_call(log_file_path: str,
|
| 191 |
+
root_dir: Optional[str] = None):
|
| 192 |
+
"""
|
| 193 |
+
Enable tracing of every function call in code under `root_dir`.
|
| 194 |
+
This is useful for debugging hangs or crashes.
|
| 195 |
+
`log_file_path` is the path to the log file.
|
| 196 |
+
`root_dir` is the root directory of the code to trace. If None, it is the
|
| 197 |
+
vllm root directory.
|
| 198 |
+
|
| 199 |
+
Note that this call is thread-level, any threads calling this function
|
| 200 |
+
will have the trace enabled. Other threads will not be affected.
|
| 201 |
+
"""
|
| 202 |
+
logger.warning(
|
| 203 |
+
"VLLM_TRACE_FUNCTION is enabled. It will record every"
|
| 204 |
+
" function executed by Python. This will slow down the code. It "
|
| 205 |
+
"is suggested to be used for debugging hang or crashes only.")
|
| 206 |
+
logger.info("Trace frame log is saved to %s", log_file_path)
|
| 207 |
+
if root_dir is None:
|
| 208 |
+
# by default, this is the vllm root directory
|
| 209 |
+
root_dir = os.path.dirname(os.path.dirname(__file__))
|
| 210 |
+
sys.settrace(partial(_trace_calls, log_file_path, root_dir))
|
.venv/lib/python3.11/site-packages/vllm/logits_process.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import Callable, List, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
| 8 |
+
|
| 9 |
+
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
|
| 10 |
+
Callable[[List[int], List[int], torch.Tensor],
|
| 11 |
+
torch.Tensor]]
|
| 12 |
+
"""LogitsProcessor is a function that takes a list
|
| 13 |
+
of previously generated tokens, the logits tensor
|
| 14 |
+
for the next token and, optionally, prompt tokens as a
|
| 15 |
+
first argument, and returns a modified tensor of logits
|
| 16 |
+
to sample from."""
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_bad_words_logits_processors(
|
| 20 |
+
bad_words: List[str],
|
| 21 |
+
tokenizer: AnyTokenizer) -> List[LogitsProcessor]:
|
| 22 |
+
bad_words_ids: List[List[int]] = list()
|
| 23 |
+
|
| 24 |
+
for bad_word in bad_words:
|
| 25 |
+
# To prohibit words both at the beginning
|
| 26 |
+
# and in the middle of text
|
| 27 |
+
# (related to add_prefix_space tokenizer parameter)
|
| 28 |
+
for add_prefix_space in [False, True]:
|
| 29 |
+
prefix = " " if add_prefix_space else ""
|
| 30 |
+
prompt = prefix + bad_word.lstrip()
|
| 31 |
+
|
| 32 |
+
if isinstance(tokenizer, MistralTokenizer):
|
| 33 |
+
# Mistral tokenizers should not add special tokens
|
| 34 |
+
prompt_token_ids = tokenizer.encode(prompt=prompt)
|
| 35 |
+
else:
|
| 36 |
+
prompt_token_ids = tokenizer.encode(text=prompt,
|
| 37 |
+
add_special_tokens=False)
|
| 38 |
+
|
| 39 |
+
# If no space at the beginning
|
| 40 |
+
# or if prefix space produces a new word token
|
| 41 |
+
if (not add_prefix_space) or (
|
| 42 |
+
add_prefix_space
|
| 43 |
+
and prompt_token_ids[0] != bad_words_ids[-1][0]
|
| 44 |
+
and len(prompt_token_ids) == len(bad_words_ids[-1])):
|
| 45 |
+
bad_words_ids.append(prompt_token_ids)
|
| 46 |
+
|
| 47 |
+
return [NoBadWordsLogitsProcessor(bad_words_ids=bad_words_ids)]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class NoBadWordsLogitsProcessor:
|
| 51 |
+
_SMALLEST_LOGIT = float("-inf")
|
| 52 |
+
_NEUTRAL_LOGIT = 0.0
|
| 53 |
+
|
| 54 |
+
def __init__(self, bad_words_ids: List[List[int]]):
|
| 55 |
+
self.bad_words_ids = bad_words_ids
|
| 56 |
+
self.word_bias: torch.FloatTensor = None
|
| 57 |
+
|
| 58 |
+
def __call__(
|
| 59 |
+
self,
|
| 60 |
+
past_tokens_ids: Union[List[int], Tuple[int]],
|
| 61 |
+
logits: torch.FloatTensor,
|
| 62 |
+
) -> torch.Tensor:
|
| 63 |
+
if self.word_bias is None:
|
| 64 |
+
self._init_word_bias(logits=logits)
|
| 65 |
+
|
| 66 |
+
last_token_bias = torch.zeros_like(logits)
|
| 67 |
+
|
| 68 |
+
for bad_word_ids in self.bad_words_ids:
|
| 69 |
+
if len(bad_word_ids) == 1: # 1-token words already processed
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
if len(bad_word_ids) > len(past_tokens_ids) + 1:
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
prefix_length = len(bad_word_ids) - 1
|
| 76 |
+
last_token_id = bad_word_ids[-1]
|
| 77 |
+
actual_prefix = past_tokens_ids[-prefix_length:]
|
| 78 |
+
expected_prefix = bad_word_ids[:prefix_length]
|
| 79 |
+
|
| 80 |
+
assert len(actual_prefix) == len(expected_prefix)
|
| 81 |
+
|
| 82 |
+
is_match = tuple(actual_prefix) == tuple(expected_prefix)
|
| 83 |
+
last_token_bias[last_token_id] += (self._SMALLEST_LOGIT if is_match
|
| 84 |
+
else self._NEUTRAL_LOGIT)
|
| 85 |
+
|
| 86 |
+
logits = logits + self.word_bias + last_token_bias
|
| 87 |
+
|
| 88 |
+
return logits
|
| 89 |
+
|
| 90 |
+
def _init_word_bias(self, logits: torch.FloatTensor) -> None:
|
| 91 |
+
# Code based on NoBadWordsLogitsProcessor and SequenceBiasLogitsProcessor # noqa: E501
|
| 92 |
+
# from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py
|
| 93 |
+
|
| 94 |
+
vocab_size = logits.shape[-1]
|
| 95 |
+
|
| 96 |
+
self._check_token_ids_bounds(vocab_size=vocab_size)
|
| 97 |
+
|
| 98 |
+
self.word_bias = torch.zeros((vocab_size, ),
|
| 99 |
+
dtype=torch.float,
|
| 100 |
+
device=logits.device)
|
| 101 |
+
|
| 102 |
+
for bad_word_ids in self.bad_words_ids:
|
| 103 |
+
if len(bad_word_ids) == 1:
|
| 104 |
+
bad_word_id = bad_word_ids[-1]
|
| 105 |
+
self.word_bias[bad_word_id] = self._SMALLEST_LOGIT
|
| 106 |
+
|
| 107 |
+
def _check_token_ids_bounds(self, vocab_size: int) -> None:
|
| 108 |
+
invalid_token_ids = []
|
| 109 |
+
|
| 110 |
+
for bad_word_ids in self.bad_words_ids:
|
| 111 |
+
for token_id in bad_word_ids:
|
| 112 |
+
if token_id < 0 or token_id >= vocab_size:
|
| 113 |
+
invalid_token_ids.append(token_id)
|
| 114 |
+
|
| 115 |
+
if len(invalid_token_ids) > 0:
|
| 116 |
+
raise ValueError(
|
| 117 |
+
f"The model vocabulary size is {vocab_size},"
|
| 118 |
+
f" but the following tokens"
|
| 119 |
+
f" were specified as bad: {invalid_token_ids}."
|
| 120 |
+
f" All token id values should be integers satisfying:"
|
| 121 |
+
f" 0 <= token_id < {vocab_size}.")
|
.venv/lib/python3.11/site-packages/vllm/outputs.py
ADDED
|
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Dict, Generic, List, MutableSequence, Optional
|
| 6 |
+
from typing import Sequence as GenericSequence
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from typing_extensions import TypeVar, deprecated
|
| 11 |
+
|
| 12 |
+
from vllm.lora.request import LoRARequest
|
| 13 |
+
from vllm.multimodal.inputs import MultiModalPlaceholderDict
|
| 14 |
+
from vllm.sampling_params import RequestOutputKind
|
| 15 |
+
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
|
| 16 |
+
SequenceGroup, SequenceGroupBase, SequenceStatus)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class CompletionOutput:
|
| 21 |
+
"""The output data of one completion output of a request.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
index: The index of the output in the request.
|
| 25 |
+
text: The generated output text.
|
| 26 |
+
token_ids: The token IDs of the generated output text.
|
| 27 |
+
cumulative_logprob: The cumulative log probability of the generated
|
| 28 |
+
output text.
|
| 29 |
+
logprobs: The log probabilities of the top probability words at each
|
| 30 |
+
position if the logprobs are requested.
|
| 31 |
+
finish_reason: The reason why the sequence is finished.
|
| 32 |
+
stop_reason: The stop string or token id that caused the completion
|
| 33 |
+
to stop, None if the completion finished for some other reason
|
| 34 |
+
including encountering the EOS token.
|
| 35 |
+
lora_request: The LoRA request that was used to generate the output.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
index: int
|
| 39 |
+
text: str
|
| 40 |
+
token_ids: GenericSequence[int]
|
| 41 |
+
cumulative_logprob: Optional[float]
|
| 42 |
+
logprobs: Optional[SampleLogprobs]
|
| 43 |
+
finish_reason: Optional[str] = None
|
| 44 |
+
stop_reason: Union[int, str, None] = None
|
| 45 |
+
lora_request: Optional[LoRARequest] = None
|
| 46 |
+
|
| 47 |
+
def finished(self) -> bool:
|
| 48 |
+
return self.finish_reason is not None
|
| 49 |
+
|
| 50 |
+
def __repr__(self) -> str:
|
| 51 |
+
return (f"CompletionOutput(index={self.index}, "
|
| 52 |
+
f"text={self.text!r}, "
|
| 53 |
+
f"token_ids={self.token_ids}, "
|
| 54 |
+
f"cumulative_logprob={self.cumulative_logprob}, "
|
| 55 |
+
f"logprobs={self.logprobs}, "
|
| 56 |
+
f"finish_reason={self.finish_reason}, "
|
| 57 |
+
f"stop_reason={self.stop_reason})")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
class PoolingOutput:
|
| 62 |
+
"""The output data of one pooling output of a request.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
data: The extracted hidden states.
|
| 66 |
+
"""
|
| 67 |
+
data: torch.Tensor
|
| 68 |
+
|
| 69 |
+
def __repr__(self) -> str:
|
| 70 |
+
return (f"PoolingOutput(data={self.data})")
|
| 71 |
+
|
| 72 |
+
def __eq__(self, other: object) -> bool:
|
| 73 |
+
return (isinstance(other, self.__class__) and bool(
|
| 74 |
+
(self.data == other.data).all()))
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
@deprecated("`LLM.encode()` now stores raw outputs in the `data` "
|
| 78 |
+
"attribute. To return embeddings, use `LLM.embed()`. "
|
| 79 |
+
"To return class probabilities, use `LLM.classify()` "
|
| 80 |
+
"and access the `probs` attribute. ")
|
| 81 |
+
def embedding(self) -> list[float]:
|
| 82 |
+
return self.data.tolist()
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class RequestOutput:
|
| 86 |
+
"""The output data of a completion request to the LLM.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
request_id: The unique ID of the request.
|
| 90 |
+
prompt: The prompt string of the request.
|
| 91 |
+
For encoder/decoder models, this is the
|
| 92 |
+
decoder input prompt.
|
| 93 |
+
prompt_token_ids: The token IDs of the prompt.
|
| 94 |
+
For encoder/decoder models, this is the
|
| 95 |
+
decoder input prompt token ids.
|
| 96 |
+
prompt_logprobs: The log probabilities to return per prompt token.
|
| 97 |
+
outputs: The output sequences of the request.
|
| 98 |
+
finished: Whether the whole request is finished.
|
| 99 |
+
metrics: Metrics associated with the request.
|
| 100 |
+
lora_request: The LoRA request that was used to generate the output.
|
| 101 |
+
encoder_prompt: The encoder prompt string of the request.
|
| 102 |
+
None if decoder-only.
|
| 103 |
+
encoder_prompt_token_ids: The token IDs of the encoder prompt.
|
| 104 |
+
None if decoder-only.
|
| 105 |
+
num_cached_tokens: The number of tokens with prefix cache hit.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
request_id: str,
|
| 111 |
+
prompt: Optional[str],
|
| 112 |
+
prompt_token_ids: Optional[List[int]],
|
| 113 |
+
prompt_logprobs: Optional[PromptLogprobs],
|
| 114 |
+
outputs: List[CompletionOutput],
|
| 115 |
+
finished: bool,
|
| 116 |
+
metrics: Optional[RequestMetrics] = None,
|
| 117 |
+
lora_request: Optional[LoRARequest] = None,
|
| 118 |
+
encoder_prompt: Optional[str] = None,
|
| 119 |
+
encoder_prompt_token_ids: Optional[List[int]] = None,
|
| 120 |
+
num_cached_tokens: Optional[int] = None,
|
| 121 |
+
*,
|
| 122 |
+
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
|
| 123 |
+
) -> None:
|
| 124 |
+
self.request_id = request_id
|
| 125 |
+
self.prompt = prompt
|
| 126 |
+
self.prompt_token_ids = prompt_token_ids
|
| 127 |
+
self.multi_modal_placeholders = multi_modal_placeholders or {}
|
| 128 |
+
self.prompt_logprobs = prompt_logprobs
|
| 129 |
+
self.outputs = outputs
|
| 130 |
+
self.finished = finished
|
| 131 |
+
self.metrics = metrics
|
| 132 |
+
self.lora_request = lora_request
|
| 133 |
+
self.encoder_prompt = encoder_prompt
|
| 134 |
+
self.encoder_prompt_token_ids = encoder_prompt_token_ids
|
| 135 |
+
self.num_cached_tokens = num_cached_tokens
|
| 136 |
+
|
| 137 |
+
@classmethod
|
| 138 |
+
def new(
|
| 139 |
+
cls,
|
| 140 |
+
request_id: str,
|
| 141 |
+
prompt: Optional[str],
|
| 142 |
+
prompt_token_ids: Optional[List[int]],
|
| 143 |
+
text: str,
|
| 144 |
+
token_ids: List[int],
|
| 145 |
+
finished: bool = False,
|
| 146 |
+
) -> "RequestOutput":
|
| 147 |
+
"""Initialize a new RequestOutput object."""
|
| 148 |
+
|
| 149 |
+
# TODO: Support `n` > 1.
|
| 150 |
+
completion_output = CompletionOutput(
|
| 151 |
+
index=0,
|
| 152 |
+
text=text,
|
| 153 |
+
token_ids=token_ids,
|
| 154 |
+
cumulative_logprob=None,
|
| 155 |
+
logprobs=None, # TODO
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
return RequestOutput(
|
| 159 |
+
request_id=request_id,
|
| 160 |
+
prompt=prompt,
|
| 161 |
+
prompt_token_ids=prompt_token_ids,
|
| 162 |
+
prompt_logprobs=None, # TODO
|
| 163 |
+
outputs=[completion_output],
|
| 164 |
+
finished=finished,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
def add(self, next_output: "RequestOutput") -> None:
|
| 168 |
+
"""Merge subsequent RequestOutput into this one"""
|
| 169 |
+
|
| 170 |
+
self.prompt = next_output.prompt
|
| 171 |
+
self.prompt_token_ids = next_output.prompt_token_ids
|
| 172 |
+
self.prompt_logprobs = next_output.prompt_logprobs
|
| 173 |
+
self.finished |= next_output.finished
|
| 174 |
+
|
| 175 |
+
#TODO assuming n == 1 for now
|
| 176 |
+
completion = self.outputs[0]
|
| 177 |
+
next_completion = next_output.outputs[0]
|
| 178 |
+
completion.text += next_completion.text
|
| 179 |
+
if not isinstance(completion.token_ids, MutableSequence):
|
| 180 |
+
completion.token_ids = list(completion.token_ids)
|
| 181 |
+
completion.token_ids.extend(next_completion.token_ids)
|
| 182 |
+
if next_completion.logprobs:
|
| 183 |
+
assert completion.logprobs is not None
|
| 184 |
+
completion.logprobs.extend(next_completion.logprobs)
|
| 185 |
+
completion.cumulative_logprob = next_completion.cumulative_logprob
|
| 186 |
+
|
| 187 |
+
@classmethod
|
| 188 |
+
def from_seq_group(
|
| 189 |
+
cls, seq_group: SequenceGroup, use_cache: bool,
|
| 190 |
+
seq_id_to_seq_group: Dict[str, SequenceGroupBase]
|
| 191 |
+
) -> Optional["RequestOutput"]:
|
| 192 |
+
finished = seq_group.is_finished()
|
| 193 |
+
|
| 194 |
+
if seq_group.request_id in seq_id_to_seq_group:
|
| 195 |
+
group: SequenceGroupBase = seq_id_to_seq_group[
|
| 196 |
+
seq_group.request_id]
|
| 197 |
+
assembled_seq_group = group.maybe_assemble_group(seq_group)
|
| 198 |
+
if finished:
|
| 199 |
+
group.finish_seq(seq_group)
|
| 200 |
+
if assembled_seq_group is None:
|
| 201 |
+
return None
|
| 202 |
+
return cls.from_seq_group(assembled_seq_group, use_cache,
|
| 203 |
+
seq_id_to_seq_group)
|
| 204 |
+
|
| 205 |
+
sampling_params = seq_group.sampling_params
|
| 206 |
+
if sampling_params is None:
|
| 207 |
+
raise ValueError(
|
| 208 |
+
"Sampling parameters are missing for a CompletionRequest.")
|
| 209 |
+
|
| 210 |
+
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
|
| 211 |
+
not finished):
|
| 212 |
+
return None
|
| 213 |
+
|
| 214 |
+
# Init cache (if needed)
|
| 215 |
+
if use_cache and seq_group.cached_request_output is None:
|
| 216 |
+
seq_group.cached_request_output = RequestOutput( # type: ignore
|
| 217 |
+
request_id="",
|
| 218 |
+
prompt=None,
|
| 219 |
+
prompt_token_ids=[],
|
| 220 |
+
prompt_logprobs=None,
|
| 221 |
+
outputs=[],
|
| 222 |
+
finished=False)
|
| 223 |
+
|
| 224 |
+
top_n_seqs = seq_group.get_seqs()
|
| 225 |
+
|
| 226 |
+
# Create the outputs.
|
| 227 |
+
# NOTE: We need omit logprobs here explicitly because the sequence
|
| 228 |
+
# always has the logprobs of the sampled tokens even if the
|
| 229 |
+
# logprobs are not requested.
|
| 230 |
+
include_logprobs = sampling_params.logprobs is not None
|
| 231 |
+
text_buffer_length = sampling_params.output_text_buffer_length
|
| 232 |
+
delta = sampling_params.output_kind == RequestOutputKind.DELTA
|
| 233 |
+
|
| 234 |
+
outputs = []
|
| 235 |
+
include_prompt = True
|
| 236 |
+
# num_cached_tokens should be the same for all the sequences
|
| 237 |
+
num_cached_tokens = None
|
| 238 |
+
for i, seq in enumerate(top_n_seqs):
|
| 239 |
+
output_text = seq.get_output_text_to_return(
|
| 240 |
+
text_buffer_length, delta)
|
| 241 |
+
|
| 242 |
+
output_token_ids = seq.get_output_token_ids_to_return(delta)
|
| 243 |
+
num_output_tokens = 1 if isinstance(output_token_ids,
|
| 244 |
+
int) else len(output_token_ids)
|
| 245 |
+
num_cached_tokens = seq.data.get_num_cached_tokens()
|
| 246 |
+
|
| 247 |
+
output_logprobs = seq.output_logprobs if include_logprobs else None
|
| 248 |
+
|
| 249 |
+
if delta:
|
| 250 |
+
# Slice logprobs delta if applicable
|
| 251 |
+
if output_logprobs:
|
| 252 |
+
output_logprobs = output_logprobs[-num_output_tokens:]
|
| 253 |
+
# Don't include prompt if this is after the first output
|
| 254 |
+
# containing decode token ids
|
| 255 |
+
if include_prompt and seq.get_output_len() > num_output_tokens:
|
| 256 |
+
include_prompt = False
|
| 257 |
+
|
| 258 |
+
if use_cache:
|
| 259 |
+
# Get cached output object
|
| 260 |
+
cached_outputs = seq_group.cached_request_output.outputs # type: ignore
|
| 261 |
+
if i >= len(cached_outputs):
|
| 262 |
+
cached_outputs.append(
|
| 263 |
+
CompletionOutput(index=i,
|
| 264 |
+
text="",
|
| 265 |
+
token_ids=[],
|
| 266 |
+
cumulative_logprob=None,
|
| 267 |
+
logprobs=None,
|
| 268 |
+
finish_reason=None,
|
| 269 |
+
stop_reason=None))
|
| 270 |
+
output = cached_outputs[i]
|
| 271 |
+
|
| 272 |
+
# Init cached output object
|
| 273 |
+
assert output.index == i
|
| 274 |
+
output.text = output_text
|
| 275 |
+
|
| 276 |
+
if isinstance(output_token_ids, int):
|
| 277 |
+
output.token_ids.clear()
|
| 278 |
+
output.token_ids.append(output_token_ids)
|
| 279 |
+
else:
|
| 280 |
+
output.token_ids = output_token_ids
|
| 281 |
+
|
| 282 |
+
output.cumulative_logprob = seq.get_cumulative_logprob() \
|
| 283 |
+
if include_logprobs else None
|
| 284 |
+
output.logprobs = output_logprobs
|
| 285 |
+
output.finish_reason = SequenceStatus.get_finished_reason(
|
| 286 |
+
seq.status)
|
| 287 |
+
output.stop_reason = seq.stop_reason
|
| 288 |
+
|
| 289 |
+
else:
|
| 290 |
+
output = CompletionOutput(
|
| 291 |
+
top_n_seqs.index(seq), output_text, [output_token_ids]
|
| 292 |
+
if isinstance(output_token_ids, int) else output_token_ids,
|
| 293 |
+
seq.get_cumulative_logprob() if include_logprobs else None,
|
| 294 |
+
output_logprobs,
|
| 295 |
+
SequenceStatus.get_finished_reason(seq.status),
|
| 296 |
+
seq.stop_reason)
|
| 297 |
+
|
| 298 |
+
outputs.append(output)
|
| 299 |
+
|
| 300 |
+
# Every sequence in the sequence group should have the same prompt.
|
| 301 |
+
if include_prompt:
|
| 302 |
+
prompt = seq_group.prompt
|
| 303 |
+
prompt_token_ids = seq_group.prompt_token_ids
|
| 304 |
+
encoder_prompt = seq_group.encoder_prompt
|
| 305 |
+
encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
|
| 306 |
+
prompt_logprobs = seq_group.prompt_logprobs
|
| 307 |
+
else:
|
| 308 |
+
prompt = None
|
| 309 |
+
prompt_token_ids = None
|
| 310 |
+
encoder_prompt = None
|
| 311 |
+
encoder_prompt_token_ids = None
|
| 312 |
+
prompt_logprobs = None
|
| 313 |
+
finished_time = time.time() if finished else None
|
| 314 |
+
seq_group.set_finished_time(finished_time)
|
| 315 |
+
|
| 316 |
+
init_kwargs = {
|
| 317 |
+
"request_id": seq_group.request_id,
|
| 318 |
+
"prompt": prompt,
|
| 319 |
+
"prompt_token_ids": prompt_token_ids,
|
| 320 |
+
"prompt_logprobs": prompt_logprobs,
|
| 321 |
+
"outputs": outputs,
|
| 322 |
+
"finished": finished,
|
| 323 |
+
"metrics": seq_group.metrics,
|
| 324 |
+
"lora_request": seq_group.lora_request,
|
| 325 |
+
"encoder_prompt": encoder_prompt,
|
| 326 |
+
"encoder_prompt_token_ids": encoder_prompt_token_ids,
|
| 327 |
+
"num_cached_tokens": num_cached_tokens,
|
| 328 |
+
"multi_modal_placeholders": seq_group.multi_modal_placeholders
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
if use_cache:
|
| 332 |
+
request_output = seq_group.cached_request_output
|
| 333 |
+
request_output.__init__(**init_kwargs) # type: ignore
|
| 334 |
+
else:
|
| 335 |
+
request_output = cls(**init_kwargs) # type: ignore
|
| 336 |
+
|
| 337 |
+
return request_output
|
| 338 |
+
|
| 339 |
+
def __repr__(self) -> str:
|
| 340 |
+
return (f"RequestOutput(request_id={self.request_id}, "
|
| 341 |
+
f"prompt={self.prompt!r}, "
|
| 342 |
+
f"prompt_token_ids={self.prompt_token_ids}, "
|
| 343 |
+
f"encoder_prompt={self.encoder_prompt!r}, "
|
| 344 |
+
f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, "
|
| 345 |
+
f"prompt_logprobs={self.prompt_logprobs}, "
|
| 346 |
+
f"outputs={self.outputs}, "
|
| 347 |
+
f"finished={self.finished}, "
|
| 348 |
+
f"metrics={self.metrics}, "
|
| 349 |
+
f"lora_request={self.lora_request}, "
|
| 350 |
+
f"num_cached_tokens={self.num_cached_tokens}, "
|
| 351 |
+
f"multi_modal_placeholders={self.multi_modal_placeholders})")
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
_O = TypeVar("_O", default=PoolingOutput)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
class PoolingRequestOutput(Generic[_O]):
|
| 358 |
+
"""
|
| 359 |
+
The output data of a pooling request to the LLM.
|
| 360 |
+
|
| 361 |
+
Args:
|
| 362 |
+
request_id (str): A unique identifier for the pooling request.
|
| 363 |
+
outputs (PoolingOutput): The pooling results for the given input.
|
| 364 |
+
prompt_token_ids (List[int]): A list of token IDs used in the prompt.
|
| 365 |
+
finished (bool): A flag indicating whether the pooling is completed.
|
| 366 |
+
"""
|
| 367 |
+
|
| 368 |
+
def __init__(self, request_id: str, outputs: _O,
|
| 369 |
+
prompt_token_ids: List[int], finished: bool):
|
| 370 |
+
self.request_id = request_id
|
| 371 |
+
self.prompt_token_ids = prompt_token_ids
|
| 372 |
+
self.finished = finished
|
| 373 |
+
self.outputs = outputs
|
| 374 |
+
|
| 375 |
+
@staticmethod
|
| 376 |
+
def from_seq_group(seq_group: SequenceGroup) -> "PoolingRequestOutput":
|
| 377 |
+
pooled_data = seq_group.pooled_data
|
| 378 |
+
assert pooled_data is not None
|
| 379 |
+
|
| 380 |
+
data = pooled_data.to(dtype=torch.float32, device="cpu")
|
| 381 |
+
output = PoolingOutput(data)
|
| 382 |
+
prompt_token_ids = seq_group.prompt_token_ids
|
| 383 |
+
finished = seq_group.is_finished()
|
| 384 |
+
|
| 385 |
+
return PoolingRequestOutput(seq_group.request_id, output,
|
| 386 |
+
prompt_token_ids, finished)
|
| 387 |
+
|
| 388 |
+
def __repr__(self):
|
| 389 |
+
"""
|
| 390 |
+
Returns a string representation of an PoolingRequestOutput instance.
|
| 391 |
+
|
| 392 |
+
The representation includes the request_id and the number of outputs,
|
| 393 |
+
providing a quick overview of the pooling request's results.
|
| 394 |
+
|
| 395 |
+
Returns:
|
| 396 |
+
str: A string representation of the PoolingRequestOutput instance.
|
| 397 |
+
"""
|
| 398 |
+
return (f"{type(self).__name__}(request_id={self.request_id!r}, "
|
| 399 |
+
f"outputs={self.outputs!r}, "
|
| 400 |
+
f"prompt_token_ids={self.prompt_token_ids}, "
|
| 401 |
+
f"finished={self.finished})")
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
class RequestOutputFactory:
|
| 405 |
+
|
| 406 |
+
@staticmethod
|
| 407 |
+
def create(seq_group: SequenceGroup,
|
| 408 |
+
seq_id_to_seq_group: Dict[str, SequenceGroupBase],
|
| 409 |
+
use_cache: bool = False):
|
| 410 |
+
if seq_group.pooled_data is not None:
|
| 411 |
+
return PoolingRequestOutput.from_seq_group(seq_group)
|
| 412 |
+
else:
|
| 413 |
+
return RequestOutput.from_seq_group(seq_group, use_cache,
|
| 414 |
+
seq_id_to_seq_group)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
@dataclass
|
| 418 |
+
class EmbeddingOutput:
|
| 419 |
+
"""The output data of one embedding output of a request.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
embedding: The embedding vector, which is a list of floats.
|
| 423 |
+
Its length depends on the hidden dimension of the model.
|
| 424 |
+
"""
|
| 425 |
+
embedding: list[float]
|
| 426 |
+
|
| 427 |
+
@staticmethod
|
| 428 |
+
def from_base(pooling_output: PoolingOutput):
|
| 429 |
+
pooled_data = pooling_output.data
|
| 430 |
+
if pooled_data.ndim != 1:
|
| 431 |
+
raise ValueError("pooled_data should be a 1-D embedding vector")
|
| 432 |
+
|
| 433 |
+
return EmbeddingOutput(pooled_data.tolist())
|
| 434 |
+
|
| 435 |
+
@property
|
| 436 |
+
def hidden_size(self) -> int:
|
| 437 |
+
return len(self.embedding)
|
| 438 |
+
|
| 439 |
+
def __repr__(self) -> str:
|
| 440 |
+
return f"EmbeddingOutput(hidden_size={self.hidden_size})"
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
class EmbeddingRequestOutput(PoolingRequestOutput[EmbeddingOutput]):
|
| 444 |
+
|
| 445 |
+
@staticmethod
|
| 446 |
+
def from_base(request_output: PoolingRequestOutput):
|
| 447 |
+
return EmbeddingRequestOutput(
|
| 448 |
+
request_id=request_output.request_id,
|
| 449 |
+
outputs=EmbeddingOutput.from_base(request_output.outputs),
|
| 450 |
+
prompt_token_ids=request_output.prompt_token_ids,
|
| 451 |
+
finished=request_output.finished,
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
@dataclass
|
| 456 |
+
class ClassificationOutput:
|
| 457 |
+
"""The output data of one classification output of a request.
|
| 458 |
+
|
| 459 |
+
Args:
|
| 460 |
+
probs: The probability vector, which is a list of floats.
|
| 461 |
+
Its length depends on the number of classes.
|
| 462 |
+
"""
|
| 463 |
+
probs: list[float]
|
| 464 |
+
|
| 465 |
+
@staticmethod
|
| 466 |
+
def from_base(pooling_output: PoolingOutput):
|
| 467 |
+
pooled_data = pooling_output.data
|
| 468 |
+
if pooled_data.ndim != 1:
|
| 469 |
+
raise ValueError("pooled_data should be a 1-D probability vector")
|
| 470 |
+
|
| 471 |
+
return ClassificationOutput(pooled_data.tolist())
|
| 472 |
+
|
| 473 |
+
@property
|
| 474 |
+
def num_classes(self) -> int:
|
| 475 |
+
return len(self.probs)
|
| 476 |
+
|
| 477 |
+
def __repr__(self) -> str:
|
| 478 |
+
return f"ClassificationOutput(num_classes={self.num_classes})"
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
class ClassificationRequestOutput(PoolingRequestOutput[ClassificationOutput]):
|
| 482 |
+
|
| 483 |
+
@staticmethod
|
| 484 |
+
def from_base(request_output: PoolingRequestOutput):
|
| 485 |
+
return ClassificationRequestOutput(
|
| 486 |
+
request_id=request_output.request_id,
|
| 487 |
+
outputs=ClassificationOutput.from_base(request_output.outputs),
|
| 488 |
+
prompt_token_ids=request_output.prompt_token_ids,
|
| 489 |
+
finished=request_output.finished,
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
@dataclass
|
| 494 |
+
class ScoringOutput:
|
| 495 |
+
"""The output data of one scoring output of a request.
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
score: The similarity score, which is a scalar value.
|
| 499 |
+
"""
|
| 500 |
+
score: float
|
| 501 |
+
|
| 502 |
+
@staticmethod
|
| 503 |
+
def from_base(pooling_output: PoolingOutput):
|
| 504 |
+
pooled_data = pooling_output.data
|
| 505 |
+
if pooled_data.ndim != 0:
|
| 506 |
+
raise ValueError("pooled_data should be a scalar score")
|
| 507 |
+
|
| 508 |
+
return ScoringOutput(pooled_data.item())
|
| 509 |
+
|
| 510 |
+
def __repr__(self) -> str:
|
| 511 |
+
return f"ScoringOutput(score={self.score})"
|
| 512 |
+
|
| 513 |
+
@property
|
| 514 |
+
@deprecated("`LLM.score()` now returns scalar scores. "
|
| 515 |
+
"Please access it via the `score` attribute. ")
|
| 516 |
+
def embedding(self) -> list[float]:
|
| 517 |
+
return [self.score]
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]):
|
| 521 |
+
|
| 522 |
+
@staticmethod
|
| 523 |
+
def from_base(request_output: PoolingRequestOutput):
|
| 524 |
+
return ScoringRequestOutput(
|
| 525 |
+
request_id=request_output.request_id,
|
| 526 |
+
outputs=ScoringOutput.from_base(request_output.outputs),
|
| 527 |
+
prompt_token_ids=request_output.prompt_token_ids,
|
| 528 |
+
finished=request_output.finished,
|
| 529 |
+
)
|
.venv/lib/python3.11/site-packages/vllm/platforms/__init__.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import traceback
|
| 5 |
+
from itertools import chain
|
| 6 |
+
from typing import TYPE_CHECKING, Optional
|
| 7 |
+
|
| 8 |
+
from vllm.plugins import load_plugins_by_group
|
| 9 |
+
from vllm.utils import resolve_obj_by_qualname
|
| 10 |
+
|
| 11 |
+
from .interface import _Backend # noqa: F401
|
| 12 |
+
from .interface import CpuArchEnum, Platform, PlatformEnum
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def tpu_platform_plugin() -> Optional[str]:
|
| 18 |
+
is_tpu = False
|
| 19 |
+
try:
|
| 20 |
+
# While it's technically possible to install libtpu on a
|
| 21 |
+
# non-TPU machine, this is a very uncommon scenario. Therefore,
|
| 22 |
+
# we assume that libtpu is installed if and only if the machine
|
| 23 |
+
# has TPUs.
|
| 24 |
+
import libtpu # noqa: F401
|
| 25 |
+
is_tpu = True
|
| 26 |
+
except Exception:
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
return "vllm.platforms.tpu.TpuPlatform" if is_tpu else None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def cuda_platform_plugin() -> Optional[str]:
|
| 33 |
+
is_cuda = False
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
from vllm.utils import import_pynvml
|
| 37 |
+
pynvml = import_pynvml()
|
| 38 |
+
pynvml.nvmlInit()
|
| 39 |
+
try:
|
| 40 |
+
if pynvml.nvmlDeviceGetCount() > 0:
|
| 41 |
+
is_cuda = True
|
| 42 |
+
finally:
|
| 43 |
+
pynvml.nvmlShutdown()
|
| 44 |
+
except Exception as e:
|
| 45 |
+
if "nvml" not in e.__class__.__name__.lower():
|
| 46 |
+
# If the error is not related to NVML, re-raise it.
|
| 47 |
+
raise e
|
| 48 |
+
|
| 49 |
+
# CUDA is supported on Jetson, but NVML may not be.
|
| 50 |
+
import os
|
| 51 |
+
|
| 52 |
+
def cuda_is_jetson() -> bool:
|
| 53 |
+
return os.path.isfile("/etc/nv_tegra_release") \
|
| 54 |
+
or os.path.exists("/sys/class/tegra-firmware")
|
| 55 |
+
|
| 56 |
+
if cuda_is_jetson():
|
| 57 |
+
is_cuda = True
|
| 58 |
+
|
| 59 |
+
return "vllm.platforms.cuda.CudaPlatform" if is_cuda else None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def rocm_platform_plugin() -> Optional[str]:
|
| 63 |
+
is_rocm = False
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
import amdsmi
|
| 67 |
+
amdsmi.amdsmi_init()
|
| 68 |
+
try:
|
| 69 |
+
if len(amdsmi.amdsmi_get_processor_handles()) > 0:
|
| 70 |
+
is_rocm = True
|
| 71 |
+
finally:
|
| 72 |
+
amdsmi.amdsmi_shut_down()
|
| 73 |
+
except Exception:
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
return "vllm.platforms.rocm.RocmPlatform" if is_rocm else None
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def hpu_platform_plugin() -> Optional[str]:
|
| 80 |
+
is_hpu = False
|
| 81 |
+
try:
|
| 82 |
+
from importlib import util
|
| 83 |
+
is_hpu = util.find_spec('habana_frameworks') is not None
|
| 84 |
+
except Exception:
|
| 85 |
+
pass
|
| 86 |
+
|
| 87 |
+
return "vllm.platforms.hpu.HpuPlatform" if is_hpu else None
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def xpu_platform_plugin() -> Optional[str]:
|
| 91 |
+
is_xpu = False
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
# installed IPEX if the machine has XPUs.
|
| 95 |
+
import intel_extension_for_pytorch # noqa: F401
|
| 96 |
+
import oneccl_bindings_for_pytorch # noqa: F401
|
| 97 |
+
import torch
|
| 98 |
+
if hasattr(torch, 'xpu') and torch.xpu.is_available():
|
| 99 |
+
is_xpu = True
|
| 100 |
+
except Exception:
|
| 101 |
+
pass
|
| 102 |
+
|
| 103 |
+
return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def cpu_platform_plugin() -> Optional[str]:
|
| 107 |
+
is_cpu = False
|
| 108 |
+
try:
|
| 109 |
+
from importlib.metadata import version
|
| 110 |
+
is_cpu = "cpu" in version("vllm")
|
| 111 |
+
if not is_cpu:
|
| 112 |
+
import platform
|
| 113 |
+
is_cpu = platform.machine().lower().startswith("arm")
|
| 114 |
+
|
| 115 |
+
except Exception:
|
| 116 |
+
pass
|
| 117 |
+
|
| 118 |
+
return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def neuron_platform_plugin() -> Optional[str]:
|
| 122 |
+
is_neuron = False
|
| 123 |
+
try:
|
| 124 |
+
import transformers_neuronx # noqa: F401
|
| 125 |
+
is_neuron = True
|
| 126 |
+
except ImportError:
|
| 127 |
+
pass
|
| 128 |
+
|
| 129 |
+
return "vllm.platforms.neuron.NeuronPlatform" if is_neuron else None
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def openvino_platform_plugin() -> Optional[str]:
|
| 133 |
+
is_openvino = False
|
| 134 |
+
try:
|
| 135 |
+
from importlib.metadata import version
|
| 136 |
+
is_openvino = "openvino" in version("vllm")
|
| 137 |
+
except Exception:
|
| 138 |
+
pass
|
| 139 |
+
|
| 140 |
+
return "vllm.platforms.openvino.OpenVinoPlatform" if is_openvino else None
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
builtin_platform_plugins = {
|
| 144 |
+
'tpu': tpu_platform_plugin,
|
| 145 |
+
'cuda': cuda_platform_plugin,
|
| 146 |
+
'rocm': rocm_platform_plugin,
|
| 147 |
+
'hpu': hpu_platform_plugin,
|
| 148 |
+
'xpu': xpu_platform_plugin,
|
| 149 |
+
'cpu': cpu_platform_plugin,
|
| 150 |
+
'neuron': neuron_platform_plugin,
|
| 151 |
+
'openvino': openvino_platform_plugin,
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def resolve_current_platform_cls_qualname() -> str:
|
| 156 |
+
platform_plugins = load_plugins_by_group('vllm.platform_plugins')
|
| 157 |
+
|
| 158 |
+
activated_plugins = []
|
| 159 |
+
|
| 160 |
+
for name, func in chain(builtin_platform_plugins.items(),
|
| 161 |
+
platform_plugins.items()):
|
| 162 |
+
try:
|
| 163 |
+
assert callable(func)
|
| 164 |
+
platform_cls_qualname = func()
|
| 165 |
+
if platform_cls_qualname is not None:
|
| 166 |
+
activated_plugins.append(name)
|
| 167 |
+
except Exception:
|
| 168 |
+
pass
|
| 169 |
+
|
| 170 |
+
activated_builtin_plugins = list(
|
| 171 |
+
set(activated_plugins) & set(builtin_platform_plugins.keys()))
|
| 172 |
+
activated_oot_plugins = list(
|
| 173 |
+
set(activated_plugins) & set(platform_plugins.keys()))
|
| 174 |
+
|
| 175 |
+
if len(activated_oot_plugins) >= 2:
|
| 176 |
+
raise RuntimeError(
|
| 177 |
+
"Only one platform plugin can be activated, but got: "
|
| 178 |
+
f"{activated_oot_plugins}")
|
| 179 |
+
elif len(activated_oot_plugins) == 1:
|
| 180 |
+
platform_cls_qualname = platform_plugins[activated_oot_plugins[0]]()
|
| 181 |
+
logger.info("Platform plugin %s is activated",
|
| 182 |
+
activated_oot_plugins[0])
|
| 183 |
+
elif len(activated_builtin_plugins) >= 2:
|
| 184 |
+
raise RuntimeError(
|
| 185 |
+
"Only one platform plugin can be activated, but got: "
|
| 186 |
+
f"{activated_builtin_plugins}")
|
| 187 |
+
elif len(activated_builtin_plugins) == 1:
|
| 188 |
+
platform_cls_qualname = builtin_platform_plugins[
|
| 189 |
+
activated_builtin_plugins[0]]()
|
| 190 |
+
logger.info("Automatically detected platform %s.",
|
| 191 |
+
activated_builtin_plugins[0])
|
| 192 |
+
else:
|
| 193 |
+
platform_cls_qualname = "vllm.platforms.interface.UnspecifiedPlatform"
|
| 194 |
+
logger.info(
|
| 195 |
+
"No platform detected, vLLM is running on UnspecifiedPlatform")
|
| 196 |
+
return platform_cls_qualname
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
_current_platform = None
|
| 200 |
+
_init_trace: str = ''
|
| 201 |
+
|
| 202 |
+
if TYPE_CHECKING:
|
| 203 |
+
current_platform: Platform
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def __getattr__(name: str):
|
| 207 |
+
if name == 'current_platform':
|
| 208 |
+
# lazy init current_platform.
|
| 209 |
+
# 1. out-of-tree platform plugins need `from vllm.platforms import
|
| 210 |
+
# Platform` so that they can inherit `Platform` class. Therefore,
|
| 211 |
+
# we cannot resolve `current_platform` during the import of
|
| 212 |
+
# `vllm.platforms`.
|
| 213 |
+
# 2. when users use out-of-tree platform plugins, they might run
|
| 214 |
+
# `import vllm`, some vllm internal code might access
|
| 215 |
+
# `current_platform` during the import, and we need to make sure
|
| 216 |
+
# `current_platform` is only resolved after the plugins are loaded
|
| 217 |
+
# (we have tests for this, if any developer violate this, they will
|
| 218 |
+
# see the test failures).
|
| 219 |
+
global _current_platform
|
| 220 |
+
if _current_platform is None:
|
| 221 |
+
platform_cls_qualname = resolve_current_platform_cls_qualname()
|
| 222 |
+
_current_platform = resolve_obj_by_qualname(
|
| 223 |
+
platform_cls_qualname)()
|
| 224 |
+
global _init_trace
|
| 225 |
+
_init_trace = "".join(traceback.format_stack())
|
| 226 |
+
return _current_platform
|
| 227 |
+
elif name in globals():
|
| 228 |
+
return globals()[name]
|
| 229 |
+
else:
|
| 230 |
+
raise AttributeError(
|
| 231 |
+
f"No attribute named '{name}' exists in {__name__}.")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
__all__ = [
|
| 235 |
+
'Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum',
|
| 236 |
+
"_init_trace"
|
| 237 |
+
]
|
.venv/lib/python3.11/site-packages/vllm/platforms/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (9.54 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/platforms/__pycache__/cpu.cpython-311.pyc
ADDED
|
Binary file (6.86 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/platforms/__pycache__/cuda.cpython-311.pyc
ADDED
|
Binary file (20.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/platforms/__pycache__/hpu.cpython-311.pyc
ADDED
|
Binary file (4.86 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/platforms/__pycache__/interface.cpython-311.pyc
ADDED
|
Binary file (15.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/platforms/__pycache__/neuron.cpython-311.pyc
ADDED
|
Binary file (3.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/platforms/__pycache__/openvino.cpython-311.pyc
ADDED
|
Binary file (8.21 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/platforms/__pycache__/rocm.cpython-311.pyc
ADDED
|
Binary file (9.38 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/platforms/__pycache__/tpu.cpython-311.pyc
ADDED
|
Binary file (5.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/platforms/__pycache__/xpu.cpython-311.pyc
ADDED
|
Binary file (7.59 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/vllm/platforms/cpu.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from typing import TYPE_CHECKING, Optional
|
| 5 |
+
|
| 6 |
+
import psutil
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from vllm.logger import init_logger
|
| 10 |
+
|
| 11 |
+
from .interface import Platform, PlatformEnum, _Backend
|
| 12 |
+
|
| 13 |
+
logger = init_logger(__name__)
|
| 14 |
+
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
from vllm.config import VllmConfig
|
| 17 |
+
else:
|
| 18 |
+
VllmConfig = None
|
| 19 |
+
|
| 20 |
+
logger = init_logger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CpuPlatform(Platform):
|
| 24 |
+
_enum = PlatformEnum.CPU
|
| 25 |
+
device_name: str = "cpu"
|
| 26 |
+
device_type: str = "cpu"
|
| 27 |
+
dispatch_key: str = "CPU"
|
| 28 |
+
|
| 29 |
+
@classmethod
|
| 30 |
+
def get_device_name(cls, device_id: int = 0) -> str:
|
| 31 |
+
return "cpu"
|
| 32 |
+
|
| 33 |
+
@classmethod
|
| 34 |
+
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
| 35 |
+
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
| 36 |
+
block_size: int, use_v1: bool,
|
| 37 |
+
use_mla: bool) -> str:
|
| 38 |
+
if selected_backend != _Backend.TORCH_SDPA:
|
| 39 |
+
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
| 40 |
+
logger.info("Using Torch SDPA backend.")
|
| 41 |
+
return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
|
| 42 |
+
|
| 43 |
+
@classmethod
|
| 44 |
+
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
| 45 |
+
return psutil.virtual_memory().total
|
| 46 |
+
|
| 47 |
+
@classmethod
|
| 48 |
+
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
| 49 |
+
return False
|
| 50 |
+
|
| 51 |
+
@classmethod
|
| 52 |
+
def inference_mode(cls):
|
| 53 |
+
return torch.no_grad()
|
| 54 |
+
|
| 55 |
+
@classmethod
|
| 56 |
+
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
| 57 |
+
import vllm.envs as envs
|
| 58 |
+
from vllm.utils import GiB_bytes
|
| 59 |
+
model_config = vllm_config.model_config
|
| 60 |
+
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
| 61 |
+
# If the feature combo become valid
|
| 62 |
+
if not model_config.enforce_eager:
|
| 63 |
+
logger.warning(
|
| 64 |
+
"CUDA graph is not supported on CPU, fallback to the eager "
|
| 65 |
+
"mode.")
|
| 66 |
+
model_config.enforce_eager = True
|
| 67 |
+
|
| 68 |
+
cache_config = vllm_config.cache_config
|
| 69 |
+
|
| 70 |
+
if cache_config and cache_config.block_size is None:
|
| 71 |
+
cache_config.block_size = 16
|
| 72 |
+
|
| 73 |
+
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
|
| 74 |
+
|
| 75 |
+
if kv_cache_space >= 0:
|
| 76 |
+
if kv_cache_space == 0:
|
| 77 |
+
cache_config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
|
| 78 |
+
logger.warning(
|
| 79 |
+
"Environment variable VLLM_CPU_KVCACHE_SPACE (GB) "
|
| 80 |
+
"for CPU backend is not set, using 4 by default.")
|
| 81 |
+
else:
|
| 82 |
+
cache_config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore # noqa
|
| 83 |
+
else:
|
| 84 |
+
raise RuntimeError(
|
| 85 |
+
"Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
|
| 86 |
+
f" {kv_cache_space}, expect a positive integer value.")
|
| 87 |
+
|
| 88 |
+
scheduler_config = vllm_config.scheduler_config
|
| 89 |
+
if ((scheduler_config.chunked_prefill_enabled
|
| 90 |
+
or cache_config.enable_prefix_caching)
|
| 91 |
+
and model_config.dtype == torch.half):
|
| 92 |
+
logger.warning("Chunked-prefill on the CPU backend only does not"
|
| 93 |
+
" support fp16 for now, cast to bf16.")
|
| 94 |
+
model_config.dtype = torch.bfloat16
|
| 95 |
+
|
| 96 |
+
parallel_config = vllm_config.parallel_config
|
| 97 |
+
if (parallel_config.distributed_executor_backend is not None
|
| 98 |
+
and parallel_config.distributed_executor_backend != "mp"):
|
| 99 |
+
logger.warning(("%s is not supported on CPU, fallback to mp "
|
| 100 |
+
"distributed executor backend."),
|
| 101 |
+
parallel_config.distributed_executor_backend)
|
| 102 |
+
parallel_config.distributed_executor_backend = "mp"
|
| 103 |
+
if parallel_config.worker_cls == "auto":
|
| 104 |
+
if vllm_config.speculative_config:
|
| 105 |
+
parallel_config.worker_cls = \
|
| 106 |
+
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
| 107 |
+
parallel_config.sd_worker_cls = \
|
| 108 |
+
"vllm.worker.cpu_worker.CPUWorker"
|
| 109 |
+
else:
|
| 110 |
+
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
|
| 111 |
+
|
| 112 |
+
assert vllm_config.device_config.device_type == "cpu"
|
| 113 |
+
|
| 114 |
+
#
|
| 115 |
+
# Environment variables for CPU executor
|
| 116 |
+
#
|
| 117 |
+
|
| 118 |
+
# Disable torch async compiling which won't work with daemonic processes
|
| 119 |
+
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
|
| 120 |
+
|
| 121 |
+
# Intel OpenMP setting
|
| 122 |
+
ld_prealod_str = os.getenv("LD_PRELOAD", "")
|
| 123 |
+
if "libiomp5.so" in ld_prealod_str:
|
| 124 |
+
# The time(milliseconds) that a thread should wait after
|
| 125 |
+
# completing the execution of a parallel region, before sleeping.
|
| 126 |
+
os.environ['KMP_BLOCKTIME'] = "1"
|
| 127 |
+
# Prevents the CPU to run into low performance state
|
| 128 |
+
os.environ['KMP_TPAUSE'] = "0"
|
| 129 |
+
# Provides fine granularity parallelism
|
| 130 |
+
os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist"
|
| 131 |
+
os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist"
|
| 132 |
+
os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist"
|
| 133 |
+
|
| 134 |
+
# To hint IPEX uses shared memory based AllReduce
|
| 135 |
+
os.environ["LOCAL_WORLD_SIZE"] = str(
|
| 136 |
+
vllm_config.parallel_config.tensor_parallel_size)
|
| 137 |
+
|
| 138 |
+
@classmethod
|
| 139 |
+
def is_pin_memory_available(cls) -> bool:
|
| 140 |
+
logger.warning("Pin memory is not supported on CPU.")
|
| 141 |
+
return False
|
| 142 |
+
|
| 143 |
+
@classmethod
|
| 144 |
+
def get_punica_wrapper(cls) -> str:
|
| 145 |
+
return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
|
.venv/lib/python3.11/site-packages/vllm/platforms/cuda.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Code inside this file can safely assume cuda platform, e.g. importing
|
| 3 |
+
pynvml. However, it should not initialize cuda context.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from functools import lru_cache, wraps
|
| 8 |
+
from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar,
|
| 9 |
+
Union)
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from typing_extensions import ParamSpec
|
| 13 |
+
|
| 14 |
+
# import custom ops, trigger op registration
|
| 15 |
+
import vllm._C # noqa
|
| 16 |
+
import vllm.envs as envs
|
| 17 |
+
from vllm.logger import init_logger
|
| 18 |
+
from vllm.utils import import_pynvml
|
| 19 |
+
|
| 20 |
+
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
| 21 |
+
|
| 22 |
+
if TYPE_CHECKING:
|
| 23 |
+
from vllm.config import VllmConfig
|
| 24 |
+
else:
|
| 25 |
+
VllmConfig = None
|
| 26 |
+
|
| 27 |
+
logger = init_logger(__name__)
|
| 28 |
+
|
| 29 |
+
_P = ParamSpec("_P")
|
| 30 |
+
_R = TypeVar("_R")
|
| 31 |
+
|
| 32 |
+
pynvml = import_pynvml()
|
| 33 |
+
|
| 34 |
+
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
|
| 35 |
+
# see https://github.com/huggingface/diffusers/issues/9704 for details
|
| 36 |
+
torch.backends.cuda.enable_cudnn_sdp(False)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def device_id_to_physical_device_id(device_id: int) -> int:
|
| 40 |
+
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
| 41 |
+
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
| 42 |
+
if device_ids == [""]:
|
| 43 |
+
msg = (
|
| 44 |
+
"CUDA_VISIBLE_DEVICES is set to empty string, which means"
|
| 45 |
+
" GPU support is disabled. If you are using ray, please unset"
|
| 46 |
+
" the environment variable `CUDA_VISIBLE_DEVICES` inside the"
|
| 47 |
+
" worker/actor. "
|
| 48 |
+
"Check https://github.com/vllm-project/vllm/issues/8402 for"
|
| 49 |
+
" more information.")
|
| 50 |
+
raise RuntimeError(msg)
|
| 51 |
+
physical_device_id = device_ids[device_id]
|
| 52 |
+
return int(physical_device_id)
|
| 53 |
+
else:
|
| 54 |
+
return device_id
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
| 58 |
+
|
| 59 |
+
@wraps(fn)
|
| 60 |
+
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
| 61 |
+
pynvml.nvmlInit()
|
| 62 |
+
try:
|
| 63 |
+
return fn(*args, **kwargs)
|
| 64 |
+
finally:
|
| 65 |
+
pynvml.nvmlShutdown()
|
| 66 |
+
|
| 67 |
+
return wrapper
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class CudaPlatformBase(Platform):
|
| 71 |
+
_enum = PlatformEnum.CUDA
|
| 72 |
+
device_name: str = "cuda"
|
| 73 |
+
device_type: str = "cuda"
|
| 74 |
+
dispatch_key: str = "CUDA"
|
| 75 |
+
ray_device_key: str = "GPU"
|
| 76 |
+
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
|
| 77 |
+
|
| 78 |
+
@classmethod
|
| 79 |
+
def get_device_capability(cls,
|
| 80 |
+
device_id: int = 0
|
| 81 |
+
) -> Optional[DeviceCapability]:
|
| 82 |
+
raise NotImplementedError
|
| 83 |
+
|
| 84 |
+
@classmethod
|
| 85 |
+
def get_device_name(cls, device_id: int = 0) -> str:
|
| 86 |
+
raise NotImplementedError
|
| 87 |
+
|
| 88 |
+
@classmethod
|
| 89 |
+
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
| 90 |
+
raise NotImplementedError
|
| 91 |
+
|
| 92 |
+
@classmethod
|
| 93 |
+
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
| 94 |
+
if enforce_eager:
|
| 95 |
+
logger.warning(
|
| 96 |
+
"To see benefits of async output processing, enable CUDA "
|
| 97 |
+
"graph. Since, enforce-eager is enabled, async output "
|
| 98 |
+
"processor cannot be used")
|
| 99 |
+
return False
|
| 100 |
+
return True
|
| 101 |
+
|
| 102 |
+
@classmethod
|
| 103 |
+
def is_full_nvlink(cls, device_ids: List[int]) -> bool:
|
| 104 |
+
raise NotImplementedError
|
| 105 |
+
|
| 106 |
+
@classmethod
|
| 107 |
+
def log_warnings(cls):
|
| 108 |
+
pass
|
| 109 |
+
|
| 110 |
+
@classmethod
|
| 111 |
+
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
| 112 |
+
parallel_config = vllm_config.parallel_config
|
| 113 |
+
scheduler_config = vllm_config.scheduler_config
|
| 114 |
+
|
| 115 |
+
if parallel_config.worker_cls == "auto":
|
| 116 |
+
if scheduler_config.is_multi_step:
|
| 117 |
+
if envs.VLLM_USE_V1:
|
| 118 |
+
raise NotImplementedError(
|
| 119 |
+
"Multi-step scheduling is not supported (and not "
|
| 120 |
+
"needed) on VLLM V1. Please launch without "
|
| 121 |
+
"--num-scheduler-steps.")
|
| 122 |
+
else:
|
| 123 |
+
parallel_config.worker_cls = \
|
| 124 |
+
"vllm.worker.multi_step_worker.MultiStepWorker"
|
| 125 |
+
elif vllm_config.speculative_config:
|
| 126 |
+
if envs.VLLM_USE_V1:
|
| 127 |
+
raise NotImplementedError(
|
| 128 |
+
"Speculative decoding is not yet supported on VLLM V1."
|
| 129 |
+
)
|
| 130 |
+
else:
|
| 131 |
+
parallel_config.worker_cls = \
|
| 132 |
+
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
| 133 |
+
parallel_config.sd_worker_cls = \
|
| 134 |
+
"vllm.worker.worker.Worker"
|
| 135 |
+
else:
|
| 136 |
+
if envs.VLLM_USE_V1:
|
| 137 |
+
parallel_config.worker_cls = \
|
| 138 |
+
"vllm.v1.worker.gpu_worker.Worker"
|
| 139 |
+
else:
|
| 140 |
+
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
| 141 |
+
|
| 142 |
+
cache_config = vllm_config.cache_config
|
| 143 |
+
if cache_config and cache_config.block_size is None:
|
| 144 |
+
cache_config.block_size = 16
|
| 145 |
+
|
| 146 |
+
@classmethod
|
| 147 |
+
def get_current_memory_usage(cls,
|
| 148 |
+
device: Optional[torch.types.Device] = None
|
| 149 |
+
) -> float:
|
| 150 |
+
torch.cuda.reset_peak_memory_stats(device)
|
| 151 |
+
return torch.cuda.max_memory_allocated(device)
|
| 152 |
+
|
| 153 |
+
@classmethod
|
| 154 |
+
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
| 155 |
+
kv_cache_dtype, block_size, use_v1,
|
| 156 |
+
use_mla) -> str:
|
| 157 |
+
if use_v1:
|
| 158 |
+
logger.info("Using Flash Attention backend on V1 engine.")
|
| 159 |
+
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
| 160 |
+
if use_mla:
|
| 161 |
+
logger.info("Using Triton MLA backend.")
|
| 162 |
+
return "vllm.attention.backends.triton_mla.TritonMLABackend"
|
| 163 |
+
if selected_backend == _Backend.FLASHINFER:
|
| 164 |
+
logger.info("Using FlashInfer backend.")
|
| 165 |
+
return "vllm.attention.backends.flashinfer.FlashInferBackend"
|
| 166 |
+
elif selected_backend == _Backend.XFORMERS:
|
| 167 |
+
logger.info("Using XFormers backend.")
|
| 168 |
+
return "vllm.attention.backends.xformers.XFormersBackend"
|
| 169 |
+
elif selected_backend == _Backend.FLASH_ATTN:
|
| 170 |
+
pass
|
| 171 |
+
elif selected_backend:
|
| 172 |
+
raise ValueError(
|
| 173 |
+
f"Invalid attention backend for {cls.device_name}, "
|
| 174 |
+
f"with use_v1: {use_v1} use_mla: {use_mla}")
|
| 175 |
+
|
| 176 |
+
target_backend = _Backend.FLASH_ATTN
|
| 177 |
+
if not cls.has_device_capability(80):
|
| 178 |
+
# Volta and Turing NVIDIA GPUs.
|
| 179 |
+
logger.info(
|
| 180 |
+
"Cannot use FlashAttention-2 backend for Volta and Turing "
|
| 181 |
+
"GPUs.")
|
| 182 |
+
target_backend = _Backend.XFORMERS
|
| 183 |
+
elif dtype not in (torch.float16, torch.bfloat16):
|
| 184 |
+
logger.info(
|
| 185 |
+
"Cannot use FlashAttention-2 backend for dtype other than "
|
| 186 |
+
"torch.float16 or torch.bfloat16.")
|
| 187 |
+
target_backend = _Backend.XFORMERS
|
| 188 |
+
elif kv_cache_dtype is not None and \
|
| 189 |
+
kv_cache_dtype.startswith("fp8"):
|
| 190 |
+
logger.info(
|
| 191 |
+
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
|
| 192 |
+
logger.warning(
|
| 193 |
+
"Please use FlashInfer backend with FP8 KV Cache for "
|
| 194 |
+
"better performance by setting environment variable "
|
| 195 |
+
"VLLM_ATTENTION_BACKEND=FLASHINFER")
|
| 196 |
+
target_backend = _Backend.XFORMERS
|
| 197 |
+
elif block_size % 16 != 0:
|
| 198 |
+
logger.info(
|
| 199 |
+
"Cannot use FlashAttention-2 backend for block size not "
|
| 200 |
+
"divisible by 16.")
|
| 201 |
+
target_backend = _Backend.XFORMERS
|
| 202 |
+
|
| 203 |
+
# FlashAttn is valid for the model, checking if the package is
|
| 204 |
+
# installed.
|
| 205 |
+
if target_backend == _Backend.FLASH_ATTN:
|
| 206 |
+
try:
|
| 207 |
+
import vllm.vllm_flash_attn # noqa: F401
|
| 208 |
+
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
| 209 |
+
FlashAttentionBackend)
|
| 210 |
+
|
| 211 |
+
supported_sizes = \
|
| 212 |
+
FlashAttentionBackend.get_supported_head_sizes()
|
| 213 |
+
if head_size not in supported_sizes:
|
| 214 |
+
logger.info(
|
| 215 |
+
"Cannot use FlashAttention-2 backend for head size %d.",
|
| 216 |
+
head_size)
|
| 217 |
+
target_backend = _Backend.XFORMERS
|
| 218 |
+
except ImportError:
|
| 219 |
+
logger.info(
|
| 220 |
+
"Cannot use FlashAttention-2 backend because the "
|
| 221 |
+
"vllm.vllm_flash_attn package is not found. "
|
| 222 |
+
"Make sure that vllm_flash_attn was built and installed "
|
| 223 |
+
"(on by default).")
|
| 224 |
+
target_backend = _Backend.XFORMERS
|
| 225 |
+
|
| 226 |
+
if target_backend == _Backend.XFORMERS:
|
| 227 |
+
logger.info("Using XFormers backend.")
|
| 228 |
+
return "vllm.attention.backends.xformers.XFormersBackend"
|
| 229 |
+
|
| 230 |
+
logger.info("Using Flash Attention backend.")
|
| 231 |
+
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"
|
| 232 |
+
|
| 233 |
+
@classmethod
|
| 234 |
+
def get_punica_wrapper(cls) -> str:
|
| 235 |
+
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# NVML utils
|
| 239 |
+
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
| 240 |
+
# all the related functions work on real physical device ids.
|
| 241 |
+
# the major benefit of using NVML is that it will not initialize CUDA
|
| 242 |
+
class NvmlCudaPlatform(CudaPlatformBase):
|
| 243 |
+
|
| 244 |
+
@classmethod
|
| 245 |
+
@lru_cache(maxsize=8)
|
| 246 |
+
@with_nvml_context
|
| 247 |
+
def get_device_capability(cls,
|
| 248 |
+
device_id: int = 0
|
| 249 |
+
) -> Optional[DeviceCapability]:
|
| 250 |
+
try:
|
| 251 |
+
physical_device_id = device_id_to_physical_device_id(device_id)
|
| 252 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
| 253 |
+
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
| 254 |
+
return DeviceCapability(major=major, minor=minor)
|
| 255 |
+
except RuntimeError:
|
| 256 |
+
return None
|
| 257 |
+
|
| 258 |
+
@classmethod
|
| 259 |
+
@lru_cache(maxsize=8)
|
| 260 |
+
@with_nvml_context
|
| 261 |
+
def has_device_capability(
|
| 262 |
+
cls,
|
| 263 |
+
capability: Union[Tuple[int, int], int],
|
| 264 |
+
device_id: int = 0,
|
| 265 |
+
) -> bool:
|
| 266 |
+
try:
|
| 267 |
+
return super().has_device_capability(capability, device_id)
|
| 268 |
+
except RuntimeError:
|
| 269 |
+
return False
|
| 270 |
+
|
| 271 |
+
@classmethod
|
| 272 |
+
@lru_cache(maxsize=8)
|
| 273 |
+
@with_nvml_context
|
| 274 |
+
def get_device_name(cls, device_id: int = 0) -> str:
|
| 275 |
+
physical_device_id = device_id_to_physical_device_id(device_id)
|
| 276 |
+
return cls._get_physical_device_name(physical_device_id)
|
| 277 |
+
|
| 278 |
+
@classmethod
|
| 279 |
+
@lru_cache(maxsize=8)
|
| 280 |
+
@with_nvml_context
|
| 281 |
+
def get_device_uuid(cls, device_id: int = 0) -> str:
|
| 282 |
+
physical_device_id = device_id_to_physical_device_id(device_id)
|
| 283 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
| 284 |
+
return pynvml.nvmlDeviceGetUUID(handle)
|
| 285 |
+
|
| 286 |
+
@classmethod
|
| 287 |
+
@lru_cache(maxsize=8)
|
| 288 |
+
@with_nvml_context
|
| 289 |
+
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
| 290 |
+
physical_device_id = device_id_to_physical_device_id(device_id)
|
| 291 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
| 292 |
+
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
|
| 293 |
+
|
| 294 |
+
@classmethod
|
| 295 |
+
@with_nvml_context
|
| 296 |
+
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
|
| 297 |
+
"""
|
| 298 |
+
query if the set of gpus are fully connected by nvlink (1 hop)
|
| 299 |
+
"""
|
| 300 |
+
handles = [
|
| 301 |
+
pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
|
| 302 |
+
]
|
| 303 |
+
for i, handle in enumerate(handles):
|
| 304 |
+
for j, peer_handle in enumerate(handles):
|
| 305 |
+
if i < j:
|
| 306 |
+
try:
|
| 307 |
+
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
| 308 |
+
handle,
|
| 309 |
+
peer_handle,
|
| 310 |
+
pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
|
| 311 |
+
)
|
| 312 |
+
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
| 313 |
+
return False
|
| 314 |
+
except pynvml.NVMLError:
|
| 315 |
+
logger.exception(
|
| 316 |
+
"NVLink detection failed. This is normal if"
|
| 317 |
+
" your machine has no NVLink equipped.")
|
| 318 |
+
return False
|
| 319 |
+
return True
|
| 320 |
+
|
| 321 |
+
@classmethod
|
| 322 |
+
def _get_physical_device_name(cls, device_id: int = 0) -> str:
|
| 323 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
| 324 |
+
return pynvml.nvmlDeviceGetName(handle)
|
| 325 |
+
|
| 326 |
+
@classmethod
|
| 327 |
+
@with_nvml_context
|
| 328 |
+
def log_warnings(cls):
|
| 329 |
+
device_ids: int = pynvml.nvmlDeviceGetCount()
|
| 330 |
+
if device_ids > 1:
|
| 331 |
+
device_names = [
|
| 332 |
+
cls._get_physical_device_name(i) for i in range(device_ids)
|
| 333 |
+
]
|
| 334 |
+
if (len(set(device_names)) > 1
|
| 335 |
+
and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
|
| 336 |
+
logger.warning(
|
| 337 |
+
"Detected different devices in the system: \n%s\nPlease"
|
| 338 |
+
" make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
|
| 339 |
+
"avoid unexpected behavior.",
|
| 340 |
+
"\n".join(device_names),
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class NonNvmlCudaPlatform(CudaPlatformBase):
|
| 345 |
+
|
| 346 |
+
@classmethod
|
| 347 |
+
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
|
| 348 |
+
major, minor = torch.cuda.get_device_capability(device_id)
|
| 349 |
+
return DeviceCapability(major=major, minor=minor)
|
| 350 |
+
|
| 351 |
+
@classmethod
|
| 352 |
+
def get_device_name(cls, device_id: int = 0) -> str:
|
| 353 |
+
return torch.cuda.get_device_name(device_id)
|
| 354 |
+
|
| 355 |
+
@classmethod
|
| 356 |
+
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
| 357 |
+
device_props = torch.cuda.get_device_properties(device_id)
|
| 358 |
+
return device_props.total_memory
|
| 359 |
+
|
| 360 |
+
@classmethod
|
| 361 |
+
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
|
| 362 |
+
logger.exception(
|
| 363 |
+
"NVLink detection not possible, as context support was"
|
| 364 |
+
" not found. Assuming no NVLink available.")
|
| 365 |
+
return False
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
# Autodetect either NVML-enabled or non-NVML platform
|
| 369 |
+
# based on whether NVML is available.
|
| 370 |
+
nvml_available = False
|
| 371 |
+
try:
|
| 372 |
+
try:
|
| 373 |
+
pynvml.nvmlInit()
|
| 374 |
+
nvml_available = True
|
| 375 |
+
except Exception:
|
| 376 |
+
# On Jetson, NVML is not supported.
|
| 377 |
+
nvml_available = False
|
| 378 |
+
finally:
|
| 379 |
+
if nvml_available:
|
| 380 |
+
pynvml.nvmlShutdown()
|
| 381 |
+
|
| 382 |
+
CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform
|
| 383 |
+
|
| 384 |
+
try:
|
| 385 |
+
from sphinx.ext.autodoc.mock import _MockModule
|
| 386 |
+
|
| 387 |
+
if not isinstance(pynvml, _MockModule):
|
| 388 |
+
CudaPlatform.log_warnings()
|
| 389 |
+
except ModuleNotFoundError:
|
| 390 |
+
CudaPlatform.log_warnings()
|
.venv/lib/python3.11/site-packages/vllm/platforms/hpu.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from typing import TYPE_CHECKING, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from vllm import envs
|
| 9 |
+
from vllm.logger import init_logger
|
| 10 |
+
|
| 11 |
+
from .interface import Platform, PlatformEnum, _Backend
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING:
|
| 14 |
+
from vllm.config import VllmConfig
|
| 15 |
+
else:
|
| 16 |
+
VllmConfig = None
|
| 17 |
+
|
| 18 |
+
logger = init_logger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class HpuPlatform(Platform):
|
| 22 |
+
_enum = PlatformEnum.HPU
|
| 23 |
+
device_name: str = "hpu"
|
| 24 |
+
device_type: str = "hpu"
|
| 25 |
+
dispatch_key: str = "HPU"
|
| 26 |
+
ray_device_key: str = "HPU"
|
| 27 |
+
device_control_env_var: str = "HABANA_VISIBLE_MODULES"
|
| 28 |
+
|
| 29 |
+
@classmethod
|
| 30 |
+
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
| 31 |
+
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
| 32 |
+
block_size: int, use_v1: bool,
|
| 33 |
+
use_mla: bool) -> str:
|
| 34 |
+
logger.info("Using HPUAttention backend.")
|
| 35 |
+
return "vllm.attention.backends.hpu_attn.HPUAttentionBackend"
|
| 36 |
+
|
| 37 |
+
@classmethod
|
| 38 |
+
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
| 39 |
+
return True
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
def inference_mode():
|
| 43 |
+
return torch.no_grad()
|
| 44 |
+
|
| 45 |
+
@classmethod
|
| 46 |
+
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
| 47 |
+
|
| 48 |
+
scheduler_config = vllm_config.scheduler_config
|
| 49 |
+
if scheduler_config.is_multi_step:
|
| 50 |
+
raise NotImplementedError(
|
| 51 |
+
"Multi-step execution is not implemented for HPU")
|
| 52 |
+
|
| 53 |
+
if vllm_config.speculative_config is not None:
|
| 54 |
+
raise NotImplementedError(
|
| 55 |
+
"Speculative decoding is not implemented for HPU")
|
| 56 |
+
|
| 57 |
+
parallel_config = vllm_config.parallel_config
|
| 58 |
+
if parallel_config.worker_cls == "auto":
|
| 59 |
+
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"
|
| 60 |
+
|
| 61 |
+
# NOTE(kzawora): default block size for Gaudi should be 128
|
| 62 |
+
# smaller sizes still work, but very inefficiently
|
| 63 |
+
cache_config = vllm_config.cache_config
|
| 64 |
+
if cache_config and cache_config.block_size is None:
|
| 65 |
+
cache_config.block_size = 128
|
| 66 |
+
if (parallel_config.distributed_executor_backend == 'mp'
|
| 67 |
+
and envs.VLLM_WORKER_MULTIPROC_METHOD == 'fork'):
|
| 68 |
+
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD",
|
| 69 |
+
None) is not None:
|
| 70 |
+
logger.warning("On HPU, VLLM_WORKER_MULTIPROC_METHOD=fork "
|
| 71 |
+
"might cause application hangs on exit. Using "
|
| 72 |
+
"VLLM_WORKER_MULTIPROC_METHOD=fork anyway, "
|
| 73 |
+
"as it was explicitly requested.")
|
| 74 |
+
else:
|
| 75 |
+
logger.warning(
|
| 76 |
+
"On HPU, VLLM_WORKER_MULTIPROC_METHOD=fork "
|
| 77 |
+
"might cause application hangs on exit. Setting "
|
| 78 |
+
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
|
| 79 |
+
"To override that behavior, please set "
|
| 80 |
+
"VLLM_WORKER_MULTIPROC_METHOD=fork explicitly.")
|
| 81 |
+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
| 82 |
+
|
| 83 |
+
@classmethod
|
| 84 |
+
def is_pin_memory_available(cls):
|
| 85 |
+
logger.warning("Pin memory is not supported on HPU.")
|
| 86 |
+
return False
|
| 87 |
+
|
| 88 |
+
@classmethod
|
| 89 |
+
def get_punica_wrapper(cls) -> str:
|
| 90 |
+
return "vllm.lora.punica_wrapper.punica_hpu.PunicaWrapperHPU"
|
.venv/lib/python3.11/site-packages/vllm/platforms/interface.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
import enum
|
| 4 |
+
import platform
|
| 5 |
+
import random
|
| 6 |
+
from platform import uname
|
| 7 |
+
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from vllm.logger import init_logger
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from vllm.config import VllmConfig
|
| 16 |
+
else:
|
| 17 |
+
VllmConfig = None
|
| 18 |
+
|
| 19 |
+
logger = init_logger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def in_wsl() -> bool:
|
| 23 |
+
# Reference: https://github.com/microsoft/WSL/issues/4071
|
| 24 |
+
return "microsoft" in " ".join(uname()).lower()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class _Backend(enum.Enum):
|
| 28 |
+
FLASH_ATTN = enum.auto()
|
| 29 |
+
FLASH_ATTN_VLLM_V1 = enum.auto()
|
| 30 |
+
XFORMERS = enum.auto()
|
| 31 |
+
ROCM_FLASH = enum.auto()
|
| 32 |
+
TORCH_SDPA = enum.auto()
|
| 33 |
+
OPENVINO = enum.auto()
|
| 34 |
+
FLASHINFER = enum.auto()
|
| 35 |
+
TRITON_MLA = enum.auto()
|
| 36 |
+
HPU_ATTN = enum.auto()
|
| 37 |
+
PALLAS = enum.auto()
|
| 38 |
+
IPEX = enum.auto()
|
| 39 |
+
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
|
| 40 |
+
NO_ATTENTION = enum.auto()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class PlatformEnum(enum.Enum):
|
| 44 |
+
CUDA = enum.auto()
|
| 45 |
+
ROCM = enum.auto()
|
| 46 |
+
TPU = enum.auto()
|
| 47 |
+
HPU = enum.auto()
|
| 48 |
+
XPU = enum.auto()
|
| 49 |
+
CPU = enum.auto()
|
| 50 |
+
NEURON = enum.auto()
|
| 51 |
+
OPENVINO = enum.auto()
|
| 52 |
+
OOT = enum.auto()
|
| 53 |
+
UNSPECIFIED = enum.auto()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class CpuArchEnum(enum.Enum):
|
| 57 |
+
X86 = enum.auto()
|
| 58 |
+
ARM = enum.auto()
|
| 59 |
+
POWERPC = enum.auto()
|
| 60 |
+
OTHER = enum.auto()
|
| 61 |
+
UNKNOWN = enum.auto()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class DeviceCapability(NamedTuple):
|
| 65 |
+
major: int
|
| 66 |
+
minor: int
|
| 67 |
+
|
| 68 |
+
def as_version_str(self) -> str:
|
| 69 |
+
return f"{self.major}.{self.minor}"
|
| 70 |
+
|
| 71 |
+
def to_int(self) -> int:
|
| 72 |
+
"""
|
| 73 |
+
Express device capability as an integer ``<major><minor>``.
|
| 74 |
+
|
| 75 |
+
It is assumed that the minor version is always a single digit.
|
| 76 |
+
"""
|
| 77 |
+
assert 0 <= self.minor < 10
|
| 78 |
+
return self.major * 10 + self.minor
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class Platform:
|
| 82 |
+
_enum: PlatformEnum
|
| 83 |
+
device_name: str
|
| 84 |
+
device_type: str
|
| 85 |
+
|
| 86 |
+
# available dispatch keys:
|
| 87 |
+
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
|
| 88 |
+
# use "CPU" as a fallback for platforms not registered in PyTorch
|
| 89 |
+
dispatch_key: str = "CPU"
|
| 90 |
+
|
| 91 |
+
# available ray device keys:
|
| 92 |
+
# https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa
|
| 93 |
+
# empty string means the device does not support ray
|
| 94 |
+
ray_device_key: str = ""
|
| 95 |
+
|
| 96 |
+
# platform-agnostic way to specify the device control environment variable,
|
| 97 |
+
# .e.g. CUDA_VISIBLE_DEVICES for CUDA.
|
| 98 |
+
# hint: search for "get_visible_accelerator_ids_env_var" in
|
| 99 |
+
# https://github.com/ray-project/ray/tree/master/python/ray/_private/accelerators # noqa
|
| 100 |
+
device_control_env_var: str = "VLLM_DEVICE_CONTROL_ENV_VAR_PLACEHOLDER"
|
| 101 |
+
|
| 102 |
+
# The torch.compile backend for compiling simple and
|
| 103 |
+
# standalone functions. The default value is "inductor" to keep
|
| 104 |
+
# the same behavior as PyTorch.
|
| 105 |
+
# NOTE: for the forward part of the model, vLLM has another separate
|
| 106 |
+
# compilation strategy.
|
| 107 |
+
simple_compile_backend: str = "inductor"
|
| 108 |
+
|
| 109 |
+
supported_quantization: list[str] = []
|
| 110 |
+
|
| 111 |
+
def is_cuda(self) -> bool:
|
| 112 |
+
return self._enum == PlatformEnum.CUDA
|
| 113 |
+
|
| 114 |
+
def is_rocm(self) -> bool:
|
| 115 |
+
return self._enum == PlatformEnum.ROCM
|
| 116 |
+
|
| 117 |
+
def is_tpu(self) -> bool:
|
| 118 |
+
return self._enum == PlatformEnum.TPU
|
| 119 |
+
|
| 120 |
+
def is_hpu(self) -> bool:
|
| 121 |
+
return self._enum == PlatformEnum.HPU
|
| 122 |
+
|
| 123 |
+
def is_xpu(self) -> bool:
|
| 124 |
+
return self._enum == PlatformEnum.XPU
|
| 125 |
+
|
| 126 |
+
def is_cpu(self) -> bool:
|
| 127 |
+
return self._enum == PlatformEnum.CPU
|
| 128 |
+
|
| 129 |
+
def is_neuron(self) -> bool:
|
| 130 |
+
return self._enum == PlatformEnum.NEURON
|
| 131 |
+
|
| 132 |
+
def is_openvino(self) -> bool:
|
| 133 |
+
return self._enum == PlatformEnum.OPENVINO
|
| 134 |
+
|
| 135 |
+
def is_out_of_tree(self) -> bool:
|
| 136 |
+
return self._enum == PlatformEnum.OOT
|
| 137 |
+
|
| 138 |
+
def is_cuda_alike(self) -> bool:
|
| 139 |
+
"""Stateless version of :func:`torch.cuda.is_available`."""
|
| 140 |
+
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
|
| 141 |
+
|
| 142 |
+
@classmethod
|
| 143 |
+
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
| 144 |
+
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
| 145 |
+
block_size: int, use_v1: bool,
|
| 146 |
+
use_mla: bool) -> str:
|
| 147 |
+
"""Get the attention backend class of a device."""
|
| 148 |
+
return ""
|
| 149 |
+
|
| 150 |
+
@classmethod
|
| 151 |
+
def get_device_capability(
|
| 152 |
+
cls,
|
| 153 |
+
device_id: int = 0,
|
| 154 |
+
) -> Optional[DeviceCapability]:
|
| 155 |
+
"""Stateless version of :func:`torch.cuda.get_device_capability`."""
|
| 156 |
+
return None
|
| 157 |
+
|
| 158 |
+
@classmethod
|
| 159 |
+
def has_device_capability(
|
| 160 |
+
cls,
|
| 161 |
+
capability: Union[Tuple[int, int], int],
|
| 162 |
+
device_id: int = 0,
|
| 163 |
+
) -> bool:
|
| 164 |
+
"""
|
| 165 |
+
Test whether this platform is compatible with a device capability.
|
| 166 |
+
|
| 167 |
+
The ``capability`` argument can either be:
|
| 168 |
+
|
| 169 |
+
- A tuple ``(major, minor)``.
|
| 170 |
+
- An integer ``<major><minor>``. (See :meth:`DeviceCapability.to_int`)
|
| 171 |
+
"""
|
| 172 |
+
current_capability = cls.get_device_capability(device_id=device_id)
|
| 173 |
+
if current_capability is None:
|
| 174 |
+
return False
|
| 175 |
+
|
| 176 |
+
if isinstance(capability, tuple):
|
| 177 |
+
return current_capability >= capability
|
| 178 |
+
|
| 179 |
+
return current_capability.to_int() >= capability
|
| 180 |
+
|
| 181 |
+
@classmethod
|
| 182 |
+
def get_device_name(cls, device_id: int = 0) -> str:
|
| 183 |
+
"""Get the name of a device."""
|
| 184 |
+
raise NotImplementedError
|
| 185 |
+
|
| 186 |
+
@classmethod
|
| 187 |
+
def get_device_uuid(cls, device_id: int = 0) -> str:
|
| 188 |
+
"""Get the uuid of a device, e.g. the PCI bus ID."""
|
| 189 |
+
raise NotImplementedError
|
| 190 |
+
|
| 191 |
+
@classmethod
|
| 192 |
+
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
| 193 |
+
"""Get the total memory of a device in bytes."""
|
| 194 |
+
raise NotImplementedError
|
| 195 |
+
|
| 196 |
+
@classmethod
|
| 197 |
+
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
| 198 |
+
"""
|
| 199 |
+
Check if the current platform supports async output.
|
| 200 |
+
"""
|
| 201 |
+
raise NotImplementedError
|
| 202 |
+
|
| 203 |
+
@classmethod
|
| 204 |
+
def inference_mode(cls):
|
| 205 |
+
"""A device-specific wrapper of `torch.inference_mode`.
|
| 206 |
+
|
| 207 |
+
This wrapper is recommended because some hardware backends such as TPU
|
| 208 |
+
do not support `torch.inference_mode`. In such a case, they will fall
|
| 209 |
+
back to `torch.no_grad` by overriding this method.
|
| 210 |
+
"""
|
| 211 |
+
return torch.inference_mode(mode=True)
|
| 212 |
+
|
| 213 |
+
@classmethod
|
| 214 |
+
def seed_everything(cls, seed: int) -> None:
|
| 215 |
+
"""
|
| 216 |
+
Set the seed of each random module.
|
| 217 |
+
`torch.manual_seed` will set seed on all devices.
|
| 218 |
+
|
| 219 |
+
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
|
| 220 |
+
"""
|
| 221 |
+
random.seed(seed)
|
| 222 |
+
np.random.seed(seed)
|
| 223 |
+
torch.manual_seed(seed)
|
| 224 |
+
|
| 225 |
+
@classmethod
|
| 226 |
+
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
| 227 |
+
"""
|
| 228 |
+
Check and update the configuration for the current platform.
|
| 229 |
+
|
| 230 |
+
It can raise an exception if the configuration is not compatible with
|
| 231 |
+
the current platform, or it can update the configuration to make it
|
| 232 |
+
compatible with the current platform.
|
| 233 |
+
|
| 234 |
+
The config is passed by reference, so it can be modified in place.
|
| 235 |
+
"""
|
| 236 |
+
pass
|
| 237 |
+
|
| 238 |
+
@classmethod
|
| 239 |
+
def verify_model_arch(cls, model_arch: str) -> None:
|
| 240 |
+
"""
|
| 241 |
+
Verify whether the current platform supports the specified model
|
| 242 |
+
architecture.
|
| 243 |
+
|
| 244 |
+
- This will raise an Error or Warning based on the model support on
|
| 245 |
+
the current platform.
|
| 246 |
+
- By default all models are considered supported.
|
| 247 |
+
"""
|
| 248 |
+
pass
|
| 249 |
+
|
| 250 |
+
@classmethod
|
| 251 |
+
def verify_quantization(cls, quant: str) -> None:
|
| 252 |
+
"""
|
| 253 |
+
Verify whether the quantization is supported by the current platform.
|
| 254 |
+
"""
|
| 255 |
+
if cls.supported_quantization and \
|
| 256 |
+
quant not in cls.supported_quantization:
|
| 257 |
+
raise ValueError(
|
| 258 |
+
f"{quant} quantization is currently not supported in "
|
| 259 |
+
f"{cls.device_name}.")
|
| 260 |
+
|
| 261 |
+
@classmethod
|
| 262 |
+
def get_cpu_architecture(cls) -> CpuArchEnum:
|
| 263 |
+
"""
|
| 264 |
+
Determine the CPU architecture of the current system.
|
| 265 |
+
Returns CpuArchEnum indicating the architecture type.
|
| 266 |
+
"""
|
| 267 |
+
machine = platform.machine().lower()
|
| 268 |
+
|
| 269 |
+
if machine in ("x86_64", "amd64", "i386", "i686"):
|
| 270 |
+
return CpuArchEnum.X86
|
| 271 |
+
elif machine.startswith("arm") or machine.startswith("aarch"):
|
| 272 |
+
return CpuArchEnum.ARM
|
| 273 |
+
elif machine.startswith("ppc"):
|
| 274 |
+
return CpuArchEnum.POWERPC
|
| 275 |
+
|
| 276 |
+
return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN
|
| 277 |
+
|
| 278 |
+
@classmethod
|
| 279 |
+
def is_pin_memory_available(cls) -> bool:
|
| 280 |
+
"""Checks whether pin memory is available on the current platform."""
|
| 281 |
+
if in_wsl():
|
| 282 |
+
# Pinning memory in WSL is not supported.
|
| 283 |
+
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
|
| 284 |
+
logger.warning("Using 'pin_memory=False' as WSL is detected. "
|
| 285 |
+
"This may slow down the performance.")
|
| 286 |
+
return False
|
| 287 |
+
return True
|
| 288 |
+
|
| 289 |
+
@classmethod
|
| 290 |
+
def get_current_memory_usage(cls,
|
| 291 |
+
device: Optional[torch.types.Device] = None
|
| 292 |
+
) -> float:
|
| 293 |
+
"""
|
| 294 |
+
Return the memory usage in bytes.
|
| 295 |
+
"""
|
| 296 |
+
raise NotImplementedError
|
| 297 |
+
|
| 298 |
+
@classmethod
|
| 299 |
+
def get_punica_wrapper(cls) -> str:
|
| 300 |
+
"""
|
| 301 |
+
Return the punica wrapper for current platform.
|
| 302 |
+
"""
|
| 303 |
+
raise NotImplementedError
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class UnspecifiedPlatform(Platform):
|
| 307 |
+
_enum = PlatformEnum.UNSPECIFIED
|
| 308 |
+
device_type = ""
|
.venv/lib/python3.11/site-packages/vllm/platforms/neuron.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
|
| 3 |
+
from typing import TYPE_CHECKING, Optional
|
| 4 |
+
|
| 5 |
+
from vllm.logger import init_logger
|
| 6 |
+
|
| 7 |
+
from .interface import Platform, PlatformEnum
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
from vllm.config import VllmConfig
|
| 11 |
+
else:
|
| 12 |
+
VllmConfig = None
|
| 13 |
+
|
| 14 |
+
logger = init_logger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class NeuronPlatform(Platform):
|
| 18 |
+
_enum = PlatformEnum.NEURON
|
| 19 |
+
device_name: str = "neuron"
|
| 20 |
+
device_type: str = "neuron"
|
| 21 |
+
ray_device_key: str = "neuron_cores"
|
| 22 |
+
supported_quantization: list[str] = ["neuron_quant"]
|
| 23 |
+
device_control_env_var: str = "NEURON_RT_VISIBLE_CORES"
|
| 24 |
+
|
| 25 |
+
@classmethod
|
| 26 |
+
def get_device_name(cls, device_id: int = 0) -> str:
|
| 27 |
+
return "neuron"
|
| 28 |
+
|
| 29 |
+
@classmethod
|
| 30 |
+
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
| 31 |
+
return False
|
| 32 |
+
|
| 33 |
+
@classmethod
|
| 34 |
+
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
| 35 |
+
parallel_config = vllm_config.parallel_config
|
| 36 |
+
if parallel_config.worker_cls == "auto":
|
| 37 |
+
parallel_config.worker_cls = \
|
| 38 |
+
"vllm.worker.neuron_worker.NeuronWorker"
|
| 39 |
+
|
| 40 |
+
if parallel_config.world_size > 1:
|
| 41 |
+
parallel_config.distributed_executor_backend = "uni"
|
| 42 |
+
|
| 43 |
+
assert (vllm_config.lora_config
|
| 44 |
+
is None), "LoRA is not supported for Neuron backend."
|
| 45 |
+
assert (not vllm_config.speculative_config
|
| 46 |
+
), "Speculative decoding not yet supported for Neuron backend."
|
| 47 |
+
|
| 48 |
+
cache_config = vllm_config.cache_config
|
| 49 |
+
if cache_config:
|
| 50 |
+
# neuron needs block_size = max_model_len
|
| 51 |
+
vllm_config.cache_config.block_size = \
|
| 52 |
+
vllm_config.model_config.max_model_len
|
| 53 |
+
|
| 54 |
+
@classmethod
|
| 55 |
+
def is_pin_memory_available(cls) -> bool:
|
| 56 |
+
logger.warning("Pin memory is not supported on Neuron.")
|
| 57 |
+
return False
|