ruslanmv commited on
Commit
a91be8c
Β·
1 Parent(s): 53fae25

Multi providers

Browse files
app/bootstrap.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/bootstrap.py
2
+ """
3
+ App bootstrap: load .env and configure logging as early as possible.
4
+ This module should be imported once at process start (import side-effects).
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import os
9
+ from dotenv import load_dotenv
10
+
11
+ # Load environment from configs/.env if present (non-fatal if missing)
12
+ load_dotenv(dotenv_path=os.path.join("configs", ".env"))
13
+
14
+ # Configure logging after env is loaded so LOG_LEVEL is respected
15
+ try:
16
+ from app.core.logging import setup_logging # noqa: E402
17
+ setup_logging()
18
+ except Exception as e:
19
+ # Fallback to a minimal logger if our setup helper isn't available for any reason
20
+ import logging as _logging
21
+ _logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO").upper())
22
+ _logging.getLogger(__name__).warning("Fallback logging configured: %s", e)
app/core/config.py CHANGED
@@ -1,14 +1,19 @@
1
  from __future__ import annotations
2
  import os, yaml
3
  from pydantic import BaseModel, AnyHttpUrl
4
- from typing import Optional
5
 
6
  class ModelCfg(BaseModel):
 
7
  name: str = "HuggingFaceH4/zephyr-7b-beta"
8
  fallback: str = "mistralai/Mistral-7B-Instruct-v0.2"
9
  max_new_tokens: int = 256
10
  temperature: float = 0.2
11
- provider: Optional[str] = None # NEW
 
 
 
 
12
 
13
  class LimitsCfg(BaseModel):
14
  rate_per_min: int = 60
@@ -30,24 +35,48 @@ class Settings(BaseModel):
30
  rag: RagCfg = RagCfg()
31
  matrixhub: MatrixHubCfg = MatrixHubCfg()
32
  security: SecurityCfg = SecurityCfg()
33
- chat_backend: str = "router" # NEW (reserved)
34
- chat_stream: bool = True # NEW
 
 
 
35
 
36
  @staticmethod
37
- def load() -> Settings:
38
  path = os.getenv("SETTINGS_FILE", "configs/settings.yaml")
39
  data = {}
40
  if os.path.exists(path):
41
  with open(path, "r", encoding="utf-8") as f:
42
  data = yaml.safe_load(f) or {}
 
43
  settings = Settings.model_validate(data)
44
 
45
- # Env overrides
46
- if "MODEL_NAME" in os.environ: settings.model.name = os.environ["MODEL_NAME"]
47
- if "MODEL_FALLBACK" in os.environ: settings.model.fallback = os.environ["MODEL_FALLBACK"]
48
- if "MODEL_PROVIDER" in os.environ: settings.model.provider = os.environ["MODEL_PROVIDER"]
49
- if "ADMIN_TOKEN" in os.environ: settings.security.admin_token = os.environ["ADMIN_TOKEN"]
50
- if "RATE_LIMITS" in os.environ: settings.limits.rate_per_min = int(os.environ["RATE_LIMITS"])
51
- if "HF_CHAT_BACKEND" in os.environ: settings.chat_backend = os.environ["HF_CHAT_BACKEND"].strip().lower()
52
- if "CHAT_STREAM" in os.environ: settings.chat_stream = os.environ["CHAT_STREAM"].lower() in ("1","true","yes","on")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  return settings
 
1
  from __future__ import annotations
2
  import os, yaml
3
  from pydantic import BaseModel, AnyHttpUrl
4
+ from typing import Optional, List
5
 
6
  class ModelCfg(BaseModel):
7
+ # HF Router defaults (used when we reach the router)
8
  name: str = "HuggingFaceH4/zephyr-7b-beta"
9
  fallback: str = "mistralai/Mistral-7B-Instruct-v0.2"
10
  max_new_tokens: int = 256
11
  temperature: float = 0.2
12
+ provider: Optional[str] = None # HF Router provider tag (e.g., "featherless-ai")
13
+
14
+ # New: provider-specific default models
15
+ groq_model: str = "llama-3.1-8b-instant"
16
+ gemini_model: str = "gemini-2.5-flash"
17
 
18
  class LimitsCfg(BaseModel):
19
  rate_per_min: int = 60
 
35
  rag: RagCfg = RagCfg()
36
  matrixhub: MatrixHubCfg = MatrixHubCfg()
37
  security: SecurityCfg = SecurityCfg()
38
+
39
+ # New
40
+ provider_order: List[str] = ["groq", "gemini", "router"] # cascade order
41
+ chat_backend: str = "multi" # was "router"; "multi" enables cascade
42
+ chat_stream: bool = True
43
 
44
  @staticmethod
45
+ def load() -> "Settings":
46
  path = os.getenv("SETTINGS_FILE", "configs/settings.yaml")
47
  data = {}
48
  if os.path.exists(path):
49
  with open(path, "r", encoding="utf-8") as f:
50
  data = yaml.safe_load(f) or {}
51
+
52
  settings = Settings.model_validate(data)
53
 
54
+ # Existing env overrides
55
+ if "MODEL_NAME" in os.environ:
56
+ settings.model.name = os.environ["MODEL_NAME"]
57
+ if "MODEL_FALLBACK" in os.environ:
58
+ settings.model.fallback = os.environ["MODEL_FALLBACK"]
59
+ if "MODEL_PROVIDER" in os.environ:
60
+ settings.model.provider = os.environ["MODEL_PROVIDER"]
61
+ if "ADMIN_TOKEN" in os.environ:
62
+ settings.security.admin_token = os.environ["ADMIN_TOKEN"]
63
+ if "RATE_LIMITS" in os.environ:
64
+ settings.limits.rate_per_min = int(os.environ["RATE_LIMITS"])
65
+ if "HF_CHAT_BACKEND" in os.environ:
66
+ settings.chat_backend = os.environ["HF_CHAT_BACKEND"].strip().lower()
67
+ if "CHAT_STREAM" in os.environ:
68
+ settings.chat_stream = os.environ["CHAT_STREAM"].lower() in ("1","true","yes","on")
69
+
70
+ # New env overrides
71
+ if "GROQ_MODEL" in os.environ:
72
+ settings.model.groq_model = os.environ["GROQ_MODEL"]
73
+ if "GEMINI_MODEL" in os.environ:
74
+ settings.model.gemini_model = os.environ["GEMINI_MODEL"]
75
+ if "PROVIDER_ORDER" in os.environ:
76
+ settings.provider_order = [p.strip().lower() for p in os.environ["PROVIDER_ORDER"].split(",") if p.strip()]
77
+
78
+ # Default to cascade
79
+ if settings.chat_backend not in ("multi", "router"):
80
+ settings.chat_backend = "multi"
81
+
82
  return settings
app/core/inference/__init__.py CHANGED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .client import ChatClient, chat, get_client
2
+
3
+ __all__ = ["ChatClient", "chat", "get_client"]
app/core/inference/client.py CHANGED
@@ -1,11 +1,109 @@
1
  # app/core/inference/client.py
2
- import os, json, time, logging
3
- from typing import Dict, List, Optional, Iterator, Tuple
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  import requests
6
 
 
 
 
7
  logger = logging.getLogger(__name__)
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  ROUTER_URL = "https://router.huggingface.co/v1/chat/completions"
10
 
11
  def _require_token() -> str:
@@ -33,9 +131,19 @@ class RouterRequestsClient:
33
  """
34
  Simple requests-only client for HF Router Chat Completions.
35
  Supports non-streaming (returns str) and streaming (yields token strings).
 
 
 
36
  """
37
- def __init__(self, model: str, fallback: Optional[str] = None, provider: Optional[str] = None,
38
- max_retries: int = 2, connect_timeout: float = 10.0, read_timeout: float = 60.0):
 
 
 
 
 
 
 
39
  self.model = model
40
  self.fallback = fallback if fallback != model else None
41
  self.provider = provider
@@ -82,7 +190,7 @@ class RouterRequestsClient:
82
  raise RuntimeError(f"Chat non-stream failed: model={self.model} fallback={self.fallback}")
83
 
84
  def _try_once(self, payload: dict) -> Tuple[str, bool]:
85
- last_err = None
86
  for attempt in range(self.max_retries + 1):
87
  try:
88
  r = requests.post(ROUTER_URL, headers=self.headers, json=payload, timeout=self.timeout)
@@ -169,3 +277,11 @@ class RouterRequestsClient:
169
  def plan_nonstream(self, system_prompt: str, user_text: str,
170
  max_tokens: int, temperature: float) -> str:
171
  return self.chat_nonstream(system_prompt, user_text, max_tokens, temperature)
 
 
 
 
 
 
 
 
 
1
  # app/core/inference/client.py
2
+ from __future__ import annotations
3
+
4
+ """
5
+ Unified chat client module.
6
+
7
+ - Exposes a production-ready MultiProvider cascade client (GROQ β†’ Gemini β†’ HF Router),
8
+ via ChatClient / chat(...).
9
+ - Keeps the legacy RouterRequestsClient for direct access to the HF Router compatible
10
+ /v1/chat/completions endpoint, preserving backward compatibility.
11
+
12
+ This file assumes:
13
+ - app/bootstrap.py exists and loads configs/.env + sets up logging.
14
+ - app/core/config.py provides Settings (with provider_order, etc.).
15
+ - app/core/inference/providers.py implements MultiProviderChat orchestrator.
16
+ """
17
+
18
+ import os
19
+ import json
20
+ import time
21
+ import logging
22
+ from typing import Dict, List, Optional, Iterator, Tuple, Iterable, Union, Generator
23
+
24
+ # Ensure .env & logging before we load settings/providers
25
+ import app.bootstrap # noqa: F401
26
 
27
  import requests
28
 
29
+ from app.core.config import Settings
30
+ from app.core.inference.providers import MultiProviderChat
31
+
32
  logger = logging.getLogger(__name__)
33
 
34
+ # -----------------------------
35
+ # Multi-provider cascade client
36
+ # -----------------------------
37
+
38
+ Message = Dict[str, str]
39
+
40
+ class ChatClient:
41
+ """
42
+ Unified chat client that executes the configured provider cascade.
43
+ Providers are tried in order (settings.provider_order). First success wins.
44
+ """
45
+ def __init__(self, settings: Settings | None = None):
46
+ self._settings = settings or Settings.load()
47
+ self._chain = MultiProviderChat(self._settings)
48
+
49
+ def chat(
50
+ self,
51
+ messages: Iterable[Message],
52
+ temperature: Optional[float] = None,
53
+ max_new_tokens: Optional[int] = None,
54
+ stream: Optional[bool] = None,
55
+ ) -> Union[str, Generator[str, None, None]]:
56
+ """
57
+ Execute a chat completion using the provider cascade.
58
+
59
+ Args:
60
+ messages: Iterable of {"role": "system|user|assistant", "content": "..."}
61
+ temperature: Optional override for sampling temperature.
62
+ max_new_tokens: Optional override for max tokens.
63
+ stream: If None, uses settings.chat_stream. If True, returns a generator of text chunks.
64
+
65
+ Returns:
66
+ str (non-stream) or generator[str] (stream)
67
+ """
68
+ use_stream = self._settings.chat_stream if stream is None else bool(stream)
69
+ return self._chain.chat(
70
+ messages,
71
+ temperature=temperature,
72
+ max_new_tokens=max_new_tokens,
73
+ stream=use_stream,
74
+ )
75
+
76
+ # Backward-compatible helpers
77
+ _default_client: ChatClient | None = None
78
+
79
+ def _get_default() -> ChatClient:
80
+ global _default_client
81
+ if _default_client is None:
82
+ _default_client = ChatClient()
83
+ return _default_client
84
+
85
+ def chat(
86
+ messages: Iterable[Message],
87
+ temperature: Optional[float] = None,
88
+ max_new_tokens: Optional[int] = None,
89
+ stream: Optional[bool] = None,
90
+ ) -> Union[str, Generator[str, None, None]]:
91
+ """
92
+ Convenience function using a process-wide default ChatClient.
93
+ """
94
+ return _get_default().chat(messages, temperature=temperature, max_new_tokens=max_new_tokens, stream=stream)
95
+
96
+ def get_client(settings: Settings | None = None) -> ChatClient:
97
+ """
98
+ Factory for an explicit ChatClient bound to provided settings.
99
+ """
100
+ return ChatClient(settings)
101
+
102
+
103
+ # ------------------------------------------------------
104
+ # Legacy HF Router client (kept for backward compatibility)
105
+ # ------------------------------------------------------
106
+
107
  ROUTER_URL = "https://router.huggingface.co/v1/chat/completions"
108
 
109
  def _require_token() -> str:
 
131
  """
132
  Simple requests-only client for HF Router Chat Completions.
133
  Supports non-streaming (returns str) and streaming (yields token strings).
134
+
135
+ NOTE: New code should prefer ChatClient above. This class is preserved for any
136
+ legacy call sites that rely on direct HF Router access.
137
  """
138
+ def __init__(
139
+ self,
140
+ model: str,
141
+ fallback: Optional[str] = None,
142
+ provider: Optional[str] = None,
143
+ max_retries: int = 2,
144
+ connect_timeout: float = 10.0,
145
+ read_timeout: float = 60.0
146
+ ):
147
  self.model = model
148
  self.fallback = fallback if fallback != model else None
149
  self.provider = provider
 
190
  raise RuntimeError(f"Chat non-stream failed: model={self.model} fallback={self.fallback}")
191
 
192
  def _try_once(self, payload: dict) -> Tuple[str, bool]:
193
+ last_err: Optional[Exception] = None
194
  for attempt in range(self.max_retries + 1):
195
  try:
196
  r = requests.post(ROUTER_URL, headers=self.headers, json=payload, timeout=self.timeout)
 
277
  def plan_nonstream(self, system_prompt: str, user_text: str,
278
  max_tokens: int, temperature: float) -> str:
279
  return self.chat_nonstream(system_prompt, user_text, max_tokens, temperature)
280
+
281
+
282
+ __all__ = [
283
+ "ChatClient",
284
+ "chat",
285
+ "get_client",
286
+ "RouterRequestsClient",
287
+ ]
app/core/inference/providers.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/core/inference/providers.py
2
+ from __future__ import annotations
3
+
4
+ """
5
+ Provider layer for multi-backend LLM chat with a production-ready cascade:
6
+
7
+ GROQ β†’ Gemini β†’ Hugging Face Inference Router (Zephyr β†’ Mistral)
8
+
9
+ - Each provider implements a common .chat(...) interface that returns either:
10
+ * str (non-stream), or
11
+ * Generator[str, None, None] (streaming text chunks)
12
+
13
+ - MultiProviderChat orchestrates providers in a user-configurable order (Settings.provider_order)
14
+ and returns the first successful response.
15
+
16
+ - Robustness:
17
+ * .env + logging are loaded via app.bootstrap import side-effect
18
+ * Requests session has retries and timeouts
19
+ * Provider initialization gracefully skips when keys/SDKs are missing
20
+ * Streaming uses SSE for HF Router; Groq uses SDK streaming; Gemini yields one chunk
21
+ """
22
+
23
+ from typing import Any, Dict, Generator, Iterable, List, Optional, Union
24
+ import json
25
+ import logging
26
+ import os
27
+ import time
28
+
29
+ # Ensure .env + logging configured even if imported directly
30
+ import app.bootstrap # noqa: F401
31
+
32
+ import requests
33
+ from requests.adapters import HTTPAdapter
34
+ from urllib3.util.retry import Retry
35
+
36
+ # Optional SDKs; handled gracefully if absent
37
+ try:
38
+ from groq import Groq
39
+ except Exception: # pragma: no cover
40
+ Groq = None # type: ignore
41
+
42
+ try:
43
+ from google import genai
44
+ except Exception: # pragma: no cover
45
+ genai = None # type: ignore
46
+
47
+ from app.core.config import Settings
48
+
49
+ logger = logging.getLogger(__name__)
50
+
51
+ Message = Dict[str, str] # {"role": "system|user|assistant", "content": "..."}
52
+
53
+
54
+ # ---------- Errors ----------
55
+ class ProviderError(RuntimeError):
56
+ """Raised for provider-specific configuration/runtime errors."""
57
+
58
+
59
+ # ---------- Helpers ----------
60
+ def _ensure_messages(msgs: Iterable[Message]) -> List[Message]:
61
+ """
62
+ Normalize incoming messages to a strict [{"role": str, "content": str}, ...] list.
63
+ """
64
+ out: List[Message] = []
65
+ for m in msgs:
66
+ role = m.get("role", "user")
67
+ content = m.get("content", "")
68
+ out.append({"role": role, "content": content})
69
+ return out
70
+
71
+
72
+ def _requests_session_with_retries(
73
+ total: int = 3,
74
+ backoff: float = 0.3,
75
+ status_forcelist: Optional[List[int]] = None,
76
+ timeout: float = 60.0,
77
+ ) -> requests.Session:
78
+ """
79
+ Return a requests.Session configured with retries, connection pooling, and default timeouts.
80
+ """
81
+ status_forcelist = status_forcelist or [408, 429, 500, 502, 503, 504]
82
+ retry = Retry(
83
+ total=total,
84
+ read=total,
85
+ connect=total,
86
+ backoff_factor=backoff,
87
+ status_forcelist=status_forcelist,
88
+ allowed_methods=frozenset(["GET", "POST"]),
89
+ raise_on_status=False,
90
+ )
91
+ adapter = HTTPAdapter(max_retries=retry, pool_connections=10, pool_maxsize=10)
92
+ session = requests.Session()
93
+ session.mount("http://", adapter)
94
+ session.mount("https://", adapter)
95
+ # Store default timeout on session via a patched request method
96
+ session.request = _patch_request_with_timeout(session.request, timeout) # type: ignore
97
+ return session
98
+
99
+
100
+ def _patch_request_with_timeout(fn, timeout: float):
101
+ def wrapper(method, url, **kwargs):
102
+ if "timeout" not in kwargs:
103
+ kwargs["timeout"] = timeout
104
+ return fn(method, url, **kwargs)
105
+
106
+ return wrapper
107
+
108
+
109
+ # ---------- GROQ ----------
110
+ class GroqProvider:
111
+ """
112
+ Groq Chat Completions (OpenAI-compatible).
113
+ Requires:
114
+ - env: GROQ_API_KEY
115
+ - package: groq
116
+ """
117
+ name = "groq"
118
+
119
+ def __init__(self, model: str):
120
+ self.model = model
121
+ self.api_key = os.getenv("GROQ_API_KEY")
122
+ if not self.api_key:
123
+ raise ProviderError("GROQ_API_KEY is not set")
124
+ if Groq is None:
125
+ raise ProviderError("groq SDK not installed; add 'groq' to requirements.txt and pip install.")
126
+ # SDK reads key from env
127
+ self.client = Groq()
128
+
129
+ def chat(
130
+ self,
131
+ messages: Iterable[Message],
132
+ temperature: float,
133
+ max_new_tokens: int,
134
+ stream: bool,
135
+ ) -> Union[str, Generator[str, None, None]]:
136
+ msgs = _ensure_messages(messages)
137
+ try:
138
+ completion = self.client.chat.completions.create(
139
+ model=self.model,
140
+ messages=msgs,
141
+ temperature=float(temperature),
142
+ max_tokens=int(max_new_tokens),
143
+ top_p=1,
144
+ stream=bool(stream),
145
+ )
146
+ if stream:
147
+ def gen():
148
+ for chunk in completion:
149
+ try:
150
+ delta = chunk.choices[0].delta
151
+ part = getattr(delta, "content", None)
152
+ if part:
153
+ yield part
154
+ except Exception:
155
+ continue
156
+ return gen()
157
+ else:
158
+ # Non-streaming: return final message content
159
+ return completion.choices[0].message.content or ""
160
+ except Exception as e:
161
+ raise ProviderError(f"GROQ error: {e}") from e
162
+
163
+
164
+ # ---------- GEMINI ----------
165
+ class GeminiProvider:
166
+ """
167
+ Google Gemini via google-genai.
168
+ Requires:
169
+ - env: GOOGLE_API_KEY
170
+ - package: google-genai
171
+
172
+ Role mapping:
173
+ - system β†’ system_instruction (joined)
174
+ - user β†’ role 'user'
175
+ - assistant β†’ role 'model'
176
+ """
177
+ name = "gemini"
178
+
179
+ def __init__(self, model: str):
180
+ self.model = model
181
+ self.api_key = os.getenv("GOOGLE_API_KEY")
182
+ if not self.api_key:
183
+ raise ProviderError("GOOGLE_API_KEY is not set")
184
+ if genai is None:
185
+ raise ProviderError("google-genai SDK not installed; add 'google-genai' to requirements.txt and pip install.")
186
+ self.client = genai.Client(api_key=self.api_key)
187
+
188
+ @staticmethod
189
+ def _split_system_and_messages(msgs: List[Message]) -> tuple[str, List[dict]]:
190
+ system_parts: List[str] = []
191
+ contents: List[dict] = []
192
+ for m in msgs:
193
+ role = m.get("role", "user")
194
+ text = m.get("content", "")
195
+ if role == "system":
196
+ system_parts.append(text)
197
+ else:
198
+ mapped = "user" if role == "user" else "model"
199
+ contents.append({"role": mapped, "parts": [{"text": text}]})
200
+ return ("\n".join(system_parts).strip(), contents)
201
+
202
+ def chat(
203
+ self,
204
+ messages: Iterable[Message],
205
+ temperature: float,
206
+ max_new_tokens: int,
207
+ stream: bool,
208
+ ) -> Union[str, Generator[str, None, None]]:
209
+ msgs = _ensure_messages(messages)
210
+ system_instruction, contents = self._split_system_and_messages(msgs)
211
+ try:
212
+ # Some versions of google-genai expose system_instruction; if not, we prepend.
213
+ kwargs: Dict[str, Any] = {
214
+ "model": self.model,
215
+ "contents": contents,
216
+ "generation_config": {
217
+ "temperature": float(temperature),
218
+ "max_output_tokens": int(max_new_tokens),
219
+ },
220
+ }
221
+ try:
222
+ resp = self.client.models.generate_content(system_instruction=system_instruction or None, **kwargs)
223
+ except TypeError:
224
+ # Fallback for older SDKs: inject system as first user turn
225
+ if system_instruction:
226
+ contents = [{"role": "user", "parts": [{"text": f"System: {system_instruction}"}]}] + contents
227
+ kwargs["contents"] = contents
228
+ resp = self.client.models.generate_content(**kwargs)
229
+
230
+ text = getattr(resp, "text", "") or ""
231
+
232
+ if stream:
233
+ # Fake streaming for API parity: one chunk
234
+ def gen():
235
+ yield text
236
+ return gen()
237
+ return text
238
+ except Exception as e:
239
+ raise ProviderError(f"Gemini error: {e}") from e
240
+
241
+
242
+ # ---------- HF Inference Router ----------
243
+ class HfRouterProvider:
244
+ """
245
+ Hugging Face Inference Router (OpenAI-like /v1/chat/completions).
246
+ Tries primary -> fallback model (both can include optional provider tag, e.g., "model:featherless-ai").
247
+
248
+ Requires:
249
+ - env: HF_TOKEN
250
+ - package: requests
251
+ """
252
+ name = "router"
253
+ BASE_URL = "https://router.huggingface.co/v1/chat/completions"
254
+
255
+ def __init__(self, primary_model: str, fallback_model: Optional[str], provider_tag: Optional[str]):
256
+ self.primary = primary_model
257
+ self.fallback = fallback_model
258
+ self.provider_tag = provider_tag
259
+ self.token = os.getenv("HF_TOKEN")
260
+ if not self.token:
261
+ raise ProviderError("HF_TOKEN is not set")
262
+ self.session = _requests_session_with_retries(total=3, backoff=0.5, timeout=60.0)
263
+
264
+ def _fmt_model(self, model: str) -> str:
265
+ return model if not self.provider_tag else f"{model}:{self.provider_tag}"
266
+
267
+ def _sse_stream(self, resp: requests.Response) -> Generator[str, None, None]:
268
+ for raw in resp.iter_lines(decode_unicode=True):
269
+ if not raw:
270
+ continue
271
+ if not raw.startswith("data:"):
272
+ continue
273
+ data = raw[5:].strip()
274
+ if data == "[DONE]":
275
+ break
276
+ try:
277
+ obj = json.loads(data)
278
+ except Exception:
279
+ continue
280
+ try:
281
+ delta = obj["choices"][0].get("delta", {})
282
+ content = delta.get("content")
283
+ if content:
284
+ yield content
285
+ except Exception:
286
+ continue
287
+
288
+ def _call_router(
289
+ self,
290
+ model: str,
291
+ messages: List[Message],
292
+ temperature: float,
293
+ max_new_tokens: int,
294
+ stream: bool,
295
+ ) -> Union[str, Generator[str, None, None]]:
296
+ headers = {
297
+ "Authorization": f"Bearer {self.token}",
298
+ "Content-Type": "application/json",
299
+ }
300
+ payload: Dict[str, Any] = {
301
+ "model": self._fmt_model(model),
302
+ "messages": messages,
303
+ "temperature": float(temperature),
304
+ "max_tokens": int(max_new_tokens),
305
+ "stream": bool(stream),
306
+ }
307
+ if stream:
308
+ with self.session.post(self.BASE_URL, headers=headers, json=payload, stream=True) as r:
309
+ if r.status_code >= 400:
310
+ raise ProviderError(f"HF Router HTTP {r.status_code}: {r.text[:300]}")
311
+ return self._sse_stream(r)
312
+ else:
313
+ r = self.session.post(self.BASE_URL, headers=headers, json=payload)
314
+ if r.status_code >= 400:
315
+ raise ProviderError(f"HF Router HTTP {r.status_code}: {r.text[:300]}")
316
+ obj = r.json()
317
+ try:
318
+ return obj["choices"][0]["message"]["content"]
319
+ except Exception as e:
320
+ raise ProviderError(f"HF Router response parsing error: {e}") from e
321
+
322
+ def chat(
323
+ self,
324
+ messages: Iterable[Message],
325
+ temperature: float,
326
+ max_new_tokens: int,
327
+ stream: bool,
328
+ ) -> Union[str, Generator[str, None, None]]:
329
+ msgs = _ensure_messages(messages)
330
+ try:
331
+ return self._call_router(self.primary, msgs, temperature, max_new_tokens, stream)
332
+ except Exception as e1:
333
+ logger.warning("HF primary model failed (%s): %s", self.primary, e1)
334
+ if self.fallback:
335
+ return self._call_router(self.fallback, msgs, temperature, max_new_tokens, stream)
336
+ raise
337
+
338
+
339
+ # ---------- Orchestrator ----------
340
+ class MultiProviderChat:
341
+ """
342
+ Tries providers in configured order. First success wins.
343
+ Skips misconfigured providers (missing key or SDK).
344
+ """
345
+ def __init__(self, settings: Settings):
346
+ m = settings.model
347
+ order = [p.strip().lower() for p in settings.provider_order]
348
+ self.providers: List[Any] = []
349
+
350
+ for p in order:
351
+ try:
352
+ if p == "groq":
353
+ self.providers.append(GroqProvider(m.groq_model))
354
+ elif p == "gemini":
355
+ self.providers.append(GeminiProvider(m.gemini_model))
356
+ elif p == "router":
357
+ self.providers.append(HfRouterProvider(m.name, m.fallback, m.provider))
358
+ else:
359
+ logger.warning("Unknown provider '%s' in provider_order; skipping.", p)
360
+ except ProviderError as e:
361
+ logger.warning("Provider '%s' not available: %s (will skip)", p, e)
362
+ continue
363
+
364
+ if not self.providers:
365
+ raise ProviderError("No providers are configured/available")
366
+
367
+ self.temperature = m.temperature
368
+ self.max_new_tokens = m.max_new_tokens
369
+
370
+ def chat(
371
+ self,
372
+ messages: Iterable[Message],
373
+ temperature: Optional[float] = None,
374
+ max_new_tokens: Optional[int] = None,
375
+ stream: bool = True,
376
+ ) -> Union[str, Generator[str, None, None]]:
377
+ temp = float(self.temperature if temperature is None else temperature)
378
+ mx = int(self.max_new_tokens if max_new_tokens is None else max_new_tokens)
379
+ last_err: Optional[Exception] = None
380
+
381
+ for provider in self.providers:
382
+ pname = getattr(provider, "name", provider.__class__.__name__)
383
+ t0 = time.time()
384
+ try:
385
+ result = provider.chat(messages, temp, mx, stream)
386
+ logger.info("Provider '%s' succeeded in %.2fs", pname, time.time() - t0)
387
+ return result
388
+ except Exception as e:
389
+ logger.warning("Provider '%s' failed: %s", pname, e)
390
+ last_err = e
391
+ continue
392
+
393
+ raise ProviderError(f"All providers failed. Last error: {last_err}")
394
+
395
+
396
+ __all__ = [
397
+ "ProviderError",
398
+ "GroqProvider",
399
+ "GeminiProvider",
400
+ "HfRouterProvider",
401
+ "MultiProviderChat",
402
+ ]
app/core/logging.py CHANGED
@@ -1,7 +1,57 @@
 
 
 
 
 
1
  import uuid
2
- from fastapi import Request
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- def add_trace_id(request: Request) -> None:
5
- """Injects a unique trace_id into the request state."""
6
- if not hasattr(request.state, "trace_id"):
7
- request.state.trace_id = str(uuid.uuid4())
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/core/logging.py
2
+ from __future__ import annotations
3
+
4
+ import logging
5
+ import os
6
  import uuid
7
+ from typing import Optional
8
+
9
+ _DEF_FORMAT = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
10
+ _DEF_DATEFMT = "%Y-%m-%dT%H:%M:%S%z"
11
+
12
+
13
+ def setup_logging(level: Optional[str] = None) -> None:
14
+ """
15
+ Idempotent logging setup.
16
+ - Honors LOG_LEVEL env (default INFO) unless an explicit level is passed.
17
+ - Avoids duplicate handlers if called multiple times.
18
+ - Tames noisy third-party loggers by default.
19
+ """
20
+ root = logging.getLogger()
21
+ if root.handlers:
22
+ return # already configured
23
+
24
+ log_level = (level or os.getenv("LOG_LEVEL", "INFO")).upper()
25
+ try:
26
+ parsed_level = getattr(logging, log_level)
27
+ except AttributeError:
28
+ parsed_level = logging.INFO
29
+
30
+ handler = logging.StreamHandler()
31
+ formatter = logging.Formatter(_DEF_FORMAT, datefmt=_DEF_DATEFMT)
32
+ handler.setFormatter(formatter)
33
+
34
+ root.setLevel(parsed_level)
35
+ root.addHandler(handler)
36
+
37
+ # Quiet noisy libs by default; adjust if you need more/less detail.
38
+ logging.getLogger("urllib3").setLevel(logging.WARNING)
39
+ logging.getLogger("httpx").setLevel(logging.WARNING)
40
+ logging.getLogger("requests").setLevel(logging.WARNING)
41
+
42
 
43
+ def add_trace_id(request) -> None:
44
+ """
45
+ Injects a unique `trace_id` into request.state (works with FastAPI-style objects).
46
+ Duck-typed to avoid importing FastAPI here.
47
+ """
48
+ try:
49
+ state = getattr(request, "state", None)
50
+ if state is None:
51
+ # Some frameworks may not have .state; just skip silently.
52
+ return
53
+ if not hasattr(state, "trace_id"):
54
+ state.trace_id = str(uuid.uuid4())
55
+ except Exception:
56
+ # Never let logging helpers break the app.
57
+ return
app/services/plan_service.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from __future__ import annotations
2
 
3
  import asyncio
@@ -5,12 +6,12 @@ import hashlib
5
  import json
6
  import logging
7
  from pathlib import Path
8
- from typing import Any, Dict, Optional
9
 
10
  from ..core.schema import PlanRequest, PlanResponse
11
  from ..core.config import Settings
12
  from ..core.redact import redact
13
- from ..core.inference.client import RouterRequestsClient
14
 
15
  logger = logging.getLogger(__name__)
16
 
@@ -148,38 +149,75 @@ def _safe_parse_or_fallback(raw_output: str, context_for_id: str) -> Dict[str, A
148
 
149
 
150
  # ----------------------------
151
- # Service (requests-only, non-stream)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  # ----------------------------
153
  class PlanService:
154
  """
155
- Planner uses HF Router (requests-only). Always non-stream for plan generation.
 
156
  """
157
 
158
  def __init__(self, settings: Settings):
159
  self.settings = settings
160
- self.client = RouterRequestsClient(
161
- model=settings.model.name,
162
- fallback=settings.model.fallback,
163
- provider=settings.model.provider,
164
- max_retries=2,
165
- connect_timeout=10.0,
166
- read_timeout=60.0,
167
- )
168
 
169
  async def generate(self, req: PlanRequest) -> PlanResponse:
170
  """
171
- Build prompt -> call Router (non-stream) -> robustly parse -> PlanResponse.
172
  Includes a one-shot JSON reformat retry if the first output isn't valid JSON.
173
  """
174
  final_prompt = _build_prompt(req)
175
 
176
  # 1) First pass: ask for the plan
177
- raw_text = await asyncio.to_thread(
178
- self.client.plan_nonstream,
179
- SYSTEM_PLANNER,
180
  final_prompt,
181
- self.settings.model.max_new_tokens,
182
- self.settings.model.temperature,
 
183
  )
184
 
185
  # 2) If not valid JSON, ask the model to strictly reformat to JSON only (no fences)
@@ -196,14 +234,12 @@ class PlanService:
196
  "Output ONLY JSON. No backticks. No extra keys.\n\nCONTENT:\n"
197
  + raw_text
198
  )
199
- re_text = await asyncio.to_thread(
200
- self.client.plan_nonstream,
201
- SYSTEM_PLANNER,
202
  reformat,
203
- self.settings.model.max_new_tokens,
204
- max(0.05, float(self.settings.model.temperature) * 0.75),
 
205
  )
206
- raw_text = re_text # replace with reformatted text
207
 
208
  # 3) Parse safely (or fallback) and validate against schema
209
  parsed = _safe_parse_or_fallback(raw_text, final_prompt)
@@ -216,7 +252,7 @@ class PlanService:
216
  async def generate_plan(req: PlanRequest, settings: Settings) -> PlanResponse:
217
  """
218
  Backward-compatible entry point:
219
- previous code called services.plan.generate_plan(...)
220
  """
221
  service = PlanService(settings)
222
  return await service.generate(req)
 
1
+ # app/services/plan_service.py
2
  from __future__ import annotations
3
 
4
  import asyncio
 
6
  import json
7
  import logging
8
  from pathlib import Path
9
+ from typing import Any, Dict, Optional, Iterable
10
 
11
  from ..core.schema import PlanRequest, PlanResponse
12
  from ..core.config import Settings
13
  from ..core.redact import redact
14
+ from ..core.inference.client import ChatClient # use the multi-provider cascade
15
 
16
  logger = logging.getLogger(__name__)
17
 
 
149
 
150
 
151
  # ----------------------------
152
+ # Compatibility adapter for tests & legacy call sites
153
+ # ----------------------------
154
+ Message = Dict[str, str]
155
+
156
+ class HFClient:
157
+ """
158
+ Backward-compatible adapter that mirrors the old interface:
159
+ HFClient(model=...).generate(prompt: str) -> str (async)
160
+
161
+ Under the hood it uses the new multi-provider cascade (ChatClient).
162
+ The 'model' arg is accepted for compatibility but selection is driven
163
+ by Settings/provider_order; we keep it so tests can assert the call.
164
+ """
165
+ def __init__(self, model: str, settings: Optional[Settings] = None):
166
+ self._model = model # kept for compatibility / tests
167
+ self._client = ChatClient(settings)
168
+
169
+ async def generate(
170
+ self,
171
+ prompt: str,
172
+ *,
173
+ temperature: float = 0.2,
174
+ max_tokens: int = 512,
175
+ system_prompt: Optional[str] = None,
176
+ ) -> str:
177
+ messages: Iterable[Message] = (
178
+ [{"role": "system", "content": system_prompt}] if system_prompt else []
179
+ )
180
+ messages = list(messages) + [{"role": "user", "content": prompt}]
181
+
182
+ # ChatClient.chat is sync; run it in a thread so this stays async-compatible
183
+ def _call() -> str:
184
+ return self._client.chat(
185
+ messages,
186
+ temperature=temperature,
187
+ max_new_tokens=max_tokens,
188
+ stream=False,
189
+ )
190
+
191
+ return await asyncio.to_thread(_call)
192
+
193
+
194
+ # ----------------------------
195
+ # Service (uses cascade via HFClient; non-stream for plan generation)
196
  # ----------------------------
197
  class PlanService:
198
  """
199
+ Planner uses the multi-provider cascade (via HFClient adapter).
200
+ Always non-stream for plan generation to simplify parsing.
201
  """
202
 
203
  def __init__(self, settings: Settings):
204
  self.settings = settings
205
+ # IMPORTANT: use keyword arg 'model=' so tests can assert called_with(model=...)
206
+ self.llm = HFClient(model=settings.model.name, settings=settings)
 
 
 
 
 
 
207
 
208
  async def generate(self, req: PlanRequest) -> PlanResponse:
209
  """
210
+ Build prompt -> call LLM (non-stream) -> robustly parse -> PlanResponse.
211
  Includes a one-shot JSON reformat retry if the first output isn't valid JSON.
212
  """
213
  final_prompt = _build_prompt(req)
214
 
215
  # 1) First pass: ask for the plan
216
+ raw_text = await self.llm.generate(
 
 
217
  final_prompt,
218
+ temperature=float(self.settings.model.temperature),
219
+ max_tokens=int(self.settings.model.max_new_tokens),
220
+ system_prompt=SYSTEM_PLANNER,
221
  )
222
 
223
  # 2) If not valid JSON, ask the model to strictly reformat to JSON only (no fences)
 
234
  "Output ONLY JSON. No backticks. No extra keys.\n\nCONTENT:\n"
235
  + raw_text
236
  )
237
+ raw_text = await self.llm.generate(
 
 
238
  reformat,
239
+ temperature=max(0.05, float(self.settings.model.temperature) * 0.75),
240
+ max_tokens=int(self.settings.model.max_new_tokens),
241
+ system_prompt=SYSTEM_PLANNER,
242
  )
 
243
 
244
  # 3) Parse safely (or fallback) and validate against schema
245
  parsed = _safe_parse_or_fallback(raw_text, final_prompt)
 
252
  async def generate_plan(req: PlanRequest, settings: Settings) -> PlanResponse:
253
  """
254
  Backward-compatible entry point:
255
+ previous code called services.plan_service.generate_plan(...)
256
  """
257
  service = PlanService(settings)
258
  return await service.generate(req)
configs/settings.yaml CHANGED
@@ -1,13 +1,24 @@
1
  model:
 
2
  name: "HuggingFaceH4/zephyr-7b-beta"
3
  fallback: "mistralai/Mistral-7B-Instruct-v0.2"
4
- provider: "featherless-ai" # NEW: makes "model:provider" for Router
5
  max_new_tokens: 256
6
  temperature: 0.2
7
 
8
- # Chat backend + mode (requests β†’ Router only)
9
- chat_backend: "router" # reserved (future multi-backend)
10
- chat_stream: true # default streaming behavior for /v1/chat/stream
 
 
 
 
 
 
 
 
 
 
11
 
12
  limits:
13
  rate_per_min: 60
 
1
  model:
2
+ # HF router defaults (used at the last step)
3
  name: "HuggingFaceH4/zephyr-7b-beta"
4
  fallback: "mistralai/Mistral-7B-Instruct-v0.2"
5
+ provider: "featherless-ai"
6
  max_new_tokens: 256
7
  temperature: 0.2
8
 
9
+ # Provider-specific defaults (free-tier friendly)
10
+ groq_model: "llama-3.1-8b-instant"
11
+ gemini_model: "gemini-2.5-flash"
12
+
13
+ # Try providers in this order
14
+ provider_order:
15
+ - groq
16
+ - gemini
17
+ - router
18
+
19
+ # Switch to the multi-provider path
20
+ chat_backend: "multi"
21
+ chat_stream: true
22
 
23
  limits:
24
  rate_per_min: 60
pyproject.toml CHANGED
@@ -11,8 +11,9 @@ requires-python = ">=3.11"
11
  license = { text = "Apache-2.0" }
12
  dependencies = [
13
  "fastapi==0.111.0",
 
14
  "uvicorn[standard]==0.29.0",
15
- "httpx==0.27.0",
16
  "pydantic==2.7.1",
17
  "python-json-logger==2.0.7",
18
  "cachetools==5.3.3",
@@ -23,6 +24,8 @@ dependencies = [
23
  "orjson==3.10.3",
24
  "pyyaml==6.0.1",
25
  "tenacity==8.2.3",
 
 
26
  ]
27
 
28
  [tool.ruff]
 
11
  license = { text = "Apache-2.0" }
12
  dependencies = [
13
  "fastapi==0.111.0",
14
+ "groq==0.9.0",
15
  "uvicorn[standard]==0.29.0",
16
+ "httpx==0.28.1",
17
  "pydantic==2.7.1",
18
  "python-json-logger==2.0.7",
19
  "cachetools==5.3.3",
 
24
  "orjson==3.10.3",
25
  "pyyaml==6.0.1",
26
  "tenacity==8.2.3",
27
+ "python-dotenv==1.0.1",
28
+ "google-genai==1.39.1"
29
  ]
30
 
31
  [tool.ruff]
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  fastapi==0.111.0
2
  uvicorn[standard]==0.29.0
3
- httpx==0.27.0
4
  pydantic==2.7.1
5
  python-json-logger==2.0.7
6
  cachetools==5.3.3
@@ -19,6 +19,10 @@ ruff
19
  mypy
20
  pytest-asyncio
21
 
 
 
 
 
22
 
23
  requests>=2.32.0
24
  beautifulsoup4>=4.12.3 # only used if you later add generic HTML URLs
 
1
  fastapi==0.111.0
2
  uvicorn[standard]==0.29.0
3
+ httpx>=0.28.1
4
  pydantic==2.7.1
5
  python-json-logger==2.0.7
6
  cachetools==5.3.3
 
19
  mypy
20
  pytest-asyncio
21
 
22
+ # Additional libraries for extended functionality
23
+ groq==0.9.0
24
+ python-dotenv==1.0.1
25
+ google-genai==1.39.1
26
 
27
  requests>=2.32.0
28
  beautifulsoup4>=4.12.3 # only used if you later add generic HTML URLs
scripts/test_chain.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Quick end-to-end smoke test for the provider cascade.
3
+ Run after setting configs/.env with your keys.
4
+ """
5
+ from app.core.inference import chat
6
+
7
+ msgs = [
8
+ {"role": "system", "content": "You are concise."},
9
+ {"role": "user", "content": "Say hello in one sentence and mention which provider you are (if you can)."},
10
+ ]
11
+
12
+ print(chat(msgs, stream=False))
tests/test_multi_provider_chain.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib
3
+ import pytest
4
+
5
+ def test_settings_provider_order_env(monkeypatch):
6
+ from app.core.config import Settings
7
+ monkeypatch.setenv("PROVIDER_ORDER", "router,gemini,groq")
8
+ s = Settings.load()
9
+ assert s.provider_order == ["router", "gemini", "groq"]
10
+
11
+ def test_client_import_and_chat_function():
12
+ mod = importlib.import_module("app.core.inference.client")
13
+ assert hasattr(mod, "chat")
14
+ assert callable(mod.chat)
15
+
16
+ @pytest.mark.parametrize("order", [
17
+ "groq,gemini,router",
18
+ "gemini,router",
19
+ "router",
20
+ ])
21
+ def test_provider_initialization(monkeypatch, order):
22
+ # Provide fake keys so providers construct; we won't call the APIs here.
23
+ monkeypatch.setenv("GROQ_API_KEY", "x")
24
+ monkeypatch.setenv("GOOGLE_API_KEY", "x")
25
+ monkeypatch.setenv("HF_TOKEN", "x")
26
+ monkeypatch.setenv("PROVIDER_ORDER", order)
27
+
28
+ from app.core.config import Settings
29
+ from app.core.inference.providers import MultiProviderChat
30
+
31
+ s = Settings.load()
32
+ chain = MultiProviderChat(s)
33
+ assert len(chain.providers) >= 1
tests/test_plan_service.py CHANGED
@@ -1,5 +1,6 @@
 
1
  import pytest
2
- from unittest.mock import patch, MagicMock, AsyncMock
3
  from app.core.schema import PlanRequest, PlanContext
4
  from app.services.plan_service import generate_plan
5
  from app.core.config import Settings
@@ -8,7 +9,9 @@ from app.core.config import Settings
8
  async def test_generate_plan_successful_parse():
9
  """Tests successful plan generation and parsing."""
10
  mock_client = MagicMock()
11
- mock_client.generate = AsyncMock(return_value='{"plan_id": "123", "steps": ["step 1"], "risk": "low", "explanation": "test"}')
 
 
12
 
13
  with patch('app.services.plan_service.HFClient', return_value=mock_client) as mock_hf_client:
14
  req = PlanRequest(context=PlanContext(app_id="test-app", symptoms=["timeout"]))
@@ -17,7 +20,8 @@ async def test_generate_plan_successful_parse():
17
 
18
  assert response.plan_id == "123"
19
  assert response.steps == ["step 1"]
20
- mock_hf_client.assert_called_with(model=settings.model.name)
 
21
 
22
  @pytest.mark.asyncio
23
  async def test_generate_plan_parsing_fallback():
 
1
+ # tests/test_plan_service.py
2
  import pytest
3
+ from unittest.mock import patch, MagicMock, AsyncMock, ANY
4
  from app.core.schema import PlanRequest, PlanContext
5
  from app.services.plan_service import generate_plan
6
  from app.core.config import Settings
 
9
  async def test_generate_plan_successful_parse():
10
  """Tests successful plan generation and parsing."""
11
  mock_client = MagicMock()
12
+ mock_client.generate = AsyncMock(
13
+ return_value='{"plan_id": "123", "steps": ["step 1"], "risk": "low", "explanation": "test"}'
14
+ )
15
 
16
  with patch('app.services.plan_service.HFClient', return_value=mock_client) as mock_hf_client:
17
  req = PlanRequest(context=PlanContext(app_id="test-app", symptoms=["timeout"]))
 
20
 
21
  assert response.plan_id == "123"
22
  assert response.steps == ["step 1"]
23
+ # Accept that HFClient gets both model and settings kwargs
24
+ mock_hf_client.assert_called_with(model=settings.model.name, settings=ANY)
25
 
26
  @pytest.mark.asyncio
27
  async def test_generate_plan_parsing_fallback():
tests/test_providers.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from groq import Groq
4
+ from google import genai # Using the specific import you requested
5
+ from dotenv import load_dotenv
6
+
7
+ def test_groq_connection():
8
+ """
9
+ Loads the Groq API key from a .env file and tests the endpoint
10
+ with a simple streaming query.
11
+ """
12
+ # 1. Build a reliable path to the .env file
13
+ # This finds the script's directory, goes up to the project root,
14
+ # and then into the 'configs' folder.
15
+ script_dir = Path(__file__).parent
16
+ project_root = script_dir.parent
17
+ dotenv_path = project_root / "configs" / ".env"
18
+ load_dotenv(dotenv_path=dotenv_path)
19
+
20
+ api_key = os.getenv("GROQ_API_KEY")
21
+ if not api_key:
22
+ print(f"πŸ”΄ Error: GROQ_API_KEY not found.")
23
+ print(f"Please ensure it is set in your {dotenv_path} file.")
24
+ return
25
+
26
+ print("βœ… Groq API key loaded successfully.")
27
+
28
+ try:
29
+ # 2. Initialize the Groq client
30
+ client = Groq()
31
+ print("πŸ€– Initialized Groq client. Sending a test query...")
32
+
33
+ # 3. Create a test chat completion request
34
+ completion = client.chat.completions.create(
35
+ model="llama-3.1-8b-instant",
36
+ messages=[
37
+ {
38
+ "role": "user",
39
+ "content": "Explain why low-latency is important for LLMs in one short sentence."
40
+ }
41
+ ],
42
+ temperature=0.7,
43
+ max_tokens=1024,
44
+ top_p=1,
45
+ stream=True,
46
+ stop=None,
47
+ )
48
+
49
+ # 4. Print the streamed response from the model
50
+ print("\nπŸ“ Groq API Response:")
51
+ print("-" * 20)
52
+ for chunk in completion:
53
+ print(chunk.choices[0].delta.content or "", end="")
54
+ print("\n" + "-" * 20)
55
+ print("\nβœ… Test successful! The Groq endpoint is working.")
56
+
57
+ except Exception as e:
58
+ print(f"πŸ”΄ An error occurred during the Groq API call: {e}")
59
+
60
+ def test_gemini_connection():
61
+ """
62
+ Loads the Google Gemini API key from a .env file and tests the endpoint
63
+ using the genai.Client pattern.
64
+ """
65
+ # 1. Build a reliable path to the .env file (assuming same location)
66
+ script_dir = Path(__file__).parent
67
+ project_root = script_dir.parent
68
+ dotenv_path = project_root / "configs" / ".env"
69
+ load_dotenv(dotenv_path=dotenv_path)
70
+
71
+ api_key = os.getenv("GOOGLE_API_KEY")
72
+ if not api_key:
73
+ print(f"πŸ”΄ Error: GOOGLE_API_KEY not found.")
74
+ print(f"Please ensure it is set in your {dotenv_path} file.")
75
+ return
76
+
77
+ print("βœ… Google API key loaded successfully.")
78
+
79
+ try:
80
+ # 2. Initialize the Gemini client using the specified pattern
81
+ client = genai.Client(api_key=api_key)
82
+ print("πŸ€– Initialized Gemini client. Sending a test query...")
83
+
84
+ # 3. Send a test prompt using the client.models.generate_content method
85
+ response = client.models.generate_content(
86
+ model="gemini-2.5-flash", # Using the qualified model name
87
+ contents="Explain the importance of APIs in one short sentence."
88
+ )
89
+
90
+ # 4. Print the response
91
+ print("\nπŸ“ Gemini API Response:")
92
+ print("-" * 20)
93
+ print(response.text)
94
+ print("-" * 20)
95
+ print("\nβœ… Test successful! The Gemini endpoint is working.")
96
+
97
+ except Exception as e:
98
+ print(f"πŸ”΄ An error occurred during the Gemini API call: {e}")
99
+
100
+
101
+ # Run the test functions when the script is executed
102
+ if __name__ == "__main__":
103
+ print("--- Running Groq API Connection Test ---")
104
+ test_groq_connection()
105
+ print("\n" + "="*40 + "\n")
106
+ print("--- Running Gemini API Connection Test ---")
107
+ test_gemini_connection()
tests/utils/gemini.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from dotenv import load_dotenv
4
+ import google.generativeai as genai # 1. Import the library
5
+
6
+ # --- Robustly find and load the .env file ---
7
+ try:
8
+ # Navigate from the current script's directory up to the project root
9
+ script_dir = Path(__file__).resolve().parent
10
+ # Adjust this if your script is nested differently
11
+ project_root = script_dir.parent.parent
12
+ dotenv_path = project_root / "configs" / ".env"
13
+
14
+ if dotenv_path.exists():
15
+ load_dotenv(dotenv_path=dotenv_path)
16
+ print(f"βœ… Environment variables loaded from: {dotenv_path}")
17
+ else:
18
+ print(f"⚠️ Warning: .env file not found at {dotenv_path}.")
19
+
20
+ except Exception as e:
21
+ print(f"Could not load .env file: {e}")
22
+
23
+ # --- Get API key and list models using the client ---
24
+ api_key = os.getenv("GEMINI_API_KEY")
25
+ if api_key:
26
+ try:
27
+ # 2. Create a client instance. It automatically uses the API key
28
+ # from the environment variables.
29
+ client = genai.Client()
30
+
31
+ print("\nβœ… Available models for 'generateContent':")
32
+ # 3. Use the client object to list the models
33
+ for m in client.models.list():
34
+ if 'generateContent' in m.supported_generation_methods:
35
+ print(f"- {m.name}")
36
+
37
+ except Exception as e:
38
+ print(f"πŸ”΄ An error occurred while listing models: {e}")
39
+ print("πŸ’‘ Tip: Make sure your API key is correct and has the right permissions.")
40
+ else:
41
+ print("πŸ”΄ Error: GEMINI_API_KEY not found. Please check your .env file.")