Ryan2219's picture
Upload 70 files
e1ced8e verified
"""code_lookup node — lightweight snippet reviewer (no multi-turn tool loop).
Flow:
1. discover_code_locations() — ChromaDB semantic search (~1-2 sec per query)
2. GPT reviews the raw snippets in a SINGLE call — flags relevant ones with
a relevance tag and brief note
3. Raw flagged snippets + GPT notes go to the compliance analyst
No fetch_full_chapter in the initial pass. The compliance analyst can request
targeted chapter fetches via additional_code_queries if it needs more context.
"""
from __future__ import annotations
import json
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from openai import OpenAI
from config import CODE_LOOKUP_MODEL, OPENAI_API_KEY
from prompts.code_lookup import CODE_REVIEWER_SYSTEM_PROMPT
from state import AgentMessage, CodeQuery, CodeSection, ComplianceState
from tools.chroma_tools import QueryCache, discover_code_locations
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Single-call snippet reviewer (replaces the multi-turn tool loop)
# ---------------------------------------------------------------------------
def _review_snippets(
research_goal: str,
discover_report: str,
) -> tuple[str, list[CodeSection]]:
"""GPT reviews discover results in ONE call.
Returns (brief_review, flagged_sections).
"""
client = OpenAI(api_key=OPENAI_API_KEY)
response = client.chat.completions.create(
model=CODE_LOOKUP_MODEL,
messages=[
{"role": "system", "content": CODE_REVIEWER_SYSTEM_PROMPT},
{
"role": "user",
"content": (
f"## Research Goal\n{research_goal}\n\n"
f"## Code Snippets from Database\n{discover_report}"
),
},
],
response_format={"type": "json_object"},
)
raw = response.choices[0].message.content or "{}"
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
logger.warning("GPT snippet review returned invalid JSON, using raw text.")
return raw, []
flagged_sections: list[CodeSection] = []
for item in parsed.get("relevant_sections", []):
flagged_sections.append(
CodeSection(
section_full=item.get("section_id", "?"),
code_type=item.get("code_type", "?"),
parent_major=item.get("chapter", "?"),
text=item.get("snippet", "")[:1500],
relevance=item.get("relevance_note", ""),
)
)
summary = parsed.get("summary", "No summary provided.")
return summary, flagged_sections
def _run_single_lookup(
cq: CodeQuery,
query_cache: QueryCache | None = None,
) -> tuple[str, list[CodeSection], str]:
"""Run discover + review for ONE code query.
Returns (summary, flagged_sections, discover_report_raw).
"""
research_goal = f"{cq['query']} (Context: {cq['context']})"
logger.info("ChromaDB query: %s", research_goal)
# Step 1: ChromaDB discover (fast, ~1-2s)
discover_report = discover_code_locations(research_goal, cache=query_cache)
# Step 2: GPT reviews snippets in a single call
summary, flagged = _review_snippets(research_goal, discover_report)
return summary, flagged, discover_report
# ---------------------------------------------------------------------------
# LangGraph node functions
# ---------------------------------------------------------------------------
def initial_code_lookup(state: ComplianceState) -> dict:
"""Run the initial code lookup based on planner's code_queries.
All queries run in PARALLEL via ThreadPoolExecutor.
Each query = 1 discover + 1 GPT review call (no multi-turn loop).
"""
code_queries = state.get("code_queries", [])
if not code_queries:
return {
"code_report": "No code queries were planned.",
"code_sections": [],
"discussion_log": [
AgentMessage(
timestamp=datetime.now().strftime("%H:%M:%S"),
agent="code_analyst",
action="search_code",
summary="No code queries to execute.",
detail="The planner did not generate any code queries.",
evidence_refs=[],
)
],
"status_message": ["No code queries to execute."],
}
query_cache = QueryCache()
discussion_messages: list[AgentMessage] = []
# Add "searching" messages for all queries upfront
for cq in code_queries:
discussion_messages.append(
AgentMessage(
timestamp=datetime.now().strftime("%H:%M:%S"),
agent="code_analyst",
action="search_code",
summary=f"Searching: {cq['query'][:80]}...",
detail=f"Focus area: {cq['focus_area']}\nContext: {cq['context']}",
evidence_refs=[],
)
)
# Execute ALL queries concurrently
results: dict[int, tuple[str, list[CodeSection], str]] = {}
with ThreadPoolExecutor(max_workers=min(len(code_queries), 4)) as pool:
futures = {
pool.submit(_run_single_lookup, cq, query_cache): i
for i, cq in enumerate(code_queries)
}
for future in as_completed(futures):
i = futures[future]
try:
summary, flagged, _raw = future.result()
results[i] = (summary, flagged, _raw)
cq = code_queries[i]
section_ids = ", ".join(
s["section_full"] for s in flagged
)
discussion_messages.append(
AgentMessage(
timestamp=datetime.now().strftime("%H:%M:%S"),
agent="code_analyst",
action="search_code",
summary=(
f"Flagged {len(flagged)} sections "
f"for '{cq['query']}'"
),
detail=(
f"**Query:** {cq['query']}\n"
f"**Focus:** {cq['focus_area']}\n\n"
f"**Sections:** {section_ids}\n\n"
f"{summary[:800]}"
),
evidence_refs=[s["section_full"] for s in flagged],
)
)
except Exception as e:
logger.error("Code query %d failed: %s", i, e)
results[i] = (f"Error: {e}", [], "")
discussion_messages.append(
AgentMessage(
timestamp=datetime.now().strftime("%H:%M:%S"),
agent="code_analyst",
action="search_code",
summary=f"Query {i + 1} failed: {e}",
detail=str(e),
evidence_refs=[],
)
)
# Reassemble in original order
report_parts: list[str] = []
all_sections: list[CodeSection] = []
for i in range(len(code_queries)):
summary, flagged, _raw = results.get(i, ("No result", [], ""))
cq = code_queries[i]
report_parts.append(
f"### Query {i + 1}: {cq['focus_area']}\n{summary}"
)
all_sections.extend(flagged)
combined_report = "\n\n---\n\n".join(report_parts)
return {
"code_report": combined_report,
"code_sections": all_sections,
"discussion_log": discussion_messages,
"status_message": [
f"Code lookup complete. {len(all_sections)} relevant sections "
f"flagged across {len(code_queries)} queries."
],
}
def targeted_code_lookup(state: ComplianceState) -> dict:
"""Run additional code lookups requested by the compliance analyst.
These may use fetch_full_chapter for deeper context when the analyst
needs full exception text or cross-reference detail.
"""
additional_queries = state.get("additional_code_queries", [])
if not additional_queries:
return {
"status_message": ["No additional code queries."],
}
query_cache = QueryCache()
all_sections: list[CodeSection] = []
report_parts: list[str] = []
discussion_messages: list[AgentMessage] = []
for i, cq in enumerate(additional_queries):
discussion_messages.append(
AgentMessage(
timestamp=datetime.now().strftime("%H:%M:%S"),
agent="code_analyst",
action="search_code",
summary=f"Follow-up search: {cq['query'][:80]}...",
detail=f"Requested by compliance analyst.\nFocus: {cq['focus_area']}",
evidence_refs=[],
)
)
summary, flagged, _raw = _run_single_lookup(cq, query_cache)
report_parts.append(summary)
all_sections.extend(flagged)
section_ids = ", ".join(s["section_full"] for s in flagged)
discussion_messages.append(
AgentMessage(
timestamp=datetime.now().strftime("%H:%M:%S"),
agent="code_analyst",
action="search_code",
summary=(
f"Follow-up: flagged {len(flagged)} sections "
f"for '{cq['query']}'"
),
detail=(
f"**Query:** {cq['query']}\n"
f"**Focus:** {cq['focus_area']}\n\n"
f"**Sections:** {section_ids}\n\n"
f"{summary[:800]}"
),
evidence_refs=[s["section_full"] for s in flagged],
)
)
# Append to existing report
existing_report = state.get("code_report", "")
new_report = "\n\n---\n\n".join(report_parts)
combined_report = f"{existing_report}\n\n## FOLLOW-UP CODE RESEARCH\n\n{new_report}"
return {
"code_report": combined_report,
"code_sections": all_sections,
"additional_code_queries": [], # Clear after processing
"discussion_log": discussion_messages,
"status_message": [
f"Targeted code lookup complete. {len(all_sections)} additional sections "
f"from {len(additional_queries)} follow-up queries."
],
}