import time import torch import numpy as np from typing import List, Dict import os from dotenv import load_dotenv from openai import OpenAI load_dotenv() def benchmark_vllm_throughput(model_name: str, prompts: List[str]): """ Benchmarks token throughput using the vLLM OpenAI-compatible API. Designed for AMD Instinct MI300X high-bandwidth performance. """ api_base = os.getenv("VLLM_API_BASE", "http://localhost:8000/v1") api_key = os.getenv("VLLM_API_KEY", "EMPTY") client = OpenAI(api_key=api_key, base_url=api_base) print(f"\n--- Benchmarking vLLM on MI300X (Model: {model_name}) ---") total_tokens = 0 start_time = time.time() for i, prompt in enumerate(prompts): p_start = time.time() response = client.chat.completions.create( model=model_name, messages=[{"role": "user", "content": prompt}], temperature=0.0, max_tokens=256 ) p_end = time.time() tokens = response.usage.completion_tokens total_tokens += tokens print(f" [Request {i+1}] {tokens} tokens in {p_end - p_start:.2f}s ({tokens/(p_end-p_start):.2f} tokens/s)") end_time = time.time() total_duration = end_time - start_time print(f"\n[SUMMARY]") print(f" Total Duration: {total_duration:.2f}s") print(f" Total Tokens: {total_tokens}") print(f" Aggregate Throughput: {total_tokens/total_duration:.2f} tokens/s") print(f" Hardware Target: AMD Instinctâ„¢ MI300X (ROCm 7.2)") def validate_torch_rocm(): """ Validates ROCm device detection and HBM3 memory accessibility. """ print("\n--- Validating ROCm Environment ---") if not torch.cuda.is_available(): print(" [ERROR] ROCm/HIP not detected by PyTorch.") return device_name = torch.cuda.get_device_name(0) device_count = torch.cuda.device_count() free_mem, total_mem = torch.cuda.mem_get_info(0) print(f" Device: {device_name}") print(f" GPU Count: {device_count}") print(f" Total HBM3 Memory: {total_mem / (1024**3):.2f} GB") print(f" Free HBM3 Memory: {free_mem / (1024**3):.2f} GB") # Simple tensor operation to verify compute x = torch.randn(1000, 1000).to("cuda") y = torch.randn(1000, 1000).to("cuda") z = torch.matmul(x, y) torch.cuda.synchronize() print(" Compute Check (Matrix Mult): SUCCESS") if __name__ == "__main__": validate_torch_rocm() # Sample clinical prompts for throughput test test_prompts = [ "Explain the standard of care for Metastatic Non-Small Cell Lung Cancer with EGFR T790M mutation.", "Summarize the NCCN guidelines for Triple Negative Breast Cancer Stage III.", "What are the second-line treatment options for BRAF V600E positive Melanoma?", "Compare the efficacy of Pembrolizumab vs Nivolumab in advanced RCC.", "List the common adverse effects of Sotorasib in KRAS G12C mutated NSCLC." ] try: benchmark_vllm_throughput("meta-llama/Meta-Llama-3.1-8B-Instruct", test_prompts) except Exception as e: print(f"\n[WARNING] vLLM benchmark skipped: {e}") print("Ensure the vLLM server is running on port 8000.")