Spaces:
Sleeping
Sleeping
Nashid-Noor commited on
Commit ·
4a18f13
1
Parent(s): 880df70
Upgrade domain guardrail to Vector Semantic Similarity
Browse files- rag/guardrails.py +27 -18
rag/guardrails.py
CHANGED
|
@@ -9,9 +9,11 @@ import re
|
|
| 9 |
from dataclasses import dataclass, field
|
| 10 |
from enum import Enum
|
| 11 |
from typing import Optional
|
|
|
|
| 12 |
|
| 13 |
from rag.logger import log_security_event
|
| 14 |
from rag.retrieve import RetrievedPassage
|
|
|
|
| 15 |
|
| 16 |
# =========================================================================
|
| 17 |
# Enums & data models
|
|
@@ -116,24 +118,31 @@ def _redact_pii(query: str) -> str:
|
|
| 116 |
# 3. DOMAIN ENFORCEMENT
|
| 117 |
# =========================================================================
|
| 118 |
|
| 119 |
-
|
| 120 |
-
"
|
| 121 |
-
"substation
|
| 122 |
-
"
|
| 123 |
-
"
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
"
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
|
| 139 |
# =========================================================================
|
|
|
|
| 9 |
from dataclasses import dataclass, field
|
| 10 |
from enum import Enum
|
| 11 |
from typing import Optional
|
| 12 |
+
import numpy as np
|
| 13 |
|
| 14 |
from rag.logger import log_security_event
|
| 15 |
from rag.retrieve import RetrievedPassage
|
| 16 |
+
from rag.index import embed_texts
|
| 17 |
|
| 18 |
# =========================================================================
|
| 19 |
# Enums & data models
|
|
|
|
| 118 |
# 3. DOMAIN ENFORCEMENT
|
| 119 |
# =========================================================================
|
| 120 |
|
| 121 |
+
_DOMAIN_ANCHOR_TEXT = (
|
| 122 |
+
"electricity grid operations power distribution high voltage transmission lines "
|
| 123 |
+
"substation maintenance electrical safety arc flash ppe lockout tagout circuits "
|
| 124 |
+
"breakers transformers utility lineman restoration outage fault generator "
|
| 125 |
+
"switchgear solar clearance grounding"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
_domain_anchor_vec = None
|
| 129 |
+
|
| 130 |
+
def _is_on_domain(query: str, threshold: float = 0.20) -> bool:
|
| 131 |
+
"""Check if query is semantically similar to the grid domain anchor text."""
|
| 132 |
+
global _domain_anchor_vec
|
| 133 |
+
|
| 134 |
+
if _domain_anchor_vec is None:
|
| 135 |
+
# Lazy load and normalise the anchor text vector
|
| 136 |
+
vec = embed_texts([_DOMAIN_ANCHOR_TEXT])
|
| 137 |
+
_domain_anchor_vec = vec / np.linalg.norm(vec, axis=1, keepdims=True)
|
| 138 |
+
|
| 139 |
+
# Embed and normalise the user query
|
| 140 |
+
query_vec = embed_texts([query])
|
| 141 |
+
query_vec = query_vec / np.linalg.norm(query_vec, axis=1, keepdims=True)
|
| 142 |
+
|
| 143 |
+
# Compute cosine similarity
|
| 144 |
+
similarity = float(np.dot(_domain_anchor_vec, query_vec.T)[0][0])
|
| 145 |
+
return similarity >= threshold
|
| 146 |
|
| 147 |
|
| 148 |
# =========================================================================
|