survey-questionnaires / questionnaire_rag.py
umangchaudhry's picture
Upload 2 files
f81886e verified
import os
import json
import re
from typing import List, Dict, Any, Optional
from pathlib import Path
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_pinecone import PineconeVectorStore
from pinecone import Pinecone
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.output_parser import StrOutputParser
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
class PollQuestionnaireRAG:
"""Query pre-built vector store for poll questionnaires"""
def __init__(self, openai_api_key: str, persist_directory: str = "./questionnaire_vectorstores", verbose: bool = False):
self.openai_api_key = openai_api_key
self.persist_directory = persist_directory
self.verbose = verbose
# Get Pinecone API key from environment (set by the app)
pinecone_api_key = os.getenv("PINECONE_API_KEY")
if pinecone_api_key is None:
raise ValueError("PINECONE_API_KEY environment variable not set")
self.pinecone_api_key = pinecone_api_key
self.embeddings = OpenAIEmbeddings(model=os.getenv("EMBEDDING_MODEL", "text-embedding-3-small"))
chat_model = os.getenv("OPENAI_MODEL", "gpt-4o")
self.llm = ChatOpenAI(model=chat_model, temperature=0)
# Load pre-built vector store
if not os.path.exists(persist_directory):
raise ValueError(
f"Vector store not found at {persist_directory}\n"
"Run create_questionnaire_vectorstores.py first to create it"
)
if self.verbose:
print("Loading vector store...")
index_name = os.getenv("PINECONE_INDEX_NAME", "poll-questionnaire-index")
pc = Pinecone(api_key=self.pinecone_api_key)
self.index = pc.Index(index_name)
self.vectorstore = PineconeVectorStore(index=self.index, embedding=self.embeddings)
# Load catalog and questions index
self.poll_catalog = self._load_catalog()
self.questions_by_id = self._load_questions_index()
if self.verbose:
print(f"Loaded {len(self.questions_by_id)} questions from {len(self.poll_catalog)} polls")
self._print_catalog()
def _load_catalog(self) -> Dict[str, Dict]:
"""Load poll catalog"""
catalog_path = Path(self.persist_directory) / "poll_catalog.json"
if catalog_path.exists():
with open(catalog_path, 'r') as f:
return json.load(f)
return {}
def _load_questions_index(self) -> Dict[str, Dict]:
"""Load questions index"""
questions_path = Path(self.persist_directory) / "questions_index.json"
if questions_path.exists():
with open(questions_path, 'r') as f:
return json.load(f)
return {}
def _print_catalog(self):
"""Print loaded polls"""
print("\nAvailable Polls:")
for poll_date in sorted(self.poll_catalog.keys()):
info = self.poll_catalog[poll_date]
month_str = f" ({info['month']})" if info['month'] else ""
print(f" • {poll_date}{month_str}: {info['num_questions']} questions")
def _get_prompt(self) -> ChatPromptTemplate:
"""Create the system prompt - single source of truth"""
return ChatPromptTemplate.from_messages([
("system", """You are an expert assistant for analyzing poll questionnaires.
You have access to complete question data including:
- Question text and response options
- Poll date (year and month)
- Question sequence (what came before/after)
- Sibling variants (alternate versions for different respondent groups)
- Topics and question types
- Skip logic and sampling details
CRITICAL: When listing questions, ALWAYS include sampling information inline.
Guidelines for responding:
1. **When asked "what questions were asked" or similar listing requests:**
- List ALL the questions provided in order
- Include the full question text
- Include response options for each question
- **ALWAYS note sampling/skip logic inline** using clear language:
* "Asked to all respondents" (not "ASK ALL")
* "Asked to half the sample" (not "HALFSAMP1=1")
* "Asked only if respondent is Republican/voted/etc." (not "POLPARTY=1")
- If sibling variants exist, note "One of two versions shown to different groups"
- Use a clear, scannable format with sampling info clearly visible
Example format:
1. **Question text** (Asked to all respondents)
- Response options...
2. **Question text** (Asked to half the sample)
- Response options...
2. **When asked about specific topics or themes:**
- List relevant questions
- Explain what topics they cover
- Note any patterns or connections
- Include sampling information if relevant
- Include citation to the poll(s) used
3. **When asked about question sequence or context:**
- Explain what respondents experienced
- Note what came before/after specific questions
- Clarify if sequence varied by sampling group
4. **About sibling variants:**
- These are DIFFERENT versions shown to DIFFERENT groups
- Respondents never see both versions
- Always mention when they exist in question lists
5. **General style:**
- Use natural language - NO jargon like "HALFSAMP3=1" or "ASK ALL"
- Variable names (like VAND8A) are internal - mention only if asked
- Be proactive about explaining sampling - don't wait to be asked
- Always cite which poll(s) you're referencing
Available polls:
{catalog}
"""),
("human", """Context from relevant questions:
{context}
Question: {question}
Answer:""")
])
def query(self, question: str, k: int = None) -> str: # type: ignore
"""
Answer a question about poll questionnaires.
Args:
question: Natural language question about the polls
k: Number of documents to retrieve (auto-detected if None)
Returns:
String response from the LLM
"""
result = self._query_internal(question, k)
return result['answer']
def query_with_metadata(self, question: str, k: int = None) -> Dict[str, Any]: # type: ignore
"""
Answer a question and return full metadata.
Args:
question: Natural language question about the polls
k: Number of documents to retrieve (auto-detected if None)
Returns:
Dict with 'answer', 'source_questions', 'num_sources', 'filters_applied'
"""
return self._query_internal(question, k)
def _query_internal(self, question: str, k: int = None) -> Dict[str, Any]: # type: ignore
"""Internal query method used by both public methods"""
# Auto-detect if user wants complete list
q_lower = question.lower()
list_indicators = [
'what questions', 'list questions', 'all questions',
'show questions', 'which questions', 'questions were asked',
'questions asked', 'what was asked'
]
wants_complete_list = any(indicator in q_lower for indicator in list_indicators)
# Set k based on query type
if k is None:
k = 50 if wants_complete_list else 8
# Parse temporal filters from question
filters = self._extract_filters(question)
# Retrieve relevant documents
if filters:
retriever = self.vectorstore.as_retriever(
search_kwargs={"k": k, "filter": filters}
)
else:
retriever = self.vectorstore.as_retriever(search_kwargs={"k": k})
# Get documents
docs = retriever.invoke(question)
# Reconstruct full questions from IDs
full_questions = []
for doc in docs:
q_id = doc.metadata['question_id']
if q_id in self.questions_by_id:
full_questions.append(self.questions_by_id[q_id])
# Sort by position to maintain survey order
full_questions.sort(key=lambda q: q['position'])
context = self._format_context(full_questions)
prompt = self._get_prompt()
chain = (
{
"context": lambda x: context,
"question": RunnablePassthrough(),
"catalog": lambda x: self._get_catalog_summary()
}
| prompt
| self.llm
| StrOutputParser()
)
# Get answer
answer = chain.invoke(question)
return {
'answer': answer,
'source_questions': full_questions,
'num_sources': len(full_questions),
'filters_applied': filters
}
def _extract_filters(self, question: str) -> Dict[str, Any]:
"""Extract temporal and topic filters from question"""
filter_conditions = []
q_lower = question.lower()
# Year filter
year_match = re.search(r'\b(20\d{2})\b', question)
if year_match:
filter_conditions.append({"year": int(year_match.group(1))})
# Month filter
months = {
'january': 'January', 'february': 'February', 'march': 'March',
'april': 'April', 'may': 'May', 'june': 'June',
'july': 'July', 'august': 'August', 'september': 'September',
'october': 'October', 'november': 'November', 'december': 'December'
}
for month_lower, month_proper in months.items():
if month_lower in q_lower:
filter_conditions.append({"month": month_proper})
break
# Return proper Chroma filter syntax
if len(filter_conditions) == 0:
return {}
elif len(filter_conditions) == 1:
return filter_conditions[0]
else:
return {"$and": filter_conditions}
def _format_context(self, questions: List[Dict]) -> str:
"""Format full questions as context for LLM"""
if not questions:
return "No relevant questions found."
context_parts = []
for q in questions:
part = f"""
--- Question {q['position'] + 1} from {q['survey_name']} ({q['poll_date']}) ---
Variable: {q['variable_name']}
Question: {q['question_text']}
Response Options: {' | '.join(q['response_options'])}
Topics: {', '.join(q['topics'])}
Question Type: {q['question_type']}
Administration: {q['ask_condition']}
"""
# Add skip logic/sampling info
if q.get('skip_logic'):
part += f"Skip Logic: {q['skip_logic']}\n"
if q.get('half_sample_group'):
part += f"Half Sample Group: {q['half_sample_group']}\n"
# Add sibling variants
if q.get('sibling_variants'):
part += f"\nAlternate Versions (shown to different groups):\n"
for sib in q['sibling_variants']:
sib_group = sib.get('half_sample_group', 'other group')
part += f" - [{sib_group}] {sib['question_text']}\n"
# Add sequence context
if q.get('previous_question'):
prev_vars = q.get('previous_question_variants', [])
if len(prev_vars) > 1:
part += "\nPrevious Question (respondents saw one of these):\n"
for pv in prev_vars:
part += f" - {pv['question_text']}\n"
else:
part += f"\nPrevious Question: {q['previous_question']['question_text']}\n"
if q.get('next_question'):
next_vars = q.get('next_question_variants', [])
if len(next_vars) > 1:
part += "\nNext Question (respondents saw one of these):\n"
for nv in next_vars:
part += f" - {nv['question_text']}\n"
else:
part += f"\nNext Question: {q['next_question']['question_text']}\n"
context_parts.append(part.strip())
return "\n\n".join(context_parts)
def _get_catalog_summary(self) -> str:
"""Get summary of available polls"""
lines = []
for poll_date in sorted(self.poll_catalog.keys()):
info = self.poll_catalog[poll_date]
month_str = f" ({info['month']})" if info['month'] else ""
lines.append(f"- {poll_date}{month_str}: {info['num_questions']} questions")
return "\n".join(lines)
def get_question_by_variable(self, variable_name: str) -> Optional[Dict]:
"""Get a specific question by variable name"""
for q_id, q in self.questions_by_id.items():
if q['variable_name'] == variable_name:
return q
return None
def main():
"""CLI interface"""
import sys
# Get API key
openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key:
print("Error: OPENAI_API_KEY not set")
print("Set it with: export OPENAI_API_KEY='your-key'")
sys.exit(1)
pinecone_api_key = os.getenv("PINECONE_API_KEY")
if not pinecone_api_key:
print("Error: PINECONE_API_KEY not set")
print("Set it with: export PINECONE_API_KEY='your-key'")
sys.exit(1)
# Initialize
try:
rag = PollQuestionnaireRAG(
openai_api_key=openai_api_key, verbose=True)
except ValueError as e:
print(f"\nError: {e}")
sys.exit(1)
# Interactive mode if no query provided
if len(sys.argv) < 2:
print("\n" + "="*80)
print("Interactive Mode - Type 'quit' to exit")
print("="*80 + "\n")
print("Example questions:")
print(" • What questions have been asked about gun violence?")
print(" • What questions about the economy were asked in June 2025?")
print(" • What came before VAND8A?")
print(" • Show me all immigration questions")
while True:
try:
question = input("\n\nYour question: ").strip()
if not question or question.lower() in ['quit', 'exit', 'q']:
break
print("\nThinking...\n")
answer = rag.query(question)
print(answer)
except KeyboardInterrupt:
print("\n\nGoodbye!")
break
except Exception as e:
print(f"Error: {e}")
# Single query mode
else:
query = " ".join(sys.argv[1:])
print(f"\nQuery: {query}")
print("="*80 + "\n")
answer = rag.query(query)
print(answer)
if __name__ == "__main__":
main()