ROCmPort-AI / backend /tools /llm_client.py
tazwarrrr's picture
fixing bugs23
27c4e2c
from openai import OpenAI
from groq import Groq
from typing import Optional, Dict, Any
import os
from dotenv import load_dotenv
# pylint: disable=broad-exception-caught
# Load environment variables
load_dotenv()
class LLMClient:
"""
Unified LLM client supporting three providers:
1. Groq (default, local dev) — GROQ_API_KEY
2. vLLM on AMD Cloud (production) — USE_VLLM=true + VLLM_* vars
3. Qwen via HuggingFace Inference — USE_QWEN=true + QWEN_API_KEY
Model: Qwen/Qwen2.5-Coder-32B-Instruct (purpose-built for code tasks)
Qualifies for the AMD hackathon Qwen bonus prize.
"""
def __init__(self):
self.use_vllm = os.getenv("USE_VLLM", "false").lower() == "true"
self.use_qwen = os.getenv("USE_QWEN", "false").lower() == "true"
self.client = None
self.model = "mock"
self.provider = "mock"
self.init_error: Optional[str] = None
if self.use_vllm:
self._init_vllm()
elif self.use_qwen:
self._init_qwen()
else:
self._init_groq()
# ------------------------------------------------------------------
# Provider initializers
# ------------------------------------------------------------------
def _init_vllm(self) -> None:
"""Connect to vLLM endpoint on AMD Developer Cloud."""
self.vllm_base_url = os.getenv("VLLM_BASE_URL", "http://localhost:8000")
self.vllm_api_key = os.getenv("VLLM_API_KEY", "dummy-key")
try:
self.client = OpenAI(
base_url=self.vllm_base_url,
api_key=self.vllm_api_key
)
self.model = os.getenv("VLLM_MODEL", "amd/llama-3.3-70b")
self.provider = "vLLM (AMD Cloud)"
except Exception as e:
self.init_error = f"vLLM client init failed: {str(e)}"
print(f"Warning: {self.init_error}. Falling back to mock mode.")
def _init_qwen(self) -> None:
"""
Connect to Qwen/Qwen2.5-Coder-32B-Instruct via HuggingFace Inference API.
Qwen2.5-Coder-32B-Instruct is purpose-built for code tasks and is directly
relevant to CUDA-to-HIP translation. Free tier on HuggingFace — no billing.
Set USE_QWEN=true and QWEN_API_KEY=hf_... in .env to activate.
"""
qwen_api_key = os.getenv("QWEN_API_KEY")
if not qwen_api_key:
print("Warning: QWEN_API_KEY not found. Falling back to Groq.")
self._init_groq()
return
try:
# HuggingFace Inference API exposes an OpenAI-compatible endpoint
hf_base_url = os.getenv(
"QWEN_BASE_URL",
"https://api-inference.huggingface.co/models/Qwen/Qwen2.5-Coder-32B-Instruct/v1"
)
self.client = OpenAI(
base_url=hf_base_url,
api_key=qwen_api_key,
)
self.model = os.getenv("QWEN_MODEL", "Qwen/Qwen2.5-Coder-32B-Instruct")
self.provider = "Qwen (HuggingFace)"
except Exception as e:
self.init_error = f"Qwen client init failed: {str(e)}"
print(f"Warning: {self.init_error}. Falling back to Groq.")
self._init_groq()
def _init_groq(self) -> None:
"""Connect to Groq (LLaMA-3.3-70B). Default provider for local development."""
self.groq_api_key = os.getenv("GROQ_API_KEY")
if not self.groq_api_key:
print("Warning: GROQ_API_KEY not found. Using mock mode.")
self.provider = "mock"
return
try:
self.client = Groq(api_key=self.groq_api_key)
self.model = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
self.provider = "Groq (LLaMA-3.3-70B)"
except Exception as e:
self.init_error = f"Groq client init failed: {str(e)}"
print(f"Warning: {self.init_error}. Falling back to mock mode.")
self.provider = "mock"
# ------------------------------------------------------------------
# Core interface
# ------------------------------------------------------------------
def chat_completion(self, messages: list, temperature: float = 0.7, max_tokens: int = 4000) -> str:
"""Send chat completion request to the configured LLM."""
if self.client is None:
# Mock response when no API key is available
return (
'{"kernels_found": ["mock_kernel"], "cuda_apis": ["cudaMalloc"], '
'"warp_size_issue": true, "workload_type": "memory-bound", '
'"sharding_detected": false, "difficulty": "Medium"}'
)
try:
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
return response.choices[0].message.content
except Exception as e:
message = str(e)
lowered = message.lower()
if "rate limit" in lowered or "429" in lowered or "quota" in lowered:
raise RuntimeError(f"LLM request rate-limited: {message}") from e
raise RuntimeError(f"LLM request failed: {message}") from e
# ------------------------------------------------------------------
# Utility / introspection
# ------------------------------------------------------------------
def get_model_info(self) -> Dict[str, Any]:
"""Return current provider configuration for the /health and /benchmark-report endpoints."""
return {
"provider": self.provider,
"model": self.model,
}
def test_connection(self) -> bool:
"""Test if the LLM connection is working."""
try:
test_messages = [
{"role": "user", "content": "Respond with 'OK' if you can read this."}
]
response = self.chat_completion(test_messages, max_tokens=10)
return "OK" in response.upper()
except Exception:
return False