codebook / potato /server_utils /display_logic.py
davidjurgens's picture
Deploy: Potato — Codebook Annotation
aceb1b2 verified
Raw
History Blame Contribute Delete
26.6 kB
"""
Display Logic Module for Conditional Schema Branching
This module provides the core validation and evaluation logic for conditional
annotation schemas. It allows schemas to show/hide based on user responses
to other schemas.
Key Components:
- DisplayLogicCondition: Represents a single condition (e.g., "schema X equals 'Yes'")
- DisplayLogicRule: Represents a complete rule with multiple conditions and AND/OR logic
- DisplayLogicValidator: Validates display_logic configurations
- DisplayLogicEvaluator: Evaluates conditions at runtime
Example Configuration:
display_logic:
show_when:
- schema: contains_pii
operator: equals
value: "Yes"
logic: all # 'all' = AND, 'any' = OR
"""
import re
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, Tuple, Union
logger = logging.getLogger(__name__)
# Supported operators and their descriptions
SUPPORTED_OPERATORS = {
# Value comparison
"equals": "Exact value match (single value or list of values)",
"not_equals": "Value doesn't match any specified values",
# Collection operators
"contains": "List/text contains value(s)",
"not_contains": "List/text doesn't contain value(s)",
# Regex
"matches": "Regex pattern match",
# Numeric comparison
"gt": "Greater than",
"gte": "Greater than or equal",
"lt": "Less than",
"lte": "Less than or equal",
"in_range": "Value is within range (inclusive)",
"not_in_range": "Value is outside range",
# Emptiness
"empty": "Field is empty or not set",
"not_empty": "Field has a value",
# Text length
"length_gt": "Text length greater than",
"length_lt": "Text length less than",
"length_in_range": "Text length within range (inclusive)",
}
@dataclass
class DisplayLogicCondition:
"""
Represents a single condition in a display logic rule.
Attributes:
schema: Name of the schema to watch
operator: Comparison operator (equals, contains, gt, etc.)
value: Value(s) to compare against (can be single value, list, or range)
case_sensitive: Whether text comparisons are case-sensitive (default: False)
"""
schema: str
operator: str
value: Any = None
case_sensitive: bool = False
def __post_init__(self):
"""Validate the condition after initialization."""
if self.operator not in SUPPORTED_OPERATORS:
raise ValueError(f"Unsupported operator: {self.operator}. "
f"Supported operators: {list(SUPPORTED_OPERATORS.keys())}")
# Validate operator-specific requirements
if self.operator in ("empty", "not_empty"):
# These operators don't require a value
pass
elif self.operator in ("in_range", "not_in_range", "length_in_range"):
# Range operators require a list of exactly 2 values
if not isinstance(self.value, (list, tuple)) or len(self.value) != 2:
raise ValueError(f"Operator '{self.operator}' requires a range value "
f"as [min, max], got: {self.value}")
# Validate that range values are numeric
for v in self.value:
if not isinstance(v, (int, float)):
raise ValueError(f"Operator '{self.operator}' requires numeric range values, "
f"got: {self.value}")
elif self.operator in ("gt", "gte", "lt", "lte", "length_gt", "length_lt"):
# Numeric operators require numeric values
if self.value is None:
raise ValueError(f"Operator '{self.operator}' requires a numeric value")
if not isinstance(self.value, (int, float)):
raise ValueError(f"Operator '{self.operator}' requires a numeric value, "
f"got: {type(self.value).__name__} '{self.value}'")
elif self.value is None and self.operator not in ("empty", "not_empty"):
raise ValueError(f"Operator '{self.operator}' requires a value")
def to_dict(self) -> Dict[str, Any]:
"""Convert condition to dictionary for serialization."""
result = {
"schema": self.schema,
"operator": self.operator,
}
if self.value is not None:
result["value"] = self.value
if self.case_sensitive:
result["case_sensitive"] = True
return result
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DisplayLogicCondition":
"""Create a condition from a dictionary."""
return cls(
schema=data["schema"],
operator=data["operator"],
value=data.get("value"),
case_sensitive=data.get("case_sensitive", False)
)
@dataclass
class DisplayLogicRule:
"""
Represents a complete display logic rule with multiple conditions.
Attributes:
conditions: List of DisplayLogicCondition objects
logic: 'all' (AND) or 'any' (OR) - how to combine conditions
"""
conditions: List[DisplayLogicCondition] = field(default_factory=list)
logic: str = "all" # 'all' = AND, 'any' = OR
def __post_init__(self):
"""Validate the rule after initialization."""
if self.logic not in ("all", "any"):
raise ValueError(f"Invalid logic type: {self.logic}. Must be 'all' or 'any'")
def get_watched_schemas(self) -> Set[str]:
"""Return set of schema names this rule depends on."""
return {condition.schema for condition in self.conditions}
def to_dict(self) -> Dict[str, Any]:
"""Convert rule to dictionary for serialization."""
return {
"show_when": [c.to_dict() for c in self.conditions],
"logic": self.logic
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DisplayLogicRule":
"""Create a rule from a dictionary (config format)."""
conditions = []
show_when = data.get("show_when", [])
for cond_data in show_when:
conditions.append(DisplayLogicCondition.from_dict(cond_data))
return cls(
conditions=conditions,
logic=data.get("logic", "all")
)
class DisplayLogicValidator:
"""
Validates display_logic configurations in annotation schemes.
Responsibilities:
- Validate condition syntax and operators
- Check that referenced schemas exist
- Detect circular dependencies
- Warn about potential issues
"""
def __init__(self, annotation_schemes: List[Dict[str, Any]]):
"""
Initialize the validator with all annotation schemes.
Args:
annotation_schemes: List of annotation scheme configurations
"""
self.schemes = annotation_schemes
self.schema_names = {s.get("name") for s in annotation_schemes if "name" in s}
self.dependency_graph: Dict[str, Set[str]] = {}
self._build_dependency_graph()
def _build_dependency_graph(self) -> None:
"""Build a graph of schema dependencies for cycle detection."""
for scheme in self.schemes:
schema_name = scheme.get("name")
if not schema_name:
continue
display_logic = scheme.get("display_logic", {})
if not display_logic:
self.dependency_graph[schema_name] = set()
continue
# Extract schemas this one depends on
dependencies = set()
show_when = display_logic.get("show_when", [])
for condition in show_when:
if "schema" in condition:
dependencies.add(condition["schema"])
self.dependency_graph[schema_name] = dependencies
def validate(self) -> Tuple[bool, List[str]]:
"""
Validate all display_logic configurations.
Returns:
Tuple of (is_valid, list_of_errors)
"""
errors = []
for scheme in self.schemes:
schema_name = scheme.get("name", "<unnamed>")
display_logic = scheme.get("display_logic")
# Skip if display_logic is not present (None) or not a dict
if display_logic is None:
continue
if not isinstance(display_logic, dict):
errors.append(f"Schema '{schema_name}': display_logic must be a dictionary")
continue
# Validate structure (including empty dict - which is invalid)
structure_errors = self._validate_structure(schema_name, display_logic)
errors.extend(structure_errors)
# Validate referenced schemas exist
reference_errors = self._validate_references(schema_name, display_logic)
errors.extend(reference_errors)
# Check for circular dependencies
cycle_errors = self._detect_cycles()
errors.extend(cycle_errors)
return len(errors) == 0, errors
def _validate_structure(self, schema_name: str, display_logic: Dict) -> List[str]:
"""Validate the structure of a display_logic configuration."""
errors = []
# Must have show_when
if "show_when" not in display_logic:
errors.append(f"Schema '{schema_name}': display_logic must have 'show_when' field")
return errors
show_when = display_logic["show_when"]
if not isinstance(show_when, list):
errors.append(f"Schema '{schema_name}': 'show_when' must be a list of conditions")
return errors
if len(show_when) == 0:
errors.append(f"Schema '{schema_name}': 'show_when' must have at least one condition")
return errors
# Validate each condition
for i, condition in enumerate(show_when):
prefix = f"Schema '{schema_name}', condition {i+1}"
if not isinstance(condition, dict):
errors.append(f"{prefix}: condition must be a dictionary")
continue
# Required fields
if "schema" not in condition:
errors.append(f"{prefix}: missing required 'schema' field")
if "operator" not in condition:
errors.append(f"{prefix}: missing required 'operator' field")
elif condition["operator"] not in SUPPORTED_OPERATORS:
errors.append(f"{prefix}: unsupported operator '{condition['operator']}'. "
f"Supported: {list(SUPPORTED_OPERATORS.keys())}")
# Validate operator-specific requirements
operator = condition.get("operator")
if operator:
op_errors = self._validate_operator_value(prefix, operator, condition.get("value"))
errors.extend(op_errors)
# Validate logic field if present
logic = display_logic.get("logic", "all")
if logic not in ("all", "any"):
errors.append(f"Schema '{schema_name}': 'logic' must be 'all' or 'any', got '{logic}'")
return errors
def _validate_operator_value(self, prefix: str, operator: str, value: Any) -> List[str]:
"""Validate that the value is appropriate for the operator."""
errors = []
# Operators that don't need a value
if operator in ("empty", "not_empty"):
return errors
# Range operators need [min, max]
if operator in ("in_range", "not_in_range", "length_in_range"):
if not isinstance(value, (list, tuple)):
errors.append(f"{prefix}: operator '{operator}' requires a range value as [min, max]")
elif len(value) != 2:
errors.append(f"{prefix}: range value must have exactly 2 elements [min, max]")
else:
try:
min_val, max_val = float(value[0]), float(value[1])
if min_val > max_val:
errors.append(f"{prefix}: range min ({min_val}) is greater than max ({max_val})")
except (ValueError, TypeError):
errors.append(f"{prefix}: range values must be numeric")
return errors
# Numeric operators need numeric values
if operator in ("gt", "gte", "lt", "lte", "length_gt", "length_lt"):
if value is None:
errors.append(f"{prefix}: operator '{operator}' requires a value")
else:
try:
float(value)
except (ValueError, TypeError):
errors.append(f"{prefix}: operator '{operator}' requires a numeric value")
return errors
# Regex operator needs a valid pattern
if operator == "matches":
if value is None:
errors.append(f"{prefix}: operator 'matches' requires a regex pattern")
else:
try:
re.compile(value)
except re.error as e:
errors.append(f"{prefix}: invalid regex pattern '{value}': {e}")
return errors
# Other operators just need a non-None value
if value is None:
errors.append(f"{prefix}: operator '{operator}' requires a value")
return errors
def _validate_references(self, schema_name: str, display_logic: Dict) -> List[str]:
"""Validate that all referenced schemas exist."""
errors = []
show_when = display_logic.get("show_when", [])
for i, condition in enumerate(show_when):
ref_schema = condition.get("schema")
if ref_schema and ref_schema not in self.schema_names:
errors.append(
f"Schema '{schema_name}', condition {i+1}: references unknown schema '{ref_schema}'"
)
return errors
def _detect_cycles(self) -> List[str]:
"""Detect circular dependencies using DFS."""
errors = []
visited = set()
rec_stack = set()
def dfs(node: str, path: List[str]) -> Optional[List[str]]:
"""DFS to detect cycles, returns cycle path if found."""
if node in rec_stack:
# Found a cycle
cycle_start = path.index(node)
return path[cycle_start:] + [node]
if node in visited:
return None
visited.add(node)
rec_stack.add(node)
for neighbor in self.dependency_graph.get(node, set()):
result = dfs(neighbor, path + [node])
if result:
return result
rec_stack.remove(node)
return None
for schema in self.dependency_graph:
if schema not in visited:
cycle = dfs(schema, [])
if cycle:
cycle_str = " -> ".join(cycle)
errors.append(f"Circular dependency detected: {cycle_str}")
return errors
def get_schema_dependencies(self, schema_name: str) -> Set[str]:
"""Get the schemas that a given schema depends on."""
return self.dependency_graph.get(schema_name, set())
def get_dependents(self, schema_name: str) -> Set[str]:
"""Get schemas that depend on the given schema."""
dependents = set()
for schema, deps in self.dependency_graph.items():
if schema_name in deps:
dependents.add(schema)
return dependents
class DisplayLogicEvaluator:
"""
Evaluates display logic conditions at runtime.
This class is used both server-side (Python) and provides the logic
that's replicated in the frontend JavaScript.
"""
@staticmethod
def evaluate_condition(
condition: DisplayLogicCondition,
schema_value: Any
) -> bool:
"""
Evaluate a single condition against a schema value.
Args:
condition: The condition to evaluate
schema_value: Current value of the watched schema
Returns:
bool: Whether the condition is satisfied
"""
operator = condition.operator
expected = condition.value
case_sensitive = condition.case_sensitive
# Handle empty checks first
if operator == "empty":
return DisplayLogicEvaluator._is_empty(schema_value)
if operator == "not_empty":
return not DisplayLogicEvaluator._is_empty(schema_value)
# For all other operators, normalize the actual value
actual = schema_value
# Apply case normalization for text comparisons
if not case_sensitive and isinstance(actual, str):
actual = actual.lower()
# Equality operators
if operator == "equals":
return DisplayLogicEvaluator._check_equals(actual, expected, case_sensitive)
if operator == "not_equals":
return not DisplayLogicEvaluator._check_equals(actual, expected, case_sensitive)
# Contains operators (for lists and text)
if operator == "contains":
return DisplayLogicEvaluator._check_contains(actual, expected, case_sensitive)
if operator == "not_contains":
return not DisplayLogicEvaluator._check_contains(actual, expected, case_sensitive)
# Regex matching
if operator == "matches":
if not isinstance(actual, str):
actual = str(actual) if actual is not None else ""
flags = 0 if case_sensitive else re.IGNORECASE
try:
return bool(re.search(expected, actual, flags))
except re.error:
logger.warning(f"Invalid regex pattern: {expected}")
return False
# Numeric comparisons
if operator in ("gt", "gte", "lt", "lte"):
return DisplayLogicEvaluator._check_numeric(operator, actual, expected)
# Range operators
if operator in ("in_range", "not_in_range"):
result = DisplayLogicEvaluator._check_range(actual, expected)
return result if operator == "in_range" else not result
# Length operators
if operator in ("length_gt", "length_lt"):
return DisplayLogicEvaluator._check_length(operator, actual, expected)
if operator == "length_in_range":
return DisplayLogicEvaluator._check_length_range(actual, expected)
logger.warning(f"Unknown operator: {operator}")
return False
@staticmethod
def _is_empty(value: Any) -> bool:
"""Check if a value is considered empty."""
if value is None:
return True
if isinstance(value, str):
return len(value.strip()) == 0
if isinstance(value, (list, dict, set)):
return len(value) == 0
return False
@staticmethod
def _check_equals(actual: Any, expected: Any, case_sensitive: bool) -> bool:
"""Check equality, handling single values and lists."""
# If expected is a list, check if actual matches ANY of them
if isinstance(expected, list):
for exp in expected:
if DisplayLogicEvaluator._values_equal(actual, exp, case_sensitive):
return True
return False
return DisplayLogicEvaluator._values_equal(actual, expected, case_sensitive)
@staticmethod
def _values_equal(actual: Any, expected: Any, case_sensitive: bool) -> bool:
"""Compare two values for equality."""
# Handle None
if actual is None and expected is None:
return True
if actual is None or expected is None:
return False
# String comparison with case sensitivity
if isinstance(expected, str):
actual_str = str(actual)
if not case_sensitive:
return actual_str.lower() == expected.lower()
return actual_str == expected
# Direct comparison for non-strings
return actual == expected
@staticmethod
def _check_contains(actual: Any, expected: Any, case_sensitive: bool) -> bool:
"""Check if actual contains expected value(s)."""
# If expected is a list, check if actual contains ANY of them
if isinstance(expected, list):
for exp in expected:
if DisplayLogicEvaluator._value_contains(actual, exp, case_sensitive):
return True
return False
return DisplayLogicEvaluator._value_contains(actual, expected, case_sensitive)
@staticmethod
def _value_contains(actual: Any, expected: Any, case_sensitive: bool) -> bool:
"""Check if actual contains a single expected value."""
# If actual is a list (multiselect), check membership
if isinstance(actual, list):
for item in actual:
if DisplayLogicEvaluator._values_equal(item, expected, case_sensitive):
return True
return False
# If actual is a string, check substring
if isinstance(actual, str):
expected_str = str(expected)
if not case_sensitive:
return expected_str.lower() in actual.lower()
return expected_str in actual
# Fallback to equality
return DisplayLogicEvaluator._values_equal(actual, expected, case_sensitive)
@staticmethod
def _check_numeric(operator: str, actual: Any, expected: Any) -> bool:
"""Check numeric comparison."""
try:
actual_num = float(actual) if actual is not None else 0
expected_num = float(expected)
except (ValueError, TypeError):
return False
if operator == "gt":
return actual_num > expected_num
if operator == "gte":
return actual_num >= expected_num
if operator == "lt":
return actual_num < expected_num
if operator == "lte":
return actual_num <= expected_num
return False
@staticmethod
def _check_range(actual: Any, range_val: List) -> bool:
"""Check if actual is within range (inclusive)."""
try:
actual_num = float(actual) if actual is not None else 0
min_val, max_val = float(range_val[0]), float(range_val[1])
except (ValueError, TypeError, IndexError):
return False
return min_val <= actual_num <= max_val
@staticmethod
def _check_length(operator: str, actual: Any, expected: Any) -> bool:
"""Check text length comparison."""
try:
length = len(str(actual)) if actual is not None else 0
expected_len = int(expected)
except (ValueError, TypeError):
return False
if operator == "length_gt":
return length > expected_len
if operator == "length_lt":
return length < expected_len
return False
@staticmethod
def _check_length_range(actual: Any, range_val: List) -> bool:
"""Check if text length is within range (inclusive)."""
try:
length = len(str(actual)) if actual is not None else 0
min_len, max_len = int(range_val[0]), int(range_val[1])
except (ValueError, TypeError, IndexError):
return False
return min_len <= length <= max_len
@staticmethod
def evaluate_rule(
rule: DisplayLogicRule,
annotations: Dict[str, Any]
) -> bool:
"""
Evaluate a complete display logic rule.
Args:
rule: The DisplayLogicRule to evaluate
annotations: Current annotations dictionary {schema_name: value}
Returns:
bool: Whether the schema should be visible
"""
if not rule.conditions:
# No conditions = always visible
return True
results = []
for condition in rule.conditions:
schema_value = annotations.get(condition.schema)
result = DisplayLogicEvaluator.evaluate_condition(condition, schema_value)
results.append(result)
if rule.logic == "all":
return all(results)
else: # "any"
return any(results)
@staticmethod
def evaluate_visibility(
schema_name: str,
display_logic: Optional[Dict],
annotations: Dict[str, Any]
) -> Tuple[bool, Optional[str]]:
"""
Evaluate whether a schema should be visible.
Args:
schema_name: Name of the schema being evaluated
display_logic: The display_logic configuration (can be None)
annotations: Current annotations dictionary
Returns:
Tuple of (is_visible, reason_if_hidden)
"""
if not display_logic:
return True, None
try:
rule = DisplayLogicRule.from_dict(display_logic)
is_visible = DisplayLogicEvaluator.evaluate_rule(rule, annotations)
if not is_visible:
# Build reason string
reasons = []
for cond in rule.conditions:
actual_val = annotations.get(cond.schema, "<not set>")
reasons.append(f"{cond.schema} {cond.operator} {cond.value} (actual: {actual_val})")
reason = f"Conditions not met ({rule.logic}): " + ", ".join(reasons)
return False, reason
return True, None
except Exception as e:
logger.error(f"Error evaluating display logic for {schema_name}: {e}")
# Default to visible on error
return True, None
def validate_display_logic_config(
annotation_schemes: List[Dict[str, Any]]
) -> Tuple[bool, List[str]]:
"""
Convenience function to validate display_logic across all annotation schemes.
Args:
annotation_schemes: List of annotation scheme configurations
Returns:
Tuple of (is_valid, list_of_errors)
"""
validator = DisplayLogicValidator(annotation_schemes)
return validator.validate()
def get_display_logic_dependencies(
annotation_schemes: List[Dict[str, Any]]
) -> Dict[str, Set[str]]:
"""
Get the dependency graph for all schemas with display_logic.
Args:
annotation_schemes: List of annotation scheme configurations
Returns:
Dictionary mapping schema names to their dependencies
"""
validator = DisplayLogicValidator(annotation_schemes)
return validator.dependency_graph