File size: 8,104 Bytes
8981bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""LLM provider interface with a deterministic offline stub.

Every LLM call in the app goes through :class:`LLMProvider`. In production you set
``ANTHROPIC_API_KEY`` or ``OPENAI_API_KEY`` and a real model answers. With no key set
(CI, tests, local demo) the :class:`StubProvider` produces a grounded answer from the
retrieved context using simple extractive logic — no network, fully reproducible.

The stub is intentionally not a toy: it composes a real answer out of the retrieved
policy/catalog snippets so the end-to-end product (citations, refund decisions,
escalation) is demonstrable with zero paid keys.
"""

from __future__ import annotations

import re
from typing import Protocol, Sequence

from .config import Settings, get_settings


class LLMProvider(Protocol):
    """Minimal surface the app depends on. Easy to swap or mock."""

    name: str

    def answer(self, question: str, context: Sequence[str]) -> str:
        """Answer ``question`` grounded in ``context`` snippets."""
        ...


_SENT_SPLIT = re.compile(r"(?<=[.!?])\s+")
_WORD = re.compile(r"[a-z0-9]+")


def _tokens(text: str) -> list[str]:
    return _WORD.findall(text.lower())


class StubProvider:
    """Deterministic, offline, network-free provider.

    Strategy: rank the sentences inside the retrieved context by lexical overlap
    with the question and stitch the best ones into a concise grounded reply. This
    keeps answers faithful to the knowledge base (no hallucination) which is exactly
    what you want from a support agent.
    """

    name = "stub"

    def answer(self, question: str, context: Sequence[str]) -> str:
        q_terms = set(_tokens(question))
        if not context:
            return (
                "I couldn't find that in our policies or catalog, so I'm escalating "
                "this to a human teammate."
            )

        # Flatten context into candidate sentences. Each is scored by how many query
        # terms it covers (overlap). Trivially short fragments (e.g. a bare section
        # heading like "Shipping.") are skipped so they can't win on a single keyword
        # while carrying no actual answer. Ties prefer the more informative (longer)
        # sentence, then earlier source order, keeping the result deterministic.
        candidates: list[tuple[int, int, int, str]] = []
        for src_idx, block in enumerate(context):
            for sent in _SENT_SPLIT.split(block.strip()):
                sent = sent.strip()
                sent_tokens = _tokens(sent)
                if len(sent_tokens) < 4:
                    continue
                overlap = len(q_terms & set(sent_tokens))
                candidates.append((overlap, len(sent_tokens), -src_idx, sent))

        if not candidates:
            return context[0].strip()

        # Highest overlap first; then longer (more informative); then earlier source.
        candidates.sort(key=lambda c: (c[0], c[1], c[2]), reverse=True)
        if candidates[0][0] == 0:
            # No lexical match — surface the most relevant retrieved block rather than
            # fabricate. The orchestrator's confidence will be low and it will likely
            # escalate.
            return context[0].strip()

        chosen: list[str] = []
        for overlap, _len, _src, sent in candidates:
            if overlap == 0:
                break
            if sent not in chosen:
                chosen.append(sent)
            if len(chosen) >= 2:
                break
        return " ".join(chosen)


class AnthropicProvider:
    """Real provider backed by the Anthropic Messages API."""

    name = "anthropic"

    def __init__(self, settings: Settings):
        # Imported lazily so the dependency is optional and the offline path never
        # needs it installed.
        import anthropic  # type: ignore

        self._client = anthropic.Anthropic(api_key=settings.anthropic_api_key)
        self._model = settings.anthropic_model

    def answer(self, question: str, context: Sequence[str]) -> str:
        joined = "\n\n".join(f"[doc {i + 1}]\n{c}" for i, c in enumerate(context))
        system = (
            "You are an e-commerce support agent. Answer ONLY from the provided "
            "context. If the answer is not in the context, say you don't know. Be "
            "concise and cite the relevant policy heading or product name."
        )
        msg = self._client.messages.create(
            model=self._model,
            max_tokens=400,
            system=system,
            messages=[
                {
                    "role": "user",
                    "content": f"Context:\n{joined}\n\nCustomer question: {question}",
                }
            ],
        )
        return "".join(block.text for block in msg.content if block.type == "text").strip()


class OpenAIProvider:
    """Real provider backed by the OpenAI Chat Completions API."""

    name = "openai"

    def __init__(self, settings: Settings):
        from openai import OpenAI  # type: ignore

        self._client = OpenAI(api_key=settings.openai_api_key)
        self._model = settings.openai_model

    def answer(self, question: str, context: Sequence[str]) -> str:
        joined = "\n\n".join(f"[doc {i + 1}]\n{c}" for i, c in enumerate(context))
        system = (
            "You are an e-commerce support agent. Answer ONLY from the provided "
            "context. If the answer is not in the context, say you don't know. Be "
            "concise and cite the relevant policy heading or product name."
        )
        resp = self._client.chat.completions.create(
            model=self._model,
            max_tokens=400,
            messages=[
                {"role": "system", "content": system},
                {
                    "role": "user",
                    "content": f"Context:\n{joined}\n\nCustomer question: {question}",
                },
            ],
        )
        return (resp.choices[0].message.content or "").strip()


class SafeProvider:
    """Wrap a real provider so any runtime failure degrades to the offline stub.

    A bad API key, rate limit, or network blip on a live call should never 500 a
    support request — we fall back to a grounded extractive answer instead. The
    reported ``name`` is the wrapped provider's so observability stays accurate.
    """

    def __init__(self, inner: LLMProvider):
        self._inner = inner
        self._fallback = StubProvider()
        self.name = inner.name

    def answer(self, question: str, context: Sequence[str]) -> str:
        try:
            text = self._inner.answer(question, context)
            return text or self._fallback.answer(question, context)
        except Exception:
            return self._fallback.answer(question, context)


def get_provider(settings: Settings | None = None) -> LLMProvider:
    """Resolve a provider from settings.

    ``auto`` (default) uses a real provider only when its key is present, otherwise
    the offline stub. Explicit values (``anthropic``/``openai``/``stub``) are honored.
    Any failure to construct a real provider degrades gracefully to the stub so the
    app never hard-crashes on a missing optional dependency, and any failure at call
    time is caught by :class:`SafeProvider`.
    """
    settings = settings or get_settings()
    choice = settings.llm_provider.lower()

    def _try(builder) -> LLMProvider | None:
        try:
            return SafeProvider(builder(settings))
        except Exception:
            return None

    if choice == "stub":
        return StubProvider()
    if choice == "anthropic" and settings.anthropic_api_key:
        return _try(AnthropicProvider) or StubProvider()
    if choice == "openai" and settings.openai_api_key:
        return _try(OpenAIProvider) or StubProvider()
    if choice == "auto":
        if settings.anthropic_api_key:
            p = _try(AnthropicProvider)
            if p:
                return p
        if settings.openai_api_key:
            p = _try(OpenAIProvider)
            if p:
                return p
    return StubProvider()