File size: 3,269 Bytes
53ea588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD 2-Clause License

"""This module defines the `GuardrailProcessor` for handling blocked words/topics in user queries.

Ensures that specific queries can be detected and blocked from further processing.
"""

import re

from loguru import logger
from pipecat.frames.frames import (
    Frame,
    TranscriptionFrame,
    TTSSpeakFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor


class GuardrailProcessor(FrameProcessor):
    """Blocks queries containing specified words or topics.

    Args:
        blocked_words (list[str]): List of words that trigger query blocking.
            Matching is case-insensitive and uses word boundaries.
        block_message (str): Message returned when a query is blocked.
            Defaults to "I am not allowed to answer this question".
        **kwargs: Additional arguments passed to parent FrameProcessor.
    """

    def __init__(self, blocked_words=None, block_message="I am not allowed to answer this question", **kwargs):
        """Initializes the GuardrailProcessor.

        Args:
            blocked_words (list[str] or None): A list of words for which queries should be blocked.
            block_message (str): The message to return when a blocked query is detected.
            **kwargs: Additional keyword arguments passed to the parent class initializer.
        """
        super().__init__(**kwargs)
        # Normalize all words to lowercase for case-insensitive matching
        self._blocked_words = [word.lower() for word in (blocked_words or [])]
        self._block_message = block_message

    def is_query_blocked(self, query: str) -> bool:
        """Checks if the given query contains any blocked word.

        Args:
            query (str): The query text to evaluate.

        Returns:
            bool: True if query contains any blocked word, False otherwise.
        """
        query_lower = query.lower()  # Convert query to lowercase for case-insensitive matching
        for word in self._blocked_words:
            # Use regex to match the word as a whole word (word boundaries)
            if re.search(rf"\b{re.escape(word)}\b", query_lower):
                logger.debug(f"Query: {query} contains blocked word: {word}")
                return True
        return False

    async def process_frame(self, frame: Frame, direction: FrameDirection):
        """Processes incoming frames and blocks those containing restricted content.

        Args:
            frame (Frame): The incoming frame to process.
            direction (FrameDirection): The direction the frame is traveling.
        """
        await super().process_frame(frame, direction)

        # Handle user queries
        if isinstance(frame, TranscriptionFrame) and self.is_query_blocked(frame.text):
            logger.debug(f"Blocked query detected: {frame.text}")
            # Respond with blocking message
            await self.push_frame(TTSSpeakFrame(self._block_message))
            return  # Stop further propagation of the blocked query

        # For all other frames, proceed as usual
        await super().push_frame(frame, direction)