""" Sensitive Topic Detector Detects sensitive topics in queries and provides reframe guidance. Uses trigger patterns from sensitive_topics.json to identify: - Controversial claims - Allegations - Sensitive framings Returns appropriate response strategies and suggested responses. """ from dataclasses import dataclass from typing import List, Optional from .knowledge_base import KnowledgeBase @dataclass class SensitiveMatch: """Result from sensitive topic detection""" entity_id: str entity_name: str topic_type: str severity: str # "high", "medium", "low" matched_pattern: str problematic_framing: str suggested_response: str strategy: str # "provide_context", "diplomatic", "decline" tone: str class SensitiveTopicDetector: """ Detects sensitive topics in queries using trigger patterns. This is a simple pattern-matching approach (not LLM-based) for fast detection. For more nuanced detection, use LLM verification. """ def __init__(self, knowledge_base: KnowledgeBase = None, debug: bool = False): """ Initialize detector. Args: knowledge_base: UAEKnowledgeBase instance (will load if not provided) debug: Enable debug output """ self.debug = debug self.kb = knowledge_base or KnowledgeBase(debug=debug) # Build pattern index self._patterns: List[tuple] = [] # (pattern, topic_dict) self._build_pattern_index() def _build_pattern_index(self) -> None: """Build searchable pattern index""" for topic in self.kb._sensitive_topics_raw: for pattern in topic.get("trigger_patterns", []): self._patterns.append((pattern.lower(), topic)) # Sort by pattern length (longer patterns first for more specific matches) self._patterns.sort(key=lambda x: len(x[0]), reverse=True) if self.debug: print(f"✅ SensitiveTopicDetector: {len(self._patterns)} patterns indexed") def detect(self, query: str) -> List[SensitiveMatch]: """ Detect sensitive topics in a query. Args: query: User query text Returns: List of SensitiveMatch objects for each triggered pattern """ query_lower = query.lower() matches = [] seen_topics = set() # Avoid duplicates for pattern, topic in self._patterns: if pattern in query_lower: topic_id = topic.get("id", "") # Skip if already matched this topic if topic_id in seen_topics: continue seen_topics.add(topic_id) response = topic.get("appropriate_response", {}) match = SensitiveMatch( entity_id=topic.get("source_entity_id", ""), entity_name=topic.get("source_entity_name", ""), topic_type=topic.get("topic_type", ""), severity=topic.get("severity", "medium"), matched_pattern=pattern, problematic_framing=topic.get("problematic_framing", ""), suggested_response=response.get("suggested_response", ""), strategy=response.get("strategy", "provide_context"), tone=response.get("tone", "factual_neutral"), ) matches.append(match) if self.debug and matches: print(f"🚨 Sensitive topics detected: {len(matches)}") for m in matches: print(f" - {m.matched_pattern} ({m.severity})") return matches def is_sensitive(self, query: str) -> bool: """Quick check if query contains any sensitive patterns""" query_lower = query.lower() for pattern, _ in self._patterns: if pattern in query_lower: return True return False def get_reframe_guidance(self, query: str) -> Optional[dict]: """ Get reframe guidance for a sensitive query. Returns None if query is not sensitive. Returns dict with guidance if sensitive. """ matches = self.detect(query) if not matches: return None # Get the highest severity match severity_order = {"high": 3, "medium": 2, "low": 1} matches.sort(key=lambda m: severity_order.get(m.severity, 0), reverse=True) primary = matches[0] return { "is_sensitive": True, "severity": primary.severity, "entity_name": primary.entity_name, "strategy": primary.strategy, "suggested_response": primary.suggested_response, "tone": primary.tone, "all_matches": [ { "pattern": m.matched_pattern, "entity": m.entity_name, "severity": m.severity, } for m in matches ], } def get_statistics(self) -> dict: """Get detector statistics""" severity_counts = {"high": 0, "medium": 0, "low": 0} type_counts = {} for _, topic in self._patterns: sev = topic.get("severity", "medium") severity_counts[sev] = severity_counts.get(sev, 0) + 1 ttype = topic.get("topic_type", "unknown") type_counts[ttype] = type_counts.get(ttype, 0) + 1 return { "total_patterns": len(self._patterns), "by_severity": severity_counts, "by_type": type_counts, }