Spaces:
Sleeping
Sleeping
File size: 6,094 Bytes
6252f54 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | """
LLM client with vLLM primary endpoint (AMD MI300X) and public API fallback.
Uses OpenAI-compatible API for both endpoints.
"""
import json
import logging
import time
import re
from typing import AsyncGenerator
from openai import AsyncOpenAI, APIConnectionError, APITimeoutError
from backend.config import Settings
log = logging.getLogger(__name__)
class LLMClient:
"""
Wraps the OpenAI-compatible API.
Primary: vLLM on AMD MI300X (api_key="EMPTY").
Fallback: Together.ai or any public Qwen API.
"""
def __init__(self, settings: Settings):
self._model = settings.vllm_model
self._fallback_model = settings.fallback_model
self._max_tokens = settings.llm_max_tokens
self._temperature = settings.llm_temperature
self._total_tokens = 0
self._total_time = 0.0
self._primary = AsyncOpenAI(
base_url=settings.vllm_base_url,
api_key="EMPTY",
timeout=settings.llm_timeout,
max_retries=1,
)
self._fallback: AsyncOpenAI | None = None
if settings.fallback_api_key:
self._fallback = AsyncOpenAI(
base_url=settings.fallback_base_url,
api_key=settings.fallback_api_key,
timeout=settings.llm_timeout,
max_retries=2,
)
else:
log.warning("No FALLBACK_API_KEY set — LLM will fail if vLLM endpoint is unreachable")
async def chat(
self,
messages: list[dict],
max_tokens: int | None = None,
temperature: float | None = None,
system: str | None = None,
) -> str:
"""Send a chat request. Returns assistant message content string."""
if system:
messages = [{"role": "system", "content": system}] + list(messages)
mt = max_tokens or self._max_tokens
temp = temperature or self._temperature
t0 = time.time()
try:
resp = await self._primary.chat.completions.create(
model=self._model,
messages=messages,
max_tokens=mt,
temperature=temp,
)
content = resp.choices[0].message.content or ""
elapsed = time.time() - t0
tokens = resp.usage.completion_tokens if resp.usage else 0
self._total_tokens += tokens
self._total_time += elapsed
log.info(f"vLLM: {tokens} tokens in {elapsed:.1f}s ({tokens/elapsed:.0f} tok/s)")
return content
except (APIConnectionError, APITimeoutError, Exception) as primary_err:
log.warning(f"Primary vLLM endpoint failed ({primary_err}), trying fallback...")
if not self._fallback:
raise RuntimeError("vLLM endpoint unreachable and no fallback API key configured") from primary_err
try:
resp = await self._fallback.chat.completions.create(
model=self._fallback_model,
messages=messages,
max_tokens=mt,
temperature=temp,
)
content = resp.choices[0].message.content or ""
elapsed = time.time() - t0
tokens = resp.usage.completion_tokens if resp.usage else 0
self._total_tokens += tokens
self._total_time += elapsed
log.info(f"Fallback API: {tokens} tokens in {elapsed:.1f}s")
return content
except Exception as fallback_err:
raise RuntimeError(f"Both LLM endpoints failed. Primary: {primary_err}. Fallback: {fallback_err}")
async def chat_stream(
self,
messages: list[dict],
max_tokens: int | None = None,
system: str | None = None,
) -> AsyncGenerator[str, None]:
"""Stream chat completions, yielding token chunks."""
if system:
messages = [{"role": "system", "content": system}] + list(messages)
try:
stream = await self._primary.chat.completions.create(
model=self._model,
messages=messages,
max_tokens=max_tokens or self._max_tokens,
temperature=self._temperature,
stream=True,
)
async for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta.content
if delta:
yield delta
except Exception:
if self._fallback:
stream = await self._fallback.chat.completions.create(
model=self._fallback_model,
messages=messages,
max_tokens=max_tokens or self._max_tokens,
stream=True,
)
async for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta.content
if delta:
yield delta
@property
def total_tokens(self) -> int:
return self._total_tokens
@property
def avg_tokens_per_second(self) -> float:
if self._total_time > 0:
return self._total_tokens / self._total_time
return 0.0
def extract_json(raw: str) -> dict | list:
"""Extract JSON from LLM response, handling markdown fences."""
# Strip markdown fences
cleaned = re.sub(r"```(?:json)?\s*", "", raw).strip().rstrip("```").strip()
# Try direct parse
try:
return json.loads(cleaned)
except json.JSONDecodeError:
pass
# Try to find JSON object/array within the text
for pattern in [r'\{.*\}', r'\[.*\]']:
match = re.search(pattern, cleaned, re.DOTALL)
if match:
try:
return json.loads(match.group())
except json.JSONDecodeError:
continue
raise ValueError(f"Could not extract valid JSON from LLM response: {raw[:200]}")
|