Nomearod commited on
Commit
1bf7f2d
Β·
1 Parent(s): f56d519

refactor: address batch-2 review feedback

Browse files

Must-fix:
- Silent fallthrough when body.corpus was valid-per-Literal but not
in corpus_map (e.g. k8s removed from YAML but still in the
AskRequest Literal) is now a loud HTTPException(400) with the list
of available corpora in the detail. Mirrors the Task 2 AppConfig
validator for default_corpus but at request time.

Should-fix:
- Moved two_corpus_two_provider_app and _FakeOpenAI into
tests/conftest.py so test_corpus_routing.py, test_meta_corpus.py,
and test_prompt_template.py all consume a single source of truth
instead of cross-importing each other.
- Hoisted format_system_prompt import to the top of routes.py and
HTTPException / AppConfig alongside it.
- Fixed test_prompt_template.py ordering: imports at top, no more
bottom-of-file E402 noqa, no more `import pytest` inside a class
body.
- Removed `# noqa: F811` from test_meta_corpus.py fixture parameters
now that the fixture comes from conftest.

Security:
- AskRequest.provider is now a Literal constrained to
{openai, anthropic, selfhosted, mock}. Unknown providers fail
validation with 422 instead of silently falling back to
app.state.orchestrator.

Performance:
- format_system_prompt wrapped in @lru_cache(maxsize=32). The corpus
label set is tiny (a handful) and the function is called once per
/ask request; cache hit rate is effectively 100% post-warmup.
Includes a test that asserts identity-equality on repeated calls
so the cache can't silently regress.

Nice-to-haves:
- Direct unit tests for _resolve_orchestrator and
_resolve_system_prompt using a SimpleNamespace fake Request. Cover
happy path, provider fallback within corpus, legacy mode, and the
400-raise. Run in ~0.7s vs ~60s for HTTP-based coverage.
- Fixed `config: object` -> `config: AppConfig` in ask_stream.

Tests:
- 410 -> 421 (+11 new): 7 helper unit tests, 3 misconfiguration-400
tests, 1 lru_cache identity test. Ruff clean. Mypy clean on routes,
schemas, prompts.

agent_bench/core/prompts.py CHANGED
@@ -7,6 +7,8 @@ prevents per-corpus drift when the prompt is tuned.
7
 
8
  from __future__ import annotations
9
 
 
 
10
  SYSTEM_PROMPT_TEMPLATE = """\
11
  You are a technical documentation assistant for {corpus_label}. Answer \
12
  questions using ONLY the retrieved context from the {corpus_label} \
@@ -18,11 +20,13 @@ extrapolate, do not draw on general knowledge.\
18
  """
19
 
20
 
 
21
  def format_system_prompt(corpus_label: str) -> str:
22
  """Format the template with a corpus label.
23
 
24
- Raises at call time if the caller forgets to pass a label, which is
25
- louder than silently returning a prompt with an unresolved
 
26
  placeholder.
27
  """
28
  if not corpus_label:
 
7
 
8
  from __future__ import annotations
9
 
10
+ from functools import lru_cache
11
+
12
  SYSTEM_PROMPT_TEMPLATE = """\
13
  You are a technical documentation assistant for {corpus_label}. Answer \
14
  questions using ONLY the retrieved context from the {corpus_label} \
 
20
  """
21
 
22
 
23
+ @lru_cache(maxsize=32)
24
  def format_system_prompt(corpus_label: str) -> str:
25
  """Format the template with a corpus label.
26
 
27
+ Cached because the corpus label set is small (a handful of corpora)
28
+ and the prompt is requested once per /ask call. Raises on empty
29
+ label β€” louder than silently returning a prompt with an unresolved
30
  placeholder.
31
  """
32
  if not corpus_label:
agent_bench/serving/routes.py CHANGED
@@ -4,11 +4,13 @@ from __future__ import annotations
4
 
5
  import time
6
 
7
- from fastapi import APIRouter, Request
8
  from fastapi.responses import StreamingResponse
9
  from starlette.responses import Response
10
 
11
  from agent_bench.agents.orchestrator import Orchestrator
 
 
12
  from agent_bench.serving.middleware import MetricsCollector
13
  from agent_bench.serving.schemas import (
14
  AskRequest,
@@ -33,14 +35,34 @@ def _resolve_orchestrator(
33
  Legacy single-corpus mode: use the flat orchestrators dict keyed by
34
  provider name, then fall back to app.state.orchestrator.
35
 
 
 
 
 
 
 
 
36
  Returns the resolved orchestrator and the corpus name used (empty
37
  string in legacy mode when no default_corpus is configured).
38
  """
39
- config = request.app.state.config
40
  corpus_map: dict = getattr(request.app.state, "corpus_map", {})
41
  default_corpus: str = getattr(config, "default_corpus", "") or ""
42
  provider_default: str = config.provider.default
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  corpus_name: str = body.corpus or default_corpus
45
 
46
  if corpus_map and corpus_name in corpus_map:
@@ -71,11 +93,9 @@ def _resolve_system_prompt(
71
  from the task config (app.state.system_prompt) is returned unchanged
72
  and corpus_label is empty.
73
  """
74
- config = request.app.state.config
75
  corpora = getattr(config, "corpora", None) or {}
76
  if corpus_name and corpus_name in corpora:
77
- from agent_bench.core.prompts import format_system_prompt
78
-
79
  label = corpora[corpus_name].label
80
  return format_system_prompt(label), label
81
  return request.app.state.system_prompt, ""
@@ -210,7 +230,7 @@ async def ask_stream(body: AskRequest, request: Request) -> StreamingResponse:
210
  system_prompt, corpus_label = _resolve_system_prompt(request, corpus_name)
211
  metrics: MetricsCollector = request.app.state.metrics
212
  request_id: str = getattr(request.state, "request_id", "unknown")
213
- config: object = request.app.state.config
214
 
215
  # --- Meta event data (available before request starts) ---
216
  provider_name = getattr(config, "provider", None)
 
4
 
5
  import time
6
 
7
+ from fastapi import APIRouter, HTTPException, Request
8
  from fastapi.responses import StreamingResponse
9
  from starlette.responses import Response
10
 
11
  from agent_bench.agents.orchestrator import Orchestrator
12
+ from agent_bench.core.config import AppConfig
13
+ from agent_bench.core.prompts import format_system_prompt
14
  from agent_bench.serving.middleware import MetricsCollector
15
  from agent_bench.serving.schemas import (
16
  AskRequest,
 
35
  Legacy single-corpus mode: use the flat orchestrators dict keyed by
36
  provider name, then fall back to app.state.orchestrator.
37
 
38
+ Raises:
39
+ HTTPException(400): body.corpus is explicitly set in multi-corpus
40
+ mode but the named corpus is not wired on this server. Pydantic
41
+ Literal on AskRequest.corpus catches unknown names (422); this
42
+ catches "known per schema but not deployed" at 400. Mirrors
43
+ the AppConfig validator for default_corpus at request time.
44
+
45
  Returns the resolved orchestrator and the corpus name used (empty
46
  string in legacy mode when no default_corpus is configured).
47
  """
48
+ config: AppConfig = request.app.state.config
49
  corpus_map: dict = getattr(request.app.state, "corpus_map", {})
50
  default_corpus: str = getattr(config, "default_corpus", "") or ""
51
  provider_default: str = config.provider.default
52
 
53
+ # Fail loud if the request names a corpus that is not wired. Only
54
+ # fires when body.corpus is explicit β€” a None corpus always falls
55
+ # through to default_corpus (which the AppConfig validator guarantees
56
+ # is in corpus_map when corpora is non-empty).
57
+ if corpus_map and body.corpus is not None and body.corpus not in corpus_map:
58
+ raise HTTPException(
59
+ status_code=400,
60
+ detail=(
61
+ f"Corpus {body.corpus!r} is not configured on this server. "
62
+ f"Available corpora: {sorted(corpus_map.keys())}"
63
+ ),
64
+ )
65
+
66
  corpus_name: str = body.corpus or default_corpus
67
 
68
  if corpus_map and corpus_name in corpus_map:
 
93
  from the task config (app.state.system_prompt) is returned unchanged
94
  and corpus_label is empty.
95
  """
96
+ config: AppConfig = request.app.state.config
97
  corpora = getattr(config, "corpora", None) or {}
98
  if corpus_name and corpus_name in corpora:
 
 
99
  label = corpora[corpus_name].label
100
  return format_system_prompt(label), label
101
  return request.app.state.system_prompt, ""
 
230
  system_prompt, corpus_label = _resolve_system_prompt(request, corpus_name)
231
  metrics: MetricsCollector = request.app.state.metrics
232
  request_id: str = getattr(request.state, "request_id", "unknown")
233
+ config: AppConfig = request.app.state.config
234
 
235
  # --- Meta event data (available before request starts) ---
236
  provider_name = getattr(config, "provider", None)
agent_bench/serving/schemas.py CHANGED
@@ -15,9 +15,14 @@ class AskRequest(BaseModel):
15
  top_k: int = 5
16
  retrieval_strategy: Literal["semantic", "keyword", "hybrid"] = "hybrid"
17
  session_id: str | None = None # None = stateless (V1 behavior)
18
- provider: str | None = None # None = use server default
 
 
 
19
  # Per-request corpus selection. None = use default_corpus from config.
20
- # Unknown values are rejected at validation time with HTTP 422.
 
 
21
  corpus: Literal["fastapi", "k8s"] | None = None
22
 
23
 
 
15
  top_k: int = 5
16
  retrieval_strategy: Literal["semantic", "keyword", "hybrid"] = "hybrid"
17
  session_id: str | None = None # None = stateless (V1 behavior)
18
+ # Per-request provider override. Constrained to the set of known
19
+ # provider names so unknown values are rejected at validation time
20
+ # with HTTP 422 instead of silently falling back.
21
+ provider: Literal["openai", "anthropic", "selfhosted", "mock"] | None = None
22
  # Per-request corpus selection. None = use default_corpus from config.
23
+ # Unknown values rejected at validation time with HTTP 422. Names that
24
+ # pass validation but are not wired on the current server produce a
25
+ # 400 in the route handler (see _resolve_orchestrator).
26
  corpus: Literal["fastapi", "k8s"] | None = None
27
 
28
 
tests/conftest.py CHANGED
@@ -110,3 +110,70 @@ def test_store(mock_embedder: Embedder, sample_chunks: list[Chunk]) -> HybridSto
110
  def test_retriever(mock_embedder: Embedder, test_store: HybridStore) -> Retriever:
111
  """Retriever wired to mock embedder + test store."""
112
  return Retriever(embedder=mock_embedder, store=test_store)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  def test_retriever(mock_embedder: Embedder, test_store: HybridStore) -> Retriever:
111
  """Retriever wired to mock embedder + test store."""
112
  return Retriever(embedder=mock_embedder, store=test_store)
113
+
114
+
115
+ # --- Multi-corpus test app (shared across routing / meta / prompt tests) ---
116
+
117
+
118
+ class _FakeOpenAI(MockProvider):
119
+ """Distinct MockProvider subclass so tests can tell it apart from
120
+ the default mock when asserting which orchestrator actually ran."""
121
+
122
+
123
+ @pytest.fixture
124
+ def two_corpus_two_provider_app(tmp_path, monkeypatch):
125
+ """Two corpora (fastapi, k8s) Γ— two providers (mock, openai-faked).
126
+
127
+ After building the app, each corpus Γ— provider cell gets a *unique*
128
+ MockProvider instance tagged with `_tag`. create_app deliberately
129
+ shares one provider instance across corpora in production (providers
130
+ hold LLM clients and are expensive), but the test needs to distinguish
131
+ which cell ran a given request β€” so the fixture breaks the sharing
132
+ here and only here.
133
+ """
134
+ from agent_bench.core import provider as provider_mod
135
+ from agent_bench.core.config import (
136
+ AppConfig,
137
+ CorpusConfig,
138
+ EmbeddingConfig,
139
+ ProviderConfig,
140
+ RAGConfig,
141
+ SecurityConfig,
142
+ )
143
+ from agent_bench.serving.app import create_app
144
+
145
+ monkeypatch.setattr(provider_mod, "OpenAIProvider", lambda _cfg: _FakeOpenAI())
146
+ monkeypatch.setenv("OPENAI_API_KEY", "test-key")
147
+
148
+ config = AppConfig(
149
+ provider=ProviderConfig(default="mock"),
150
+ rag=RAGConfig(store_path=str(tmp_path / "store_default")),
151
+ embedding=EmbeddingConfig(cache_dir=str(tmp_path / "emb_cache")),
152
+ security=SecurityConfig(),
153
+ corpora={
154
+ "fastapi": CorpusConfig(
155
+ label="FastAPI Docs",
156
+ store_path=str(tmp_path / "store_fastapi"),
157
+ data_path="data/tech_docs",
158
+ ),
159
+ "k8s": CorpusConfig(
160
+ label="Kubernetes",
161
+ store_path=str(tmp_path / "store_k8s"),
162
+ data_path="data/k8s_docs",
163
+ ),
164
+ },
165
+ default_corpus="fastapi",
166
+ )
167
+ app = create_app(config)
168
+
169
+ # Stamp a unique provider into each cell so call_count is per-cell.
170
+ for c_name, inner in app.state.corpus_map.items():
171
+ for p_name, orch in inner.items():
172
+ unique = MockProvider()
173
+ unique._tag = f"{c_name}:{p_name}" # type: ignore[attr-defined]
174
+ orch.provider = unique
175
+ # Keep the flat orchestrators dict and the singular orchestrator in
176
+ # sync with the per-cell instances for the default corpus.
177
+ app.state.orchestrators = dict(app.state.corpus_map[config.default_corpus])
178
+ app.state.orchestrator = app.state.orchestrators[config.provider.default]
179
+ return app
tests/test_corpus_routing.py CHANGED
@@ -1,8 +1,9 @@
1
  """Tests for per-request corpus routing.
2
 
3
  Exercises the full corpus Γ— provider matrix through /ask and /ask/stream.
4
- Uses create_app with a multi-corpus fixture and monkeypatches a second
5
- provider so both toggles can be tested without real API keys.
 
6
  """
7
 
8
  from __future__ import annotations
@@ -18,64 +19,9 @@ from agent_bench.core.config import (
18
  RAGConfig,
19
  SecurityConfig,
20
  )
21
- from agent_bench.core.provider import MockProvider
22
  from agent_bench.serving.app import create_app
23
 
24
 
25
- class _FakeOpenAI(MockProvider):
26
- """Distinct MockProvider subclass so we can distinguish it from the
27
- default mock when asserting which orchestrator actually ran."""
28
-
29
-
30
- @pytest.fixture
31
- def two_corpus_two_provider_app(tmp_path, monkeypatch):
32
- """Two corpora (fastapi, k8s) Γ— two providers (mock, openai-faked).
33
-
34
- After building the app, each corpusΓ—provider cell gets a *unique*
35
- MockProvider instance stamped with a `_tag` attribute. create_app
36
- deliberately shares one provider instance across corpora (it's an
37
- expensive object), but the test needs to distinguish which cell ran
38
- a given request, so we break the sharing here and only here.
39
- """
40
- from agent_bench.core import provider as provider_mod
41
-
42
- monkeypatch.setattr(provider_mod, "OpenAIProvider", lambda _cfg: _FakeOpenAI())
43
- monkeypatch.setenv("OPENAI_API_KEY", "test-key")
44
-
45
- config = AppConfig(
46
- provider=ProviderConfig(default="mock"),
47
- rag=RAGConfig(store_path=str(tmp_path / "store_default")),
48
- embedding=EmbeddingConfig(cache_dir=str(tmp_path / "emb_cache")),
49
- security=SecurityConfig(),
50
- corpora={
51
- "fastapi": CorpusConfig(
52
- label="FastAPI Docs",
53
- store_path=str(tmp_path / "store_fastapi"),
54
- data_path="data/tech_docs",
55
- ),
56
- "k8s": CorpusConfig(
57
- label="Kubernetes",
58
- store_path=str(tmp_path / "store_k8s"),
59
- data_path="data/k8s_docs",
60
- ),
61
- },
62
- default_corpus="fastapi",
63
- )
64
- app = create_app(config)
65
-
66
- # Stamp a unique provider into each cell so call_count is per-cell.
67
- for c_name, inner in app.state.corpus_map.items():
68
- for p_name, orch in inner.items():
69
- unique = MockProvider()
70
- unique._tag = f"{c_name}:{p_name}" # type: ignore[attr-defined]
71
- orch.provider = unique
72
- # Keep the flat orchestrators dict and the singular orchestrator in
73
- # sync with the per-cell instances for the default corpus.
74
- app.state.orchestrators = dict(app.state.corpus_map[config.default_corpus])
75
- app.state.orchestrator = app.state.orchestrators[config.provider.default]
76
- return app
77
-
78
-
79
  def _reset_call_counts(app):
80
  """Zero out provider.call_count on every orchestrator in corpus_map."""
81
  for inner in app.state.corpus_map.values():
@@ -196,3 +142,227 @@ class TestLegacyRouting:
196
  "/ask", json={"question": "hi", "corpus": "fastapi"},
197
  )
198
  assert resp.status_code == 200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Tests for per-request corpus routing.
2
 
3
  Exercises the full corpus Γ— provider matrix through /ask and /ask/stream.
4
+ The multi-corpus test-app fixture (`two_corpus_two_provider_app`) lives
5
+ in tests/conftest.py and is shared with test_meta_corpus.py and
6
+ test_prompt_template.py.
7
  """
8
 
9
  from __future__ import annotations
 
19
  RAGConfig,
20
  SecurityConfig,
21
  )
 
22
  from agent_bench.serving.app import create_app
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def _reset_call_counts(app):
26
  """Zero out provider.call_count on every orchestrator in corpus_map."""
27
  for inner in app.state.corpus_map.values():
 
142
  "/ask", json={"question": "hi", "corpus": "fastapi"},
143
  )
144
  assert resp.status_code == 200
145
+
146
+
147
+ class TestMisconfiguredCorpus:
148
+ """body.corpus valid-per-Literal but not in corpus_map should fail
149
+ loud at request time with 400 instead of silently falling through
150
+ to the legacy orchestrator."""
151
+
152
+ @pytest.fixture
153
+ def fastapi_only_app(self, tmp_path, monkeypatch):
154
+ """Multi-corpus mode with ONLY fastapi configured (k8s removed)."""
155
+ from agent_bench.core import provider as provider_mod
156
+ from agent_bench.core.provider import MockProvider
157
+
158
+ monkeypatch.setattr(
159
+ provider_mod, "OpenAIProvider", lambda _cfg: MockProvider(),
160
+ )
161
+ config = AppConfig(
162
+ provider=ProviderConfig(default="mock"),
163
+ rag=RAGConfig(store_path=str(tmp_path / "store_default")),
164
+ embedding=EmbeddingConfig(cache_dir=str(tmp_path / "emb_cache")),
165
+ security=SecurityConfig(),
166
+ corpora={
167
+ "fastapi": CorpusConfig(
168
+ label="FastAPI Docs",
169
+ store_path=str(tmp_path / "store_fastapi"),
170
+ data_path="data/tech_docs",
171
+ ),
172
+ },
173
+ default_corpus="fastapi",
174
+ )
175
+ return create_app(config)
176
+
177
+ @pytest.mark.asyncio
178
+ async def test_unconfigured_corpus_returns_400(self, fastapi_only_app):
179
+ """k8s passes Literal but is not in corpus_map β€” expect 400."""
180
+ async with AsyncClient(
181
+ transport=ASGITransport(app=fastapi_only_app), base_url="http://test",
182
+ ) as client:
183
+ resp = await client.post(
184
+ "/ask", json={"question": "hi", "corpus": "k8s"},
185
+ )
186
+ assert resp.status_code == 400
187
+ detail = resp.json().get("detail", "")
188
+ assert "not configured" in detail.lower()
189
+ assert "k8s" in detail
190
+ assert "fastapi" in detail # lists available corpora
191
+
192
+ @pytest.mark.asyncio
193
+ async def test_unconfigured_corpus_returns_400_stream(self, fastapi_only_app):
194
+ """Same guard on /ask/stream."""
195
+ async with AsyncClient(
196
+ transport=ASGITransport(app=fastapi_only_app), base_url="http://test",
197
+ ) as client:
198
+ resp = await client.post(
199
+ "/ask/stream", json={"question": "hi", "corpus": "k8s"},
200
+ )
201
+ assert resp.status_code == 400
202
+
203
+ @pytest.mark.asyncio
204
+ async def test_default_corpus_still_works(self, fastapi_only_app):
205
+ """fastapi (the only configured corpus) still routes fine."""
206
+ async with AsyncClient(
207
+ transport=ASGITransport(app=fastapi_only_app), base_url="http://test",
208
+ ) as client:
209
+ resp = await client.post(
210
+ "/ask", json={"question": "hi", "corpus": "fastapi"},
211
+ )
212
+ assert resp.status_code == 200
213
+
214
+
215
+ class TestResolveOrchestratorDirect:
216
+ """Unit tests for _resolve_orchestrator without the HTTP stack.
217
+
218
+ Builds a fake Request object with just the app.state attributes the
219
+ helper reads. Catches edge cases that integration tests would miss
220
+ (unknown provider, explicit provider not in inner dict, etc.).
221
+ """
222
+
223
+ @pytest.fixture
224
+ def fake_request_builder(self):
225
+ """Return a factory that makes a fake Request with the given state."""
226
+ from types import SimpleNamespace
227
+
228
+ def build(
229
+ corpus_map,
230
+ default_corpus,
231
+ provider_default,
232
+ orchestrators=None,
233
+ orchestrator=None,
234
+ system_prompt="legacy prompt",
235
+ ):
236
+ state = SimpleNamespace(
237
+ config=SimpleNamespace(
238
+ corpora={k: SimpleNamespace(label=k.title()) for k in corpus_map},
239
+ default_corpus=default_corpus,
240
+ provider=SimpleNamespace(default=provider_default),
241
+ ),
242
+ corpus_map=corpus_map,
243
+ orchestrators=orchestrators or {},
244
+ orchestrator=orchestrator,
245
+ system_prompt=system_prompt,
246
+ )
247
+ return SimpleNamespace(app=SimpleNamespace(state=state))
248
+
249
+ return build
250
+
251
+ def _make_body(self, corpus=None, provider=None):
252
+ from agent_bench.serving.schemas import AskRequest
253
+
254
+ return AskRequest(question="x", corpus=corpus, provider=provider)
255
+
256
+ def test_multi_corpus_happy_path(self, fake_request_builder):
257
+ from agent_bench.serving.routes import _resolve_orchestrator
258
+
259
+ sentinel = object()
260
+ req = fake_request_builder(
261
+ corpus_map={"fastapi": {"mock": sentinel}},
262
+ default_corpus="fastapi",
263
+ provider_default="mock",
264
+ )
265
+ orch, name = _resolve_orchestrator(req, self._make_body())
266
+ assert orch is sentinel
267
+ assert name == "fastapi"
268
+
269
+ def test_provider_fallback_to_corpus_default(self, fake_request_builder):
270
+ """body.provider=None uses corpus default; Literal accepts None."""
271
+ from agent_bench.serving.routes import _resolve_orchestrator
272
+
273
+ mock_sent = object()
274
+ oai_sent = object()
275
+ req = fake_request_builder(
276
+ corpus_map={"fastapi": {"mock": mock_sent, "openai": oai_sent}},
277
+ default_corpus="fastapi",
278
+ provider_default="mock",
279
+ )
280
+ orch, _ = _resolve_orchestrator(req, self._make_body(provider="openai"))
281
+ assert orch is oai_sent
282
+ orch, _ = _resolve_orchestrator(req, self._make_body())
283
+ assert orch is mock_sent
284
+
285
+ def test_explicit_unconfigured_corpus_raises_400(self, fake_request_builder):
286
+ from fastapi import HTTPException
287
+
288
+ from agent_bench.serving.routes import _resolve_orchestrator
289
+
290
+ req = fake_request_builder(
291
+ corpus_map={"fastapi": {"mock": object()}},
292
+ default_corpus="fastapi",
293
+ provider_default="mock",
294
+ )
295
+ with pytest.raises(HTTPException) as exc_info:
296
+ _resolve_orchestrator(req, self._make_body(corpus="k8s"))
297
+ assert exc_info.value.status_code == 400
298
+ assert "k8s" in exc_info.value.detail
299
+ assert "fastapi" in exc_info.value.detail
300
+
301
+ def test_legacy_mode_uses_flat_orchestrators(self, fake_request_builder):
302
+ from agent_bench.serving.routes import _resolve_orchestrator
303
+
304
+ legacy_orch = object()
305
+ flat_oai = object()
306
+ req = fake_request_builder(
307
+ corpus_map={},
308
+ default_corpus="",
309
+ provider_default="mock",
310
+ orchestrators={"openai": flat_oai},
311
+ orchestrator=legacy_orch,
312
+ )
313
+ # body.provider=openai finds it in flat dict
314
+ orch, _ = _resolve_orchestrator(req, self._make_body(provider="openai"))
315
+ assert orch is flat_oai
316
+ # No provider falls back to app.state.orchestrator
317
+ orch, _ = _resolve_orchestrator(req, self._make_body())
318
+ assert orch is legacy_orch
319
+
320
+
321
+ class TestResolveSystemPromptDirect:
322
+ """Unit tests for _resolve_system_prompt."""
323
+
324
+ def _build_req(self, corpora, system_prompt="legacy"):
325
+ from types import SimpleNamespace
326
+
327
+ state = SimpleNamespace(
328
+ config=SimpleNamespace(corpora=corpora),
329
+ system_prompt=system_prompt,
330
+ )
331
+ return SimpleNamespace(app=SimpleNamespace(state=state))
332
+
333
+ def test_multi_corpus_formats_template(self):
334
+ from types import SimpleNamespace
335
+
336
+ from agent_bench.serving.routes import _resolve_system_prompt
337
+
338
+ req = self._build_req(
339
+ {"fastapi": SimpleNamespace(label="FastAPI Docs")},
340
+ )
341
+ prompt, label = _resolve_system_prompt(req, "fastapi")
342
+ assert label == "FastAPI Docs"
343
+ assert "FastAPI Docs" in prompt
344
+ assert "{corpus_label}" not in prompt
345
+ assert "refuse" in prompt.lower()
346
+
347
+ def test_legacy_returns_task_prompt(self):
348
+ from agent_bench.serving.routes import _resolve_system_prompt
349
+
350
+ req = self._build_req({}, system_prompt="legacy task prompt")
351
+ prompt, label = _resolve_system_prompt(req, "")
352
+ assert prompt == "legacy task prompt"
353
+ assert label == ""
354
+
355
+ def test_unknown_corpus_name_falls_to_legacy(self):
356
+ """If corpus_name isn't in corpora (shouldn't happen post-resolve
357
+ because of the 400 guard, but the helper should still be safe)."""
358
+ from types import SimpleNamespace
359
+
360
+ from agent_bench.serving.routes import _resolve_system_prompt
361
+
362
+ req = self._build_req(
363
+ {"fastapi": SimpleNamespace(label="FastAPI Docs")},
364
+ system_prompt="legacy",
365
+ )
366
+ prompt, label = _resolve_system_prompt(req, "nonexistent")
367
+ assert prompt == "legacy"
368
+ assert label == ""
tests/test_meta_corpus.py CHANGED
@@ -1,4 +1,7 @@
1
- """Tests for corpus + corpus_label fields in the SSE meta event."""
 
 
 
2
 
3
  from __future__ import annotations
4
 
@@ -7,8 +10,6 @@ import json as json_mod
7
  import pytest
8
  from httpx import ASGITransport, AsyncClient
9
 
10
- from tests.test_corpus_routing import two_corpus_two_provider_app # noqa: F401
11
-
12
 
13
  def _parse_sse(text: str) -> list[dict]:
14
  events = []
@@ -21,7 +22,7 @@ def _parse_sse(text: str) -> list[dict]:
21
  class TestMetaCorpus:
22
  @pytest.mark.asyncio
23
  async def test_meta_includes_corpus_and_label_default(
24
- self, two_corpus_two_provider_app, # noqa: F811
25
  ):
26
  app = two_corpus_two_provider_app
27
  async with AsyncClient(
@@ -35,7 +36,7 @@ class TestMetaCorpus:
35
 
36
  @pytest.mark.asyncio
37
  async def test_meta_reflects_explicit_corpus(
38
- self, two_corpus_two_provider_app, # noqa: F811
39
  ):
40
  app = two_corpus_two_provider_app
41
  async with AsyncClient(
@@ -51,7 +52,7 @@ class TestMetaCorpus:
51
 
52
  @pytest.mark.asyncio
53
  async def test_meta_provider_label_composes_with_corpus(
54
- self, two_corpus_two_provider_app, # noqa: F811
55
  ):
56
  """Provider field in meta still reflects the config default."""
57
  app = two_corpus_two_provider_app
 
1
+ """Tests for corpus + corpus_label fields in the SSE meta event.
2
+
3
+ The multi-corpus fixture is auto-loaded from tests/conftest.py.
4
+ """
5
 
6
  from __future__ import annotations
7
 
 
10
  import pytest
11
  from httpx import ASGITransport, AsyncClient
12
 
 
 
13
 
14
  def _parse_sse(text: str) -> list[dict]:
15
  events = []
 
22
  class TestMetaCorpus:
23
  @pytest.mark.asyncio
24
  async def test_meta_includes_corpus_and_label_default(
25
+ self, two_corpus_two_provider_app,
26
  ):
27
  app = two_corpus_two_provider_app
28
  async with AsyncClient(
 
36
 
37
  @pytest.mark.asyncio
38
  async def test_meta_reflects_explicit_corpus(
39
+ self, two_corpus_two_provider_app,
40
  ):
41
  app = two_corpus_two_provider_app
42
  async with AsyncClient(
 
52
 
53
  @pytest.mark.asyncio
54
  async def test_meta_provider_label_composes_with_corpus(
55
+ self, two_corpus_two_provider_app,
56
  ):
57
  """Provider field in meta still reflects the config default."""
58
  app = two_corpus_two_provider_app
tests/test_prompt_template.py CHANGED
@@ -1,7 +1,14 @@
1
- """Tests for the parameterized system prompt template."""
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
 
 
 
5
  from agent_bench.core.prompts import SYSTEM_PROMPT_TEMPLATE, format_system_prompt
6
 
7
 
@@ -47,25 +54,26 @@ def test_format_requires_citations():
47
  def test_format_rejects_empty_label():
48
  """Empty label is a caller bug β€” fail loud instead of producing a
49
  prompt with an unresolved placeholder."""
50
- import pytest as _pytest
51
-
52
- with _pytest.raises(ValueError, match="corpus_label"):
53
  format_system_prompt("")
54
 
55
 
 
 
 
 
 
 
 
56
  class TestRouteHandlerUsesFormattedPrompt:
57
  """In multi-corpus mode the orchestrator must receive a prompt
58
  formatted with the active corpus's label β€” not the legacy
59
  app.state.system_prompt."""
60
 
61
- import pytest
62
-
63
  @pytest.mark.asyncio
64
  async def test_stream_passes_k8s_prompt_to_orchestrator(
65
- self, two_corpus_two_provider_app, # noqa: F811
66
  ):
67
- from httpx import ASGITransport, AsyncClient
68
-
69
  app = two_corpus_two_provider_app
70
  # Record every system_prompt the orchestrator sees.
71
  captured: list[str] = []
@@ -96,10 +104,8 @@ class TestRouteHandlerUsesFormattedPrompt:
96
 
97
  @pytest.mark.asyncio
98
  async def test_fastapi_and_k8s_prompts_differ(
99
- self, two_corpus_two_provider_app, # noqa: F811
100
  ):
101
- from httpx import ASGITransport, AsyncClient
102
-
103
  app = two_corpus_two_provider_app
104
  captured: dict[str, str] = {}
105
 
@@ -130,7 +136,3 @@ class TestRouteHandlerUsesFormattedPrompt:
130
  assert "FastAPI Docs" in captured["fastapi"]
131
  assert "Kubernetes" in captured["k8s"]
132
  assert captured["fastapi"] != captured["k8s"]
133
-
134
-
135
- # Re-export the multi-corpus fixture so the class above can use it
136
- from tests.test_corpus_routing import two_corpus_two_provider_app # noqa: F401, E402
 
1
+ """Tests for the parameterized system prompt template.
2
+
3
+ The integration tests rely on `two_corpus_two_provider_app` from
4
+ tests/conftest.py.
5
+ """
6
 
7
  from __future__ import annotations
8
 
9
+ import pytest
10
+ from httpx import ASGITransport, AsyncClient
11
+
12
  from agent_bench.core.prompts import SYSTEM_PROMPT_TEMPLATE, format_system_prompt
13
 
14
 
 
54
  def test_format_rejects_empty_label():
55
  """Empty label is a caller bug β€” fail loud instead of producing a
56
  prompt with an unresolved placeholder."""
57
+ with pytest.raises(ValueError, match="corpus_label"):
 
 
58
  format_system_prompt("")
59
 
60
 
61
+ def test_format_is_cached():
62
+ """@lru_cache on format_system_prompt β€” same input returns same object."""
63
+ a = format_system_prompt("FastAPI Docs")
64
+ b = format_system_prompt("FastAPI Docs")
65
+ assert a is b # cached: same object identity, not just equal
66
+
67
+
68
  class TestRouteHandlerUsesFormattedPrompt:
69
  """In multi-corpus mode the orchestrator must receive a prompt
70
  formatted with the active corpus's label β€” not the legacy
71
  app.state.system_prompt."""
72
 
 
 
73
  @pytest.mark.asyncio
74
  async def test_stream_passes_k8s_prompt_to_orchestrator(
75
+ self, two_corpus_two_provider_app,
76
  ):
 
 
77
  app = two_corpus_two_provider_app
78
  # Record every system_prompt the orchestrator sees.
79
  captured: list[str] = []
 
104
 
105
  @pytest.mark.asyncio
106
  async def test_fastapi_and_k8s_prompts_differ(
107
+ self, two_corpus_two_provider_app,
108
  ):
 
 
109
  app = two_corpus_two_provider_app
110
  captured: dict[str, str] = {}
111
 
 
136
  assert "FastAPI Docs" in captured["fastapi"]
137
  assert "Kubernetes" in captured["k8s"]
138
  assert captured["fastapi"] != captured["k8s"]