Nashid-Noor commited on
Commit
4a18f13
·
1 Parent(s): 880df70

Upgrade domain guardrail to Vector Semantic Similarity

Browse files
Files changed (1) hide show
  1. rag/guardrails.py +27 -18
rag/guardrails.py CHANGED
@@ -9,9 +9,11 @@ import re
9
  from dataclasses import dataclass, field
10
  from enum import Enum
11
  from typing import Optional
 
12
 
13
  from rag.logger import log_security_event
14
  from rag.retrieve import RetrievedPassage
 
15
 
16
  # =========================================================================
17
  # Enums & data models
@@ -116,24 +118,31 @@ def _redact_pii(query: str) -> str:
116
  # 3. DOMAIN ENFORCEMENT
117
  # =========================================================================
118
 
119
- _DOMAIN_KEYWORDS: list[str] = [
120
- "electric", "grid", "power", "voltage", "transformer", "cable",
121
- "substation", "safety", "maintenance", "outage", "fault", "relay",
122
- "circuit", "breaker", "distribution", "transmission", "insulation",
123
- "grounding", "earthing", "load", "switchgear", "overhead", "conductor",
124
- "utility", "compliance", "inspection", "clearance", "arc flash",
125
- "lockout", "tagout", "ppe", "hazard", "regulation", "pole",
126
- "energy", "wiring", "current", "amp", "ohm", "resistance",
127
- "kilowatt", "megawatt", "feeder", "capacitor", "inverter", "solar",
128
- "battery", "generator", "diesel", "backup", "scada", "meter",
129
- "lineman", "crew", "storm", "restoration", "vegetation", "tree",
130
- "trimming", "right.of.way", "easement", "permit",
131
- ]
132
-
133
-
134
- def _is_on_domain(query: str) -> bool:
135
- q_lower = query.lower()
136
- return any(kw in q_lower for kw in _DOMAIN_KEYWORDS)
 
 
 
 
 
 
 
137
 
138
 
139
  # =========================================================================
 
9
  from dataclasses import dataclass, field
10
  from enum import Enum
11
  from typing import Optional
12
+ import numpy as np
13
 
14
  from rag.logger import log_security_event
15
  from rag.retrieve import RetrievedPassage
16
+ from rag.index import embed_texts
17
 
18
  # =========================================================================
19
  # Enums & data models
 
118
  # 3. DOMAIN ENFORCEMENT
119
  # =========================================================================
120
 
121
+ _DOMAIN_ANCHOR_TEXT = (
122
+ "electricity grid operations power distribution high voltage transmission lines "
123
+ "substation maintenance electrical safety arc flash ppe lockout tagout circuits "
124
+ "breakers transformers utility lineman restoration outage fault generator "
125
+ "switchgear solar clearance grounding"
126
+ )
127
+
128
+ _domain_anchor_vec = None
129
+
130
+ def _is_on_domain(query: str, threshold: float = 0.20) -> bool:
131
+ """Check if query is semantically similar to the grid domain anchor text."""
132
+ global _domain_anchor_vec
133
+
134
+ if _domain_anchor_vec is None:
135
+ # Lazy load and normalise the anchor text vector
136
+ vec = embed_texts([_DOMAIN_ANCHOR_TEXT])
137
+ _domain_anchor_vec = vec / np.linalg.norm(vec, axis=1, keepdims=True)
138
+
139
+ # Embed and normalise the user query
140
+ query_vec = embed_texts([query])
141
+ query_vec = query_vec / np.linalg.norm(query_vec, axis=1, keepdims=True)
142
+
143
+ # Compute cosine similarity
144
+ similarity = float(np.dot(_domain_anchor_vec, query_vec.T)[0][0])
145
+ return similarity >= threshold
146
 
147
 
148
  # =========================================================================