File size: 5,164 Bytes
cd3b358
 
342230a
cd3b358
 
 
342230a
cd3b358
 
 
 
 
 
342230a
 
 
cd3b358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342230a
 
 
 
 
 
cd3b358
 
 
 
 
 
342230a
cd3b358
 
 
 
 
 
 
 
 
342230a
 
 
 
 
 
 
 
 
 
 
cd3b358
342230a
 
 
 
cd3b358
342230a
 
 
 
 
 
 
 
 
 
 
 
 
cd3b358
 
 
 
 
 
 
342230a
cd3b358
 
 
 
 
 
 
 
342230a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd3b358
 
 
 
 
342230a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
LLM Client for GLM 5.1 via Z.ai API (OpenAI-compatible endpoint).
Includes automatic retry with exponential backoff for rate-limit (429) errors.
"""
import time
import logging
import random
from typing import Iterator, List, Dict, Any, Optional
import openai
from backend.config import GLM_API_KEY, GLM_BASE_URL, GLM_MODEL, LOG_LLM_CALLS

logger = logging.getLogger(__name__)

MAX_RETRIES = 5
INITIAL_BACKOFF = 5  # seconds


class GLMClient:
    """OpenAI-compatible wrapper for Z.ai's GLM models."""

    def __init__(
        self,
        api_key: Optional[str] = None,
        base_url: str = GLM_BASE_URL,
        model: str = GLM_MODEL,
    ):
        self.api_key = api_key or GLM_API_KEY
        self.base_url = base_url
        self.model = model
        self._client: Optional[openai.OpenAI] = None

    def _get_client(self) -> openai.OpenAI:
        if self._client is None:
            if not self.api_key:
                raise ValueError(
                    "GLM API key is not set. Please provide it in the sidebar or .env file."
                )
            self._client = openai.OpenAI(
                api_key=self.api_key,
                base_url=self.base_url,
            )
        return self._client

    def _backoff_wait(self, attempt: int) -> None:
        """Exponential backoff with jitter. Waits and logs the wait time."""
        wait = INITIAL_BACKOFF * (2 ** attempt) + random.uniform(0, 2)
        logger.warning("[GLM] Rate limited (429). Retrying in %.1fs (attempt %d/%d)...", wait, attempt + 1, MAX_RETRIES)
        time.sleep(wait)

    def chat(
        self,
        messages: List[Dict[str, str]],
        temperature: float = 0.3,
        max_tokens: int = 4096,
    ) -> str:
        """Synchronous chat completion with automatic retry on 429."""
        client = self._get_client()
        start = time.time()

        if LOG_LLM_CALLS:
            logger.info(
                "[GLM] chat() | model=%s | messages=%d | temp=%.1f",
                self.model, len(messages), temperature,
            )

        last_error = None
        for attempt in range(MAX_RETRIES):
            try:
                response = client.chat.completions.create(
                    model=self.model,
                    messages=messages,
                    temperature=temperature,
                    max_tokens=max_tokens,
                )
                content = response.choices[0].message.content or ""
                elapsed = time.time() - start

                if LOG_LLM_CALLS:
                    logger.info("[GLM] completed in %.2fs | output_chars=%d", elapsed, len(content))

                return content

            except openai.RateLimitError as e:
                last_error = e
                if attempt < MAX_RETRIES - 1:
                    self._backoff_wait(attempt)
                else:
                    raise RuntimeError(
                        f"GLM API rate limit exceeded after {MAX_RETRIES} retries. "
                        f"Please wait a moment and try again. Detail: {e}"
                    ) from e
            except openai.APIError as e:
                raise RuntimeError(f"GLM API error: {e}") from e

        raise RuntimeError(f"GLM request failed after {MAX_RETRIES} attempts: {last_error}")

    def chat_stream(
        self,
        messages: List[Dict[str, str]],
        temperature: float = 0.3,
        max_tokens: int = 4096,
    ) -> Iterator[str]:
        """Streaming chat completion with automatic retry on 429."""
        client = self._get_client()

        if LOG_LLM_CALLS:
            logger.info(
                "[GLM] chat_stream() | model=%s | messages=%d",
                self.model, len(messages),
            )

        last_error = None
        for attempt in range(MAX_RETRIES):
            try:
                response = client.chat.completions.create(
                    model=self.model,
                    messages=messages,
                    temperature=temperature,
                    max_tokens=max_tokens,
                    stream=True,
                )
                for chunk in response:
                    delta = chunk.choices[0].delta
                    if delta and delta.content:
                        yield delta.content
                return  # Completed successfully

            except openai.RateLimitError as e:
                last_error = e
                if attempt < MAX_RETRIES - 1:
                    self._backoff_wait(attempt)
                else:
                    raise RuntimeError(
                        f"GLM API rate limit exceeded after {MAX_RETRIES} retries. "
                        f"Please wait a moment and try again. Detail: {e}"
                    ) from e
            except openai.APIError as e:
                raise RuntimeError(f"GLM API error: {e}") from e

        raise RuntimeError(f"GLM stream failed after {MAX_RETRIES} attempts: {last_error}")

    def update_api_key(self, api_key: str) -> None:
        """Allow hot-swapping the API key (e.g. from Streamlit sidebar)."""
        self.api_key = api_key
        self._client = None  # Force re-initialization