Spaces:
Sleeping
Sleeping
| """ | |
| Phase 3 β LangGraph Multi-Agent Orchestration | |
| Graph flow: | |
| orchestrator | |
| β (route_by_intent) | |
| ββββΊ rag_agent | |
| ββββΊ compliance_agent | |
| ββββΊ alert_agent | |
| ββββΊ finance_calculator | |
| β | |
| tool_executor | |
| β | |
| self_rag_evaluator | |
| β (route_self_rag) | |
| ββββΊ rag_agent (if faithfulness < 0.7 and retries < 2) | |
| ββββΊ END | |
| """ | |
| import datetime | |
| import json | |
| import os | |
| from typing import Optional, TypedDict | |
| from dotenv import load_dotenv | |
| from langgraph.graph import END, StateGraph | |
| from langchain_core.documents import Document | |
| from phase1_ingestion import ( | |
| CHROMA_DIR, | |
| PDF_PATH, | |
| TABLE_STORE_PATH, | |
| build_embeddings, | |
| build_llm, | |
| load_pdf, | |
| tag_financial_entities, | |
| ) | |
| from phase2_retrieval import ( | |
| advanced_retrieval_pipeline, | |
| build_bm25_retriever, | |
| build_cross_encoder, | |
| build_dense_retriever, | |
| load_or_build_vectorstore, | |
| ) | |
| load_dotenv() | |
| # ββ AgentState ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class AgentState(TypedDict): | |
| query: str | |
| intent: str | |
| reasoning: str | |
| retrieved_docs: list # list of {"page_content": str, "metadata": dict} | |
| final_answer: str | |
| citations: list # list of page numbers | |
| tool_name: Optional[str] | |
| tool_input: Optional[dict] | |
| tool_output: Optional[dict] | |
| faithfulness_score: float | |
| should_rerun: bool | |
| iteration_count: int # counts Self-RAG retry loops (max 2) | |
| def make_initial_state(query: str) -> AgentState: | |
| return { | |
| "query": query, | |
| "intent": "", | |
| "reasoning": "", | |
| "retrieved_docs": [], | |
| "final_answer": "", | |
| "citations": [], | |
| "tool_name": None, | |
| "tool_input": None, | |
| "tool_output": None, | |
| "faithfulness_score": 0.0, | |
| "should_rerun": False, | |
| "iteration_count": 0, | |
| } | |
| # ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def docs_to_dicts(docs: list[Document]) -> list[dict]: | |
| return [{"page_content": d.page_content, "metadata": d.metadata} for d in docs] | |
| def context_from_dicts(doc_dicts: list[dict]) -> str: | |
| return "\n\n".join( | |
| f"[Page {d['metadata'].get('page', '?')}]\n{d['page_content']}" | |
| for d in doc_dicts | |
| ) | |
| def parse_json_from_response(text: str) -> dict: | |
| """Robustly extract the first JSON object from an LLM response.""" | |
| start = text.find("{") | |
| end = text.rfind("}") + 1 | |
| if start == -1 or end == 0: | |
| raise ValueError("No JSON object found in response") | |
| return json.loads(text[start:end]) | |
| # ββ Finance Tools (Phase 4 will enhance generate_investment_alert with FinBERT) | |
| def extract_financial_metrics(tool_input: dict, llm) -> dict: | |
| company = tool_input.get("company", "Infosys") | |
| context = tool_input.get("context", "") | |
| prompt = f"""Extract financial metrics from the context for {company}. | |
| Return ONLY valid JSON: | |
| {{ | |
| "company": "{company}", | |
| "metrics": [ | |
| {{"name": "...", "value": "...", "unit": "...", "page": 0, "yoy_change": "..."}} | |
| ] | |
| }} | |
| Context: | |
| {context[:3000]}""" | |
| try: | |
| return parse_json_from_response(llm.invoke(prompt).content) | |
| except Exception as e: | |
| return {"company": company, "metrics": [], "error": str(e)} | |
| def generate_risk_summary(tool_input: dict, llm) -> dict: | |
| company = tool_input.get("company", "Infosys") | |
| context = tool_input.get("context", "") | |
| prompt = f"""Summarize risk factors from the context for {company}. | |
| Return ONLY valid JSON: | |
| {{ | |
| "company": "{company}", | |
| "risks": [ | |
| {{"category": "market|credit|regulatory|operational", "description": "...", "severity": "low|medium|high", "page": 0, "mitigation": "..."}} | |
| ] | |
| }} | |
| Context: | |
| {context[:3000]}""" | |
| try: | |
| return parse_json_from_response(llm.invoke(prompt).content) | |
| except Exception as e: | |
| return {"company": company, "risks": [], "error": str(e)} | |
| def flag_compliance_issue(tool_input: dict, llm) -> dict: | |
| company = tool_input.get("company", "Infosys") | |
| context = tool_input.get("context", "") | |
| prompt = f"""Identify compliance and regulatory issues from the context for {company}. | |
| Return ONLY valid JSON: | |
| {{ | |
| "company": "{company}", | |
| "compliance_flags": [ | |
| {{"regulation": "...", "issue": "...", "severity": "low|medium|high|critical", "page": 0, "recommended_action": "..."}} | |
| ] | |
| }} | |
| Context: | |
| {context[:3000]}""" | |
| try: | |
| return parse_json_from_response(llm.invoke(prompt).content) | |
| except Exception as e: | |
| return {"company": company, "compliance_flags": [], "error": str(e)} | |
| def schedule_analyst_review(tool_input: dict, llm) -> dict: | |
| company = tool_input.get("company", "Infosys") | |
| priority = tool_input.get("priority", "normal") | |
| context = tool_input.get("context", "") | |
| prompt = f"""Based on the context about {company}, generate 3 agenda items for an analyst review meeting. | |
| Return ONLY valid JSON: {{"agenda_items": ["item1", "item2", "item3"]}} | |
| Context: | |
| {context[:1500]}""" | |
| try: | |
| data = parse_json_from_response(llm.invoke(prompt).content) | |
| agenda = data.get("agenda_items", []) | |
| except Exception: | |
| agenda = ["Review financial performance", "Assess risk exposure", "Compliance check"] | |
| suggested_date = (datetime.date.today() + datetime.timedelta(days=3)).isoformat() | |
| meeting_id = f"AR-{company[:3].upper()}-{datetime.date.today().strftime('%Y%m%d')}" | |
| return { | |
| "analyst_review_request": { | |
| "company": company, | |
| "review_type": "credit", | |
| "priority": priority, | |
| "suggested_date": suggested_date, | |
| "duration_minutes": 60, | |
| "agenda_items": agenda, | |
| "status": "SCHEDULED", | |
| "meeting_id": meeting_id, | |
| } | |
| } | |
| def generate_investment_alert(tool_input: dict, llm) -> dict: | |
| # FinBERT sentiment will be injected here in Phase 4 | |
| company = tool_input.get("company", "Infosys") | |
| context = tool_input.get("context", "") | |
| prompt = f"""Based on the financial context for {company}, generate an investment alert. | |
| Return ONLY valid JSON: | |
| {{ | |
| "company": "{company}", | |
| "signal": "buy|sell|hold|watch", | |
| "trigger_reason": "...", | |
| "confidence_score": 0.0, | |
| "supporting_evidence": ["...", "..."], | |
| "finbert_sentiment": {{"label": "positive|negative|neutral", "score": 0.0}} | |
| }} | |
| Note: finbert_sentiment is a placeholder β Phase 4 replaces with real FinBERT scores. | |
| Context: | |
| {context[:3000]}""" | |
| try: | |
| return parse_json_from_response(llm.invoke(prompt).content) | |
| except Exception as e: | |
| return {"company": company, "signal": "hold", "error": str(e)} | |
| TOOL_REGISTRY = { | |
| "extract_financial_metrics": extract_financial_metrics, | |
| "generate_risk_summary": generate_risk_summary, | |
| "flag_compliance_issue": flag_compliance_issue, | |
| "schedule_analyst_review": schedule_analyst_review, | |
| "generate_investment_alert": generate_investment_alert, | |
| } | |
| INTENT_TO_TOOL = { | |
| "rag_query": "generate_risk_summary", | |
| "compliance_check": "flag_compliance_issue", | |
| "investment_alert": "generate_investment_alert", | |
| "financial_metrics":"extract_financial_metrics", | |
| "analyst_review": "schedule_analyst_review", | |
| } | |
| # ββ Node builders (closures capture components) βββββββββββββββββββββββββββββββ | |
| def build_agent_nodes(components: dict) -> dict: | |
| llm = components["llm"] | |
| bm25_retriever = components["bm25_retriever"] | |
| dense_retriever= components["dense_retriever"] | |
| cross_encoder = components["cross_encoder"] | |
| tables = components["tables"] | |
| # ββ Orchestrator ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def orchestrator_node(state: AgentState) -> dict: | |
| prompt = """Classify the intent of this financial query into exactly one category. | |
| Categories: | |
| - rag_query : general information retrieval from the document | |
| - compliance_check : identify regulatory/SEBI/RBI compliance issues | |
| - investment_alert : generate buy/sell/hold investment signal | |
| - financial_metrics : extract specific financial numbers, ratios, or metrics | |
| - analyst_review : schedule or request an analyst review meeting | |
| Return ONLY valid JSON: {"intent": "...", "reasoning": "..."} | |
| Query: """ + state["query"] | |
| try: | |
| data = parse_json_from_response(llm.invoke(prompt).content) | |
| intent = data.get("intent", "rag_query") | |
| reasoning= data.get("reasoning", "") | |
| except Exception: | |
| intent, reasoning = "rag_query", "fallback" | |
| print(f"\n[ORCHESTRATOR] intent={intent} | {reasoning[:80]}") | |
| return {"intent": intent, "reasoning": reasoning} | |
| # ββ Shared retrieval helper βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _retrieve(query: str) -> tuple[list[dict], list, str]: | |
| """Run Phase 2 pipeline, return (doc_dicts, citations, context_str).""" | |
| docs = advanced_retrieval_pipeline( | |
| query, bm25_retriever, dense_retriever, cross_encoder, llm | |
| ) | |
| doc_dicts = docs_to_dicts(docs) | |
| citations = list({d["metadata"].get("page") for d in doc_dicts}) | |
| context = context_from_dicts(doc_dicts) | |
| return doc_dicts, citations, context | |
| def _generate_answer(system_context: str, query: str, context: str) -> str: | |
| prompt = f"""{system_context} | |
| Retrieved context: | |
| {context[:4000]} | |
| Question: {query} | |
| Answer with specific facts and page citations:""" | |
| return llm.invoke(prompt).content | |
| # ββ RAG Agent βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def rag_agent_node(state: AgentState) -> dict: | |
| print(f"[RAG AGENT] Retrieving for: {state['query'][:60]}") | |
| doc_dicts, citations, context = _retrieve(state["query"]) | |
| system = ("You are a financial analyst. Answer only from the provided context. " | |
| "Cite page numbers.") | |
| answer = _generate_answer(system, state["query"], context) | |
| tool_name = INTENT_TO_TOOL.get(state["intent"]) | |
| tool_input = {"company": "Infosys", "context": context} if tool_name else None | |
| # Inject priority if analyst_review | |
| if tool_name == "schedule_analyst_review": | |
| tool_input["priority"] = ( | |
| "urgent" if "urgent" in state["query"].lower() else "normal" | |
| ) | |
| print(f"[RAG AGENT] tool={tool_name} | {len(doc_dicts)} docs | " | |
| f"citations={citations}") | |
| return { | |
| "retrieved_docs": doc_dicts, | |
| "final_answer": answer, | |
| "citations": citations, | |
| "tool_name": tool_name, | |
| "tool_input": tool_input, | |
| } | |
| # ββ Compliance Agent ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def compliance_agent_node(state: AgentState) -> dict: | |
| print(f"[COMPLIANCE AGENT] Checking: {state['query'][:60]}") | |
| # Prepend compliance keywords to improve retrieval | |
| retrieval_query = f"regulatory compliance SEBI RBI disclosure {state['query']}" | |
| doc_dicts, citations, context = _retrieve(retrieval_query) | |
| system = ("You are a compliance officer specialising in SEBI and RBI regulations. " | |
| "Identify all compliance risks and regulatory obligations. " | |
| "Cite the exact page and regulation.") | |
| answer = _generate_answer(system, state["query"], context) | |
| tool_input = {"company": "Infosys", "context": context} | |
| print(f"[COMPLIANCE AGENT] {len(doc_dicts)} docs | citations={citations}") | |
| return { | |
| "retrieved_docs": doc_dicts, | |
| "final_answer": answer, | |
| "citations": citations, | |
| "tool_name": "flag_compliance_issue", | |
| "tool_input": tool_input, | |
| } | |
| # ββ Alert Agent βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def alert_agent_node(state: AgentState) -> dict: | |
| print(f"[ALERT AGENT] Analysing: {state['query'][:60]}") | |
| retrieval_query = (f"revenue profit growth performance outlook " | |
| f"risk sentiment {state['query']}") | |
| doc_dicts, citations, context = _retrieve(retrieval_query) | |
| system = ("You are an investment analyst. Assess the company's financial health " | |
| "and generate a clear investment signal with supporting evidence. " | |
| "Note: FinBERT sentiment scoring will be added in Phase 4.") | |
| answer = _generate_answer(system, state["query"], context) | |
| tool_input = {"company": "Infosys", "context": context} | |
| print(f"[ALERT AGENT] {len(doc_dicts)} docs | citations={citations}") | |
| return { | |
| "retrieved_docs": doc_dicts, | |
| "final_answer": answer, | |
| "citations": citations, | |
| "tool_name": "generate_investment_alert", | |
| "tool_input": tool_input, | |
| } | |
| # ββ Finance Calculator Agent ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def finance_calculator_node(state: AgentState) -> dict: | |
| print(f"[FINANCE CALC] Computing: {state['query'][:60]}") | |
| doc_dicts, citations, context = _retrieve(state["query"]) | |
| # Augment context with matching table data | |
| query_lower = state["query"].lower() | |
| relevant_tables = [ | |
| t for t in tables | |
| if any(w in t["text_representation"].lower() | |
| for w in query_lower.split() if len(w) > 3) | |
| ][:3] | |
| table_context = "\n\n".join( | |
| t["text_representation"] for t in relevant_tables | |
| ) | |
| full_context = ( | |
| f"FINANCIAL TABLES:\n{table_context}\n\nDOCUMENT EXCERPTS:\n{context}" | |
| if table_context else context | |
| ) | |
| system = ("You are a financial analyst specialising in quantitative analysis. " | |
| "Extract precise numbers, ratios, and year-on-year changes. " | |
| "Be exact β use figures directly from the tables and document.") | |
| answer = _generate_answer(system, state["query"], full_context) | |
| tool_input = {"company": "Infosys", "context": full_context} | |
| print(f"[FINANCE CALC] {len(doc_dicts)} docs + {len(relevant_tables)} tables | " | |
| f"citations={citations}") | |
| return { | |
| "retrieved_docs": doc_dicts, | |
| "final_answer": answer, | |
| "citations": citations, | |
| "tool_name": "extract_financial_metrics", | |
| "tool_input": tool_input, | |
| } | |
| # ββ Tool Executor βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def tool_executor_node(state: AgentState) -> dict: | |
| tool_name = state.get("tool_name") | |
| if not tool_name or tool_name not in TOOL_REGISTRY: | |
| print("[TOOL EXECUTOR] No tool β passing answer through") | |
| return {"tool_output": {"answer": state["final_answer"]}} | |
| print(f"[TOOL EXECUTOR] Running: {tool_name}") | |
| tool_fn = TOOL_REGISTRY[tool_name] | |
| output = tool_fn(state["tool_input"], llm) | |
| print(f"[TOOL EXECUTOR] Done β keys: {list(output.keys())}") | |
| return {"tool_output": output} | |
| # ββ Self-RAG Evaluator ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def self_rag_evaluator_node(state: AgentState) -> dict: | |
| context = context_from_dicts(state["retrieved_docs"]) | |
| prompt = f"""Rate how faithfully this answer is grounded in the provided context. | |
| A score of 1.0 means every claim is directly supported by the context. | |
| A score of 0.0 means the answer contains hallucinated facts not in the context. | |
| Return ONLY valid JSON: {{"score": <0.0-1.0>, "reasoning": "..."}} | |
| Context (truncated): | |
| {context[:2000]} | |
| Answer: | |
| {state['final_answer'][:1000]}""" | |
| try: | |
| data = parse_json_from_response(llm.invoke(prompt).content) | |
| score = float(data.get("score", 0.8)) | |
| score = max(0.0, min(1.0, score)) # clamp to [0, 1] | |
| except Exception: | |
| score = 0.8 # conservative default on parse failure | |
| current_iter = state["iteration_count"] | |
| should_rerun = score < 0.7 and current_iter < 2 | |
| new_count = current_iter + 1 if should_rerun else current_iter | |
| print(f"[SELF-RAG] faithfulness={score:.2f} | " | |
| f"rerun={should_rerun} | iter={new_count}") | |
| return { | |
| "faithfulness_score": score, | |
| "should_rerun": should_rerun, | |
| "iteration_count": new_count, | |
| } | |
| return { | |
| "orchestrator": orchestrator_node, | |
| "rag_agent": rag_agent_node, | |
| "compliance_agent": compliance_agent_node, | |
| "alert_agent": alert_agent_node, | |
| "finance_calculator":finance_calculator_node, | |
| "tool_executor": tool_executor_node, | |
| "self_rag_evaluator":self_rag_evaluator_node, | |
| } | |
| # ββ Routing functions βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def route_by_intent(state: AgentState) -> str: | |
| mapping = { | |
| "rag_query": "rag_agent", | |
| "compliance_check": "compliance_agent", | |
| "investment_alert": "alert_agent", | |
| "financial_metrics": "finance_calculator", | |
| "analyst_review": "rag_agent", | |
| } | |
| node = mapping.get(state["intent"], "rag_agent") | |
| print(f"[ROUTER] intent={state['intent']} β {node}") | |
| return node | |
| def route_self_rag(state: AgentState) -> str: | |
| return "rag_agent" if state["should_rerun"] else "end" | |
| # ββ Graph builder βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_graph(components: dict): | |
| nodes = build_agent_nodes(components) | |
| workflow = StateGraph(AgentState) | |
| for name, fn in nodes.items(): | |
| workflow.add_node(name, fn) | |
| workflow.set_entry_point("orchestrator") | |
| workflow.add_conditional_edges( | |
| "orchestrator", | |
| route_by_intent, | |
| { | |
| "rag_agent": "rag_agent", | |
| "compliance_agent": "compliance_agent", | |
| "alert_agent": "alert_agent", | |
| "finance_calculator": "finance_calculator", | |
| }, | |
| ) | |
| for specialist in ["rag_agent", "compliance_agent", | |
| "alert_agent", "finance_calculator"]: | |
| workflow.add_edge(specialist, "tool_executor") | |
| workflow.add_edge("tool_executor", "self_rag_evaluator") | |
| workflow.add_conditional_edges( | |
| "self_rag_evaluator", | |
| route_self_rag, | |
| {"rag_agent": "rag_agent", "end": END}, | |
| ) | |
| return workflow.compile() | |
| # ββ Component initialiser βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_components() -> dict: | |
| print("=== Initialising Phase 3 components ===") | |
| llm = build_llm() | |
| embeddings = build_embeddings() | |
| pages = load_pdf(PDF_PATH) | |
| # spaCy NER (entities metadata) is not used at query time β BM25 needs only | |
| # page_content and the dense store is already persisted. Skipping NER here | |
| # cuts a large chunk off cold-start time. (Phase 1 ingestion still tags.) | |
| vs = load_or_build_vectorstore(embeddings, pages) | |
| bm25 = build_bm25_retriever(pages) | |
| dense = build_dense_retriever(vs) | |
| cross_enc = build_cross_encoder() | |
| with open(TABLE_STORE_PATH) as f: | |
| tables = json.load(f) | |
| print(f"[INIT] Loaded {len(tables)} tables from table store") | |
| return { | |
| "llm": llm, | |
| "embeddings": embeddings, | |
| "bm25_retriever": bm25, | |
| "dense_retriever": dense, | |
| "cross_encoder": cross_enc, | |
| "tables": tables, | |
| } | |
| # ββ Pretty-print result βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def print_result(result: AgentState) -> None: | |
| print("\n" + "=" * 65) | |
| print(f"QUERY : {result['query']}") | |
| print(f"INTENT : {result['intent']}") | |
| print(f"FAITHFULNESS : {result['faithfulness_score']:.2f}") | |
| print(f"CITATIONS : pages {result['citations']}") | |
| print(f"\nANSWER:\n{result['final_answer'][:600]}") | |
| if result.get("tool_output"): | |
| out = result["tool_output"] | |
| print(f"\nTOOL OUTPUT ({result['tool_name']}):") | |
| print(json.dumps(out, indent=2)[:800]) | |
| print("=" * 65) | |
| # ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| components = build_components() | |
| app = build_graph(components) | |
| print("\n=== LangGraph compiled successfully ===") | |
| test_queries = [ | |
| "What was Infosys revenue and operating margin in FY25?", | |
| "Are there any SEBI compliance or regulatory issues mentioned in the annual report?", | |
| "Generate an investment alert for Infosys based on their FY25 performance.", | |
| "Schedule an urgent credit review for Infosys.", | |
| ] | |
| for query in test_queries: | |
| print(f"\n{'#' * 65}") | |
| result = app.invoke(make_initial_state(query)) | |
| print_result(result) | |