Spaces:
Sleeping
Sleeping
File size: 5,829 Bytes
899a7c7 b196357 899a7c7 86c3e08 899a7c7 b196357 899a7c7 bc02c39 899a7c7 b196357 899a7c7 b196357 899a7c7 b196357 899a7c7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | from __future__ import annotations
from pathlib import Path
from typing import Literal
import networkx as nx
from sqlmodel import Session, select
from db.schema import ModuleEdge, ModuleNode
from db.store import Store
class GraphManager:
"""Load and query dependency graph state from SQLite."""
def __init__(self, source_root: str | Path, db_path: str | Path | None = None) -> None:
self.source_root = str(Path(source_root).resolve())
self.store = Store(source_root=self.source_root, db_path=db_path)
self._graph_cache: nx.DiGraph | None = None
self._centrality_cache: dict[str, float] | None = None
def load_graph(self, refresh: bool = False) -> nx.DiGraph:
if self._graph_cache is not None and not refresh:
return self._graph_cache.copy()
graph = nx.DiGraph()
with Session(self.store.engine) as session:
nodes = list(
session.exec(
select(ModuleNode).where(ModuleNode.source_root == self.store.config.source_root)
).all()
)
edges = list(
session.exec(
select(ModuleEdge).where(ModuleEdge.source_root == self.store.config.source_root)
).all()
)
for node in nodes:
graph.add_node(
node.module_id,
name=node.name,
raw_code=node.raw_code,
ast_summary=node.ast_summary,
summary=node.summary or "",
linter_flags=node.linter_flags,
parent_module_id=node.parent_module_id,
review_status=node.review_status.value,
review_summary=node.review_summary or "",
is_chunk=node.is_chunk,
)
for edge in edges:
graph.add_edge(
edge.source_module_id,
edge.target_module_id,
edge_type=edge.edge_type.value,
import_line=edge.import_line,
weight=edge.weight,
connection_summary=edge.connection_summary,
)
self._graph_cache = graph
self._centrality_cache = None
return graph.copy()
def invalidate_cache(self) -> None:
self._graph_cache = None
self._centrality_cache = None
def get_node(self, module_id: str) -> dict[str, object]:
graph = self.load_graph()
if module_id not in graph:
raise ValueError(f"Unknown module_id: {module_id}")
return dict(graph.nodes[module_id])
def get_neighbors(
self,
module_id: str,
direction: Literal["out", "in", "both"] = "both",
limit: int | None = None,
) -> list[str]:
graph = self.load_graph()
if module_id not in graph:
raise ValueError(f"Unknown module_id: {module_id}")
if direction == "out":
neighbors = set(graph.successors(module_id))
elif direction == "in":
neighbors = set(graph.predecessors(module_id))
else:
neighbors = set(graph.successors(module_id))
neighbors.update(graph.predecessors(module_id))
ordered = sorted(neighbors)
if limit is None:
return ordered
return ordered[: max(limit, 0)]
def resolve_module_id(self, module_id: str) -> str:
graph = self.load_graph()
if module_id in graph:
return module_id
candidate = module_id.strip()
variants = {
candidate,
candidate.replace("/", "."),
candidate.replace("\\", "."),
}
if candidate.endswith(".py"):
without_suffix = candidate[:-3]
variants.add(without_suffix)
variants.add(without_suffix.replace("/", "."))
variants.add(without_suffix.replace("\\", "."))
for variant in variants:
if variant in graph:
return variant
lower_lookup = {str(node).lower(): str(node) for node in graph.nodes()}
for variant in variants:
resolved = lower_lookup.get(variant.lower())
if resolved:
return resolved
raise ValueError(f"Unknown module_id: {module_id}")
def centrality(self) -> dict[str, float]:
if self._centrality_cache is not None:
return dict(self._centrality_cache)
graph = self.load_graph()
if graph.number_of_nodes() == 0:
self._centrality_cache = {}
return {}
self._centrality_cache = nx.betweenness_centrality(graph, normalized=True)
return dict(self._centrality_cache)
def traversal_order(self) -> list[str]:
"""
Return a deterministic, leaf-first traversal where high-centrality nodes are later.
"""
graph = self.load_graph()
if graph.number_of_nodes() == 0:
return []
centrality = self.centrality()
# For DAGs, reverse topological order visits leaves first.
if nx.is_directed_acyclic_graph(graph):
topo_reversed = list(reversed(list(nx.lexicographical_topological_sort(graph))))
topo_rank = {node: idx for idx, node in enumerate(topo_reversed)}
return sorted(
graph.nodes(),
key=lambda node: (
int(topo_rank.get(node, 0)),
float(centrality.get(node, 0.0)),
str(node),
),
)
# Stable fallback for cyclic graphs.
return sorted(
graph.nodes(),
key=lambda node: (
int(graph.out_degree(node)),
float(centrality.get(node, 0.0)),
str(node),
),
)
|