""" Shared notation parsing utilities for PSI2. Parses Einstein-like notation strings (e.g. "rgb0,c01->rgb1") into structured element dicts with modality class references via GlobalRegistry. """ from __future__ import annotations import re from dataclasses import dataclass from typing import List, Dict, Set, Tuple, Any, Type, Union from .modalities import ( GlobalRegistry, SerializableModality, NodeModality, EdgeModality, ) @dataclass class ParsedNotation: name: str modalities: Set[Type[SerializableModality]] nodes: Dict[Type[SerializableModality], List[int]] edges: Dict[Type[SerializableModality], Union[List[int], List[Tuple[int, int]]]] elements: List[Tuple[Type[SerializableModality], Union[int, Tuple[int, int]]]] element_names: List[str] def parse_element(element_str: str) -> Dict[str, Any]: """ Parse a single element string. Format examples: rgb0 -> RGB frame at index 0 f01 -> Flow from frame 0 to frame 1 d0 -> Disparity at frame 0 c01 -> Camera pose from frame 0 to frame 1 Returns: Dict with keys: modality, modality_cls, indices, raw """ modality_name = re.match(r'^[^0-9]*', element_str) if not modality_name: raise ValueError(f"Invalid element string: {element_str}") modality_name = modality_name.group() modality_cls = GlobalRegistry.get(modality_name) if modality_cls is None: raise ValueError(f"Unknown modality: {element_str} (valid names: {GlobalRegistry.names_str()})") times_str = element_str.removeprefix(modality_name) if times_str and not times_str.isdigit(): match = re.match(r"^(\d+)(.+)$", times_str) if match: digits, trailing = match.groups() raise ValueError( f"Invalid element '{element_str}': time indices for '{modality_name}' must be digits only. " f"Did you mean '{modality_name}{digits},{trailing}'?" ) raise ValueError( f"Invalid element '{element_str}': time indices for '{modality_name}' must be digits only." ) times = [x for x in times_str] if issubclass(modality_cls, NodeModality): if len(times) != 1: raise ValueError(f"{modality_name} is a Node and expects 1 time index: {element_str}") elif issubclass(modality_cls, EdgeModality): if len(times) != 2: raise ValueError(f"{modality_name} is an Edge and expects 2 time indices: {element_str}") else: if times != ['']: raise ValueError( f"{modality_name} does not take time indices, but got {times}: {element_str}" ) if times != ['']: times = list(map(int, times)) else: times = [] return {"modality": modality_name, "modality_cls": modality_cls, "indices": times, "raw": element_str} def parse_notation(notation: str) -> List[Dict[str, Any]]: """Parse a notation string into a list of element dictionaries.""" elements = [] for part in notation.split(","): part = part.strip() if not part: continue elements.append(parse_element(part)) return elements def analyze_notation(name: str, elements: List[Dict[str, Any]]) -> ParsedNotation: """Analyze parsed notation elements to determine requirements.""" nodes = {} edges = {} out_elements = [] for e in elements: modality_cls, indices = e['modality_cls'], e['indices'] if len(indices) == 2: edges.setdefault(modality_cls, set()).add(tuple(indices)) out_elements.append((modality_cls, tuple(indices))) elif len(indices) == 1: nodes.setdefault(modality_cls, set()).add(indices[0]) out_elements.append((modality_cls, indices[0])) else: out_elements.append((modality_cls, 0)) return ParsedNotation( name=name, modalities=set(e['modality_cls'] for e in elements), nodes={k: list(v) for k, v in nodes.items()}, edges={k: list(v) for k, v in edges.items()}, elements=out_elements, element_names=[e['raw'] for e in elements], ) def parse_full_notation(notation: str) -> Tuple[List[Dict], List[Dict], ParsedNotation, ParsedNotation]: """ Parse full notation with input/output split on '->'. Returns: input_elements: List of parsed element dicts output_elements: List of parsed element dicts input_analysis: ParsedNotation for inputs output_analysis: ParsedNotation for outputs """ if "->" not in notation: raise ValueError(f"Notation must contain '->' to separate inputs from outputs: {notation}") input_str, output_str = notation.split("->", 1) input_elements = parse_notation(input_str) output_elements = parse_notation(output_str) if not output_elements: raise ValueError(f"No output elements specified in notation: {notation}") input_analysis = analyze_notation("input", input_elements) output_analysis = analyze_notation("output", output_elements) return input_elements, output_elements, input_analysis, output_analysis