CognitivePulse / rag_engine.py
Kshamaa S
Initial deployment: CognitivePulse biomarker intelligence and coaching assistant
14a5ab4
Raw
History Blame Contribute Delete
10.5 kB
"""
rag_engine.py — CognitivePulse
Retrieval-augmented generation for the coaching assistant.
Retrieves literature from data/corpus.json based on the intervention domains
identified by intervention_engine.py, generates a grounded coaching brief,
and runs a RAGAS-style faithfulness check on the output.
LLM backend: Groq inference API (openai/gpt-oss-120b).
Embedding backend: sentence-transformers (intfloat/e5-large-v2) + FAISS;
automatic TF-IDF fallback if model is unavailable.
"""
from __future__ import annotations
import json
import os
from pathlib import Path
import numpy as np
CORPUS_PATH = Path(__file__).parent / "data" / "corpus.json"
GROQ_MODEL = "openai/gpt-oss-120b" # see README.md for API key setup at console.groq.com
def load_corpus() -> list:
with open(CORPUS_PATH) as f:
return json.load(f)
class LiteratureRetriever:
"""Embedding-backed retriever over the curated corpus with TF-IDF fallback."""
def __init__(self):
self.corpus = load_corpus()
self.texts = [f"{d['title']}. {d['summary']}" for d in self.corpus]
self.backend = None
self._init_backend()
def _init_backend(self):
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("intfloat/e5-large-v2")
passage_texts = [f"passage: {t}" for t in self.texts]
embeddings = model.encode(passage_texts, normalize_embeddings=True)
self._st_model = model
self._embeddings = np.array(embeddings, dtype="float32")
self.backend = "sentence-transformers"
try:
import faiss
index = faiss.IndexFlatIP(self._embeddings.shape[1])
index.add(self._embeddings)
self._faiss_index = index
self.backend = "sentence-transformers+faiss"
except ImportError:
self._faiss_index = None
return
except Exception:
pass
from sklearn.feature_extraction.text import TfidfVectorizer
self._vectorizer = TfidfVectorizer(stop_words="english")
self._tfidf_matrix = self._vectorizer.fit_transform(self.texts)
self.backend = "tfidf_fallback"
def _retrieve_indices(self, query: str, k: int, candidate_idx=None):
if self.backend.startswith("sentence-transformers"):
q_emb = self._st_model.encode([f"query: {query}"], normalize_embeddings=True)
q_emb = np.array(q_emb, dtype="float32")
subset = self._embeddings[candidate_idx] if candidate_idx else self._embeddings
sims = subset @ q_emb[0]
order = np.argsort(-sims)[:k]
return [candidate_idx[i] for i in order] if candidate_idx else list(order)
else:
from sklearn.metrics.pairwise import cosine_similarity
q_vec = self._vectorizer.transform([query])
matrix = self._tfidf_matrix[candidate_idx] if candidate_idx else self._tfidf_matrix
sims = cosine_similarity(q_vec, matrix).flatten()
order = np.argsort(-sims)[:k]
return [candidate_idx[i] for i in order] if candidate_idx else list(order)
def retrieve_for_domains(self, query: str, domain_tags: list, k: int = 2) -> list:
candidate_idx = [i for i, d in enumerate(self.corpus) if d["domain"] in domain_tags]
idx = self._retrieve_indices(query, k, candidate_idx=candidate_idx or None)
return [self.corpus[i] for i in idx]
def retrieve_for_interventions(self, interventions: list, k_per_intervention: int = 2) -> dict:
"""Returns {intervention_summary: [retrieved docs]} for each intervention."""
retrieved = {}
for iv in interventions:
query = f"interventions and evidence for {iv['intervention_summary']}"
docs = self.retrieve_for_domains(query, iv["literature_tags"], k=k_per_intervention)
if docs:
retrieved[iv["intervention_summary"]] = docs
return retrieved
# ---------------------------------------------------------------------------
# Coaching generation
# ---------------------------------------------------------------------------
COACHING_SYSTEM_PROMPT = """You are an AI assistant helping prepare a brain health coaching brief \
for a health professional at a preventive neurology clinic. Based on a client's biomarker risk \
profile and the research literature provided, write a structured, evidence-grounded coaching summary.
Strict rules:
- Only state claims that are directly supported by the provided source excerpts.
- Every factual claim about research findings must cite a source naturally \
(e.g. "The SPRINT MIND trial found that...").
- Never make diagnostic statements. Frame everything as risk factors and evidence-based \
lifestyle interventions, not diagnoses.
- Tone: professional, evidence-informed, supportive. Suitable for a clinician to read \
before a client coaching session.
- Structure your response as: (1) Risk Summary, (2) Priority Interventions with evidence, \
(3) Suggested coaching focus areas.
- Keep the full response under 350 words.
- Respond with plain prose only — no markdown headers, no preamble."""
def generate_coaching(coach_brief: str, retrieved_context: dict,
client=None, model: str = GROQ_MODEL) -> dict:
"""
Calls Groq to generate a grounded coaching summary from the coach brief
and retrieved literature.
Returns {"text": str, "sources_used": [...]}.
"""
from groq import Groq
if client is None:
client = Groq()
if not retrieved_context:
return {
"text": ("No specific modifiable risk factors were flagged above typical ranges. "
"General brain-health maintenance advice applies: consistent physical activity, "
"quality sleep, a nutrient-dense diet, and active social and cognitive engagement "
"are supported across the prevention literature. "
"Note: this is a research prototype, not a clinical assessment."),
"sources_used": [],
}
context_blocks = []
all_sources = []
for area, docs in retrieved_context.items():
for d in docs:
context_blocks.append(f"[Source: {d['title']}{d['source']}]\n{d['summary']}")
all_sources.append(d)
context_text = "\n\n".join(context_blocks)
user_prompt = f"""CLIENT RISK BRIEF:
{coach_brief}
RETRIEVED RESEARCH EVIDENCE (use ONLY these sources):
{context_text}
Write the coaching summary now."""
response = client.chat.completions.create(
model=model,
max_tokens=1000,
reasoning_effort="low",
messages=[
{"role": "system", "content": COACHING_SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
)
text = (response.choices[0].message.content or "").strip()
if not text:
raise RuntimeError("Groq returned empty content. Try increasing max_tokens.")
seen_ids = set()
unique_sources = [s for s in all_sources
if s["id"] not in seen_ids and not seen_ids.add(s["id"])]
return {"text": text, "sources_used": unique_sources}
# ---------------------------------------------------------------------------
# Faithfulness evaluation
# ---------------------------------------------------------------------------
FAITHFULNESS_SYSTEM_PROMPT = """You are a strict clinical fact-checker. You will be given \
a generated coaching text and the source excerpts it was grounded in.
Break the coaching text into individual factual claims about research findings. For each claim, \
determine whether it is directly supported by the source excerpts (SUPPORTED), partially \
supported or overstated (PARTIAL), or not supported / hallucinated (UNSUPPORTED).
Respond ONLY with valid JSON in this exact format — no other text, no markdown code fences, \
no explanation before or after the JSON:
{
"claims": [
{"claim": "...", "verdict": "SUPPORTED|PARTIAL|UNSUPPORTED", "reason": "..."}
],
"faithfulness_score": <float between 0 and 1>
}
Keep each reason to one short sentence. Output the JSON object directly."""
def check_faithfulness(generated_text: str, sources_used: list,
client=None, model: str = GROQ_MODEL) -> dict:
import json as _json
from groq import Groq
if client is None:
client = Groq()
if not sources_used:
return {"claims": [], "faithfulness_score": 1.0,
"note": "No sources retrieved; fallback text used."}
source_text = "\n\n".join(f"[{s['title']}]\n{s['summary']}" for s in sources_used)
response = client.chat.completions.create(
model=model,
max_tokens=2000,
reasoning_effort="low",
messages=[
{"role": "system", "content": FAITHFULNESS_SYSTEM_PROMPT},
{"role": "user", "content": (
f"GENERATED TEXT:\n{generated_text}\n\n"
f"SOURCE EXCERPTS:\n{source_text}\n\n"
"Evaluate faithfulness now. Respond with JSON only."
)},
],
)
raw = (response.choices[0].message.content or "").strip()
if not raw:
return {"claims": [], "faithfulness_score": None,
"note": "Groq returned empty content (token budget). Try increasing max_tokens."}
raw = raw.removeprefix("```json").removeprefix("```").removesuffix("```").strip()
try:
return _json.loads(raw)
except _json.JSONDecodeError:
return {"claims": [], "faithfulness_score": None,
"note": "Could not parse evaluator output.", "raw": raw}
if __name__ == "__main__":
retriever = LiteratureRetriever()
print(f"Backend: {retriever.backend} | Corpus size: {len(retriever.corpus)}")
from intervention_engine import INTERVENTION_SUMMARY
sample_ivs = [
{"intervention_summary": "Managing cardiovascular risk factors (BP / cholesterol)",
"literature_tags": ["cardiovascular_risk"]},
{"intervention_summary": "Improving sleep quality and duration",
"literature_tags": ["sleep_glymphatic"]},
]
retrieved = retriever.retrieve_for_interventions(sample_ivs)
for area, docs in retrieved.items():
print(f"\n--- {area} ---")
for d in docs:
print(" •", d["title"])