Spaces:
Configuration error
Configuration error
resolve intermittent CUDA assertion error in concurrent serving scenarios
Browse files
acestep/third_parts/nano-vllm/nanovllm/engine/llm_engine.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import atexit
|
|
|
|
| 2 |
from dataclasses import fields
|
| 3 |
from time import perf_counter
|
| 4 |
from tqdm.auto import tqdm
|
|
@@ -20,6 +21,15 @@ class LLMEngine:
|
|
| 20 |
config = Config(model, **config_kwargs)
|
| 21 |
self.ps = []
|
| 22 |
self.events = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
ctx = mp.get_context("spawn")
|
| 24 |
for i in range(1, config.tensor_parallel_size):
|
| 25 |
event = ctx.Event()
|
|
@@ -108,6 +118,20 @@ class LLMEngine:
|
|
| 108 |
sampling_params: SamplingParams | list[SamplingParams],
|
| 109 |
use_tqdm: bool = True,
|
| 110 |
unconditional_prompts: list[str] | list[list[int]] | None = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
) -> list[str]:
|
| 112 |
# Clean up any residual state from previous interrupted generations
|
| 113 |
# This prevents 'deque index out of range' errors from accumulated block leaks
|
|
|
|
| 1 |
import atexit
|
| 2 |
+
import threading
|
| 3 |
from dataclasses import fields
|
| 4 |
from time import perf_counter
|
| 5 |
from tqdm.auto import tqdm
|
|
|
|
| 21 |
config = Config(model, **config_kwargs)
|
| 22 |
self.ps = []
|
| 23 |
self.events = []
|
| 24 |
+
# Thread-safety lock for generate().
|
| 25 |
+
# The scheduler, block manager, model runner, and CUDA graph buffers are all
|
| 26 |
+
# shared mutable state that is NOT thread-safe. In concurrent serving scenarios
|
| 27 |
+
# (API server with ThreadPoolExecutor, multiple queue workers, Gradio with
|
| 28 |
+
# concurrent requests), multiple threads can call generate() simultaneously.
|
| 29 |
+
# Without this lock, concurrent access corrupts scheduler state, block tables,
|
| 30 |
+
# and CUDA graph input buffers, leading to intermittent CUDA device-side
|
| 31 |
+
# assertion failures (illegal memory access in KV cache).
|
| 32 |
+
self._generate_lock = threading.Lock()
|
| 33 |
ctx = mp.get_context("spawn")
|
| 34 |
for i in range(1, config.tensor_parallel_size):
|
| 35 |
event = ctx.Event()
|
|
|
|
| 118 |
sampling_params: SamplingParams | list[SamplingParams],
|
| 119 |
use_tqdm: bool = True,
|
| 120 |
unconditional_prompts: list[str] | list[list[int]] | None = None,
|
| 121 |
+
) -> list[str]:
|
| 122 |
+
# Serialize access to the engine to prevent concurrent corruption of
|
| 123 |
+
# scheduler state, block manager, CUDA graph buffers, and KV cache.
|
| 124 |
+
# This is the primary defense against the intermittent CUDA device-side
|
| 125 |
+
# assertion error that occurs in concurrent serving scenarios.
|
| 126 |
+
with self._generate_lock:
|
| 127 |
+
return self._generate_impl(prompts, sampling_params, use_tqdm, unconditional_prompts)
|
| 128 |
+
|
| 129 |
+
def _generate_impl(
|
| 130 |
+
self,
|
| 131 |
+
prompts: list[str] | list[list[int]],
|
| 132 |
+
sampling_params: SamplingParams | list[SamplingParams],
|
| 133 |
+
use_tqdm: bool = True,
|
| 134 |
+
unconditional_prompts: list[str] | list[list[int]] | None = None,
|
| 135 |
) -> list[str]:
|
| 136 |
# Clean up any residual state from previous interrupted generations
|
| 137 |
# This prevents 'deque index out of range' errors from accumulated block leaks
|
acestep/third_parts/nano-vllm/nanovllm/utils/context.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
from dataclasses import dataclass
|
|
|
|
| 2 |
import torch
|
| 3 |
|
| 4 |
|
|
@@ -13,15 +14,34 @@ class Context:
|
|
| 13 |
context_lens: torch.Tensor | None = None
|
| 14 |
block_tables: torch.Tensor | None = None
|
| 15 |
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def get_context():
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
|
| 22 |
-
|
| 23 |
-
_CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
|
| 24 |
|
| 25 |
def reset_context():
|
| 26 |
-
|
| 27 |
-
_CONTEXT = Context()
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
+
import threading
|
| 3 |
import torch
|
| 4 |
|
| 5 |
|
|
|
|
| 14 |
context_lens: torch.Tensor | None = None
|
| 15 |
block_tables: torch.Tensor | None = None
|
| 16 |
|
| 17 |
+
|
| 18 |
+
# Thread-local storage for context.
|
| 19 |
+
#
|
| 20 |
+
# ROOT CAUSE FIX: The original implementation used a plain module-level global
|
| 21 |
+
# `_CONTEXT` variable. In concurrent serving scenarios (API server with
|
| 22 |
+
# ThreadPoolExecutor, multiple queue workers, or Gradio with concurrent requests),
|
| 23 |
+
# multiple threads can call set_context() / get_context() / reset_context()
|
| 24 |
+
# concurrently. This creates a race condition:
|
| 25 |
+
#
|
| 26 |
+
# Thread A: set_context(...) # sets slot_mapping, block_tables for request A
|
| 27 |
+
# Thread B: set_context(...) # OVERWRITES with request B's data
|
| 28 |
+
# Thread A: run_model(...) # reads Thread B's context → WRONG KV cache addresses
|
| 29 |
+
# # → CUDA illegal memory access / device-side assertion
|
| 30 |
+
#
|
| 31 |
+
# By using threading.local(), each thread gets its own independent Context,
|
| 32 |
+
# eliminating the race condition entirely.
|
| 33 |
+
_THREAD_LOCAL = threading.local()
|
| 34 |
+
|
| 35 |
|
| 36 |
def get_context():
|
| 37 |
+
ctx = getattr(_THREAD_LOCAL, 'context', None)
|
| 38 |
+
if ctx is None:
|
| 39 |
+
ctx = Context()
|
| 40 |
+
_THREAD_LOCAL.context = ctx
|
| 41 |
+
return ctx
|
| 42 |
|
| 43 |
def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
|
| 44 |
+
_THREAD_LOCAL.context = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
|
|
|
|
| 45 |
|
| 46 |
def reset_context():
|
| 47 |
+
_THREAD_LOCAL.context = Context()
|
|
|