Spaces:
Running
Running
| """ | |
| LLM utilities with retry logic and token budgeting for the Groq API. | |
| """ | |
| import logging | |
| import os | |
| from typing import Any, Callable | |
| from tenacity import ( | |
| retry, | |
| wait_exponential, | |
| stop_after_attempt, | |
| retry_if_exception_type, | |
| RetryError | |
| ) | |
| logger = logging.getLogger(__name__) | |
| LLM_MAX_RETRIES = int(os.getenv("LLM_MAX_RETRIES", "3")) | |
| LLM_INITIAL_WAIT_MS = int(os.getenv("LLM_INITIAL_WAIT_MS", "1000")) | |
| # Import Groq-specific exception types so the retry decorator actually fires | |
| # on real API failures (rate limits, timeouts, connection drops). | |
| # Falls back to generic exception types if groq package is absent. | |
| try: | |
| from groq import ( | |
| APIStatusError as _GroqAPIStatusError, | |
| RateLimitError as _GroqRateLimitError, | |
| APIConnectionError as _GroqAPIConnectionError, | |
| APITimeoutError as _GroqAPITimeoutError, | |
| ) | |
| _RETRYABLE_EXCEPTIONS: tuple = ( | |
| _GroqAPIStatusError, | |
| _GroqRateLimitError, | |
| _GroqAPIConnectionError, | |
| _GroqAPITimeoutError, | |
| RuntimeError, | |
| OSError, | |
| ) | |
| logger.debug("Groq-specific exception types loaded for LLM retry targeting.") | |
| except ImportError: | |
| _RETRYABLE_EXCEPTIONS = (RuntimeError, OSError) | |
| logger.debug( | |
| "groq package exceptions not available; " | |
| "LLM retry will only fire on RuntimeError/OSError." | |
| ) | |
| def with_llm_retry(func: Callable) -> Callable: | |
| """ | |
| Decorator for LLM calls with exponential backoff retry. | |
| Correctly targets Groq API exceptions (rate limit, connection, timeout, status). | |
| """ | |
| def wrapper(*args: Any, **kwargs: Any) -> Any: | |
| return func(*args, **kwargs) | |
| return wrapper | |
| def safe_llm_call(llm_func: Callable, fallback_value: Any = None) -> Callable: | |
| """ | |
| Wrapper for non-critical LLM operations that returns a fallback value | |
| on exhausted retries or unexpected failures. | |
| """ | |
| def wrapper(*args: Any, **kwargs: Any) -> Any: | |
| try: | |
| return with_llm_retry(llm_func)(*args, **kwargs) | |
| except RetryError as e: | |
| logger.error("LLM call exhausted all retries: %s", e) | |
| return fallback_value | |
| except Exception as e: | |
| logger.error("Unexpected LLM call failure: %s", e) | |
| return fallback_value | |
| return wrapper | |
| def estimate_prompt_tokens(text: str) -> int: | |
| """ | |
| Rough token count estimation: 1 token ≈ 4 characters. | |
| Use Groq's tokenizer endpoint for accurate counting. | |
| """ | |
| return len(text) // 4 | |
| def enforce_token_budget(text: str, max_tokens: int = 2000) -> str: | |
| """ | |
| Truncate text to stay within an estimated token budget. | |
| Attempts to cut at a clean newline boundary to avoid producing | |
| structurally broken JSON or mid-sentence truncations. | |
| """ | |
| estimated = estimate_prompt_tokens(text) | |
| if estimated <= max_tokens: | |
| return text | |
| max_chars = max_tokens * 4 | |
| truncated = text[:max_chars] | |
| # Prefer cutting at a newline close to (but not past) the limit | |
| break_pos = truncated.rfind('\n') | |
| if break_pos > int(max_chars * 0.8): | |
| truncated = truncated[:break_pos] | |
| logger.warning( | |
| "Text truncated from ~%d to ~%d tokens to respect budget of %d tokens.", | |
| estimated, | |
| estimate_prompt_tokens(truncated), | |
| max_tokens | |
| ) | |
| return truncated | |