import asyncio import os import re import uuid import threading from typing import Generator import gradio as gr # Detect available GPUs def get_num_gpus() -> int: """Detect the number of available GPUs.""" try: import torch if torch.cuda.is_available(): return torch.cuda.device_count() except ImportError: pass # Fallback: check CUDA_VISIBLE_DEVICES cuda_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "") if cuda_devices: return len(cuda_devices.split(",")) return 1 # Default to 1 NUM_GPUS = get_num_gpus() MAX_PARALLEL_REQUESTS = 8 # UI supports up to 8 parallel inputs print(f"Detected {NUM_GPUS} GPU(s)") # Stop strings for generation STOP_STRINGS = [ "\nUser:", "\nユーザ:", "\nユーザー:", "\nAssistant:", "\nアシスタント:", "\nHuman:", "\nHuman:" ] # Regex for post-processing cleanup STOP_RE = re.compile( r"(?:^|\n)(?:User|ユーザ|ユーザー|Assistant|アシスタント)[::].*", re.MULTILINE ) # Global vLLM engine (single instance, handles concurrent requests internally) _engine = None _engine_lock = threading.Lock() _loop = None def get_engine(): """Get or create the global vLLM engine.""" global _engine, _loop if _engine is None: with _engine_lock: if _engine is None: from vllm import AsyncLLMEngine, AsyncEngineArgs engine_args = AsyncEngineArgs( model="EQUES/JPharmatron-7B-chat", enforce_eager=True, gpu_memory_utilization=0.85, tensor_parallel_size=NUM_GPUS, # Use all available GPUs ) _engine = AsyncLLMEngine.from_engine_args(engine_args) _loop = asyncio.new_event_loop() asyncio.set_event_loop(_loop) return _engine, _loop def build_prompt(user_input: str, mode: list[str]) -> str: """Build the prompt with system instructions and mode settings.""" base_prompt = "あなたは製薬に関する専門家です。製薬に関するユーザーの質問に親切に回答してください。参照した文献を回答の末尾に常に提示してください。\n" if "製薬の専門家" in mode: base_prompt += "あなたは製薬に関する専門家です。製薬に関するユーザーの質問に親切に回答してください。参照した文献は常に提示してください。\n" if "国際基準に準拠" in mode: base_prompt += "回答に際して、国際基準に準拠してください。\n" if "具体的な手順" in mode: base_prompt += "回答には具体的な作業手順を含めてください。\n" base_prompt += f"ユーザー: {user_input}\nアシスタント:" return base_prompt async def astream_generate(engine, prompt: str, request_id: str): """Async generator that streams tokens from vLLM.""" from vllm import SamplingParams params = SamplingParams( temperature=0.0, max_tokens=4096, repetition_penalty=1.2, stop=STOP_STRINGS, ) previous_text = "" async for out in engine.generate(prompt, params, request_id=request_id): full_text = out.outputs[0].text # Check for stop patterns that might have leaked through m = STOP_RE.search(full_text) if m: cut = m.start() chunk = full_text[len(previous_text):cut] if chunk: yield chunk break chunk = full_text[len(previous_text):] previous_text = full_text if chunk: yield chunk async def run_parallel_async(prompts: list[str], mode: list[str]): """ Run multiple prompts in parallel using vLLM's continuous batching. Yields (slot_id, accumulated_text) tuples as tokens arrive. """ engine, _ = get_engine() # Build full prompts and track active slots active_slots = {} results = [""] * MAX_PARALLEL_REQUESTS for i, prompt in enumerate(prompts): if prompt and prompt.strip(): full_prompt = build_prompt(prompt.strip(), mode) request_id = f"req_{i}_{uuid.uuid4().hex[:8]}" active_slots[i] = { "request_id": request_id, "prompt": full_prompt, "generator": None, "done": False, } if not active_slots: yield results return # Start all generators for slot_id, slot_info in active_slots.items(): slot_info["generator"] = astream_generate( engine, slot_info["prompt"], slot_info["request_id"] ) # Poll all generators and yield updates while any(not slot["done"] for slot in active_slots.values()): for slot_id, slot_info in active_slots.items(): if slot_info["done"]: continue try: # Try to get next chunk with a small timeout chunk = await asyncio.wait_for( slot_info["generator"].__anext__(), timeout=0.05 ) results[slot_id] += chunk except StopAsyncIteration: slot_info["done"] = True except asyncio.TimeoutError: pass # No data ready, continue to next slot yield results def respond_parallel( prompt0: str, prompt1: str, prompt2: str, prompt3: str, prompt4: str, prompt5: str, prompt6: str, prompt7: str, mode: list[str] ) -> Generator: """ Process up to 8 prompts in parallel using vLLM's continuous batching. """ prompts = [prompt0, prompt1, prompt2, prompt3, prompt4, prompt5, prompt6, prompt7] _, loop = get_engine() async def run(): async for results in run_parallel_async(prompts, mode): yield tuple(results) # Run the async generator in the event loop agen = run() try: while True: results = loop.run_until_complete(agen.__anext__()) yield results except StopAsyncIteration: return def respond_single(slot_id: int, prompt: str, mode: list[str]) -> Generator: """Process a single prompt.""" if not prompt or not prompt.strip(): yield "" return engine, loop = get_engine() full_prompt = build_prompt(prompt.strip(), mode) request_id = f"single_{slot_id}_{uuid.uuid4().hex[:8]}" async def run(): result = "" async for chunk in astream_generate(engine, full_prompt, request_id): result += chunk yield result agen = run() try: while True: result = loop.run_until_complete(agen.__anext__()) yield result except StopAsyncIteration: return # Build the Gradio interface with gr.Blocks(title="JPharmatron Parallel Chat") as demo: gr.Markdown("# 💊 JPharmatron - Parallel Request Processing") gr.Markdown( f"Enter up to {MAX_PARALLEL_REQUESTS} prompts and process them simultaneously. " f"Using {NUM_GPUS} GPU(s) with vLLM continuous batching." ) # Mode selection mode = gr.CheckboxGroup( label="モード (Mode)", choices=["製薬の専門家", "国際基準に準拠", "具体的な手順"], value=[], ) # Preset examples gr.Markdown("### 🔧 Presets (click to copy)") preset_list = [ "グレープフルーツと薬を一緒に飲んじゃだめなんですか?", "新薬の臨床試験(Phase I〜III)の概要を、具体例つきで簡単に教えて。", "ジェネリック医薬品が承認されるまでの流れを、タイムラインで解説して。", "抗生物質の作用機序と耐性菌について説明してください。", "COVID-19ワクチンの開発プロセスを教えてください。", "薬物相互作用の主なメカニズムを教えてください。", "バイオシミラーと先発医薬品の違いは何ですか?", "製薬企業のGMP(Good Manufacturing Practice)について説明してください。", ] # Input section gr.Markdown("### 📝 Input Prompts") with gr.Row(): with gr.Column(): input_boxes = [] for i in range(4): tb = gr.Textbox( label=f"Prompt {i+1}", placeholder=f"Enter prompt {i+1}...", lines=3 ) input_boxes.append(tb) with gr.Column(): for i in range(4, 8): tb = gr.Textbox( label=f"Prompt {i+1}", placeholder=f"Enter prompt {i+1}...", lines=3 ) input_boxes.append(tb) # Examples that fill multiple boxes gr.Examples( examples=[preset_list[:4], preset_list[4:]], inputs=input_boxes[:4], label="Fill first 4 prompts with presets" ) # Control buttons with gr.Row(): run_all_btn = gr.Button("🚀 Run All in Parallel", variant="primary", scale=2) clear_inputs_btn = gr.Button("🗑️ Clear Inputs", scale=1) clear_outputs_btn = gr.Button("🗑️ Clear Outputs", scale=1) # Output section gr.Markdown("### 📤 Streaming Outputs") with gr.Row(): with gr.Column(): output_boxes = [] for i in range(4): tb = gr.Textbox( label=f"Response {i+1}", lines=10, interactive=False, show_copy_button=True ) output_boxes.append(tb) with gr.Column(): for i in range(4, 8): tb = gr.Textbox( label=f"Response {i+1}", lines=10, interactive=False, show_copy_button=True ) output_boxes.append(tb) # Wire up the "Run All" button run_all_btn.click( fn=respond_parallel, inputs=input_boxes + [mode], outputs=output_boxes ) # Clear buttons clear_inputs_btn.click( fn=lambda: tuple([""] * 8), inputs=None, outputs=input_boxes ) clear_outputs_btn.click( fn=lambda: tuple([""] * 8), inputs=None, outputs=output_boxes ) # Individual run buttons for each slot gr.Markdown("### 🎯 Run Individual Prompts") with gr.Row(): for i in range(8): btn = gr.Button(f"Run #{i+1}", size="sm") # Create closure to capture slot_id def make_single_handler(slot_id): def handler(prompt, mode): yield from respond_single(slot_id, prompt, mode) return handler btn.click( fn=make_single_handler(i), inputs=[input_boxes[i], mode], outputs=[output_boxes[i]] ) def main(): """Entry point for the application.""" demo.queue() demo.launch() if __name__ == "__main__": main()