Risk-Adjustment-Version1 / TestFindingAgent.py
sujataprakashdatycs's picture
Update TestFindingAgent.py
10baa77 verified
import os
import json
import pandas as pd
from PyPDF2 import PdfReader
from json_repair import repair_json
from typing import List, Dict, Any, Optional
from crewai import Agent, Task, Crew, Process
from crewai_tools import SerperDevTool
from langchain_openai import ChatOpenAI
from langchain_community.vectorstores import Chroma
SEED_SOURCES = [
"https://www.cms.gov/medicare/payment/medicare-advantage-rates-statistics/risk-adjustment",
"https://www.cms.gov/data-research/monitoring-programs/medicare-risk-adjustment-data-validation-program",
"https://www.cms.gov/files/document/fy-2024-icd-10-cm-coding-guidelines-updated-02/01/2024.pdf",
"https://www.aapc.com/blog/41212-include-meat-in-your-risk-adjustment-documentation/",
]
class TestFindingAgent:
def __init__(self, hcc_code: str, model_version: str,
model: str = "gpt-4o", output_file: Optional[str] = None):
self.hcc_code = hcc_code.strip()
self.model_version = model_version.strip().upper()
self.llm = ChatOpenAI(model=model, temperature=0)
self.search = SerperDevTool(seed_sources=SEED_SOURCES)
safe_code = self.hcc_code.lower().replace(" ", "_")
safe_ver = self.model_version.lower()
self.output_file = output_file or f"{safe_code}_{safe_ver}_tests.json"
self.agent = Agent(
role="HCC Test & Procedure Extractor",
goal="For each HCC diagnosis, find labs, procedures, and vitals required to support it.",
backstory=(
"You specialize in mapping diagnoses to supporting labs, vitals, and procedures. "
"You always rely on CMS/AAPC sources to find the tests required for the diagnosis for the hcc code and extract available values from the patient chart context."
),
tools=[self.search],
verbose=True,
memory=False,
llm=self.llm,
)
def _extract_json_from_llm(self, raw_response: str) -> Dict[str, Any]:
"""Extracts and repairs JSON from an LLM response safely."""
import re
match = re.search(r"\{.*\}", raw_response, re.DOTALL)
if not match:
print("[ERROR] No JSON object found in LLM response")
return {}
clean_json_str = match.group(0)
# Step 1: Try direct JSON parse
try:
return json.loads(clean_json_str)
except json.JSONDecodeError as e:
print(f"[WARN] Direct JSON parsing failed: {e}")
# Step 2: Try repairing JSON
try:
repaired = repair_json(clean_json_str)
return json.loads(repaired)
except Exception as e:
print(f"[ERROR] Failed to repair and parse JSON: {e}")
return {}
def run(self, input_diagnoses: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
updated_list = []
for diag in input_diagnoses:
task = Task(
description=(
f"For HCC {self.hcc_code} ({self.model_version}), analyze this patient context:\n\n"
f"{diag['context']} for the diagnosis {diag['diagnosis']}\n\n"
"Instructions:\n"
"- Identify all **lab tests, procedures, and vitals** that are required to validate this diagnosis for that hcc given per CMS/AAPC.\n"
"- Extract actual values if present in the `context`. For example: BMI, blood pressure, HbA1c, lipids.\n"
"- If something is not in the context, return an empty dict for that category.\n"
"- Give the output as JSON given below:\n"
" {\n"
" 'vitals': {...},\n"
" 'procedures': {...},\n"
" 'lab_test': {...}\n"
" }\n"
"- Return the output as strict JSON only."
),
expected_output="One JSON object: the updated diagnosis with `test` included.",
agent=self.agent,
json_mode=True,
)
crew = Crew(
agents=[self.agent],
tasks=[task],
process=Process.sequential,
verbose=True
)
result = crew.kickoff()
# Use safe extractor
result_dict = self._extract_json_from_llm(result)
diag["tests"] = result_dict
updated_list.append(diag)
return updated_list