RAG / phase3_agent.py
sumitnewold's picture
Upload 10 files
76bd1fc verified
Raw
History Blame Contribute Delete
22.9 kB
"""
Phase 3 β€” LangGraph Multi-Agent Orchestration
Graph flow:
orchestrator
β”‚ (route_by_intent)
β”œβ”€β”€β–Ί rag_agent
β”œβ”€β”€β–Ί compliance_agent
β”œβ”€β”€β–Ί alert_agent
└──► finance_calculator
β”‚
tool_executor
β”‚
self_rag_evaluator
β”‚ (route_self_rag)
β”œβ”€β”€β–Ί rag_agent (if faithfulness < 0.7 and retries < 2)
└──► END
"""
import datetime
import json
import os
from typing import Optional, TypedDict
from dotenv import load_dotenv
from langgraph.graph import END, StateGraph
from langchain_core.documents import Document
from phase1_ingestion import (
CHROMA_DIR,
PDF_PATH,
TABLE_STORE_PATH,
build_embeddings,
build_llm,
load_pdf,
tag_financial_entities,
)
from phase2_retrieval import (
advanced_retrieval_pipeline,
build_bm25_retriever,
build_cross_encoder,
build_dense_retriever,
load_or_build_vectorstore,
)
load_dotenv()
# ── AgentState ────────────────────────────────────────────────────────────────
class AgentState(TypedDict):
query: str
intent: str
reasoning: str
retrieved_docs: list # list of {"page_content": str, "metadata": dict}
final_answer: str
citations: list # list of page numbers
tool_name: Optional[str]
tool_input: Optional[dict]
tool_output: Optional[dict]
faithfulness_score: float
should_rerun: bool
iteration_count: int # counts Self-RAG retry loops (max 2)
def make_initial_state(query: str) -> AgentState:
return {
"query": query,
"intent": "",
"reasoning": "",
"retrieved_docs": [],
"final_answer": "",
"citations": [],
"tool_name": None,
"tool_input": None,
"tool_output": None,
"faithfulness_score": 0.0,
"should_rerun": False,
"iteration_count": 0,
}
# ── Helpers ───────────────────────────────────────────────────────────────────
def docs_to_dicts(docs: list[Document]) -> list[dict]:
return [{"page_content": d.page_content, "metadata": d.metadata} for d in docs]
def context_from_dicts(doc_dicts: list[dict]) -> str:
return "\n\n".join(
f"[Page {d['metadata'].get('page', '?')}]\n{d['page_content']}"
for d in doc_dicts
)
def parse_json_from_response(text: str) -> dict:
"""Robustly extract the first JSON object from an LLM response."""
start = text.find("{")
end = text.rfind("}") + 1
if start == -1 or end == 0:
raise ValueError("No JSON object found in response")
return json.loads(text[start:end])
# ── Finance Tools (Phase 4 will enhance generate_investment_alert with FinBERT)
def extract_financial_metrics(tool_input: dict, llm) -> dict:
company = tool_input.get("company", "Infosys")
context = tool_input.get("context", "")
prompt = f"""Extract financial metrics from the context for {company}.
Return ONLY valid JSON:
{{
"company": "{company}",
"metrics": [
{{"name": "...", "value": "...", "unit": "...", "page": 0, "yoy_change": "..."}}
]
}}
Context:
{context[:3000]}"""
try:
return parse_json_from_response(llm.invoke(prompt).content)
except Exception as e:
return {"company": company, "metrics": [], "error": str(e)}
def generate_risk_summary(tool_input: dict, llm) -> dict:
company = tool_input.get("company", "Infosys")
context = tool_input.get("context", "")
prompt = f"""Summarize risk factors from the context for {company}.
Return ONLY valid JSON:
{{
"company": "{company}",
"risks": [
{{"category": "market|credit|regulatory|operational", "description": "...", "severity": "low|medium|high", "page": 0, "mitigation": "..."}}
]
}}
Context:
{context[:3000]}"""
try:
return parse_json_from_response(llm.invoke(prompt).content)
except Exception as e:
return {"company": company, "risks": [], "error": str(e)}
def flag_compliance_issue(tool_input: dict, llm) -> dict:
company = tool_input.get("company", "Infosys")
context = tool_input.get("context", "")
prompt = f"""Identify compliance and regulatory issues from the context for {company}.
Return ONLY valid JSON:
{{
"company": "{company}",
"compliance_flags": [
{{"regulation": "...", "issue": "...", "severity": "low|medium|high|critical", "page": 0, "recommended_action": "..."}}
]
}}
Context:
{context[:3000]}"""
try:
return parse_json_from_response(llm.invoke(prompt).content)
except Exception as e:
return {"company": company, "compliance_flags": [], "error": str(e)}
def schedule_analyst_review(tool_input: dict, llm) -> dict:
company = tool_input.get("company", "Infosys")
priority = tool_input.get("priority", "normal")
context = tool_input.get("context", "")
prompt = f"""Based on the context about {company}, generate 3 agenda items for an analyst review meeting.
Return ONLY valid JSON: {{"agenda_items": ["item1", "item2", "item3"]}}
Context:
{context[:1500]}"""
try:
data = parse_json_from_response(llm.invoke(prompt).content)
agenda = data.get("agenda_items", [])
except Exception:
agenda = ["Review financial performance", "Assess risk exposure", "Compliance check"]
suggested_date = (datetime.date.today() + datetime.timedelta(days=3)).isoformat()
meeting_id = f"AR-{company[:3].upper()}-{datetime.date.today().strftime('%Y%m%d')}"
return {
"analyst_review_request": {
"company": company,
"review_type": "credit",
"priority": priority,
"suggested_date": suggested_date,
"duration_minutes": 60,
"agenda_items": agenda,
"status": "SCHEDULED",
"meeting_id": meeting_id,
}
}
def generate_investment_alert(tool_input: dict, llm) -> dict:
# FinBERT sentiment will be injected here in Phase 4
company = tool_input.get("company", "Infosys")
context = tool_input.get("context", "")
prompt = f"""Based on the financial context for {company}, generate an investment alert.
Return ONLY valid JSON:
{{
"company": "{company}",
"signal": "buy|sell|hold|watch",
"trigger_reason": "...",
"confidence_score": 0.0,
"supporting_evidence": ["...", "..."],
"finbert_sentiment": {{"label": "positive|negative|neutral", "score": 0.0}}
}}
Note: finbert_sentiment is a placeholder β€” Phase 4 replaces with real FinBERT scores.
Context:
{context[:3000]}"""
try:
return parse_json_from_response(llm.invoke(prompt).content)
except Exception as e:
return {"company": company, "signal": "hold", "error": str(e)}
TOOL_REGISTRY = {
"extract_financial_metrics": extract_financial_metrics,
"generate_risk_summary": generate_risk_summary,
"flag_compliance_issue": flag_compliance_issue,
"schedule_analyst_review": schedule_analyst_review,
"generate_investment_alert": generate_investment_alert,
}
INTENT_TO_TOOL = {
"rag_query": "generate_risk_summary",
"compliance_check": "flag_compliance_issue",
"investment_alert": "generate_investment_alert",
"financial_metrics":"extract_financial_metrics",
"analyst_review": "schedule_analyst_review",
}
# ── Node builders (closures capture components) ───────────────────────────────
def build_agent_nodes(components: dict) -> dict:
llm = components["llm"]
bm25_retriever = components["bm25_retriever"]
dense_retriever= components["dense_retriever"]
cross_encoder = components["cross_encoder"]
tables = components["tables"]
# ── Orchestrator ──────────────────────────────────────────────────────────
def orchestrator_node(state: AgentState) -> dict:
prompt = """Classify the intent of this financial query into exactly one category.
Categories:
- rag_query : general information retrieval from the document
- compliance_check : identify regulatory/SEBI/RBI compliance issues
- investment_alert : generate buy/sell/hold investment signal
- financial_metrics : extract specific financial numbers, ratios, or metrics
- analyst_review : schedule or request an analyst review meeting
Return ONLY valid JSON: {"intent": "...", "reasoning": "..."}
Query: """ + state["query"]
try:
data = parse_json_from_response(llm.invoke(prompt).content)
intent = data.get("intent", "rag_query")
reasoning= data.get("reasoning", "")
except Exception:
intent, reasoning = "rag_query", "fallback"
print(f"\n[ORCHESTRATOR] intent={intent} | {reasoning[:80]}")
return {"intent": intent, "reasoning": reasoning}
# ── Shared retrieval helper ───────────────────────────────────────────────
def _retrieve(query: str) -> tuple[list[dict], list, str]:
"""Run Phase 2 pipeline, return (doc_dicts, citations, context_str)."""
docs = advanced_retrieval_pipeline(
query, bm25_retriever, dense_retriever, cross_encoder, llm
)
doc_dicts = docs_to_dicts(docs)
citations = list({d["metadata"].get("page") for d in doc_dicts})
context = context_from_dicts(doc_dicts)
return doc_dicts, citations, context
def _generate_answer(system_context: str, query: str, context: str) -> str:
prompt = f"""{system_context}
Retrieved context:
{context[:4000]}
Question: {query}
Answer with specific facts and page citations:"""
return llm.invoke(prompt).content
# ── RAG Agent ─────────────────────────────────────────────────────────────
def rag_agent_node(state: AgentState) -> dict:
print(f"[RAG AGENT] Retrieving for: {state['query'][:60]}")
doc_dicts, citations, context = _retrieve(state["query"])
system = ("You are a financial analyst. Answer only from the provided context. "
"Cite page numbers.")
answer = _generate_answer(system, state["query"], context)
tool_name = INTENT_TO_TOOL.get(state["intent"])
tool_input = {"company": "Infosys", "context": context} if tool_name else None
# Inject priority if analyst_review
if tool_name == "schedule_analyst_review":
tool_input["priority"] = (
"urgent" if "urgent" in state["query"].lower() else "normal"
)
print(f"[RAG AGENT] tool={tool_name} | {len(doc_dicts)} docs | "
f"citations={citations}")
return {
"retrieved_docs": doc_dicts,
"final_answer": answer,
"citations": citations,
"tool_name": tool_name,
"tool_input": tool_input,
}
# ── Compliance Agent ──────────────────────────────────────────────────────
def compliance_agent_node(state: AgentState) -> dict:
print(f"[COMPLIANCE AGENT] Checking: {state['query'][:60]}")
# Prepend compliance keywords to improve retrieval
retrieval_query = f"regulatory compliance SEBI RBI disclosure {state['query']}"
doc_dicts, citations, context = _retrieve(retrieval_query)
system = ("You are a compliance officer specialising in SEBI and RBI regulations. "
"Identify all compliance risks and regulatory obligations. "
"Cite the exact page and regulation.")
answer = _generate_answer(system, state["query"], context)
tool_input = {"company": "Infosys", "context": context}
print(f"[COMPLIANCE AGENT] {len(doc_dicts)} docs | citations={citations}")
return {
"retrieved_docs": doc_dicts,
"final_answer": answer,
"citations": citations,
"tool_name": "flag_compliance_issue",
"tool_input": tool_input,
}
# ── Alert Agent ───────────────────────────────────────────────────────────
def alert_agent_node(state: AgentState) -> dict:
print(f"[ALERT AGENT] Analysing: {state['query'][:60]}")
retrieval_query = (f"revenue profit growth performance outlook "
f"risk sentiment {state['query']}")
doc_dicts, citations, context = _retrieve(retrieval_query)
system = ("You are an investment analyst. Assess the company's financial health "
"and generate a clear investment signal with supporting evidence. "
"Note: FinBERT sentiment scoring will be added in Phase 4.")
answer = _generate_answer(system, state["query"], context)
tool_input = {"company": "Infosys", "context": context}
print(f"[ALERT AGENT] {len(doc_dicts)} docs | citations={citations}")
return {
"retrieved_docs": doc_dicts,
"final_answer": answer,
"citations": citations,
"tool_name": "generate_investment_alert",
"tool_input": tool_input,
}
# ── Finance Calculator Agent ──────────────────────────────────────────────
def finance_calculator_node(state: AgentState) -> dict:
print(f"[FINANCE CALC] Computing: {state['query'][:60]}")
doc_dicts, citations, context = _retrieve(state["query"])
# Augment context with matching table data
query_lower = state["query"].lower()
relevant_tables = [
t for t in tables
if any(w in t["text_representation"].lower()
for w in query_lower.split() if len(w) > 3)
][:3]
table_context = "\n\n".join(
t["text_representation"] for t in relevant_tables
)
full_context = (
f"FINANCIAL TABLES:\n{table_context}\n\nDOCUMENT EXCERPTS:\n{context}"
if table_context else context
)
system = ("You are a financial analyst specialising in quantitative analysis. "
"Extract precise numbers, ratios, and year-on-year changes. "
"Be exact β€” use figures directly from the tables and document.")
answer = _generate_answer(system, state["query"], full_context)
tool_input = {"company": "Infosys", "context": full_context}
print(f"[FINANCE CALC] {len(doc_dicts)} docs + {len(relevant_tables)} tables | "
f"citations={citations}")
return {
"retrieved_docs": doc_dicts,
"final_answer": answer,
"citations": citations,
"tool_name": "extract_financial_metrics",
"tool_input": tool_input,
}
# ── Tool Executor ─────────────────────────────────────────────────────────
def tool_executor_node(state: AgentState) -> dict:
tool_name = state.get("tool_name")
if not tool_name or tool_name not in TOOL_REGISTRY:
print("[TOOL EXECUTOR] No tool β€” passing answer through")
return {"tool_output": {"answer": state["final_answer"]}}
print(f"[TOOL EXECUTOR] Running: {tool_name}")
tool_fn = TOOL_REGISTRY[tool_name]
output = tool_fn(state["tool_input"], llm)
print(f"[TOOL EXECUTOR] Done β€” keys: {list(output.keys())}")
return {"tool_output": output}
# ── Self-RAG Evaluator ────────────────────────────────────────────────────
def self_rag_evaluator_node(state: AgentState) -> dict:
context = context_from_dicts(state["retrieved_docs"])
prompt = f"""Rate how faithfully this answer is grounded in the provided context.
A score of 1.0 means every claim is directly supported by the context.
A score of 0.0 means the answer contains hallucinated facts not in the context.
Return ONLY valid JSON: {{"score": <0.0-1.0>, "reasoning": "..."}}
Context (truncated):
{context[:2000]}
Answer:
{state['final_answer'][:1000]}"""
try:
data = parse_json_from_response(llm.invoke(prompt).content)
score = float(data.get("score", 0.8))
score = max(0.0, min(1.0, score)) # clamp to [0, 1]
except Exception:
score = 0.8 # conservative default on parse failure
current_iter = state["iteration_count"]
should_rerun = score < 0.7 and current_iter < 2
new_count = current_iter + 1 if should_rerun else current_iter
print(f"[SELF-RAG] faithfulness={score:.2f} | "
f"rerun={should_rerun} | iter={new_count}")
return {
"faithfulness_score": score,
"should_rerun": should_rerun,
"iteration_count": new_count,
}
return {
"orchestrator": orchestrator_node,
"rag_agent": rag_agent_node,
"compliance_agent": compliance_agent_node,
"alert_agent": alert_agent_node,
"finance_calculator":finance_calculator_node,
"tool_executor": tool_executor_node,
"self_rag_evaluator":self_rag_evaluator_node,
}
# ── Routing functions ─────────────────────────────────────────────────────────
def route_by_intent(state: AgentState) -> str:
mapping = {
"rag_query": "rag_agent",
"compliance_check": "compliance_agent",
"investment_alert": "alert_agent",
"financial_metrics": "finance_calculator",
"analyst_review": "rag_agent",
}
node = mapping.get(state["intent"], "rag_agent")
print(f"[ROUTER] intent={state['intent']} β†’ {node}")
return node
def route_self_rag(state: AgentState) -> str:
return "rag_agent" if state["should_rerun"] else "end"
# ── Graph builder ─────────────────────────────────────────────────────────────
def build_graph(components: dict):
nodes = build_agent_nodes(components)
workflow = StateGraph(AgentState)
for name, fn in nodes.items():
workflow.add_node(name, fn)
workflow.set_entry_point("orchestrator")
workflow.add_conditional_edges(
"orchestrator",
route_by_intent,
{
"rag_agent": "rag_agent",
"compliance_agent": "compliance_agent",
"alert_agent": "alert_agent",
"finance_calculator": "finance_calculator",
},
)
for specialist in ["rag_agent", "compliance_agent",
"alert_agent", "finance_calculator"]:
workflow.add_edge(specialist, "tool_executor")
workflow.add_edge("tool_executor", "self_rag_evaluator")
workflow.add_conditional_edges(
"self_rag_evaluator",
route_self_rag,
{"rag_agent": "rag_agent", "end": END},
)
return workflow.compile()
# ── Component initialiser ─────────────────────────────────────────────────────
def build_components() -> dict:
print("=== Initialising Phase 3 components ===")
llm = build_llm()
embeddings = build_embeddings()
pages = load_pdf(PDF_PATH)
# spaCy NER (entities metadata) is not used at query time β€” BM25 needs only
# page_content and the dense store is already persisted. Skipping NER here
# cuts a large chunk off cold-start time. (Phase 1 ingestion still tags.)
vs = load_or_build_vectorstore(embeddings, pages)
bm25 = build_bm25_retriever(pages)
dense = build_dense_retriever(vs)
cross_enc = build_cross_encoder()
with open(TABLE_STORE_PATH) as f:
tables = json.load(f)
print(f"[INIT] Loaded {len(tables)} tables from table store")
return {
"llm": llm,
"embeddings": embeddings,
"bm25_retriever": bm25,
"dense_retriever": dense,
"cross_encoder": cross_enc,
"tables": tables,
}
# ── Pretty-print result ───────────────────────────────────────────────────────
def print_result(result: AgentState) -> None:
print("\n" + "=" * 65)
print(f"QUERY : {result['query']}")
print(f"INTENT : {result['intent']}")
print(f"FAITHFULNESS : {result['faithfulness_score']:.2f}")
print(f"CITATIONS : pages {result['citations']}")
print(f"\nANSWER:\n{result['final_answer'][:600]}")
if result.get("tool_output"):
out = result["tool_output"]
print(f"\nTOOL OUTPUT ({result['tool_name']}):")
print(json.dumps(out, indent=2)[:800])
print("=" * 65)
# ── Main ──────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
components = build_components()
app = build_graph(components)
print("\n=== LangGraph compiled successfully ===")
test_queries = [
"What was Infosys revenue and operating margin in FY25?",
"Are there any SEBI compliance or regulatory issues mentioned in the annual report?",
"Generate an investment alert for Infosys based on their FY25 performance.",
"Schedule an urgent credit review for Infosys.",
]
for query in test_queries:
print(f"\n{'#' * 65}")
result = app.invoke(make_initial_state(query))
print_result(result)