from __future__ import annotations import ast import json from collections import defaultdict from dataclasses import dataclass from pathlib import Path from pydantic import BaseModel, ConfigDict from sqlmodel import Session, select from db.schema import LinterFinding, ModuleEdge, ModuleNode, ReviewAnnotation from db.store import Store from graph.graph_manager import GraphManager from visualizer.pyvis_renderer import render_graph_html class ReviewQualityMetrics(BaseModel): model_config = ConfigDict(strict=True, extra="forbid") true_positives: int false_positives: int false_negatives: int precision: float recall: float f1: float severity_weighted_coverage: float security_coverage: float dependency_attribution_validity: float stage_coverage: float llm_stage_catch_rate: float llm_any_match_rate: float consistency: float confidence_score: float class GeneratedArtifacts(BaseModel): model_config = ConfigDict(strict=True, extra="forbid") markdown_path: str json_path: str html_path: str module_count: int edge_count: int annotation_count: int confidence_score: float @dataclass(frozen=True) class _Context: graph_manager: GraphManager store: Store module_ids: list[str] def _safe_div(numerator: float, denominator: float) -> float: return numerator / denominator if denominator else 0.0 def _parse_note_payload(note: str) -> dict[str, object]: text = note.strip() if not text: return {} try: loaded = json.loads(text) except json.JSONDecodeError: return {} return loaded if isinstance(loaded, dict) else {} def _resolve_module_scope(graph_manager: GraphManager, module_filter: list[str] | None, hops: int) -> list[str]: graph = graph_manager.load_graph() if not module_filter: return sorted(str(node) for node in graph.nodes()) seeds: set[str] = set() for module_id in module_filter: seeds.add(graph_manager.resolve_module_id(module_id)) related: set[str] = set(seeds) frontier = set(seeds) hop_count = max(hops, 0) for _ in range(hop_count): next_frontier: set[str] = set() for module_id in frontier: next_frontier.update(str(item) for item in graph.successors(module_id)) next_frontier.update(str(item) for item in graph.predecessors(module_id)) next_frontier -= related related.update(next_frontier) frontier = next_frontier return sorted(related) def _load_context( source_root: str | Path, db_path: str | Path | None, module_filter: list[str] | None, hops: int, ) -> _Context: graph_manager = GraphManager(source_root=source_root, db_path=db_path) store = Store(source_root=str(source_root), db_path=db_path) module_ids = _resolve_module_scope(graph_manager, module_filter, hops) return _Context(graph_manager=graph_manager, store=store, module_ids=module_ids) def _compatible_finding_ids(action_type: str, findings: list[LinterFinding]) -> list[int]: ids: list[int] = [] for finding in findings: if finding.id is None: continue if action_type == "FLAG_SECURITY" and finding.tool == "bandit": ids.append(finding.id) elif action_type == "FLAG_STYLE" and finding.tool != "bandit" and finding.severity.value == "low": ids.append(finding.id) elif action_type == "FLAG_BUG" and finding.severity.value in {"medium", "high"}: ids.append(finding.id) return ids def _compute_metrics( module_ids: list[str], graph_manager: GraphManager, findings_by_module: dict[str, list[LinterFinding]], annotations_by_module: dict[str, list[ReviewAnnotation]], ) -> ReviewQualityMetrics: true_positives = 0 false_positives = 0 false_negatives = 0 severity_weight_total = 0.0 severity_weight_matched = 0.0 security_total = 0 security_matched = 0 attribution_total = 0 attribution_valid = 0 stage_set: set[str] = set() total_findings = 0 llm_stage_caught = 0 llm_any_matched = 0 contradictory_modules = 0 severity_weight = {"low": 1.0, "medium": 2.0, "high": 3.0} graph = graph_manager.load_graph() for module_id in module_ids: findings = findings_by_module.get(module_id, []) annotations = sorted( annotations_by_module.get(module_id, []), key=lambda item: (item.created_at, item.step_number), ) finding_by_id = {finding.id: finding for finding in findings if finding.id is not None} matched_finding_ids: set[int] = set() llm_matched_ids: set[int] = set() finding_stage_map: dict[int, str] = {} terminal_actions: set[str] = set() for annotation in annotations: if annotation.task_id: stage_set.add(str(annotation.task_id).split("_", 1)[0]) if annotation.action_type in {"APPROVE", "REQUEST_CHANGES"}: terminal_actions.add(annotation.action_type) if annotation.action_type == "FLAG_DEPENDENCY_ISSUE" and annotation.attributed_to: attribution_total += 1 if annotation.attributed_to in graph and ( graph.has_edge(module_id, annotation.attributed_to) or graph.has_edge(annotation.attributed_to, module_id) ): attribution_valid += 1 if annotation.action_type not in {"FLAG_STYLE", "FLAG_BUG", "FLAG_SECURITY"}: continue payload = _parse_note_payload(annotation.note) if annotation.task_id and str(annotation.task_id).split("_", 1)[0] == "hard": hard_matched = payload.get("matched_finding_id") if isinstance(hard_matched, int): llm_matched_ids.add(hard_matched) matched_id = payload.get("matched_finding_id") if isinstance(matched_id, int) and matched_id in finding_by_id and matched_id not in matched_finding_ids: matched_finding_ids.add(matched_id) if annotation.task_id: finding_stage_map[matched_id] = str(annotation.task_id).split("_", 1)[0] true_positives += 1 continue compatible = [item for item in _compatible_finding_ids(annotation.action_type, findings) if item not in matched_finding_ids] if compatible: matched_finding_ids.add(compatible[0]) if annotation.task_id: finding_stage_map[compatible[0]] = str(annotation.task_id).split("_", 1)[0] true_positives += 1 else: false_positives += 1 if len(terminal_actions) > 1: contradictory_modules += 1 false_negatives += max(len(findings) - len(matched_finding_ids), 0) for finding in findings: total_findings += 1 weight = severity_weight.get(finding.severity.value, 1.0) severity_weight_total += weight if finding.id in matched_finding_ids: severity_weight_matched += weight if finding.id is not None and finding_stage_map.get(finding.id) == "hard": llm_stage_caught += 1 if finding.id in llm_matched_ids: llm_any_matched += 1 if finding.tool == "bandit": security_total += 1 if finding.id in matched_finding_ids: security_matched += 1 precision = _safe_div(true_positives, true_positives + false_positives) recall = _safe_div(true_positives, true_positives + false_negatives) f1 = _safe_div(2 * precision * recall, precision + recall) severity_coverage = _safe_div(severity_weight_matched, severity_weight_total) security_coverage = _safe_div(security_matched, security_total) attribution_validity = _safe_div(attribution_valid, attribution_total) stage_coverage = _safe_div(len(stage_set), 3) llm_stage_catch_rate = _safe_div(llm_stage_caught, total_findings) llm_any_match_rate = _safe_div(llm_any_matched, total_findings) consistency = 1.0 - _safe_div(contradictory_modules, len(module_ids)) confidence_score = ( 0.35 * f1 + 0.2 * severity_coverage + 0.15 * security_coverage + 0.15 * attribution_validity + 0.1 * stage_coverage + 0.03 * llm_any_match_rate + 0.02 * consistency ) confidence_score = max(0.0, min(1.0, confidence_score)) return ReviewQualityMetrics( true_positives=true_positives, false_positives=false_positives, false_negatives=false_negatives, precision=precision, recall=recall, f1=f1, severity_weighted_coverage=severity_coverage, security_coverage=security_coverage, dependency_attribution_validity=attribution_validity, stage_coverage=stage_coverage, llm_stage_catch_rate=llm_stage_catch_rate, llm_any_match_rate=llm_any_match_rate, consistency=consistency, confidence_score=confidence_score, ) def _extract_module_shape(raw_code: str) -> str: try: tree = ast.parse(raw_code) except SyntaxError: return "Could not parse AST for this module." functions = [node.name for node in tree.body if isinstance(node, ast.FunctionDef)] async_functions = [node.name for node in tree.body if isinstance(node, ast.AsyncFunctionDef)] classes = [node.name for node in tree.body if isinstance(node, ast.ClassDef)] parts: list[str] = [] if functions: parts.append(f"functions={', '.join(functions[:6])}") if async_functions: parts.append(f"async_functions={', '.join(async_functions[:6])}") if classes: parts.append(f"classes={', '.join(classes[:6])}") if not parts: return "No top-level functions/classes; likely constants, helpers, or script-style module." return " | ".join(parts) def _build_node_title( module: ModuleNode, findings: list[LinterFinding], annotations: list[ReviewAnnotation], status: str, confidence_score: float, ) -> str: security_findings = [finding for finding in findings if finding.tool == "bandit"] latest = annotations[-3:] latest_lines = [] for item in latest: latest_lines.append(f"#{item.step_number} {item.action_type}: {item.reward_given:.2f}") summary_text = (module.summary or module.ast_summary).replace("\n", " ") return ( f"module: {module.module_id}\n" f"status: {status}\n" f"confidence: {confidence_score:.2f}\n" f"summary: {summary_text[:260]}\n" f"shape: {_extract_module_shape(module.raw_code)}\n" f"security_findings: {len(security_findings)}\n" f"latest_reviews: {' | '.join(latest_lines) if latest_lines else 'none'}" ) def _derive_status(node: ModuleNode, annotations: list[ReviewAnnotation]) -> str: if not annotations: return node.review_status.value last = annotations[-1].action_type if last == "APPROVE": return "approved" if last == "REQUEST_CHANGES": return "changes_requested" return node.review_status.value def _build_json_payload( *, source_root: str, module_ids: list[str], nodes: list[ModuleNode], edges: list[ModuleEdge], findings_by_module: dict[str, list[LinterFinding]], annotations_by_module: dict[str, list[ReviewAnnotation]], metrics: ReviewQualityMetrics, episode_id: str | None, ) -> dict[str, object]: node_payload = [] for node in sorted(nodes, key=lambda item: item.module_id): annotations = annotations_by_module.get(node.module_id, []) finding_stage_map: dict[int, str] = {} finding_llm_verified_map: dict[int, bool] = {} for item in annotations: payload = _parse_note_payload(item.note) matched_id = payload.get("matched_finding_id") if isinstance(matched_id, int) and item.task_id: stage = str(item.task_id).split("_", 1)[0] finding_stage_map.setdefault(matched_id, stage) if stage == "hard": finding_llm_verified_map[matched_id] = True status = _derive_status(node, annotations) node_payload.append( { "module_id": node.module_id, "name": node.name, "status": status, "summary": node.summary or node.ast_summary, "raw_code": node.raw_code, "module_shape": _extract_module_shape(node.raw_code), "caught_stages": sorted( { str(item.task_id).split("_", 1)[0] for item in annotations if item.task_id } ), "primary_caught_stage": ( str(annotations[0].task_id).split("_", 1)[0] if annotations and annotations[0].task_id else None ), "security_findings": [ { "line": finding.line, "code": finding.code, "severity": finding.severity.value, "message": finding.message, } for finding in findings_by_module.get(node.module_id, []) if finding.tool == "bandit" ], "linter_findings": [ { "id": finding.id, "tool": finding.tool, "line": finding.line, "severity": finding.severity.value, "code": finding.code, "message": finding.message, "caught_stage": (finding_stage_map.get(finding.id) if finding.id is not None else None), "llm_first_catch": ( finding_stage_map.get(finding.id) == "hard" if finding.id is not None else False ), "llm_verified": ( bool(finding_llm_verified_map.get(finding.id, False)) if finding.id is not None else False ), } for finding in findings_by_module.get(node.module_id, []) ], "reviews": [ { "step_number": item.step_number, "task_id": item.task_id, "caught_stage": (str(item.task_id).split("_", 1)[0] if item.task_id else None), "action_type": item.action_type, "reward_given": item.reward_given, "attributed_to": item.attributed_to, "is_amendment": item.is_amendment, "note": _parse_note_payload(item.note), "created_at": item.created_at.isoformat(), } for item in annotations ], } ) edge_payload = [ { "source": edge.source_module_id, "target": edge.target_module_id, "edge_type": edge.edge_type.value, "weight": edge.weight, "import_line": edge.import_line, "connection_summary": edge.connection_summary, } for edge in sorted(edges, key=lambda item: (item.source_module_id, item.target_module_id, item.import_line)) ] return { "report_schema_version": "1.0.0", "source_root": source_root, "episode_id": episode_id, "scope_modules": module_ids, "metrics": metrics.model_dump(), "nodes": node_payload, "edges": edge_payload, "rl_integrity": { "trajectory_reconstructable": True, "reward_causality_tracked": True, "deterministic_replay_notes": "easy/medium deterministic by construction; hard uses judge with temperature=0", }, } def _build_markdown_report(payload: dict[str, object]) -> str: metrics = payload["metrics"] lines: list[str] = [] lines.append("# GraphReview Report") lines.append("") lines.append("## Executive Summary") lines.append(f"- Source root: {payload['source_root']}") lines.append(f"- Episode id: {payload.get('episode_id') or 'all'}") lines.append(f"- Modules in scope: {len(payload['scope_modules'])}") lines.append(f"- Confidence score: {metrics['confidence_score']:.3f}") lines.append(f"- Precision: {metrics['precision']:.3f} | Recall: {metrics['recall']:.3f} | F1: {metrics['f1']:.3f}") lines.append( "- Security coverage: " f"{metrics['security_coverage']:.3f} | Dependency attribution validity: {metrics['dependency_attribution_validity']:.3f}" ) lines.append(f"- Stage coverage: {metrics['stage_coverage']:.3f}") lines.append(f"- LLM first-catch rate: {metrics['llm_stage_catch_rate']:.3f}") lines.append(f"- LLM any-match rate: {metrics['llm_any_match_rate']:.3f}") lines.append("") lines.append("## Security Analysis") for node in payload["nodes"]: security_findings = node["security_findings"] if not security_findings: continue lines.append(f"### {node['module_id']}") for finding in security_findings: lines.append( "- " f"[{finding['severity'].upper()}] {finding['code']} line {finding['line']}: {finding['message']}" ) lines.append("") lines.append("## Cascade Attribution Summary") for node in payload["nodes"]: attributions = [review for review in node["reviews"] if review.get("attributed_to")] if not attributions: continue lines.append(f"### {node['module_id']}") for item in attributions: lines.append( "- " f"step {item['step_number']} -> attributed_to={item['attributed_to']} " f"action={item['action_type']} reward={item['reward_given']:.2f}" ) lines.append("") lines.append("## Module Reviews") for node in payload["nodes"]: lines.append(f"### {node['module_id']}") lines.append(f"- Status: {node['status']}") lines.append(f"- Summary: {node['summary']}") lines.append(f"- Shape: {node['module_shape']}") lines.append(f"- Findings: {len(node['linter_findings'])}") lines.append(f"- Reviews: {len(node['reviews'])}") if node["reviews"]: latest = node["reviews"][-1] lines.append( "- Latest review: " f"step {latest['step_number']} action={latest['action_type']} reward={latest['reward_given']:.2f}" ) lines.append("") lines.append("## RL Integrity") lines.append("- Trajectory reconstructable from DB annotations and episode records.") lines.append("- Reward causality linked to each persisted action payload.") lines.append("- Easy/Medium deterministic replay expected; Hard constrained by temperature=0 judge policy.") lines.append("") return "\n".join(lines).strip() + "\n" def generate_phase5_outputs( *, source_root: str | Path, db_path: str | Path | None = None, output_dir: str | Path = "outputs", episode_id: str | None = None, module_filter: list[str] | None = None, hops: int = 1, report_prefix: str = "graphreview", ) -> GeneratedArtifacts: source_root_text = str(Path(source_root).resolve()) context = _load_context(source_root_text, db_path, module_filter, hops) with Session(context.store.engine) as session: nodes = list( session.exec( select(ModuleNode).where( ModuleNode.source_root == context.store.config.source_root, ModuleNode.module_id.in_(context.module_ids), ) ).all() ) edges = list( session.exec( select(ModuleEdge).where( ModuleEdge.source_root == context.store.config.source_root, ModuleEdge.source_module_id.in_(context.module_ids), ModuleEdge.target_module_id.in_(context.module_ids), ) ).all() ) findings = list( session.exec( select(LinterFinding).where( LinterFinding.source_root == context.store.config.source_root, LinterFinding.module_id.in_(context.module_ids), ) ).all() ) annotation_query = select(ReviewAnnotation).where( ReviewAnnotation.source_root == context.store.config.source_root, ReviewAnnotation.module_id.in_(context.module_ids), ) if episode_id: annotation_query = annotation_query.where(ReviewAnnotation.episode_id == episode_id) annotations = list(session.exec(annotation_query).all()) findings_by_module: dict[str, list[LinterFinding]] = defaultdict(list) for finding in findings: findings_by_module[finding.module_id].append(finding) annotations_by_module: dict[str, list[ReviewAnnotation]] = defaultdict(list) for annotation in annotations: annotations_by_module[annotation.module_id].append(annotation) for module_id in list(annotations_by_module.keys()): annotations_by_module[module_id] = sorted( annotations_by_module[module_id], key=lambda item: (item.created_at, item.step_number), ) metrics = _compute_metrics( module_ids=context.module_ids, graph_manager=context.graph_manager, findings_by_module=findings_by_module, annotations_by_module=annotations_by_module, ) payload = _build_json_payload( source_root=source_root_text, module_ids=context.module_ids, nodes=nodes, edges=edges, findings_by_module=findings_by_module, annotations_by_module=annotations_by_module, metrics=metrics, episode_id=episode_id, ) output_root = Path(output_dir) output_root.mkdir(parents=True, exist_ok=True) json_path = output_root / f"{report_prefix}_report.json" markdown_path = output_root / f"{report_prefix}_report.md" html_path = output_root / f"{report_prefix}_graph.html" json_path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8") markdown_path.write_text(_build_markdown_report(payload), encoding="utf-8") node_map = {node.module_id: node for node in nodes} graph = context.graph_manager.load_graph() centrality = context.graph_manager.centrality() html_nodes: list[dict[str, object]] = [] for module_id in context.module_ids: node = node_map.get(module_id) if node is None: continue module_annotations = annotations_by_module.get(module_id, []) status = _derive_status(node, module_annotations) html_nodes.append( { "id": module_id, "label": module_id, "status": status, "size": 8.0 + (centrality.get(module_id, 0.0) * 42.0), "title": _build_node_title( module=node, findings=findings_by_module.get(module_id, []), annotations=module_annotations, status=status, confidence_score=metrics.confidence_score, ), } ) html_edges: list[dict[str, object]] = [] for edge in edges: if edge.source_module_id not in graph or edge.target_module_id not in graph: continue html_edges.append( { "source": edge.source_module_id, "target": edge.target_module_id, "edge_type": edge.edge_type.value, "weight": edge.weight, "title": ( f"{edge.edge_type.value}: {edge.connection_summary or edge.import_line}" ), } ) render_graph_html(nodes=html_nodes, edges=html_edges, output_path=html_path) return GeneratedArtifacts( markdown_path=str(markdown_path), json_path=str(json_path), html_path=str(html_path), module_count=len(html_nodes), edge_count=len(html_edges), annotation_count=len(annotations), confidence_score=metrics.confidence_score, )