from __future__ import annotations from dataclasses import dataclass import math from typing import Any import torch import torch.nn.functional as F from src.model.anchor_types import AnchorRecord from src.model.future_influence import FutureInfluenceScorer @dataclass class AnchorDependencyEdge: source_id: int target_id: int approx_score: float final_score: float similarity: float temporal_prior: float support_prior: float viability_prior: float refined_delta: float = 0.0 is_refined: bool = False @dataclass class AnchorDependencyNode: anchor_id: int validity: float soft_confirmation: float node_pressure: float predecessor_ids: list[int] broken_predecessor_ids: list[int] def _to_float(value: torch.Tensor | float | int | None) -> float: if value is None: return 0.0 if isinstance(value, torch.Tensor): return float(value.detach().item()) return float(value) def _sigmoid_unit(value: float, threshold: float, slope: float) -> float: safe_slope = max(float(slope), 1e-6) tensor = torch.tensor((float(value) - float(threshold)) / safe_slope, dtype=torch.float32) return float(torch.sigmoid(tensor).item()) def _temporal_prior(source: AnchorRecord, target: AnchorRecord, temporal_window: float) -> float: distance = max(1, int(target.start_idx) - int(source.end_idx)) return float(math.exp(-(float(distance) - 1.0) / max(float(temporal_window), 1e-6))) def _approx_dependency_score( source: AnchorRecord, target: AnchorRecord, *, confirm_threshold: float, similarity_weight: float, temporal_weight: float, support_weight: float, viability_weight: float, temporal_window: float, ) -> tuple[float, dict[str, float]]: source_repr = F.normalize(source.repr.detach().float().unsqueeze(0), dim=-1) target_repr = F.normalize(target.repr.detach().float().unsqueeze(0), dim=-1) similarity = max(0.0, float(F.cosine_similarity(source_repr, target_repr, dim=-1).item())) temporal = _temporal_prior(source, target, temporal_window) support = min(1.0, max(0.0, 0.5 * (_to_float(source.support) + _to_float(target.support)))) viability = min( 1.0, max( 0.0, 0.5 * ( _sigmoid_unit(_to_float(source.support), confirm_threshold, 0.10) + _to_float(target.viability) ), ), ) total_weight = max( float(similarity_weight) + float(temporal_weight) + float(support_weight) + float(viability_weight), 1e-6, ) score = ( float(similarity_weight) * similarity + float(temporal_weight) * temporal + float(support_weight) * support + float(viability_weight) * viability ) / total_weight return float(score), { "similarity": float(similarity), "temporal_prior": float(temporal), "support_prior": float(support), "viability_prior": float(viability), } def _compute_counterfactual_scores( *, anchors: list[AnchorRecord], candidate_edges: list[AnchorDependencyEdge], hidden: torch.Tensor | None, input_ids: torch.Tensor | None, attention_mask: torch.Tensor | None, output_projection: torch.nn.Module | None, future_scorer: FutureInfluenceScorer | None, future_window: int, max_edges: int, ) -> dict[tuple[int, int], float]: if not candidate_edges or hidden is None or input_ids is None or output_projection is None or future_scorer is None: return {} if hidden.ndim != 3 or hidden.size(0) != 1 or input_ids.ndim != 2: return {} edge_map = {(edge.source_id, edge.target_id): edge for edge in candidate_edges} top_edges = sorted(candidate_edges, key=lambda item: item.approx_score, reverse=True)[: max(0, int(max_edges))] if not top_edges: return {} anchor_by_id = {anchor.id: anchor for anchor in anchors} base_hidden = hidden.detach().clone().requires_grad_(True) base_logits = output_projection(base_hidden) base_scores = future_scorer( hidden=base_hidden, logits=base_logits, input_ids=input_ids, attention_mask=attention_mask, future_window=future_window, )["scores"].detach() deltas: dict[tuple[int, int], float] = {} unique_source_ids = sorted({edge.source_id for edge in top_edges}) for source_id in unique_source_ids: source_anchor = anchor_by_id.get(source_id) if source_anchor is None: continue masked_hidden = hidden.detach().clone() start = max(0, int(source_anchor.start_idx)) end = min(masked_hidden.size(1) - 1, int(source_anchor.end_idx)) masked_hidden[:, start : end + 1, :] = 0.0 masked_hidden = masked_hidden.requires_grad_(True) masked_logits = output_projection(masked_hidden) masked_scores = future_scorer( hidden=masked_hidden, logits=masked_logits, input_ids=input_ids, attention_mask=attention_mask, future_window=future_window, )["scores"].detach() for edge in top_edges: if edge.source_id != source_id: continue target_anchor = anchor_by_id.get(edge.target_id) if target_anchor is None: continue target_start = max(0, int(target_anchor.start_idx)) target_end = min(base_scores.size(1) - 1, int(target_anchor.end_idx)) if target_end < target_start: continue delta = ( base_scores[:, target_start : target_end + 1] - masked_scores[:, target_start : target_end + 1] ).abs().mean() deltas[(edge.source_id, edge.target_id)] = float(torch.nan_to_num(delta, nan=0.0, posinf=1.0, neginf=0.0).item()) return deltas def build_anchor_dependency_graph( anchors: list[AnchorRecord], *, confirm_threshold: float, dependency_threshold: float = 0.55, confirm_slope: float = 0.10, similarity_weight: float = 0.55, temporal_weight: float = 0.20, support_weight: float = 0.15, viability_weight: float = 0.10, temporal_window: float = 16.0, max_predecessors: int = 4, counterfactual_top_edges: int = 0, future_scorer: FutureInfluenceScorer | None = None, hidden: torch.Tensor | None = None, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, output_projection: torch.nn.Module | None = None, future_window: int = 16, ) -> dict[str, Any]: if not anchors: return { "edges": [], "nodes": [], "graph_pressure": 0.0, "current_graph_pressure": 0.0, "current_anchor_id": None, "edge_count": 0, "broken_anchor_count": 0, "mean_validity": 1.0, } sorted_anchors = sorted(anchors, key=lambda item: (int(item.start_idx), int(item.end_idx), int(item.id))) anchor_by_id = {anchor.id: anchor for anchor in sorted_anchors} edges_by_target: dict[int, list[AnchorDependencyEdge]] = {anchor.id: [] for anchor in sorted_anchors} candidate_edges: list[AnchorDependencyEdge] = [] for source in sorted_anchors: for target in sorted_anchors: if int(source.end_idx) >= int(target.start_idx) or source.id == target.id: continue approx_score, parts = _approx_dependency_score( source, target, confirm_threshold=confirm_threshold, similarity_weight=similarity_weight, temporal_weight=temporal_weight, support_weight=support_weight, viability_weight=viability_weight, temporal_window=temporal_window, ) if approx_score < float(dependency_threshold): continue candidate_edges.append( AnchorDependencyEdge( source_id=source.id, target_id=target.id, approx_score=float(approx_score), final_score=float(approx_score), similarity=float(parts["similarity"]), temporal_prior=float(parts["temporal_prior"]), support_prior=float(parts["support_prior"]), viability_prior=float(parts["viability_prior"]), ) ) deltas = _compute_counterfactual_scores( anchors=sorted_anchors, candidate_edges=candidate_edges, hidden=hidden, input_ids=input_ids, attention_mask=attention_mask, output_projection=output_projection, future_scorer=future_scorer, future_window=future_window, max_edges=counterfactual_top_edges, ) for edge in candidate_edges: edge_key = (edge.source_id, edge.target_id) delta = float(deltas.get(edge_key, 0.0)) if edge_key in deltas: edge.refined_delta = delta edge.is_refined = True edge.final_score = 0.5 * float(edge.approx_score) + 0.5 * min(1.0, max(0.0, delta)) for anchor in sorted_anchors: incoming = [edge for edge in candidate_edges if edge.target_id == anchor.id] incoming.sort(key=lambda item: item.final_score, reverse=True) edges_by_target[anchor.id] = incoming[: max(1, int(max_predecessors))] if incoming else [] nodes: list[AnchorDependencyNode] = [] node_by_id: dict[int, AnchorDependencyNode] = {} for anchor in sorted_anchors: soft_confirmation = _sigmoid_unit(_to_float(anchor.support), confirm_threshold, confirm_slope) predecessors = edges_by_target[anchor.id] if predecessors: total = sum(edge.final_score for edge in predecessors) weighted_confirmation = sum( edge.final_score * _sigmoid_unit( _to_float(anchor_by_id[edge.source_id].support), confirm_threshold, confirm_slope, ) for edge in predecessors ) / max(total, 1e-6) else: weighted_confirmation = 1.0 broken_predecessors = [ edge.source_id for edge in predecessors if _sigmoid_unit( _to_float(anchor_by_id[edge.source_id].support), confirm_threshold, confirm_slope, ) < 0.5 ] node_pressure = 1.0 - float(weighted_confirmation) * min(1.0, max(0.0, _to_float(anchor.viability))) node = AnchorDependencyNode( anchor_id=anchor.id, validity=float(weighted_confirmation), soft_confirmation=float(soft_confirmation), node_pressure=float(min(1.0, max(0.0, node_pressure))), predecessor_ids=[edge.source_id for edge in predecessors], broken_predecessor_ids=broken_predecessors, ) nodes.append(node) node_by_id[anchor.id] = node current_anchor = max(sorted_anchors, key=lambda item: (int(item.end_idx), int(item.start_idx), int(item.id))) current_graph_pressure = float(node_by_id[current_anchor.id].node_pressure) graph_pressure = max((node.node_pressure for node in nodes), default=0.0) return { "edges": [ { "source_id": edge.source_id, "target_id": edge.target_id, "approx_score": edge.approx_score, "final_score": edge.final_score, "similarity": edge.similarity, "temporal_prior": edge.temporal_prior, "support_prior": edge.support_prior, "viability_prior": edge.viability_prior, "refined_delta": edge.refined_delta, "is_refined": edge.is_refined, } for target_edges in edges_by_target.values() for edge in target_edges ], "nodes": [ { "anchor_id": node.anchor_id, "validity": node.validity, "soft_confirmation": node.soft_confirmation, "node_pressure": node.node_pressure, "predecessor_ids": node.predecessor_ids, "broken_predecessor_ids": node.broken_predecessor_ids, } for node in nodes ], "graph_pressure": float(graph_pressure), "current_graph_pressure": float(current_graph_pressure), "current_anchor_id": int(current_anchor.id), "edge_count": int(sum(len(edges) for edges in edges_by_target.values())), "broken_anchor_count": int(sum(1 for node in nodes if node.broken_predecessor_ids)), "mean_validity": float(sum(node.validity for node in nodes) / max(len(nodes), 1)), }