shreyas-joshi's picture
Add Phase 06 Plan for Adaptive Judging and Edge Intelligence; Create initial project outline for GraphReview RL Environment
86c3e08
from __future__ import annotations
from pathlib import Path
from typing import Literal
import networkx as nx
from sqlmodel import Session, select
from db.schema import ModuleEdge, ModuleNode
from db.store import Store
class GraphManager:
"""Load and query dependency graph state from SQLite."""
def __init__(self, source_root: str | Path, db_path: str | Path | None = None) -> None:
self.source_root = str(Path(source_root).resolve())
self.store = Store(source_root=self.source_root, db_path=db_path)
self._graph_cache: nx.DiGraph | None = None
self._centrality_cache: dict[str, float] | None = None
def load_graph(self, refresh: bool = False) -> nx.DiGraph:
if self._graph_cache is not None and not refresh:
return self._graph_cache.copy()
graph = nx.DiGraph()
with Session(self.store.engine) as session:
nodes = list(
session.exec(
select(ModuleNode).where(ModuleNode.source_root == self.store.config.source_root)
).all()
)
edges = list(
session.exec(
select(ModuleEdge).where(ModuleEdge.source_root == self.store.config.source_root)
).all()
)
for node in nodes:
graph.add_node(
node.module_id,
name=node.name,
raw_code=node.raw_code,
ast_summary=node.ast_summary,
summary=node.summary or "",
linter_flags=node.linter_flags,
parent_module_id=node.parent_module_id,
review_status=node.review_status.value,
review_summary=node.review_summary or "",
is_chunk=node.is_chunk,
)
for edge in edges:
graph.add_edge(
edge.source_module_id,
edge.target_module_id,
edge_type=edge.edge_type.value,
import_line=edge.import_line,
weight=edge.weight,
connection_summary=edge.connection_summary,
)
self._graph_cache = graph
self._centrality_cache = None
return graph.copy()
def invalidate_cache(self) -> None:
self._graph_cache = None
self._centrality_cache = None
def get_node(self, module_id: str) -> dict[str, object]:
graph = self.load_graph()
if module_id not in graph:
raise ValueError(f"Unknown module_id: {module_id}")
return dict(graph.nodes[module_id])
def get_neighbors(
self,
module_id: str,
direction: Literal["out", "in", "both"] = "both",
limit: int | None = None,
) -> list[str]:
graph = self.load_graph()
if module_id not in graph:
raise ValueError(f"Unknown module_id: {module_id}")
if direction == "out":
neighbors = set(graph.successors(module_id))
elif direction == "in":
neighbors = set(graph.predecessors(module_id))
else:
neighbors = set(graph.successors(module_id))
neighbors.update(graph.predecessors(module_id))
ordered = sorted(neighbors)
if limit is None:
return ordered
return ordered[: max(limit, 0)]
def resolve_module_id(self, module_id: str) -> str:
graph = self.load_graph()
if module_id in graph:
return module_id
candidate = module_id.strip()
variants = {
candidate,
candidate.replace("/", "."),
candidate.replace("\\", "."),
}
if candidate.endswith(".py"):
without_suffix = candidate[:-3]
variants.add(without_suffix)
variants.add(without_suffix.replace("/", "."))
variants.add(without_suffix.replace("\\", "."))
for variant in variants:
if variant in graph:
return variant
lower_lookup = {str(node).lower(): str(node) for node in graph.nodes()}
for variant in variants:
resolved = lower_lookup.get(variant.lower())
if resolved:
return resolved
raise ValueError(f"Unknown module_id: {module_id}")
def centrality(self) -> dict[str, float]:
if self._centrality_cache is not None:
return dict(self._centrality_cache)
graph = self.load_graph()
if graph.number_of_nodes() == 0:
self._centrality_cache = {}
return {}
self._centrality_cache = nx.betweenness_centrality(graph, normalized=True)
return dict(self._centrality_cache)
def traversal_order(self) -> list[str]:
"""
Return a deterministic, leaf-first traversal where high-centrality nodes are later.
"""
graph = self.load_graph()
if graph.number_of_nodes() == 0:
return []
centrality = self.centrality()
# For DAGs, reverse topological order visits leaves first.
if nx.is_directed_acyclic_graph(graph):
topo_reversed = list(reversed(list(nx.lexicographical_topological_sort(graph))))
topo_rank = {node: idx for idx, node in enumerate(topo_reversed)}
return sorted(
graph.nodes(),
key=lambda node: (
int(topo_rank.get(node, 0)),
float(centrality.get(node, 0.0)),
str(node),
),
)
# Stable fallback for cyclic graphs.
return sorted(
graph.nodes(),
key=lambda node: (
int(graph.out_degree(node)),
float(centrality.get(node, 0.0)),
str(node),
),
)