Spaces:
Runtime error
Runtime error
| # activate venv: source .venv/bin/activate | |
| ### To stage, commit, and push after edits, save file and then run in terminal: | |
| ## stage | |
| # git add rag_ui.py | |
| ## commit | |
| # git commit -m "EXPLAIN CHANGES" | |
| ## push | |
| # git push hf main | |
| import streamlit as st | |
| st.set_page_config(page_title="Post-Neoliberalism Literature RAG", layout="centered") | |
| import os | |
| import json | |
| import numpy as np | |
| import faiss | |
| from openai import OpenAI | |
| import re | |
| import gzip | |
| from huggingface_hub import hf_hub_download | |
| from rank_bm25 import BM25Okapi | |
| import io | |
| from docx import Document | |
| import hashlib | |
| import math | |
| # Caching for search results function | |
| def cached_search(query, chunk_idx_pool_tuple, n_final): | |
| return hybrid_search(query, chunk_idx_pool=list(chunk_idx_pool_tuple) if chunk_idx_pool_tuple else None, n_final=n_final) | |
| ############### TOKENIZER AND NORM FUNCTION ############## | |
| def query_tokenize(text): | |
| return re.findall(r"\w+", text.lower()) | |
| def l2_normalize(vecs, axis=1, epsilon=1e-10): | |
| norms = np.linalg.norm(vecs, ord=2, axis=axis, keepdims=True) | |
| return vecs / (norms + epsilon) | |
| ############# DOWNLOAD DATA AND INDEX ############## | |
| print("Checking /tmp/ directory...") | |
| print("Exists?", os.path.exists("/tmp")) | |
| print("Writeable?", os.access("/tmp", os.W_OK)) | |
| print("Listing:", os.listdir("/tmp")) | |
| HF_USERNAME = "mkegel" | |
| HF_REPONAME = "post_n_RAG_chunks" | |
| chunks_gz = hf_hub_download( | |
| repo_id=f"{HF_USERNAME}/{HF_REPONAME}", | |
| filename="zotero_chunks_with_embeddings.json.gz", | |
| repo_type="dataset" | |
| ) | |
| faiss_gz = hf_hub_download( | |
| repo_id=f"{HF_USERNAME}/{HF_REPONAME}", | |
| filename="zotero_chunks.index.gz", | |
| repo_type="dataset" | |
| ) | |
| ### PARAMETERS ### | |
| EMBED_MODEL = "text-embedding-3-large" | |
| TOPK_SPARSE = 20 | |
| TOPK_DENSE = 20 | |
| CONTEXT_CHUNKS = 15 | |
| REASONING_MODELS = {"o3", "o4-mini"} # Models using responses endpoint and reasoning | |
| TEMPERATURE_MODELS = {"gpt-4.1", "gpt-4.1-mini"} # Models using completions endpoint with temp | |
| # --- Load Data --- | |
| client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) | |
| def load_search_data(): | |
| with gzip.open(chunks_gz, "rt", encoding="utf-8") as f: | |
| chunks = json.load(f) | |
| with gzip.open(faiss_gz, "rb") as fidx: | |
| with open("/tmp/zotero_chunks.index", "wb") as fout: | |
| fout.write(fidx.read()) | |
| faiss_index = faiss.read_index("/tmp/zotero_chunks.index") | |
| # get tokens for BM25 | |
| tokenized_texts = [c["tokens"] for c in chunks] | |
| bm25 = BM25Okapi(tokenized_texts) | |
| return chunks, faiss_index, bm25 | |
| chunks, faiss_index, bm25 = load_search_data() | |
| # --- Utility to build author/title list for dropdown --- | |
| def primary_author(authors_string): | |
| # Parse "Lastname, Firstname; ..." and return the first last name (for sorting) | |
| if not authors_string: return "" | |
| match = re.match(r"([^,; ]+)", authors_string.strip()) | |
| return (match.group(1) or "").strip().lower() if match else authors_string.strip().lower() | |
| def make_source_label(chunk): | |
| meta = chunk["metadata"] | |
| first_author_last = primary_author(meta["authors"]) | |
| return f"{first_author_last.title()}, {meta['authors']} - \"{meta['title']}\" ({meta['year']})" | |
| source_groups = {} # (author_last, title) -> [chunk_indices] | |
| for idx, c in enumerate(chunks): | |
| last = primary_author(c["metadata"]["authors"]) | |
| title = c["metadata"]["title"] | |
| key = (last, title) | |
| if key not in source_groups: | |
| source_groups[key] = [] | |
| source_groups[key].append(idx) | |
| sources_sorted = sorted(source_groups.keys(), key=lambda x: (x[0], x[1].lower())) | |
| source_labels = [f"{author.title()} - \"{title}\"" for author, title in sources_sorted] | |
| source_key_map = dict(zip(source_labels, sources_sorted)) # Map label to (author_last, title) | |
| # --- Retrieval Functions --- | |
| ########### BM25-BASED SPARSE SEARCH ########### | |
| def sparse_search(query, chunk_idx_pool=None, k=TOPK_SPARSE): | |
| query_tokens = query_tokenize(query) | |
| if chunk_idx_pool is None: | |
| scores = bm25.get_scores(query_tokens) | |
| idxs = np.argsort(scores)[::-1][:k] | |
| return idxs, np.array(scores)[idxs] | |
| else: | |
| scores = bm25.get_batch_scores(query_tokens, chunk_idx_pool) | |
| idxs = np.argsort(scores)[::-1][:k] | |
| idxs = [chunk_idx_pool[i] for i in idxs] | |
| scores = np.array(scores)[np.argsort(scores)[::-1][:k]] | |
| return idxs, scores | |
| ########### DENSE (COSINE) RETRIEVAL ############## | |
| def dense_search(query, chunk_idx_pool=None, k=TOPK_DENSE, model=EMBED_MODEL): | |
| # Query embedding and L2 normalization | |
| resp = client.embeddings.create(input=query, model=model) | |
| emb = np.array(resp.data[0].embedding, dtype="float32").reshape(1, -1) | |
| emb = l2_normalize(emb, axis=1) | |
| if chunk_idx_pool is not None: | |
| # Pool-specific embeddings (normalized) | |
| chunk_embs = np.array([chunks[i]['embedding'] for i in chunk_idx_pool], dtype='float32') | |
| chunk_embs = l2_normalize(chunk_embs, axis=1) | |
| faiss_subindex = faiss.IndexFlatL2(emb.shape[1]) | |
| faiss_subindex.add(chunk_embs) | |
| dists, ranks = faiss_subindex.search(emb, k) | |
| idxs = [chunk_idx_pool[i] for i in ranks[0]] | |
| return idxs, dists[0] | |
| else: | |
| # All-vector index: assumed already using normalized embeddings | |
| dists, ranks = faiss_index.search(emb, k) | |
| return ranks[0], dists[0] | |
| def hybrid_search(query, chunk_idx_pool=None, k_sparse=TOPK_SPARSE, k_dense=TOPK_DENSE, n_final=CONTEXT_CHUNKS): | |
| sparse_idx, sparse_scores = sparse_search(query, chunk_idx_pool, k=k_sparse) | |
| dense_idx, dense_dists = dense_search(query, chunk_idx_pool, k=k_dense) | |
| # Ensure 1D numpy arrays | |
| sparse_idx = np.atleast_1d(sparse_idx) | |
| sparse_scores = np.atleast_1d(sparse_scores) | |
| dense_idx = np.atleast_1d(dense_idx) | |
| dense_dists = np.atleast_1d(dense_dists) | |
| all_idx = set(sparse_idx) | set(dense_idx) | |
| # RRF computation | |
| k_rrf = 60 # adjust as needed (RRF constant) | |
| sparse_ranks = {idx: rank for rank, idx in enumerate(sparse_idx)} | |
| dense_ranks = {idx: rank for rank, idx in enumerate(dense_idx)} | |
| hybrid_scores = {} | |
| for idx in all_idx: | |
| rr_bm25 = 1 / (k_rrf + sparse_ranks.get(idx, 9999)) | |
| rr_dense = 1 / (k_rrf + dense_ranks.get(idx, 9999)) | |
| hybrid_scores[idx] = rr_bm25 + rr_dense | |
| best_idxs = sorted(hybrid_scores, key=hybrid_scores.get, reverse=True)[:n_final] | |
| # Add preceding/following chunks for the top 3 | |
| extra_idxs = set() | |
| for rank_idx in best_idxs[:3]: | |
| chunk = chunks[rank_idx] | |
| pid = chunk['paper_id'] | |
| cid = chunk['chunk_id'] | |
| for offset in [-1, 1]: | |
| neighbor_id = cid + offset | |
| neighbor = next((i for i in range(len(chunks)) | |
| if chunks[i]['paper_id'] == pid and chunks[i]['chunk_id'] == neighbor_id | |
| and (chunk_idx_pool is None or i in chunk_idx_pool)), None) | |
| if neighbor is not None: | |
| extra_idxs.add(neighbor) | |
| all_final_idxs = list(dict.fromkeys(list(best_idxs) + list(extra_idxs))) | |
| selected_chunks = [] | |
| source_counts = {} | |
| author_counts = {} | |
| if chunk_idx_pool is None: # Only apply capping when searching all sources | |
| max_per_source = math.ceil(n_final * 0.5) | |
| max_per_author = math.ceil(n_final * 0.7) | |
| else: | |
| # If subset, no caps | |
| max_per_source = max_per_author = n_final | |
| for i in all_final_idxs: | |
| if i < len(chunks) and (chunk_idx_pool is None or i in chunk_idx_pool): | |
| chunk = chunks[i] | |
| meta = chunk["metadata"] | |
| source_id = (meta.get("title", ""), meta.get("authors", "")) # By title & authors (source) | |
| author_id = meta.get("authors", "") | |
| # Count how many from this source and author so far | |
| s_count = source_counts.get(source_id, 0) | |
| a_count = author_counts.get(author_id, 0) | |
| # Enforce cap only if no source filter | |
| if s_count >= max_per_source or a_count >= max_per_author: | |
| continue | |
| rationale = [] | |
| sparse_rank = sparse_ranks.get(i) | |
| dense_rank = dense_ranks.get(i) | |
| combined_rank = list(sorted(hybrid_scores, key=hybrid_scores.get, reverse=True)).index(i) if i in hybrid_scores else None | |
| if sparse_rank is not None and sparse_rank < 3: | |
| rationale.append("high sparse similarity (BM25 rank top-3)") | |
| if dense_rank is not None and dense_rank < 3: | |
| rationale.append("high dense similarity (embedding rank top-3)") | |
| if combined_rank is not None and combined_rank < 3: | |
| rationale.append("high combined score (RRF top-3)") | |
| selected_chunk = dict(chunk) # shallow copy, to avoid mutating source | |
| selected_chunk["retrieval_rationale"] = rationale if rationale else ["selected via hybrid search"] | |
| selected_chunks.append(selected_chunk) | |
| # Update counts | |
| source_counts[source_id] = s_count + 1 | |
| author_counts[author_id] = a_count + 1 | |
| # Stop early if we have enough | |
| if len(selected_chunks) >= n_final: | |
| break | |
| # --- Sort so that, within each paper_id, chunk_id is ascending --- | |
| selected_chunks.sort(key=lambda c: (c['paper_id'], c['chunk_id'])) | |
| return selected_chunks | |
| def build_context_prompt(selected_chunks): | |
| out = [] | |
| for i, c in enumerate(selected_chunks, 1): | |
| meta = c["metadata"] | |
| citation = f'[{i}] Source: "{meta["title"]}" ({meta["authors"]}, {meta["year"]})' | |
| chunk_info = f"(Chunk {c.get('chunk_id', '')}, Section: {c.get('section', '')})" | |
| out.append(f"{citation} {chunk_info}\n{c['text'][:850]}{'...' if len(c['text'])>850 else ''}") | |
| return "\n\n---\n\n".join(out) | |
| def ask_llm(user_query, context_texts, model, temperature=0.3, reasoning_effort=None, max_output_tokens=1500): | |
| prompt = f"""You are a helpful and rigorous research assistant. You assist social scientists in analyzing and synthesizing academic literature to answer research questions. | |
| You are provided with CONTEXT from academic sources. Use only this information to answer the USER QUESTION. When referencing the context, quote the text directly and **always cite the source** using the following format: (Title, First Author, Year, Chunk #). | |
| Your answer should be: | |
| - Accurate, concise, and well-organized | |
| - Written in coherent, formal academic prose | |
| - Analytical in tone (aim to help users think critically about the literature) | |
| - Grounded **strictly** in the provided context (do not add external knowledge) | |
| Avoid: | |
| - Bulleted lists | |
| - Repetition | |
| - Speculation beyond the given context | |
| --- | |
| CONTEXT: | |
| {context_texts} | |
| USER QUESTION: | |
| {user_query} | |
| Answer: | |
| """ | |
| system_msg = ( | |
| "You are a research assistant helping social scientists understand and synthesize academic literature. Respond only based on the provided chunks of academic content. Always quote and cite your sources, using this format: (Title, First Author, Year, Chunk #). Your goal is to help clarify and connect insights from the literature with precision and depth." | |
| ) | |
| if model in REASONING_MODELS: | |
| reasoning_dict = None | |
| if reasoning_effort is not None: | |
| reasoning_dict = {"effort": reasoning_effort} | |
| try: | |
| resp = client.responses.create( | |
| model=model, | |
| input=[ | |
| {"role": "system", "content": system_msg}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| reasoning=reasoning_dict if reasoning_dict else None, | |
| max_output_tokens=max_output_tokens, | |
| ) | |
| # Check for truncated outputs | |
| answer = resp.output_text.strip() if resp.output_text else "" | |
| if resp.status == "incomplete" and hasattr(resp, "incomplete_details") and \ | |
| getattr(resp.incomplete_details, "reason", None) == "max_output_tokens": | |
| answer += "\n\n[Warning: Response was cut off due to reaching the maximum output length. Try refining your question or reducing context size to get a more complete answer.]" | |
| return answer | |
| except Exception as e: | |
| return f"Model `{model}` call failed: {e}" | |
| else: # Use chat completions endpoint | |
| try: | |
| completions = client.chat.completions.create( | |
| model=model, | |
| messages=[ | |
| {"role": "system", "content": system_msg}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| temperature=temperature, | |
| max_tokens=max_output_tokens, | |
| ) | |
| return completions.choices[0].message.content.strip() | |
| except Exception as e: | |
| return f"Model `{model}` call failed: {e}" | |
| # --- Pricing Table (per 1M tokens): USD --- | |
| MODEL_PRICING = { | |
| "gpt-4.1": {"input": 2.00, "output": 8.00}, | |
| "gpt-4.1-mini": {"input": 0.40, "output": 1.60}, | |
| "o3": {"input": 10.00, "output": 40.00}, | |
| "o4-mini": {"input": 1.10, "output": 4.40}, | |
| } | |
| model_label_map = { | |
| "GPT-4.1": "gpt-4.1", | |
| "GPT-4.1-mini": "gpt-4.1-mini", | |
| "GPT o3 (reasoning)": "o3", | |
| "GPT o4-mini (small reasoning)": "o4-mini", | |
| } | |
| model_friendly_names = list(model_label_map.keys()) | |
| # === STREAMLIT UI === | |
| st.title("Post-Neoliberalism Literature Review Gizmo") | |
| st.markdown("Your question:") | |
| row1_col1, row1_col2 = st.columns([6, 1]) | |
| with row1_col1: | |
| question = st.text_area("Your question:", height=80, label_visibility="collapsed",) | |
| with row1_col2: | |
| ask_clicked = st.button("Ask 🔎") | |
| st.markdown("---") | |
| if "history" not in st.session_state: | |
| st.session_state["history"] = [] | |
| # --- Settings UI --- | |
| retrieval_col, llm_col = st.columns(2) | |
| with retrieval_col: | |
| st.subheader("Retrieval Settings") | |
| selected_labels = st.multiselect( | |
| "Select sources to search (default is _all_):", | |
| source_labels, | |
| default=[] | |
| ) | |
| # chunk_idx_pool definition moves here: | |
| chunk_idx_pool = None | |
| if selected_labels: | |
| selected_keys = [source_key_map[label] for label in selected_labels] | |
| chunk_idx_pool = [i for key in selected_keys for i in source_groups[key]] | |
| context_chunk_count = st.number_input( | |
| "Number of chunks passed on to the LLM:", | |
| min_value=3, | |
| max_value=30, | |
| value=15, | |
| step=1 | |
| ) | |
| with llm_col: | |
| st.subheader("LLM Settings") | |
| selected_model_name = st.selectbox("Choose an OpenAI model:", model_friendly_names, index=0) | |
| selected_model = model_label_map[selected_model_name] | |
| # Max output tokens UI -- show as "words" | |
| max_output_words = st.number_input( | |
| "Max response length (# of words):", | |
| min_value=50, | |
| max_value=2000, | |
| value=800, | |
| step=50 | |
| ) | |
| # Advanced controls: | |
| with st.expander("Advanced LLM Controls (Optional)"): | |
| if selected_model not in TEMPERATURE_MODELS: | |
| st.caption("Temperature is only used for GPT-4.1 and GPT-4.1-mini.") | |
| temp_value = st.slider( | |
| "Model randomness (temperature): Lower = more deterministic outputs (only GPT-4.1 and 4.1-mini)", | |
| 0.0, 0.5, value=0.3, step=0.05, | |
| disabled=selected_model not in TEMPERATURE_MODELS, | |
| key="temperature_slider" | |
| ) | |
| if selected_model not in REASONING_MODELS: | |
| st.caption("Reasoning effort is only used for o3 and o4-mini.") | |
| reasoning_effort = st.selectbox( | |
| "Reasoning effort (only for o3 and o4-mini):", | |
| ["default", "low", "medium", "high"], | |
| index=2, | |
| disabled=selected_model not in REASONING_MODELS, | |
| key="reasoning_effort" | |
| ) | |
| user_temperature = float(temp_value) | |
| user_reasoning = reasoning_effort if reasoning_effort != "default" else None | |
| # Convert words to tokens for API call (model-aware token multiplier) | |
| if selected_model in REASONING_MODELS: | |
| if user_reasoning == "low": | |
| output_token_multiplier = 7 | |
| elif user_reasoning == "medium" or user_reasoning is None: | |
| output_token_multiplier = 12 | |
| elif user_reasoning == "high": | |
| output_token_multiplier = 18 | |
| else: | |
| output_token_multiplier = 12 # default | |
| else: | |
| output_token_multiplier = 1.5 | |
| user_max_output_tokens = int(max_output_words * output_token_multiplier) | |
| # --- Pricing estimate (dollars only) --- | |
| chunk_token = 750 # ~500-600 words per chunk ≈ 750 tokens | |
| input_tok = context_chunk_count * chunk_token + len(question.split()) * 1.3 + 1800 | |
| output_tok = user_max_output_tokens | |
| rates = MODEL_PRICING[selected_model] | |
| input_cost = (input_tok / 1_000_000) * rates["input"] | |
| output_cost = (output_tok / 1_000_000) * rates["output"] | |
| total_cost = input_cost + output_cost | |
| # Show price estimate, turn red if over $1 | |
| if total_cost > 1: | |
| st.error(f"**API cost estimate for this query:** ${total_cost:.3f}") | |
| else: | |
| st.info(f"**API cost estimate for this query:** ${total_cost:.3f}") | |
| if ask_clicked and question.strip(): | |
| with st.spinner("Retrieving and generating answer..."): | |
| # To use caching, chunk_idx_pool must be hashable (convert to tuple) | |
| pool_tuple = tuple(chunk_idx_pool) if chunk_idx_pool is not None else None | |
| relevant_chunks = cached_search(question, pool_tuple, context_chunk_count) | |
| context = build_context_prompt(relevant_chunks) | |
| answer = ask_llm( | |
| question, | |
| context, | |
| model=selected_model, | |
| temperature=user_temperature, | |
| reasoning_effort=user_reasoning, | |
| max_output_tokens=user_max_output_tokens | |
| ) | |
| # Save both Q, A, and context chunks in chat history | |
| st.session_state["history"].append({"role": "user", "content": question}) | |
| st.session_state["history"].append({ | |
| "role": "assistant", | |
| "content": answer, | |
| "context_chunks": relevant_chunks | |
| }) | |
| st.header("Answer") | |
| st.markdown(f"**Assistant:**\n\n{answer}") | |
| with st.expander("Show evidence (retrieved chunks)"): | |
| for i, c in enumerate(relevant_chunks, 1): | |
| meta = c["metadata"] | |
| st.write( | |
| f"[{i}] {meta['title']} ({meta['authors']}, {meta['year']}) (Chunk {c.get('chunk_id', '')}):\n" | |
| f"{c['text'][:500]}{'...' if len(c['text']) > 500 else ''}" | |
| ) | |
| rationale = c.get('retrieval_rationale', []) | |
| if rationale: | |
| st.caption("Retrieval rationale: " + "; ".join(rationale)) | |
| st.markdown("---") | |
| def render_chat_docx(history, with_chunks=True): | |
| doc = Document() | |
| doc.add_heading("Chat History Export", 0) | |
| for turn in history: | |
| if turn["role"] == "user": | |
| doc.add_paragraph("You:", style="List Bullet").add_run(turn["content"]).bold = True | |
| elif turn["role"] == "assistant": | |
| para = doc.add_paragraph("Assistant:", style="List Bullet") | |
| para.add_run(turn["content"]) | |
| if with_chunks and "context_chunks" in turn: | |
| doc.add_paragraph("Evidence Chunks:", style="List Number") | |
| for i, c in enumerate(turn["context_chunks"], 1): | |
| meta = c["metadata"] | |
| chunk_text = f"[{i}] {meta['title']} ({meta['authors']}, {meta['year']}) (Chunk {c.get('chunk_id', '')}):\n{c['text'][:400]}{'...' if len(c['text'])>400 else ''}" | |
| doc.add_paragraph(chunk_text, style="List Continue") | |
| return doc | |
| # Layout for Chat History heading with export controls to the right | |
| if "show_download_expander" not in st.session_state: | |
| st.session_state["show_download_expander"] = False | |
| chat_col, dl_col = st.columns([6, 2]) | |
| with chat_col: | |
| st.header("Chat History") | |
| with dl_col: | |
| if st.button("**DOWNLOAD HISTORY**"): | |
| st.session_state["show_download_expander"] = True | |
| if st.session_state.get("show_download_expander", False): | |
| with st.expander("Export options", expanded=True): | |
| include_chunks = st.checkbox("Include context chunks in download", value=True, key="include_chunks_dl") | |
| doc = render_chat_docx(st.session_state["history"], with_chunks=include_chunks) | |
| tmpfile = io.BytesIO() | |
| doc.save(tmpfile) | |
| st.download_button( | |
| label="Download DOCX", | |
| data=tmpfile.getvalue(), | |
| file_name="chat_history.docx", | |
| mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document", | |
| key="download_chat_docx" | |
| ) | |
| for turn in st.session_state["history"]: | |
| if turn["role"] == "user": | |
| st.write(f"**You:** {turn['content']}") | |
| elif turn["role"] == "assistant": | |
| st.write(f"**Assistant:** {turn['content']}") | |
| # Show evidence as expandable if available | |
| if "context_chunks" in turn: | |
| with st.expander("Show retrieved chunks", expanded=False): | |
| for i, c in enumerate(turn["context_chunks"], 1): | |
| meta = c["metadata"] | |
| st.write( | |
| f"[{i}] {meta['title']} ({meta['authors']}, {meta['year']}) (Chunk {c.get('chunk_id', '')}):\n" | |
| f"{c['text'][:500]}{'...' if len(c['text']) > 500 else ''}" | |
| ) | |
| rationale = c.get('retrieval_rationale', []) | |
| if rationale: | |
| st.caption("Retrieval rationale: " + "; ".join(rationale)) | |