NeuroLens / rag_engine.py
Kshamaa S
Initial deployment: NeuroLens cognitive health screening pipeline
8dccdbf
Raw
History Blame Contribute Delete
14.2 kB
"""
rag_engine.py — Stage 3 of NeuroLens: Grounded, Personalized Prevention Coaching
Retrieves relevant literature from a curated corpus (data/corpus.json) based on
which linguistic biomarkers from Stage 2 stood out, then generates a citation-
grounded coaching summary via Groq's inference API. Includes a RAGAS-style
faithfulness check on the generated text against the retrieved sources.
Embedding backend:
- Primary (production / HuggingFace Spaces): sentence-transformers
'intfloat/e5-large-v2' + FAISS, for high-quality semantic retrieval.
- Fallback (offline / restricted environments, e.g. sandboxed dev): scikit-learn
TF-IDF + cosine similarity. Automatically used if sentence-transformers or
its model weights aren't available, so the pipeline still runs end-to-end
without internet access to the Hugging Face Hub.
This keeps app.py agnostic to which backend is active.
"""
import json
import os
from pathlib import Path
import numpy as np
CORPUS_PATH = Path(__file__).parent / "data" / "corpus.json"
# Maps a Stage 2 biomarker "band" finding to the literature domains most
# relevant to addressing it. Used to bias retrieval toward useful content
# rather than retrieving purely by surface similarity to assessment text.
MARKER_TO_DOMAINS = {
"semantic_fluency": ["cognitive_training", "social_engagement"],
"phonemic_fluency": ["exercise_cognitive_reserve", "diet_nutrition"],
"lexical_diversity": ["cognitive_reserve_early_life", "cognitive_training"],
"idea_density": ["cognitive_reserve_early_life", "social_engagement"],
"syntactic_complexity": ["exercise_cognitive_reserve", "diet_nutrition"],
}
def load_corpus() -> list:
with open(CORPUS_PATH, "r") as f:
return json.load(f)
class LiteratureRetriever:
"""Embedding-backed retriever over the curated corpus, with automatic
fallback to TF-IDF if a neural embedding backend isn't available."""
def __init__(self):
self.corpus = load_corpus()
self.texts = [f"{d['title']}. {d['summary']}" for d in self.corpus]
self.backend = None
self._init_backend()
def _init_backend(self):
# Try neural embeddings first (production path)
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("intfloat/e5-large-v2")
# e5 models expect "passage: " / "query: " prefixes for best results
passage_texts = [f"passage: {t}" for t in self.texts]
embeddings = model.encode(passage_texts, normalize_embeddings=True)
self._st_model = model
self._embeddings = np.array(embeddings, dtype="float32")
self.backend = "sentence-transformers"
try:
import faiss
index = faiss.IndexFlatIP(self._embeddings.shape[1])
index.add(self._embeddings)
self._faiss_index = index
self.backend = "sentence-transformers+faiss"
except ImportError:
self._faiss_index = None
return
except Exception:
pass
# Fallback: TF-IDF (always available, no network/model download required)
from sklearn.feature_extraction.text import TfidfVectorizer
self._vectorizer = TfidfVectorizer(stop_words="english")
self._tfidf_matrix = self._vectorizer.fit_transform(self.texts)
self.backend = "tfidf_fallback"
def _retrieve_indices(self, query: str, k: int, candidate_idx=None):
if self.backend.startswith("sentence-transformers"):
q_emb = self._st_model.encode([f"query: {query}"], normalize_embeddings=True)
q_emb = np.array(q_emb, dtype="float32")
if candidate_idx is not None and len(candidate_idx) > 0:
sub_embeddings = self._embeddings[candidate_idx]
sims = sub_embeddings @ q_emb[0]
order = np.argsort(-sims)[:k]
return [candidate_idx[i] for i in order]
else:
sims = self._embeddings @ q_emb[0]
return list(np.argsort(-sims)[:k])
else:
from sklearn.metrics.pairwise import cosine_similarity
q_vec = self._vectorizer.transform([query])
if candidate_idx is not None and len(candidate_idx) > 0:
sims = cosine_similarity(q_vec, self._tfidf_matrix[candidate_idx]).flatten()
order = np.argsort(-sims)[:k]
return [candidate_idx[i] for i in order]
else:
sims = cosine_similarity(q_vec, self._tfidf_matrix).flatten()
return list(np.argsort(-sims)[:k])
def retrieve_for_domains(self, query: str, domains: list, k: int = 2) -> list:
candidate_idx = [i for i, d in enumerate(self.corpus) if d["domain"] in domains]
idx = self._retrieve_indices(query, k, candidate_idx=candidate_idx if candidate_idx else None)
return [self.corpus[i] for i in idx]
def select_retrieval_plan(bands: dict) -> dict:
"""
Given Stage 2's `bands` dict (marker -> 'below_typical_range' / 'within_typical_range'
/ 'above_typical_range'), decide which markers are worth addressing and which
literature domains to pull from for each.
"""
plan = {}
for marker, band in bands.items():
if band == "below_typical_range" and marker in MARKER_TO_DOMAINS:
plan[marker] = MARKER_TO_DOMAINS[marker]
return plan
def build_retrieved_context(retriever: LiteratureRetriever, profile: dict, k_per_marker: int = 2) -> dict:
"""Returns {marker: [retrieved doc dicts]} for every marker that needs addressing."""
plan = select_retrieval_plan(profile["bands"])
retrieved = {}
queries = {
"semantic_fluency": "interventions to improve semantic memory and category fluency",
"phonemic_fluency": "exercise and vascular health for executive function and word retrieval",
"lexical_diversity": "building cognitive reserve through lifelong learning and engagement",
"idea_density": "early and ongoing intellectual engagement and cognitive reserve",
"syntactic_complexity": "physical activity and diet supporting language processing and cognitive load",
}
for marker, domains in plan.items():
docs = retriever.retrieve_for_domains(queries.get(marker, marker), domains, k=k_per_marker)
retrieved[marker] = docs
return retrieved
# ---------------------------------------------------------------------------
# Generation — uses Groq's inference API running an open-weight model
# (openai/gpt-oss-120b). See README.md for API key setup.
# ---------------------------------------------------------------------------
COACHING_SYSTEM_PROMPT = """You are a brain-health research communicator writing for NeuroLens, \
a non-clinical research demonstration prototype. You write supportive, non-alarmist, \
evidence-grounded coaching summaries for users based on their performance on simplified \
cognitive-linguistic tasks.
Strict rules:
- Only state claims that are directly supported by the provided source excerpts. Do not \
add outside knowledge or invent statistics.
- Every factual claim about research findings must be attributable to one of the provided \
sources; refer to sources naturally (e.g., "Research from the FINGER trial suggests...").
- Never use clinical or diagnostic language (no "you may have," "this indicates early signs of," \
"this is a symptom of"). Frame everything as "your results on this task" and "research on this topic," \
never as a statement about the user's health status.
- Tone: warm, encouraging, plain-language. No alarmism.
- End with one sentence reminding the user this is a research prototype, not a clinical assessment.
- Keep the full response under 250 words.
- Respond with plain prose only — no markdown headers, no preamble like "Here is your summary."
"""
GROQ_MODEL = "openai/gpt-oss-120b" # see README.md for API key setup at console.groq.com
def generate_coaching(profile: dict, retrieved_context: dict, client=None, model: str = GROQ_MODEL) -> dict:
"""
Calls Groq's API to generate a grounded coaching summary.
`client` should be an instantiated groq.Groq() client; if None, one is
created from the GROQ_API_KEY environment variable (see README.md for setup).
Returns {"text": str, "sources_used": [...]}
"""
from groq import Groq
if client is None:
client = Groq()
if not retrieved_context:
return {
"text": ("Your results across these short tasks fell within the typical ranges used for "
"this demo, so there's no specific area to highlight right now. Maintaining a mix "
"of physical activity, social engagement, varied learning, sleep, and a nutrient-dense "
"diet is broadly supported by the brain-health prevention literature regardless. "
"Reminder: NeuroLens is a research prototype, not a clinical assessment."),
"sources_used": [],
}
context_blocks = []
all_sources = []
for marker, docs in retrieved_context.items():
for d in docs:
context_blocks.append(
f"[Source: {d['title']}{d['source']}]\n{d['summary']}"
)
all_sources.append(d)
context_text = "\n\n".join(context_blocks)
markers_text = ", ".join(retrieved_context.keys())
user_prompt = f"""The user's results on these tasks were below the typical comparison range: {markers_text}.
Here are the relevant source excerpts to ground your response (use ONLY these):
{context_text}
Write the personalized coaching summary now."""
response = client.chat.completions.create(
model=model,
max_tokens=1000,
reasoning_effort="low",
messages=[
{"role": "system", "content": COACHING_SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
)
text = (response.choices[0].message.content or "").strip()
if not text:
raise RuntimeError(
"Groq returned empty content for coaching generation (model spent its token "
"budget on internal reasoning). Try increasing max_tokens further."
)
# de-duplicate sources by id
seen_ids = set()
unique_sources = []
for s in all_sources:
if s["id"] not in seen_ids:
seen_ids.add(s["id"])
unique_sources.append(s)
return {"text": text, "sources_used": unique_sources}
# ---------------------------------------------------------------------------
# Faithfulness evaluation (RAGAS-style) — also via Groq.
# ---------------------------------------------------------------------------
FAITHFULNESS_SYSTEM_PROMPT = """You are a strict fact-checker. You will be given a generated \
coaching text and a set of source excerpts it was supposed to be grounded in.
Break the generated text into individual factual claims about research findings. For each \
claim, determine whether it is directly supported by the source excerpts (SUPPORTED), \
partially supported / overstated relative to the sources (PARTIAL), or not supported / \
hallucinated (UNSUPPORTED).
Respond ONLY with valid JSON in this exact format — no other text, no markdown code fences, \
no explanation before or after the JSON:
{
"claims": [
{"claim": "...", "verdict": "SUPPORTED|PARTIAL|UNSUPPORTED", "reason": "..."}
],
"faithfulness_score": <float between 0 and 1, fraction of claims that are SUPPORTED>
}
Keep each "reason" to one short sentence. Do not show your reasoning process — output the
JSON object directly.
"""
def check_faithfulness(generated_text: str, sources_used: list, client=None, model: str = GROQ_MODEL) -> dict:
from groq import Groq
import json as _json
if client is None:
client = Groq()
if not sources_used:
return {"claims": [], "faithfulness_score": 1.0, "note": "No sources retrieved; generic fallback text used."}
source_text = "\n\n".join(f"[{s['title']}]\n{s['summary']}" for s in sources_used)
user_prompt = f"""GENERATED TEXT:
{generated_text}
SOURCE EXCERPTS:
{source_text}
Evaluate faithfulness now. Respond with JSON only."""
response = client.chat.completions.create(
model=model,
max_tokens=2000,
reasoning_effort="low",
messages=[
{"role": "system", "content": FAITHFULNESS_SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
# Note: response_format={"type": "json_object"} is intentionally omitted —
# it has been unreliable with this model (see Groq community reports of
# json_validate_failed / structured outputs being ignored on gpt-oss-120b).
# Plain prompt-based JSON instructions + the parsing below are more robust.
)
raw = (response.choices[0].message.content or "").strip()
if not raw:
return {"claims": [], "faithfulness_score": None,
"note": "Groq returned empty content (model spent its token budget on internal "
"reasoning). Try increasing max_tokens further."}
raw = raw.removeprefix("```json").removeprefix("```").removesuffix("```").strip()
try:
return _json.loads(raw)
except _json.JSONDecodeError:
return {"claims": [], "faithfulness_score": None, "note": "Could not parse evaluator output", "raw": raw}
if __name__ == "__main__":
retriever = LiteratureRetriever()
print(f"Backend: {retriever.backend}, corpus size: {len(retriever.corpus)}")
fake_bands = {
"semantic_fluency": "below_typical_range",
"phonemic_fluency": "within_typical_range",
"lexical_diversity": "within_typical_range",
"idea_density": "below_typical_range",
"syntactic_complexity": "within_typical_range",
}
fake_profile = {"bands": fake_bands}
context = build_retrieved_context(retriever, fake_profile)
for marker, docs in context.items():
print(f"\n--- {marker} ---")
for d in docs:
print(" ", d["title"])