| """ |
| Correlation module — entity linking and cross-record correlation. |
| |
| Engine contract: |
| run(EngineInput) -> EngineOutput |
| |
| Groups records by shared entity keys (email, IP, phone, domain, hash) |
| and produces correlation metadata. No external dependencies (Neo4j |
| is optional and handled separately). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| from collections import defaultdict |
| from typing import Any, Dict, List, Set |
|
|
| from engine.io_contract import ( |
| EngineInput, |
| EngineOutput, |
| NormalizedRecord, |
| StageStatus, |
| ) |
|
|
| logger = logging.getLogger("modules.correlation") |
|
|
|
|
| def _build_entity_index(records: List[NormalizedRecord]) -> Dict[str, List[str]]: |
| """ |
| Build an inverted index: entity_value → [row_ids]. |
| """ |
| index: Dict[str, List[str]] = defaultdict(list) |
| for r in records: |
| for field in ("entity_email", "entity_ip", "entity_phone", |
| "entity_domain", "entity_hash"): |
| val = getattr(r, field, None) |
| if val: |
| key = f"{field}:{val}" |
| index[key].append(r.row_id) |
| return dict(index) |
|
|
|
|
| def _find_clusters(index: Dict[str, List[str]]) -> List[Set[str]]: |
| """ |
| Find clusters of row_ids that share at least one entity value. |
| Uses simple union-find. |
| """ |
| parent: Dict[str, str] = {} |
|
|
| def find(x: str) -> str: |
| while parent.get(x, x) != x: |
| parent[x] = parent.get(parent[x], parent[x]) |
| x = parent[x] |
| return x |
|
|
| def union(a: str, b: str) -> None: |
| ra, rb = find(a), find(b) |
| if ra != rb: |
| parent[ra] = rb |
|
|
| for entity_key, row_ids in index.items(): |
| if len(row_ids) > 1: |
| first = row_ids[0] |
| for rid in row_ids[1:]: |
| union(first, rid) |
|
|
| clusters: Dict[str, Set[str]] = defaultdict(set) |
| all_ids = set() |
| for row_ids in index.values(): |
| all_ids.update(row_ids) |
| for rid in all_ids: |
| root = find(rid) |
| clusters[root].add(rid) |
|
|
| |
| return [c for c in clusters.values() if len(c) > 1] |
|
|
|
|
| def run(engine_input: EngineInput) -> EngineOutput: |
| """ |
| Correlate records by shared entity values. |
| """ |
| try: |
| records = engine_input.records |
| index = _build_entity_index(records) |
| clusters = _find_clusters(index) |
|
|
| |
| row_to_cluster: Dict[str, int] = {} |
| for i, cluster in enumerate(clusters): |
| for rid in cluster: |
| row_to_cluster[rid] = i |
|
|
| annotated: List[NormalizedRecord] = [] |
| for r in records: |
| cluster_id = row_to_cluster.get(r.row_id) |
| if cluster_id is not None: |
| extra = dict(r.extra) |
| extra["correlation_cluster"] = cluster_id |
| annotated.append(r.model_copy(update={"extra": extra})) |
| else: |
| annotated.append(r) |
|
|
| correlated_count = len(row_to_cluster) |
| return EngineOutput( |
| stage="correlation", |
| status=StageStatus.SUCCESS, |
| records=annotated, |
| summary=f"Found {len(clusters)} clusters linking {correlated_count} records", |
| metadata={ |
| "cluster_count": len(clusters), |
| "correlated_records": correlated_count, |
| "entity_index_size": len(index), |
| }, |
| ) |
| except Exception as exc: |
| logger.error("Correlation failed: %s", exc, exc_info=True) |
| return EngineOutput( |
| stage="correlation", |
| status=StageStatus.FAILED, |
| error=str(exc), |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| print("Correlation module — use via engine pipeline") |
|
|