EA_strat_optimizer / backend /agents /orchestrator.py
TheQuantEd's picture
deploy: AMD EA Strategy Optimizer — Neo4j + FastAPI + Streamlit
6252f54
"""LangGraph StateGraph orchestrator: retrieve → optimize → generate → verify (→ regenerate loop)."""
import logging
import time
from typing import TypedDict, Annotated
import operator
import torch
from langgraph.graph import StateGraph, END
from backend.agents.retriever import RetrieverAgent, EnrichedCapability
from backend.agents.optimizer import OptimizerAgent, PrioritizationResult
from backend.agents.generator import GeneratorAgent
from backend.agents.verifier import VerifierAgent
from backend.config import Settings
from backend.graph.neo4j_client import Neo4jClient
from backend.graph.cypher_queries import STORE_GENERATED_OUTPUT
from backend.llm.client import LLMClient
from backend.schemas.request import AnalyzeRequest
from backend.schemas.response import (
AnalyzeResponse,
RoadmapPhase,
AMDMetrics,
DRLTrace,
ComplianceSummary,
)
log = logging.getLogger(__name__)
MAX_ITERATIONS = 2
COMPLIANCE_THRESHOLD = 70
class AgentState(TypedDict):
request: AnalyzeRequest
request_id: str
# Retriever outputs
relevant_capabilities: list[EnrichedCapability]
graph_context: str
# Optimizer outputs
priority_result: PrioritizationResult | None
# Generator outputs
roadmap_draft: list[RoadmapPhase]
compliance_issues: list[str]
# Verifier outputs
compliance_summary: ComplianceSummary | None
final_roadmap: list[RoadmapPhase]
# Control
iteration_count: int
errors: list[str]
# Timing
t_retrieve: float
t_optimize: float
t_generate: float
t_verify: float
def _make_graph_context(caps: list[EnrichedCapability]) -> str:
lines: list[str] = []
domains_seen: set[str] = set()
for c in caps[:10]:
cap = c.capability
std = c.standard or {}
trend = c.trend or {}
domain_name = c.domain.get("name", "")
domains_seen.add(domain_name)
lines.append(
f"Cap: {cap.get('name','')} | Domain: {domain_name} "
f"| Std: {std.get('name','')} | Trend: {trend.get('name','')}"
)
if len(domains_seen) > 1:
lines.insert(0, f"[CROSS-DOMAIN CONTEXT: {len(domains_seen)} domains — {', '.join(sorted(domains_seen))}]")
return "\n".join(lines)
def _detect_gpu(settings: Settings) -> tuple[str, str | None]:
"""
Return (gpu_device_name, rocm_version).
If VLLM_BASE_URL points to a non-localhost host, query the vLLM metrics
endpoint to confirm the remote AMD GPU is active. Falls back to local
torch detection, then plain CPU.
"""
import re
import urllib.request
base_url = settings.vllm_base_url # e.g. http://134.199.197.181:8000/v1
is_remote = not re.search(r"localhost|127\.0\.0\.1", base_url)
if is_remote:
# Derive metrics URL: http://host:port/metrics
metrics_base = re.sub(r"/v1/?$", "", base_url)
try:
with urllib.request.urlopen(f"{metrics_base}/metrics", timeout=3) as resp:
text = resp.read().decode()
# If vLLM is serving, report the AMD MI300X
if "vllm:" in text:
rocm = None
for line in text.splitlines():
if "rocm_version" in line.lower() or "hip_version" in line.lower():
rocm = line.split()[-1] if line.split() else None
break
return "AMD Instinct MI300X", rocm or "ROCm"
except Exception:
pass
# Remote URL configured but unreachable — still label it AMD
return "AMD Instinct MI300X (vLLM)", None
# Local torch detection
if torch.cuda.is_available():
return torch.cuda.get_device_name(0), getattr(torch.version, "hip", None)
return "CPU", None
def build_graph(
neo4j: Neo4jClient,
llm: LLMClient,
settings: Settings,
policy=None,
):
retriever = RetrieverAgent(neo4j=neo4j, llm=llm)
optimizer = OptimizerAgent(policy=policy)
generator = GeneratorAgent(llm=llm)
verifier = VerifierAgent(neo4j=neo4j, llm=llm)
# ---- node functions ----
async def retrieve_node(state: AgentState) -> dict:
t0 = time.time()
req = state["request"]
caps: list = []
# Tier 1: exact capability IDs from questionnaire
if req.selected_capability_ids:
caps = await retriever.retrieve_by_ids(
capability_ids=req.selected_capability_ids,
org_type=req.org_type,
goals=req.goals,
)
# Tier 2: domain names from questionnaire (cross-domain safe)
if not caps and req.sector_focus:
caps = await retriever.retrieve_by_domain_names(
domain_names=req.sector_focus,
org_type=req.org_type,
goals=req.goals,
)
# Tier 3: semantic vector + cypher traversal fallback
if not caps:
caps = await retriever.retrieve(
org_type=req.org_type,
goals=req.goals,
sectors=req.sector_focus,
)
return {
"relevant_capabilities": caps,
"graph_context": _make_graph_context(caps),
"t_retrieve": time.time() - t0,
}
def optimize_node(state: AgentState) -> dict:
t0 = time.time()
req = state["request"]
result = optimizer.prioritize(
caps=state["relevant_capabilities"],
budget_tier=req.budget_tier,
timeline_months=req.timeline_months,
risk_tolerance=req.risk_tolerance,
)
return {"priority_result": result, "t_optimize": time.time() - t0}
async def generate_node(state: AgentState) -> dict:
t0 = time.time()
pr = state.get("priority_result")
caps = pr.ordered_capabilities if pr else state["relevant_capabilities"]
phases = await generator.generate(
caps=caps,
request=state["request"],
compliance_issues=state.get("compliance_issues") or None,
)
return {
"roadmap_draft": phases,
"t_generate": time.time() - t0,
"iteration_count": state.get("iteration_count", 0) + 1,
}
async def verify_node(state: AgentState) -> dict:
t0 = time.time()
summary = await verifier.verify(state["roadmap_draft"])
updates: dict = {
"compliance_summary": summary,
"t_verify": time.time() - t0,
}
if summary.score >= COMPLIANCE_THRESHOLD or state.get("iteration_count", 0) >= MAX_ITERATIONS:
updates["final_roadmap"] = state["roadmap_draft"]
updates["compliance_issues"] = []
else:
updates["compliance_issues"] = summary.issues
return updates
def should_regenerate(state: AgentState) -> str:
if (
state.get("final_roadmap")
or state.get("iteration_count", 0) >= MAX_ITERATIONS
):
return "end"
summary = state.get("compliance_summary")
if summary and summary.score < COMPLIANCE_THRESHOLD and state.get("compliance_issues"):
log.info(
f"Compliance score {summary.score} < {COMPLIANCE_THRESHOLD}; "
f"regenerating (iteration {state['iteration_count']})"
)
return "regenerate"
return "end"
# ---- build graph ----
graph = StateGraph(AgentState)
graph.add_node("retrieve", retrieve_node)
graph.add_node("optimize", optimize_node)
graph.add_node("generate", generate_node)
graph.add_node("verify", verify_node)
graph.set_entry_point("retrieve")
graph.add_edge("retrieve", "optimize")
graph.add_edge("optimize", "generate")
graph.add_edge("generate", "verify")
graph.add_conditional_edges(
"verify",
should_regenerate,
{"regenerate": "generate", "end": END},
)
return graph.compile()
async def run_pipeline(
request: AnalyzeRequest,
neo4j: Neo4jClient,
llm: LLMClient,
settings: Settings,
request_id: str,
) -> AnalyzeResponse:
import hashlib
import json as _json
# Load DRL policy if checkpoint exists
policy = None
try:
import os
from backend.drl.trainer import load_trained_policy
ckpt = settings.drl_checkpoint_path
if os.path.exists(ckpt):
policy = load_trained_policy(ckpt)
except Exception as exc:
log.warning(f"DRL checkpoint not loaded: {exc}")
# --- Cache check ---
cap_ids_for_cache = (
request.selected_capability_ids
if request.selected_capability_ids
else []
)
org_keyword = request.org_type.split()[0].lower() if request.org_type else ""
cache_key = None
if cap_ids_for_cache:
cache_key = hashlib.md5(
("|".join(sorted(cap_ids_for_cache)) + "|" + request.org_type.lower()).encode()
).hexdigest()[:16]
# Try exact cache hit
cached = neo4j.run_query(
"MATCH (o:GeneratedOutput {cache_key: $cache_key}) "
"SET o.hit_count = coalesce(o.hit_count,0)+1, o.last_accessed=datetime() "
"RETURN o.output_json AS output_json",
cache_key=cache_key,
)
if cached and cached[0].get("output_json"):
log.info(f"[{request_id}] Cache HIT for key {cache_key}")
try:
return AnalyzeResponse.model_validate_json(cached[0]["output_json"])
except Exception as exc:
log.warning(f"Cache deserialize failed: {exc}")
app_graph = build_graph(neo4j, llm, settings, policy=policy)
initial_state: AgentState = {
"request": request,
"request_id": request_id,
"relevant_capabilities": [],
"graph_context": "",
"priority_result": None,
"roadmap_draft": [],
"compliance_issues": [],
"compliance_summary": None,
"final_roadmap": [],
"iteration_count": 0,
"errors": [],
"t_retrieve": 0.0,
"t_optimize": 0.0,
"t_generate": 0.0,
"t_verify": 0.0,
}
final_state = await app_graph.ainvoke(initial_state)
phases = final_state.get("final_roadmap") or final_state.get("roadmap_draft") or []
compliance = final_state.get("compliance_summary")
pr = final_state.get("priority_result")
# --- Cache store ---
if cache_key:
try:
resp_for_cache = AnalyzeResponse(
request_id=request_id,
org_type=request.org_type,
phases=phases,
compliance_summary=compliance,
amd_metrics=AMDMetrics(),
drl_trace=None,
)
output_json = resp_for_cache.model_dump_json()
epics_count = sum(len(p.epics) for p in phases)
neo4j.run_query(
STORE_GENERATED_OUTPUT,
cache_key=cache_key,
org_type=request.org_type,
output_json=output_json,
capability_ids=cap_ids_for_cache,
phases_count=len(phases),
epics_count=epics_count,
)
log.info(f"[{request_id}] Cached output under key {cache_key}")
except Exception as exc:
log.warning(f"Cache store failed: {exc}")
# Build AMD metrics — prefer remote vLLM GPU info over local torch detection
gpu_name, rocm_version = _detect_gpu(settings)
amd_metrics = AMDMetrics(
gpu_device=gpu_name,
rocm_version=rocm_version,
processing_time_seconds=round(
final_state.get("t_retrieve", 0)
+ final_state.get("t_optimize", 0)
+ final_state.get("t_generate", 0)
+ final_state.get("t_verify", 0),
2,
),
capabilities_retrieved=len(final_state.get("relevant_capabilities") or []),
iterations=final_state.get("iteration_count", 1),
)
drl_trace = None
if pr:
from backend.schemas.response import CapabilityScore
drl_trace = DRLTrace(
drl_used=pr.drl_used,
state_vector=pr.state_vector,
capability_scores=[
CapabilityScore(
capability_id=c.capability.get("id", ""),
capability_name=c.capability.get("name", ""),
score=pr.priority_scores[i] if i < len(pr.priority_scores) else 0.0,
)
for i, c in enumerate(pr.ordered_capabilities[:10])
],
)
return AnalyzeResponse(
request_id=request_id,
org_type=request.org_type,
phases=phases,
compliance_summary=compliance,
amd_metrics=amd_metrics,
drl_trace=drl_trace,
)