shinnosukeono Claude Opus 4.5 commited on
Commit
4161fbd
·
1 Parent(s): 0ab0126

Fix: Use single vLLM engine with continuous batching

Browse files

- Detect available GPUs dynamically
- Use vLLM's built-in continuous batching for concurrent requests
- Remove multiprocessing approach that failed on HF Spaces
- Use tensor_parallel_size for multi-GPU when available

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +160 -191
app.py CHANGED
@@ -3,13 +3,31 @@ import os
3
  import re
4
  import uuid
5
  import threading
6
- import queue
7
- from multiprocessing import Process, Queue
8
  from typing import Generator
9
 
10
  import gradio as gr
11
 
12
- NUM_GPUS = 8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # Stop strings for generation
15
  STOP_STRINGS = [
@@ -23,134 +41,30 @@ STOP_RE = re.compile(
23
  re.MULTILINE
24
  )
25
 
26
-
27
- def gpu_worker_main(gpu_id: int, request_queue: Queue, response_queue: Queue):
28
- """
29
- Worker process that runs on a dedicated GPU.
30
- Each worker has its own vLLM engine instance.
31
- """
32
- os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
33
-
34
- import asyncio
35
- from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
36
-
37
- # Initialize vLLM engine on this GPU
38
- engine_args = AsyncEngineArgs(
39
- model="EQUES/JPharmatron-7B-chat",
40
- enforce_eager=True,
41
- gpu_memory_utilization=0.85,
42
- )
43
- engine = AsyncLLMEngine.from_engine_args(engine_args)
44
-
45
- loop = asyncio.new_event_loop()
46
- asyncio.set_event_loop(loop)
47
-
48
- async def generate_and_stream(request_id: str, prompt: str):
49
- """Generate tokens and stream chunks back via response queue."""
50
- params = SamplingParams(
51
- temperature=0.0,
52
- max_tokens=4096,
53
- repetition_penalty=1.2,
54
- stop=STOP_STRINGS,
55
- )
56
-
57
- previous_text = ""
58
- try:
59
- async for out in engine.generate(prompt, params, request_id=request_id):
60
- full_text = out.outputs[0].text
61
-
62
- # Check for stop patterns that might have leaked through
63
- m = STOP_RE.search(full_text)
64
- if m:
65
- cut = m.start()
66
- chunk = full_text[len(previous_text):cut]
67
- if chunk:
68
- response_queue.put((gpu_id, chunk, False))
69
- break
70
-
71
- chunk = full_text[len(previous_text):]
72
- previous_text = full_text
73
- if chunk:
74
- response_queue.put((gpu_id, chunk, False))
75
-
76
- except Exception as e:
77
- response_queue.put((gpu_id, f"\n[Error: {str(e)}]", False))
78
-
79
- # Signal completion
80
- response_queue.put((gpu_id, "", True))
81
-
82
- # Main worker loop
83
- while True:
84
- try:
85
- request = request_queue.get(timeout=1.0)
86
- except:
87
- continue
88
-
89
- if request is None: # Shutdown signal
90
- break
91
-
92
- request_id, prompt = request
93
- loop.run_until_complete(generate_and_stream(request_id, prompt))
94
-
95
-
96
- class ParallelInferenceManager:
97
- """Manages multiple GPU worker processes for parallel inference."""
98
-
99
- def __init__(self, num_gpus: int = NUM_GPUS):
100
- self.num_gpus = num_gpus
101
- self.workers = []
102
- self.request_queues = []
103
- self.response_queue = Queue() # Shared response queue
104
- self._started = False
105
-
106
- def start(self):
107
- """Start all GPU worker processes."""
108
- if self._started:
109
- return
110
-
111
- for gpu_id in range(self.num_gpus):
112
- request_queue = Queue()
113
- self.request_queues.append(request_queue)
114
-
115
- process = Process(
116
- target=gpu_worker_main,
117
- args=(gpu_id, request_queue, self.response_queue),
118
- daemon=True
119
- )
120
- process.start()
121
- self.workers.append(process)
122
-
123
- self._started = True
124
-
125
- def submit_request(self, gpu_id: int, prompt: str, request_id: str):
126
- """Submit a request to a specific GPU worker."""
127
- if 0 <= gpu_id < self.num_gpus:
128
- self.request_queues[gpu_id].put((request_id, prompt))
129
-
130
- def shutdown(self):
131
- """Shutdown all workers."""
132
- for q in self.request_queues:
133
- q.put(None)
134
- for w in self.workers:
135
- w.join(timeout=5)
136
- if w.is_alive():
137
- w.terminate()
138
-
139
-
140
- # Global manager instance (initialized lazily)
141
- _manager = None
142
- _manager_lock = threading.Lock()
143
-
144
-
145
- def get_manager() -> ParallelInferenceManager:
146
- """Get or create the global inference manager."""
147
- global _manager
148
- if _manager is None:
149
- with _manager_lock:
150
- if _manager is None:
151
- _manager = ParallelInferenceManager(NUM_GPUS)
152
- _manager.start()
153
- return _manager
154
 
155
 
156
  def build_prompt(user_input: str, mode: list[str]) -> str:
@@ -168,90 +82,145 @@ def build_prompt(user_input: str, mode: list[str]) -> str:
168
  return base_prompt
169
 
170
 
171
- def respond_parallel(
172
- prompt0: str, prompt1: str, prompt2: str, prompt3: str,
173
- prompt4: str, prompt5: str, prompt6: str, prompt7: str,
174
- mode: list[str]
175
- ) -> Generator:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  """
177
- Process up to 8 prompts in parallel, streaming results back.
178
- Each prompt is sent to a dedicated GPU worker.
179
  """
180
- prompts = [prompt0, prompt1, prompt2, prompt3, prompt4, prompt5, prompt6, prompt7]
181
- manager = get_manager()
182
 
183
- # Track active requests and their results
184
- results = [""] * NUM_GPUS
185
- active_gpus = set()
186
 
187
- # Submit non-empty prompts to their respective GPUs
188
- for gpu_id, prompt in enumerate(prompts):
189
  if prompt and prompt.strip():
190
  full_prompt = build_prompt(prompt.strip(), mode)
191
- request_id = f"req_{gpu_id}_{uuid.uuid4().hex[:8]}"
192
- manager.submit_request(gpu_id, full_prompt, request_id)
193
- active_gpus.add(gpu_id)
194
- results[gpu_id] = "" # Initialize result
195
- else:
196
- results[gpu_id] = "" # Empty prompt = empty result
197
-
198
- if not active_gpus:
199
- # No prompts to process
200
- yield tuple(results)
201
  return
202
 
203
- # Stream results from all active workers
204
- while active_gpus:
205
- try:
206
- gpu_id, chunk, is_done = manager.response_queue.get(timeout=0.1)
 
207
 
208
- if is_done:
209
- active_gpus.discard(gpu_id)
210
- else:
211
- results[gpu_id] += chunk
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
- # Yield current state of all results
214
- yield tuple(results)
215
 
216
- except:
217
- # Timeout - yield current state and continue
 
 
 
 
 
 
 
 
 
 
 
 
218
  yield tuple(results)
219
 
 
 
 
 
 
 
 
 
220
 
221
- def respond_single(gpu_id: int, prompt: str, mode: list[str]) -> Generator:
222
- """Process a single prompt on a specific GPU."""
223
- manager = get_manager()
224
 
 
 
225
  if not prompt or not prompt.strip():
226
  yield ""
227
  return
228
 
 
229
  full_prompt = build_prompt(prompt.strip(), mode)
230
- request_id = f"single_{gpu_id}_{uuid.uuid4().hex[:8]}"
231
- manager.submit_request(gpu_id, full_prompt, request_id)
232
-
233
- result = ""
234
- while True:
235
- try:
236
- recv_gpu_id, chunk, is_done = manager.response_queue.get(timeout=0.1)
237
-
238
- # Only process responses for our request
239
- if recv_gpu_id == gpu_id:
240
- if is_done:
241
- break
242
- result += chunk
243
- yield result
244
-
245
- except:
246
  yield result
247
 
 
 
 
 
 
 
 
 
248
 
249
  # Build the Gradio interface
250
  with gr.Blocks(title="JPharmatron Parallel Chat") as demo:
251
  gr.Markdown("# 💊 JPharmatron - Parallel Request Processing")
252
  gr.Markdown(
253
- "Enter up to 8 prompts and process them simultaneously on dedicated GPUs. "
254
- "Each response streams independently."
255
  )
256
 
257
  # Mode selection
@@ -355,10 +324,10 @@ with gr.Blocks(title="JPharmatron Parallel Chat") as demo:
355
  with gr.Row():
356
  for i in range(8):
357
  btn = gr.Button(f"Run #{i+1}", size="sm")
358
- # Create closure to capture gpu_id
359
- def make_single_handler(gpu_id):
360
  def handler(prompt, mode):
361
- yield from respond_single(gpu_id, prompt, mode)
362
  return handler
363
  btn.click(
364
  fn=make_single_handler(i),
 
3
  import re
4
  import uuid
5
  import threading
 
 
6
  from typing import Generator
7
 
8
  import gradio as gr
9
 
10
+ # Detect available GPUs
11
+ def get_num_gpus() -> int:
12
+ """Detect the number of available GPUs."""
13
+ try:
14
+ import torch
15
+ if torch.cuda.is_available():
16
+ return torch.cuda.device_count()
17
+ except ImportError:
18
+ pass
19
+
20
+ # Fallback: check CUDA_VISIBLE_DEVICES
21
+ cuda_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
22
+ if cuda_devices:
23
+ return len(cuda_devices.split(","))
24
+
25
+ return 1 # Default to 1
26
+
27
+ NUM_GPUS = get_num_gpus()
28
+ MAX_PARALLEL_REQUESTS = 8 # UI supports up to 8 parallel inputs
29
+
30
+ print(f"Detected {NUM_GPUS} GPU(s)")
31
 
32
  # Stop strings for generation
33
  STOP_STRINGS = [
 
41
  re.MULTILINE
42
  )
43
 
44
+ # Global vLLM engine (single instance, handles concurrent requests internally)
45
+ _engine = None
46
+ _engine_lock = threading.Lock()
47
+ _loop = None
48
+
49
+
50
+ def get_engine():
51
+ """Get or create the global vLLM engine."""
52
+ global _engine, _loop
53
+ if _engine is None:
54
+ with _engine_lock:
55
+ if _engine is None:
56
+ from vllm import AsyncLLMEngine, AsyncEngineArgs
57
+
58
+ engine_args = AsyncEngineArgs(
59
+ model="EQUES/JPharmatron-7B-chat",
60
+ enforce_eager=True,
61
+ gpu_memory_utilization=0.85,
62
+ tensor_parallel_size=NUM_GPUS, # Use all available GPUs
63
+ )
64
+ _engine = AsyncLLMEngine.from_engine_args(engine_args)
65
+ _loop = asyncio.new_event_loop()
66
+ asyncio.set_event_loop(_loop)
67
+ return _engine, _loop
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
 
70
  def build_prompt(user_input: str, mode: list[str]) -> str:
 
82
  return base_prompt
83
 
84
 
85
+ async def astream_generate(engine, prompt: str, request_id: str):
86
+ """Async generator that streams tokens from vLLM."""
87
+ from vllm import SamplingParams
88
+
89
+ params = SamplingParams(
90
+ temperature=0.0,
91
+ max_tokens=4096,
92
+ repetition_penalty=1.2,
93
+ stop=STOP_STRINGS,
94
+ )
95
+
96
+ previous_text = ""
97
+ async for out in engine.generate(prompt, params, request_id=request_id):
98
+ full_text = out.outputs[0].text
99
+
100
+ # Check for stop patterns that might have leaked through
101
+ m = STOP_RE.search(full_text)
102
+ if m:
103
+ cut = m.start()
104
+ chunk = full_text[len(previous_text):cut]
105
+ if chunk:
106
+ yield chunk
107
+ break
108
+
109
+ chunk = full_text[len(previous_text):]
110
+ previous_text = full_text
111
+ if chunk:
112
+ yield chunk
113
+
114
+
115
+ async def run_parallel_async(prompts: list[str], mode: list[str]):
116
  """
117
+ Run multiple prompts in parallel using vLLM's continuous batching.
118
+ Yields (slot_id, accumulated_text) tuples as tokens arrive.
119
  """
120
+ engine, _ = get_engine()
 
121
 
122
+ # Build full prompts and track active slots
123
+ active_slots = {}
124
+ results = [""] * MAX_PARALLEL_REQUESTS
125
 
126
+ for i, prompt in enumerate(prompts):
 
127
  if prompt and prompt.strip():
128
  full_prompt = build_prompt(prompt.strip(), mode)
129
+ request_id = f"req_{i}_{uuid.uuid4().hex[:8]}"
130
+ active_slots[i] = {
131
+ "request_id": request_id,
132
+ "prompt": full_prompt,
133
+ "generator": None,
134
+ "done": False,
135
+ }
136
+
137
+ if not active_slots:
138
+ yield results
139
  return
140
 
141
+ # Start all generators
142
+ for slot_id, slot_info in active_slots.items():
143
+ slot_info["generator"] = astream_generate(
144
+ engine, slot_info["prompt"], slot_info["request_id"]
145
+ )
146
 
147
+ # Poll all generators and yield updates
148
+ while any(not slot["done"] for slot in active_slots.values()):
149
+ for slot_id, slot_info in active_slots.items():
150
+ if slot_info["done"]:
151
+ continue
152
+
153
+ try:
154
+ # Try to get next chunk with a small timeout
155
+ chunk = await asyncio.wait_for(
156
+ slot_info["generator"].__anext__(),
157
+ timeout=0.05
158
+ )
159
+ results[slot_id] += chunk
160
+ except StopAsyncIteration:
161
+ slot_info["done"] = True
162
+ except asyncio.TimeoutError:
163
+ pass # No data ready, continue to next slot
164
 
165
+ yield results
 
166
 
167
+
168
+ def respond_parallel(
169
+ prompt0: str, prompt1: str, prompt2: str, prompt3: str,
170
+ prompt4: str, prompt5: str, prompt6: str, prompt7: str,
171
+ mode: list[str]
172
+ ) -> Generator:
173
+ """
174
+ Process up to 8 prompts in parallel using vLLM's continuous batching.
175
+ """
176
+ prompts = [prompt0, prompt1, prompt2, prompt3, prompt4, prompt5, prompt6, prompt7]
177
+ _, loop = get_engine()
178
+
179
+ async def run():
180
+ async for results in run_parallel_async(prompts, mode):
181
  yield tuple(results)
182
 
183
+ # Run the async generator in the event loop
184
+ agen = run()
185
+ try:
186
+ while True:
187
+ results = loop.run_until_complete(agen.__anext__())
188
+ yield results
189
+ except StopAsyncIteration:
190
+ return
191
 
 
 
 
192
 
193
+ def respond_single(slot_id: int, prompt: str, mode: list[str]) -> Generator:
194
+ """Process a single prompt."""
195
  if not prompt or not prompt.strip():
196
  yield ""
197
  return
198
 
199
+ engine, loop = get_engine()
200
  full_prompt = build_prompt(prompt.strip(), mode)
201
+ request_id = f"single_{slot_id}_{uuid.uuid4().hex[:8]}"
202
+
203
+ async def run():
204
+ result = ""
205
+ async for chunk in astream_generate(engine, full_prompt, request_id):
206
+ result += chunk
 
 
 
 
 
 
 
 
 
 
207
  yield result
208
 
209
+ agen = run()
210
+ try:
211
+ while True:
212
+ result = loop.run_until_complete(agen.__anext__())
213
+ yield result
214
+ except StopAsyncIteration:
215
+ return
216
+
217
 
218
  # Build the Gradio interface
219
  with gr.Blocks(title="JPharmatron Parallel Chat") as demo:
220
  gr.Markdown("# 💊 JPharmatron - Parallel Request Processing")
221
  gr.Markdown(
222
+ f"Enter up to {MAX_PARALLEL_REQUESTS} prompts and process them simultaneously. "
223
+ f"Using {NUM_GPUS} GPU(s) with vLLM continuous batching."
224
  )
225
 
226
  # Mode selection
 
324
  with gr.Row():
325
  for i in range(8):
326
  btn = gr.Button(f"Run #{i+1}", size="sm")
327
+ # Create closure to capture slot_id
328
+ def make_single_handler(slot_id):
329
  def handler(prompt, mode):
330
+ yield from respond_single(slot_id, prompt, mode)
331
  return handler
332
  btn.click(
333
  fn=make_single_handler(i),