Commit ·
76be5a0
1
Parent(s): 54a5940
Replace enforce_terminology with pinned glossary doc in RAG context
Browse files- backend/pipeline.py +10 -6
- backend/rosetta.py +11 -16
backend/pipeline.py
CHANGED
|
@@ -26,7 +26,7 @@ from sentence_transformers import SentenceTransformer
|
|
| 26 |
|
| 27 |
from config import features_path, domain_for, DISPLAY_NAMES
|
| 28 |
from grader import grade, GradeReport, get_embedder
|
| 29 |
-
from rosetta import client_terms,
|
| 30 |
|
| 31 |
log = logging.getLogger(__name__)
|
| 32 |
|
|
@@ -320,11 +320,15 @@ def run(
|
|
| 320 |
if scores[i] > MIN_RETRIEVAL_SCORE
|
| 321 |
]
|
| 322 |
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
report = grade(
|
| 329 |
query=query,
|
| 330 |
response=answer,
|
|
|
|
| 26 |
|
| 27 |
from config import features_path, domain_for, DISPLAY_NAMES
|
| 28 |
from grader import grade, GradeReport, get_embedder
|
| 29 |
+
from rosetta import client_terms, client_terms_doc
|
| 30 |
|
| 31 |
log = logging.getLogger(__name__)
|
| 32 |
|
|
|
|
| 320 |
if scores[i] > MIN_RETRIEVAL_SCORE
|
| 321 |
]
|
| 322 |
|
| 323 |
+
terms_doc = client_terms_doc(client)
|
| 324 |
+
pinned = RetrievedDoc(
|
| 325 |
+
id=terms_doc["id"],
|
| 326 |
+
title=terms_doc["title"],
|
| 327 |
+
content=terms_doc["content"],
|
| 328 |
+
score=1.0,
|
| 329 |
+
)
|
| 330 |
+
context = _build_context([pinned] + retrieved)
|
| 331 |
+
answer = _generate(query, context, client, domain, hf_client)
|
| 332 |
report = grade(
|
| 333 |
query=query,
|
| 334 |
response=answer,
|
backend/rosetta.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
| 1 |
"""RosettaStone: canonical term -> client-specific term translation."""
|
| 2 |
|
| 3 |
-
import re
|
| 4 |
import yaml
|
| 5 |
from functools import lru_cache
|
| 6 |
|
| 7 |
-
from config import term_catalog_path, domain_for
|
| 8 |
|
| 9 |
|
| 10 |
@lru_cache(maxsize=8)
|
|
@@ -29,20 +28,16 @@ def client_terms(client: str) -> dict[str, str]:
|
|
| 29 |
return dict(catalog.get(client, {}))
|
| 30 |
|
| 31 |
|
| 32 |
-
def
|
| 33 |
-
"""
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
if rival and re.search(re.escape(rival), result, re.IGNORECASE):
|
| 43 |
-
result = re.sub(re.escape(rival), client_term, result, flags=re.IGNORECASE)
|
| 44 |
-
replacements.append({"from": rival, "to": client_term})
|
| 45 |
-
return result, replacements
|
| 46 |
|
| 47 |
|
| 48 |
def check_terminology(response_text: str, client: str) -> dict:
|
|
|
|
| 1 |
"""RosettaStone: canonical term -> client-specific term translation."""
|
| 2 |
|
|
|
|
| 3 |
import yaml
|
| 4 |
from functools import lru_cache
|
| 5 |
|
| 6 |
+
from config import term_catalog_path, domain_for, DISPLAY_NAMES
|
| 7 |
|
| 8 |
|
| 9 |
@lru_cache(maxsize=8)
|
|
|
|
| 28 |
return dict(catalog.get(client, {}))
|
| 29 |
|
| 30 |
|
| 31 |
+
def client_terms_doc(client: str) -> dict:
|
| 32 |
+
"""Return the term catalog as a pinned KB document for context injection."""
|
| 33 |
+
terms = client_terms(client)
|
| 34 |
+
display = DISPLAY_NAMES.get(client, client.title())
|
| 35 |
+
lines = "\n".join(f"- {k.replace('_', ' ').title()}: use '{v}'" for k, v in terms.items())
|
| 36 |
+
return {
|
| 37 |
+
"id": f"terms_{client}",
|
| 38 |
+
"title": f"{display} Terminology Guide",
|
| 39 |
+
"content": f"Always use these exact terms when responding to {display} users:\n{lines}",
|
| 40 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
def check_terminology(response_text: str, client: str) -> dict:
|