Spaces:
Paused
Paused
| import asyncio | |
| import os | |
| import re | |
| import uuid | |
| import threading | |
| from typing import Generator | |
| import gradio as gr | |
| # Detect available GPUs | |
| def get_num_gpus() -> int: | |
| """Detect the number of available GPUs.""" | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| return torch.cuda.device_count() | |
| except ImportError: | |
| pass | |
| # Fallback: check CUDA_VISIBLE_DEVICES | |
| cuda_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "") | |
| if cuda_devices: | |
| return len(cuda_devices.split(",")) | |
| return 1 # Default to 1 | |
| NUM_GPUS = get_num_gpus() | |
| MAX_PARALLEL_REQUESTS = 8 # UI supports up to 8 parallel inputs | |
| print(f"Detected {NUM_GPUS} GPU(s)") | |
| # Stop strings for generation | |
| STOP_STRINGS = [ | |
| "\nUser:", "\nユーザ:", "\nユーザー:", | |
| "\nAssistant:", "\nアシスタント:", | |
| "\nHuman:", "\nHuman:" | |
| ] | |
| # Regex for post-processing cleanup | |
| STOP_RE = re.compile( | |
| r"(?:^|\n)(?:User|ユーザ|ユーザー|Assistant|アシスタント)[::].*", | |
| re.MULTILINE | |
| ) | |
| # Global vLLM engine (single instance, handles concurrent requests internally) | |
| _engine = None | |
| _engine_lock = threading.Lock() | |
| _loop = None | |
| def get_engine(): | |
| """Get or create the global vLLM engine.""" | |
| global _engine, _loop | |
| if _engine is None: | |
| with _engine_lock: | |
| if _engine is None: | |
| from vllm import AsyncLLMEngine, AsyncEngineArgs | |
| engine_args = AsyncEngineArgs( | |
| model="EQUES/JPharmatron-7B-chat", | |
| enforce_eager=True, | |
| gpu_memory_utilization=0.85, | |
| tensor_parallel_size=NUM_GPUS, # Use all available GPUs | |
| ) | |
| _engine = AsyncLLMEngine.from_engine_args(engine_args) | |
| _loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(_loop) | |
| return _engine, _loop | |
| def build_prompt(user_input: str, mode: list[str]) -> str: | |
| """Build the prompt with system instructions and mode settings.""" | |
| base_prompt = "あなたは製薬に関する専門家です。製薬に関するユーザーの質問に親切に回答してください。参照した文献を回答の末尾に常に提示してください。\n" | |
| if "製薬の専門家" in mode: | |
| base_prompt += "あなたは製薬に関する専門家です。製薬に関するユーザーの質問に親切に回答してください。参照した文献は常に提示してください。\n" | |
| if "国際基準に準拠" in mode: | |
| base_prompt += "回答に際して、国際基準に準拠してください。\n" | |
| if "具体的な手順" in mode: | |
| base_prompt += "回答には具体的な作業手順を含めてください。\n" | |
| base_prompt += f"ユーザー: {user_input}\nアシスタント:" | |
| return base_prompt | |
| async def astream_generate(engine, prompt: str, request_id: str): | |
| """Async generator that streams tokens from vLLM.""" | |
| from vllm import SamplingParams | |
| params = SamplingParams( | |
| temperature=0.0, | |
| max_tokens=4096, | |
| repetition_penalty=1.2, | |
| stop=STOP_STRINGS, | |
| ) | |
| previous_text = "" | |
| async for out in engine.generate(prompt, params, request_id=request_id): | |
| full_text = out.outputs[0].text | |
| # Check for stop patterns that might have leaked through | |
| m = STOP_RE.search(full_text) | |
| if m: | |
| cut = m.start() | |
| chunk = full_text[len(previous_text):cut] | |
| if chunk: | |
| yield chunk | |
| break | |
| chunk = full_text[len(previous_text):] | |
| previous_text = full_text | |
| if chunk: | |
| yield chunk | |
| async def run_parallel_async(prompts: list[str], mode: list[str]): | |
| """ | |
| Run multiple prompts in parallel using vLLM's continuous batching. | |
| Yields (slot_id, accumulated_text) tuples as tokens arrive. | |
| """ | |
| engine, _ = get_engine() | |
| # Build full prompts and track active slots | |
| active_slots = {} | |
| results = [""] * MAX_PARALLEL_REQUESTS | |
| for i, prompt in enumerate(prompts): | |
| if prompt and prompt.strip(): | |
| full_prompt = build_prompt(prompt.strip(), mode) | |
| request_id = f"req_{i}_{uuid.uuid4().hex[:8]}" | |
| active_slots[i] = { | |
| "request_id": request_id, | |
| "prompt": full_prompt, | |
| "generator": None, | |
| "done": False, | |
| } | |
| if not active_slots: | |
| yield results | |
| return | |
| # Start all generators | |
| for slot_id, slot_info in active_slots.items(): | |
| slot_info["generator"] = astream_generate( | |
| engine, slot_info["prompt"], slot_info["request_id"] | |
| ) | |
| # Poll all generators and yield updates | |
| while any(not slot["done"] for slot in active_slots.values()): | |
| for slot_id, slot_info in active_slots.items(): | |
| if slot_info["done"]: | |
| continue | |
| try: | |
| # Try to get next chunk with a small timeout | |
| chunk = await asyncio.wait_for( | |
| slot_info["generator"].__anext__(), | |
| timeout=0.05 | |
| ) | |
| results[slot_id] += chunk | |
| except StopAsyncIteration: | |
| slot_info["done"] = True | |
| except asyncio.TimeoutError: | |
| pass # No data ready, continue to next slot | |
| yield results | |
| def respond_parallel( | |
| prompt0: str, prompt1: str, prompt2: str, prompt3: str, | |
| prompt4: str, prompt5: str, prompt6: str, prompt7: str, | |
| mode: list[str] | |
| ) -> Generator: | |
| """ | |
| Process up to 8 prompts in parallel using vLLM's continuous batching. | |
| """ | |
| prompts = [prompt0, prompt1, prompt2, prompt3, prompt4, prompt5, prompt6, prompt7] | |
| _, loop = get_engine() | |
| async def run(): | |
| async for results in run_parallel_async(prompts, mode): | |
| yield tuple(results) | |
| # Run the async generator in the event loop | |
| agen = run() | |
| try: | |
| while True: | |
| results = loop.run_until_complete(agen.__anext__()) | |
| yield results | |
| except StopAsyncIteration: | |
| return | |
| def respond_single(slot_id: int, prompt: str, mode: list[str]) -> Generator: | |
| """Process a single prompt.""" | |
| if not prompt or not prompt.strip(): | |
| yield "" | |
| return | |
| engine, loop = get_engine() | |
| full_prompt = build_prompt(prompt.strip(), mode) | |
| request_id = f"single_{slot_id}_{uuid.uuid4().hex[:8]}" | |
| async def run(): | |
| result = "" | |
| async for chunk in astream_generate(engine, full_prompt, request_id): | |
| result += chunk | |
| yield result | |
| agen = run() | |
| try: | |
| while True: | |
| result = loop.run_until_complete(agen.__anext__()) | |
| yield result | |
| except StopAsyncIteration: | |
| return | |
| # Build the Gradio interface | |
| with gr.Blocks(title="JPharmatron Parallel Chat") as demo: | |
| gr.Markdown("# 💊 JPharmatron - Parallel Request Processing") | |
| gr.Markdown( | |
| f"Enter up to {MAX_PARALLEL_REQUESTS} prompts and process them simultaneously. " | |
| f"Using {NUM_GPUS} GPU(s) with vLLM continuous batching." | |
| ) | |
| # Mode selection | |
| mode = gr.CheckboxGroup( | |
| label="モード (Mode)", | |
| choices=["製薬の専門家", "国際基準に準拠", "具体的な手順"], | |
| value=[], | |
| ) | |
| # Preset examples | |
| gr.Markdown("### 🔧 Presets (click to copy)") | |
| preset_list = [ | |
| "グレープフルーツと薬を一緒に飲んじゃだめなんですか?", | |
| "新薬の臨床試験(Phase I〜III)の概要を、具体例つきで簡単に教えて。", | |
| "ジェネリック医薬品が承認されるまでの流れを、タイムラインで解説して。", | |
| "抗生物質の作用機序と耐性菌について説明してください。", | |
| "COVID-19ワクチンの開発プロセスを教えてください。", | |
| "薬物相互作用の主なメカニズムを教えてください。", | |
| "バイオシミラーと先発医薬品の違いは何ですか?", | |
| "製薬企業のGMP(Good Manufacturing Practice)について説明してください。", | |
| ] | |
| # Input section | |
| gr.Markdown("### 📝 Input Prompts") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_boxes = [] | |
| for i in range(4): | |
| tb = gr.Textbox( | |
| label=f"Prompt {i+1}", | |
| placeholder=f"Enter prompt {i+1}...", | |
| lines=3 | |
| ) | |
| input_boxes.append(tb) | |
| with gr.Column(): | |
| for i in range(4, 8): | |
| tb = gr.Textbox( | |
| label=f"Prompt {i+1}", | |
| placeholder=f"Enter prompt {i+1}...", | |
| lines=3 | |
| ) | |
| input_boxes.append(tb) | |
| # Examples that fill multiple boxes | |
| gr.Examples( | |
| examples=[preset_list[:4], preset_list[4:]], | |
| inputs=input_boxes[:4], | |
| label="Fill first 4 prompts with presets" | |
| ) | |
| # Control buttons | |
| with gr.Row(): | |
| run_all_btn = gr.Button("🚀 Run All in Parallel", variant="primary", scale=2) | |
| clear_inputs_btn = gr.Button("🗑️ Clear Inputs", scale=1) | |
| clear_outputs_btn = gr.Button("🗑️ Clear Outputs", scale=1) | |
| # Output section | |
| gr.Markdown("### 📤 Streaming Outputs") | |
| with gr.Row(): | |
| with gr.Column(): | |
| output_boxes = [] | |
| for i in range(4): | |
| tb = gr.Textbox( | |
| label=f"Response {i+1}", | |
| lines=10, | |
| interactive=False, | |
| show_copy_button=True | |
| ) | |
| output_boxes.append(tb) | |
| with gr.Column(): | |
| for i in range(4, 8): | |
| tb = gr.Textbox( | |
| label=f"Response {i+1}", | |
| lines=10, | |
| interactive=False, | |
| show_copy_button=True | |
| ) | |
| output_boxes.append(tb) | |
| # Wire up the "Run All" button | |
| run_all_btn.click( | |
| fn=respond_parallel, | |
| inputs=input_boxes + [mode], | |
| outputs=output_boxes | |
| ) | |
| # Clear buttons | |
| clear_inputs_btn.click( | |
| fn=lambda: tuple([""] * 8), | |
| inputs=None, | |
| outputs=input_boxes | |
| ) | |
| clear_outputs_btn.click( | |
| fn=lambda: tuple([""] * 8), | |
| inputs=None, | |
| outputs=output_boxes | |
| ) | |
| # Individual run buttons for each slot | |
| gr.Markdown("### 🎯 Run Individual Prompts") | |
| with gr.Row(): | |
| for i in range(8): | |
| btn = gr.Button(f"Run #{i+1}", size="sm") | |
| # Create closure to capture slot_id | |
| def make_single_handler(slot_id): | |
| def handler(prompt, mode): | |
| yield from respond_single(slot_id, prompt, mode) | |
| return handler | |
| btn.click( | |
| fn=make_single_handler(i), | |
| inputs=[input_boxes[i], mode], | |
| outputs=[output_boxes[i]] | |
| ) | |
| def main(): | |
| """Entry point for the application.""" | |
| demo.queue() | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |