File size: 5,120 Bytes
8986591
 
 
4cc24b5
 
 
 
 
 
 
 
 
 
8986591
 
4cc24b5
 
8986591
 
 
4cc24b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8986591
4cc24b5
 
 
 
 
 
 
 
8986591
4cc24b5
8986591
 
 
 
 
4cc24b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8986591
4cc24b5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
app/utils/llm.py
────────────────
LLM singleton with automatic model fallback chain.

When a model hits its rate limit (429), the client transparently
tries the next model in the FALLBACK_MODELS list.

Fallback order (separate daily token quotas on Groq free tier):
  1. Primary model from config (default: llama-3.3-70b-versatile, 500k TPD)
  2. llama-3.1-8b-instant                                                (500k TPD)
  3. openai/gpt-oss-120b                                     (100k TPD)
  4. meta-llama/llama-4-scout-17b-16e-instruct                                             (100k TPD)
"""

import re
import time
from langchain_groq import ChatGroq
from app.config import settings

# ── Fallback chain ─────────────────────────────────────────────────────────
# Primary is whatever LLM_MODEL is set to in .env / HF Secrets.
# The rest are tried in order when the current one is rate-limited.
FALLBACK_MODELS = [
    settings.LLM_MODEL,
    "llama-3.1-8b-instant",
    "openai/gpt-oss-120b",
    "meta-llama/llama-4-scout-17b-16e-instruct",
]
# Deduplicate while preserving order
seen = set()
FALLBACK_MODELS = [m for m in FALLBACK_MODELS if not (m in seen or seen.add(m))]

_RATE_LIMIT_RE = re.compile(r'try again in\s+(?:(\d+)m)?(?:([\d.]+)s)?', re.IGNORECASE)


def _is_rate_limit(error: Exception) -> bool:
    return "429" in str(error) or "rate_limit_exceeded" in str(error)


def _parse_wait(error: Exception) -> float:
    m = _RATE_LIMIT_RE.search(str(error))
    if m:
        return float(m.group(1) or 0) * 60 + float(m.group(2) or 0)
    return 30.0


def _build(model: str) -> ChatGroq:
    return ChatGroq(
        model=model,
        temperature=settings.LLM_TEMPERATURE,
        api_key=settings.GROQ_API_KEY,
    )


# ── FallbackLLM wrapper ────────────────────────────────────────────────────

class FallbackLLM:
    """
    Drop-in replacement for a ChatGroq instance.
    On 429, switches to the next model in the chain automatically.
    Remembers which model is currently active across calls.
    """

    def __init__(self):
        self._index = 0          # index into FALLBACK_MODELS
        self._client = _build(FALLBACK_MODELS[0])
        print(f"[LLM] Active model: {FALLBACK_MODELS[0]}")

    @property
    def current_model(self) -> str:
        return FALLBACK_MODELS[self._index]

    def _next_model(self, error: Exception) -> bool:
        """Switch to next model. Returns False if all exhausted."""
        wait = _parse_wait(error)
        print(f"[LLM] ⚠ {self.current_model} rate-limited β€” trying next model (wait would be {wait:.0f}s)")

        self._index += 1
        if self._index >= len(FALLBACK_MODELS):
            self._index = 0   # full rotation β€” wait on primary
            mins, secs = int(wait // 60), int(wait % 60)
            print(f"[LLM] All models exhausted. Waiting {mins}m {secs}s for {self.current_model}...")
            time.sleep(wait + 2)
            self._client = _build(FALLBACK_MODELS[0])
            return False

        self._client = _build(FALLBACK_MODELS[self._index])
        print(f"[LLM] βœ“ Switched to: {self.current_model}")
        return True

    def invoke(self, messages, **kwargs):
        while True:
            try:
                return self._client.invoke(messages, **kwargs)
            except Exception as e:
                if _is_rate_limit(e):
                    exhausted = not self._next_model(e)
                    if exhausted:
                        raise  # re-raise after waiting on primary
                else:
                    raise

    def bind_tools(self, tools):
        """Return a bound-tools version that also falls back on rate limit."""
        return FallbackLLMWithTools(self, tools)

    # Passthrough for any other ChatGroq attributes callers might use
    def __getattr__(self, name):
        return getattr(self._client, name)


class FallbackLLMWithTools:
    """Wraps FallbackLLM for tool-calling routes."""

    def __init__(self, parent: FallbackLLM, tools: list):
        self._parent = parent
        self._tools  = tools

    def invoke(self, messages, **kwargs):
        while True:
            try:
                bound = self._parent._client.bind_tools(self._tools)
                return bound.invoke(messages, **kwargs)
            except Exception as e:
                if _is_rate_limit(e):
                    exhausted = not self._parent._next_model(e)
                    if exhausted:
                        raise
                else:
                    raise


# ── Singletons ─────────────────────────────────────────────────────────────

llm = FallbackLLM()

def get_llm_with_tools(tools: list) -> FallbackLLMWithTools:
    return llm.bind_tools(tools)