ChuxiJ commited on
Commit
6d3b89f
·
1 Parent(s): 033008e

resolve intermittent CUDA assertion error in concurrent serving scenarios

Browse files
acestep/third_parts/nano-vllm/nanovllm/engine/llm_engine.py CHANGED
@@ -1,4 +1,5 @@
1
  import atexit
 
2
  from dataclasses import fields
3
  from time import perf_counter
4
  from tqdm.auto import tqdm
@@ -20,6 +21,15 @@ class LLMEngine:
20
  config = Config(model, **config_kwargs)
21
  self.ps = []
22
  self.events = []
 
 
 
 
 
 
 
 
 
23
  ctx = mp.get_context("spawn")
24
  for i in range(1, config.tensor_parallel_size):
25
  event = ctx.Event()
@@ -108,6 +118,20 @@ class LLMEngine:
108
  sampling_params: SamplingParams | list[SamplingParams],
109
  use_tqdm: bool = True,
110
  unconditional_prompts: list[str] | list[list[int]] | None = None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  ) -> list[str]:
112
  # Clean up any residual state from previous interrupted generations
113
  # This prevents 'deque index out of range' errors from accumulated block leaks
 
1
  import atexit
2
+ import threading
3
  from dataclasses import fields
4
  from time import perf_counter
5
  from tqdm.auto import tqdm
 
21
  config = Config(model, **config_kwargs)
22
  self.ps = []
23
  self.events = []
24
+ # Thread-safety lock for generate().
25
+ # The scheduler, block manager, model runner, and CUDA graph buffers are all
26
+ # shared mutable state that is NOT thread-safe. In concurrent serving scenarios
27
+ # (API server with ThreadPoolExecutor, multiple queue workers, Gradio with
28
+ # concurrent requests), multiple threads can call generate() simultaneously.
29
+ # Without this lock, concurrent access corrupts scheduler state, block tables,
30
+ # and CUDA graph input buffers, leading to intermittent CUDA device-side
31
+ # assertion failures (illegal memory access in KV cache).
32
+ self._generate_lock = threading.Lock()
33
  ctx = mp.get_context("spawn")
34
  for i in range(1, config.tensor_parallel_size):
35
  event = ctx.Event()
 
118
  sampling_params: SamplingParams | list[SamplingParams],
119
  use_tqdm: bool = True,
120
  unconditional_prompts: list[str] | list[list[int]] | None = None,
121
+ ) -> list[str]:
122
+ # Serialize access to the engine to prevent concurrent corruption of
123
+ # scheduler state, block manager, CUDA graph buffers, and KV cache.
124
+ # This is the primary defense against the intermittent CUDA device-side
125
+ # assertion error that occurs in concurrent serving scenarios.
126
+ with self._generate_lock:
127
+ return self._generate_impl(prompts, sampling_params, use_tqdm, unconditional_prompts)
128
+
129
+ def _generate_impl(
130
+ self,
131
+ prompts: list[str] | list[list[int]],
132
+ sampling_params: SamplingParams | list[SamplingParams],
133
+ use_tqdm: bool = True,
134
+ unconditional_prompts: list[str] | list[list[int]] | None = None,
135
  ) -> list[str]:
136
  # Clean up any residual state from previous interrupted generations
137
  # This prevents 'deque index out of range' errors from accumulated block leaks
acestep/third_parts/nano-vllm/nanovllm/utils/context.py CHANGED
@@ -1,4 +1,5 @@
1
  from dataclasses import dataclass
 
2
  import torch
3
 
4
 
@@ -13,15 +14,34 @@ class Context:
13
  context_lens: torch.Tensor | None = None
14
  block_tables: torch.Tensor | None = None
15
 
16
- _CONTEXT = Context()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def get_context():
19
- return _CONTEXT
 
 
 
 
20
 
21
  def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
22
- global _CONTEXT
23
- _CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
24
 
25
  def reset_context():
26
- global _CONTEXT
27
- _CONTEXT = Context()
 
1
  from dataclasses import dataclass
2
+ import threading
3
  import torch
4
 
5
 
 
14
  context_lens: torch.Tensor | None = None
15
  block_tables: torch.Tensor | None = None
16
 
17
+
18
+ # Thread-local storage for context.
19
+ #
20
+ # ROOT CAUSE FIX: The original implementation used a plain module-level global
21
+ # `_CONTEXT` variable. In concurrent serving scenarios (API server with
22
+ # ThreadPoolExecutor, multiple queue workers, or Gradio with concurrent requests),
23
+ # multiple threads can call set_context() / get_context() / reset_context()
24
+ # concurrently. This creates a race condition:
25
+ #
26
+ # Thread A: set_context(...) # sets slot_mapping, block_tables for request A
27
+ # Thread B: set_context(...) # OVERWRITES with request B's data
28
+ # Thread A: run_model(...) # reads Thread B's context → WRONG KV cache addresses
29
+ # # → CUDA illegal memory access / device-side assertion
30
+ #
31
+ # By using threading.local(), each thread gets its own independent Context,
32
+ # eliminating the race condition entirely.
33
+ _THREAD_LOCAL = threading.local()
34
+
35
 
36
  def get_context():
37
+ ctx = getattr(_THREAD_LOCAL, 'context', None)
38
+ if ctx is None:
39
+ ctx = Context()
40
+ _THREAD_LOCAL.context = ctx
41
+ return ctx
42
 
43
  def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
44
+ _THREAD_LOCAL.context = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
 
45
 
46
  def reset_context():
47
+ _THREAD_LOCAL.context = Context()