File size: 4,523 Bytes
57a5a52
 
 
 
 
 
 
 
10baa77
57a5a52
 
10baa77
6a42e31
 
 
 
 
 
 
 
22638b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10baa77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22638b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10baa77
 
 
 
 
22638b6
10baa77
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import json
import pandas as pd
from PyPDF2 import PdfReader
from json_repair import repair_json
from typing import List, Dict, Any, Optional
from crewai import Agent, Task, Crew, Process
from crewai_tools import SerperDevTool
from langchain_openai import ChatOpenAI
from langchain_community.vectorstores import Chroma


SEED_SOURCES = [
    "https://www.cms.gov/medicare/payment/medicare-advantage-rates-statistics/risk-adjustment",
    "https://www.cms.gov/data-research/monitoring-programs/medicare-risk-adjustment-data-validation-program",
    "https://www.cms.gov/files/document/fy-2024-icd-10-cm-coding-guidelines-updated-02/01/2024.pdf",
    "https://www.aapc.com/blog/41212-include-meat-in-your-risk-adjustment-documentation/",
]


class TestFindingAgent:
    def __init__(self, hcc_code: str, model_version: str,
                 model: str = "gpt-4o", output_file: Optional[str] = None):
        self.hcc_code = hcc_code.strip()
        self.model_version = model_version.strip().upper()
        self.llm = ChatOpenAI(model=model, temperature=0)

        self.search = SerperDevTool(seed_sources=SEED_SOURCES)

        safe_code = self.hcc_code.lower().replace(" ", "_")
        safe_ver = self.model_version.lower()
        self.output_file = output_file or f"{safe_code}_{safe_ver}_tests.json"

        self.agent = Agent(
            role="HCC Test & Procedure Extractor",
            goal="For each HCC diagnosis, find labs, procedures, and vitals required to support it.",
            backstory=(
                "You specialize in mapping diagnoses to supporting labs, vitals, and procedures. "
                "You always rely on CMS/AAPC sources to find the tests required for the diagnosis for the hcc code and extract available values from the patient chart context."
            ),
            tools=[self.search],
            verbose=True,
            memory=False,
            llm=self.llm,
        )

    def _extract_json_from_llm(self, raw_response: str) -> Dict[str, Any]:
        """Extracts and repairs JSON from an LLM response safely."""
        import re
        match = re.search(r"\{.*\}", raw_response, re.DOTALL)
        if not match:
            print("[ERROR] No JSON object found in LLM response")
            return {}

        clean_json_str = match.group(0)

        # Step 1: Try direct JSON parse
        try:
            return json.loads(clean_json_str)
        except json.JSONDecodeError as e:
            print(f"[WARN] Direct JSON parsing failed: {e}")

        # Step 2: Try repairing JSON
        try:
            repaired = repair_json(clean_json_str)
            return json.loads(repaired)
        except Exception as e:
            print(f"[ERROR] Failed to repair and parse JSON: {e}")
            return {}

    def run(self, input_diagnoses: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        updated_list = []

        for diag in input_diagnoses:
            task = Task(
                description=(
                    f"For HCC {self.hcc_code} ({self.model_version}), analyze this patient context:\n\n"
                    f"{diag['context']} for the diagnosis {diag['diagnosis']}\n\n"
                    "Instructions:\n"
                    "- Identify all **lab tests, procedures, and vitals** that are required to validate this diagnosis for that hcc given per CMS/AAPC.\n"
                    "- Extract actual values if present in the `context`. For example: BMI, blood pressure, HbA1c, lipids.\n"
                    "- If something is not in the context, return an empty dict for that category.\n"
                    "- Give the output as JSON given below:\n"
                    "  {\n"
                    "    'vitals': {...},\n"
                    "    'procedures': {...},\n"
                    "    'lab_test': {...}\n"
                    "  }\n"
                    "- Return the output as strict JSON only."
                ),
                expected_output="One JSON object: the updated diagnosis with `test` included.",
                agent=self.agent,
                json_mode=True,
            )

            crew = Crew(
                agents=[self.agent],
                tasks=[task],
                process=Process.sequential,
                verbose=True
            )
            result = crew.kickoff()

            # Use safe extractor
            result_dict = self._extract_json_from_llm(result)

            diag["tests"] = result_dict
            updated_list.append(diag)

        return updated_list