File size: 6,180 Bytes
56b60d4 1a6672d 28263c0 56b60d4 28263c0 1a6672d 984e3c2 56b60d4 1a6672d 984e3c2 56b60d4 984e3c2 56b60d4 1a6672d 984e3c2 1a6672d 984e3c2 56b60d4 1a6672d 984e3c2 1a6672d 984e3c2 56b60d4 1a6672d 984e3c2 56b60d4 1a6672d 27c4e2c 56b60d4 984e3c2 1a6672d 984e3c2 56b60d4 1a6672d 984e3c2 1a6672d 56b60d4 1a6672d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | from openai import OpenAI
from groq import Groq
from typing import Optional, Dict, Any
import os
from dotenv import load_dotenv
# pylint: disable=broad-exception-caught
# Load environment variables
load_dotenv()
class LLMClient:
"""
Unified LLM client supporting three providers:
1. Groq (default, local dev) — GROQ_API_KEY
2. vLLM on AMD Cloud (production) — USE_VLLM=true + VLLM_* vars
3. Qwen via HuggingFace Inference — USE_QWEN=true + QWEN_API_KEY
Model: Qwen/Qwen2.5-Coder-32B-Instruct (purpose-built for code tasks)
Qualifies for the AMD hackathon Qwen bonus prize.
"""
def __init__(self):
self.use_vllm = os.getenv("USE_VLLM", "false").lower() == "true"
self.use_qwen = os.getenv("USE_QWEN", "false").lower() == "true"
self.client = None
self.model = "mock"
self.provider = "mock"
self.init_error: Optional[str] = None
if self.use_vllm:
self._init_vllm()
elif self.use_qwen:
self._init_qwen()
else:
self._init_groq()
# ------------------------------------------------------------------
# Provider initializers
# ------------------------------------------------------------------
def _init_vllm(self) -> None:
"""Connect to vLLM endpoint on AMD Developer Cloud."""
self.vllm_base_url = os.getenv("VLLM_BASE_URL", "http://localhost:8000")
self.vllm_api_key = os.getenv("VLLM_API_KEY", "dummy-key")
try:
self.client = OpenAI(
base_url=self.vllm_base_url,
api_key=self.vllm_api_key
)
self.model = os.getenv("VLLM_MODEL", "amd/llama-3.3-70b")
self.provider = "vLLM (AMD Cloud)"
except Exception as e:
self.init_error = f"vLLM client init failed: {str(e)}"
print(f"Warning: {self.init_error}. Falling back to mock mode.")
def _init_qwen(self) -> None:
"""
Connect to Qwen/Qwen2.5-Coder-32B-Instruct via HuggingFace Inference API.
Qwen2.5-Coder-32B-Instruct is purpose-built for code tasks and is directly
relevant to CUDA-to-HIP translation. Free tier on HuggingFace — no billing.
Set USE_QWEN=true and QWEN_API_KEY=hf_... in .env to activate.
"""
qwen_api_key = os.getenv("QWEN_API_KEY")
if not qwen_api_key:
print("Warning: QWEN_API_KEY not found. Falling back to Groq.")
self._init_groq()
return
try:
# HuggingFace Inference API exposes an OpenAI-compatible endpoint
hf_base_url = os.getenv(
"QWEN_BASE_URL",
"https://api-inference.huggingface.co/models/Qwen/Qwen2.5-Coder-32B-Instruct/v1"
)
self.client = OpenAI(
base_url=hf_base_url,
api_key=qwen_api_key,
)
self.model = os.getenv("QWEN_MODEL", "Qwen/Qwen2.5-Coder-32B-Instruct")
self.provider = "Qwen (HuggingFace)"
except Exception as e:
self.init_error = f"Qwen client init failed: {str(e)}"
print(f"Warning: {self.init_error}. Falling back to Groq.")
self._init_groq()
def _init_groq(self) -> None:
"""Connect to Groq (LLaMA-3.3-70B). Default provider for local development."""
self.groq_api_key = os.getenv("GROQ_API_KEY")
if not self.groq_api_key:
print("Warning: GROQ_API_KEY not found. Using mock mode.")
self.provider = "mock"
return
try:
self.client = Groq(api_key=self.groq_api_key)
self.model = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
self.provider = "Groq (LLaMA-3.3-70B)"
except Exception as e:
self.init_error = f"Groq client init failed: {str(e)}"
print(f"Warning: {self.init_error}. Falling back to mock mode.")
self.provider = "mock"
# ------------------------------------------------------------------
# Core interface
# ------------------------------------------------------------------
def chat_completion(self, messages: list, temperature: float = 0.7, max_tokens: int = 4000) -> str:
"""Send chat completion request to the configured LLM."""
if self.client is None:
# Mock response when no API key is available
return (
'{"kernels_found": ["mock_kernel"], "cuda_apis": ["cudaMalloc"], '
'"warp_size_issue": true, "workload_type": "memory-bound", '
'"sharding_detected": false, "difficulty": "Medium"}'
)
try:
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
return response.choices[0].message.content
except Exception as e:
message = str(e)
lowered = message.lower()
if "rate limit" in lowered or "429" in lowered or "quota" in lowered:
raise RuntimeError(f"LLM request rate-limited: {message}") from e
raise RuntimeError(f"LLM request failed: {message}") from e
# ------------------------------------------------------------------
# Utility / introspection
# ------------------------------------------------------------------
def get_model_info(self) -> Dict[str, Any]:
"""Return current provider configuration for the /health and /benchmark-report endpoints."""
return {
"provider": self.provider,
"model": self.model,
}
def test_connection(self) -> bool:
"""Test if the LLM connection is working."""
try:
test_messages = [
{"role": "user", "content": "Respond with 'OK' if you can read this."}
]
response = self.chat_completion(test_messages, max_tokens=10)
return "OK" in response.upper()
except Exception:
return False
|