below-threshold commited on
Commit
76be5a0
·
1 Parent(s): 54a5940

Replace enforce_terminology with pinned glossary doc in RAG context

Browse files
Files changed (2) hide show
  1. backend/pipeline.py +10 -6
  2. backend/rosetta.py +11 -16
backend/pipeline.py CHANGED
@@ -26,7 +26,7 @@ from sentence_transformers import SentenceTransformer
26
 
27
  from config import features_path, domain_for, DISPLAY_NAMES
28
  from grader import grade, GradeReport, get_embedder
29
- from rosetta import client_terms, enforce_terminology
30
 
31
  log = logging.getLogger(__name__)
32
 
@@ -320,11 +320,15 @@ def run(
320
  if scores[i] > MIN_RETRIEVAL_SCORE
321
  ]
322
 
323
- context = _build_context(retrieved)
324
- raw_answer = _generate(query, context, client, domain, hf_client)
325
- answer, replacements = enforce_terminology(raw_answer, client)
326
- if replacements:
327
- log.info("Terminology enforced for client=%s replacements=%s", client, replacements)
 
 
 
 
328
  report = grade(
329
  query=query,
330
  response=answer,
 
26
 
27
  from config import features_path, domain_for, DISPLAY_NAMES
28
  from grader import grade, GradeReport, get_embedder
29
+ from rosetta import client_terms, client_terms_doc
30
 
31
  log = logging.getLogger(__name__)
32
 
 
320
  if scores[i] > MIN_RETRIEVAL_SCORE
321
  ]
322
 
323
+ terms_doc = client_terms_doc(client)
324
+ pinned = RetrievedDoc(
325
+ id=terms_doc["id"],
326
+ title=terms_doc["title"],
327
+ content=terms_doc["content"],
328
+ score=1.0,
329
+ )
330
+ context = _build_context([pinned] + retrieved)
331
+ answer = _generate(query, context, client, domain, hf_client)
332
  report = grade(
333
  query=query,
334
  response=answer,
backend/rosetta.py CHANGED
@@ -1,10 +1,9 @@
1
  """RosettaStone: canonical term -> client-specific term translation."""
2
 
3
- import re
4
  import yaml
5
  from functools import lru_cache
6
 
7
- from config import term_catalog_path, domain_for
8
 
9
 
10
  @lru_cache(maxsize=8)
@@ -29,20 +28,16 @@ def client_terms(client: str) -> dict[str, str]:
29
  return dict(catalog.get(client, {}))
30
 
31
 
32
- def enforce_terminology(text: str, client: str) -> tuple[str, list[dict]]:
33
- """Replace rival client terms with correct client terms. Returns (corrected_text, replacements)."""
34
- catalog = _load_catalog(domain_for(client))
35
- expected = catalog.get(client, {})
36
- other_clients = {c: terms for c, terms in catalog.items() if c != client}
37
- result = text
38
- replacements = []
39
- for canonical_key, client_term in expected.items():
40
- for other_terms in other_clients.values():
41
- rival = other_terms.get(canonical_key, "")
42
- if rival and re.search(re.escape(rival), result, re.IGNORECASE):
43
- result = re.sub(re.escape(rival), client_term, result, flags=re.IGNORECASE)
44
- replacements.append({"from": rival, "to": client_term})
45
- return result, replacements
46
 
47
 
48
  def check_terminology(response_text: str, client: str) -> dict:
 
1
  """RosettaStone: canonical term -> client-specific term translation."""
2
 
 
3
  import yaml
4
  from functools import lru_cache
5
 
6
+ from config import term_catalog_path, domain_for, DISPLAY_NAMES
7
 
8
 
9
  @lru_cache(maxsize=8)
 
28
  return dict(catalog.get(client, {}))
29
 
30
 
31
+ def client_terms_doc(client: str) -> dict:
32
+ """Return the term catalog as a pinned KB document for context injection."""
33
+ terms = client_terms(client)
34
+ display = DISPLAY_NAMES.get(client, client.title())
35
+ lines = "\n".join(f"- {k.replace('_', ' ').title()}: use '{v}'" for k, v in terms.items())
36
+ return {
37
+ "id": f"terms_{client}",
38
+ "title": f"{display} Terminology Guide",
39
+ "content": f"Always use these exact terms when responding to {display} users:\n{lines}",
40
+ }
 
 
 
 
41
 
42
 
43
  def check_terminology(response_text: str, client: str) -> dict: