GitHub Actions commited on
Commit
1d47e3c
·
1 Parent(s): c1411e9

Deploy 555915a

Browse files
app/api/chat.py CHANGED
@@ -29,6 +29,28 @@ def _is_criticism(message: str) -> bool:
29
  return any(sig in lowered for sig in _CRITICISM_SIGNALS)
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  async def _generate_follow_ups(
33
  query: str,
34
  answer: str,
@@ -270,17 +292,9 @@ async def chat_endpoint(
270
 
271
  elapsed_ms = int((time.monotonic() - start_time) * 1000)
272
 
273
- # Citation-index filtering single serialisation-time safety net.
274
- # Applies to all paths (RAG, Gemini fast-path, enumeration).
275
- # If the answer cites only [3][5], only sources 3 and 5 are sent;
276
- # all other chunks retrieved but not cited are discarded here.
277
- if final_answer and final_sources:
278
- cited_nums = {int(m) for m in re.findall(r"\[(\d+)\]", final_answer)}
279
- if cited_nums:
280
- final_sources = [
281
- s for i, s in enumerate(final_sources, start=1)
282
- if i in cited_nums
283
- ]
284
 
285
  sources_list = [
286
  s.model_dump() if hasattr(s, "model_dump")
 
29
  return any(sig in lowered for sig in _CRITICISM_SIGNALS)
30
 
31
 
32
+ def _filter_sources_by_citations(answer: str, sources: list) -> list:
33
+ """
34
+ Keep only sources explicitly cited in answer text.
35
+
36
+ If sources are already pre-filtered upstream (e.g. generate node returned
37
+ only cited sources from original indices), citation numbers may no longer
38
+ match local list positions. In that case, keep the original list unchanged.
39
+ """
40
+ if not answer or not sources:
41
+ return sources
42
+
43
+ cited_nums = {int(m) for m in re.findall(r"\[(\d+)\]", answer)}
44
+ if not cited_nums:
45
+ return sources
46
+
47
+ max_cited = max(cited_nums)
48
+ if max_cited > len(sources):
49
+ return sources
50
+
51
+ return [s for i, s in enumerate(sources, start=1) if i in cited_nums]
52
+
53
+
54
  async def _generate_follow_ups(
55
  query: str,
56
  answer: str,
 
292
 
293
  elapsed_ms = int((time.monotonic() - start_time) * 1000)
294
 
295
+ # Citation-index filtering safety net for paths that return full
296
+ # source lists. No-op when sources are already citation-filtered.
297
+ final_sources = _filter_sources_by_citations(final_answer, final_sources)
 
 
 
 
 
 
 
 
298
 
299
  sources_list = [
300
  s.model_dump() if hasattr(s, "model_dump")
app/core/config.py CHANGED
@@ -17,6 +17,9 @@ class Settings(BaseSettings):
17
  QDRANT_URL: str
18
  QDRANT_API_KEY: Optional[str] = None
19
  QDRANT_COLLECTION: str = "knowledge_base"
 
 
 
20
 
21
  # In-memory semantic cache
22
  # Replaces Redis. No external service required.
 
17
  QDRANT_URL: str
18
  QDRANT_API_KEY: Optional[str] = None
19
  QDRANT_COLLECTION: str = "knowledge_base"
20
+ # Keepalive ping interval to touch Qdrant regularly and avoid idle expiry.
21
+ # Default is 6 days (< 1 week) so the database is contacted at least weekly.
22
+ QDRANT_KEEPALIVE_SECONDS: int = 518400
23
 
24
  # In-memory semantic cache
25
  # Replaces Redis. No external service required.
app/main.py CHANGED
@@ -1,10 +1,13 @@
 
1
  from contextlib import asynccontextmanager
2
  import os
3
  import sqlite3
 
4
 
5
  from fastapi import FastAPI, Request
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from fastapi.responses import JSONResponse
 
8
  from slowapi.errors import RateLimitExceeded
9
 
10
  from app.api.admin import router as admin_router
@@ -28,6 +31,40 @@ from qdrant_client import QdrantClient
28
  logger = get_logger(__name__)
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def _sqlite_row_count(db_path: str) -> int:
32
  """Return the current interactions row count, or 0 if the table doesn't exist."""
33
  try:
@@ -39,6 +76,33 @@ def _sqlite_row_count(db_path: str) -> int:
39
  return 0
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  @asynccontextmanager
43
  async def lifespan(app: FastAPI):
44
  settings = get_settings()
@@ -96,8 +160,9 @@ async def lifespan(app: FastAPI):
96
  from app.services.vector_store import VectorStore
97
  from app.security.guard_classifier import GuardClassifier
98
 
 
99
  qdrant = QdrantClient(
100
- url=settings.QDRANT_URL,
101
  api_key=settings.QDRANT_API_KEY,
102
  timeout=60,
103
  )
@@ -105,7 +170,26 @@ async def lifespan(app: FastAPI):
105
  vector_store = VectorStore(qdrant, settings.QDRANT_COLLECTION)
106
  # Idempotent: creates collection if absent so a cold-start before first
107
  # ingest run doesn't crash every search with "collection not found".
108
- vector_store.ensure_collection()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  # Issue 7: shared TPM bucket tracks token consumption across the current 60s
111
  # window. Injected into GroqClient so it can downgrade 70B → 8B automatically
@@ -130,10 +214,28 @@ async def lifespan(app: FastAPI):
130
  app.state.settings = settings
131
  app.state.qdrant = qdrant
132
 
 
 
 
 
 
 
 
 
 
 
 
133
  logger.info("Startup complete")
134
  yield
135
 
136
  logger.info("Shutting down")
 
 
 
 
 
 
 
137
  app.state.semantic_cache = None
138
  app.state.qdrant.close()
139
  # Only attempt to end an MLflow run when DagsHub tracking was enabled at startup.
 
1
+ import asyncio
2
  from contextlib import asynccontextmanager
3
  import os
4
  import sqlite3
5
+ from urllib.parse import urlsplit, urlunsplit
6
 
7
  from fastapi import FastAPI, Request
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from fastapi.responses import JSONResponse
10
+ from qdrant_client.http.exceptions import UnexpectedResponse
11
  from slowapi.errors import RateLimitExceeded
12
 
13
  from app.api.admin import router as admin_router
 
31
  logger = get_logger(__name__)
32
 
33
 
34
+ def _is_qdrant_not_found(exc: Exception) -> bool:
35
+ """Return True when Qdrant responded with HTTP 404."""
36
+ if isinstance(exc, UnexpectedResponse):
37
+ status_code = getattr(exc, "status_code", None)
38
+ if status_code == 404:
39
+ return True
40
+ message = str(exc)
41
+ return "404" in message and "page not found" in message.lower()
42
+
43
+
44
+ def _normalize_qdrant_url(url: str) -> str:
45
+ """
46
+ Normalize QDRANT_URL to an API base URL.
47
+
48
+ If the configured URL includes a non-root path (for example, a dashboard
49
+ URL), strip the path and keep scheme + host(+port) only.
50
+ """
51
+ raw = (url or "").strip().rstrip("/")
52
+ if not raw:
53
+ return raw
54
+
55
+ if "://" not in raw:
56
+ scheme = "http" if raw.startswith(("localhost", "127.0.0.1")) else "https"
57
+ raw = f"{scheme}://{raw}"
58
+
59
+ parsed = urlsplit(raw)
60
+ if not parsed.netloc:
61
+ return raw
62
+
63
+ if parsed.path and parsed.path != "/":
64
+ return urlunsplit((parsed.scheme, parsed.netloc, "", "", "")).rstrip("/")
65
+ return raw
66
+
67
+
68
  def _sqlite_row_count(db_path: str) -> int:
69
  """Return the current interactions row count, or 0 if the table doesn't exist."""
70
  try:
 
76
  return 0
77
 
78
 
79
+ async def _qdrant_keepalive_loop(
80
+ qdrant: QdrantClient,
81
+ interval_seconds: int,
82
+ stop_event: asyncio.Event,
83
+ ) -> None:
84
+ """
85
+ Periodically ping Qdrant so the deployment keeps an active connection.
86
+
87
+ Uses asyncio.to_thread because qdrant-client methods are synchronous.
88
+ """
89
+ if interval_seconds <= 0:
90
+ return
91
+
92
+ while not stop_event.is_set():
93
+ try:
94
+ await asyncio.wait_for(stop_event.wait(), timeout=interval_seconds)
95
+ break
96
+ except TimeoutError:
97
+ pass
98
+
99
+ try:
100
+ await asyncio.to_thread(qdrant.get_collections)
101
+ logger.info("Qdrant keepalive ping succeeded")
102
+ except Exception as exc:
103
+ logger.warning("Qdrant keepalive ping failed: %s", exc)
104
+
105
+
106
  @asynccontextmanager
107
  async def lifespan(app: FastAPI):
108
  settings = get_settings()
 
160
  from app.services.vector_store import VectorStore
161
  from app.security.guard_classifier import GuardClassifier
162
 
163
+ qdrant_url = (settings.QDRANT_URL or "").strip()
164
  qdrant = QdrantClient(
165
+ url=qdrant_url,
166
  api_key=settings.QDRANT_API_KEY,
167
  timeout=60,
168
  )
 
170
  vector_store = VectorStore(qdrant, settings.QDRANT_COLLECTION)
171
  # Idempotent: creates collection if absent so a cold-start before first
172
  # ingest run doesn't crash every search with "collection not found".
173
+ try:
174
+ vector_store.ensure_collection()
175
+ except UnexpectedResponse as exc:
176
+ fallback_url = _normalize_qdrant_url(qdrant_url)
177
+ if _is_qdrant_not_found(exc) and fallback_url and fallback_url != qdrant_url:
178
+ logger.warning(
179
+ "Qdrant URL returned 404, retrying with normalized root URL | original=%s normalized=%s",
180
+ qdrant_url,
181
+ fallback_url,
182
+ )
183
+ qdrant.close()
184
+ qdrant = QdrantClient(
185
+ url=fallback_url,
186
+ api_key=settings.QDRANT_API_KEY,
187
+ timeout=60,
188
+ )
189
+ vector_store = VectorStore(qdrant, settings.QDRANT_COLLECTION)
190
+ vector_store.ensure_collection()
191
+ else:
192
+ raise
193
 
194
  # Issue 7: shared TPM bucket tracks token consumption across the current 60s
195
  # window. Injected into GroqClient so it can downgrade 70B → 8B automatically
 
214
  app.state.settings = settings
215
  app.state.qdrant = qdrant
216
 
217
+ keepalive_stop = asyncio.Event()
218
+ keepalive_task = asyncio.create_task(
219
+ _qdrant_keepalive_loop(
220
+ qdrant=qdrant,
221
+ interval_seconds=settings.QDRANT_KEEPALIVE_SECONDS,
222
+ stop_event=keepalive_stop,
223
+ )
224
+ )
225
+ app.state.qdrant_keepalive_stop = keepalive_stop
226
+ app.state.qdrant_keepalive_task = keepalive_task
227
+
228
  logger.info("Startup complete")
229
  yield
230
 
231
  logger.info("Shutting down")
232
+ app.state.qdrant_keepalive_stop.set()
233
+ try:
234
+ await asyncio.wait_for(app.state.qdrant_keepalive_task, timeout=2)
235
+ except TimeoutError:
236
+ app.state.qdrant_keepalive_task.cancel()
237
+ except Exception:
238
+ pass
239
  app.state.semantic_cache = None
240
  app.state.qdrant.close()
241
  # Only attempt to end an MLflow run when DagsHub tracking was enabled at startup.
app/pipeline/nodes/generate.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import logging
2
  import re
3
  from typing import Callable
@@ -220,6 +221,26 @@ def _dedup_sources(source_refs: list[SourceRef], limit: int | None = None) -> li
220
  return result
221
 
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  def make_generate_node(llm_client: LLMClient, gemini_client=None) -> Callable[[PipelineState], dict]: # noqa: ANN001
224
  # Number of token chunks to buffer before deciding there is no CoT block.
225
  # Llama 3.1 8B may omit <think> entirely; Llama 3.3 70B always starts with one.
@@ -434,6 +455,8 @@ def make_generate_node(llm_client: LLMClient, gemini_client=None) -> Callable[[P
434
  if reformatted:
435
  full_answer = reformatted
436
 
 
 
437
  # Only surface sources the LLM actually cited, deduplicated by URL so
438
  # multiple chunks from the same document show as one source card.
439
  cited_indices = {int(m) for m in re.findall(r"\[(\d+)\]", full_answer)}
 
1
+ import asyncio
2
  import logging
3
  import re
4
  from typing import Callable
 
221
  return result
222
 
223
 
224
+ def _normalise_answer_text(answer: str, max_citation_index: int) -> str:
225
+ """
226
+ Clean up model output while preserving citation semantics.
227
+
228
+ - Drops out-of-range citation markers like [99] when only 5 passages exist.
229
+ - Collapses adjacent duplicate citations ([2][2] -> [2]).
230
+ - Normalizes punctuation spacing and excess blank lines.
231
+ """
232
+
233
+ def _keep_valid_citation(match: re.Match[str]) -> str:
234
+ idx = int(match.group(1))
235
+ return f"[{idx}]" if 1 <= idx <= max_citation_index else ""
236
+
237
+ cleaned = re.sub(r"\[(\d+)\]", _keep_valid_citation, answer)
238
+ cleaned = re.sub(r"(\[\d+\])(\1)+", r"\1", cleaned)
239
+ cleaned = re.sub(r"\s+([,.;:!?])", r"\1", cleaned)
240
+ cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
241
+ return cleaned.strip()
242
+
243
+
244
  def make_generate_node(llm_client: LLMClient, gemini_client=None) -> Callable[[PipelineState], dict]: # noqa: ANN001
245
  # Number of token chunks to buffer before deciding there is no CoT block.
246
  # Llama 3.1 8B may omit <think> entirely; Llama 3.3 70B always starts with one.
 
455
  if reformatted:
456
  full_answer = reformatted
457
 
458
+ full_answer = _normalise_answer_text(full_answer, max_citation_index=len(source_refs))
459
+
460
  # Only surface sources the LLM actually cited, deduplicated by URL so
461
  # multiple chunks from the same document show as one source card.
462
  cited_indices = {int(m) for m in re.findall(r"\[(\d+)\]", full_answer)}
tests/test_chat_source_filtering.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.api.chat import _filter_sources_by_citations
2
+
3
+
4
+ def test_filter_sources_by_citations_keeps_matching_positions() -> None:
5
+ sources = [
6
+ {"title": "A"},
7
+ {"title": "B"},
8
+ {"title": "C"},
9
+ ]
10
+ answer = "Uses A [1] and C [3]."
11
+
12
+ filtered = _filter_sources_by_citations(answer, sources)
13
+
14
+ assert [s["title"] for s in filtered] == ["A", "C"]
15
+
16
+
17
+ def test_filter_sources_by_citations_skips_reindex_mismatch() -> None:
18
+ # Upstream may already return only cited sources while answer keeps original
19
+ # citation numbers (e.g. [3][5]). We must not strip them again here.
20
+ sources = [
21
+ {"title": "Third source"},
22
+ {"title": "Fifth source"},
23
+ ]
24
+ answer = "Summary from [3] and [5]."
25
+
26
+ filtered = _filter_sources_by_citations(answer, sources)
27
+
28
+ assert filtered == sources
29
+
30
+
31
+ def test_filter_sources_by_citations_no_citations_returns_input() -> None:
32
+ sources = [{"title": "A"}]
33
+ answer = "No explicit references in this sentence."
34
+
35
+ filtered = _filter_sources_by_citations(answer, sources)
36
+
37
+ assert filtered == sources
tests/test_qdrant_keepalive.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ import pytest
4
+
5
+ from app.main import _qdrant_keepalive_loop
6
+
7
+
8
+ class _FakeQdrant:
9
+ def __init__(self) -> None:
10
+ self.calls = 0
11
+
12
+ def get_collections(self) -> None:
13
+ self.calls += 1
14
+
15
+
16
+ @pytest.mark.asyncio
17
+ async def test_keepalive_loop_pings_qdrant() -> None:
18
+ qdrant = _FakeQdrant()
19
+ stop_event = asyncio.Event()
20
+
21
+ task = asyncio.create_task(
22
+ _qdrant_keepalive_loop(qdrant=qdrant, interval_seconds=1, stop_event=stop_event)
23
+ )
24
+
25
+ await asyncio.sleep(1.2)
26
+ stop_event.set()
27
+ await asyncio.wait_for(task, timeout=1)
28
+
29
+ assert qdrant.calls >= 1
30
+
31
+
32
+ @pytest.mark.asyncio
33
+ async def test_keepalive_loop_disabled_when_interval_non_positive() -> None:
34
+ qdrant = _FakeQdrant()
35
+ stop_event = asyncio.Event()
36
+
37
+ await _qdrant_keepalive_loop(qdrant=qdrant, interval_seconds=0, stop_event=stop_event)
38
+
39
+ assert qdrant.calls == 0
tests/test_qdrant_url_normalization.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.main import _normalize_qdrant_url
2
+
3
+
4
+ def test_normalize_qdrant_url_strips_dashboard_path() -> None:
5
+ url = "https://example.qdrant.io/dashboard"
6
+ assert _normalize_qdrant_url(url) == "https://example.qdrant.io"
7
+
8
+
9
+ def test_normalize_qdrant_url_adds_scheme_for_cloud_host() -> None:
10
+ url = "cluster-id.aws.cloud.qdrant.io"
11
+ assert _normalize_qdrant_url(url) == "https://cluster-id.aws.cloud.qdrant.io"
12
+
13
+
14
+ def test_normalize_qdrant_url_uses_http_for_localhost() -> None:
15
+ url = "localhost:6333"
16
+ assert _normalize_qdrant_url(url) == "http://localhost:6333"