Spaces:
Paused
Paused
| """ | |
| Display Logic Module for Conditional Schema Branching | |
| This module provides the core validation and evaluation logic for conditional | |
| annotation schemas. It allows schemas to show/hide based on user responses | |
| to other schemas. | |
| Key Components: | |
| - DisplayLogicCondition: Represents a single condition (e.g., "schema X equals 'Yes'") | |
| - DisplayLogicRule: Represents a complete rule with multiple conditions and AND/OR logic | |
| - DisplayLogicValidator: Validates display_logic configurations | |
| - DisplayLogicEvaluator: Evaluates conditions at runtime | |
| Example Configuration: | |
| display_logic: | |
| show_when: | |
| - schema: contains_pii | |
| operator: equals | |
| value: "Yes" | |
| logic: all # 'all' = AND, 'any' = OR | |
| """ | |
| import re | |
| import logging | |
| from dataclasses import dataclass, field | |
| from typing import Any, Dict, List, Optional, Set, Tuple, Union | |
| logger = logging.getLogger(__name__) | |
| # Supported operators and their descriptions | |
| SUPPORTED_OPERATORS = { | |
| # Value comparison | |
| "equals": "Exact value match (single value or list of values)", | |
| "not_equals": "Value doesn't match any specified values", | |
| # Collection operators | |
| "contains": "List/text contains value(s)", | |
| "not_contains": "List/text doesn't contain value(s)", | |
| # Regex | |
| "matches": "Regex pattern match", | |
| # Numeric comparison | |
| "gt": "Greater than", | |
| "gte": "Greater than or equal", | |
| "lt": "Less than", | |
| "lte": "Less than or equal", | |
| "in_range": "Value is within range (inclusive)", | |
| "not_in_range": "Value is outside range", | |
| # Emptiness | |
| "empty": "Field is empty or not set", | |
| "not_empty": "Field has a value", | |
| # Text length | |
| "length_gt": "Text length greater than", | |
| "length_lt": "Text length less than", | |
| "length_in_range": "Text length within range (inclusive)", | |
| } | |
| class DisplayLogicCondition: | |
| """ | |
| Represents a single condition in a display logic rule. | |
| Attributes: | |
| schema: Name of the schema to watch | |
| operator: Comparison operator (equals, contains, gt, etc.) | |
| value: Value(s) to compare against (can be single value, list, or range) | |
| case_sensitive: Whether text comparisons are case-sensitive (default: False) | |
| """ | |
| schema: str | |
| operator: str | |
| value: Any = None | |
| case_sensitive: bool = False | |
| def __post_init__(self): | |
| """Validate the condition after initialization.""" | |
| if self.operator not in SUPPORTED_OPERATORS: | |
| raise ValueError(f"Unsupported operator: {self.operator}. " | |
| f"Supported operators: {list(SUPPORTED_OPERATORS.keys())}") | |
| # Validate operator-specific requirements | |
| if self.operator in ("empty", "not_empty"): | |
| # These operators don't require a value | |
| pass | |
| elif self.operator in ("in_range", "not_in_range", "length_in_range"): | |
| # Range operators require a list of exactly 2 values | |
| if not isinstance(self.value, (list, tuple)) or len(self.value) != 2: | |
| raise ValueError(f"Operator '{self.operator}' requires a range value " | |
| f"as [min, max], got: {self.value}") | |
| # Validate that range values are numeric | |
| for v in self.value: | |
| if not isinstance(v, (int, float)): | |
| raise ValueError(f"Operator '{self.operator}' requires numeric range values, " | |
| f"got: {self.value}") | |
| elif self.operator in ("gt", "gte", "lt", "lte", "length_gt", "length_lt"): | |
| # Numeric operators require numeric values | |
| if self.value is None: | |
| raise ValueError(f"Operator '{self.operator}' requires a numeric value") | |
| if not isinstance(self.value, (int, float)): | |
| raise ValueError(f"Operator '{self.operator}' requires a numeric value, " | |
| f"got: {type(self.value).__name__} '{self.value}'") | |
| elif self.value is None and self.operator not in ("empty", "not_empty"): | |
| raise ValueError(f"Operator '{self.operator}' requires a value") | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert condition to dictionary for serialization.""" | |
| result = { | |
| "schema": self.schema, | |
| "operator": self.operator, | |
| } | |
| if self.value is not None: | |
| result["value"] = self.value | |
| if self.case_sensitive: | |
| result["case_sensitive"] = True | |
| return result | |
| def from_dict(cls, data: Dict[str, Any]) -> "DisplayLogicCondition": | |
| """Create a condition from a dictionary.""" | |
| return cls( | |
| schema=data["schema"], | |
| operator=data["operator"], | |
| value=data.get("value"), | |
| case_sensitive=data.get("case_sensitive", False) | |
| ) | |
| class DisplayLogicRule: | |
| """ | |
| Represents a complete display logic rule with multiple conditions. | |
| Attributes: | |
| conditions: List of DisplayLogicCondition objects | |
| logic: 'all' (AND) or 'any' (OR) - how to combine conditions | |
| """ | |
| conditions: List[DisplayLogicCondition] = field(default_factory=list) | |
| logic: str = "all" # 'all' = AND, 'any' = OR | |
| def __post_init__(self): | |
| """Validate the rule after initialization.""" | |
| if self.logic not in ("all", "any"): | |
| raise ValueError(f"Invalid logic type: {self.logic}. Must be 'all' or 'any'") | |
| def get_watched_schemas(self) -> Set[str]: | |
| """Return set of schema names this rule depends on.""" | |
| return {condition.schema for condition in self.conditions} | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert rule to dictionary for serialization.""" | |
| return { | |
| "show_when": [c.to_dict() for c in self.conditions], | |
| "logic": self.logic | |
| } | |
| def from_dict(cls, data: Dict[str, Any]) -> "DisplayLogicRule": | |
| """Create a rule from a dictionary (config format).""" | |
| conditions = [] | |
| show_when = data.get("show_when", []) | |
| for cond_data in show_when: | |
| conditions.append(DisplayLogicCondition.from_dict(cond_data)) | |
| return cls( | |
| conditions=conditions, | |
| logic=data.get("logic", "all") | |
| ) | |
| class DisplayLogicValidator: | |
| """ | |
| Validates display_logic configurations in annotation schemes. | |
| Responsibilities: | |
| - Validate condition syntax and operators | |
| - Check that referenced schemas exist | |
| - Detect circular dependencies | |
| - Warn about potential issues | |
| """ | |
| def __init__(self, annotation_schemes: List[Dict[str, Any]]): | |
| """ | |
| Initialize the validator with all annotation schemes. | |
| Args: | |
| annotation_schemes: List of annotation scheme configurations | |
| """ | |
| self.schemes = annotation_schemes | |
| self.schema_names = {s.get("name") for s in annotation_schemes if "name" in s} | |
| self.dependency_graph: Dict[str, Set[str]] = {} | |
| self._build_dependency_graph() | |
| def _build_dependency_graph(self) -> None: | |
| """Build a graph of schema dependencies for cycle detection.""" | |
| for scheme in self.schemes: | |
| schema_name = scheme.get("name") | |
| if not schema_name: | |
| continue | |
| display_logic = scheme.get("display_logic", {}) | |
| if not display_logic: | |
| self.dependency_graph[schema_name] = set() | |
| continue | |
| # Extract schemas this one depends on | |
| dependencies = set() | |
| show_when = display_logic.get("show_when", []) | |
| for condition in show_when: | |
| if "schema" in condition: | |
| dependencies.add(condition["schema"]) | |
| self.dependency_graph[schema_name] = dependencies | |
| def validate(self) -> Tuple[bool, List[str]]: | |
| """ | |
| Validate all display_logic configurations. | |
| Returns: | |
| Tuple of (is_valid, list_of_errors) | |
| """ | |
| errors = [] | |
| for scheme in self.schemes: | |
| schema_name = scheme.get("name", "<unnamed>") | |
| display_logic = scheme.get("display_logic") | |
| # Skip if display_logic is not present (None) or not a dict | |
| if display_logic is None: | |
| continue | |
| if not isinstance(display_logic, dict): | |
| errors.append(f"Schema '{schema_name}': display_logic must be a dictionary") | |
| continue | |
| # Validate structure (including empty dict - which is invalid) | |
| structure_errors = self._validate_structure(schema_name, display_logic) | |
| errors.extend(structure_errors) | |
| # Validate referenced schemas exist | |
| reference_errors = self._validate_references(schema_name, display_logic) | |
| errors.extend(reference_errors) | |
| # Check for circular dependencies | |
| cycle_errors = self._detect_cycles() | |
| errors.extend(cycle_errors) | |
| return len(errors) == 0, errors | |
| def _validate_structure(self, schema_name: str, display_logic: Dict) -> List[str]: | |
| """Validate the structure of a display_logic configuration.""" | |
| errors = [] | |
| # Must have show_when | |
| if "show_when" not in display_logic: | |
| errors.append(f"Schema '{schema_name}': display_logic must have 'show_when' field") | |
| return errors | |
| show_when = display_logic["show_when"] | |
| if not isinstance(show_when, list): | |
| errors.append(f"Schema '{schema_name}': 'show_when' must be a list of conditions") | |
| return errors | |
| if len(show_when) == 0: | |
| errors.append(f"Schema '{schema_name}': 'show_when' must have at least one condition") | |
| return errors | |
| # Validate each condition | |
| for i, condition in enumerate(show_when): | |
| prefix = f"Schema '{schema_name}', condition {i+1}" | |
| if not isinstance(condition, dict): | |
| errors.append(f"{prefix}: condition must be a dictionary") | |
| continue | |
| # Required fields | |
| if "schema" not in condition: | |
| errors.append(f"{prefix}: missing required 'schema' field") | |
| if "operator" not in condition: | |
| errors.append(f"{prefix}: missing required 'operator' field") | |
| elif condition["operator"] not in SUPPORTED_OPERATORS: | |
| errors.append(f"{prefix}: unsupported operator '{condition['operator']}'. " | |
| f"Supported: {list(SUPPORTED_OPERATORS.keys())}") | |
| # Validate operator-specific requirements | |
| operator = condition.get("operator") | |
| if operator: | |
| op_errors = self._validate_operator_value(prefix, operator, condition.get("value")) | |
| errors.extend(op_errors) | |
| # Validate logic field if present | |
| logic = display_logic.get("logic", "all") | |
| if logic not in ("all", "any"): | |
| errors.append(f"Schema '{schema_name}': 'logic' must be 'all' or 'any', got '{logic}'") | |
| return errors | |
| def _validate_operator_value(self, prefix: str, operator: str, value: Any) -> List[str]: | |
| """Validate that the value is appropriate for the operator.""" | |
| errors = [] | |
| # Operators that don't need a value | |
| if operator in ("empty", "not_empty"): | |
| return errors | |
| # Range operators need [min, max] | |
| if operator in ("in_range", "not_in_range", "length_in_range"): | |
| if not isinstance(value, (list, tuple)): | |
| errors.append(f"{prefix}: operator '{operator}' requires a range value as [min, max]") | |
| elif len(value) != 2: | |
| errors.append(f"{prefix}: range value must have exactly 2 elements [min, max]") | |
| else: | |
| try: | |
| min_val, max_val = float(value[0]), float(value[1]) | |
| if min_val > max_val: | |
| errors.append(f"{prefix}: range min ({min_val}) is greater than max ({max_val})") | |
| except (ValueError, TypeError): | |
| errors.append(f"{prefix}: range values must be numeric") | |
| return errors | |
| # Numeric operators need numeric values | |
| if operator in ("gt", "gte", "lt", "lte", "length_gt", "length_lt"): | |
| if value is None: | |
| errors.append(f"{prefix}: operator '{operator}' requires a value") | |
| else: | |
| try: | |
| float(value) | |
| except (ValueError, TypeError): | |
| errors.append(f"{prefix}: operator '{operator}' requires a numeric value") | |
| return errors | |
| # Regex operator needs a valid pattern | |
| if operator == "matches": | |
| if value is None: | |
| errors.append(f"{prefix}: operator 'matches' requires a regex pattern") | |
| else: | |
| try: | |
| re.compile(value) | |
| except re.error as e: | |
| errors.append(f"{prefix}: invalid regex pattern '{value}': {e}") | |
| return errors | |
| # Other operators just need a non-None value | |
| if value is None: | |
| errors.append(f"{prefix}: operator '{operator}' requires a value") | |
| return errors | |
| def _validate_references(self, schema_name: str, display_logic: Dict) -> List[str]: | |
| """Validate that all referenced schemas exist.""" | |
| errors = [] | |
| show_when = display_logic.get("show_when", []) | |
| for i, condition in enumerate(show_when): | |
| ref_schema = condition.get("schema") | |
| if ref_schema and ref_schema not in self.schema_names: | |
| errors.append( | |
| f"Schema '{schema_name}', condition {i+1}: references unknown schema '{ref_schema}'" | |
| ) | |
| return errors | |
| def _detect_cycles(self) -> List[str]: | |
| """Detect circular dependencies using DFS.""" | |
| errors = [] | |
| visited = set() | |
| rec_stack = set() | |
| def dfs(node: str, path: List[str]) -> Optional[List[str]]: | |
| """DFS to detect cycles, returns cycle path if found.""" | |
| if node in rec_stack: | |
| # Found a cycle | |
| cycle_start = path.index(node) | |
| return path[cycle_start:] + [node] | |
| if node in visited: | |
| return None | |
| visited.add(node) | |
| rec_stack.add(node) | |
| for neighbor in self.dependency_graph.get(node, set()): | |
| result = dfs(neighbor, path + [node]) | |
| if result: | |
| return result | |
| rec_stack.remove(node) | |
| return None | |
| for schema in self.dependency_graph: | |
| if schema not in visited: | |
| cycle = dfs(schema, []) | |
| if cycle: | |
| cycle_str = " -> ".join(cycle) | |
| errors.append(f"Circular dependency detected: {cycle_str}") | |
| return errors | |
| def get_schema_dependencies(self, schema_name: str) -> Set[str]: | |
| """Get the schemas that a given schema depends on.""" | |
| return self.dependency_graph.get(schema_name, set()) | |
| def get_dependents(self, schema_name: str) -> Set[str]: | |
| """Get schemas that depend on the given schema.""" | |
| dependents = set() | |
| for schema, deps in self.dependency_graph.items(): | |
| if schema_name in deps: | |
| dependents.add(schema) | |
| return dependents | |
| class DisplayLogicEvaluator: | |
| """ | |
| Evaluates display logic conditions at runtime. | |
| This class is used both server-side (Python) and provides the logic | |
| that's replicated in the frontend JavaScript. | |
| """ | |
| def evaluate_condition( | |
| condition: DisplayLogicCondition, | |
| schema_value: Any | |
| ) -> bool: | |
| """ | |
| Evaluate a single condition against a schema value. | |
| Args: | |
| condition: The condition to evaluate | |
| schema_value: Current value of the watched schema | |
| Returns: | |
| bool: Whether the condition is satisfied | |
| """ | |
| operator = condition.operator | |
| expected = condition.value | |
| case_sensitive = condition.case_sensitive | |
| # Handle empty checks first | |
| if operator == "empty": | |
| return DisplayLogicEvaluator._is_empty(schema_value) | |
| if operator == "not_empty": | |
| return not DisplayLogicEvaluator._is_empty(schema_value) | |
| # For all other operators, normalize the actual value | |
| actual = schema_value | |
| # Apply case normalization for text comparisons | |
| if not case_sensitive and isinstance(actual, str): | |
| actual = actual.lower() | |
| # Equality operators | |
| if operator == "equals": | |
| return DisplayLogicEvaluator._check_equals(actual, expected, case_sensitive) | |
| if operator == "not_equals": | |
| return not DisplayLogicEvaluator._check_equals(actual, expected, case_sensitive) | |
| # Contains operators (for lists and text) | |
| if operator == "contains": | |
| return DisplayLogicEvaluator._check_contains(actual, expected, case_sensitive) | |
| if operator == "not_contains": | |
| return not DisplayLogicEvaluator._check_contains(actual, expected, case_sensitive) | |
| # Regex matching | |
| if operator == "matches": | |
| if not isinstance(actual, str): | |
| actual = str(actual) if actual is not None else "" | |
| flags = 0 if case_sensitive else re.IGNORECASE | |
| try: | |
| return bool(re.search(expected, actual, flags)) | |
| except re.error: | |
| logger.warning(f"Invalid regex pattern: {expected}") | |
| return False | |
| # Numeric comparisons | |
| if operator in ("gt", "gte", "lt", "lte"): | |
| return DisplayLogicEvaluator._check_numeric(operator, actual, expected) | |
| # Range operators | |
| if operator in ("in_range", "not_in_range"): | |
| result = DisplayLogicEvaluator._check_range(actual, expected) | |
| return result if operator == "in_range" else not result | |
| # Length operators | |
| if operator in ("length_gt", "length_lt"): | |
| return DisplayLogicEvaluator._check_length(operator, actual, expected) | |
| if operator == "length_in_range": | |
| return DisplayLogicEvaluator._check_length_range(actual, expected) | |
| logger.warning(f"Unknown operator: {operator}") | |
| return False | |
| def _is_empty(value: Any) -> bool: | |
| """Check if a value is considered empty.""" | |
| if value is None: | |
| return True | |
| if isinstance(value, str): | |
| return len(value.strip()) == 0 | |
| if isinstance(value, (list, dict, set)): | |
| return len(value) == 0 | |
| return False | |
| def _check_equals(actual: Any, expected: Any, case_sensitive: bool) -> bool: | |
| """Check equality, handling single values and lists.""" | |
| # If expected is a list, check if actual matches ANY of them | |
| if isinstance(expected, list): | |
| for exp in expected: | |
| if DisplayLogicEvaluator._values_equal(actual, exp, case_sensitive): | |
| return True | |
| return False | |
| return DisplayLogicEvaluator._values_equal(actual, expected, case_sensitive) | |
| def _values_equal(actual: Any, expected: Any, case_sensitive: bool) -> bool: | |
| """Compare two values for equality.""" | |
| # Handle None | |
| if actual is None and expected is None: | |
| return True | |
| if actual is None or expected is None: | |
| return False | |
| # String comparison with case sensitivity | |
| if isinstance(expected, str): | |
| actual_str = str(actual) | |
| if not case_sensitive: | |
| return actual_str.lower() == expected.lower() | |
| return actual_str == expected | |
| # Direct comparison for non-strings | |
| return actual == expected | |
| def _check_contains(actual: Any, expected: Any, case_sensitive: bool) -> bool: | |
| """Check if actual contains expected value(s).""" | |
| # If expected is a list, check if actual contains ANY of them | |
| if isinstance(expected, list): | |
| for exp in expected: | |
| if DisplayLogicEvaluator._value_contains(actual, exp, case_sensitive): | |
| return True | |
| return False | |
| return DisplayLogicEvaluator._value_contains(actual, expected, case_sensitive) | |
| def _value_contains(actual: Any, expected: Any, case_sensitive: bool) -> bool: | |
| """Check if actual contains a single expected value.""" | |
| # If actual is a list (multiselect), check membership | |
| if isinstance(actual, list): | |
| for item in actual: | |
| if DisplayLogicEvaluator._values_equal(item, expected, case_sensitive): | |
| return True | |
| return False | |
| # If actual is a string, check substring | |
| if isinstance(actual, str): | |
| expected_str = str(expected) | |
| if not case_sensitive: | |
| return expected_str.lower() in actual.lower() | |
| return expected_str in actual | |
| # Fallback to equality | |
| return DisplayLogicEvaluator._values_equal(actual, expected, case_sensitive) | |
| def _check_numeric(operator: str, actual: Any, expected: Any) -> bool: | |
| """Check numeric comparison.""" | |
| try: | |
| actual_num = float(actual) if actual is not None else 0 | |
| expected_num = float(expected) | |
| except (ValueError, TypeError): | |
| return False | |
| if operator == "gt": | |
| return actual_num > expected_num | |
| if operator == "gte": | |
| return actual_num >= expected_num | |
| if operator == "lt": | |
| return actual_num < expected_num | |
| if operator == "lte": | |
| return actual_num <= expected_num | |
| return False | |
| def _check_range(actual: Any, range_val: List) -> bool: | |
| """Check if actual is within range (inclusive).""" | |
| try: | |
| actual_num = float(actual) if actual is not None else 0 | |
| min_val, max_val = float(range_val[0]), float(range_val[1]) | |
| except (ValueError, TypeError, IndexError): | |
| return False | |
| return min_val <= actual_num <= max_val | |
| def _check_length(operator: str, actual: Any, expected: Any) -> bool: | |
| """Check text length comparison.""" | |
| try: | |
| length = len(str(actual)) if actual is not None else 0 | |
| expected_len = int(expected) | |
| except (ValueError, TypeError): | |
| return False | |
| if operator == "length_gt": | |
| return length > expected_len | |
| if operator == "length_lt": | |
| return length < expected_len | |
| return False | |
| def _check_length_range(actual: Any, range_val: List) -> bool: | |
| """Check if text length is within range (inclusive).""" | |
| try: | |
| length = len(str(actual)) if actual is not None else 0 | |
| min_len, max_len = int(range_val[0]), int(range_val[1]) | |
| except (ValueError, TypeError, IndexError): | |
| return False | |
| return min_len <= length <= max_len | |
| def evaluate_rule( | |
| rule: DisplayLogicRule, | |
| annotations: Dict[str, Any] | |
| ) -> bool: | |
| """ | |
| Evaluate a complete display logic rule. | |
| Args: | |
| rule: The DisplayLogicRule to evaluate | |
| annotations: Current annotations dictionary {schema_name: value} | |
| Returns: | |
| bool: Whether the schema should be visible | |
| """ | |
| if not rule.conditions: | |
| # No conditions = always visible | |
| return True | |
| results = [] | |
| for condition in rule.conditions: | |
| schema_value = annotations.get(condition.schema) | |
| result = DisplayLogicEvaluator.evaluate_condition(condition, schema_value) | |
| results.append(result) | |
| if rule.logic == "all": | |
| return all(results) | |
| else: # "any" | |
| return any(results) | |
| def evaluate_visibility( | |
| schema_name: str, | |
| display_logic: Optional[Dict], | |
| annotations: Dict[str, Any] | |
| ) -> Tuple[bool, Optional[str]]: | |
| """ | |
| Evaluate whether a schema should be visible. | |
| Args: | |
| schema_name: Name of the schema being evaluated | |
| display_logic: The display_logic configuration (can be None) | |
| annotations: Current annotations dictionary | |
| Returns: | |
| Tuple of (is_visible, reason_if_hidden) | |
| """ | |
| if not display_logic: | |
| return True, None | |
| try: | |
| rule = DisplayLogicRule.from_dict(display_logic) | |
| is_visible = DisplayLogicEvaluator.evaluate_rule(rule, annotations) | |
| if not is_visible: | |
| # Build reason string | |
| reasons = [] | |
| for cond in rule.conditions: | |
| actual_val = annotations.get(cond.schema, "<not set>") | |
| reasons.append(f"{cond.schema} {cond.operator} {cond.value} (actual: {actual_val})") | |
| reason = f"Conditions not met ({rule.logic}): " + ", ".join(reasons) | |
| return False, reason | |
| return True, None | |
| except Exception as e: | |
| logger.error(f"Error evaluating display logic for {schema_name}: {e}") | |
| # Default to visible on error | |
| return True, None | |
| def validate_display_logic_config( | |
| annotation_schemes: List[Dict[str, Any]] | |
| ) -> Tuple[bool, List[str]]: | |
| """ | |
| Convenience function to validate display_logic across all annotation schemes. | |
| Args: | |
| annotation_schemes: List of annotation scheme configurations | |
| Returns: | |
| Tuple of (is_valid, list_of_errors) | |
| """ | |
| validator = DisplayLogicValidator(annotation_schemes) | |
| return validator.validate() | |
| def get_display_logic_dependencies( | |
| annotation_schemes: List[Dict[str, Any]] | |
| ) -> Dict[str, Set[str]]: | |
| """ | |
| Get the dependency graph for all schemas with display_logic. | |
| Args: | |
| annotation_schemes: List of annotation scheme configurations | |
| Returns: | |
| Dictionary mapping schema names to their dependencies | |
| """ | |
| validator = DisplayLogicValidator(annotation_schemes) | |
| return validator.dependency_graph | |