Commit ·
54a5940
1
Parent(s): cd30e2d
Add enforce_terminology: deterministic post-processing corrective gate
Browse files- backend/pipeline.py +5 -2
- backend/rosetta.py +17 -0
backend/pipeline.py
CHANGED
|
@@ -26,7 +26,7 @@ from sentence_transformers import SentenceTransformer
|
|
| 26 |
|
| 27 |
from config import features_path, domain_for, DISPLAY_NAMES
|
| 28 |
from grader import grade, GradeReport, get_embedder
|
| 29 |
-
from rosetta import client_terms
|
| 30 |
|
| 31 |
log = logging.getLogger(__name__)
|
| 32 |
|
|
@@ -321,7 +321,10 @@ def run(
|
|
| 321 |
]
|
| 322 |
|
| 323 |
context = _build_context(retrieved)
|
| 324 |
-
|
|
|
|
|
|
|
|
|
|
| 325 |
report = grade(
|
| 326 |
query=query,
|
| 327 |
response=answer,
|
|
|
|
| 26 |
|
| 27 |
from config import features_path, domain_for, DISPLAY_NAMES
|
| 28 |
from grader import grade, GradeReport, get_embedder
|
| 29 |
+
from rosetta import client_terms, enforce_terminology
|
| 30 |
|
| 31 |
log = logging.getLogger(__name__)
|
| 32 |
|
|
|
|
| 321 |
]
|
| 322 |
|
| 323 |
context = _build_context(retrieved)
|
| 324 |
+
raw_answer = _generate(query, context, client, domain, hf_client)
|
| 325 |
+
answer, replacements = enforce_terminology(raw_answer, client)
|
| 326 |
+
if replacements:
|
| 327 |
+
log.info("Terminology enforced for client=%s replacements=%s", client, replacements)
|
| 328 |
report = grade(
|
| 329 |
query=query,
|
| 330 |
response=answer,
|
backend/rosetta.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
"""RosettaStone: canonical term -> client-specific term translation."""
|
| 2 |
|
|
|
|
| 3 |
import yaml
|
| 4 |
from functools import lru_cache
|
| 5 |
|
|
@@ -28,6 +29,22 @@ def client_terms(client: str) -> dict[str, str]:
|
|
| 28 |
return dict(catalog.get(client, {}))
|
| 29 |
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
def check_terminology(response_text: str, client: str) -> dict:
|
| 32 |
"""
|
| 33 |
Deterministic chain_terminology check.
|
|
|
|
| 1 |
"""RosettaStone: canonical term -> client-specific term translation."""
|
| 2 |
|
| 3 |
+
import re
|
| 4 |
import yaml
|
| 5 |
from functools import lru_cache
|
| 6 |
|
|
|
|
| 29 |
return dict(catalog.get(client, {}))
|
| 30 |
|
| 31 |
|
| 32 |
+
def enforce_terminology(text: str, client: str) -> tuple[str, list[dict]]:
|
| 33 |
+
"""Replace rival client terms with correct client terms. Returns (corrected_text, replacements)."""
|
| 34 |
+
catalog = _load_catalog(domain_for(client))
|
| 35 |
+
expected = catalog.get(client, {})
|
| 36 |
+
other_clients = {c: terms for c, terms in catalog.items() if c != client}
|
| 37 |
+
result = text
|
| 38 |
+
replacements = []
|
| 39 |
+
for canonical_key, client_term in expected.items():
|
| 40 |
+
for other_terms in other_clients.values():
|
| 41 |
+
rival = other_terms.get(canonical_key, "")
|
| 42 |
+
if rival and re.search(re.escape(rival), result, re.IGNORECASE):
|
| 43 |
+
result = re.sub(re.escape(rival), client_term, result, flags=re.IGNORECASE)
|
| 44 |
+
replacements.append({"from": rival, "to": client_term})
|
| 45 |
+
return result, replacements
|
| 46 |
+
|
| 47 |
+
|
| 48 |
def check_terminology(response_text: str, client: str) -> dict:
|
| 49 |
"""
|
| 50 |
Deterministic chain_terminology check.
|