File size: 8,755 Bytes
699677f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af380f4
 
 
699677f
 
 
 
 
 
af380f4
 
699677f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b2d478
699677f
0b2d478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
699677f
 
0b2d478
 
 
 
 
 
699677f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af380f4
 
 
699677f
 
 
 
af380f4
 
699677f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9457bc
af380f4
 
 
 
 
b9457bc
 
699677f
af380f4
699677f
af380f4
699677f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
"""LLM clients for Azure OpenAI and Azure AI Foundry Agents."""
from __future__ import annotations

from typing import Any

import httpx
import anyio
from azure.ai.projects import AIProjectClient
from azure.identity import DefaultAzureCredential

from ..core.config import get_settings
from ..core.errors import LLMError


class AzureOpenAIClient:
    """Minimal Azure OpenAI chat completions client."""

    def __init__(self) -> None:
        self._settings = get_settings()

    async def chat(
        self, transcript: str, prompt: str | None = None, language: str | None = None
    ) -> str:
        """Call Azure OpenAI chat completions and return assistant text."""

        system_prompt = (
            "You are a concise, helpful assistant. "
            "Answer briefly and ask a clarifying question if needed."
        )
        if language:
            system_prompt += f" Reply in the same language as the user ({language})."
        user_content = f"Transcript: {transcript}"
        if prompt:
            user_content += f"\nUser instruction: {prompt}"

        base = self._normalize_endpoint(self._settings.azure_openai_endpoint)
        url = (
            f"{base}/openai/deployments/"
            f"{self._settings.azure_openai_deployment}/chat/completions"
            f"?api-version={self._settings.azure_openai_api_version}"
        )

        payload = {
            "messages": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_content},
            ],
            "temperature": 0.2,
            "max_tokens": 300,
        }

        headers = {"api-key": self._settings.azure_openai_api_key}

        last_exc: httpx.HTTPStatusError | None = None
        try:
            for attempt in range(3):
                try:
                    async with httpx.AsyncClient(timeout=30.0) as client:
                        response = await client.post(url, json=payload, headers=headers)
                        response.raise_for_status()
                        data: dict[str, Any] = response.json()
                        last_exc = None
                        break
                except httpx.HTTPStatusError as exc:
                    last_exc = exc
                    body = exc.response.text or ""
                    if exc.response.status_code == 400 and "content_filter" in body:
                        # retry twice, then return guardrail message
                        if attempt < 2:
                            continue
                        raise LLMError(
                            code="llm_guardrail",
                            message="Query is violating some guardrails.",
                            details={"body": body},
                        ) from exc
                    raise LLMError(
                        code="llm_http",
                        message=f"LLM request failed with status {exc.response.status_code}.",
                        details={"body": body},
                    ) from exc
        except httpx.HTTPError as exc:
            raise LLMError(code="llm_http", message="LLM request failed.") from exc
        if last_exc is not None:
            raise LLMError(
                code="llm_http",
                message=f"LLM request failed with status {last_exc.response.status_code}.",
                details={"body": last_exc.response.text},
            ) from last_exc

        try:
            content = data["choices"][0]["message"]["content"]
        except (KeyError, IndexError, TypeError) as exc:
            raise LLMError(code="llm_response", message="Invalid LLM response.") from exc

        text = str(content).strip()
        if not text:
            raise LLMError(code="llm_empty", message="Empty LLM response.")
        return text

    def _normalize_endpoint(self, endpoint: str) -> str:
        cleaned = endpoint.rstrip("/")
        marker = "/openai/"
        if marker in cleaned:
            cleaned = cleaned.split(marker, 1)[0]
        return cleaned


class FoundryAgentClient:
    """Azure AI Foundry Agent client using a connection string."""

    def __init__(self) -> None:
        self._settings = get_settings()
        if not hasattr(AIProjectClient, "from_connection_string"):
            raise LLMError(
                code="llm_config",
                message=(
                    "azure-ai-projects is missing from_connection_string(). "
                    "Install azure-ai-projects==1.0.0b10."
                ),
            )
        self._credential = DefaultAzureCredential(
            exclude_managed_identity_credential=True
        )
        self._client = AIProjectClient.from_connection_string(
            credential=self._credential,
            conn_str=self._settings.foundry_project_conn_str,
        )

    async def chat(
        self, transcript: str, prompt: str | None = None, language: str | None = None
    ) -> str:
        """Send a message to the Foundry agent and return the reply text."""
        user_content = f"Transcript: {transcript}"
        if prompt:
            user_content += f"\nUser instruction: {prompt}"
        if language:
            user_content += f"\nDetected language: {language}. Reply in the same language."

        try:
            return await anyio.to_thread.run_sync(self._chat_sync, user_content)
        except LLMError:
            raise
        except Exception as exc:
            raise LLMError(
                code="llm_http",
                message="LLM request failed.",
                details={"error": repr(exc)},
            ) from exc

    def _chat_sync(self, user_content: str) -> str:
        thread_id = self._client.agents.create_thread().id
        self._client.agents.create_message(
            thread_id=thread_id, role="user", content=user_content
        )
        run = self._client.agents.create_and_process_run(
            thread_id=thread_id, agent_id=self._settings.foundry_agent_id
        )
        messages = self._client.agents.list_messages(thread_id=thread_id)
        text = self._extract_assistant_text(messages, run_id=getattr(run, "id", None))
        if not text:
            raise LLMError(
                code="llm_empty",
                message="Empty LLM response.",
                details={
                    "messages_type": type(messages).__name__,
                    "messages_repr": self._safe_repr(messages),
                },
            )
        return text

    def _extract_assistant_text(
        self, messages: Any, run_id: str | None = None
    ) -> str | None:
        data = getattr(messages, "data", None)
        if data is None and isinstance(messages, dict):
            data = messages.get("data")
        if not data:
            return None

        def get(field: str, obj: Any, default: Any = None) -> Any:
            if isinstance(obj, dict):
                return obj.get(field, default)
            return getattr(obj, field, default)

        candidates: list[Any] = []
        for m in data:
            if get("role", m) != "assistant":
                continue
            if run_id is None or get("run_id", m) == run_id:
                candidates.append(m)

        if not candidates:
            candidates = [m for m in data if get("role", m) == "assistant"]
            if not candidates:
                return None

        msg = candidates[0]
        content = get("content", msg, []) or []

        parts: list[str] = []
        for block in content:
            btype = get("type", block)
            if btype == "text":
                text_obj = get("text", block, {})
                value = get("value", text_obj)
                if value:
                    parts.append(value)

        final = "\n".join(parts).strip()
        return final or None

    def _safe_repr(self, value: Any) -> str:
        try:
            return repr(value)[:2000]
        except Exception:
            return "<unreprable>"


class LLMClient:
    """LLM router that dispatches to configured provider."""

    def __init__(self) -> None:
        self._settings = get_settings()
        self._azure = AzureOpenAIClient()
        self._foundry = FoundryAgentClient()

    async def chat(
        self,
        transcript: str,
        prompt: str | None = None,
        provider: str | None = None,
        language: str | None = None,
    ) -> str:
        provider = provider or self._settings.llm_provider
        if provider == "foundry_agent":
            return await self._foundry.chat(transcript, prompt, language)
        if provider == "azure_openai":
            return await self._azure.chat(transcript, prompt, language)
        raise LLMError(code="llm_provider", message="Unsupported LLM provider.")