below-threshold commited on
Commit
54a5940
·
1 Parent(s): cd30e2d

Add enforce_terminology: deterministic post-processing corrective gate

Browse files
Files changed (2) hide show
  1. backend/pipeline.py +5 -2
  2. backend/rosetta.py +17 -0
backend/pipeline.py CHANGED
@@ -26,7 +26,7 @@ from sentence_transformers import SentenceTransformer
26
 
27
  from config import features_path, domain_for, DISPLAY_NAMES
28
  from grader import grade, GradeReport, get_embedder
29
- from rosetta import client_terms
30
 
31
  log = logging.getLogger(__name__)
32
 
@@ -321,7 +321,10 @@ def run(
321
  ]
322
 
323
  context = _build_context(retrieved)
324
- answer = _generate(query, context, client, domain, hf_client)
 
 
 
325
  report = grade(
326
  query=query,
327
  response=answer,
 
26
 
27
  from config import features_path, domain_for, DISPLAY_NAMES
28
  from grader import grade, GradeReport, get_embedder
29
+ from rosetta import client_terms, enforce_terminology
30
 
31
  log = logging.getLogger(__name__)
32
 
 
321
  ]
322
 
323
  context = _build_context(retrieved)
324
+ raw_answer = _generate(query, context, client, domain, hf_client)
325
+ answer, replacements = enforce_terminology(raw_answer, client)
326
+ if replacements:
327
+ log.info("Terminology enforced for client=%s replacements=%s", client, replacements)
328
  report = grade(
329
  query=query,
330
  response=answer,
backend/rosetta.py CHANGED
@@ -1,5 +1,6 @@
1
  """RosettaStone: canonical term -> client-specific term translation."""
2
 
 
3
  import yaml
4
  from functools import lru_cache
5
 
@@ -28,6 +29,22 @@ def client_terms(client: str) -> dict[str, str]:
28
  return dict(catalog.get(client, {}))
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def check_terminology(response_text: str, client: str) -> dict:
32
  """
33
  Deterministic chain_terminology check.
 
1
  """RosettaStone: canonical term -> client-specific term translation."""
2
 
3
+ import re
4
  import yaml
5
  from functools import lru_cache
6
 
 
29
  return dict(catalog.get(client, {}))
30
 
31
 
32
+ def enforce_terminology(text: str, client: str) -> tuple[str, list[dict]]:
33
+ """Replace rival client terms with correct client terms. Returns (corrected_text, replacements)."""
34
+ catalog = _load_catalog(domain_for(client))
35
+ expected = catalog.get(client, {})
36
+ other_clients = {c: terms for c, terms in catalog.items() if c != client}
37
+ result = text
38
+ replacements = []
39
+ for canonical_key, client_term in expected.items():
40
+ for other_terms in other_clients.values():
41
+ rival = other_terms.get(canonical_key, "")
42
+ if rival and re.search(re.escape(rival), result, re.IGNORECASE):
43
+ result = re.sub(re.escape(rival), client_term, result, flags=re.IGNORECASE)
44
+ replacements.append({"from": rival, "to": client_term})
45
+ return result, replacements
46
+
47
+
48
  def check_terminology(response_text: str, client: str) -> dict:
49
  """
50
  Deterministic chain_terminology check.