""" rag_engine.py — CognitivePulse Retrieval-augmented generation for the coaching assistant. Retrieves literature from data/corpus.json based on the intervention domains identified by intervention_engine.py, generates a grounded coaching brief, and runs a RAGAS-style faithfulness check on the output. LLM backend: Groq inference API (openai/gpt-oss-120b). Embedding backend: sentence-transformers (intfloat/e5-large-v2) + FAISS; automatic TF-IDF fallback if model is unavailable. """ from __future__ import annotations import json import os from pathlib import Path import numpy as np CORPUS_PATH = Path(__file__).parent / "data" / "corpus.json" GROQ_MODEL = "openai/gpt-oss-120b" # see README.md for API key setup at console.groq.com def load_corpus() -> list: with open(CORPUS_PATH) as f: return json.load(f) class LiteratureRetriever: """Embedding-backed retriever over the curated corpus with TF-IDF fallback.""" def __init__(self): self.corpus = load_corpus() self.texts = [f"{d['title']}. {d['summary']}" for d in self.corpus] self.backend = None self._init_backend() def _init_backend(self): try: from sentence_transformers import SentenceTransformer model = SentenceTransformer("intfloat/e5-large-v2") passage_texts = [f"passage: {t}" for t in self.texts] embeddings = model.encode(passage_texts, normalize_embeddings=True) self._st_model = model self._embeddings = np.array(embeddings, dtype="float32") self.backend = "sentence-transformers" try: import faiss index = faiss.IndexFlatIP(self._embeddings.shape[1]) index.add(self._embeddings) self._faiss_index = index self.backend = "sentence-transformers+faiss" except ImportError: self._faiss_index = None return except Exception: pass from sklearn.feature_extraction.text import TfidfVectorizer self._vectorizer = TfidfVectorizer(stop_words="english") self._tfidf_matrix = self._vectorizer.fit_transform(self.texts) self.backend = "tfidf_fallback" def _retrieve_indices(self, query: str, k: int, candidate_idx=None): if self.backend.startswith("sentence-transformers"): q_emb = self._st_model.encode([f"query: {query}"], normalize_embeddings=True) q_emb = np.array(q_emb, dtype="float32") subset = self._embeddings[candidate_idx] if candidate_idx else self._embeddings sims = subset @ q_emb[0] order = np.argsort(-sims)[:k] return [candidate_idx[i] for i in order] if candidate_idx else list(order) else: from sklearn.metrics.pairwise import cosine_similarity q_vec = self._vectorizer.transform([query]) matrix = self._tfidf_matrix[candidate_idx] if candidate_idx else self._tfidf_matrix sims = cosine_similarity(q_vec, matrix).flatten() order = np.argsort(-sims)[:k] return [candidate_idx[i] for i in order] if candidate_idx else list(order) def retrieve_for_domains(self, query: str, domain_tags: list, k: int = 2) -> list: candidate_idx = [i for i, d in enumerate(self.corpus) if d["domain"] in domain_tags] idx = self._retrieve_indices(query, k, candidate_idx=candidate_idx or None) return [self.corpus[i] for i in idx] def retrieve_for_interventions(self, interventions: list, k_per_intervention: int = 2) -> dict: """Returns {intervention_summary: [retrieved docs]} for each intervention.""" retrieved = {} for iv in interventions: query = f"interventions and evidence for {iv['intervention_summary']}" docs = self.retrieve_for_domains(query, iv["literature_tags"], k=k_per_intervention) if docs: retrieved[iv["intervention_summary"]] = docs return retrieved # --------------------------------------------------------------------------- # Coaching generation # --------------------------------------------------------------------------- COACHING_SYSTEM_PROMPT = """You are an AI assistant helping prepare a brain health coaching brief \ for a health professional at a preventive neurology clinic. Based on a client's biomarker risk \ profile and the research literature provided, write a structured, evidence-grounded coaching summary. Strict rules: - Only state claims that are directly supported by the provided source excerpts. - Every factual claim about research findings must cite a source naturally \ (e.g. "The SPRINT MIND trial found that..."). - Never make diagnostic statements. Frame everything as risk factors and evidence-based \ lifestyle interventions, not diagnoses. - Tone: professional, evidence-informed, supportive. Suitable for a clinician to read \ before a client coaching session. - Structure your response as: (1) Risk Summary, (2) Priority Interventions with evidence, \ (3) Suggested coaching focus areas. - Keep the full response under 350 words. - Respond with plain prose only — no markdown headers, no preamble.""" def generate_coaching(coach_brief: str, retrieved_context: dict, client=None, model: str = GROQ_MODEL) -> dict: """ Calls Groq to generate a grounded coaching summary from the coach brief and retrieved literature. Returns {"text": str, "sources_used": [...]}. """ from groq import Groq if client is None: client = Groq() if not retrieved_context: return { "text": ("No specific modifiable risk factors were flagged above typical ranges. " "General brain-health maintenance advice applies: consistent physical activity, " "quality sleep, a nutrient-dense diet, and active social and cognitive engagement " "are supported across the prevention literature. " "Note: this is a research prototype, not a clinical assessment."), "sources_used": [], } context_blocks = [] all_sources = [] for area, docs in retrieved_context.items(): for d in docs: context_blocks.append(f"[Source: {d['title']} — {d['source']}]\n{d['summary']}") all_sources.append(d) context_text = "\n\n".join(context_blocks) user_prompt = f"""CLIENT RISK BRIEF: {coach_brief} RETRIEVED RESEARCH EVIDENCE (use ONLY these sources): {context_text} Write the coaching summary now.""" response = client.chat.completions.create( model=model, max_tokens=1000, reasoning_effort="low", messages=[ {"role": "system", "content": COACHING_SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ], ) text = (response.choices[0].message.content or "").strip() if not text: raise RuntimeError("Groq returned empty content. Try increasing max_tokens.") seen_ids = set() unique_sources = [s for s in all_sources if s["id"] not in seen_ids and not seen_ids.add(s["id"])] return {"text": text, "sources_used": unique_sources} # --------------------------------------------------------------------------- # Faithfulness evaluation # --------------------------------------------------------------------------- FAITHFULNESS_SYSTEM_PROMPT = """You are a strict clinical fact-checker. You will be given \ a generated coaching text and the source excerpts it was grounded in. Break the coaching text into individual factual claims about research findings. For each claim, \ determine whether it is directly supported by the source excerpts (SUPPORTED), partially \ supported or overstated (PARTIAL), or not supported / hallucinated (UNSUPPORTED). Respond ONLY with valid JSON in this exact format — no other text, no markdown code fences, \ no explanation before or after the JSON: { "claims": [ {"claim": "...", "verdict": "SUPPORTED|PARTIAL|UNSUPPORTED", "reason": "..."} ], "faithfulness_score": } Keep each reason to one short sentence. Output the JSON object directly.""" def check_faithfulness(generated_text: str, sources_used: list, client=None, model: str = GROQ_MODEL) -> dict: import json as _json from groq import Groq if client is None: client = Groq() if not sources_used: return {"claims": [], "faithfulness_score": 1.0, "note": "No sources retrieved; fallback text used."} source_text = "\n\n".join(f"[{s['title']}]\n{s['summary']}" for s in sources_used) response = client.chat.completions.create( model=model, max_tokens=2000, reasoning_effort="low", messages=[ {"role": "system", "content": FAITHFULNESS_SYSTEM_PROMPT}, {"role": "user", "content": ( f"GENERATED TEXT:\n{generated_text}\n\n" f"SOURCE EXCERPTS:\n{source_text}\n\n" "Evaluate faithfulness now. Respond with JSON only." )}, ], ) raw = (response.choices[0].message.content or "").strip() if not raw: return {"claims": [], "faithfulness_score": None, "note": "Groq returned empty content (token budget). Try increasing max_tokens."} raw = raw.removeprefix("```json").removeprefix("```").removesuffix("```").strip() try: return _json.loads(raw) except _json.JSONDecodeError: return {"claims": [], "faithfulness_score": None, "note": "Could not parse evaluator output.", "raw": raw} if __name__ == "__main__": retriever = LiteratureRetriever() print(f"Backend: {retriever.backend} | Corpus size: {len(retriever.corpus)}") from intervention_engine import INTERVENTION_SUMMARY sample_ivs = [ {"intervention_summary": "Managing cardiovascular risk factors (BP / cholesterol)", "literature_tags": ["cardiovascular_risk"]}, {"intervention_summary": "Improving sleep quality and duration", "literature_tags": ["sleep_glymphatic"]}, ] retrieved = retriever.retrieve_for_interventions(sample_ivs) for area, docs in retrieved.items(): print(f"\n--- {area} ---") for d in docs: print(" •", d["title"])