shinnosukeono's picture
Fix: Use single vLLM engine with continuous batching
4161fbd
import asyncio
import os
import re
import uuid
import threading
from typing import Generator
import gradio as gr
# Detect available GPUs
def get_num_gpus() -> int:
"""Detect the number of available GPUs."""
try:
import torch
if torch.cuda.is_available():
return torch.cuda.device_count()
except ImportError:
pass
# Fallback: check CUDA_VISIBLE_DEVICES
cuda_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
if cuda_devices:
return len(cuda_devices.split(","))
return 1 # Default to 1
NUM_GPUS = get_num_gpus()
MAX_PARALLEL_REQUESTS = 8 # UI supports up to 8 parallel inputs
print(f"Detected {NUM_GPUS} GPU(s)")
# Stop strings for generation
STOP_STRINGS = [
"\nUser:", "\nユーザ:", "\nユーザー:",
"\nAssistant:", "\nアシスタント:",
"\nHuman:", "\nHuman:"
]
# Regex for post-processing cleanup
STOP_RE = re.compile(
r"(?:^|\n)(?:User|ユーザ|ユーザー|Assistant|アシスタント)[::].*",
re.MULTILINE
)
# Global vLLM engine (single instance, handles concurrent requests internally)
_engine = None
_engine_lock = threading.Lock()
_loop = None
def get_engine():
"""Get or create the global vLLM engine."""
global _engine, _loop
if _engine is None:
with _engine_lock:
if _engine is None:
from vllm import AsyncLLMEngine, AsyncEngineArgs
engine_args = AsyncEngineArgs(
model="EQUES/JPharmatron-7B-chat",
enforce_eager=True,
gpu_memory_utilization=0.85,
tensor_parallel_size=NUM_GPUS, # Use all available GPUs
)
_engine = AsyncLLMEngine.from_engine_args(engine_args)
_loop = asyncio.new_event_loop()
asyncio.set_event_loop(_loop)
return _engine, _loop
def build_prompt(user_input: str, mode: list[str]) -> str:
"""Build the prompt with system instructions and mode settings."""
base_prompt = "あなたは製薬に関する専門家です。製薬に関するユーザーの質問に親切に回答してください。参照した文献を回答の末尾に常に提示してください。\n"
if "製薬の専門家" in mode:
base_prompt += "あなたは製薬に関する専門家です。製薬に関するユーザーの質問に親切に回答してください。参照した文献は常に提示してください。\n"
if "国際基準に準拠" in mode:
base_prompt += "回答に際して、国際基準に準拠してください。\n"
if "具体的な手順" in mode:
base_prompt += "回答には具体的な作業手順を含めてください。\n"
base_prompt += f"ユーザー: {user_input}\nアシスタント:"
return base_prompt
async def astream_generate(engine, prompt: str, request_id: str):
"""Async generator that streams tokens from vLLM."""
from vllm import SamplingParams
params = SamplingParams(
temperature=0.0,
max_tokens=4096,
repetition_penalty=1.2,
stop=STOP_STRINGS,
)
previous_text = ""
async for out in engine.generate(prompt, params, request_id=request_id):
full_text = out.outputs[0].text
# Check for stop patterns that might have leaked through
m = STOP_RE.search(full_text)
if m:
cut = m.start()
chunk = full_text[len(previous_text):cut]
if chunk:
yield chunk
break
chunk = full_text[len(previous_text):]
previous_text = full_text
if chunk:
yield chunk
async def run_parallel_async(prompts: list[str], mode: list[str]):
"""
Run multiple prompts in parallel using vLLM's continuous batching.
Yields (slot_id, accumulated_text) tuples as tokens arrive.
"""
engine, _ = get_engine()
# Build full prompts and track active slots
active_slots = {}
results = [""] * MAX_PARALLEL_REQUESTS
for i, prompt in enumerate(prompts):
if prompt and prompt.strip():
full_prompt = build_prompt(prompt.strip(), mode)
request_id = f"req_{i}_{uuid.uuid4().hex[:8]}"
active_slots[i] = {
"request_id": request_id,
"prompt": full_prompt,
"generator": None,
"done": False,
}
if not active_slots:
yield results
return
# Start all generators
for slot_id, slot_info in active_slots.items():
slot_info["generator"] = astream_generate(
engine, slot_info["prompt"], slot_info["request_id"]
)
# Poll all generators and yield updates
while any(not slot["done"] for slot in active_slots.values()):
for slot_id, slot_info in active_slots.items():
if slot_info["done"]:
continue
try:
# Try to get next chunk with a small timeout
chunk = await asyncio.wait_for(
slot_info["generator"].__anext__(),
timeout=0.05
)
results[slot_id] += chunk
except StopAsyncIteration:
slot_info["done"] = True
except asyncio.TimeoutError:
pass # No data ready, continue to next slot
yield results
def respond_parallel(
prompt0: str, prompt1: str, prompt2: str, prompt3: str,
prompt4: str, prompt5: str, prompt6: str, prompt7: str,
mode: list[str]
) -> Generator:
"""
Process up to 8 prompts in parallel using vLLM's continuous batching.
"""
prompts = [prompt0, prompt1, prompt2, prompt3, prompt4, prompt5, prompt6, prompt7]
_, loop = get_engine()
async def run():
async for results in run_parallel_async(prompts, mode):
yield tuple(results)
# Run the async generator in the event loop
agen = run()
try:
while True:
results = loop.run_until_complete(agen.__anext__())
yield results
except StopAsyncIteration:
return
def respond_single(slot_id: int, prompt: str, mode: list[str]) -> Generator:
"""Process a single prompt."""
if not prompt or not prompt.strip():
yield ""
return
engine, loop = get_engine()
full_prompt = build_prompt(prompt.strip(), mode)
request_id = f"single_{slot_id}_{uuid.uuid4().hex[:8]}"
async def run():
result = ""
async for chunk in astream_generate(engine, full_prompt, request_id):
result += chunk
yield result
agen = run()
try:
while True:
result = loop.run_until_complete(agen.__anext__())
yield result
except StopAsyncIteration:
return
# Build the Gradio interface
with gr.Blocks(title="JPharmatron Parallel Chat") as demo:
gr.Markdown("# 💊 JPharmatron - Parallel Request Processing")
gr.Markdown(
f"Enter up to {MAX_PARALLEL_REQUESTS} prompts and process them simultaneously. "
f"Using {NUM_GPUS} GPU(s) with vLLM continuous batching."
)
# Mode selection
mode = gr.CheckboxGroup(
label="モード (Mode)",
choices=["製薬の専門家", "国際基準に準拠", "具体的な手順"],
value=[],
)
# Preset examples
gr.Markdown("### 🔧 Presets (click to copy)")
preset_list = [
"グレープフルーツと薬を一緒に飲んじゃだめなんですか?",
"新薬の臨床試験(Phase I〜III)の概要を、具体例つきで簡単に教えて。",
"ジェネリック医薬品が承認されるまでの流れを、タイムラインで解説して。",
"抗生物質の作用機序と耐性菌について説明してください。",
"COVID-19ワクチンの開発プロセスを教えてください。",
"薬物相互作用の主なメカニズムを教えてください。",
"バイオシミラーと先発医薬品の違いは何ですか?",
"製薬企業のGMP(Good Manufacturing Practice)について説明してください。",
]
# Input section
gr.Markdown("### 📝 Input Prompts")
with gr.Row():
with gr.Column():
input_boxes = []
for i in range(4):
tb = gr.Textbox(
label=f"Prompt {i+1}",
placeholder=f"Enter prompt {i+1}...",
lines=3
)
input_boxes.append(tb)
with gr.Column():
for i in range(4, 8):
tb = gr.Textbox(
label=f"Prompt {i+1}",
placeholder=f"Enter prompt {i+1}...",
lines=3
)
input_boxes.append(tb)
# Examples that fill multiple boxes
gr.Examples(
examples=[preset_list[:4], preset_list[4:]],
inputs=input_boxes[:4],
label="Fill first 4 prompts with presets"
)
# Control buttons
with gr.Row():
run_all_btn = gr.Button("🚀 Run All in Parallel", variant="primary", scale=2)
clear_inputs_btn = gr.Button("🗑️ Clear Inputs", scale=1)
clear_outputs_btn = gr.Button("🗑️ Clear Outputs", scale=1)
# Output section
gr.Markdown("### 📤 Streaming Outputs")
with gr.Row():
with gr.Column():
output_boxes = []
for i in range(4):
tb = gr.Textbox(
label=f"Response {i+1}",
lines=10,
interactive=False,
show_copy_button=True
)
output_boxes.append(tb)
with gr.Column():
for i in range(4, 8):
tb = gr.Textbox(
label=f"Response {i+1}",
lines=10,
interactive=False,
show_copy_button=True
)
output_boxes.append(tb)
# Wire up the "Run All" button
run_all_btn.click(
fn=respond_parallel,
inputs=input_boxes + [mode],
outputs=output_boxes
)
# Clear buttons
clear_inputs_btn.click(
fn=lambda: tuple([""] * 8),
inputs=None,
outputs=input_boxes
)
clear_outputs_btn.click(
fn=lambda: tuple([""] * 8),
inputs=None,
outputs=output_boxes
)
# Individual run buttons for each slot
gr.Markdown("### 🎯 Run Individual Prompts")
with gr.Row():
for i in range(8):
btn = gr.Button(f"Run #{i+1}", size="sm")
# Create closure to capture slot_id
def make_single_handler(slot_id):
def handler(prompt, mode):
yield from respond_single(slot_id, prompt, mode)
return handler
btn.click(
fn=make_single_handler(i),
inputs=[input_boxes[i], mode],
outputs=[output_boxes[i]]
)
def main():
"""Entry point for the application."""
demo.queue()
demo.launch()
if __name__ == "__main__":
main()