RAG / phase4_tools.py
sumitnewold's picture
Upload 10 files
76bd1fc verified
Raw
History Blame Contribute Delete
15.2 kB
"""
Phase 4 β€” Enhanced Finance Tools with Real FinBERT Sentiment
Key addition over Phase 3:
generate_investment_alert now uses ProsusAI/finbert to score the retrieved
financial text before deciding the investment signal. The placeholder
sentiment from Phase 3 is replaced with actual model output.
Usage:
from phase4_tools import build_finbert, run_finbert, ENHANCED_TOOL_REGISTRY
# patch into Phase 3 graph at runtime (see main() below)
"""
import json
import datetime
import os
import torch
from dotenv import load_dotenv
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from phase1_ingestion import (
PDF_PATH, CHROMA_DIR, TABLE_STORE_PATH,
build_embeddings, build_llm, load_pdf, tag_financial_entities,
)
from phase2_retrieval import (
advanced_retrieval_pipeline,
build_bm25_retriever,
build_cross_encoder,
build_dense_retriever,
load_or_build_vectorstore,
)
from phase3_agent import (
build_components,
build_graph,
context_from_dicts,
docs_to_dicts,
make_initial_state,
parse_json_from_response,
print_result,
# Phase 3 tools we keep unchanged
extract_financial_metrics,
generate_risk_summary,
flag_compliance_issue,
schedule_analyst_review,
)
load_dotenv()
FINBERT_MODEL = "ProsusAI/finbert"
FINBERT_LABELS = ["positive", "negative", "neutral"]
FINBERT_MAX_TOKENS = 512
# ── FinBERT loader ────────────────────────────────────────────────────────────
def build_finbert() -> tuple:
"""Load FinBERT tokenizer + model. Returns (tokenizer, model)."""
print(f"[FINBERT] Loading {FINBERT_MODEL}...")
tokenizer = AutoTokenizer.from_pretrained(FINBERT_MODEL)
model = AutoModelForSequenceClassification.from_pretrained(FINBERT_MODEL)
model.eval()
print("[FINBERT] Model ready")
return tokenizer, model
# ── FinBERT inference ─────────────────────────────────────────────────────────
def run_finbert(text: str, tokenizer, model) -> dict:
"""
Run FinBERT on a single text string (truncated to 512 tokens).
Returns {"label": "positive|negative|neutral", "score": float,
"all_scores": {"positive": float, "negative": float, "neutral": float}}
"""
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=FINBERT_MAX_TOKENS,
padding=True,
)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=1)[0].tolist()
all_scores = {FINBERT_LABELS[i]: round(probs[i], 4) for i in range(3)}
dominant = max(all_scores, key=all_scores.get)
return {"label": dominant, "score": all_scores[dominant], "all_scores": all_scores}
def run_finbert_on_docs(doc_dicts: list[dict], tokenizer, model) -> dict:
"""
Run FinBERT across all retrieved doc chunks and aggregate scores.
Strategy: average scores across all chunks, then pick the dominant label.
This handles the common case where a single page has mixed sentiment.
"""
if not doc_dicts:
return {"label": "neutral", "score": 0.5,
"all_scores": {"positive": 0.33, "negative": 0.33, "neutral": 0.34}}
agg = {"positive": 0.0, "negative": 0.0, "neutral": 0.0}
n = 0
for doc in doc_dicts:
text = doc["page_content"].strip()
if not text:
continue
result = run_finbert(text, tokenizer, model)
for label, score in result["all_scores"].items():
agg[label] += score
n += 1
if n == 0:
return {"label": "neutral", "score": 0.5,
"all_scores": {"positive": 0.33, "negative": 0.33, "neutral": 0.34}}
avg = {k: round(v / n, 4) for k, v in agg.items()}
dominant = max(avg, key=avg.get)
return {"label": dominant, "score": avg[dominant], "all_scores": avg}
# ── Enhanced generate_investment_alert (real FinBERT) ─────────────────────────
def make_enhanced_investment_alert(finbert_tokenizer, finbert_model):
"""
Returns an enhanced generate_investment_alert function that runs real
FinBERT sentiment before calling the LLM to decide the signal.
Injected into ENHANCED_TOOL_REGISTRY and into the Phase 3 graph.
"""
def generate_investment_alert(tool_input: dict, llm) -> dict:
company = tool_input.get("company", "Infosys")
context = tool_input.get("context", "")
doc_dicts= tool_input.get("doc_dicts", []) # injected by alert_agent
# ── Step 1: Run real FinBERT ──────────────────────────────────────────
# Prefer doc_dicts for per-chunk scoring; fall back to raw context string
if doc_dicts:
finbert_result = run_finbert_on_docs(doc_dicts, finbert_tokenizer, finbert_model)
else:
finbert_result = run_finbert(context[:FINBERT_MAX_TOKENS],
finbert_tokenizer, finbert_model)
print(f"[FINBERT] sentiment={finbert_result['label']} "
f"({finbert_result['score']:.2f}) | "
f"all={finbert_result['all_scores']}")
# ── Step 2: Let LLM decide signal using FinBERT as a hard input ──────
sentiment_hint = (
f"FinBERT (domain-specific sentiment model) scored this text as: "
f"{finbert_result['label'].upper()} "
f"(confidence {finbert_result['score']:.2f}). "
f"All scores β€” positive: {finbert_result['all_scores']['positive']}, "
f"negative: {finbert_result['all_scores']['negative']}, "
f"neutral: {finbert_result['all_scores']['neutral']}."
)
prompt = f"""You are a senior investment analyst.
{sentiment_hint}
Using BOTH the FinBERT sentiment above AND the financial context below,
generate a structured investment alert for {company}.
Return ONLY valid JSON:
{{
"company": "{company}",
"signal": "buy|sell|hold|watch",
"trigger_reason": "...",
"confidence_score": <0.0-1.0>,
"supporting_evidence": ["exact quote 1", "exact quote 2"],
"finbert_sentiment": {{
"label": "{finbert_result['label']}",
"score": {finbert_result['score']},
"all_scores": {{
"positive": {finbert_result['all_scores']['positive']},
"negative": {finbert_result['all_scores']['negative']},
"neutral": {finbert_result['all_scores']['neutral']}
}}
}}
}}
Financial context:
{context[:3000]}"""
try:
return parse_json_from_response(llm.invoke(prompt).content)
except Exception as e:
# Fallback: return what we have with FinBERT scores populated
return {
"company": company,
"signal": "hold",
"trigger_reason": "LLM parse error β€” FinBERT scores still valid",
"confidence_score": finbert_result["score"],
"supporting_evidence": [],
"finbert_sentiment": finbert_result,
"error": str(e),
}
return generate_investment_alert
# ── Build patched components dict ─────────────────────────────────────────────
def build_phase4_components() -> dict:
"""
Builds all Phase 3 components then loads FinBERT and patches
generate_investment_alert with the real-FinBERT version.
"""
components = build_components() # Phase 3 setup
finbert_tokenizer, finbert_model = build_finbert()
enhanced_fn = make_enhanced_investment_alert(finbert_tokenizer, finbert_model)
# Patch the tool registry in the components dict so the graph uses it
components["finbert_tokenizer"] = finbert_tokenizer
components["finbert_model"] = finbert_model
components["enhanced_alert_fn"] = enhanced_fn
return components
# ── Patch alert agent to pass doc_dicts to the tool ──────────────────────────
def build_phase4_graph(components: dict):
"""
Builds the Phase 3 graph but overrides the alert_agent_node closure
so it passes doc_dicts to generate_investment_alert for per-chunk FinBERT.
"""
from phase3_agent import (
build_agent_nodes, route_by_intent, route_self_rag,
INTENT_TO_TOOL, AgentState,
)
from langgraph.graph import StateGraph, END
# Build standard Phase 3 nodes
nodes = build_agent_nodes(components)
llm = components["llm"]
bm25_retriever = components["bm25_retriever"]
dense_retriever = components["dense_retriever"]
cross_encoder = components["cross_encoder"]
enhanced_alert = components["enhanced_alert_fn"]
# ── Override alert_agent_node ─────────────────────────────────────────────
def alert_agent_node_v4(state: AgentState) -> dict:
print(f"[ALERT AGENT v4] Analysing: {state['query'][:60]}")
retrieval_query = (
f"revenue profit growth performance outlook risk sentiment {state['query']}"
)
docs = advanced_retrieval_pipeline(
retrieval_query, bm25_retriever, dense_retriever, cross_encoder, llm
)
doc_dicts = docs_to_dicts(docs)
citations = list({d["metadata"].get("page") for d in doc_dicts})
context = context_from_dicts(doc_dicts)
answer_prompt = (
"You are an investment analyst. Assess the company's financial health "
"and generate a clear investment signal with supporting evidence. "
"Real FinBERT sentiment has already been computed and will be included.\n\n"
f"Retrieved context:\n{context[:4000]}\n\n"
f"Question: {state['query']}\n\nAnswer with specific facts and page citations:"
)
answer = llm.invoke(answer_prompt).content
# Pass doc_dicts so enhanced_alert can run FinBERT per chunk
tool_input = {
"company": "Infosys",
"context": context,
"doc_dicts": doc_dicts,
}
print(f"[ALERT AGENT v4] {len(doc_dicts)} docs | citations={citations}")
return {
"retrieved_docs": doc_dicts,
"final_answer": answer,
"citations": citations,
"tool_name": "generate_investment_alert",
"tool_input": tool_input,
}
nodes["alert_agent"] = alert_agent_node_v4
# ── Override tool_executor to use enhanced tool for investment_alert ──────
def tool_executor_node_v4(state: AgentState) -> dict:
tool_name = state.get("tool_name")
from phase3_agent import TOOL_REGISTRY
# Use enhanced alert function, standard registry for everything else
registry = {**TOOL_REGISTRY, "generate_investment_alert": enhanced_alert}
if not tool_name or tool_name not in registry:
print("[TOOL EXECUTOR v4] No tool β€” passing answer through")
return {"tool_output": {"answer": state["final_answer"]}}
print(f"[TOOL EXECUTOR v4] Running: {tool_name}")
output = registry[tool_name](state["tool_input"], llm)
print(f"[TOOL EXECUTOR v4] Done β€” keys: {list(output.keys())}")
return {"tool_output": output}
nodes["tool_executor"] = tool_executor_node_v4
# ── Compile graph ─────────────────────────────────────────────────────────
workflow = StateGraph(AgentState)
for name, fn in nodes.items():
workflow.add_node(name, fn)
workflow.set_entry_point("orchestrator")
workflow.add_conditional_edges(
"orchestrator", route_by_intent,
{
"rag_agent": "rag_agent",
"compliance_agent": "compliance_agent",
"alert_agent": "alert_agent",
"finance_calculator": "finance_calculator",
},
)
for specialist in ["rag_agent", "compliance_agent",
"alert_agent", "finance_calculator"]:
workflow.add_edge(specialist, "tool_executor")
workflow.add_edge("tool_executor", "self_rag_evaluator")
workflow.add_conditional_edges(
"self_rag_evaluator", route_self_rag,
{"rag_agent": "rag_agent", "end": END},
)
return workflow.compile()
# ── Standalone FinBERT demo ───────────────────────────────────────────────────
def run_finbert_demo(finbert_tokenizer, finbert_model) -> None:
"""Show FinBERT scoring on representative Infosys sentences."""
print("\n" + "=" * 65)
print("FINBERT STANDALONE DEMO")
print("=" * 65)
sentences = [
"Revenues grew 6.1% year-on-year to β‚Ή1,62,990 crore in FY25.",
"Revenue fell short of analyst estimates amid macro uncertainty.",
"Operating margin improved to 21.1%, up from 20.7% in FY24.",
"The company faces risks from geopolitical tensions and talent attrition.",
"Infosys won large AI-led transformation deals across key verticals.",
"Net headcount declined as the company focused on utilisation.",
]
for sentence in sentences:
result = run_finbert(sentence, finbert_tokenizer, finbert_model)
bar = "+" * int(result["all_scores"]["positive"] * 20)
print(f" [{result['label']:8s} {result['score']:.2f}] {sentence[:55]}")
print(f" pos={result['all_scores']['positive']:.3f} "
f"neg={result['all_scores']['negative']:.3f} "
f"neu={result['all_scores']['neutral']:.3f}")
# ── Main ──────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
# Build all components (Phase 3 + FinBERT)
components = build_phase4_components()
# Show FinBERT standalone demo
run_finbert_demo(components["finbert_tokenizer"], components["finbert_model"])
# Build Phase 4 graph (Phase 3 graph + patched alert agent + enhanced tool)
app = build_phase4_graph(components)
print("\n=== Phase 4 LangGraph compiled successfully ===\n")
# Test: investment alert query (the one that exercises FinBERT)
test_queries = [
"Generate an investment alert for Infosys based on their FY25 performance.",
"Should I be concerned about Infosys given the risk factors in their latest filing?",
]
for query in test_queries:
print(f"\n{'#' * 65}")
result = app.invoke(make_initial_state(query))
print_result(result)