uae-kb / ir /sensitive_detector.py
Demon1212122's picture
Initial UAE Knowledge System demo
8124364
"""
Sensitive Topic Detector
Detects sensitive topics in queries and provides reframe guidance.
Uses trigger patterns from sensitive_topics.json to identify:
- Controversial claims
- Allegations
- Sensitive framings
Returns appropriate response strategies and suggested responses.
"""
from dataclasses import dataclass
from typing import List, Optional
from .knowledge_base import KnowledgeBase
@dataclass
class SensitiveMatch:
"""Result from sensitive topic detection"""
entity_id: str
entity_name: str
topic_type: str
severity: str # "high", "medium", "low"
matched_pattern: str
problematic_framing: str
suggested_response: str
strategy: str # "provide_context", "diplomatic", "decline"
tone: str
class SensitiveTopicDetector:
"""
Detects sensitive topics in queries using trigger patterns.
This is a simple pattern-matching approach (not LLM-based) for
fast detection. For more nuanced detection, use LLM verification.
"""
def __init__(self, knowledge_base: KnowledgeBase = None, debug: bool = False):
"""
Initialize detector.
Args:
knowledge_base: UAEKnowledgeBase instance (will load if not provided)
debug: Enable debug output
"""
self.debug = debug
self.kb = knowledge_base or KnowledgeBase(debug=debug)
# Build pattern index
self._patterns: List[tuple] = [] # (pattern, topic_dict)
self._build_pattern_index()
def _build_pattern_index(self) -> None:
"""Build searchable pattern index"""
for topic in self.kb._sensitive_topics_raw:
for pattern in topic.get("trigger_patterns", []):
self._patterns.append((pattern.lower(), topic))
# Sort by pattern length (longer patterns first for more specific matches)
self._patterns.sort(key=lambda x: len(x[0]), reverse=True)
if self.debug:
print(f"✅ SensitiveTopicDetector: {len(self._patterns)} patterns indexed")
def detect(self, query: str) -> List[SensitiveMatch]:
"""
Detect sensitive topics in a query.
Args:
query: User query text
Returns:
List of SensitiveMatch objects for each triggered pattern
"""
query_lower = query.lower()
matches = []
seen_topics = set() # Avoid duplicates
for pattern, topic in self._patterns:
if pattern in query_lower:
topic_id = topic.get("id", "")
# Skip if already matched this topic
if topic_id in seen_topics:
continue
seen_topics.add(topic_id)
response = topic.get("appropriate_response", {})
match = SensitiveMatch(
entity_id=topic.get("source_entity_id", ""),
entity_name=topic.get("source_entity_name", ""),
topic_type=topic.get("topic_type", ""),
severity=topic.get("severity", "medium"),
matched_pattern=pattern,
problematic_framing=topic.get("problematic_framing", ""),
suggested_response=response.get("suggested_response", ""),
strategy=response.get("strategy", "provide_context"),
tone=response.get("tone", "factual_neutral"),
)
matches.append(match)
if self.debug and matches:
print(f"🚨 Sensitive topics detected: {len(matches)}")
for m in matches:
print(f" - {m.matched_pattern} ({m.severity})")
return matches
def is_sensitive(self, query: str) -> bool:
"""Quick check if query contains any sensitive patterns"""
query_lower = query.lower()
for pattern, _ in self._patterns:
if pattern in query_lower:
return True
return False
def get_reframe_guidance(self, query: str) -> Optional[dict]:
"""
Get reframe guidance for a sensitive query.
Returns None if query is not sensitive.
Returns dict with guidance if sensitive.
"""
matches = self.detect(query)
if not matches:
return None
# Get the highest severity match
severity_order = {"high": 3, "medium": 2, "low": 1}
matches.sort(key=lambda m: severity_order.get(m.severity, 0), reverse=True)
primary = matches[0]
return {
"is_sensitive": True,
"severity": primary.severity,
"entity_name": primary.entity_name,
"strategy": primary.strategy,
"suggested_response": primary.suggested_response,
"tone": primary.tone,
"all_matches": [
{
"pattern": m.matched_pattern,
"entity": m.entity_name,
"severity": m.severity,
}
for m in matches
],
}
def get_statistics(self) -> dict:
"""Get detector statistics"""
severity_counts = {"high": 0, "medium": 0, "low": 0}
type_counts = {}
for _, topic in self._patterns:
sev = topic.get("severity", "medium")
severity_counts[sev] = severity_counts.get(sev, 0) + 1
ttype = topic.get("topic_type", "unknown")
type_counts[ttype] = type_counts.get(ttype, 0) + 1
return {
"total_patterns": len(self._patterns),
"by_severity": severity_counts,
"by_type": type_counts,
}