Spaces:
Sleeping
Sleeping
| """ | |
| rag_engine.py — CognitivePulse | |
| Retrieval-augmented generation for the coaching assistant. | |
| Retrieves literature from data/corpus.json based on the intervention domains | |
| identified by intervention_engine.py, generates a grounded coaching brief, | |
| and runs a RAGAS-style faithfulness check on the output. | |
| LLM backend: Groq inference API (openai/gpt-oss-120b). | |
| Embedding backend: sentence-transformers (intfloat/e5-large-v2) + FAISS; | |
| automatic TF-IDF fallback if model is unavailable. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| from pathlib import Path | |
| import numpy as np | |
| CORPUS_PATH = Path(__file__).parent / "data" / "corpus.json" | |
| GROQ_MODEL = "openai/gpt-oss-120b" # see README.md for API key setup at console.groq.com | |
| def load_corpus() -> list: | |
| with open(CORPUS_PATH) as f: | |
| return json.load(f) | |
| class LiteratureRetriever: | |
| """Embedding-backed retriever over the curated corpus with TF-IDF fallback.""" | |
| def __init__(self): | |
| self.corpus = load_corpus() | |
| self.texts = [f"{d['title']}. {d['summary']}" for d in self.corpus] | |
| self.backend = None | |
| self._init_backend() | |
| def _init_backend(self): | |
| try: | |
| from sentence_transformers import SentenceTransformer | |
| model = SentenceTransformer("intfloat/e5-large-v2") | |
| passage_texts = [f"passage: {t}" for t in self.texts] | |
| embeddings = model.encode(passage_texts, normalize_embeddings=True) | |
| self._st_model = model | |
| self._embeddings = np.array(embeddings, dtype="float32") | |
| self.backend = "sentence-transformers" | |
| try: | |
| import faiss | |
| index = faiss.IndexFlatIP(self._embeddings.shape[1]) | |
| index.add(self._embeddings) | |
| self._faiss_index = index | |
| self.backend = "sentence-transformers+faiss" | |
| except ImportError: | |
| self._faiss_index = None | |
| return | |
| except Exception: | |
| pass | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| self._vectorizer = TfidfVectorizer(stop_words="english") | |
| self._tfidf_matrix = self._vectorizer.fit_transform(self.texts) | |
| self.backend = "tfidf_fallback" | |
| def _retrieve_indices(self, query: str, k: int, candidate_idx=None): | |
| if self.backend.startswith("sentence-transformers"): | |
| q_emb = self._st_model.encode([f"query: {query}"], normalize_embeddings=True) | |
| q_emb = np.array(q_emb, dtype="float32") | |
| subset = self._embeddings[candidate_idx] if candidate_idx else self._embeddings | |
| sims = subset @ q_emb[0] | |
| order = np.argsort(-sims)[:k] | |
| return [candidate_idx[i] for i in order] if candidate_idx else list(order) | |
| else: | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| q_vec = self._vectorizer.transform([query]) | |
| matrix = self._tfidf_matrix[candidate_idx] if candidate_idx else self._tfidf_matrix | |
| sims = cosine_similarity(q_vec, matrix).flatten() | |
| order = np.argsort(-sims)[:k] | |
| return [candidate_idx[i] for i in order] if candidate_idx else list(order) | |
| def retrieve_for_domains(self, query: str, domain_tags: list, k: int = 2) -> list: | |
| candidate_idx = [i for i, d in enumerate(self.corpus) if d["domain"] in domain_tags] | |
| idx = self._retrieve_indices(query, k, candidate_idx=candidate_idx or None) | |
| return [self.corpus[i] for i in idx] | |
| def retrieve_for_interventions(self, interventions: list, k_per_intervention: int = 2) -> dict: | |
| """Returns {intervention_summary: [retrieved docs]} for each intervention.""" | |
| retrieved = {} | |
| for iv in interventions: | |
| query = f"interventions and evidence for {iv['intervention_summary']}" | |
| docs = self.retrieve_for_domains(query, iv["literature_tags"], k=k_per_intervention) | |
| if docs: | |
| retrieved[iv["intervention_summary"]] = docs | |
| return retrieved | |
| # --------------------------------------------------------------------------- | |
| # Coaching generation | |
| # --------------------------------------------------------------------------- | |
| COACHING_SYSTEM_PROMPT = """You are an AI assistant helping prepare a brain health coaching brief \ | |
| for a health professional at a preventive neurology clinic. Based on a client's biomarker risk \ | |
| profile and the research literature provided, write a structured, evidence-grounded coaching summary. | |
| Strict rules: | |
| - Only state claims that are directly supported by the provided source excerpts. | |
| - Every factual claim about research findings must cite a source naturally \ | |
| (e.g. "The SPRINT MIND trial found that..."). | |
| - Never make diagnostic statements. Frame everything as risk factors and evidence-based \ | |
| lifestyle interventions, not diagnoses. | |
| - Tone: professional, evidence-informed, supportive. Suitable for a clinician to read \ | |
| before a client coaching session. | |
| - Structure your response as: (1) Risk Summary, (2) Priority Interventions with evidence, \ | |
| (3) Suggested coaching focus areas. | |
| - Keep the full response under 350 words. | |
| - Respond with plain prose only — no markdown headers, no preamble.""" | |
| def generate_coaching(coach_brief: str, retrieved_context: dict, | |
| client=None, model: str = GROQ_MODEL) -> dict: | |
| """ | |
| Calls Groq to generate a grounded coaching summary from the coach brief | |
| and retrieved literature. | |
| Returns {"text": str, "sources_used": [...]}. | |
| """ | |
| from groq import Groq | |
| if client is None: | |
| client = Groq() | |
| if not retrieved_context: | |
| return { | |
| "text": ("No specific modifiable risk factors were flagged above typical ranges. " | |
| "General brain-health maintenance advice applies: consistent physical activity, " | |
| "quality sleep, a nutrient-dense diet, and active social and cognitive engagement " | |
| "are supported across the prevention literature. " | |
| "Note: this is a research prototype, not a clinical assessment."), | |
| "sources_used": [], | |
| } | |
| context_blocks = [] | |
| all_sources = [] | |
| for area, docs in retrieved_context.items(): | |
| for d in docs: | |
| context_blocks.append(f"[Source: {d['title']} — {d['source']}]\n{d['summary']}") | |
| all_sources.append(d) | |
| context_text = "\n\n".join(context_blocks) | |
| user_prompt = f"""CLIENT RISK BRIEF: | |
| {coach_brief} | |
| RETRIEVED RESEARCH EVIDENCE (use ONLY these sources): | |
| {context_text} | |
| Write the coaching summary now.""" | |
| response = client.chat.completions.create( | |
| model=model, | |
| max_tokens=1000, | |
| reasoning_effort="low", | |
| messages=[ | |
| {"role": "system", "content": COACHING_SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| ) | |
| text = (response.choices[0].message.content or "").strip() | |
| if not text: | |
| raise RuntimeError("Groq returned empty content. Try increasing max_tokens.") | |
| seen_ids = set() | |
| unique_sources = [s for s in all_sources | |
| if s["id"] not in seen_ids and not seen_ids.add(s["id"])] | |
| return {"text": text, "sources_used": unique_sources} | |
| # --------------------------------------------------------------------------- | |
| # Faithfulness evaluation | |
| # --------------------------------------------------------------------------- | |
| FAITHFULNESS_SYSTEM_PROMPT = """You are a strict clinical fact-checker. You will be given \ | |
| a generated coaching text and the source excerpts it was grounded in. | |
| Break the coaching text into individual factual claims about research findings. For each claim, \ | |
| determine whether it is directly supported by the source excerpts (SUPPORTED), partially \ | |
| supported or overstated (PARTIAL), or not supported / hallucinated (UNSUPPORTED). | |
| Respond ONLY with valid JSON in this exact format — no other text, no markdown code fences, \ | |
| no explanation before or after the JSON: | |
| { | |
| "claims": [ | |
| {"claim": "...", "verdict": "SUPPORTED|PARTIAL|UNSUPPORTED", "reason": "..."} | |
| ], | |
| "faithfulness_score": <float between 0 and 1> | |
| } | |
| Keep each reason to one short sentence. Output the JSON object directly.""" | |
| def check_faithfulness(generated_text: str, sources_used: list, | |
| client=None, model: str = GROQ_MODEL) -> dict: | |
| import json as _json | |
| from groq import Groq | |
| if client is None: | |
| client = Groq() | |
| if not sources_used: | |
| return {"claims": [], "faithfulness_score": 1.0, | |
| "note": "No sources retrieved; fallback text used."} | |
| source_text = "\n\n".join(f"[{s['title']}]\n{s['summary']}" for s in sources_used) | |
| response = client.chat.completions.create( | |
| model=model, | |
| max_tokens=2000, | |
| reasoning_effort="low", | |
| messages=[ | |
| {"role": "system", "content": FAITHFULNESS_SYSTEM_PROMPT}, | |
| {"role": "user", "content": ( | |
| f"GENERATED TEXT:\n{generated_text}\n\n" | |
| f"SOURCE EXCERPTS:\n{source_text}\n\n" | |
| "Evaluate faithfulness now. Respond with JSON only." | |
| )}, | |
| ], | |
| ) | |
| raw = (response.choices[0].message.content or "").strip() | |
| if not raw: | |
| return {"claims": [], "faithfulness_score": None, | |
| "note": "Groq returned empty content (token budget). Try increasing max_tokens."} | |
| raw = raw.removeprefix("```json").removeprefix("```").removesuffix("```").strip() | |
| try: | |
| return _json.loads(raw) | |
| except _json.JSONDecodeError: | |
| return {"claims": [], "faithfulness_score": None, | |
| "note": "Could not parse evaluator output.", "raw": raw} | |
| if __name__ == "__main__": | |
| retriever = LiteratureRetriever() | |
| print(f"Backend: {retriever.backend} | Corpus size: {len(retriever.corpus)}") | |
| from intervention_engine import INTERVENTION_SUMMARY | |
| sample_ivs = [ | |
| {"intervention_summary": "Managing cardiovascular risk factors (BP / cholesterol)", | |
| "literature_tags": ["cardiovascular_risk"]}, | |
| {"intervention_summary": "Improving sleep quality and duration", | |
| "literature_tags": ["sleep_glymphatic"]}, | |
| ] | |
| retrieved = retriever.retrieve_for_interventions(sample_ivs) | |
| for area, docs in retrieved.items(): | |
| print(f"\n--- {area} ---") | |
| for d in docs: | |
| print(" •", d["title"]) | |