SkyAlone / services /glm.py
FreshPixels's picture
Rename services glm.py to services/glm.py
fa43855 verified
Raw
History Blame Contribute Delete
14.5 kB
import asyncio
import logging
import time
from typing import List, Dict, Any, Optional, AsyncGenerator
import httpx
import openai
from openai import AsyncOpenAI
from config import config
from database import db
logger = logging.getLogger(__name__)
class LLMServiceError(Exception):
"""Base exception for LLM service errors."""
pass
class LLMTimeoutError(LLMServiceError):
"""Timeout error with fallback exhausted."""
pass
class LLMRateLimitError(LLMServiceError):
"""Rate limit from provider."""
pass
class GLMService:
"""Production-ready GLM service with retry, fallback, streaming, metrics."""
def __init__(self) -> None:
# Granular timeouts for NVIDIA NIM
self.timeout = httpx.Timeout(
connect=config.TIMEOUT_CONNECT,
read=config.TIMEOUT_READ,
write=config.TIMEOUT_WRITE,
pool=config.TIMEOUT_POOL,
)
self.primary_client = AsyncOpenAI(
base_url=config.NVIDIA_BASE_URL,
api_key=config.NVIDIA_API_KEY,
timeout=self.timeout,
max_retries=0, # управляСм retry ΡΠ°ΠΌΠΎΡΡ‚ΠΎΡΡ‚Π΅Π»ΡŒΠ½ΠΎ
)
self.fallback_client = AsyncOpenAI(
base_url=config.NVIDIA_BASE_URL,
api_key=config.NVIDIA_API_KEY,
timeout=self.timeout,
max_retries=0,
) if config.FALLBACK_ENABLED else None
self.primary_model = config.PRIMARY_MODEL
self.fallback_model = config.FALLBACK_MODEL
# ═══════════════════════════════════════════════════════════════
# BLOCK: Exponential Retry with Jitter
# ═══════════════════════════════════════════════════════════════
def _calculate_delay(self, attempt: int) -> float:
"""Π­ΠΊΡΠΏΠΎΠ½Π΅Π½Ρ†ΠΈΠ°Π»ΡŒΠ½Π°Ρ Π·Π°Π΄Π΅Ρ€ΠΆΠΊΠ° с jitter."""
import random
delay = min(
config.RETRY_BASE_DELAY * (config.RETRY_EXPONENTIAL_BASE ** attempt),
config.RETRY_MAX_DELAY,
)
jitter = random.uniform(0, delay * 0.3)
return delay + jitter
def _is_retryable_error(self, error: Exception) -> bool:
"""ΠžΠΏΡ€Π΅Π΄Π΅Π»ΡΠ΅Ρ‚, стоит Π»ΠΈ retry-ΠΈΡ‚ΡŒ ΠΎΡˆΠΈΠ±ΠΊΡƒ."""
if isinstance(error, openai.APITimeoutError):
return True
if isinstance(error, openai.APIConnectionError):
return True
if isinstance(error, openai.RateLimitError):
return True
if isinstance(error, openai.InternalServerError):
return True
if isinstance(error, openai.APIStatusError):
if hasattr(error, 'status_code') and error.status_code in (429, 502, 503, 504):
return True
return False
# ═══════════════════════════════════════════════════════════════
# BLOCK: Core Chat with Retry and Fallback
# ═══════════════════════════════════════════════════════════════
async def chat(
self,
messages: List[Dict[str, str]],
user_id: Optional[int] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
stream: bool = False,
) -> str:
"""
ΠžΡ‚ΠΏΡ€Π°Π²Π»ΡΠ΅Ρ‚ запрос ΠΊ LLM с retry ΠΈ fallback.
Π’ΠΎΠ·Π²Ρ€Π°Ρ‰Π°Π΅Ρ‚ ΠΏΠΎΠ»Π½Ρ‹ΠΉ тСкст ΠΎΡ‚Π²Π΅Ρ‚Π°.
"""
start_time = time.perf_counter()
model = self.primary_model
client = self.primary_client
used_fallback = False
last_error = None
params = self._build_params(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
stream=stream,
)
# Retry loop for primary model
for attempt in range(config.MAX_RETRIES):
try:
logger.info(
"β†’ LLM request: model=%s, attempt=%d/%d, messages=%d, stream=%s",
model, attempt + 1, config.MAX_RETRIES, len(messages), stream
)
if stream and config.STREAMING_ENABLED:
response_text = await self._stream_chat(client, params, user_id)
else:
response = await client.chat.completions.create(**params)
response_text = response.choices[0].message.content or ""
duration_ms = (time.perf_counter() - start_time) * 1000
logger.info(
"← LLM response: model=%s, duration=%.1fms, length=%d, fallback=%s",
model, duration_ms, len(response_text), used_fallback
)
# Save metrics
if config.METRICS_ENABLED:
await db.save_metric(
user_id=user_id,
model=model,
duration_ms=duration_ms,
success=True,
)
return response_text
except Exception as e:
last_error = e
duration_ms = (time.perf_counter() - start_time) * 1000
error_type = type(e).__name__
if not self._is_retryable_error(e):
logger.error("Non-retryable error: %s", e)
if config.METRICS_ENABLED:
await db.save_metric(
user_id=user_id,
model=model,
duration_ms=duration_ms,
success=False,
error_type=error_type,
)
raise LLMServiceError(f"Non-retryable error: {e}") from e
if attempt < config.MAX_RETRIES - 1:
delay = self._calculate_delay(attempt)
logger.warning(
"⚠️ Retry %d/%d for model=%s after %.1fs: %s",
attempt + 1, config.MAX_RETRIES, model, delay, e
)
await asyncio.sleep(delay)
else:
logger.error("Primary model exhausted all retries: %s", e)
# Fallback model
if config.FALLBACK_ENABLED and self.fallback_client:
logger.warning("πŸ”„ Switching to fallback model: %s", self.fallback_model)
model = self.fallback_model
client = self.fallback_client
used_fallback = True
try:
params["model"] = model
if stream and config.STREAMING_ENABLED:
response_text = await self._stream_chat(client, params, user_id)
else:
response = await client.chat.completions.create(**params)
response_text = response.choices[0].message.content or ""
duration_ms = (time.perf_counter() - start_time) * 1000
logger.info(
"← Fallback response: model=%s, duration=%.1fms, length=%d",
model, duration_ms, len(response_text)
)
if config.METRICS_ENABLED:
await db.save_metric(
user_id=user_id,
model=f"{model} (fallback)",
duration_ms=duration_ms,
success=True,
)
return response_text
except Exception as e:
duration_ms = (time.perf_counter() - start_time) * 1000
logger.error("Fallback model also failed: %s", e)
if config.METRICS_ENABLED:
await db.save_metric(
user_id=user_id,
model=f"{model} (fallback)",
duration_ms=duration_ms,
success=False,
error_type=type(e).__name__,
)
raise LLMTimeoutError(
f"Both primary and fallback models failed. Last error: {last_error}"
) from e
raise LLMTimeoutError(f"All retries exhausted. Last error: {last_error}")
# ═══════════════════════════════════════════════════════════════
# BLOCK: Streaming
# ═══════════════════════════════════════════════════════════════
async def _stream_chat(
self,
client: AsyncOpenAI,
params: Dict[str, Any],
user_id: Optional[int] = None,
) -> str:
"""ΠžΠ±Ρ€Π°Π±Π°Ρ‚Ρ‹Π²Π°Π΅Ρ‚ streaming-ΠΎΡ‚Π²Π΅Ρ‚ ΠΈ собираСт ΠΏΠΎΠ»Π½Ρ‹ΠΉ тСкст."""
params["stream"] = True
params["stream_options"] = {"include_usage": True}
full_text = ""
usage = None
try:
stream = await client.chat.completions.create(**params)
async for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
full_text += chunk.choices[0].delta.content
if chunk.usage:
usage = chunk.usage
except Exception as e:
logger.error("Streaming error: %s", e)
raise
if usage:
logger.info(
"Streaming complete: tokens_in=%d, tokens_out=%d",
usage.prompt_tokens or 0, usage.completion_tokens or 0
)
return full_text
async def chat_stream(
self,
messages: List[Dict[str, str]],
user_id: Optional[int] = None,
) -> AsyncGenerator[str, None]:
"""Yields text chunks for real-time Telegram updates."""
params = self._build_params(
messages=messages,
model=self.primary_model,
stream=True,
)
params["stream"] = True
params["stream_options"] = {"include_usage": True}
try:
stream = await self.primary_client.chat.completions.create(**params)
async for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
except Exception as e:
logger.error("Streaming generation failed: %s", e)
raise
# ═══════════════════════════════════════════════════════════════
# BLOCK: Summarization
# ═══════════════════════════════════════════════════════════════
async def summarize(self, dialog_text: str) -> str:
"""Буммаризация Π΄ΠΈΠ°Π»ΠΎΠ³Π° Ρ‡Π΅Ρ€Π΅Π· LLM."""
start_time = time.perf_counter()
messages = [
{
"role": "system",
"content": (
"Π‘ΡƒΠΌΠΌΠ°Ρ€ΠΈΠ·ΠΈΡ€ΡƒΠΉ ΡΠ»Π΅Π΄ΡƒΡŽΡ‰ΠΈΠΉ Π΄ΠΈΠ°Π»ΠΎΠ³ ΠΌΠ΅ΠΆΠ΄Ρƒ ΠΏΠΎΠ»ΡŒΠ·ΠΎΠ²Π°Ρ‚Π΅Π»Π΅ΠΌ ΠΈ ассистСнтом. "
"Π‘ΠΎΡ…Ρ€Π°Π½ΠΈ ΠΊΠ»ΡŽΡ‡Π΅Π²Ρ‹Π΅ Ρ„Π°ΠΊΡ‚Ρ‹, прСдпочтСния ΠΏΠΎΠ»ΡŒΠ·ΠΎΠ²Π°Ρ‚Π΅Π»Ρ, Π²Π°ΠΆΠ½Ρ‹Π΅ Π΄Π΅Ρ‚Π°Π»ΠΈ ΠΈ контСкст. "
"Π‘ΡƒΠ΄ΡŒ ΠΊΡ€Π°Ρ‚ΠΎΠΊ, максимум 4096 Ρ‚ΠΎΠΊΠ΅Π½ΠΎΠ². Π˜ΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠΉ русский язык."
)
},
{"role": "user", "content": dialog_text}
]
try:
response = await self.primary_client.chat.completions.create(
model=self.primary_model,
messages=messages,
temperature=0.1,
max_tokens=config.SUMMARY_MAX_TOKENS,
timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0),
)
content = response.choices[0].message.content or ""
duration_ms = (time.perf_counter() - start_time) * 1000
logger.info("Summary generated in %.1fms, length=%d", duration_ms, len(content))
return content
except Exception as e:
logger.error("Summary generation failed: %s", e)
return ""
# ═══════════════════════════════════════════════════════════════
# BLOCK: Helpers
# ═══════════════════════════════════════════════════════════════
def _build_params(
self,
messages: List[Dict[str, str]],
model: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
stream: bool = False,
) -> Dict[str, Any]:
"""Π‘Ρ‚Ρ€ΠΎΠΈΡ‚ ΠΏΠ°Ρ€Π°ΠΌΠ΅Ρ‚Ρ€Ρ‹ запроса ΠΊ API."""
return {
"model": model,
"messages": messages,
"temperature": temperature if temperature is not None else config.GLM_TEMPERATURE,
"top_p": config.GLM_TOP_P,
"frequency_penalty": config.GLM_FREQUENCY_PENALTY,
"presence_penalty": config.GLM_PRESENCE_PENALTY,
"max_tokens": max_tokens if max_tokens is not None else config.GLM_MAX_TOKENS,
"stream": stream,
}
def estimate_tokens(self, text: str) -> int:
"""Грубая ΠΎΡ†Π΅Π½ΠΊΠ° количСства Ρ‚ΠΎΠΊΠ΅Π½ΠΎΠ² (1 token β‰ˆ 0.75 английских слов ΠΈΠ»ΠΈ 0.5 русских)."""
# ΠŸΡ€ΠΎΡΡ‚Π°Ρ эвристика: ~4 символа Π½Π° Ρ‚ΠΎΠΊΠ΅Π½ для смСшанного тСкста
return max(1, len(text) // 4)
glm_service = GLMService()