Spaces:
Running
Running
GitHub Actions commited on
Commit ·
c44df3b
1
Parent(s): b616cc1
Deploy 2e8cff3
Browse files- app/api/tts.py +4 -1
- app/core/portfolio_context.py +3 -0
- app/models/speech.py +1 -0
- app/pipeline/graph.py +6 -2
- app/pipeline/nodes/retrieve.py +70 -3
- app/services/reranker.py +30 -23
- app/services/tts_client.py +2 -2
- tests/test_enumerate_query.py +3 -0
- tests/test_graph_routing.py +30 -0
- tests/test_models.py +11 -0
- tests/test_retrieve_query_normalization.py +9 -1
- tests/test_speech_endpoints.py +27 -1
app/api/tts.py
CHANGED
|
@@ -22,5 +22,8 @@ async def synthesize_endpoint(
|
|
| 22 |
detail="TTS service is not configured.",
|
| 23 |
)
|
| 24 |
|
| 25 |
-
audio_bytes = await tts_client.synthesize(
|
|
|
|
|
|
|
|
|
|
| 26 |
return Response(content=audio_bytes, media_type="audio/wav")
|
|
|
|
| 22 |
detail="TTS service is not configured.",
|
| 23 |
)
|
| 24 |
|
| 25 |
+
audio_bytes = await tts_client.synthesize(
|
| 26 |
+
payload.text.strip(),
|
| 27 |
+
voice=payload.voice.strip().lower(),
|
| 28 |
+
)
|
| 29 |
return Response(content=audio_bytes, media_type="audio/wav")
|
app/core/portfolio_context.py
CHANGED
|
@@ -87,6 +87,9 @@ KNOWN_INTENTS: frozenset[str] = frozenset({
|
|
| 87 |
"work", "experience", "work experience", "career", "employment", "job", "role",
|
| 88 |
"internship", "internships", "skills", "skill", "education", "degree", "university",
|
| 89 |
"resume", "cv", "background", "certification", "certifications",
|
|
|
|
|
|
|
|
|
|
| 90 |
})
|
| 91 |
|
| 92 |
# ---------------------------------------------------------------------------
|
|
|
|
| 87 |
"work", "experience", "work experience", "career", "employment", "job", "role",
|
| 88 |
"internship", "internships", "skills", "skill", "education", "degree", "university",
|
| 89 |
"resume", "cv", "background", "certification", "certifications",
|
| 90 |
+
"tech", "stack", "tech stack", "technology", "technologies",
|
| 91 |
+
"framework", "frameworks", "tool", "tools", "tooling",
|
| 92 |
+
"language", "languages",
|
| 93 |
})
|
| 94 |
|
| 95 |
# ---------------------------------------------------------------------------
|
app/models/speech.py
CHANGED
|
@@ -7,3 +7,4 @@ class TranscribeResponse(BaseModel):
|
|
| 7 |
|
| 8 |
class SynthesizeRequest(BaseModel):
|
| 9 |
text: str = Field(..., min_length=1, max_length=300)
|
|
|
|
|
|
| 7 |
|
| 8 |
class SynthesizeRequest(BaseModel):
|
| 9 |
text: str = Field(..., min_length=1, max_length=300)
|
| 10 |
+
voice: str = Field(default="am_adam", min_length=2, max_length=32)
|
app/pipeline/graph.py
CHANGED
|
@@ -97,8 +97,12 @@ def route_retrieve_result(state: PipelineState) -> str:
|
|
| 97 |
# also failed (still empty after the first CRAG rewrite). When the query
|
| 98 |
# mentions a known portfolio entity, attempt one more vocabulary-shifted rewrite
|
| 99 |
# before admitting the not-found path.
|
| 100 |
-
if attempts == 3 and
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
return "generate"
|
| 104 |
|
|
|
|
| 97 |
# also failed (still empty after the first CRAG rewrite). When the query
|
| 98 |
# mentions a known portfolio entity, attempt one more vocabulary-shifted rewrite
|
| 99 |
# before admitting the not-found path.
|
| 100 |
+
if attempts == 3 and is_portfolio_relevant(query):
|
| 101 |
+
if not reranked:
|
| 102 |
+
return "rewrite"
|
| 103 |
+
top_score = state.get("top_rerank_score")
|
| 104 |
+
if top_score is not None and top_score < _CRAG_LOW_CONFIDENCE_SCORE:
|
| 105 |
+
return "rewrite"
|
| 106 |
|
| 107 |
return "generate"
|
| 108 |
|
app/pipeline/nodes/retrieve.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import asyncio
|
| 2 |
import logging
|
|
|
|
| 3 |
from typing import Callable
|
| 4 |
|
| 5 |
from langgraph.config import get_stream_writer
|
|
@@ -21,6 +22,7 @@ from app.services.sparse_encoder import SparseEncoder
|
|
| 21 |
# unrelated (noise), while –3.5 to –1.0 still captures valid skill/project
|
| 22 |
# passages that answer tech-stack or experience questions.
|
| 23 |
_MIN_TOP_SCORE: float = -3.5
|
|
|
|
| 24 |
|
| 25 |
# Default cap: max chunks per source document for BROAD queries.
|
| 26 |
# Without this, a verbose doc can crowd out all 5 context slots, hiding other
|
|
@@ -45,7 +47,9 @@ _FOCUS_KEYWORDS: dict[frozenset[str], str] = {
|
|
| 45 |
frozenset({"experience", "work", "job", "role", "career", "internship",
|
| 46 |
"skills", "skill", "education", "degree", "university",
|
| 47 |
"certification", "certifications", "qualified", "resume", "cv",
|
| 48 |
-
"employment", "professional", "placement", "history"
|
|
|
|
|
|
|
| 49 |
frozenset({"project", "built", "build", "developed", "architecture",
|
| 50 |
"system", "platform", "app", "application"}): "project",
|
| 51 |
frozenset({"blog", "post", "article", "wrote", "writing", "published"}): "blog",
|
|
@@ -59,6 +63,43 @@ _RRF_K: int = 60
|
|
| 59 |
# Module-level singleton — BM25 model downloads once (~5 MB), cached in memory.
|
| 60 |
_sparse_encoder = SparseEncoder()
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
def _focused_source_type(query: str) -> str | None:
|
| 64 |
"""
|
|
@@ -69,7 +110,7 @@ def _focused_source_type(query: str) -> str | None:
|
|
| 69 |
that don't match any category retain the 2-per-doc default cap so no single
|
| 70 |
source dominates the 5 context slots.
|
| 71 |
"""
|
| 72 |
-
tokens = frozenset(query.lower()
|
| 73 |
for keyword_set, source_type in _FOCUS_KEYWORDS.items():
|
| 74 |
if tokens & keyword_set:
|
| 75 |
return source_type
|
|
@@ -166,6 +207,9 @@ def _normalise_focus_typos(query: str) -> str:
|
|
| 166 |
if len(stripped) < 4 or stripped in _FOCUS_VOCAB:
|
| 167 |
corrected.append(token)
|
| 168 |
continue
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
replacement = _best_focus_replacement(stripped)
|
| 171 |
|
|
@@ -177,6 +221,11 @@ def _normalise_focus_typos(query: str) -> str:
|
|
| 177 |
return " ".join(corrected)
|
| 178 |
|
| 179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
def make_retrieve_node(
|
| 181 |
vector_store: VectorStore, embedder: Embedder, reranker: Reranker
|
| 182 |
) -> Callable[[PipelineState], dict]:
|
|
@@ -427,13 +476,31 @@ def make_retrieve_node(
|
|
| 427 |
|
| 428 |
# ── Relevance gate ─────────────────────────────────────────────────────
|
| 429 |
top_score = reranked[0]["metadata"].get("rerank_score", 0.0) if reranked else None
|
| 430 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
return {
|
| 432 |
"answer": "",
|
| 433 |
"retrieved_chunks": [],
|
| 434 |
"reranked_chunks": [],
|
| 435 |
"retrieval_attempts": attempts + 1, "top_rerank_score": top_score, }
|
| 436 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
# ── Source diversity cap (query-aware) ─────────────────────────────────
|
| 438 |
focused_type = _focused_source_type(retrieval_query)
|
| 439 |
doc_counts: dict[str, int] = {}
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import logging
|
| 3 |
+
import re
|
| 4 |
from typing import Callable
|
| 5 |
|
| 6 |
from langgraph.config import get_stream_writer
|
|
|
|
| 22 |
# unrelated (noise), while –3.5 to –1.0 still captures valid skill/project
|
| 23 |
# passages that answer tech-stack or experience questions.
|
| 24 |
_MIN_TOP_SCORE: float = -3.5
|
| 25 |
+
_MIN_RESCUE_SCORE: float = -6.0
|
| 26 |
|
| 27 |
# Default cap: max chunks per source document for BROAD queries.
|
| 28 |
# Without this, a verbose doc can crowd out all 5 context slots, hiding other
|
|
|
|
| 47 |
frozenset({"experience", "work", "job", "role", "career", "internship",
|
| 48 |
"skills", "skill", "education", "degree", "university",
|
| 49 |
"certification", "certifications", "qualified", "resume", "cv",
|
| 50 |
+
"employment", "professional", "placement", "history",
|
| 51 |
+
"tech", "stack", "technology", "technologies", "framework",
|
| 52 |
+
"frameworks", "tool", "tools", "tooling", "language", "languages"}): "cv",
|
| 53 |
frozenset({"project", "built", "build", "developed", "architecture",
|
| 54 |
"system", "platform", "app", "application"}): "project",
|
| 55 |
frozenset({"blog", "post", "article", "wrote", "writing", "published"}): "blog",
|
|
|
|
| 63 |
# Module-level singleton — BM25 model downloads once (~5 MB), cached in memory.
|
| 64 |
_sparse_encoder = SparseEncoder()
|
| 65 |
|
| 66 |
+
_CAPABILITY_QUERY_HINTS: frozenset[str] = frozenset(
|
| 67 |
+
{
|
| 68 |
+
"tech",
|
| 69 |
+
"stack",
|
| 70 |
+
"technology",
|
| 71 |
+
"technologies",
|
| 72 |
+
"framework",
|
| 73 |
+
"frameworks",
|
| 74 |
+
"tool",
|
| 75 |
+
"tools",
|
| 76 |
+
"tooling",
|
| 77 |
+
"language",
|
| 78 |
+
"languages",
|
| 79 |
+
"skills",
|
| 80 |
+
"skill",
|
| 81 |
+
}
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
_NORMALISATION_STOPWORDS: frozenset[str] = frozenset(
|
| 85 |
+
{
|
| 86 |
+
"tell",
|
| 87 |
+
"about",
|
| 88 |
+
"what",
|
| 89 |
+
"which",
|
| 90 |
+
"where",
|
| 91 |
+
"when",
|
| 92 |
+
"could",
|
| 93 |
+
"would",
|
| 94 |
+
"should",
|
| 95 |
+
"your",
|
| 96 |
+
"with",
|
| 97 |
+
"from",
|
| 98 |
+
"that",
|
| 99 |
+
"this",
|
| 100 |
+
}
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
|
| 104 |
def _focused_source_type(query: str) -> str | None:
|
| 105 |
"""
|
|
|
|
| 110 |
that don't match any category retain the 2-per-doc default cap so no single
|
| 111 |
source dominates the 5 context slots.
|
| 112 |
"""
|
| 113 |
+
tokens = frozenset(re.findall(r"[a-z0-9]+", query.lower()))
|
| 114 |
for keyword_set, source_type in _FOCUS_KEYWORDS.items():
|
| 115 |
if tokens & keyword_set:
|
| 116 |
return source_type
|
|
|
|
| 207 |
if len(stripped) < 4 or stripped in _FOCUS_VOCAB:
|
| 208 |
corrected.append(token)
|
| 209 |
continue
|
| 210 |
+
if stripped in _NORMALISATION_STOPWORDS:
|
| 211 |
+
corrected.append(token)
|
| 212 |
+
continue
|
| 213 |
|
| 214 |
replacement = _best_focus_replacement(stripped)
|
| 215 |
|
|
|
|
| 221 |
return " ".join(corrected)
|
| 222 |
|
| 223 |
|
| 224 |
+
def _is_capability_query(query: str) -> bool:
|
| 225 |
+
tokens = frozenset(re.findall(r"[a-z0-9]+", query.lower()))
|
| 226 |
+
return bool(tokens & _CAPABILITY_QUERY_HINTS)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
def make_retrieve_node(
|
| 230 |
vector_store: VectorStore, embedder: Embedder, reranker: Reranker
|
| 231 |
) -> Callable[[PipelineState], dict]:
|
|
|
|
| 476 |
|
| 477 |
# ── Relevance gate ─────────────────────────────────────────────────────
|
| 478 |
top_score = reranked[0]["metadata"].get("rerank_score", 0.0) if reranked else None
|
| 479 |
+
low_confidence = top_score is not None and top_score < _MIN_TOP_SCORE
|
| 480 |
+
capability_query = _is_capability_query(retrieval_query)
|
| 481 |
+
rescue_low_confidence = bool(
|
| 482 |
+
reranked
|
| 483 |
+
and low_confidence
|
| 484 |
+
and top_score is not None
|
| 485 |
+
and top_score >= _MIN_RESCUE_SCORE
|
| 486 |
+
and (capability_query or _focused_source_type(retrieval_query) is not None)
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
if not reranked or (low_confidence and not rescue_low_confidence):
|
| 490 |
return {
|
| 491 |
"answer": "",
|
| 492 |
"retrieved_chunks": [],
|
| 493 |
"reranked_chunks": [],
|
| 494 |
"retrieval_attempts": attempts + 1, "top_rerank_score": top_score, }
|
| 495 |
|
| 496 |
+
if rescue_low_confidence:
|
| 497 |
+
writer(
|
| 498 |
+
{
|
| 499 |
+
"type": "status",
|
| 500 |
+
"label": "Applying retrieval rescue for portfolio capability query...",
|
| 501 |
+
}
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
# ── Source diversity cap (query-aware) ─────────────────────────────────
|
| 505 |
focused_type = _focused_source_type(retrieval_query)
|
| 506 |
doc_counts: dict[str, int] = {}
|
app/services/reranker.py
CHANGED
|
@@ -6,6 +6,7 @@
|
|
| 6 |
from typing import Any, Optional
|
| 7 |
|
| 8 |
import httpx
|
|
|
|
| 9 |
|
| 10 |
from app.models.pipeline import Chunk
|
| 11 |
|
|
@@ -43,29 +44,35 @@ class Reranker:
|
|
| 43 |
texts = [chunk.get("contextualised_text") or chunk["text"] for chunk in chunks]
|
| 44 |
|
| 45 |
if self._remote:
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
model = _get_local_model()
|
| 71 |
pairs = [(query, text) for text in texts]
|
|
|
|
| 6 |
from typing import Any, Optional
|
| 7 |
|
| 8 |
import httpx
|
| 9 |
+
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
| 10 |
|
| 11 |
from app.models.pipeline import Chunk
|
| 12 |
|
|
|
|
| 44 |
texts = [chunk.get("contextualised_text") or chunk["text"] for chunk in chunks]
|
| 45 |
|
| 46 |
if self._remote:
|
| 47 |
+
@retry(
|
| 48 |
+
stop=stop_after_attempt(2),
|
| 49 |
+
wait=wait_exponential(multiplier=0.4, min=0.4, max=1.2),
|
| 50 |
+
retry=retry_if_exception_type((httpx.TimeoutException, httpx.HTTPError)),
|
| 51 |
+
reraise=True,
|
| 52 |
+
)
|
| 53 |
+
async def _remote_call() -> tuple[list[int], list[float]]:
|
| 54 |
+
async with httpx.AsyncClient(timeout=60.0) as client:
|
| 55 |
+
truncated = [t[:1500] for t in texts]
|
| 56 |
+
resp = await client.post(
|
| 57 |
+
f"{self._url}/rerank",
|
| 58 |
+
json={"query": query[:512], "texts": truncated, "top_k": top_k},
|
| 59 |
+
)
|
| 60 |
+
resp.raise_for_status()
|
| 61 |
+
data = resp.json()
|
| 62 |
+
indices = data.get("indices")
|
| 63 |
+
scores = data.get("scores")
|
| 64 |
+
if not isinstance(indices, list) or not isinstance(scores, list):
|
| 65 |
+
raise httpx.HTTPError("Invalid reranker response schema")
|
| 66 |
+
return [int(i) for i in indices], [float(s) for s in scores]
|
| 67 |
+
|
| 68 |
+
indices, scores = await _remote_call()
|
| 69 |
+
result = []
|
| 70 |
+
for idx, score in zip(indices, scores):
|
| 71 |
+
chunk_copy = dict(chunks[idx])
|
| 72 |
+
chunk_copy["metadata"]["rerank_score"] = score
|
| 73 |
+
result.append(chunk_copy)
|
| 74 |
+
self._min_score = scores[-1] if scores else 0.0
|
| 75 |
+
return result # type: ignore[return-value]
|
| 76 |
|
| 77 |
model = _get_local_model()
|
| 78 |
pairs = [(query, text) for text in texts]
|
app/services/tts_client.py
CHANGED
|
@@ -14,7 +14,7 @@ class TTSClient:
|
|
| 14 |
def is_configured(self) -> bool:
|
| 15 |
return bool(self._tts_space_url)
|
| 16 |
|
| 17 |
-
async def synthesize(self, text: str) -> bytes:
|
| 18 |
if not self.is_configured:
|
| 19 |
raise GenerationError("TTS client is not configured")
|
| 20 |
|
|
@@ -22,7 +22,7 @@ class TTSClient:
|
|
| 22 |
async with httpx.AsyncClient(timeout=self._timeout_seconds) as client:
|
| 23 |
response = await client.post(
|
| 24 |
f"{self._tts_space_url}/synthesize",
|
| 25 |
-
json={"text": text},
|
| 26 |
headers={"Content-Type": "application/json"},
|
| 27 |
)
|
| 28 |
response.raise_for_status()
|
|
|
|
| 14 |
def is_configured(self) -> bool:
|
| 15 |
return bool(self._tts_space_url)
|
| 16 |
|
| 17 |
+
async def synthesize(self, text: str, voice: str = "am_adam") -> bytes:
|
| 18 |
if not self.is_configured:
|
| 19 |
raise GenerationError("TTS client is not configured")
|
| 20 |
|
|
|
|
| 22 |
async with httpx.AsyncClient(timeout=self._timeout_seconds) as client:
|
| 23 |
response = await client.post(
|
| 24 |
f"{self._tts_space_url}/synthesize",
|
| 25 |
+
json={"text": text, "voice": voice},
|
| 26 |
headers={"Content-Type": "application/json"},
|
| 27 |
)
|
| 28 |
response.raise_for_status()
|
tests/test_enumerate_query.py
CHANGED
|
@@ -217,3 +217,6 @@ class TestIsPortfolioRelevant:
|
|
| 217 |
|
| 218 |
def test_stt_typo_work_experience_is_still_relevant(self):
|
| 219 |
assert is_portfolio_relevant("tell me about his walk experience") is True
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
def test_stt_typo_work_experience_is_still_relevant(self):
|
| 219 |
assert is_portfolio_relevant("tell me about his walk experience") is True
|
| 220 |
+
|
| 221 |
+
def test_tech_stack_intent_is_relevant(self):
|
| 222 |
+
assert is_portfolio_relevant("Could you tell me about his tech stack?") is True
|
tests/test_graph_routing.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from app.pipeline.graph import route_retrieve_result
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def test_attempt_three_portfolio_empty_rewrites() -> None:
|
| 5 |
+
state = {
|
| 6 |
+
"retrieval_attempts": 3,
|
| 7 |
+
"reranked_chunks": [],
|
| 8 |
+
"query": "Could you tell me about his tech stack?",
|
| 9 |
+
}
|
| 10 |
+
assert route_retrieve_result(state) == "rewrite"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_attempt_three_portfolio_low_confidence_rewrites() -> None:
|
| 14 |
+
state = {
|
| 15 |
+
"retrieval_attempts": 3,
|
| 16 |
+
"reranked_chunks": [{"text": "x", "metadata": {}}],
|
| 17 |
+
"top_rerank_score": -2.0,
|
| 18 |
+
"query": "Could you tell me about his tech stack?",
|
| 19 |
+
}
|
| 20 |
+
assert route_retrieve_result(state) == "rewrite"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def test_attempt_three_unrelated_low_confidence_generates() -> None:
|
| 24 |
+
state = {
|
| 25 |
+
"retrieval_attempts": 3,
|
| 26 |
+
"reranked_chunks": [{"text": "x", "metadata": {}}],
|
| 27 |
+
"top_rerank_score": -2.0,
|
| 28 |
+
"query": "what is the weather in london",
|
| 29 |
+
}
|
| 30 |
+
assert route_retrieve_result(state) == "generate"
|
tests/test_models.py
CHANGED
|
@@ -5,6 +5,7 @@
|
|
| 5 |
import pytest
|
| 6 |
from pydantic import ValidationError
|
| 7 |
from app.models.chat import ChatRequest, SourceRef, ChatResponse
|
|
|
|
| 8 |
|
| 9 |
VALID_UUID = "a1b2c3d4-e5f6-4789-8abc-def012345678"
|
| 10 |
|
|
@@ -85,3 +86,13 @@ class TestChatResponse:
|
|
| 85 |
assert resp.cached is False
|
| 86 |
assert resp.latency_ms == 312
|
| 87 |
assert len(resp.sources) == 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import pytest
|
| 6 |
from pydantic import ValidationError
|
| 7 |
from app.models.chat import ChatRequest, SourceRef, ChatResponse
|
| 8 |
+
from app.models.speech import SynthesizeRequest
|
| 9 |
|
| 10 |
VALID_UUID = "a1b2c3d4-e5f6-4789-8abc-def012345678"
|
| 11 |
|
|
|
|
| 86 |
assert resp.cached is False
|
| 87 |
assert resp.latency_ms == 312
|
| 88 |
assert len(resp.sources) == 1
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class TestSynthesizeRequest:
|
| 92 |
+
def test_default_voice_is_male(self):
|
| 93 |
+
req = SynthesizeRequest(text="hello")
|
| 94 |
+
assert req.voice == "am_adam"
|
| 95 |
+
|
| 96 |
+
def test_voice_too_long_rejected(self):
|
| 97 |
+
with pytest.raises(ValidationError):
|
| 98 |
+
SynthesizeRequest(text="hello", voice="x" * 33)
|
tests/test_retrieve_query_normalization.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from app.pipeline.nodes.retrieve import _normalise_focus_typos
|
| 2 |
|
| 3 |
|
| 4 |
def test_walk_experience_normalises_to_work_experience() -> None:
|
|
@@ -10,3 +10,11 @@ def test_walk_experience_normalises_to_work_experience() -> None:
|
|
| 10 |
def test_non_focus_text_is_not_overwritten() -> None:
|
| 11 |
original = "Tell me about widget orchestration internals"
|
| 12 |
assert _normalise_focus_typos(original) == original.lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from app.pipeline.nodes.retrieve import _focused_source_type, _is_capability_query, _normalise_focus_typos
|
| 2 |
|
| 3 |
|
| 4 |
def test_walk_experience_normalises_to_work_experience() -> None:
|
|
|
|
| 10 |
def test_non_focus_text_is_not_overwritten() -> None:
|
| 11 |
original = "Tell me about widget orchestration internals"
|
| 12 |
assert _normalise_focus_typos(original) == original.lower()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def test_capability_query_detection_handles_punctuation() -> None:
|
| 16 |
+
assert _is_capability_query("What tech stack does he use?") is True
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def test_focus_source_type_for_tech_stack_query() -> None:
|
| 20 |
+
assert _focused_source_type("What technologies and skills does he work with?") == "cv"
|
tests/test_speech_endpoints.py
CHANGED
|
@@ -44,8 +44,12 @@ def test_tts_requires_auth(app_client):
|
|
| 44 |
|
| 45 |
|
| 46 |
def test_tts_success(app_client, valid_token):
|
| 47 |
-
|
|
|
|
|
|
|
| 48 |
await asyncio.sleep(0)
|
|
|
|
|
|
|
| 49 |
return b"RIFF....fake"
|
| 50 |
|
| 51 |
app_client.app.state.tts_client.synthesize = fake_synthesize
|
|
@@ -59,3 +63,25 @@ def test_tts_success(app_client, valid_token):
|
|
| 59 |
assert response.status_code == 200
|
| 60 |
assert response.headers.get("content-type", "").startswith("audio/wav")
|
| 61 |
assert response.content == b"RIFF....fake"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
def test_tts_success(app_client, valid_token):
|
| 47 |
+
captured: dict[str, str] = {}
|
| 48 |
+
|
| 49 |
+
async def fake_synthesize(text, voice="am_adam"):
|
| 50 |
await asyncio.sleep(0)
|
| 51 |
+
captured["text"] = text
|
| 52 |
+
captured["voice"] = voice
|
| 53 |
return b"RIFF....fake"
|
| 54 |
|
| 55 |
app_client.app.state.tts_client.synthesize = fake_synthesize
|
|
|
|
| 63 |
assert response.status_code == 200
|
| 64 |
assert response.headers.get("content-type", "").startswith("audio/wav")
|
| 65 |
assert response.content == b"RIFF....fake"
|
| 66 |
+
assert captured["text"] == "Hello world"
|
| 67 |
+
assert captured["voice"] == "am_adam"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def test_tts_uses_provided_voice(app_client, valid_token):
|
| 71 |
+
captured: dict[str, str] = {}
|
| 72 |
+
|
| 73 |
+
async def fake_synthesize(text, voice="am_adam"):
|
| 74 |
+
await asyncio.sleep(0)
|
| 75 |
+
captured["voice"] = voice
|
| 76 |
+
return b"RIFF....fake"
|
| 77 |
+
|
| 78 |
+
app_client.app.state.tts_client.synthesize = fake_synthesize
|
| 79 |
+
|
| 80 |
+
response = app_client.post(
|
| 81 |
+
"/tts",
|
| 82 |
+
json={"text": "Hello world", "voice": "af_heart"},
|
| 83 |
+
headers={"Authorization": f"Bearer {valid_token}"},
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
assert response.status_code == 200
|
| 87 |
+
assert captured["voice"] == "af_heart"
|