CardioGenRAG / rag.py
hlnicholls's picture
feat: bot v1
0edf71e verified
import os
import time
from pathlib import Path
from typing import List, Tuple, Optional
import tiktoken
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
import fitz # PyMuPDF
# -----------------------------
# Config
# -----------------------------
PROMPT_PATH = Path("prompts/cardio_prompt.txt")
DATA_DIR = Path("data/papers")
VSTORE_DIR = Path("vectorstore")
LLAMA_CTX = int(os.getenv("LLAMA_CTX", "2048"))
ENC = tiktoken.get_encoding("cl100k_base")
def _num_tokens(text: str) -> int:
return len(ENC.encode(text))
class LLMBackend:
"""Backends: 'hf_local' (Transformers CPU), 'llamacpp' (local GGUF), 'lmstudio' (OpenAI-compatible)."""
def __init__(self, kind: str, temperature: float = 0.2, model_path: str = None, endpoint: str = None, n_ctx: int = LLAMA_CTX):
self.kind = kind
self.temperature = temperature
self.model_path = model_path
self.endpoint = endpoint
self.n_ctx = n_ctx
self._hf_pipe = None # cached transformers pipeline
@classmethod
def from_hf_local(cls, temperature: float = 0.2):
return cls(kind="hf_local", temperature=temperature)
@classmethod
def from_lmstudio(cls, temperature: float = 0.2):
endpoint = os.getenv("LMSTUDIO_ENDPOINT", "http://localhost:1234/v1")
return cls(kind="lmstudio", temperature=temperature, endpoint=endpoint)
@classmethod
def from_llamacpp(cls, model_path: str, temperature: float = 0.2, n_ctx: int = LLAMA_CTX):
return cls(kind="llamacpp", temperature=temperature, model_path=model_path, n_ctx=n_ctx)
def _ensure_hf_pipe(self):
if self._hf_pipe is None:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
model_id = os.getenv("HF_LOCAL_MODEL", "google/gemma-2-2b-it")
tok = AutoTokenizer.from_pretrained(model_id)
mdl = AutoModelForCausalLM.from_pretrained(model_id)
self._hf_pipe = pipeline("text-generation", model=mdl, tokenizer=tok, device=-1)
return self._hf_pipe
def generate(self, prompt: str, max_tokens: int = 120) -> str:
if self.kind == "hf_local":
pipe = self._ensure_hf_pipe()
out = pipe(
prompt,
max_new_tokens=max_tokens,
do_sample=True,
temperature=self.temperature
)
text = out[0]["generated_text"]
# return only the new continuation if the model echoes the prompt
return text.split(prompt, 1)[-1].strip()
elif self.kind == "lmstudio":
import requests
payload = {
"model": "local-model",
"messages": [{"role": "user", "content": prompt}],
"temperature": self.temperature,
"max_tokens": max_tokens,
}
r = requests.post(f"{self.endpoint}/chat/completions", json=payload, timeout=120)
r.raise_for_status()
return r.json()["choices"][0]["message"]["content"]
else: # llamacpp
from llama_cpp import Llama
llm = Llama(
model_path=self.model_path,
n_ctx=self.n_ctx,
n_threads=min(8, os.cpu_count() or 4),
verbose=False,
)
out = llm.create_chat_completion(
messages=[{"role": "user", "content": prompt}],
temperature=self.temperature,
max_tokens=max_tokens,
)
return out["choices"][0]["message"]["content"]
class CardioRAG:
def __init__(self):
self.client = chromadb.PersistentClient(
path=str(VSTORE_DIR),
settings=Settings(allow_reset=True, anonymized_telemetry=False),
)
self.collection = self.client.get_or_create_collection(name="papers")
self.embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# Utilities
def list_sources(self) -> List[str]:
return sorted([p.name for p in DATA_DIR.glob("*.pdf")])
# Ingestion
def rebuild_index(self):
self.client.reset()
self.collection = self.client.get_or_create_collection(name="papers")
self.ingest_directory(DATA_DIR)
def ingest_directory(self, dir_path: Path):
pdfs = list(dir_path.glob("*.pdf"))
for pdf in pdfs:
self._ingest_pdf(pdf)
def _extract_pdf(self, pdf_path: Path) -> List[Tuple[str, str]]:
doc = fitz.open(pdf_path)
sections = []
for i, page in enumerate(doc):
text = page.get_text("text")
sections.append((f"page {i+1}", text))
return sections
def _chunk(self, text: str, chunk_size_tokens: int = 300, overlap_tokens: int = 30) -> List[str]:
tokens = ENC.encode(text)
chunks = []
start = 0
while start < len(tokens):
end = min(start + chunk_size_tokens, len(tokens))
chunk = ENC.decode(tokens[start:end])
chunks.append(chunk)
start = end - overlap_tokens
if start < 0:
start = 0
if end == len(tokens):
break
return chunks
def _ingest_pdf(self, pdf_path: Path):
sections = self._extract_pdf(pdf_path)
docs, ids, metas = [], [], []
for idx, (section_name, text) in enumerate(sections):
text = text.strip().replace("\u00ad", "")
if len(text) < 200:
continue
for cidx, chunk in enumerate(self._chunk(text)):
docs.append(chunk)
ids.append(f"{pdf_path.name}-{idx}-{cidx}")
metas.append({"source": pdf_path.name, "section": section_name})
if docs:
embeddings = self.embedder.encode(docs, show_progress_bar=False).tolist()
self.collection.add(documents=docs, embeddings=embeddings, ids=ids, metadatas=metas)
# Retrieval
def _retrieve(self, query: str, top_k: int = 4, source_filter: Optional[str] = None):
q_emb = self.embedder.encode([query], show_progress_bar=False).tolist()[0]
where = {"source": source_filter} if source_filter else None
res = self.collection.query(query_embeddings=[q_emb], n_results=top_k, where=where)
docs = res.get("documents", [[]])[0]
metas = res.get("metadatas", [[]])[0]
return list(zip(docs, metas))
# Prompt builder
def _build_prompt(self, query: str, reading_level: str, retrieved: List[Tuple[str, dict]], max_gen_tokens: int) -> str:
prompt_tmpl = Path(PROMPT_PATH).read_text()
blocks = []
for doc, meta in retrieved:
source = f"{meta.get('source')} ({meta.get('section')})"
content = (doc or "")[:1600]
blocks.append(f"[SOURCE: {source}]\n{content}")
def assemble(bs: List[str]) -> str:
return prompt_tmpl.format(
reading_level=reading_level,
user_query=query,
context="\n\n".join(bs),
)
bs = blocks[:]
prompt = assemble(bs)
budget = int(LLAMA_CTX * 0.9) - max_gen_tokens
budget = max(budget, 512)
while _num_tokens(prompt) > budget and len(bs) > 1:
bs.pop()
prompt = assemble(bs)
while _num_tokens(prompt) > budget and bs:
bs[-1] = bs[-1][: max(0, len(bs[-1]) - 300)]
prompt = assemble(bs)
return prompt
# Answer
def answer(self, query: str, top_k: int, reading_level: str, llm: LLMBackend, source_filter: Optional[str] = None):
t0 = time.time()
retrieved = self._retrieve(query=query, top_k=top_k, source_filter=source_filter)
if not retrieved:
return (
"### Cardio Summary\nI couldn’t retrieve any context from your indexed papers. "
"Check that you’ve ingested at least one **text-based** PDF and rerun `python ingest.py`.",
[]
)
t1 = time.time()
max_gen = 120
prompt = self._build_prompt(query=query, reading_level=reading_level, retrieved=retrieved, max_gen_tokens=max_gen)
t2 = time.time()
answer = llm.generate(prompt, max_tokens=max_gen)
t3 = time.time()
md = f"### Cardio Summary\n{answer}\n"
sources = [f"{m.get('source')} ({m.get('section')})" for _, m in retrieved]
print(f"[timings] retrieve={t1-t0:.2f}s build_prompt={t2-t1:.2f}s generate={t3-t2:.2f}s")
return md, sources