""" rag_engine.py — Stage 3 of NeuroLens: Grounded, Personalized Prevention Coaching Retrieves relevant literature from a curated corpus (data/corpus.json) based on which linguistic biomarkers from Stage 2 stood out, then generates a citation- grounded coaching summary via Groq's inference API. Includes a RAGAS-style faithfulness check on the generated text against the retrieved sources. Embedding backend: - Primary (production / HuggingFace Spaces): sentence-transformers 'intfloat/e5-large-v2' + FAISS, for high-quality semantic retrieval. - Fallback (offline / restricted environments, e.g. sandboxed dev): scikit-learn TF-IDF + cosine similarity. Automatically used if sentence-transformers or its model weights aren't available, so the pipeline still runs end-to-end without internet access to the Hugging Face Hub. This keeps app.py agnostic to which backend is active. """ import json import os from pathlib import Path import numpy as np CORPUS_PATH = Path(__file__).parent / "data" / "corpus.json" # Maps a Stage 2 biomarker "band" finding to the literature domains most # relevant to addressing it. Used to bias retrieval toward useful content # rather than retrieving purely by surface similarity to assessment text. MARKER_TO_DOMAINS = { "semantic_fluency": ["cognitive_training", "social_engagement"], "phonemic_fluency": ["exercise_cognitive_reserve", "diet_nutrition"], "lexical_diversity": ["cognitive_reserve_early_life", "cognitive_training"], "idea_density": ["cognitive_reserve_early_life", "social_engagement"], "syntactic_complexity": ["exercise_cognitive_reserve", "diet_nutrition"], } def load_corpus() -> list: with open(CORPUS_PATH, "r") as f: return json.load(f) class LiteratureRetriever: """Embedding-backed retriever over the curated corpus, with automatic fallback to TF-IDF if a neural embedding backend isn't available.""" def __init__(self): self.corpus = load_corpus() self.texts = [f"{d['title']}. {d['summary']}" for d in self.corpus] self.backend = None self._init_backend() def _init_backend(self): # Try neural embeddings first (production path) try: from sentence_transformers import SentenceTransformer model = SentenceTransformer("intfloat/e5-large-v2") # e5 models expect "passage: " / "query: " prefixes for best results passage_texts = [f"passage: {t}" for t in self.texts] embeddings = model.encode(passage_texts, normalize_embeddings=True) self._st_model = model self._embeddings = np.array(embeddings, dtype="float32") self.backend = "sentence-transformers" try: import faiss index = faiss.IndexFlatIP(self._embeddings.shape[1]) index.add(self._embeddings) self._faiss_index = index self.backend = "sentence-transformers+faiss" except ImportError: self._faiss_index = None return except Exception: pass # Fallback: TF-IDF (always available, no network/model download required) from sklearn.feature_extraction.text import TfidfVectorizer self._vectorizer = TfidfVectorizer(stop_words="english") self._tfidf_matrix = self._vectorizer.fit_transform(self.texts) self.backend = "tfidf_fallback" def _retrieve_indices(self, query: str, k: int, candidate_idx=None): if self.backend.startswith("sentence-transformers"): q_emb = self._st_model.encode([f"query: {query}"], normalize_embeddings=True) q_emb = np.array(q_emb, dtype="float32") if candidate_idx is not None and len(candidate_idx) > 0: sub_embeddings = self._embeddings[candidate_idx] sims = sub_embeddings @ q_emb[0] order = np.argsort(-sims)[:k] return [candidate_idx[i] for i in order] else: sims = self._embeddings @ q_emb[0] return list(np.argsort(-sims)[:k]) else: from sklearn.metrics.pairwise import cosine_similarity q_vec = self._vectorizer.transform([query]) if candidate_idx is not None and len(candidate_idx) > 0: sims = cosine_similarity(q_vec, self._tfidf_matrix[candidate_idx]).flatten() order = np.argsort(-sims)[:k] return [candidate_idx[i] for i in order] else: sims = cosine_similarity(q_vec, self._tfidf_matrix).flatten() return list(np.argsort(-sims)[:k]) def retrieve_for_domains(self, query: str, domains: list, k: int = 2) -> list: candidate_idx = [i for i, d in enumerate(self.corpus) if d["domain"] in domains] idx = self._retrieve_indices(query, k, candidate_idx=candidate_idx if candidate_idx else None) return [self.corpus[i] for i in idx] def select_retrieval_plan(bands: dict) -> dict: """ Given Stage 2's `bands` dict (marker -> 'below_typical_range' / 'within_typical_range' / 'above_typical_range'), decide which markers are worth addressing and which literature domains to pull from for each. """ plan = {} for marker, band in bands.items(): if band == "below_typical_range" and marker in MARKER_TO_DOMAINS: plan[marker] = MARKER_TO_DOMAINS[marker] return plan def build_retrieved_context(retriever: LiteratureRetriever, profile: dict, k_per_marker: int = 2) -> dict: """Returns {marker: [retrieved doc dicts]} for every marker that needs addressing.""" plan = select_retrieval_plan(profile["bands"]) retrieved = {} queries = { "semantic_fluency": "interventions to improve semantic memory and category fluency", "phonemic_fluency": "exercise and vascular health for executive function and word retrieval", "lexical_diversity": "building cognitive reserve through lifelong learning and engagement", "idea_density": "early and ongoing intellectual engagement and cognitive reserve", "syntactic_complexity": "physical activity and diet supporting language processing and cognitive load", } for marker, domains in plan.items(): docs = retriever.retrieve_for_domains(queries.get(marker, marker), domains, k=k_per_marker) retrieved[marker] = docs return retrieved # --------------------------------------------------------------------------- # Generation — uses Groq's inference API running an open-weight model # (openai/gpt-oss-120b). See README.md for API key setup. # --------------------------------------------------------------------------- COACHING_SYSTEM_PROMPT = """You are a brain-health research communicator writing for NeuroLens, \ a non-clinical research demonstration prototype. You write supportive, non-alarmist, \ evidence-grounded coaching summaries for users based on their performance on simplified \ cognitive-linguistic tasks. Strict rules: - Only state claims that are directly supported by the provided source excerpts. Do not \ add outside knowledge or invent statistics. - Every factual claim about research findings must be attributable to one of the provided \ sources; refer to sources naturally (e.g., "Research from the FINGER trial suggests..."). - Never use clinical or diagnostic language (no "you may have," "this indicates early signs of," \ "this is a symptom of"). Frame everything as "your results on this task" and "research on this topic," \ never as a statement about the user's health status. - Tone: warm, encouraging, plain-language. No alarmism. - End with one sentence reminding the user this is a research prototype, not a clinical assessment. - Keep the full response under 250 words. - Respond with plain prose only — no markdown headers, no preamble like "Here is your summary." """ GROQ_MODEL = "openai/gpt-oss-120b" # see README.md for API key setup at console.groq.com def generate_coaching(profile: dict, retrieved_context: dict, client=None, model: str = GROQ_MODEL) -> dict: """ Calls Groq's API to generate a grounded coaching summary. `client` should be an instantiated groq.Groq() client; if None, one is created from the GROQ_API_KEY environment variable (see README.md for setup). Returns {"text": str, "sources_used": [...]} """ from groq import Groq if client is None: client = Groq() if not retrieved_context: return { "text": ("Your results across these short tasks fell within the typical ranges used for " "this demo, so there's no specific area to highlight right now. Maintaining a mix " "of physical activity, social engagement, varied learning, sleep, and a nutrient-dense " "diet is broadly supported by the brain-health prevention literature regardless. " "Reminder: NeuroLens is a research prototype, not a clinical assessment."), "sources_used": [], } context_blocks = [] all_sources = [] for marker, docs in retrieved_context.items(): for d in docs: context_blocks.append( f"[Source: {d['title']} — {d['source']}]\n{d['summary']}" ) all_sources.append(d) context_text = "\n\n".join(context_blocks) markers_text = ", ".join(retrieved_context.keys()) user_prompt = f"""The user's results on these tasks were below the typical comparison range: {markers_text}. Here are the relevant source excerpts to ground your response (use ONLY these): {context_text} Write the personalized coaching summary now.""" response = client.chat.completions.create( model=model, max_tokens=1000, reasoning_effort="low", messages=[ {"role": "system", "content": COACHING_SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ], ) text = (response.choices[0].message.content or "").strip() if not text: raise RuntimeError( "Groq returned empty content for coaching generation (model spent its token " "budget on internal reasoning). Try increasing max_tokens further." ) # de-duplicate sources by id seen_ids = set() unique_sources = [] for s in all_sources: if s["id"] not in seen_ids: seen_ids.add(s["id"]) unique_sources.append(s) return {"text": text, "sources_used": unique_sources} # --------------------------------------------------------------------------- # Faithfulness evaluation (RAGAS-style) — also via Groq. # --------------------------------------------------------------------------- FAITHFULNESS_SYSTEM_PROMPT = """You are a strict fact-checker. You will be given a generated \ coaching text and a set of source excerpts it was supposed to be grounded in. Break the generated text into individual factual claims about research findings. For each \ claim, determine whether it is directly supported by the source excerpts (SUPPORTED), \ partially supported / overstated relative to the sources (PARTIAL), or not supported / \ hallucinated (UNSUPPORTED). Respond ONLY with valid JSON in this exact format — no other text, no markdown code fences, \ no explanation before or after the JSON: { "claims": [ {"claim": "...", "verdict": "SUPPORTED|PARTIAL|UNSUPPORTED", "reason": "..."} ], "faithfulness_score": } Keep each "reason" to one short sentence. Do not show your reasoning process — output the JSON object directly. """ def check_faithfulness(generated_text: str, sources_used: list, client=None, model: str = GROQ_MODEL) -> dict: from groq import Groq import json as _json if client is None: client = Groq() if not sources_used: return {"claims": [], "faithfulness_score": 1.0, "note": "No sources retrieved; generic fallback text used."} source_text = "\n\n".join(f"[{s['title']}]\n{s['summary']}" for s in sources_used) user_prompt = f"""GENERATED TEXT: {generated_text} SOURCE EXCERPTS: {source_text} Evaluate faithfulness now. Respond with JSON only.""" response = client.chat.completions.create( model=model, max_tokens=2000, reasoning_effort="low", messages=[ {"role": "system", "content": FAITHFULNESS_SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ], # Note: response_format={"type": "json_object"} is intentionally omitted — # it has been unreliable with this model (see Groq community reports of # json_validate_failed / structured outputs being ignored on gpt-oss-120b). # Plain prompt-based JSON instructions + the parsing below are more robust. ) raw = (response.choices[0].message.content or "").strip() if not raw: return {"claims": [], "faithfulness_score": None, "note": "Groq returned empty content (model spent its token budget on internal " "reasoning). Try increasing max_tokens further."} raw = raw.removeprefix("```json").removeprefix("```").removesuffix("```").strip() try: return _json.loads(raw) except _json.JSONDecodeError: return {"claims": [], "faithfulness_score": None, "note": "Could not parse evaluator output", "raw": raw} if __name__ == "__main__": retriever = LiteratureRetriever() print(f"Backend: {retriever.backend}, corpus size: {len(retriever.corpus)}") fake_bands = { "semantic_fluency": "below_typical_range", "phonemic_fluency": "within_typical_range", "lexical_diversity": "within_typical_range", "idea_density": "below_typical_range", "syntactic_complexity": "within_typical_range", } fake_profile = {"bands": fake_bands} context = build_retrieved_context(retriever, fake_profile) for marker, docs in context.items(): print(f"\n--- {marker} ---") for d in docs: print(" ", d["title"])