sujataprakashdatycs commited on
Commit
10baa77
·
verified ·
1 Parent(s): 6a42e31

Update TestFindingAgent.py

Browse files
Files changed (1) hide show
  1. TestFindingAgent.py +32 -11
TestFindingAgent.py CHANGED
@@ -3,14 +3,13 @@ import json
3
  import pandas as pd
4
  from PyPDF2 import PdfReader
5
  from json_repair import repair_json
6
- from typing import List, Dict, Any
7
  from typing import List, Dict, Any, Optional
8
  from crewai import Agent, Task, Crew, Process
9
  from crewai_tools import SerperDevTool
10
- from langchain_openai import ChatOpenAI, OpenAIEmbeddings
11
- from langchain.text_splitter import RecursiveCharacterTextSplitter
12
  from langchain_community.vectorstores import Chroma
13
 
 
14
  SEED_SOURCES = [
15
  "https://www.cms.gov/medicare/payment/medicare-advantage-rates-statistics/risk-adjustment",
16
  "https://www.cms.gov/data-research/monitoring-programs/medicare-risk-adjustment-data-validation-program",
@@ -45,6 +44,30 @@ class TestFindingAgent:
45
  llm=self.llm,
46
  )
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def run(self, input_diagnoses: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
49
  updated_list = []
50
 
@@ -77,13 +100,11 @@ class TestFindingAgent:
77
  verbose=True
78
  )
79
  result = crew.kickoff()
80
- result = json.loads(repair_json(result))
81
- diag["tests"] = result
82
- updated_list.append(diag)
83
 
84
- # Save results to file
85
- # with open(self.output_file, "w", encoding="utf-8") as f:
86
- # json.dump(updated_list, f, indent=2, ensure_ascii=False)
 
 
87
 
88
- # print(f"[OUTPUT] Saved {len(updated_list)} diagnoses with appended tests to {self.output_file}")
89
- return updated_list
 
3
  import pandas as pd
4
  from PyPDF2 import PdfReader
5
  from json_repair import repair_json
 
6
  from typing import List, Dict, Any, Optional
7
  from crewai import Agent, Task, Crew, Process
8
  from crewai_tools import SerperDevTool
9
+ from langchain_openai import ChatOpenAI
 
10
  from langchain_community.vectorstores import Chroma
11
 
12
+
13
  SEED_SOURCES = [
14
  "https://www.cms.gov/medicare/payment/medicare-advantage-rates-statistics/risk-adjustment",
15
  "https://www.cms.gov/data-research/monitoring-programs/medicare-risk-adjustment-data-validation-program",
 
44
  llm=self.llm,
45
  )
46
 
47
+ def _extract_json_from_llm(self, raw_response: str) -> Dict[str, Any]:
48
+ """Extracts and repairs JSON from an LLM response safely."""
49
+ import re
50
+ match = re.search(r"\{.*\}", raw_response, re.DOTALL)
51
+ if not match:
52
+ print("[ERROR] No JSON object found in LLM response")
53
+ return {}
54
+
55
+ clean_json_str = match.group(0)
56
+
57
+ # Step 1: Try direct JSON parse
58
+ try:
59
+ return json.loads(clean_json_str)
60
+ except json.JSONDecodeError as e:
61
+ print(f"[WARN] Direct JSON parsing failed: {e}")
62
+
63
+ # Step 2: Try repairing JSON
64
+ try:
65
+ repaired = repair_json(clean_json_str)
66
+ return json.loads(repaired)
67
+ except Exception as e:
68
+ print(f"[ERROR] Failed to repair and parse JSON: {e}")
69
+ return {}
70
+
71
  def run(self, input_diagnoses: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
72
  updated_list = []
73
 
 
100
  verbose=True
101
  )
102
  result = crew.kickoff()
 
 
 
103
 
104
+ # Use safe extractor
105
+ result_dict = self._extract_json_from_llm(result)
106
+
107
+ diag["tests"] = result_dict
108
+ updated_list.append(diag)
109
 
110
+ return updated_list