post-n-RAG / rag_ui.py
MarlonKegel
added diversity cap when searching across all sources
8cd5cc6
# activate venv: source .venv/bin/activate
### To stage, commit, and push after edits, save file and then run in terminal:
## stage
# git add rag_ui.py
## commit
# git commit -m "EXPLAIN CHANGES"
## push
# git push hf main
import streamlit as st
st.set_page_config(page_title="Post-Neoliberalism Literature RAG", layout="centered")
import os
import json
import numpy as np
import faiss
from openai import OpenAI
import re
import gzip
from huggingface_hub import hf_hub_download
from rank_bm25 import BM25Okapi
import io
from docx import Document
import hashlib
import math
# Caching for search results function
@st.cache_data(show_spinner=False, max_entries=256)
def cached_search(query, chunk_idx_pool_tuple, n_final):
return hybrid_search(query, chunk_idx_pool=list(chunk_idx_pool_tuple) if chunk_idx_pool_tuple else None, n_final=n_final)
############### TOKENIZER AND NORM FUNCTION ##############
def query_tokenize(text):
return re.findall(r"\w+", text.lower())
def l2_normalize(vecs, axis=1, epsilon=1e-10):
norms = np.linalg.norm(vecs, ord=2, axis=axis, keepdims=True)
return vecs / (norms + epsilon)
############# DOWNLOAD DATA AND INDEX ##############
print("Checking /tmp/ directory...")
print("Exists?", os.path.exists("/tmp"))
print("Writeable?", os.access("/tmp", os.W_OK))
print("Listing:", os.listdir("/tmp"))
HF_USERNAME = "mkegel"
HF_REPONAME = "post_n_RAG_chunks"
chunks_gz = hf_hub_download(
repo_id=f"{HF_USERNAME}/{HF_REPONAME}",
filename="zotero_chunks_with_embeddings.json.gz",
repo_type="dataset"
)
faiss_gz = hf_hub_download(
repo_id=f"{HF_USERNAME}/{HF_REPONAME}",
filename="zotero_chunks.index.gz",
repo_type="dataset"
)
### PARAMETERS ###
EMBED_MODEL = "text-embedding-3-large"
TOPK_SPARSE = 20
TOPK_DENSE = 20
CONTEXT_CHUNKS = 15
REASONING_MODELS = {"o3", "o4-mini"} # Models using responses endpoint and reasoning
TEMPERATURE_MODELS = {"gpt-4.1", "gpt-4.1-mini"} # Models using completions endpoint with temp
# --- Load Data ---
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
@st.cache_resource(show_spinner=True)
def load_search_data():
with gzip.open(chunks_gz, "rt", encoding="utf-8") as f:
chunks = json.load(f)
with gzip.open(faiss_gz, "rb") as fidx:
with open("/tmp/zotero_chunks.index", "wb") as fout:
fout.write(fidx.read())
faiss_index = faiss.read_index("/tmp/zotero_chunks.index")
# get tokens for BM25
tokenized_texts = [c["tokens"] for c in chunks]
bm25 = BM25Okapi(tokenized_texts)
return chunks, faiss_index, bm25
chunks, faiss_index, bm25 = load_search_data()
# --- Utility to build author/title list for dropdown ---
def primary_author(authors_string):
# Parse "Lastname, Firstname; ..." and return the first last name (for sorting)
if not authors_string: return ""
match = re.match(r"([^,; ]+)", authors_string.strip())
return (match.group(1) or "").strip().lower() if match else authors_string.strip().lower()
def make_source_label(chunk):
meta = chunk["metadata"]
first_author_last = primary_author(meta["authors"])
return f"{first_author_last.title()}, {meta['authors']} - \"{meta['title']}\" ({meta['year']})"
source_groups = {} # (author_last, title) -> [chunk_indices]
for idx, c in enumerate(chunks):
last = primary_author(c["metadata"]["authors"])
title = c["metadata"]["title"]
key = (last, title)
if key not in source_groups:
source_groups[key] = []
source_groups[key].append(idx)
sources_sorted = sorted(source_groups.keys(), key=lambda x: (x[0], x[1].lower()))
source_labels = [f"{author.title()} - \"{title}\"" for author, title in sources_sorted]
source_key_map = dict(zip(source_labels, sources_sorted)) # Map label to (author_last, title)
# --- Retrieval Functions ---
########### BM25-BASED SPARSE SEARCH ###########
def sparse_search(query, chunk_idx_pool=None, k=TOPK_SPARSE):
query_tokens = query_tokenize(query)
if chunk_idx_pool is None:
scores = bm25.get_scores(query_tokens)
idxs = np.argsort(scores)[::-1][:k]
return idxs, np.array(scores)[idxs]
else:
scores = bm25.get_batch_scores(query_tokens, chunk_idx_pool)
idxs = np.argsort(scores)[::-1][:k]
idxs = [chunk_idx_pool[i] for i in idxs]
scores = np.array(scores)[np.argsort(scores)[::-1][:k]]
return idxs, scores
########### DENSE (COSINE) RETRIEVAL ##############
def dense_search(query, chunk_idx_pool=None, k=TOPK_DENSE, model=EMBED_MODEL):
# Query embedding and L2 normalization
resp = client.embeddings.create(input=query, model=model)
emb = np.array(resp.data[0].embedding, dtype="float32").reshape(1, -1)
emb = l2_normalize(emb, axis=1)
if chunk_idx_pool is not None:
# Pool-specific embeddings (normalized)
chunk_embs = np.array([chunks[i]['embedding'] for i in chunk_idx_pool], dtype='float32')
chunk_embs = l2_normalize(chunk_embs, axis=1)
faiss_subindex = faiss.IndexFlatL2(emb.shape[1])
faiss_subindex.add(chunk_embs)
dists, ranks = faiss_subindex.search(emb, k)
idxs = [chunk_idx_pool[i] for i in ranks[0]]
return idxs, dists[0]
else:
# All-vector index: assumed already using normalized embeddings
dists, ranks = faiss_index.search(emb, k)
return ranks[0], dists[0]
def hybrid_search(query, chunk_idx_pool=None, k_sparse=TOPK_SPARSE, k_dense=TOPK_DENSE, n_final=CONTEXT_CHUNKS):
sparse_idx, sparse_scores = sparse_search(query, chunk_idx_pool, k=k_sparse)
dense_idx, dense_dists = dense_search(query, chunk_idx_pool, k=k_dense)
# Ensure 1D numpy arrays
sparse_idx = np.atleast_1d(sparse_idx)
sparse_scores = np.atleast_1d(sparse_scores)
dense_idx = np.atleast_1d(dense_idx)
dense_dists = np.atleast_1d(dense_dists)
all_idx = set(sparse_idx) | set(dense_idx)
# RRF computation
k_rrf = 60 # adjust as needed (RRF constant)
sparse_ranks = {idx: rank for rank, idx in enumerate(sparse_idx)}
dense_ranks = {idx: rank for rank, idx in enumerate(dense_idx)}
hybrid_scores = {}
for idx in all_idx:
rr_bm25 = 1 / (k_rrf + sparse_ranks.get(idx, 9999))
rr_dense = 1 / (k_rrf + dense_ranks.get(idx, 9999))
hybrid_scores[idx] = rr_bm25 + rr_dense
best_idxs = sorted(hybrid_scores, key=hybrid_scores.get, reverse=True)[:n_final]
# Add preceding/following chunks for the top 3
extra_idxs = set()
for rank_idx in best_idxs[:3]:
chunk = chunks[rank_idx]
pid = chunk['paper_id']
cid = chunk['chunk_id']
for offset in [-1, 1]:
neighbor_id = cid + offset
neighbor = next((i for i in range(len(chunks))
if chunks[i]['paper_id'] == pid and chunks[i]['chunk_id'] == neighbor_id
and (chunk_idx_pool is None or i in chunk_idx_pool)), None)
if neighbor is not None:
extra_idxs.add(neighbor)
all_final_idxs = list(dict.fromkeys(list(best_idxs) + list(extra_idxs)))
selected_chunks = []
source_counts = {}
author_counts = {}
if chunk_idx_pool is None: # Only apply capping when searching all sources
max_per_source = math.ceil(n_final * 0.5)
max_per_author = math.ceil(n_final * 0.7)
else:
# If subset, no caps
max_per_source = max_per_author = n_final
for i in all_final_idxs:
if i < len(chunks) and (chunk_idx_pool is None or i in chunk_idx_pool):
chunk = chunks[i]
meta = chunk["metadata"]
source_id = (meta.get("title", ""), meta.get("authors", "")) # By title & authors (source)
author_id = meta.get("authors", "")
# Count how many from this source and author so far
s_count = source_counts.get(source_id, 0)
a_count = author_counts.get(author_id, 0)
# Enforce cap only if no source filter
if s_count >= max_per_source or a_count >= max_per_author:
continue
rationale = []
sparse_rank = sparse_ranks.get(i)
dense_rank = dense_ranks.get(i)
combined_rank = list(sorted(hybrid_scores, key=hybrid_scores.get, reverse=True)).index(i) if i in hybrid_scores else None
if sparse_rank is not None and sparse_rank < 3:
rationale.append("high sparse similarity (BM25 rank top-3)")
if dense_rank is not None and dense_rank < 3:
rationale.append("high dense similarity (embedding rank top-3)")
if combined_rank is not None and combined_rank < 3:
rationale.append("high combined score (RRF top-3)")
selected_chunk = dict(chunk) # shallow copy, to avoid mutating source
selected_chunk["retrieval_rationale"] = rationale if rationale else ["selected via hybrid search"]
selected_chunks.append(selected_chunk)
# Update counts
source_counts[source_id] = s_count + 1
author_counts[author_id] = a_count + 1
# Stop early if we have enough
if len(selected_chunks) >= n_final:
break
# --- Sort so that, within each paper_id, chunk_id is ascending ---
selected_chunks.sort(key=lambda c: (c['paper_id'], c['chunk_id']))
return selected_chunks
def build_context_prompt(selected_chunks):
out = []
for i, c in enumerate(selected_chunks, 1):
meta = c["metadata"]
citation = f'[{i}] Source: "{meta["title"]}" ({meta["authors"]}, {meta["year"]})'
chunk_info = f"(Chunk {c.get('chunk_id', '')}, Section: {c.get('section', '')})"
out.append(f"{citation} {chunk_info}\n{c['text'][:850]}{'...' if len(c['text'])>850 else ''}")
return "\n\n---\n\n".join(out)
def ask_llm(user_query, context_texts, model, temperature=0.3, reasoning_effort=None, max_output_tokens=1500):
prompt = f"""You are a helpful and rigorous research assistant. You assist social scientists in analyzing and synthesizing academic literature to answer research questions.
You are provided with CONTEXT from academic sources. Use only this information to answer the USER QUESTION. When referencing the context, quote the text directly and **always cite the source** using the following format: (Title, First Author, Year, Chunk #).
Your answer should be:
- Accurate, concise, and well-organized
- Written in coherent, formal academic prose
- Analytical in tone (aim to help users think critically about the literature)
- Grounded **strictly** in the provided context (do not add external knowledge)
Avoid:
- Bulleted lists
- Repetition
- Speculation beyond the given context
---
CONTEXT:
{context_texts}
USER QUESTION:
{user_query}
Answer:
"""
system_msg = (
"You are a research assistant helping social scientists understand and synthesize academic literature. Respond only based on the provided chunks of academic content. Always quote and cite your sources, using this format: (Title, First Author, Year, Chunk #). Your goal is to help clarify and connect insights from the literature with precision and depth."
)
if model in REASONING_MODELS:
reasoning_dict = None
if reasoning_effort is not None:
reasoning_dict = {"effort": reasoning_effort}
try:
resp = client.responses.create(
model=model,
input=[
{"role": "system", "content": system_msg},
{"role": "user", "content": prompt}
],
reasoning=reasoning_dict if reasoning_dict else None,
max_output_tokens=max_output_tokens,
)
# Check for truncated outputs
answer = resp.output_text.strip() if resp.output_text else ""
if resp.status == "incomplete" and hasattr(resp, "incomplete_details") and \
getattr(resp.incomplete_details, "reason", None) == "max_output_tokens":
answer += "\n\n[Warning: Response was cut off due to reaching the maximum output length. Try refining your question or reducing context size to get a more complete answer.]"
return answer
except Exception as e:
return f"Model `{model}` call failed: {e}"
else: # Use chat completions endpoint
try:
completions = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_msg},
{"role": "user", "content": prompt}
],
temperature=temperature,
max_tokens=max_output_tokens,
)
return completions.choices[0].message.content.strip()
except Exception as e:
return f"Model `{model}` call failed: {e}"
# --- Pricing Table (per 1M tokens): USD ---
MODEL_PRICING = {
"gpt-4.1": {"input": 2.00, "output": 8.00},
"gpt-4.1-mini": {"input": 0.40, "output": 1.60},
"o3": {"input": 10.00, "output": 40.00},
"o4-mini": {"input": 1.10, "output": 4.40},
}
model_label_map = {
"GPT-4.1": "gpt-4.1",
"GPT-4.1-mini": "gpt-4.1-mini",
"GPT o3 (reasoning)": "o3",
"GPT o4-mini (small reasoning)": "o4-mini",
}
model_friendly_names = list(model_label_map.keys())
# === STREAMLIT UI ===
st.title("Post-Neoliberalism Literature Review Gizmo")
st.markdown("Your question:")
row1_col1, row1_col2 = st.columns([6, 1])
with row1_col1:
question = st.text_area("Your question:", height=80, label_visibility="collapsed",)
with row1_col2:
ask_clicked = st.button("Ask 🔎")
st.markdown("---")
if "history" not in st.session_state:
st.session_state["history"] = []
# --- Settings UI ---
retrieval_col, llm_col = st.columns(2)
with retrieval_col:
st.subheader("Retrieval Settings")
selected_labels = st.multiselect(
"Select sources to search (default is _all_):",
source_labels,
default=[]
)
# chunk_idx_pool definition moves here:
chunk_idx_pool = None
if selected_labels:
selected_keys = [source_key_map[label] for label in selected_labels]
chunk_idx_pool = [i for key in selected_keys for i in source_groups[key]]
context_chunk_count = st.number_input(
"Number of chunks passed on to the LLM:",
min_value=3,
max_value=30,
value=15,
step=1
)
with llm_col:
st.subheader("LLM Settings")
selected_model_name = st.selectbox("Choose an OpenAI model:", model_friendly_names, index=0)
selected_model = model_label_map[selected_model_name]
# Max output tokens UI -- show as "words"
max_output_words = st.number_input(
"Max response length (# of words):",
min_value=50,
max_value=2000,
value=800,
step=50
)
# Advanced controls:
with st.expander("Advanced LLM Controls (Optional)"):
if selected_model not in TEMPERATURE_MODELS:
st.caption("Temperature is only used for GPT-4.1 and GPT-4.1-mini.")
temp_value = st.slider(
"Model randomness (temperature): Lower = more deterministic outputs (only GPT-4.1 and 4.1-mini)",
0.0, 0.5, value=0.3, step=0.05,
disabled=selected_model not in TEMPERATURE_MODELS,
key="temperature_slider"
)
if selected_model not in REASONING_MODELS:
st.caption("Reasoning effort is only used for o3 and o4-mini.")
reasoning_effort = st.selectbox(
"Reasoning effort (only for o3 and o4-mini):",
["default", "low", "medium", "high"],
index=2,
disabled=selected_model not in REASONING_MODELS,
key="reasoning_effort"
)
user_temperature = float(temp_value)
user_reasoning = reasoning_effort if reasoning_effort != "default" else None
# Convert words to tokens for API call (model-aware token multiplier)
if selected_model in REASONING_MODELS:
if user_reasoning == "low":
output_token_multiplier = 7
elif user_reasoning == "medium" or user_reasoning is None:
output_token_multiplier = 12
elif user_reasoning == "high":
output_token_multiplier = 18
else:
output_token_multiplier = 12 # default
else:
output_token_multiplier = 1.5
user_max_output_tokens = int(max_output_words * output_token_multiplier)
# --- Pricing estimate (dollars only) ---
chunk_token = 750 # ~500-600 words per chunk ≈ 750 tokens
input_tok = context_chunk_count * chunk_token + len(question.split()) * 1.3 + 1800
output_tok = user_max_output_tokens
rates = MODEL_PRICING[selected_model]
input_cost = (input_tok / 1_000_000) * rates["input"]
output_cost = (output_tok / 1_000_000) * rates["output"]
total_cost = input_cost + output_cost
# Show price estimate, turn red if over $1
if total_cost > 1:
st.error(f"**API cost estimate for this query:** ${total_cost:.3f}")
else:
st.info(f"**API cost estimate for this query:** ${total_cost:.3f}")
if ask_clicked and question.strip():
with st.spinner("Retrieving and generating answer..."):
# To use caching, chunk_idx_pool must be hashable (convert to tuple)
pool_tuple = tuple(chunk_idx_pool) if chunk_idx_pool is not None else None
relevant_chunks = cached_search(question, pool_tuple, context_chunk_count)
context = build_context_prompt(relevant_chunks)
answer = ask_llm(
question,
context,
model=selected_model,
temperature=user_temperature,
reasoning_effort=user_reasoning,
max_output_tokens=user_max_output_tokens
)
# Save both Q, A, and context chunks in chat history
st.session_state["history"].append({"role": "user", "content": question})
st.session_state["history"].append({
"role": "assistant",
"content": answer,
"context_chunks": relevant_chunks
})
st.header("Answer")
st.markdown(f"**Assistant:**\n\n{answer}")
with st.expander("Show evidence (retrieved chunks)"):
for i, c in enumerate(relevant_chunks, 1):
meta = c["metadata"]
st.write(
f"[{i}] {meta['title']} ({meta['authors']}, {meta['year']}) (Chunk {c.get('chunk_id', '')}):\n"
f"{c['text'][:500]}{'...' if len(c['text']) > 500 else ''}"
)
rationale = c.get('retrieval_rationale', [])
if rationale:
st.caption("Retrieval rationale: " + "; ".join(rationale))
st.markdown("---")
def render_chat_docx(history, with_chunks=True):
doc = Document()
doc.add_heading("Chat History Export", 0)
for turn in history:
if turn["role"] == "user":
doc.add_paragraph("You:", style="List Bullet").add_run(turn["content"]).bold = True
elif turn["role"] == "assistant":
para = doc.add_paragraph("Assistant:", style="List Bullet")
para.add_run(turn["content"])
if with_chunks and "context_chunks" in turn:
doc.add_paragraph("Evidence Chunks:", style="List Number")
for i, c in enumerate(turn["context_chunks"], 1):
meta = c["metadata"]
chunk_text = f"[{i}] {meta['title']} ({meta['authors']}, {meta['year']}) (Chunk {c.get('chunk_id', '')}):\n{c['text'][:400]}{'...' if len(c['text'])>400 else ''}"
doc.add_paragraph(chunk_text, style="List Continue")
return doc
# Layout for Chat History heading with export controls to the right
if "show_download_expander" not in st.session_state:
st.session_state["show_download_expander"] = False
chat_col, dl_col = st.columns([6, 2])
with chat_col:
st.header("Chat History")
with dl_col:
if st.button("**DOWNLOAD HISTORY**"):
st.session_state["show_download_expander"] = True
if st.session_state.get("show_download_expander", False):
with st.expander("Export options", expanded=True):
include_chunks = st.checkbox("Include context chunks in download", value=True, key="include_chunks_dl")
doc = render_chat_docx(st.session_state["history"], with_chunks=include_chunks)
tmpfile = io.BytesIO()
doc.save(tmpfile)
st.download_button(
label="Download DOCX",
data=tmpfile.getvalue(),
file_name="chat_history.docx",
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
key="download_chat_docx"
)
for turn in st.session_state["history"]:
if turn["role"] == "user":
st.write(f"**You:** {turn['content']}")
elif turn["role"] == "assistant":
st.write(f"**Assistant:** {turn['content']}")
# Show evidence as expandable if available
if "context_chunks" in turn:
with st.expander("Show retrieved chunks", expanded=False):
for i, c in enumerate(turn["context_chunks"], 1):
meta = c["metadata"]
st.write(
f"[{i}] {meta['title']} ({meta['authors']}, {meta['year']}) (Chunk {c.get('chunk_id', '')}):\n"
f"{c['text'][:500]}{'...' if len(c['text']) > 500 else ''}"
)
rationale = c.get('retrieval_rationale', [])
if rationale:
st.caption("Retrieval rationale: " + "; ".join(rationale))