File size: 6,094 Bytes
6252f54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
LLM client with vLLM primary endpoint (AMD MI300X) and public API fallback.
Uses OpenAI-compatible API for both endpoints.
"""

import json
import logging
import time
import re
from typing import AsyncGenerator
from openai import AsyncOpenAI, APIConnectionError, APITimeoutError

from backend.config import Settings

log = logging.getLogger(__name__)


class LLMClient:
    """
    Wraps the OpenAI-compatible API.
    Primary: vLLM on AMD MI300X (api_key="EMPTY").
    Fallback: Together.ai or any public Qwen API.
    """

    def __init__(self, settings: Settings):
        self._model = settings.vllm_model
        self._fallback_model = settings.fallback_model
        self._max_tokens = settings.llm_max_tokens
        self._temperature = settings.llm_temperature
        self._total_tokens = 0
        self._total_time = 0.0

        self._primary = AsyncOpenAI(
            base_url=settings.vllm_base_url,
            api_key="EMPTY",
            timeout=settings.llm_timeout,
            max_retries=1,
        )

        self._fallback: AsyncOpenAI | None = None
        if settings.fallback_api_key:
            self._fallback = AsyncOpenAI(
                base_url=settings.fallback_base_url,
                api_key=settings.fallback_api_key,
                timeout=settings.llm_timeout,
                max_retries=2,
            )
        else:
            log.warning("No FALLBACK_API_KEY set — LLM will fail if vLLM endpoint is unreachable")

    async def chat(
        self,
        messages: list[dict],
        max_tokens: int | None = None,
        temperature: float | None = None,
        system: str | None = None,
    ) -> str:
        """Send a chat request. Returns assistant message content string."""
        if system:
            messages = [{"role": "system", "content": system}] + list(messages)

        mt = max_tokens or self._max_tokens
        temp = temperature or self._temperature

        t0 = time.time()
        try:
            resp = await self._primary.chat.completions.create(
                model=self._model,
                messages=messages,
                max_tokens=mt,
                temperature=temp,
            )
            content = resp.choices[0].message.content or ""
            elapsed = time.time() - t0
            tokens = resp.usage.completion_tokens if resp.usage else 0
            self._total_tokens += tokens
            self._total_time += elapsed
            log.info(f"vLLM: {tokens} tokens in {elapsed:.1f}s ({tokens/elapsed:.0f} tok/s)")
            return content

        except (APIConnectionError, APITimeoutError, Exception) as primary_err:
            log.warning(f"Primary vLLM endpoint failed ({primary_err}), trying fallback...")
            if not self._fallback:
                raise RuntimeError("vLLM endpoint unreachable and no fallback API key configured") from primary_err
            try:
                resp = await self._fallback.chat.completions.create(
                    model=self._fallback_model,
                    messages=messages,
                    max_tokens=mt,
                    temperature=temp,
                )
                content = resp.choices[0].message.content or ""
                elapsed = time.time() - t0
                tokens = resp.usage.completion_tokens if resp.usage else 0
                self._total_tokens += tokens
                self._total_time += elapsed
                log.info(f"Fallback API: {tokens} tokens in {elapsed:.1f}s")
                return content
            except Exception as fallback_err:
                raise RuntimeError(f"Both LLM endpoints failed. Primary: {primary_err}. Fallback: {fallback_err}")

    async def chat_stream(
        self,
        messages: list[dict],
        max_tokens: int | None = None,
        system: str | None = None,
    ) -> AsyncGenerator[str, None]:
        """Stream chat completions, yielding token chunks."""
        if system:
            messages = [{"role": "system", "content": system}] + list(messages)

        try:
            stream = await self._primary.chat.completions.create(
                model=self._model,
                messages=messages,
                max_tokens=max_tokens or self._max_tokens,
                temperature=self._temperature,
                stream=True,
            )
            async for chunk in stream:
                if not chunk.choices:
                    continue
                delta = chunk.choices[0].delta.content
                if delta:
                    yield delta
        except Exception:
            if self._fallback:
                stream = await self._fallback.chat.completions.create(
                    model=self._fallback_model,
                    messages=messages,
                    max_tokens=max_tokens or self._max_tokens,
                    stream=True,
                )
                async for chunk in stream:
                    if not chunk.choices:
                        continue
                    delta = chunk.choices[0].delta.content
                    if delta:
                        yield delta

    @property
    def total_tokens(self) -> int:
        return self._total_tokens

    @property
    def avg_tokens_per_second(self) -> float:
        if self._total_time > 0:
            return self._total_tokens / self._total_time
        return 0.0


def extract_json(raw: str) -> dict | list:
    """Extract JSON from LLM response, handling markdown fences."""
    # Strip markdown fences
    cleaned = re.sub(r"```(?:json)?\s*", "", raw).strip().rstrip("```").strip()

    # Try direct parse
    try:
        return json.loads(cleaned)
    except json.JSONDecodeError:
        pass

    # Try to find JSON object/array within the text
    for pattern in [r'\{.*\}', r'\[.*\]']:
        match = re.search(pattern, cleaned, re.DOTALL)
        if match:
            try:
                return json.loads(match.group())
            except json.JSONDecodeError:
                continue

    raise ValueError(f"Could not extract valid JSON from LLM response: {raw[:200]}")