import datetime from collections.abc import Callable from logging import getLogger from typing import Any, TypeVar from sqlalchemy import ColumnElement, Select, and_, case, cast, literal, not_, or_ from sqlalchemy.types import Numeric from ..exceptions import FilterError from .formatting import ILIKE_ESCAPE_CHAR, escape_ilike_pattern, parse_datetime_iso logger = getLogger(__name__) # Type variable for SQLAlchemy model classes T = TypeVar("T") # Module-level constants for comparison operators COMPARISON_OPERATORS = { "gte", "lte", "gt", "lt", "ne", "in", "contains", "icontains", } NUMERIC_OPERATORS = {"gte", "lte", "gt", "lt", "ne"} ALLOWED_EXTERNAL_TO_INTERNAL_COLUMN_MAPPING = { "id": "name", "created_at": "created_at", "is_active": "is_active", "workspace_id": "workspace_name", "session_id": "session_name", "peer_id": "peer_name", "metadata": "h_metadata", } ALLOWED_EXTERNAL_TO_INTERNAL_COLUMN_MAPPING_MESSAGES = { "workspace_id": "workspace_name", "session_id": "session_name", "peer_id": "peer_name", "token_count": "token_count", "created_at": "created_at", "metadata": "h_metadata", } ALLOWED_EXTERNAL_TO_INTERNAL_COLUMN_MAPPING_DOCUMENTS = { "session_id": "session_name", "workspace_id": "workspace_name", "observer_id": "observer", "observed_id": "observed", "metadata": "internal_metadata", } def apply_filter( stmt: Select[tuple[T]], model_class: type[T], filters: dict[str, Any] | None = None ) -> Select[tuple[T]]: """ Apply advanced filter to a SQL statement based on filter dictionary. Supports logical operators (AND, OR, NOT), comparison operators (gte, lte, gt, lt, ne, contains, icontains, in), and wildcard character (*). Note that the filter refers to column names from the user perspective: that means all `*_name` fields are actually `*_id` fields and `h_metadata` is actually `metadata`. Examples: # Simple filters (backward compatible) {"peer_id": "alice", "metadata": {"type": "user"}} # Logical operators {"AND": [{"peer_id": "alice"}, {"created_at": {"gte": "2024-01-01"}}]} {"OR": [{"peer_id": "alice"}, {"peer_id": "bob"}]} {"NOT": [{"peer_id": "alice"}]} # Comparison operators {"created_at": {"gte": "2024-01-01", "lte": "2024-12-31"}} {"peer_id": {"in": ["alice", "bob"]}} # Wildcards (matches everything for that field) {"peer_id": "*"} Args: stmt: SQLAlchemy Select statement to modify model_class: SQLAlchemy model class for column access filters: Optional filter dictionary Returns: Modified Select statement with filter applied if provided Raises: FilterError: When the filter contains invalid configuration or values """ if filters is None: return stmt conditions = _build_filter_conditions(filters, model_class) if conditions is not None: stmt = stmt.where(conditions) return stmt def _build_filter_conditions( filter_dict: dict[str, Any], model_class: type[Any], *, _depth: int = 0 ) -> ColumnElement[bool] | None: """ Recursively build filter conditions from a filter dictionary. Args: filter_dict: Filter dictionary that may contain logical operators model_class: SQLAlchemy model class for column access Returns: SQLAlchemy condition object or None """ if _depth > 5: raise FilterError("Filter nesting exceeds maximum depth of 5") conditions: list[ColumnElement[bool]] = [] # Handle logical operators if "AND" in filter_dict: if not isinstance(filter_dict["AND"], list): raise FilterError( f"AND operator must contain a list, got {type(filter_dict['AND']).__name__}" ) and_conditions: list[ColumnElement[bool]] = [] for sub_filter in filter_dict["AND"]: # pyright: ignore sub_condition = _build_filter_conditions( sub_filter, # pyright: ignore[reportUnknownArgumentType] model_class, _depth=_depth + 1, ) if sub_condition is not None: and_conditions.append(sub_condition) if and_conditions: conditions.append(and_(*and_conditions)) if "OR" in filter_dict: if not isinstance(filter_dict["OR"], list): raise FilterError( f"OR operator must contain a list, got {type(filter_dict['OR']).__name__}" ) or_conditions: list[ColumnElement[bool]] = [] for sub_filter in filter_dict["OR"]: # pyright: ignore sub_condition = _build_filter_conditions( sub_filter, # pyright: ignore[reportUnknownArgumentType] model_class, _depth=_depth + 1, ) if sub_condition is not None: or_conditions.append(sub_condition) if or_conditions: conditions.append(or_(*or_conditions)) if "NOT" in filter_dict: if filter_dict["NOT"] is None: raise FilterError("NOT operator cannot be None") if not isinstance(filter_dict["NOT"], list): raise FilterError( f"NOT operator must contain a list, got {type(filter_dict['NOT']).__name__}" ) not_conditions: list[ColumnElement[bool]] = [] for sub_filter in filter_dict["NOT"]: # pyright: ignore sub_condition = _build_filter_conditions( sub_filter, # pyright: ignore[reportUnknownArgumentType] model_class, _depth=_depth + 1, ) if sub_condition is not None: not_conditions.append( not_(sub_condition) ) # Apply NOT to each condition individually if not_conditions: conditions.append(and_(*not_conditions)) # Then AND them together # Handle field-level conditions (skip logical operator keys) logical_keys = {"AND", "OR", "NOT"} for key, value in filter_dict.items(): if key in logical_keys: continue condition = _build_field_condition(key, value, model_class) if condition is not None: conditions.append(condition) # Combine all conditions with AND if len(conditions) == 0: return None elif len(conditions) == 1: return conditions[0] else: return and_(*conditions) def _build_field_condition( key: str, value: Any, model_class: type[Any] ) -> ColumnElement[bool] | None: """ Build a condition for a single field. Args: key: Field name value: Field value or comparison dict model_class: SQLAlchemy model class Returns: SQLAlchemy condition object or None """ if model_class.__name__ == "Message": column_name = ALLOWED_EXTERNAL_TO_INTERNAL_COLUMN_MAPPING_MESSAGES.get(key) elif model_class.__name__ == "Document": column_name = ALLOWED_EXTERNAL_TO_INTERNAL_COLUMN_MAPPING_DOCUMENTS.get( key, key, # fallback to the key itself if not found in the mapping for internal use here ) else: column_name = ALLOWED_EXTERNAL_TO_INTERNAL_COLUMN_MAPPING.get(key) if column_name is None: raise FilterError( f"Column '{key}' is not allowed to be filtered on or does not exist on {model_class.__name__}" ) # Check if the column exists on the model if not hasattr(model_class, column_name): raise FilterError(f"Column '{key}' does not exist on {model_class.__name__}") column = getattr(model_class, column_name) # Handle wildcard - matches everything, so no condition needed if value == "*": return None # Handle comparison operators vs regular values if isinstance(value, dict): # Check if this is a comparison operators dict by looking for known operators is_comparison_dict = any(op_key in COMPARISON_OPERATORS for op_key in value) # pyright: ignore if is_comparison_dict: return _build_comparison_conditions(column, column_name, value) # pyright: ignore else: # This is a regular value that happens to be a dict # For JSONB fields (metadata, configuration), check if it contains nested comparison operators if column_name in ("h_metadata", "configuration", "internal_metadata"): return _build_nested_metadata_conditions(column, value) # pyright: ignore else: return column == value else: if column_name in ("h_metadata", "configuration", "internal_metadata"): return column.contains(value) else: return column == value def _safe_numeric_cast( column_accessor: ColumnElement[Any], op_value: Any ) -> tuple[ColumnElement[Any], Any]: """ Safely cast JSONB column accessor to appropriate type for comparison. Args: column_accessor: SQLAlchemy JSONB column accessor (.astext) op_value: The value to compare against Returns: Tuple of (cast_column_accessor, cast_op_value) for typed comparison or (column_accessor, str_op_value) for string comparison """ try: if isinstance(op_value, bool): # For boolean values, compare with the string representation # PostgreSQL JSONB stores booleans as "true"/"false" strings when extracted with ->> return column_accessor, str(op_value).lower() # For numeric values, use a safer cast that handles empty strings and invalid values # We use CASE WHEN to handle empty strings and non-numeric values gracefully safe_cast = case( (column_accessor == "", literal(None)), # Empty string -> NULL (column_accessor.is_(None), literal(None)), # NULL -> NULL else_=cast(column_accessor, Numeric()), ) if isinstance(op_value, int | float): return safe_cast, op_value else: # Try to parse as numeric (handles both strings and other types) try: # Try int first, then float parsed_value = int(op_value) return safe_cast, parsed_value except (ValueError, TypeError): try: parsed_value = float(op_value) return safe_cast, parsed_value except (ValueError, TypeError): if isinstance(op_value, str): # If it's not numeric, treat as string comparison (e.g., dates, text) # This allows date strings like "2024-02-01" to be compared lexicographically return column_accessor, str(op_value) else: raise FilterError( f"Invalid value for numeric operator: {op_value}. Expected a number, got {type(op_value).__name__}" ) from None except Exception as e: raise FilterError( f"Failed to process numeric cast for value '{op_value}': {str(e)}" ) from e def _build_comparison_condition( column: Any, field_name: str, operator: str, op_value: Any ) -> ColumnElement[bool] | None: """ Build a single comparison condition for a JSONB field. Args: column: SQLAlchemy JSONB column object field_name: Name of the field in the JSONB column operator: Comparison operator op_value: Value to compare against Returns: SQLAlchemy condition object or None """ # Validate that the operator is supported if operator not in COMPARISON_OPERATORS: raise FilterError(f"Unsupported comparison operator: {operator}") # Handle wildcard - matches everything, so no condition needed if op_value == "*": return None field_accessor = column[field_name].astext # Mapping of operators to their SQLAlchemy methods if operator in NUMERIC_OPERATORS: try: safe_accessor, safe_value = _safe_numeric_cast(field_accessor, op_value) operator_map: dict[str, Callable[[Any, Any], ColumnElement[bool]]] = { "gte": lambda a, v: a >= v, "lte": lambda a, v: a <= v, "gt": lambda a, v: a > v, "lt": lambda a, v: a < v, "ne": lambda a, v: a != v, } return operator_map[operator](safe_accessor, safe_value) except Exception as e: raise FilterError( f"Failed to build numeric comparison condition for operator '{operator}' with value '{op_value}': {str(e)}" ) from e elif operator == "in": if hasattr(op_value, "__iter__") and not isinstance(op_value, str | bytes): # Handle wildcard in iterable - if present, matches everything, so no condition needed if "*" in op_value: return None return field_accessor.in_([str(v) for v in op_value]) else: raise FilterError( f"Invalid value for 'in' operator: {op_value}. Expected an iterable (list, tuple, set), got {type(op_value).__name__}" ) elif operator in ("contains", "icontains"): escaped_value = escape_ilike_pattern(str(op_value)) return field_accessor.ilike(f"%{escaped_value}%", escape=ILIKE_ESCAPE_CHAR) return None def _build_nested_metadata_conditions( column: Any, metadata_dict: dict[str, Any] ) -> ColumnElement[bool] | None: """ Build conditions for nested metadata fields with comparison operators. Args: column: SQLAlchemy JSONB column object metadata_dict: Dictionary containing nested field conditions Returns: Combined SQLAlchemy condition object or None """ conditions: list[ColumnElement[bool]] = [] for field_name, field_value in metadata_dict.items(): if isinstance(field_value, dict) and any( op in COMPARISON_OPERATORS for op in field_value # pyright: ignore ): # This field has comparison operators field_conditions: list[ColumnElement[bool]] = [] for operator, op_value in field_value.items(): # pyright: ignore condition = _build_comparison_condition( column, field_name, operator, # pyright: ignore op_value, ) if condition is not None: field_conditions.append(condition) if field_conditions: conditions.append( field_conditions[0] if len(field_conditions) == 1 else and_(*field_conditions) ) else: # Handle wildcard - matches everything, so no condition needed if field_value == "*": continue # Regular field equality - use JSONB contains for nested object matching conditions.append(column.contains({field_name: field_value})) # Combine all field conditions with AND return _combine_conditions_with_and(conditions) def _combine_conditions_with_and( conditions: list[ColumnElement[bool]], ) -> ColumnElement[bool] | None: """ Combine a list of conditions with AND logic. Args: conditions: List of SQLAlchemy condition objects Returns: Combined condition object or None if no conditions """ if not conditions: return None elif len(conditions) == 1: return conditions[0] else: return and_(*conditions) def _build_comparison_conditions( column: Any, column_name: str, comparisons: dict[str, Any] ) -> ColumnElement[bool] | None: """ Build comparison conditions for a single column. Args: column: SQLAlchemy column object column_name: Name of the column comparisons: Dictionary of comparison operators and values Returns: Combined SQLAlchemy condition object or None """ conditions: list[ColumnElement[bool]] = [] # Check if this is a datetime column is_datetime_column = hasattr(column.type, "python_type") and issubclass( column.type.python_type, datetime.datetime ) for operator, op_value in comparisons.items(): # Validate that the operator is supported if operator not in COMPARISON_OPERATORS: raise FilterError(f"Unsupported comparison operator: {operator}") # Handle wildcard - matches everything, so no condition needed if op_value == "*": continue condition = None # For datetime columns, cast string values to timestamp if is_datetime_column and isinstance(op_value, str): # Validate datetime string to prevent SQL injection validated_datetime = _validate_datetime_string(op_value) if validated_datetime is None: # Raise error if datetime validation fails raise FilterError(f"Invalid datetime value: {op_value}") # Use the validated datetime object directly instead of string interpolation casted_value = validated_datetime else: # if the operator is a numeric operator, the value must cast to a number if operator in NUMERIC_OPERATORS: try: casted_value = float(op_value) except ValueError: raise FilterError( f"Invalid numeric value: {op_value}. Expected a number, got {type(op_value).__name__}" ) from None else: casted_value = op_value if operator == "gte": condition = column >= casted_value elif operator == "lte": condition = column <= casted_value elif operator == "gt": condition = column > casted_value elif operator == "lt": condition = column < casted_value elif operator == "ne": condition = column != casted_value elif operator == "in": if hasattr(op_value, "__iter__") and not isinstance(op_value, str | bytes): # Handle wildcard in iterable - if present, matches everything, so no condition needed if "*" in op_value: continue else: if is_datetime_column: # Validate and cast each datetime string value casted_values: list[str | datetime.datetime] = [] for val in op_value: if isinstance(val, str): validated_datetime = _validate_datetime_string(val) if validated_datetime is None: raise FilterError( f"Invalid datetime value in list: {val}" ) casted_values.append(validated_datetime) else: casted_values.append(val) if casted_values: condition = column.in_(casted_values) else: condition = column.in_(list(op_value)) else: raise FilterError( f"Invalid value for 'in' operator: {op_value}. Expected an iterable (list, tuple, set), got {type(op_value).__name__}" ) elif operator == "contains": if column_name == "h_metadata": # For JSONB columns, use JSONB contains condition = column.contains(op_value) else: # For text columns, use ILIKE with escaped pattern escaped_value = escape_ilike_pattern(str(op_value)) condition = column.ilike(f"%{escaped_value}%", escape=ILIKE_ESCAPE_CHAR) elif operator == "icontains": # Case-insensitive contains for text columns with escaped pattern escaped_value = escape_ilike_pattern(str(op_value)) condition = column.ilike(f"%{escaped_value}%", escape=ILIKE_ESCAPE_CHAR) if condition is not None: conditions.append(condition) # Combine all conditions for this field with AND if len(conditions) == 0: return None elif len(conditions) == 1: return conditions[0] else: return and_(*conditions) def _validate_datetime_string(value: str) -> datetime.datetime | None: """ Safely validate and parse a datetime string to prevent SQL injection. This function prioritizes timezone-aware datetime formats and uses the consistent parse_datetime_iso utility for proper timezone handling. Args: value: String value to validate as datetime Returns: Parsed datetime object if valid, None if invalid """ # Strip whitespace value = value.strip() # First try the standard ISO format with timezone info using our utility try: return parse_datetime_iso(value) except ValueError: pass # Fallback to naive formats (assume UTC timezone for compatibility) naive_formats = [ "%Y-%m-%dT%H:%M:%S", # 2024-01-01T12:00:00 (ISO format, assume UTC) "%Y-%m-%dT%H:%M:%S.%f", # 2024-01-01T12:00:00.123456 (assume UTC) "%Y-%m-%d %H:%M:%S", # 2024-01-01 12:00:00 (assume UTC) "%Y-%m-%d %H:%M:%S.%f", # 2024-01-01 12:00:00.123456 (assume UTC) "%Y-%m-%d", # 2024-01-01 (assume UTC, start of day) ] for fmt in naive_formats: try: parsed = datetime.datetime.strptime(value, fmt) # Assume UTC timezone for naive datetimes return parsed.replace(tzinfo=datetime.timezone.utc) except ValueError: continue # Return None for invalid datetime - let the caller handle the error return None