Risk-Adjustment-Version1 / HCCDiagnosisListEngine.py
sujataprakashdatycs's picture
Update HCCDiagnosisListEngine.py
6829637 verified
import os
import json
import pandas as pd
from PyPDF2 import PdfReader
from json_repair import repair_json
from typing import List, Dict, Any
from typing import List, Dict, Any, Optional
from crewai import Agent, Task, Crew, Process
from crewai_tools import SerperDevTool
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
# ---------- CSV Resolver ----------
class CSVResolver:
def __init__(self, csv_path: str):
# Handle \n in the column name by renaming
self.df = pd.read_csv(csv_path, dtype=str)
self.df = self.df.rename(columns=lambda x: x.strip().replace("\n", " "))
def run(self, hcc_code: str, model_version: str) -> Dict[str, Any]:
version_col = model_version.strip().upper() # e.g., "V24"
if version_col not in self.df.columns:
return {
"resolved_label": hcc_code,
"model_version_used": model_version,
"diagnoses": [],
"notes": [f"Column {version_col} not found in CSV"],
}
subset = self.df[self.df[version_col].astype(str).str.strip() == hcc_code.strip()]
if subset.empty:
return {
"resolved_label": hcc_code,
"model_version_used": model_version,
"diagnoses": [],
"notes": [f"No matches for HCC {hcc_code} in {version_col}"],
}
diagnoses = [
f"{row['Description']} ({row['Diagnosis Code']})"
for _, row in subset.iterrows()
if pd.notna(row["Description"]) and pd.notna(row["Diagnosis Code"])
]
return {
"resolved_label": hcc_code,
"model_version_used": model_version,
"diagnoses": diagnoses,
"notes": [],
}
# ---------- HCC Diagnosis Engine ----------
SEED_SOURCES = [
"https://www.cms.gov/medicare/payment/medicare-advantage-rates-statistics/risk-adjustment",
"https://www.cms.gov/data-research/monitoring-programs/medicare-risk-adjustment-data-validation-program",
"https://www.cms.gov/files/document/fy-2024-icd-10-cm-coding-guidelines-updated-02/01/2024.pdf",
"https://www.aapc.com/blog/41212-include-meat-in-your-risk-adjustment-documentation/",
]
class HCCDiagnosisListEngine:
def __init__(self,
hcc_code: str,
model_version: str,
csv_path: str,
model: str = "gpt-4o",
output_file: Optional[str] = None):
self.hcc_code = hcc_code.strip()
self.model_version = model_version.strip().upper()
self.resolver = CSVResolver(csv_path)
self.llm = ChatOpenAI(model=model, temperature=0)
# ✅ Use Serper with CMS/AAPC seed sources
self.search = SerperDevTool(seed_sources=SEED_SOURCES)
safe_code = self.hcc_code.lower().replace(" ", "_")
safe_ver = self.model_version.lower()
self.output_file = output_file or f"{safe_code}_{safe_ver}_diagnoses.json"
self.agent = Agent(
role="Diagnosis & ICD-10 Extractor",
goal="Compile a complete list of diagnoses and ICD-10 codes mapped to the given HCC.",
backstory=(
"You specialize in identifying every relevant diagnosis and ICD-10 "
"code that maps to a specific HCC. You always base results on "
"authoritative CMS and AAPC sources, ensuring full coverage for "
"chart reviewers."
),
tools=[self.search],
verbose=True,
memory=False,
llm=self.llm,
)
def run(self) -> List[Dict[str, Any]]:
resolved = self.resolver.run(self.hcc_code, self.model_version)
diagnoses = resolved.get("diagnoses", [])
task = Task(
description=(
f"You are compiling **a complete list of diagnoses and ICD-10 codes** "
f"that map to HCC {self.hcc_code} ({resolved.get('resolved_label', 'Unknown')}) "
f"in {self.model_version}.\n\n"
f"Diagnoses and ICD-10 codes from CSV:\n- " + "\n- ".join(diagnoses) + "\n\n"
"Instructions:\n"
"- Use your Serper search tool with CMS/AAPC seed sources "
f"to identify all conditions, diagnoses, and ICD-10 codes that map to HCC {self.hcc_code}.\n"
"- Combine both the CSV-provided diagnoses AND the search results.\n"
"- Each entry must:\n"
" • include a general diagnosis name,\n"
" • list the relevant ICD-10 code(s),\n"
" • cite one authoritative CMS or AAPC source as the `reference`.\n"
"- Be exhaustive: include every diagnosis and ICD-10 code associated with this HCC.\n"
"- Make sure to cover the general diagnosis of all given in the diagnoses list atleast"
"- Also make sure if diagnosis is related to organs each organ is specified separately"
"- Keep diagnoses general enough to cover the clinical spectrum "
"(e.g., 'Breast Cancer' broadly, not just 'Melanoma of the breast').\n"
"- Return JSON ONLY as a list of objects with keys: diagnosis, icd10, reference.\n"
),
expected_output="Strict JSON array of {diagnosis, icd10, reference}",
agent=self.agent,
json_mode=True,
)
crew = Crew(agents=[self.agent], tasks=[task], process=Process.sequential, verbose=True)
result = crew.kickoff()
result = json.loads(repair_json(result))
if isinstance(result, list):
with open(self.output_file, "w", encoding="utf-8") as f:
json.dump(result, f, indent=2, ensure_ascii=False)
print(f"[OUTPUT] Saved {len(result)} diagnoses to {self.output_file}")
return result
else:
print("[WARN] Unexpected output format from agent")
return []