""" Phase 3 — LangGraph Multi-Agent Orchestration Graph flow: orchestrator │ (route_by_intent) ├──► rag_agent ├──► compliance_agent ├──► alert_agent └──► finance_calculator │ tool_executor │ self_rag_evaluator │ (route_self_rag) ├──► rag_agent (if faithfulness < 0.7 and retries < 2) └──► END """ import datetime import json import os from typing import Optional, TypedDict from dotenv import load_dotenv from langgraph.graph import END, StateGraph from langchain_core.documents import Document from phase1_ingestion import ( CHROMA_DIR, PDF_PATH, TABLE_STORE_PATH, build_embeddings, build_llm, load_pdf, tag_financial_entities, ) from phase2_retrieval import ( advanced_retrieval_pipeline, build_bm25_retriever, build_cross_encoder, build_dense_retriever, load_or_build_vectorstore, ) load_dotenv() # ── AgentState ──────────────────────────────────────────────────────────────── class AgentState(TypedDict): query: str intent: str reasoning: str retrieved_docs: list # list of {"page_content": str, "metadata": dict} final_answer: str citations: list # list of page numbers tool_name: Optional[str] tool_input: Optional[dict] tool_output: Optional[dict] faithfulness_score: float should_rerun: bool iteration_count: int # counts Self-RAG retry loops (max 2) def make_initial_state(query: str) -> AgentState: return { "query": query, "intent": "", "reasoning": "", "retrieved_docs": [], "final_answer": "", "citations": [], "tool_name": None, "tool_input": None, "tool_output": None, "faithfulness_score": 0.0, "should_rerun": False, "iteration_count": 0, } # ── Helpers ─────────────────────────────────────────────────────────────────── def docs_to_dicts(docs: list[Document]) -> list[dict]: return [{"page_content": d.page_content, "metadata": d.metadata} for d in docs] def context_from_dicts(doc_dicts: list[dict]) -> str: return "\n\n".join( f"[Page {d['metadata'].get('page', '?')}]\n{d['page_content']}" for d in doc_dicts ) def parse_json_from_response(text: str) -> dict: """Robustly extract the first JSON object from an LLM response.""" start = text.find("{") end = text.rfind("}") + 1 if start == -1 or end == 0: raise ValueError("No JSON object found in response") return json.loads(text[start:end]) # ── Finance Tools (Phase 4 will enhance generate_investment_alert with FinBERT) def extract_financial_metrics(tool_input: dict, llm) -> dict: company = tool_input.get("company", "Infosys") context = tool_input.get("context", "") prompt = f"""Extract financial metrics from the context for {company}. Return ONLY valid JSON: {{ "company": "{company}", "metrics": [ {{"name": "...", "value": "...", "unit": "...", "page": 0, "yoy_change": "..."}} ] }} Context: {context[:3000]}""" try: return parse_json_from_response(llm.invoke(prompt).content) except Exception as e: return {"company": company, "metrics": [], "error": str(e)} def generate_risk_summary(tool_input: dict, llm) -> dict: company = tool_input.get("company", "Infosys") context = tool_input.get("context", "") prompt = f"""Summarize risk factors from the context for {company}. Return ONLY valid JSON: {{ "company": "{company}", "risks": [ {{"category": "market|credit|regulatory|operational", "description": "...", "severity": "low|medium|high", "page": 0, "mitigation": "..."}} ] }} Context: {context[:3000]}""" try: return parse_json_from_response(llm.invoke(prompt).content) except Exception as e: return {"company": company, "risks": [], "error": str(e)} def flag_compliance_issue(tool_input: dict, llm) -> dict: company = tool_input.get("company", "Infosys") context = tool_input.get("context", "") prompt = f"""Identify compliance and regulatory issues from the context for {company}. Return ONLY valid JSON: {{ "company": "{company}", "compliance_flags": [ {{"regulation": "...", "issue": "...", "severity": "low|medium|high|critical", "page": 0, "recommended_action": "..."}} ] }} Context: {context[:3000]}""" try: return parse_json_from_response(llm.invoke(prompt).content) except Exception as e: return {"company": company, "compliance_flags": [], "error": str(e)} def schedule_analyst_review(tool_input: dict, llm) -> dict: company = tool_input.get("company", "Infosys") priority = tool_input.get("priority", "normal") context = tool_input.get("context", "") prompt = f"""Based on the context about {company}, generate 3 agenda items for an analyst review meeting. Return ONLY valid JSON: {{"agenda_items": ["item1", "item2", "item3"]}} Context: {context[:1500]}""" try: data = parse_json_from_response(llm.invoke(prompt).content) agenda = data.get("agenda_items", []) except Exception: agenda = ["Review financial performance", "Assess risk exposure", "Compliance check"] suggested_date = (datetime.date.today() + datetime.timedelta(days=3)).isoformat() meeting_id = f"AR-{company[:3].upper()}-{datetime.date.today().strftime('%Y%m%d')}" return { "analyst_review_request": { "company": company, "review_type": "credit", "priority": priority, "suggested_date": suggested_date, "duration_minutes": 60, "agenda_items": agenda, "status": "SCHEDULED", "meeting_id": meeting_id, } } def generate_investment_alert(tool_input: dict, llm) -> dict: # FinBERT sentiment will be injected here in Phase 4 company = tool_input.get("company", "Infosys") context = tool_input.get("context", "") prompt = f"""Based on the financial context for {company}, generate an investment alert. Return ONLY valid JSON: {{ "company": "{company}", "signal": "buy|sell|hold|watch", "trigger_reason": "...", "confidence_score": 0.0, "supporting_evidence": ["...", "..."], "finbert_sentiment": {{"label": "positive|negative|neutral", "score": 0.0}} }} Note: finbert_sentiment is a placeholder — Phase 4 replaces with real FinBERT scores. Context: {context[:3000]}""" try: return parse_json_from_response(llm.invoke(prompt).content) except Exception as e: return {"company": company, "signal": "hold", "error": str(e)} TOOL_REGISTRY = { "extract_financial_metrics": extract_financial_metrics, "generate_risk_summary": generate_risk_summary, "flag_compliance_issue": flag_compliance_issue, "schedule_analyst_review": schedule_analyst_review, "generate_investment_alert": generate_investment_alert, } INTENT_TO_TOOL = { "rag_query": "generate_risk_summary", "compliance_check": "flag_compliance_issue", "investment_alert": "generate_investment_alert", "financial_metrics":"extract_financial_metrics", "analyst_review": "schedule_analyst_review", } # ── Node builders (closures capture components) ─────────────────────────────── def build_agent_nodes(components: dict) -> dict: llm = components["llm"] bm25_retriever = components["bm25_retriever"] dense_retriever= components["dense_retriever"] cross_encoder = components["cross_encoder"] tables = components["tables"] # ── Orchestrator ────────────────────────────────────────────────────────── def orchestrator_node(state: AgentState) -> dict: prompt = """Classify the intent of this financial query into exactly one category. Categories: - rag_query : general information retrieval from the document - compliance_check : identify regulatory/SEBI/RBI compliance issues - investment_alert : generate buy/sell/hold investment signal - financial_metrics : extract specific financial numbers, ratios, or metrics - analyst_review : schedule or request an analyst review meeting Return ONLY valid JSON: {"intent": "...", "reasoning": "..."} Query: """ + state["query"] try: data = parse_json_from_response(llm.invoke(prompt).content) intent = data.get("intent", "rag_query") reasoning= data.get("reasoning", "") except Exception: intent, reasoning = "rag_query", "fallback" print(f"\n[ORCHESTRATOR] intent={intent} | {reasoning[:80]}") return {"intent": intent, "reasoning": reasoning} # ── Shared retrieval helper ─────────────────────────────────────────────── def _retrieve(query: str) -> tuple[list[dict], list, str]: """Run Phase 2 pipeline, return (doc_dicts, citations, context_str).""" docs = advanced_retrieval_pipeline( query, bm25_retriever, dense_retriever, cross_encoder, llm ) doc_dicts = docs_to_dicts(docs) citations = list({d["metadata"].get("page") for d in doc_dicts}) context = context_from_dicts(doc_dicts) return doc_dicts, citations, context def _generate_answer(system_context: str, query: str, context: str) -> str: prompt = f"""{system_context} Retrieved context: {context[:4000]} Question: {query} Answer with specific facts and page citations:""" return llm.invoke(prompt).content # ── RAG Agent ───────────────────────────────────────────────────────────── def rag_agent_node(state: AgentState) -> dict: print(f"[RAG AGENT] Retrieving for: {state['query'][:60]}") doc_dicts, citations, context = _retrieve(state["query"]) system = ("You are a financial analyst. Answer only from the provided context. " "Cite page numbers.") answer = _generate_answer(system, state["query"], context) tool_name = INTENT_TO_TOOL.get(state["intent"]) tool_input = {"company": "Infosys", "context": context} if tool_name else None # Inject priority if analyst_review if tool_name == "schedule_analyst_review": tool_input["priority"] = ( "urgent" if "urgent" in state["query"].lower() else "normal" ) print(f"[RAG AGENT] tool={tool_name} | {len(doc_dicts)} docs | " f"citations={citations}") return { "retrieved_docs": doc_dicts, "final_answer": answer, "citations": citations, "tool_name": tool_name, "tool_input": tool_input, } # ── Compliance Agent ────────────────────────────────────────────────────── def compliance_agent_node(state: AgentState) -> dict: print(f"[COMPLIANCE AGENT] Checking: {state['query'][:60]}") # Prepend compliance keywords to improve retrieval retrieval_query = f"regulatory compliance SEBI RBI disclosure {state['query']}" doc_dicts, citations, context = _retrieve(retrieval_query) system = ("You are a compliance officer specialising in SEBI and RBI regulations. " "Identify all compliance risks and regulatory obligations. " "Cite the exact page and regulation.") answer = _generate_answer(system, state["query"], context) tool_input = {"company": "Infosys", "context": context} print(f"[COMPLIANCE AGENT] {len(doc_dicts)} docs | citations={citations}") return { "retrieved_docs": doc_dicts, "final_answer": answer, "citations": citations, "tool_name": "flag_compliance_issue", "tool_input": tool_input, } # ── Alert Agent ─────────────────────────────────────────────────────────── def alert_agent_node(state: AgentState) -> dict: print(f"[ALERT AGENT] Analysing: {state['query'][:60]}") retrieval_query = (f"revenue profit growth performance outlook " f"risk sentiment {state['query']}") doc_dicts, citations, context = _retrieve(retrieval_query) system = ("You are an investment analyst. Assess the company's financial health " "and generate a clear investment signal with supporting evidence. " "Note: FinBERT sentiment scoring will be added in Phase 4.") answer = _generate_answer(system, state["query"], context) tool_input = {"company": "Infosys", "context": context} print(f"[ALERT AGENT] {len(doc_dicts)} docs | citations={citations}") return { "retrieved_docs": doc_dicts, "final_answer": answer, "citations": citations, "tool_name": "generate_investment_alert", "tool_input": tool_input, } # ── Finance Calculator Agent ────────────────────────────────────────────── def finance_calculator_node(state: AgentState) -> dict: print(f"[FINANCE CALC] Computing: {state['query'][:60]}") doc_dicts, citations, context = _retrieve(state["query"]) # Augment context with matching table data query_lower = state["query"].lower() relevant_tables = [ t for t in tables if any(w in t["text_representation"].lower() for w in query_lower.split() if len(w) > 3) ][:3] table_context = "\n\n".join( t["text_representation"] for t in relevant_tables ) full_context = ( f"FINANCIAL TABLES:\n{table_context}\n\nDOCUMENT EXCERPTS:\n{context}" if table_context else context ) system = ("You are a financial analyst specialising in quantitative analysis. " "Extract precise numbers, ratios, and year-on-year changes. " "Be exact — use figures directly from the tables and document.") answer = _generate_answer(system, state["query"], full_context) tool_input = {"company": "Infosys", "context": full_context} print(f"[FINANCE CALC] {len(doc_dicts)} docs + {len(relevant_tables)} tables | " f"citations={citations}") return { "retrieved_docs": doc_dicts, "final_answer": answer, "citations": citations, "tool_name": "extract_financial_metrics", "tool_input": tool_input, } # ── Tool Executor ───────────────────────────────────────────────────────── def tool_executor_node(state: AgentState) -> dict: tool_name = state.get("tool_name") if not tool_name or tool_name not in TOOL_REGISTRY: print("[TOOL EXECUTOR] No tool — passing answer through") return {"tool_output": {"answer": state["final_answer"]}} print(f"[TOOL EXECUTOR] Running: {tool_name}") tool_fn = TOOL_REGISTRY[tool_name] output = tool_fn(state["tool_input"], llm) print(f"[TOOL EXECUTOR] Done — keys: {list(output.keys())}") return {"tool_output": output} # ── Self-RAG Evaluator ──────────────────────────────────────────────────── def self_rag_evaluator_node(state: AgentState) -> dict: context = context_from_dicts(state["retrieved_docs"]) prompt = f"""Rate how faithfully this answer is grounded in the provided context. A score of 1.0 means every claim is directly supported by the context. A score of 0.0 means the answer contains hallucinated facts not in the context. Return ONLY valid JSON: {{"score": <0.0-1.0>, "reasoning": "..."}} Context (truncated): {context[:2000]} Answer: {state['final_answer'][:1000]}""" try: data = parse_json_from_response(llm.invoke(prompt).content) score = float(data.get("score", 0.8)) score = max(0.0, min(1.0, score)) # clamp to [0, 1] except Exception: score = 0.8 # conservative default on parse failure current_iter = state["iteration_count"] should_rerun = score < 0.7 and current_iter < 2 new_count = current_iter + 1 if should_rerun else current_iter print(f"[SELF-RAG] faithfulness={score:.2f} | " f"rerun={should_rerun} | iter={new_count}") return { "faithfulness_score": score, "should_rerun": should_rerun, "iteration_count": new_count, } return { "orchestrator": orchestrator_node, "rag_agent": rag_agent_node, "compliance_agent": compliance_agent_node, "alert_agent": alert_agent_node, "finance_calculator":finance_calculator_node, "tool_executor": tool_executor_node, "self_rag_evaluator":self_rag_evaluator_node, } # ── Routing functions ───────────────────────────────────────────────────────── def route_by_intent(state: AgentState) -> str: mapping = { "rag_query": "rag_agent", "compliance_check": "compliance_agent", "investment_alert": "alert_agent", "financial_metrics": "finance_calculator", "analyst_review": "rag_agent", } node = mapping.get(state["intent"], "rag_agent") print(f"[ROUTER] intent={state['intent']} → {node}") return node def route_self_rag(state: AgentState) -> str: return "rag_agent" if state["should_rerun"] else "end" # ── Graph builder ───────────────────────────────────────────────────────────── def build_graph(components: dict): nodes = build_agent_nodes(components) workflow = StateGraph(AgentState) for name, fn in nodes.items(): workflow.add_node(name, fn) workflow.set_entry_point("orchestrator") workflow.add_conditional_edges( "orchestrator", route_by_intent, { "rag_agent": "rag_agent", "compliance_agent": "compliance_agent", "alert_agent": "alert_agent", "finance_calculator": "finance_calculator", }, ) for specialist in ["rag_agent", "compliance_agent", "alert_agent", "finance_calculator"]: workflow.add_edge(specialist, "tool_executor") workflow.add_edge("tool_executor", "self_rag_evaluator") workflow.add_conditional_edges( "self_rag_evaluator", route_self_rag, {"rag_agent": "rag_agent", "end": END}, ) return workflow.compile() # ── Component initialiser ───────────────────────────────────────────────────── def build_components() -> dict: print("=== Initialising Phase 3 components ===") llm = build_llm() embeddings = build_embeddings() pages = load_pdf(PDF_PATH) # spaCy NER (entities metadata) is not used at query time — BM25 needs only # page_content and the dense store is already persisted. Skipping NER here # cuts a large chunk off cold-start time. (Phase 1 ingestion still tags.) vs = load_or_build_vectorstore(embeddings, pages) bm25 = build_bm25_retriever(pages) dense = build_dense_retriever(vs) cross_enc = build_cross_encoder() with open(TABLE_STORE_PATH) as f: tables = json.load(f) print(f"[INIT] Loaded {len(tables)} tables from table store") return { "llm": llm, "embeddings": embeddings, "bm25_retriever": bm25, "dense_retriever": dense, "cross_encoder": cross_enc, "tables": tables, } # ── Pretty-print result ─────────────────────────────────────────────────────── def print_result(result: AgentState) -> None: print("\n" + "=" * 65) print(f"QUERY : {result['query']}") print(f"INTENT : {result['intent']}") print(f"FAITHFULNESS : {result['faithfulness_score']:.2f}") print(f"CITATIONS : pages {result['citations']}") print(f"\nANSWER:\n{result['final_answer'][:600]}") if result.get("tool_output"): out = result["tool_output"] print(f"\nTOOL OUTPUT ({result['tool_name']}):") print(json.dumps(out, indent=2)[:800]) print("=" * 65) # ── Main ────────────────────────────────────────────────────────────────────── if __name__ == "__main__": components = build_components() app = build_graph(components) print("\n=== LangGraph compiled successfully ===") test_queries = [ "What was Infosys revenue and operating margin in FY25?", "Are there any SEBI compliance or regulatory issues mentioned in the annual report?", "Generate an investment alert for Infosys based on their FY25 performance.", "Schedule an urgent credit review for Infosys.", ] for query in test_queries: print(f"\n{'#' * 65}") result = app.invoke(make_initial_state(query)) print_result(result)