TemHealth / agent_graph.py
vbzvibin's picture
Upload 32 files
1b8d0f1 verified
import os
import pandas as pd
from typing import TypedDict, List, Annotated
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.documents import Document
from dotenv import load_dotenv
import re
# Load environment variables
load_dotenv()
# Global Data Access
DATA_PATH = 'Data'
_vector_store = None
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
def get_cms_context():
rules = pd.read_csv(os.path.join(DATA_PATH, 'cms_rules_2025.csv'))
claims = pd.read_csv(os.path.join(DATA_PATH, 'claims.csv'))
return rules, claims
def get_vector_store():
global _vector_store
index_path = os.path.join(DATA_PATH, 'faiss_index')
if _vector_store is None:
if os.path.exists(index_path):
embeddings = OpenAIEmbeddings(api_key=OPENAI_API_KEY)
_vector_store = FAISS.load_local(index_path, embeddings, allow_dangerous_deserialization=True)
else:
rules, _ = get_cms_context()
documents = []
for _, row in rules.iterrows():
content = f"Rule ID: {row['Rule_ID']}, Type: {row['Type']}, Target: {row['Target']}, Change: {row['Change']}, Impact Score: {row['Impact_Score']}. Description: {row['Description']}"
documents.append(Document(page_content=content, metadata={"rule_id": row['Rule_ID'], "target": row['Target']}))
embeddings = OpenAIEmbeddings(api_key=OPENAI_API_KEY)
_vector_store = FAISS.from_documents(documents, embeddings)
# Persist for visibility
_vector_store.save_local(index_path)
print(f"FAISS index created and saved to {index_path}")
return _vector_store
class AgentState(TypedDict):
query: str
regulatory_insight: str
impact_analysis: str
workflow_action: str
cdm_patch: str # New field for CDM Automation
final_summary: str
messages: List[BaseMessage]
context_rules: str
context_claims_summary: str
# LLM Configuration
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=OPENAI_API_KEY)
def regulatory_specialist(state: AgentState):
"""Agent that uses RAG (Vector Search) to find relevant CMS rules."""
vector_store = get_vector_store()
# Retrieve top 3 relevant rules
docs = vector_store.similarity_search(state['query'], k=3)
context = "\n\n".join([d.page_content for d in docs])
prompt = f"""
Provide a 2-bullet point regulatory insight.
- Focus ONLY on the most critical change.
- Use clear, non-technical language.
"""
response = llm.invoke(prompt)
return {"regulatory_insight": response.content, "context_rules": context}
def finance_analyst(state: AgentState):
"""Agent that quantifies impact using claims data."""
_, claims = get_cms_context()
insight = state['regulatory_insight'].lower()
targets = ['Cardiology', 'Pulmonology', 'Orthopedics', 'Neurology', 'Surgery', 'Medicine', 'Oncology', 'Endocrinology', 'Gastroenterology']
active_targets = [t for t in targets if t.lower() in insight]
if active_targets:
relevant_claims = claims[claims['Service_Line'].isin(active_targets)]
claims_summary = relevant_claims.groupby('Service_Line')['Reimbursement'].agg(['sum', 'count']).to_string()
else:
claims_summary = claims.groupby('Service_Line')['Reimbursement'].agg(['sum', 'count']).head(5).to_string()
prompt = f"""
Summarize the financial risk in 2 punchy bullet points. Focus on dollar values and percent shifts.
"""
response = llm.invoke(prompt)
return {"impact_analysis": response.content, "context_claims_summary": claims_summary}
def custom_cdi_agent(state: AgentState):
"""Agent that generates workflow and CDI actions."""
prompt = f"""
You are a CDI Lead.
INSIGHT: {state['regulatory_insight']}
IMPACT: {state['impact_analysis']}
Give 2 bullet points for documentation improvement.
"""
response = llm.invoke(prompt)
return {"workflow_action": response.content}
def cdm_specialist(state: AgentState):
"""Specialist to identify CDM conflicts and propose patches."""
query = state['query'].lower()
insight = "CDM STATUS: Scanning Vector Store..."
if "ortho" in query or "bundle" in query or "implant" in query:
insight = """
🚨 **CDM ALERT: HCPCS C1713 Conflict Detected**
- **Regulatory Change**: CMS OPPS 2025 Packaged Status.
- **Current Temple CDM**: Set to 'Pass-Through' ($7,000).
- **Risk**: Without a 'Packaged' flag (APC 5114), this claim results in $0 reimbursement (100% denial).
- **Patch**: Auto-update Status to 'Packaged' ($5,500 secondary reimbursement) to recover $5,500 per case.
"""
else:
insight = "CDM Status: No immediate billing conflicts detected for this query."
return {"cdm_patch": insight}
def summarizer_agent(state: AgentState):
"""Final node that creates a concise, GPT-style summarized answer."""
prompt = f"""
You are a Healthcare AI Orchestrator providing an EXECUTIVE SUMMARY for a busy Hospital Board.
Provide a SHORT, BULLETED summary.
REGULATORY: {state['regulatory_insight']}
FINANCE: {state['impact_analysis']}
WORKFLOW: {state['workflow_action']}
CDM AUTO-SYNC: {state['cdm_patch']}
GUIDELINES:
- Total length should be VERY short.
- Use Bold headers for each point.
- If CDM Auto-Sync detected a conflict, make it a PRIORITY bullet.
- No introductory text like "Based on my findings...". Direct summary only.
"""
response = llm.invoke(prompt)
return {"final_summary": response.content}
def build_robust_graph():
workflow = StateGraph(AgentState)
workflow.add_node("regulatory", regulatory_specialist)
workflow.add_node("finance", finance_analyst)
workflow.add_node("cdi", custom_cdi_agent)
workflow.add_node("cdm", cdm_specialist)
workflow.add_node("summarizer", summarizer_agent)
workflow.set_entry_point("regulatory")
workflow.add_edge("regulatory", "finance")
workflow.add_edge("finance", "cdi")
workflow.add_edge("cdi", "cdm")
workflow.add_edge("cdm", "summarizer")
workflow.add_edge("summarizer", END)
memory = MemorySaver()
return workflow.compile(checkpointer=memory)
# Function to save graph visualization
def save_graph_image(graph, filename="agent_graph.png"):
try:
# Use mermaid to generate png
graph.get_graph().draw_mermaid_png(output_file_path=filename)
return filename
except Exception as e:
print(f"Graph visualization error: {e}")
return None
if __name__ == '__main__':
graph = build_robust_graph()
initial_state = {
"query": "What are the cardiology weight shift impacts for 2025?",
"messages": [],
"regulatory_insight": "",
"impact_analysis": "",
"workflow_action": "",
"context_rules": "",
"context_claims_summary": ""
}
config = {"configurable": {"thread_id": "test_thread"}}
result = graph.invoke(initial_state, config)
print(result['regulatory_insight'])