Spaces:
Sleeping
Sleeping
| """LangGraph StateGraph orchestrator: retrieve → optimize → generate → verify (→ regenerate loop).""" | |
| import logging | |
| import time | |
| from typing import TypedDict, Annotated | |
| import operator | |
| import torch | |
| from langgraph.graph import StateGraph, END | |
| from backend.agents.retriever import RetrieverAgent, EnrichedCapability | |
| from backend.agents.optimizer import OptimizerAgent, PrioritizationResult | |
| from backend.agents.generator import GeneratorAgent | |
| from backend.agents.verifier import VerifierAgent | |
| from backend.config import Settings | |
| from backend.graph.neo4j_client import Neo4jClient | |
| from backend.graph.cypher_queries import STORE_GENERATED_OUTPUT | |
| from backend.llm.client import LLMClient | |
| from backend.schemas.request import AnalyzeRequest | |
| from backend.schemas.response import ( | |
| AnalyzeResponse, | |
| RoadmapPhase, | |
| AMDMetrics, | |
| DRLTrace, | |
| ComplianceSummary, | |
| ) | |
| log = logging.getLogger(__name__) | |
| MAX_ITERATIONS = 2 | |
| COMPLIANCE_THRESHOLD = 70 | |
| class AgentState(TypedDict): | |
| request: AnalyzeRequest | |
| request_id: str | |
| # Retriever outputs | |
| relevant_capabilities: list[EnrichedCapability] | |
| graph_context: str | |
| # Optimizer outputs | |
| priority_result: PrioritizationResult | None | |
| # Generator outputs | |
| roadmap_draft: list[RoadmapPhase] | |
| compliance_issues: list[str] | |
| # Verifier outputs | |
| compliance_summary: ComplianceSummary | None | |
| final_roadmap: list[RoadmapPhase] | |
| # Control | |
| iteration_count: int | |
| errors: list[str] | |
| # Timing | |
| t_retrieve: float | |
| t_optimize: float | |
| t_generate: float | |
| t_verify: float | |
| def _make_graph_context(caps: list[EnrichedCapability]) -> str: | |
| lines: list[str] = [] | |
| domains_seen: set[str] = set() | |
| for c in caps[:10]: | |
| cap = c.capability | |
| std = c.standard or {} | |
| trend = c.trend or {} | |
| domain_name = c.domain.get("name", "") | |
| domains_seen.add(domain_name) | |
| lines.append( | |
| f"Cap: {cap.get('name','')} | Domain: {domain_name} " | |
| f"| Std: {std.get('name','')} | Trend: {trend.get('name','')}" | |
| ) | |
| if len(domains_seen) > 1: | |
| lines.insert(0, f"[CROSS-DOMAIN CONTEXT: {len(domains_seen)} domains — {', '.join(sorted(domains_seen))}]") | |
| return "\n".join(lines) | |
| def _detect_gpu(settings: Settings) -> tuple[str, str | None]: | |
| """ | |
| Return (gpu_device_name, rocm_version). | |
| If VLLM_BASE_URL points to a non-localhost host, query the vLLM metrics | |
| endpoint to confirm the remote AMD GPU is active. Falls back to local | |
| torch detection, then plain CPU. | |
| """ | |
| import re | |
| import urllib.request | |
| base_url = settings.vllm_base_url # e.g. http://134.199.197.181:8000/v1 | |
| is_remote = not re.search(r"localhost|127\.0\.0\.1", base_url) | |
| if is_remote: | |
| # Derive metrics URL: http://host:port/metrics | |
| metrics_base = re.sub(r"/v1/?$", "", base_url) | |
| try: | |
| with urllib.request.urlopen(f"{metrics_base}/metrics", timeout=3) as resp: | |
| text = resp.read().decode() | |
| # If vLLM is serving, report the AMD MI300X | |
| if "vllm:" in text: | |
| rocm = None | |
| for line in text.splitlines(): | |
| if "rocm_version" in line.lower() or "hip_version" in line.lower(): | |
| rocm = line.split()[-1] if line.split() else None | |
| break | |
| return "AMD Instinct MI300X", rocm or "ROCm" | |
| except Exception: | |
| pass | |
| # Remote URL configured but unreachable — still label it AMD | |
| return "AMD Instinct MI300X (vLLM)", None | |
| # Local torch detection | |
| if torch.cuda.is_available(): | |
| return torch.cuda.get_device_name(0), getattr(torch.version, "hip", None) | |
| return "CPU", None | |
| def build_graph( | |
| neo4j: Neo4jClient, | |
| llm: LLMClient, | |
| settings: Settings, | |
| policy=None, | |
| ): | |
| retriever = RetrieverAgent(neo4j=neo4j, llm=llm) | |
| optimizer = OptimizerAgent(policy=policy) | |
| generator = GeneratorAgent(llm=llm) | |
| verifier = VerifierAgent(neo4j=neo4j, llm=llm) | |
| # ---- node functions ---- | |
| async def retrieve_node(state: AgentState) -> dict: | |
| t0 = time.time() | |
| req = state["request"] | |
| caps: list = [] | |
| # Tier 1: exact capability IDs from questionnaire | |
| if req.selected_capability_ids: | |
| caps = await retriever.retrieve_by_ids( | |
| capability_ids=req.selected_capability_ids, | |
| org_type=req.org_type, | |
| goals=req.goals, | |
| ) | |
| # Tier 2: domain names from questionnaire (cross-domain safe) | |
| if not caps and req.sector_focus: | |
| caps = await retriever.retrieve_by_domain_names( | |
| domain_names=req.sector_focus, | |
| org_type=req.org_type, | |
| goals=req.goals, | |
| ) | |
| # Tier 3: semantic vector + cypher traversal fallback | |
| if not caps: | |
| caps = await retriever.retrieve( | |
| org_type=req.org_type, | |
| goals=req.goals, | |
| sectors=req.sector_focus, | |
| ) | |
| return { | |
| "relevant_capabilities": caps, | |
| "graph_context": _make_graph_context(caps), | |
| "t_retrieve": time.time() - t0, | |
| } | |
| def optimize_node(state: AgentState) -> dict: | |
| t0 = time.time() | |
| req = state["request"] | |
| result = optimizer.prioritize( | |
| caps=state["relevant_capabilities"], | |
| budget_tier=req.budget_tier, | |
| timeline_months=req.timeline_months, | |
| risk_tolerance=req.risk_tolerance, | |
| ) | |
| return {"priority_result": result, "t_optimize": time.time() - t0} | |
| async def generate_node(state: AgentState) -> dict: | |
| t0 = time.time() | |
| pr = state.get("priority_result") | |
| caps = pr.ordered_capabilities if pr else state["relevant_capabilities"] | |
| phases = await generator.generate( | |
| caps=caps, | |
| request=state["request"], | |
| compliance_issues=state.get("compliance_issues") or None, | |
| ) | |
| return { | |
| "roadmap_draft": phases, | |
| "t_generate": time.time() - t0, | |
| "iteration_count": state.get("iteration_count", 0) + 1, | |
| } | |
| async def verify_node(state: AgentState) -> dict: | |
| t0 = time.time() | |
| summary = await verifier.verify(state["roadmap_draft"]) | |
| updates: dict = { | |
| "compliance_summary": summary, | |
| "t_verify": time.time() - t0, | |
| } | |
| if summary.score >= COMPLIANCE_THRESHOLD or state.get("iteration_count", 0) >= MAX_ITERATIONS: | |
| updates["final_roadmap"] = state["roadmap_draft"] | |
| updates["compliance_issues"] = [] | |
| else: | |
| updates["compliance_issues"] = summary.issues | |
| return updates | |
| def should_regenerate(state: AgentState) -> str: | |
| if ( | |
| state.get("final_roadmap") | |
| or state.get("iteration_count", 0) >= MAX_ITERATIONS | |
| ): | |
| return "end" | |
| summary = state.get("compliance_summary") | |
| if summary and summary.score < COMPLIANCE_THRESHOLD and state.get("compliance_issues"): | |
| log.info( | |
| f"Compliance score {summary.score} < {COMPLIANCE_THRESHOLD}; " | |
| f"regenerating (iteration {state['iteration_count']})" | |
| ) | |
| return "regenerate" | |
| return "end" | |
| # ---- build graph ---- | |
| graph = StateGraph(AgentState) | |
| graph.add_node("retrieve", retrieve_node) | |
| graph.add_node("optimize", optimize_node) | |
| graph.add_node("generate", generate_node) | |
| graph.add_node("verify", verify_node) | |
| graph.set_entry_point("retrieve") | |
| graph.add_edge("retrieve", "optimize") | |
| graph.add_edge("optimize", "generate") | |
| graph.add_edge("generate", "verify") | |
| graph.add_conditional_edges( | |
| "verify", | |
| should_regenerate, | |
| {"regenerate": "generate", "end": END}, | |
| ) | |
| return graph.compile() | |
| async def run_pipeline( | |
| request: AnalyzeRequest, | |
| neo4j: Neo4jClient, | |
| llm: LLMClient, | |
| settings: Settings, | |
| request_id: str, | |
| ) -> AnalyzeResponse: | |
| import hashlib | |
| import json as _json | |
| # Load DRL policy if checkpoint exists | |
| policy = None | |
| try: | |
| import os | |
| from backend.drl.trainer import load_trained_policy | |
| ckpt = settings.drl_checkpoint_path | |
| if os.path.exists(ckpt): | |
| policy = load_trained_policy(ckpt) | |
| except Exception as exc: | |
| log.warning(f"DRL checkpoint not loaded: {exc}") | |
| # --- Cache check --- | |
| cap_ids_for_cache = ( | |
| request.selected_capability_ids | |
| if request.selected_capability_ids | |
| else [] | |
| ) | |
| org_keyword = request.org_type.split()[0].lower() if request.org_type else "" | |
| cache_key = None | |
| if cap_ids_for_cache: | |
| cache_key = hashlib.md5( | |
| ("|".join(sorted(cap_ids_for_cache)) + "|" + request.org_type.lower()).encode() | |
| ).hexdigest()[:16] | |
| # Try exact cache hit | |
| cached = neo4j.run_query( | |
| "MATCH (o:GeneratedOutput {cache_key: $cache_key}) " | |
| "SET o.hit_count = coalesce(o.hit_count,0)+1, o.last_accessed=datetime() " | |
| "RETURN o.output_json AS output_json", | |
| cache_key=cache_key, | |
| ) | |
| if cached and cached[0].get("output_json"): | |
| log.info(f"[{request_id}] Cache HIT for key {cache_key}") | |
| try: | |
| return AnalyzeResponse.model_validate_json(cached[0]["output_json"]) | |
| except Exception as exc: | |
| log.warning(f"Cache deserialize failed: {exc}") | |
| app_graph = build_graph(neo4j, llm, settings, policy=policy) | |
| initial_state: AgentState = { | |
| "request": request, | |
| "request_id": request_id, | |
| "relevant_capabilities": [], | |
| "graph_context": "", | |
| "priority_result": None, | |
| "roadmap_draft": [], | |
| "compliance_issues": [], | |
| "compliance_summary": None, | |
| "final_roadmap": [], | |
| "iteration_count": 0, | |
| "errors": [], | |
| "t_retrieve": 0.0, | |
| "t_optimize": 0.0, | |
| "t_generate": 0.0, | |
| "t_verify": 0.0, | |
| } | |
| final_state = await app_graph.ainvoke(initial_state) | |
| phases = final_state.get("final_roadmap") or final_state.get("roadmap_draft") or [] | |
| compliance = final_state.get("compliance_summary") | |
| pr = final_state.get("priority_result") | |
| # --- Cache store --- | |
| if cache_key: | |
| try: | |
| resp_for_cache = AnalyzeResponse( | |
| request_id=request_id, | |
| org_type=request.org_type, | |
| phases=phases, | |
| compliance_summary=compliance, | |
| amd_metrics=AMDMetrics(), | |
| drl_trace=None, | |
| ) | |
| output_json = resp_for_cache.model_dump_json() | |
| epics_count = sum(len(p.epics) for p in phases) | |
| neo4j.run_query( | |
| STORE_GENERATED_OUTPUT, | |
| cache_key=cache_key, | |
| org_type=request.org_type, | |
| output_json=output_json, | |
| capability_ids=cap_ids_for_cache, | |
| phases_count=len(phases), | |
| epics_count=epics_count, | |
| ) | |
| log.info(f"[{request_id}] Cached output under key {cache_key}") | |
| except Exception as exc: | |
| log.warning(f"Cache store failed: {exc}") | |
| # Build AMD metrics — prefer remote vLLM GPU info over local torch detection | |
| gpu_name, rocm_version = _detect_gpu(settings) | |
| amd_metrics = AMDMetrics( | |
| gpu_device=gpu_name, | |
| rocm_version=rocm_version, | |
| processing_time_seconds=round( | |
| final_state.get("t_retrieve", 0) | |
| + final_state.get("t_optimize", 0) | |
| + final_state.get("t_generate", 0) | |
| + final_state.get("t_verify", 0), | |
| 2, | |
| ), | |
| capabilities_retrieved=len(final_state.get("relevant_capabilities") or []), | |
| iterations=final_state.get("iteration_count", 1), | |
| ) | |
| drl_trace = None | |
| if pr: | |
| from backend.schemas.response import CapabilityScore | |
| drl_trace = DRLTrace( | |
| drl_used=pr.drl_used, | |
| state_vector=pr.state_vector, | |
| capability_scores=[ | |
| CapabilityScore( | |
| capability_id=c.capability.get("id", ""), | |
| capability_name=c.capability.get("name", ""), | |
| score=pr.priority_scores[i] if i < len(pr.priority_scores) else 0.0, | |
| ) | |
| for i, c in enumerate(pr.ordered_capabilities[:10]) | |
| ], | |
| ) | |
| return AnalyzeResponse( | |
| request_id=request_id, | |
| org_type=request.org_type, | |
| phases=phases, | |
| compliance_summary=compliance, | |
| amd_metrics=amd_metrics, | |
| drl_trace=drl_trace, | |
| ) | |