File size: 3,288 Bytes
e1624f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import time
import torch
import numpy as np
from typing import List, Dict
import os
from dotenv import load_dotenv
from openai import OpenAI

load_dotenv()

def benchmark_vllm_throughput(model_name: str, prompts: List[str]):
    """
    Benchmarks token throughput using the vLLM OpenAI-compatible API.
    Designed for AMD Instinct MI300X high-bandwidth performance.
    """
    api_base = os.getenv("VLLM_API_BASE", "http://localhost:8000/v1")
    api_key = os.getenv("VLLM_API_KEY", "EMPTY")
    
    client = OpenAI(api_key=api_key, base_url=api_base)
    
    print(f"\n--- Benchmarking vLLM on MI300X (Model: {model_name}) ---")
    
    total_tokens = 0
    start_time = time.time()
    
    for i, prompt in enumerate(prompts):
        p_start = time.time()
        response = client.chat.completions.create(
            model=model_name,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.0,
            max_tokens=256
        )
        p_end = time.time()
        
        tokens = response.usage.completion_tokens
        total_tokens += tokens
        print(f"  [Request {i+1}] {tokens} tokens in {p_end - p_start:.2f}s ({tokens/(p_end-p_start):.2f} tokens/s)")
        
    end_time = time.time()
    total_duration = end_time - start_time
    
    print(f"\n[SUMMARY]")
    print(f"  Total Duration: {total_duration:.2f}s")
    print(f"  Total Tokens: {total_tokens}")
    print(f"  Aggregate Throughput: {total_tokens/total_duration:.2f} tokens/s")
    print(f"  Hardware Target: AMD Instinct™ MI300X (ROCm 7.2)")

def validate_torch_rocm():
    """
    Validates ROCm device detection and HBM3 memory accessibility.
    """
    print("\n--- Validating ROCm Environment ---")
    if not torch.cuda.is_available():
        print("  [ERROR] ROCm/HIP not detected by PyTorch.")
        return
    
    device_name = torch.cuda.get_device_name(0)
    device_count = torch.cuda.device_count()
    free_mem, total_mem = torch.cuda.mem_get_info(0)
    
    print(f"  Device: {device_name}")
    print(f"  GPU Count: {device_count}")
    print(f"  Total HBM3 Memory: {total_mem / (1024**3):.2f} GB")
    print(f"  Free HBM3 Memory: {free_mem / (1024**3):.2f} GB")
    
    # Simple tensor operation to verify compute
    x = torch.randn(1000, 1000).to("cuda")
    y = torch.randn(1000, 1000).to("cuda")
    z = torch.matmul(x, y)
    torch.cuda.synchronize()
    print("  Compute Check (Matrix Mult): SUCCESS")

if __name__ == "__main__":
    validate_torch_rocm()
    
    # Sample clinical prompts for throughput test
    test_prompts = [
        "Explain the standard of care for Metastatic Non-Small Cell Lung Cancer with EGFR T790M mutation.",
        "Summarize the NCCN guidelines for Triple Negative Breast Cancer Stage III.",
        "What are the second-line treatment options for BRAF V600E positive Melanoma?",
        "Compare the efficacy of Pembrolizumab vs Nivolumab in advanced RCC.",
        "List the common adverse effects of Sotorasib in KRAS G12C mutated NSCLC."
    ]
    
    try:
        benchmark_vllm_throughput("meta-llama/Meta-Llama-3.1-8B-Instruct", test_prompts)
    except Exception as e:
        print(f"\n[WARNING] vLLM benchmark skipped: {e}")
        print("Ensure the vLLM server is running on port 8000.")