abjasrees's picture
Update agents.py
6e65e5b verified
# engines.py
import os
import re
import json
from typing import List, Dict, Any
from PyPDF2 import PdfReader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from crewai_tools import tool
from crewai import Agent, Task, Crew, Process
# ---------- Utilities ----------
def ensure_api_key() -> None:
if not os.environ.get("OPENAI_API_KEY"):
raise RuntimeError("OPENAI_API_KEY is not set in environment.")
def safe_json_loads(s: str) -> Any:
"""
Try to parse JSON; if that fails, try Python-literal; otherwise return raw string.
"""
s = s.strip()
try:
return json.loads(s)
except Exception:
pass
try:
import ast
return ast.literal_eval(s)
except Exception:
return s
def coerce_to_list_of_strings(obj: Any) -> List[str]:
"""
Accepts whatever the LLM returned and tries hard to make it a list[str].
- If it's already a list[str], return it.
- If it's a dict with 'questions' or similar key, pull that.
- If it's a string, split into lines and strip bullet/number prefixes.
"""
if isinstance(obj, list) and all(isinstance(x, str) for x in obj):
return obj
if isinstance(obj, dict):
for k in ("questions", "items", "list"):
if k in obj and isinstance(obj[k], list):
return [str(x).strip() for x in obj[k] if str(x).strip()]
# If dict is like { "Q1": "...", "Q2": "..." }
if all(isinstance(v, str) for v in obj.values()):
return [v.strip() for v in obj.values() if v.strip()]
if isinstance(obj, str):
lines = []
for line in obj.splitlines():
line = line.strip()
# strip prefixes like "1. ", "1) ", "- ", "* "
line = re.sub(r'^\s*(?:[-*]|\d+[\.\)]\s*)', '', line)
if line:
lines.append(line)
# heuristically drop boilerplate lines
lines = [l for l in lines if len(l) > 2]
return lines
# last resort
return [str(obj)]
def normalize_yes_no_nf(s: str) -> str:
"""Map answer to exactly 'Yes' | 'No' | 'Not Found'."""
t = s.strip().lower()
if t.startswith("y"):
return "Yes"
if t.startswith("n"):
return "No"
return "Not Found"
# ---------- Embeddings ----------
class EmbeddingManager:
"""Handles PDF → text → embeddings with persistence."""
def __init__(self, pdf_path: str, base_dir: str = "./embeddings"):
self.pdf_path = pdf_path
name = os.path.splitext(os.path.basename(pdf_path))[0]
self.txt_path = os.path.join(base_dir, f"{name}.txt")
self.persist_dir = os.path.join(base_dir, name)
os.makedirs(base_dir, exist_ok=True)
def pdf_to_txt(self) -> str:
if os.path.exists(self.txt_path) and os.path.getsize(self.txt_path) > 0:
return self.txt_path
reader = PdfReader(self.pdf_path)
with open(self.txt_path, "w", encoding="utf-8") as f:
for page in reader.pages:
text = page.extract_text() or ""
if text.strip():
f.write(text + "\n")
return self.txt_path
def get_or_create_embeddings(self) -> Chroma:
ensure_api_key()
embeddings = OpenAIEmbeddings(openai_api_key=os.environ["OPENAI_API_KEY"])
# Load if exists
if os.path.isdir(self.persist_dir) and os.listdir(self.persist_dir):
return Chroma(persist_directory=self.persist_dir, embedding_function=embeddings)
# Else build new
txt_path = self.pdf_to_txt()
with open(txt_path, "r", encoding="utf-8") as f:
text = f.read()
splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=128)
chunks = splitter.split_text(text)
vectordb = Chroma.from_texts(
texts=chunks,
embedding=embeddings,
persist_directory=self.persist_dir
)
vectordb.persist()
return vectordb
# ---------- Question Extraction Crew ----------
class QuestionExtractionEngine:
"""
Crew 1:
Reads the surgical guideline embeddings and compiles clear YES/NO questions.
Supports iterative refinement via `user_suggestions`.
"""
def __init__(self, guideline_pdf_path: str, surgery_type: str, user_suggestions: str = ""):
self.guideline_db = EmbeddingManager(guideline_pdf_path).get_or_create_embeddings()
self.surgery_type = surgery_type
self.user_suggestions = (user_suggestions or "").strip()
@tool("guideline_search")
def guideline_search(query: str) -> str:
"""Search the guideline embeddings for relevant context to build questions."""
results = self.guideline_db.similarity_search(query, k=10)
return "\n".join([res.page_content for res in results])
@tool("list_questions")
def list_questions(questions: List[str]) -> str:
"""Return the extracted questions as a formatted string for debugging."""
return "\n".join([f"{i+1}. {q}" for i, q in enumerate(questions)])
prompt_suffix = ""
if self.user_suggestions:
prompt_suffix = (
"\nIncorporate the following user suggestions/preferences while drafting questions:\n"
f"{self.user_suggestions}\n"
)
self.question_compiler = Agent(
role="Guideline Question Compiler",
goal=(
f"From the provided {self.surgery_type} guideline, create a list of clear YES/NO clinical questions."
" Questions should be unambiguous, atomic, and clinically meaningful."
),
verbose=True,
memory=False,
tools=[guideline_search, list_questions],
backstory=(
"You are an expert at reading medical guidelines and turning them into unambiguous yes/no "
"assessment questions that support eligibility decisions."
),
)
self.task_compile_questions = Task(
description=(
f"From the {self.surgery_type} guideline, list all yes/no questions needed to assess eligibility."
" Each question should be answerable directly from a patient chart."
" Avoid double-barreled questions; keep them atomic and specific."
f"{prompt_suffix}"
" After drafting, call the 'list_questions' tool to echo them back in a clean, numbered list."
" Finally, also return the Python list of questions as JSON (e.g., [\"Q1\", \"Q2\"])."
),
expected_output="A JSON list of yes/no clinical questions, plus a numbered preview via list_questions.",
agent=self.question_compiler,
)
self.crew = Crew(
agents=[self.question_compiler],
tasks=[self.task_compile_questions],
process=Process.sequential,
verbose=True,
)
def run(self) -> List[str]:
result = self.crew.kickoff(inputs={})
parsed = safe_json_loads(str(result))
return coerce_to_list_of_strings(parsed)
# ---------- Patient-Chart Answering Crew ----------
class AnsweringEngine:
"""
Crew 2:
Takes ONE yes/no question and answers it using the patient chart embeddings.
Output: {question: {"answer": "Yes/No/Not Found", "rationale": "Reason"}}
"""
def __init__(self, patient_chart_pdf_path: str):
self.patient_db = EmbeddingManager(patient_chart_pdf_path).get_or_create_embeddings()
@tool("patient_chart_search")
def patient_chart_search(query: str) -> str:
"""Search the patient chart embeddings for relevant context to answer the question."""
results = self.patient_db.similarity_search(query, k=10)
return "\n".join([res.page_content for res in results])
self.chart_answerer = Agent(
role="Patient Chart Question Answerer",
goal=(
"Given ONE yes/no question, search the patient chart and answer with "
"Yes, No, or Not Found. Always include a one-line rationale citing evidence."
),
verbose=True,
memory=False,
tools=[patient_chart_search],
backstory="You are precise and always back your yes/no answers with chart evidence.",
)
self.json_builder = Agent(
role="JSON Question-Answer Mapper",
goal=(
"Wrap the question and its answer+rationale into JSON format: "
"{question: {answer: '...', rationale: '...'}}"
),
verbose=True,
memory=False,
tools=[],
backstory="You ensure JSON is well-structured and complete.",
)
def answer_one(self, question: str) -> Dict[str, Dict[str, str]]:
# Task 1: Answer the question
task_answer = Task(
description=(
f"Answer this question using the patient chart: '{question}'. "
"The answer must be strictly Yes, No, or Not Found, plus a one-line rationale citing evidence."
" If evidence is weak or ambiguous, return 'Not Found'."
),
expected_output='A JSON object like {"answer": "Yes", "rationale": "Reason"} or a short text with both.',
agent=self.chart_answerer,
)
# Task 2: Wrap into final JSON
task_json = Task(
description=(
f"Take the answer and rationale from the previous step and return a JSON object "
f"where the key is the question '{question}' "
"and the value is an object with keys 'answer' and 'rationale'."
),
expected_output=f'{{"{question}": {{"answer": "Yes", "rationale": "Reason"}}}}',
agent=self.json_builder,
context=[task_answer],
)
crew = Crew(
agents=[self.chart_answerer, self.json_builder],
tasks=[task_answer, task_json],
process=Process.sequential,
verbose=True,
)
result = crew.kickoff(inputs={})
obj = safe_json_loads(str(result))
# Normalize structure & fields
if isinstance(obj, dict) and question in obj:
ans = obj[question] or {}
ans_text = normalize_yes_no_nf(str(ans.get("answer", "")))
rationale = str(ans.get("rationale", "")).strip()
return {question: {"answer": ans_text, "rationale": rationale or "—"}}
# Fallback if LLM returned flat structure
if isinstance(obj, dict) and "answer" in obj:
ans_text = normalize_yes_no_nf(str(obj.get("answer", "")))
rationale = str(obj.get("rationale", "")).strip()
return {question: {"answer": ans_text, "rationale": rationale or "—"}}
# Last resort
return {question: {"answer": "Not Found", "rationale": "Could not parse model output."}}