File size: 3,874 Bytes
e3e5444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
LLM utilities with retry logic and token budgeting for the Groq API.
"""
import logging
import os
from typing import Any, Callable

from tenacity import (
    retry,
    wait_exponential,
    stop_after_attempt,
    retry_if_exception_type,
    RetryError
)

logger = logging.getLogger(__name__)

LLM_MAX_RETRIES = int(os.getenv("LLM_MAX_RETRIES", "3"))
LLM_INITIAL_WAIT_MS = int(os.getenv("LLM_INITIAL_WAIT_MS", "1000"))

# Import Groq-specific exception types so the retry decorator actually fires
# on real API failures (rate limits, timeouts, connection drops).
# Falls back to generic exception types if groq package is absent.
try:
    from groq import (
        APIStatusError as _GroqAPIStatusError,
        RateLimitError as _GroqRateLimitError,
        APIConnectionError as _GroqAPIConnectionError,
        APITimeoutError as _GroqAPITimeoutError,
    )
    _RETRYABLE_EXCEPTIONS: tuple = (
        _GroqAPIStatusError,
        _GroqRateLimitError,
        _GroqAPIConnectionError,
        _GroqAPITimeoutError,
        RuntimeError,
        OSError,
    )
    logger.debug("Groq-specific exception types loaded for LLM retry targeting.")
except ImportError:
    _RETRYABLE_EXCEPTIONS = (RuntimeError, OSError)
    logger.debug(
        "groq package exceptions not available; "
        "LLM retry will only fire on RuntimeError/OSError."
    )


def with_llm_retry(func: Callable) -> Callable:
    """
    Decorator for LLM calls with exponential backoff retry.
    Correctly targets Groq API exceptions (rate limit, connection, timeout, status).
    """
    @retry(
        retry=retry_if_exception_type(_RETRYABLE_EXCEPTIONS),
        wait=wait_exponential(
            multiplier=1,
            min=LLM_INITIAL_WAIT_MS / 1000,
            max=60
        ),
        stop=stop_after_attempt(LLM_MAX_RETRIES),
        reraise=True,
        before_sleep=lambda retry_state: logger.warning(
            "LLM call failed (attempt %d/%d), retrying in %.1fs...",
            retry_state.attempt_number,
            LLM_MAX_RETRIES,
            retry_state.next_action.sleep
        )
    )
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        return func(*args, **kwargs)

    return wrapper


def safe_llm_call(llm_func: Callable, fallback_value: Any = None) -> Callable:
    """
    Wrapper for non-critical LLM operations that returns a fallback value
    on exhausted retries or unexpected failures.
    """
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        try:
            return with_llm_retry(llm_func)(*args, **kwargs)
        except RetryError as e:
            logger.error("LLM call exhausted all retries: %s", e)
            return fallback_value
        except Exception as e:
            logger.error("Unexpected LLM call failure: %s", e)
            return fallback_value

    return wrapper


def estimate_prompt_tokens(text: str) -> int:
    """
    Rough token count estimation: 1 token ≈ 4 characters.
    Use Groq's tokenizer endpoint for accurate counting.
    """
    return len(text) // 4


def enforce_token_budget(text: str, max_tokens: int = 2000) -> str:
    """
    Truncate text to stay within an estimated token budget.
    Attempts to cut at a clean newline boundary to avoid producing
    structurally broken JSON or mid-sentence truncations.
    """
    estimated = estimate_prompt_tokens(text)
    if estimated <= max_tokens:
        return text

    max_chars = max_tokens * 4
    truncated = text[:max_chars]

    # Prefer cutting at a newline close to (but not past) the limit
    break_pos = truncated.rfind('\n')
    if break_pos > int(max_chars * 0.8):
        truncated = truncated[:break_pos]

    logger.warning(
        "Text truncated from ~%d to ~%d tokens to respect budget of %d tokens.",
        estimated,
        estimate_prompt_tokens(truncated),
        max_tokens
    )
    return truncated