Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import re | |
| from typing import List, Dict, Any, Optional | |
| from pathlib import Path | |
| from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
| from langchain_pinecone import PineconeVectorStore | |
| from pinecone import Pinecone | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain.schema.runnable import RunnablePassthrough | |
| from langchain.schema.output_parser import StrOutputParser | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except ImportError: | |
| pass | |
| class PollQuestionnaireRAG: | |
| """Query pre-built vector store for poll questionnaires""" | |
| def __init__(self, openai_api_key: str, persist_directory: str = "./questionnaire_vectorstores", verbose: bool = False): | |
| self.openai_api_key = openai_api_key | |
| self.persist_directory = persist_directory | |
| self.verbose = verbose | |
| # Get Pinecone API key from environment (set by the app) | |
| pinecone_api_key = os.getenv("PINECONE_API_KEY") | |
| if pinecone_api_key is None: | |
| raise ValueError("PINECONE_API_KEY environment variable not set") | |
| self.pinecone_api_key = pinecone_api_key | |
| self.embeddings = OpenAIEmbeddings(model=os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")) | |
| chat_model = os.getenv("OPENAI_MODEL", "gpt-4o") | |
| self.llm = ChatOpenAI(model=chat_model, temperature=0) | |
| # Load pre-built vector store | |
| if not os.path.exists(persist_directory): | |
| raise ValueError( | |
| f"Vector store not found at {persist_directory}\n" | |
| "Run create_questionnaire_vectorstores.py first to create it" | |
| ) | |
| if self.verbose: | |
| print("Loading vector store...") | |
| index_name = os.getenv("PINECONE_INDEX_NAME", "poll-questionnaire-index") | |
| pc = Pinecone(api_key=self.pinecone_api_key) | |
| self.index = pc.Index(index_name) | |
| self.vectorstore = PineconeVectorStore(index=self.index, embedding=self.embeddings) | |
| # Load catalog and questions index | |
| self.poll_catalog = self._load_catalog() | |
| self.questions_by_id = self._load_questions_index() | |
| if self.verbose: | |
| print(f"Loaded {len(self.questions_by_id)} questions from {len(self.poll_catalog)} polls") | |
| self._print_catalog() | |
| def _load_catalog(self) -> Dict[str, Dict]: | |
| """Load poll catalog""" | |
| catalog_path = Path(self.persist_directory) / "poll_catalog.json" | |
| if catalog_path.exists(): | |
| with open(catalog_path, 'r') as f: | |
| return json.load(f) | |
| return {} | |
| def _load_questions_index(self) -> Dict[str, Dict]: | |
| """Load questions index""" | |
| questions_path = Path(self.persist_directory) / "questions_index.json" | |
| if questions_path.exists(): | |
| with open(questions_path, 'r') as f: | |
| return json.load(f) | |
| return {} | |
| def _print_catalog(self): | |
| """Print loaded polls""" | |
| print("\nAvailable Polls:") | |
| for poll_date in sorted(self.poll_catalog.keys()): | |
| info = self.poll_catalog[poll_date] | |
| month_str = f" ({info['month']})" if info['month'] else "" | |
| print(f" • {poll_date}{month_str}: {info['num_questions']} questions") | |
| def _get_prompt(self) -> ChatPromptTemplate: | |
| """Create the system prompt - single source of truth""" | |
| return ChatPromptTemplate.from_messages([ | |
| ("system", """You are an expert assistant for analyzing poll questionnaires. | |
| You have access to complete question data including: | |
| - Question text and response options | |
| - Poll date (year and month) | |
| - Question sequence (what came before/after) | |
| - Sibling variants (alternate versions for different respondent groups) | |
| - Topics and question types | |
| - Skip logic and sampling details | |
| CRITICAL: When listing questions, ALWAYS include sampling information inline. | |
| Guidelines for responding: | |
| 1. **When asked "what questions were asked" or similar listing requests:** | |
| - List ALL the questions provided in order | |
| - Include the full question text | |
| - Include response options for each question | |
| - **ALWAYS note sampling/skip logic inline** using clear language: | |
| * "Asked to all respondents" (not "ASK ALL") | |
| * "Asked to half the sample" (not "HALFSAMP1=1") | |
| * "Asked only if respondent is Republican/voted/etc." (not "POLPARTY=1") | |
| - If sibling variants exist, note "One of two versions shown to different groups" | |
| - Use a clear, scannable format with sampling info clearly visible | |
| Example format: | |
| 1. **Question text** (Asked to all respondents) | |
| - Response options... | |
| 2. **Question text** (Asked to half the sample) | |
| - Response options... | |
| 2. **When asked about specific topics or themes:** | |
| - List relevant questions | |
| - Explain what topics they cover | |
| - Note any patterns or connections | |
| - Include sampling information if relevant | |
| - Include citation to the poll(s) used | |
| 3. **When asked about question sequence or context:** | |
| - Explain what respondents experienced | |
| - Note what came before/after specific questions | |
| - Clarify if sequence varied by sampling group | |
| 4. **About sibling variants:** | |
| - These are DIFFERENT versions shown to DIFFERENT groups | |
| - Respondents never see both versions | |
| - Always mention when they exist in question lists | |
| 5. **General style:** | |
| - Use natural language - NO jargon like "HALFSAMP3=1" or "ASK ALL" | |
| - Variable names (like VAND8A) are internal - mention only if asked | |
| - Be proactive about explaining sampling - don't wait to be asked | |
| - Always cite which poll(s) you're referencing | |
| Available polls: | |
| {catalog} | |
| """), | |
| ("human", """Context from relevant questions: | |
| {context} | |
| Question: {question} | |
| Answer:""") | |
| ]) | |
| def query(self, question: str, k: int = None) -> str: # type: ignore | |
| """ | |
| Answer a question about poll questionnaires. | |
| Args: | |
| question: Natural language question about the polls | |
| k: Number of documents to retrieve (auto-detected if None) | |
| Returns: | |
| String response from the LLM | |
| """ | |
| result = self._query_internal(question, k) | |
| return result['answer'] | |
| def query_with_metadata(self, question: str, k: int = None) -> Dict[str, Any]: # type: ignore | |
| """ | |
| Answer a question and return full metadata. | |
| Args: | |
| question: Natural language question about the polls | |
| k: Number of documents to retrieve (auto-detected if None) | |
| Returns: | |
| Dict with 'answer', 'source_questions', 'num_sources', 'filters_applied' | |
| """ | |
| return self._query_internal(question, k) | |
| def _query_internal(self, question: str, k: int = None) -> Dict[str, Any]: # type: ignore | |
| """Internal query method used by both public methods""" | |
| # Auto-detect if user wants complete list | |
| q_lower = question.lower() | |
| list_indicators = [ | |
| 'what questions', 'list questions', 'all questions', | |
| 'show questions', 'which questions', 'questions were asked', | |
| 'questions asked', 'what was asked' | |
| ] | |
| wants_complete_list = any(indicator in q_lower for indicator in list_indicators) | |
| # Set k based on query type | |
| if k is None: | |
| k = 50 if wants_complete_list else 8 | |
| # Parse temporal filters from question | |
| filters = self._extract_filters(question) | |
| # Retrieve relevant documents | |
| if filters: | |
| retriever = self.vectorstore.as_retriever( | |
| search_kwargs={"k": k, "filter": filters} | |
| ) | |
| else: | |
| retriever = self.vectorstore.as_retriever(search_kwargs={"k": k}) | |
| # Get documents | |
| docs = retriever.invoke(question) | |
| # Reconstruct full questions from IDs | |
| full_questions = [] | |
| for doc in docs: | |
| q_id = doc.metadata['question_id'] | |
| if q_id in self.questions_by_id: | |
| full_questions.append(self.questions_by_id[q_id]) | |
| # Sort by position to maintain survey order | |
| full_questions.sort(key=lambda q: q['position']) | |
| context = self._format_context(full_questions) | |
| prompt = self._get_prompt() | |
| chain = ( | |
| { | |
| "context": lambda x: context, | |
| "question": RunnablePassthrough(), | |
| "catalog": lambda x: self._get_catalog_summary() | |
| } | |
| | prompt | |
| | self.llm | |
| | StrOutputParser() | |
| ) | |
| # Get answer | |
| answer = chain.invoke(question) | |
| return { | |
| 'answer': answer, | |
| 'source_questions': full_questions, | |
| 'num_sources': len(full_questions), | |
| 'filters_applied': filters | |
| } | |
| def _extract_filters(self, question: str) -> Dict[str, Any]: | |
| """Extract temporal and topic filters from question""" | |
| filter_conditions = [] | |
| q_lower = question.lower() | |
| # Year filter | |
| year_match = re.search(r'\b(20\d{2})\b', question) | |
| if year_match: | |
| filter_conditions.append({"year": int(year_match.group(1))}) | |
| # Month filter | |
| months = { | |
| 'january': 'January', 'february': 'February', 'march': 'March', | |
| 'april': 'April', 'may': 'May', 'june': 'June', | |
| 'july': 'July', 'august': 'August', 'september': 'September', | |
| 'october': 'October', 'november': 'November', 'december': 'December' | |
| } | |
| for month_lower, month_proper in months.items(): | |
| if month_lower in q_lower: | |
| filter_conditions.append({"month": month_proper}) | |
| break | |
| # Return proper Chroma filter syntax | |
| if len(filter_conditions) == 0: | |
| return {} | |
| elif len(filter_conditions) == 1: | |
| return filter_conditions[0] | |
| else: | |
| return {"$and": filter_conditions} | |
| def _format_context(self, questions: List[Dict]) -> str: | |
| """Format full questions as context for LLM""" | |
| if not questions: | |
| return "No relevant questions found." | |
| context_parts = [] | |
| for q in questions: | |
| part = f""" | |
| --- Question {q['position'] + 1} from {q['survey_name']} ({q['poll_date']}) --- | |
| Variable: {q['variable_name']} | |
| Question: {q['question_text']} | |
| Response Options: {' | '.join(q['response_options'])} | |
| Topics: {', '.join(q['topics'])} | |
| Question Type: {q['question_type']} | |
| Administration: {q['ask_condition']} | |
| """ | |
| # Add skip logic/sampling info | |
| if q.get('skip_logic'): | |
| part += f"Skip Logic: {q['skip_logic']}\n" | |
| if q.get('half_sample_group'): | |
| part += f"Half Sample Group: {q['half_sample_group']}\n" | |
| # Add sibling variants | |
| if q.get('sibling_variants'): | |
| part += f"\nAlternate Versions (shown to different groups):\n" | |
| for sib in q['sibling_variants']: | |
| sib_group = sib.get('half_sample_group', 'other group') | |
| part += f" - [{sib_group}] {sib['question_text']}\n" | |
| # Add sequence context | |
| if q.get('previous_question'): | |
| prev_vars = q.get('previous_question_variants', []) | |
| if len(prev_vars) > 1: | |
| part += "\nPrevious Question (respondents saw one of these):\n" | |
| for pv in prev_vars: | |
| part += f" - {pv['question_text']}\n" | |
| else: | |
| part += f"\nPrevious Question: {q['previous_question']['question_text']}\n" | |
| if q.get('next_question'): | |
| next_vars = q.get('next_question_variants', []) | |
| if len(next_vars) > 1: | |
| part += "\nNext Question (respondents saw one of these):\n" | |
| for nv in next_vars: | |
| part += f" - {nv['question_text']}\n" | |
| else: | |
| part += f"\nNext Question: {q['next_question']['question_text']}\n" | |
| context_parts.append(part.strip()) | |
| return "\n\n".join(context_parts) | |
| def _get_catalog_summary(self) -> str: | |
| """Get summary of available polls""" | |
| lines = [] | |
| for poll_date in sorted(self.poll_catalog.keys()): | |
| info = self.poll_catalog[poll_date] | |
| month_str = f" ({info['month']})" if info['month'] else "" | |
| lines.append(f"- {poll_date}{month_str}: {info['num_questions']} questions") | |
| return "\n".join(lines) | |
| def get_question_by_variable(self, variable_name: str) -> Optional[Dict]: | |
| """Get a specific question by variable name""" | |
| for q_id, q in self.questions_by_id.items(): | |
| if q['variable_name'] == variable_name: | |
| return q | |
| return None | |
| def main(): | |
| """CLI interface""" | |
| import sys | |
| # Get API key | |
| openai_api_key = os.getenv("OPENAI_API_KEY") | |
| if not openai_api_key: | |
| print("Error: OPENAI_API_KEY not set") | |
| print("Set it with: export OPENAI_API_KEY='your-key'") | |
| sys.exit(1) | |
| pinecone_api_key = os.getenv("PINECONE_API_KEY") | |
| if not pinecone_api_key: | |
| print("Error: PINECONE_API_KEY not set") | |
| print("Set it with: export PINECONE_API_KEY='your-key'") | |
| sys.exit(1) | |
| # Initialize | |
| try: | |
| rag = PollQuestionnaireRAG( | |
| openai_api_key=openai_api_key, verbose=True) | |
| except ValueError as e: | |
| print(f"\nError: {e}") | |
| sys.exit(1) | |
| # Interactive mode if no query provided | |
| if len(sys.argv) < 2: | |
| print("\n" + "="*80) | |
| print("Interactive Mode - Type 'quit' to exit") | |
| print("="*80 + "\n") | |
| print("Example questions:") | |
| print(" • What questions have been asked about gun violence?") | |
| print(" • What questions about the economy were asked in June 2025?") | |
| print(" • What came before VAND8A?") | |
| print(" • Show me all immigration questions") | |
| while True: | |
| try: | |
| question = input("\n\nYour question: ").strip() | |
| if not question or question.lower() in ['quit', 'exit', 'q']: | |
| break | |
| print("\nThinking...\n") | |
| answer = rag.query(question) | |
| print(answer) | |
| except KeyboardInterrupt: | |
| print("\n\nGoodbye!") | |
| break | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| # Single query mode | |
| else: | |
| query = " ".join(sys.argv[1:]) | |
| print(f"\nQuery: {query}") | |
| print("="*80 + "\n") | |
| answer = rag.query(query) | |
| print(answer) | |
| if __name__ == "__main__": | |
| main() | |