""" Correlation module — entity linking and cross-record correlation. Engine contract: run(EngineInput) -> EngineOutput Groups records by shared entity keys (email, IP, phone, domain, hash) and produces correlation metadata. No external dependencies (Neo4j is optional and handled separately). """ from __future__ import annotations import logging from collections import defaultdict from typing import Any, Dict, List, Set from engine.io_contract import ( EngineInput, EngineOutput, NormalizedRecord, StageStatus, ) logger = logging.getLogger("modules.correlation") def _build_entity_index(records: List[NormalizedRecord]) -> Dict[str, List[str]]: """ Build an inverted index: entity_value → [row_ids]. """ index: Dict[str, List[str]] = defaultdict(list) for r in records: for field in ("entity_email", "entity_ip", "entity_phone", "entity_domain", "entity_hash"): val = getattr(r, field, None) if val: key = f"{field}:{val}" index[key].append(r.row_id) return dict(index) def _find_clusters(index: Dict[str, List[str]]) -> List[Set[str]]: """ Find clusters of row_ids that share at least one entity value. Uses simple union-find. """ parent: Dict[str, str] = {} def find(x: str) -> str: while parent.get(x, x) != x: parent[x] = parent.get(parent[x], parent[x]) x = parent[x] return x def union(a: str, b: str) -> None: ra, rb = find(a), find(b) if ra != rb: parent[ra] = rb for entity_key, row_ids in index.items(): if len(row_ids) > 1: first = row_ids[0] for rid in row_ids[1:]: union(first, rid) clusters: Dict[str, Set[str]] = defaultdict(set) all_ids = set() for row_ids in index.values(): all_ids.update(row_ids) for rid in all_ids: root = find(rid) clusters[root].add(rid) # Only return clusters with 2+ members return [c for c in clusters.values() if len(c) > 1] def run(engine_input: EngineInput) -> EngineOutput: """ Correlate records by shared entity values. """ try: records = engine_input.records index = _build_entity_index(records) clusters = _find_clusters(index) # Annotate records with cluster IDs row_to_cluster: Dict[str, int] = {} for i, cluster in enumerate(clusters): for rid in cluster: row_to_cluster[rid] = i annotated: List[NormalizedRecord] = [] for r in records: cluster_id = row_to_cluster.get(r.row_id) if cluster_id is not None: extra = dict(r.extra) extra["correlation_cluster"] = cluster_id annotated.append(r.model_copy(update={"extra": extra})) else: annotated.append(r) correlated_count = len(row_to_cluster) return EngineOutput( stage="correlation", status=StageStatus.SUCCESS, records=annotated, summary=f"Found {len(clusters)} clusters linking {correlated_count} records", metadata={ "cluster_count": len(clusters), "correlated_records": correlated_count, "entity_index_size": len(index), }, ) except Exception as exc: logger.error("Correlation failed: %s", exc, exc_info=True) return EngineOutput( stage="correlation", status=StageStatus.FAILED, error=str(exc), ) if __name__ == "__main__": print("Correlation module — use via engine pipeline")