umangchaudhry commited on
Commit
6241bf7
·
verified ·
1 Parent(s): ff476de

Upload 6 files

Browse files
app.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for Survey Analysis Agent
3
+ Host on Hugging Face Spaces
4
+ """
5
+
6
+ import os
7
+ import gradio as gr
8
+ from survey_agent import SurveyAnalysisAgent
9
+ import uuid
10
+ from datetime import datetime
11
+
12
+ # Initialize agent (will be done once at startup)
13
+ agent = None
14
+ initialization_error = None
15
+
16
+ def initialize_agent():
17
+ """Initialize the agent with API keys from environment"""
18
+ global agent, initialization_error
19
+
20
+ try:
21
+ openai_api_key = os.getenv("OPENAI_API_KEY")
22
+ pinecone_api_key = os.getenv("PINECONE_API_KEY")
23
+
24
+ if not openai_api_key:
25
+ initialization_error = "❌ OPENAI_API_KEY not found. Please set it in Space Settings → Repository Secrets."
26
+ return False
27
+
28
+ if not pinecone_api_key:
29
+ initialization_error = "❌ PINECONE_API_KEY not found. Please set it in Space Settings → Repository Secrets."
30
+ return False
31
+
32
+ # Check if vector store exists
33
+ if not os.path.exists("./questionnaire_vectorstores"):
34
+ initialization_error = "❌ Vector store directory not found. Please upload the questionnaire_vectorstores folder."
35
+ return False
36
+
37
+ agent = SurveyAnalysisAgent(
38
+ openai_api_key=openai_api_key,
39
+ pinecone_api_key=pinecone_api_key,
40
+ verbose=False # Set to False for cleaner UI
41
+ )
42
+
43
+ return True
44
+
45
+ except Exception as e:
46
+ initialization_error = f"❌ Initialization error: {str(e)}"
47
+ return False
48
+
49
+
50
+ def chat(message, history, session_id):
51
+ """
52
+ Handle chat interaction
53
+
54
+ Args:
55
+ message: User's message
56
+ history: Chat history (list of [user_msg, bot_msg] pairs)
57
+ session_id: Unique session identifier for conversation memory
58
+ """
59
+ if initialization_error:
60
+ return initialization_error
61
+
62
+ if not agent:
63
+ return "⚠️ Agent not initialized. Please refresh the page."
64
+
65
+ if not message.strip():
66
+ return "Please enter a question."
67
+
68
+ try:
69
+ # Use session_id as thread_id for conversation memory
70
+ answer = agent.query(message, thread_id=session_id)
71
+ return answer
72
+
73
+ except Exception as e:
74
+ error_msg = f"❌ Error processing query: {str(e)}"
75
+ print(f"Error details: {e}") # Log to console
76
+ return error_msg
77
+
78
+
79
+ def create_new_session():
80
+ """Create a new session ID"""
81
+ return str(uuid.uuid4())
82
+
83
+
84
+ def get_available_surveys():
85
+ """Get list of available surveys"""
86
+ if initialization_error or not agent:
87
+ return "Agent not initialized"
88
+
89
+ try:
90
+ surveys = agent.questionnaire_rag.get_available_survey_names()
91
+ polls = agent.questionnaire_rag.get_available_polls()
92
+
93
+ info = "## Available Surveys\n\n"
94
+ info += f"**Survey Names:** {', '.join(surveys)}\n\n"
95
+ info += "## Available Polls\n\n"
96
+
97
+ for poll in polls:
98
+ info += f"- **{poll['poll_date']}** ({poll['month']} {poll['year']}): {poll['survey_name']} - {poll['num_questions']} questions\n"
99
+
100
+ return info
101
+ except Exception as e:
102
+ return f"Error retrieving survey info: {str(e)}"
103
+
104
+
105
+ # Initialize agent at startup
106
+ print("🚀 Initializing Survey Analysis Agent...")
107
+ init_success = initialize_agent()
108
+
109
+ if init_success:
110
+ print("✅ Agent initialized successfully!")
111
+ else:
112
+ print(f"⚠️ Agent initialization failed: {initialization_error}")
113
+
114
+
115
+ # Create Gradio interface
116
+ with gr.Blocks(title="Survey Analysis Agent", theme=gr.themes.Soft()) as demo:
117
+
118
+ # Header
119
+ gr.Markdown("""
120
+ # 📊 Survey Analysis Agent
121
+
122
+ Ask questions about survey data using natural language. The agent can:
123
+ - Find questions from specific surveys and time periods
124
+ - Compare questions across different time periods
125
+ - Analyze question topics and themes
126
+ - Show sampling logic and question flow
127
+
128
+ **Note:** Currently only questionnaire data is available (questions, topics, response options, skip logic).
129
+ """)
130
+
131
+ # Show initialization status
132
+ if initialization_error:
133
+ gr.Markdown(f"## ⚠️ Setup Required\n\n{initialization_error}")
134
+
135
+ # Session state
136
+ session_id_state = gr.State(value=create_new_session())
137
+
138
+ # Main chat interface
139
+ with gr.Row():
140
+ with gr.Column(scale=2):
141
+ chatbot = gr.Chatbot(
142
+ label="Conversation",
143
+ height=500,
144
+ show_label=True,
145
+ type="messages"
146
+ )
147
+
148
+ with gr.Row():
149
+ msg = gr.Textbox(
150
+ label="Your question",
151
+ placeholder="e.g., What questions were asked in the June 2025 Unity Poll?",
152
+ show_label=False,
153
+ scale=4
154
+ )
155
+ submit = gr.Button("Send", scale=1, variant="primary")
156
+
157
+ with gr.Row():
158
+ clear = gr.Button("🔄 New Conversation", scale=1)
159
+
160
+ # Example questions
161
+ gr.Examples(
162
+ examples=[
163
+ "What questions were asked in June 2025?",
164
+ "Show me all healthcare-related questions",
165
+ "What questions were asked in the Unity Poll?",
166
+ "Compare immigration questions from different surveys",
167
+ ],
168
+ inputs=msg,
169
+ label="Example Questions"
170
+ )
171
+
172
+ # Sidebar with info
173
+ with gr.Column(scale=1):
174
+ gr.Markdown("## 📋 Available Data")
175
+ survey_info = gr.Markdown(
176
+ value=get_available_surveys() if init_success else "Agent not initialized",
177
+ label="Surveys"
178
+ )
179
+
180
+ refresh_info = gr.Button("🔄 Refresh Survey List", size="sm")
181
+
182
+ gr.Markdown("""
183
+ ## 💡 Tips
184
+
185
+ - Be specific about time periods (e.g., "June 2025")
186
+ - Mention survey names when relevant
187
+ - Follow up with clarifications if needed
188
+ - The agent maintains conversation context
189
+
190
+ ## 🔧 Current Capabilities
191
+
192
+ ✅ **Available:**
193
+ - Question text and response options
194
+ - Topics and themes
195
+ - Skip logic and sampling
196
+ - Question sequencing
197
+
198
+ ⏳ **Coming Soon:**
199
+ - Response frequencies (toplines)
200
+ - Cross-tabulations
201
+ - Statistical analysis
202
+ """)
203
+
204
+ # Event handlers
205
+ def respond(message, chat_history, session_id):
206
+ """Handle message and update chat history"""
207
+ if not message.strip():
208
+ return chat_history, ""
209
+
210
+ # Add user message
211
+ chat_history.append({"role": "user", "content": message})
212
+
213
+ # Get bot response
214
+ bot_message = chat(message, chat_history, session_id)
215
+
216
+ # Add bot message
217
+ chat_history.append({"role": "assistant", "content": bot_message})
218
+
219
+ return chat_history, ""
220
+
221
+ def clear_chat():
222
+ """Clear chat and create new session"""
223
+ new_session = create_new_session()
224
+ return [], new_session
225
+
226
+ # Wire up events
227
+ msg.submit(respond, [msg, chatbot, session_id_state], [chatbot, msg])
228
+ submit.click(respond, [msg, chatbot, session_id_state], [chatbot, msg])
229
+ clear.click(clear_chat, None, [chatbot, session_id_state])
230
+ refresh_info.click(get_available_surveys, None, survey_info)
231
+
232
+ # Footer
233
+ gr.Markdown("""
234
+ ---
235
+ **Note:** This system uses conversation memory. You can ask follow-up questions like:
236
+ 1. "What questions were asked?"
237
+ 2. "June 2025, Unity Poll" (it will understand the context)
238
+ """)
239
+
240
+
241
+ # Launch the app
242
+ if __name__ == "__main__":
243
+ demo.launch(
244
+ server_name="0.0.0.0",
245
+ server_port=7860,
246
+ share=False
247
+ )
questionnaire_rag.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Questionnaire RAG with better filtering and anti-hallucination measures.
3
+
4
+ Key improvements:
5
+ 1. Correct Pinecone filter syntax
6
+ 2. Post-retrieval validation of filters
7
+ 3. Stronger anti-hallucination prompts
8
+ 4. Explicit checks for data existence
9
+ 5. Fuzzy survey name matching
10
+ """
11
+
12
+ import os
13
+ import json
14
+ from typing import List, Dict, Any, Optional
15
+ from pathlib import Path
16
+
17
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
18
+ from langchain_pinecone import PineconeVectorStore
19
+ from pinecone import Pinecone
20
+ from langchain.prompts import ChatPromptTemplate
21
+ from langchain.schema.output_parser import StrOutputParser
22
+
23
+ try:
24
+ from dotenv import load_dotenv
25
+ load_dotenv()
26
+ except ImportError:
27
+ pass
28
+
29
+
30
+ class QuestionnaireRAG:
31
+ """
32
+ Improved questionnaire RAG with:
33
+ - Better Pinecone filtering
34
+ - Post-retrieval validation
35
+ - Anti-hallucination measures
36
+ - Fuzzy survey name matching
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ openai_api_key: str,
42
+ pinecone_api_key: str,
43
+ persist_directory: str = "./questionnaire_vectorstores",
44
+ verbose: bool = False
45
+ ):
46
+ self.openai_api_key = openai_api_key
47
+ self.pinecone_api_key = pinecone_api_key
48
+ self.persist_directory = persist_directory
49
+ self.verbose = verbose
50
+
51
+ # Initialize embeddings
52
+ self.embeddings = OpenAIEmbeddings(
53
+ model=os.getenv("OPENAI_EMBED_MODEL", "text-embedding-3-small")
54
+ )
55
+
56
+ # Initialize LLM
57
+ chat_model = os.getenv("OPENAI_MODEL", "gpt-4o")
58
+ self.llm = ChatOpenAI(model=chat_model, temperature=0)
59
+
60
+ # Load vector store
61
+ if not os.path.exists(persist_directory):
62
+ raise ValueError(
63
+ f"Vector store not found at {persist_directory}\n"
64
+ "Run create_questionnaire_vectorstores.py first"
65
+ )
66
+
67
+ # Connect to Pinecone
68
+ index_name = os.getenv("PINECONE_INDEX_NAME", "poll-questionnaire-index")
69
+ namespace = os.getenv("PINECONE_NAMESPACE") or None
70
+
71
+ pc = Pinecone(api_key=self.pinecone_api_key)
72
+ self.index = pc.Index(index_name)
73
+ self.vectorstore = PineconeVectorStore(
74
+ index=self.index,
75
+ embedding=self.embeddings,
76
+ namespace=namespace
77
+ )
78
+
79
+ # Load catalog and questions
80
+ self.poll_catalog = self._load_catalog()
81
+ self.questions_by_id = self._load_questions_index()
82
+
83
+ if self.verbose:
84
+ print(f"✓ Loaded {len(self.questions_by_id)} questions from {len(self.poll_catalog)} polls")
85
+
86
+ def _load_catalog(self) -> Dict[str, Dict]:
87
+ """Load poll catalog"""
88
+ catalog_path = Path(self.persist_directory) / "poll_catalog.json"
89
+ if catalog_path.exists():
90
+ with open(catalog_path, 'r') as f:
91
+ return json.load(f)
92
+ return {}
93
+
94
+ def _load_questions_index(self) -> Dict[str, Dict]:
95
+ """Load questions index"""
96
+ questions_path = Path(self.persist_directory) / "questions_index.json"
97
+ if questions_path.exists():
98
+ with open(questions_path, 'r') as f:
99
+ return json.load(f)
100
+ return {}
101
+
102
+ def get_available_survey_names(self) -> List[str]:
103
+ """Get list of unique survey names from the catalog"""
104
+ survey_names = set()
105
+ for info in self.poll_catalog.values():
106
+ survey_names.add(info["survey_name"])
107
+ return sorted(survey_names)
108
+
109
+ def _fuzzy_match_survey_name(self, requested_name: str) -> Optional[str]:
110
+ """
111
+ Fuzzy match a requested survey name to an actual stored name.
112
+
113
+ Examples:
114
+ - "Unity Poll" → "Vanderbilt_Unity_Poll"
115
+ - "unity poll" → "Vanderbilt_Unity_Poll"
116
+ - "Vanderbilt Unity" → "Vanderbilt_Unity_Poll"
117
+ """
118
+ # Get all unique survey names
119
+ available_names = self.get_available_survey_names()
120
+
121
+ # Normalize the requested name
122
+ normalized_requested = requested_name.lower().replace("_", " ").replace("-", " ")
123
+
124
+ # Try exact match first (case-insensitive)
125
+ for stored_name in available_names:
126
+ normalized_stored = stored_name.lower().replace("_", " ").replace("-", " ")
127
+ if normalized_requested == normalized_stored:
128
+ return stored_name
129
+
130
+ # Try substring matching - check if requested is in stored
131
+ for stored_name in available_names:
132
+ normalized_stored = stored_name.lower().replace("_", " ").replace("-", " ")
133
+ if normalized_requested in normalized_stored:
134
+ return stored_name
135
+
136
+ # Try reverse - check if stored is in requested
137
+ for stored_name in available_names:
138
+ normalized_stored = stored_name.lower().replace("_", " ").replace("-", " ")
139
+ if normalized_stored in normalized_requested:
140
+ return stored_name
141
+
142
+ # Try word-level matching - if all words from requested are in stored
143
+ requested_words = set(normalized_requested.split())
144
+ for stored_name in available_names:
145
+ normalized_stored = stored_name.lower().replace("_", " ").replace("-", " ")
146
+ stored_words = set(normalized_stored.split())
147
+
148
+ # Check if requested words are a subset of stored words
149
+ if requested_words.issubset(stored_words):
150
+ return stored_name
151
+
152
+ return None
153
+
154
+ def _build_pinecone_filter(self, filters: Dict[str, Any]) -> Optional[Dict[str, Any]]:
155
+ """
156
+ Build proper Pinecone metadata filter with fuzzy survey name matching.
157
+
158
+ Pinecone filter syntax:
159
+ - Simple: {"year": 2025}
160
+ - Multiple: {"$and": [{"year": 2025}, {"month": "February"}]}
161
+ """
162
+ if not filters:
163
+ return None
164
+
165
+ filter_conditions = []
166
+
167
+ # Handle year filter
168
+ if "year" in filters:
169
+ year = filters["year"]
170
+ if isinstance(year, str):
171
+ year = int(year)
172
+ filter_conditions.append({"year": {"$eq": year}})
173
+
174
+ # Handle month filter
175
+ if "month" in filters:
176
+ month = filters["month"]
177
+ # Ensure proper capitalization
178
+ if isinstance(month, str):
179
+ month = month.capitalize()
180
+ filter_conditions.append({"month": {"$eq": month}})
181
+
182
+ # Handle poll_date filter (exact match)
183
+ if "poll_date" in filters:
184
+ filter_conditions.append({"poll_date": {"$eq": filters["poll_date"]}})
185
+
186
+ # Handle survey_name filter with fuzzy matching
187
+ if "survey_name" in filters:
188
+ requested_name = filters["survey_name"]
189
+
190
+ # Try to fuzzy match the survey name
191
+ matched_name = self._fuzzy_match_survey_name(requested_name)
192
+
193
+ if matched_name:
194
+ if self.verbose and matched_name != requested_name:
195
+ print(f"🔄 Mapped survey name '{requested_name}' → '{matched_name}'")
196
+ filter_conditions.append({"survey_name": {"$eq": matched_name}})
197
+ else:
198
+ if self.verbose:
199
+ print(f"⚠️ Survey name '{requested_name}' not found in catalog")
200
+ print(f" Available: {self.get_available_survey_names()}")
201
+ # Don't add the filter if we can't match it - let other filters work
202
+
203
+ # Handle topics (if a topic is in the comma-separated list)
204
+ if "topic" in filters:
205
+ # This is trickier with comma-separated strings in metadata
206
+ # For now, we'll do post-filtering
207
+ pass
208
+
209
+ # Combine filters
210
+ if len(filter_conditions) == 0:
211
+ return None
212
+ elif len(filter_conditions) == 1:
213
+ return filter_conditions[0]
214
+ else:
215
+ return {"$and": filter_conditions}
216
+
217
+ def _validate_results(
218
+ self,
219
+ docs: List[Any],
220
+ filters: Dict[str, Any]
221
+ ) -> List[Any]:
222
+ """
223
+ Validate that retrieved documents actually match the filters.
224
+
225
+ This catches cases where:
226
+ 1. Pinecone filtering didn't work correctly
227
+ 2. We need to do additional filtering (like topic matching)
228
+ """
229
+ if not filters:
230
+ return docs
231
+
232
+ validated_docs = []
233
+
234
+ for doc in docs:
235
+ metadata = doc.metadata
236
+ valid = True
237
+
238
+ # Check year
239
+ if "year" in filters:
240
+ expected_year = int(filters["year"]) if isinstance(filters["year"], str) else filters["year"]
241
+ if metadata.get("year") != expected_year:
242
+ if self.verbose:
243
+ print(f"⚠️ Filtered out: wrong year {metadata.get('year')} != {expected_year}")
244
+ valid = False
245
+
246
+ # Check month
247
+ if "month" in filters and valid:
248
+ expected_month = filters["month"].capitalize() if isinstance(filters["month"], str) else filters["month"]
249
+ if metadata.get("month") != expected_month:
250
+ if self.verbose:
251
+ print(f"⚠️ Filtered out: wrong month {metadata.get('month')} != {expected_month}")
252
+ valid = False
253
+
254
+ # Check poll_date
255
+ if "poll_date" in filters and valid:
256
+ if metadata.get("poll_date") != filters["poll_date"]:
257
+ if self.verbose:
258
+ print(f"⚠️ Filtered out: wrong poll_date {metadata.get('poll_date')} != {filters['poll_date']}")
259
+ valid = False
260
+
261
+ # Check survey_name (with fuzzy matching)
262
+ if "survey_name" in filters and valid:
263
+ requested_name = filters["survey_name"]
264
+ matched_name = self._fuzzy_match_survey_name(requested_name)
265
+ if matched_name and metadata.get("survey_name") != matched_name:
266
+ if self.verbose:
267
+ print(f"⚠️ Filtered out: wrong survey {metadata.get('survey_name')} != {matched_name}")
268
+ valid = False
269
+
270
+ if valid:
271
+ validated_docs.append(doc)
272
+
273
+ return validated_docs
274
+
275
+ def _get_prompt(self) -> ChatPromptTemplate:
276
+ """Get the improved system prompt with anti-hallucination measures"""
277
+ return ChatPromptTemplate.from_messages([
278
+ ("system", """You are an expert assistant for analyzing poll questionnaires.
279
+
280
+ 🚨 CRITICAL RULES - NEVER VIOLATE THESE:
281
+
282
+ 1. **ONLY use information from the provided context**
283
+ - Do NOT make up questions, polls, or dates
284
+ - Do NOT assume a poll exists if it's not in the context
285
+ - If information is missing, say "I don't have data for [X]" rather than making it up
286
+
287
+ 2. **Verify data exists before listing it**
288
+ - Before mentioning any poll, check it's actually in the context
289
+ - Before listing questions, confirm they exist in the retrieved data
290
+ - If asked about multiple time periods, explicitly state which ones have data and which don't
291
+
292
+ 3. **Be explicit about what's NOT in the data**
293
+ - If asked about "2024 and 2025" but only 2025 data exists, say: "I have data for 2025, but there is no 2024 data in the retrieved results"
294
+ - Never silently skip missing data - always acknowledge it
295
+
296
+ 4. **When listing questions:**
297
+ - List ALL questions from the context in order
298
+ - Include full question text and response options
299
+ - Note sampling inline in clear language:
300
+ * "Asked to all respondents" (not "ASK ALL")
301
+ * "Asked to half the sample" (not "HALFSAMP1=1")
302
+ * "Asked only if [condition]" (not technical codes)
303
+ - If sibling variants exist, note "One of two versions shown to different groups"
304
+ - Always cite which poll(s) you're using
305
+
306
+ 5. **Format for scannability:**
307
+ - Use numbered lists for questions
308
+ - Bold question text
309
+ - Include response options as bullet points
310
+ - Put sampling info in parentheses after question
311
+
312
+ Available polls in the system (for reference):
313
+ {catalog}
314
+
315
+ Context (ONLY source of truth):
316
+ {context}
317
+
318
+ Question: {question}
319
+ """),
320
+ ("human", "Answer:")
321
+ ])
322
+
323
+ def query(self, question: str, filters: Optional[Dict[str, Any]] = None, k: int = 20) -> str:
324
+ """
325
+ Query the questionnaire system.
326
+
327
+ Args:
328
+ question: Natural language question
329
+ filters: Optional filters (year, month, poll_date, survey_name)
330
+ k: Number of results to retrieve
331
+
332
+ Returns:
333
+ Answer string
334
+ """
335
+ result = self._query_internal(question, filters, k)
336
+ return result['answer']
337
+
338
+ def query_with_metadata(
339
+ self,
340
+ question: str,
341
+ filters: Optional[Dict[str, Any]] = None,
342
+ k: int = 20
343
+ ) -> Dict[str, Any]:
344
+ """
345
+ Query with full metadata about retrieval.
346
+
347
+ Returns:
348
+ Dict with 'answer', 'source_questions', 'num_sources', 'filters_applied'
349
+ """
350
+ return self._query_internal(question, filters, k)
351
+
352
+ def _query_internal(
353
+ self,
354
+ question: str,
355
+ filters: Optional[Dict[str, Any]] = None,
356
+ k: int = 20
357
+ ) -> Dict[str, Any]:
358
+ """Internal query implementation"""
359
+
360
+ if self.verbose:
361
+ print(f"\n📊 Query: {question}")
362
+ if filters:
363
+ print(f"🔍 Filters: {filters}")
364
+
365
+ # Build Pinecone filter
366
+ pinecone_filter = self._build_pinecone_filter(filters or {})
367
+
368
+ # Retrieve documents
369
+ if pinecone_filter:
370
+ if self.verbose:
371
+ print(f"🔧 Pinecone filter: {pinecone_filter}")
372
+ retriever = self.vectorstore.as_retriever(
373
+ search_kwargs={"k": k, "filter": pinecone_filter}
374
+ )
375
+ else:
376
+ retriever = self.vectorstore.as_retriever(search_kwargs={"k": k})
377
+
378
+ docs = retriever.invoke(question)
379
+
380
+ if self.verbose:
381
+ print(f"📥 Retrieved {len(docs)} documents from Pinecone")
382
+
383
+ # Validate results match filters
384
+ if filters:
385
+ docs = self._validate_results(docs, filters)
386
+ if self.verbose:
387
+ print(f"✅ After validation: {len(docs)} documents")
388
+
389
+ # Check if we have any results
390
+ if not docs:
391
+ no_data_msg = f"No questionnaire data found"
392
+ if filters:
393
+ filter_desc = ", ".join([f"{k}={v}" for k, v in filters.items()])
394
+ no_data_msg += f" matching filters: {filter_desc}"
395
+
396
+ return {
397
+ "answer": no_data_msg,
398
+ "source_questions": [],
399
+ "num_sources": 0,
400
+ "filters_applied": filters or {}
401
+ }
402
+
403
+ # Reconstruct full questions
404
+ full_questions = []
405
+ seen_ids = set()
406
+
407
+ for doc in docs:
408
+ q_id = doc.metadata.get('question_id')
409
+ if q_id and q_id not in seen_ids:
410
+ if q_id in self.questions_by_id:
411
+ full_questions.append(self.questions_by_id[q_id])
412
+ seen_ids.add(q_id)
413
+
414
+ # Sort by position to maintain survey order
415
+ full_questions.sort(key=lambda q: (q.get('poll_date', ''), q.get('position', 0)))
416
+
417
+ # Format context with explicit data availability info
418
+ context = self._format_context(full_questions, filters)
419
+
420
+ # Get prompt
421
+ prompt = self._get_prompt()
422
+
423
+ # Create chain
424
+ chain = (
425
+ {
426
+ "context": lambda x: context,
427
+ "question": lambda x: question,
428
+ "catalog": lambda x: self._get_catalog_summary()
429
+ }
430
+ | prompt
431
+ | self.llm
432
+ | StrOutputParser()
433
+ )
434
+
435
+ # Get answer
436
+ answer = chain.invoke(question)
437
+
438
+ return {
439
+ 'answer': answer,
440
+ 'source_questions': full_questions,
441
+ 'num_sources': len(full_questions),
442
+ 'filters_applied': filters or {}
443
+ }
444
+
445
+ def _format_context(
446
+ self,
447
+ questions: List[Dict],
448
+ filters: Optional[Dict[str, Any]] = None
449
+ ) -> str:
450
+ """Format questions as context with explicit data availability"""
451
+
452
+ if not questions:
453
+ filter_desc = ""
454
+ if filters:
455
+ filter_desc = f" matching {filters}"
456
+ return f"⚠️ NO DATA RETRIEVED{filter_desc}\n\nYou must inform the user that no data exists for their query."
457
+
458
+ context_parts = []
459
+
460
+ # Add explicit note about what data we have
461
+ polls_found = sorted(set(q['poll_date'] for q in questions))
462
+ context_parts.append(f"✅ DATA AVAILABLE FOR: {', '.join(polls_found)}")
463
+
464
+ # Add note about what was requested vs what was found
465
+ if filters:
466
+ if 'year' in filters and 'month' in filters:
467
+ requested = f"{filters['month']} {filters['year']}"
468
+ context_parts.append(f"🔍 REQUESTED: {requested}")
469
+
470
+ context_parts.append("") # Blank line
471
+ context_parts.append("=" * 80)
472
+ context_parts.append("")
473
+
474
+ # Format each question
475
+ for i, q in enumerate(questions, 1):
476
+ part = f"""
477
+ --- Question {i} from {q['survey_name']} ({q['poll_date']}) ---
478
+ Variable: {q['variable_name']}
479
+ Question: {q['question_text']}
480
+ Response Options: {' | '.join(q['response_options'])}
481
+ Topics: {', '.join(q['topics'])}
482
+ Question Type: {q['question_type']}
483
+ Administration: {q['ask_condition']}
484
+ """
485
+
486
+ # Add skip logic/sampling
487
+ if q.get('skip_logic'):
488
+ part += f"Skip Logic: {q['skip_logic']}\n"
489
+
490
+ if q.get('half_sample_group'):
491
+ part += f"Half Sample Group: {q['half_sample_group']}\n"
492
+
493
+ # Add sibling variants
494
+ if q.get('sibling_variants'):
495
+ part += f"\nAlternate Versions (shown to different groups):\n"
496
+ for sib in q['sibling_variants']:
497
+ sib_group = sib.get('half_sample_group', 'other group')
498
+ part += f" - [{sib_group}] {sib['question_text']}\n"
499
+
500
+ # Add sequence context
501
+ if q.get('previous_question'):
502
+ prev_vars = q.get('previous_question_variants', [])
503
+ if len(prev_vars) > 1:
504
+ part += "\nPrevious Question (respondents saw one of these):\n"
505
+ for pv in prev_vars:
506
+ part += f" - {pv['question_text']}\n"
507
+ else:
508
+ part += f"\nPrevious Question: {q['previous_question']['question_text']}\n"
509
+
510
+ if q.get('next_question'):
511
+ next_vars = q.get('next_question_variants', [])
512
+ if len(next_vars) > 1:
513
+ part += "\nNext Question (respondents saw one of these):\n"
514
+ for nv in next_vars:
515
+ part += f" - {nv['question_text']}\n"
516
+ else:
517
+ part += f"\nNext Question: {q['next_question']['question_text']}\n"
518
+
519
+ context_parts.append(part.strip())
520
+
521
+ return "\n\n".join(context_parts)
522
+
523
+ def _get_catalog_summary(self) -> str:
524
+ """Get summary of available polls"""
525
+ lines = ["Available polls:"]
526
+ for poll_date in sorted(self.poll_catalog.keys()):
527
+ info = self.poll_catalog[poll_date]
528
+ month_str = f" ({info['month']})" if info.get('month') else ""
529
+ lines.append(f"- {poll_date}{month_str}: {info['num_questions']} questions")
530
+ return "\n".join(lines)
531
+
532
+ def get_available_polls(self) -> List[Dict[str, Any]]:
533
+ """Get list of all available polls"""
534
+ return [
535
+ {
536
+ "poll_date": poll_date,
537
+ "survey_name": info["survey_name"],
538
+ "year": info["year"],
539
+ "month": info.get("month", ""),
540
+ "num_questions": info["num_questions"]
541
+ }
542
+ for poll_date, info in sorted(self.poll_catalog.items())
543
+ ]
544
+
545
+
546
+ def main():
547
+ """Test CLI"""
548
+ import sys
549
+
550
+ openai_api_key = os.getenv("OPENAI_API_KEY")
551
+ pinecone_api_key = os.getenv("PINECONE_API_KEY")
552
+
553
+ if not openai_api_key or not pinecone_api_key:
554
+ print("Error: Missing API keys")
555
+ sys.exit(1)
556
+
557
+ rag = QuestionnaireRAG(
558
+ openai_api_key=openai_api_key,
559
+ pinecone_api_key=pinecone_api_key,
560
+ verbose=True
561
+ )
562
+
563
+ print("\n" + "="*80)
564
+ print("QUESTIONNAIRE RAG - TEST MODE")
565
+ print("="*80)
566
+
567
+ # Test fuzzy matching
568
+ print("\n🧪 TEST: Fuzzy survey name matching")
569
+ test_names = ["Unity Poll", "unity poll", "Vanderbilt Unity", "UNITY"]
570
+ for name in test_names:
571
+ matched = rag._fuzzy_match_survey_name(name)
572
+ print(f" '{name}' → '{matched}'")
573
+
574
+ # Test with the problematic query
575
+ print("\n🧪 TEST: Query that previously failed")
576
+ print("Query: What questions were asked in the June 2025 Unity Poll?")
577
+
578
+ filters = {"year": 2025, "month": "June", "survey_name": "Unity Poll"}
579
+ result = rag.query_with_metadata(
580
+ "What questions were asked in the June 2025 Unity Poll?",
581
+ filters=filters
582
+ )
583
+
584
+ print(f"\n📊 Results:")
585
+ print(f"Found: {result['num_sources']} questions")
586
+ print(f"\n{result['answer'][:500]}...")
587
+
588
+ print("\n" + "="*80)
589
+
590
+
591
+ if __name__ == "__main__":
592
+ main()
questionnaire_vectorstores/poll_catalog.json ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "2023-06": {
3
+ "survey_name": "Vanderbilt_Unity_Poll",
4
+ "year": 2023,
5
+ "month": "June",
6
+ "poll_date": "2023-06",
7
+ "num_questions": 15,
8
+ "file": "questionnaire_data/Vanderbilt_Unity_Poll_2023_June_questions.json"
9
+ },
10
+ "2023-03": {
11
+ "survey_name": "Vanderbilt_Unity_Poll",
12
+ "year": 2023,
13
+ "month": "March",
14
+ "poll_date": "2023-03",
15
+ "num_questions": 8,
16
+ "file": "questionnaire_data/Vanderbilt_Unity_Poll_2023_March_questions.json"
17
+ },
18
+ "2023-09": {
19
+ "survey_name": "Vanderbilt_Unity_Poll",
20
+ "year": 2023,
21
+ "month": "September",
22
+ "poll_date": "2023-09",
23
+ "num_questions": 15,
24
+ "file": "questionnaire_data/Vanderbilt_Unity_Poll_2023_September_questions.json"
25
+ },
26
+ "2024-06": {
27
+ "survey_name": "Vanderbilt_Unity_Poll",
28
+ "year": 2024,
29
+ "month": "June",
30
+ "poll_date": "2024-06",
31
+ "num_questions": 5,
32
+ "file": "questionnaire_data/Vanderbilt_Unity_Poll_2024_June_questions.json"
33
+ },
34
+ "2024-03": {
35
+ "survey_name": "Vanderbilt_Unity_Poll",
36
+ "year": 2024,
37
+ "month": "March",
38
+ "poll_date": "2024-03",
39
+ "num_questions": 13,
40
+ "file": "questionnaire_data/Vanderbilt_Unity_Poll_2024_March_questions.json"
41
+ },
42
+ "2024-10": {
43
+ "survey_name": "Vanderbilt_Unity_Poll",
44
+ "year": 2024,
45
+ "month": "October",
46
+ "poll_date": "2024-10",
47
+ "num_questions": 14,
48
+ "file": "questionnaire_data/Vanderbilt_Unity_Poll_2024_October_questions.json"
49
+ },
50
+ "2024-09": {
51
+ "survey_name": "Vanderbilt_Unity_Poll",
52
+ "year": 2024,
53
+ "month": "September",
54
+ "poll_date": "2024-09",
55
+ "num_questions": 15,
56
+ "file": "questionnaire_data/Vanderbilt_Unity_Poll_2024_September_questions.json"
57
+ },
58
+ "2025-02": {
59
+ "survey_name": "Vanderbilt_Unity_Poll",
60
+ "year": 2025,
61
+ "month": "February",
62
+ "poll_date": "2025-02",
63
+ "num_questions": 17,
64
+ "file": "questionnaire_data/Vanderbilt_Unity_Poll_2025_February_questions.json"
65
+ },
66
+ "2025-06": {
67
+ "survey_name": "Vanderbilt_Unity_Poll",
68
+ "year": 2025,
69
+ "month": "June",
70
+ "poll_date": "2025-06",
71
+ "num_questions": 23,
72
+ "file": "questionnaire_data/Vanderbilt_Unity_Poll_2025_June_questions.json"
73
+ }
74
+ }
questionnaire_vectorstores/questions_index.json ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ langchain>=0.1.0
3
+ langchain-openai>=0.0.5
4
+ langchain-pinecone>=0.0.3
5
+ langgraph>=0.0.20
6
+ openai>=1.0.0
7
+ pinecone
8
+ python-dotenv>=1.0.0
9
+ pydantic>=2.0.0
survey_agent.py ADDED
@@ -0,0 +1,1175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-agent survey analysis system using LangGraph with Staged Research Briefs.
3
+
4
+ This orchestrates multiple data sources (questionnaires, toplines, crosstabs, SQL)
5
+ to answer complex survey research questions using sequential, adaptive research stages.
6
+
7
+ # TODO: REMOVE WHEN PIPELINES READY
8
+ When new pipelines (toplines, crosstabs, SQL) become available:
9
+ 1. Add pipeline name to SurveyAnalysisAgent.AVAILABLE_PIPELINES (line ~105)
10
+ 2. Add execution logic in _execute_stage() method (around line ~450)
11
+ 3. Search for "TODO: REMOVE WHEN PIPELINES READY" and remove those sections
12
+ 4. Update examples to include the new pipeline capabilities
13
+
14
+ Current Status:
15
+ - ✅ Questionnaire pipeline: ACTIVE
16
+ - ⏳ Toplines pipeline: Not yet implemented
17
+ - ⏳ Crosstabs pipeline: Not yet implemented
18
+ - ⏳ SQL pipeline: Not yet implemented
19
+ """
20
+
21
+ import os
22
+ import json
23
+ from typing import TypedDict, Literal, Annotated, List, Dict, Any, Optional, Union
24
+ from pathlib import Path
25
+ import operator
26
+
27
+ from langgraph.graph import StateGraph, START, END
28
+ from langgraph.checkpoint.memory import MemorySaver
29
+ from langchain_openai import ChatOpenAI
30
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
31
+ from pydantic import BaseModel, Field, ConfigDict
32
+
33
+ # Import the questionnaire RAG
34
+ from questionnaire_rag import QuestionnaireRAG
35
+
36
+ try:
37
+ from dotenv import load_dotenv
38
+ load_dotenv()
39
+ except ImportError:
40
+ pass
41
+
42
+
43
+ # ============================================================================
44
+ # STATE DEFINITIONS (PYDANTIC V2) - WITH STAGED RESEARCH
45
+ # ============================================================================
46
+
47
+ class QueryFilters(BaseModel):
48
+ """Filters for data source queries - Pydantic v2 with strict schema"""
49
+ model_config = ConfigDict(extra="forbid")
50
+
51
+ year: Optional[int] = Field(default=None, description="Year filter (e.g., 2025)")
52
+ month: Optional[str] = Field(default=None, description="Month filter (e.g., 'February')")
53
+ poll_date: Optional[str] = Field(default=None, description="Specific poll date (e.g., '2025-02-15')")
54
+ survey_name: Optional[str] = Field(default=None, description="Survey name filter (e.g., 'Unity Poll')")
55
+ topic: Optional[str] = Field(default=None, description="Topic filter")
56
+ question_ids: Optional[List[str]] = Field(default=None, description="Specific question IDs from previous stage")
57
+
58
+
59
+ class DataSource(BaseModel):
60
+ """Represents a data source to query"""
61
+ model_config = ConfigDict(extra="forbid")
62
+
63
+ source_type: Literal["questionnaire", "toplines", "crosstabs", "sql"]
64
+ query_description: str = Field(description="What to retrieve from this source")
65
+ filters: QueryFilters = Field(default_factory=QueryFilters, description="Filters to apply")
66
+ result_label: Optional[str] = Field(default=None, description="Label for these results (e.g., '2024_questions')")
67
+
68
+
69
+ class ResearchStage(BaseModel):
70
+ """A single stage in a multi-stage research plan"""
71
+ model_config = ConfigDict(extra="forbid")
72
+
73
+ stage_number: int = Field(description="Stage number (1-indexed)")
74
+ description: str = Field(description="What this stage accomplishes")
75
+ data_sources: List[DataSource] = Field(description="Data sources to query in this stage")
76
+ depends_on_stages: List[int] = Field(default_factory=list, description="Which prior stages this depends on")
77
+ use_previous_results_for: Optional[str] = Field(
78
+ default=None,
79
+ description="How to use previous stage results (e.g., 'Extract question IDs from stage 1')"
80
+ )
81
+
82
+
83
+ class ResearchBrief(BaseModel):
84
+ """Research brief - can be either single-stage or multi-stage"""
85
+ model_config = ConfigDict(extra="forbid")
86
+
87
+ action: Literal["answer", "followup", "route_to_sources", "execute_stages"]
88
+ followup_question: Optional[str] = Field(default=None, description="Follow-up question to ask user")
89
+ reasoning: str = Field(description="Why this approach was chosen")
90
+
91
+ # For simple queries (single-stage)
92
+ data_sources: List[DataSource] = Field(default_factory=list, description="Data sources for simple queries")
93
+
94
+ # For complex queries (multi-stage)
95
+ stages: List[ResearchStage] = Field(default_factory=list, description="Ordered stages of research")
96
+
97
+
98
+ class StageResult(BaseModel):
99
+ """Results from executing one stage"""
100
+ model_config = ConfigDict(extra="forbid")
101
+
102
+ stage_number: int
103
+ status: Literal["success", "partial", "failed"]
104
+ questionnaire_results: Optional[Dict[str, Any]] = None
105
+ toplines_results: Optional[Dict[str, Any]] = None
106
+ crosstabs_results: Optional[Dict[str, Any]] = None
107
+ sql_results: Optional[Dict[str, Any]] = None
108
+ extracted_context: Optional[Dict[str, Any]] = Field(
109
+ default=None,
110
+ description="Key information extracted for next stages (e.g., question IDs)"
111
+ )
112
+
113
+
114
+ class VerificationResult(BaseModel):
115
+ """Result of verifying if data answers the question"""
116
+ model_config = ConfigDict(extra="forbid")
117
+
118
+ answers_question: bool = Field(description="Whether the data fully answers the question")
119
+ missing_info: Optional[str] = Field(default=None, description="What information is missing")
120
+ improvement_suggestion: Optional[str] = Field(default=None, description="How to improve the research brief")
121
+
122
+
123
+ class SurveyAnalysisState(TypedDict):
124
+ """State for the survey analysis agent - WITH STAGED RESEARCH"""
125
+ # User interaction
126
+ messages: Annotated[List, operator.add]
127
+ user_question: str
128
+
129
+ # Planning
130
+ research_brief: Optional[ResearchBrief]
131
+
132
+ # Stage execution
133
+ current_stage: int # Which stage we're executing (0-indexed internally, but 1-indexed in models)
134
+ stage_results: List[StageResult] # Results from each completed stage
135
+
136
+ # Legacy single-stage results (for backward compatibility)
137
+ questionnaire_results: Optional[Dict[str, Any]]
138
+ toplines_results: Optional[Dict[str, Any]]
139
+ crosstabs_results: Optional[Dict[str, Any]]
140
+ sql_results: Optional[Dict[str, Any]]
141
+
142
+ # Verification & synthesis
143
+ verification: Optional[VerificationResult]
144
+ final_answer: Optional[str]
145
+
146
+ # Control flow
147
+ retry_count: int
148
+ max_retries: int
149
+
150
+
151
+ # ============================================================================
152
+ # SURVEY ANALYSIS ORCHESTRATOR - WITH STAGED RESEARCH
153
+ # ============================================================================
154
+
155
+ class SurveyAnalysisAgent:
156
+ """
157
+ Multi-agent system for analyzing survey data with staged research briefs.
158
+
159
+ Flow:
160
+ 1. User asks question
161
+ 2. Research brief agent decides: simple (one-shot) or complex (staged)
162
+ 3. For simple: run pipelines in parallel → verify → synthesize
163
+ 4. For complex: execute stages sequentially, each using previous results
164
+ 5. Final synthesis combines all stage results
165
+ """
166
+
167
+ # TODO: REMOVE WHEN PIPELINES READY - START
168
+ # Track which pipelines are currently available
169
+ AVAILABLE_PIPELINES = {"questionnaire"} # Add "toplines", "crosstabs", "sql" as they become ready
170
+ # TODO: REMOVE WHEN PIPELINES READY - END
171
+
172
+ def __init__(
173
+ self,
174
+ openai_api_key: str,
175
+ pinecone_api_key: str,
176
+ questionnaire_persist_dir: str = "./questionnaire_vectorstores",
177
+ max_retries: int = 2,
178
+ verbose: bool = True
179
+ ):
180
+ self.openai_api_key = openai_api_key
181
+ self.pinecone_api_key = pinecone_api_key
182
+ self.verbose = verbose
183
+ self.max_retries = max_retries
184
+
185
+ # Initialize LLM
186
+ self.llm = ChatOpenAI(
187
+ model=os.getenv("OPENAI_MODEL", "gpt-4o"),
188
+ temperature=0
189
+ )
190
+
191
+ # Initialize questionnaire RAG
192
+ if self.verbose:
193
+ print("Initializing questionnaire RAG system...")
194
+ self.questionnaire_rag = QuestionnaireRAG(
195
+ openai_api_key=openai_api_key,
196
+ pinecone_api_key=pinecone_api_key,
197
+ persist_directory=questionnaire_persist_dir,
198
+ verbose=verbose
199
+ )
200
+
201
+ # Build the graph
202
+ self.graph = self._build_graph()
203
+
204
+ if self.verbose:
205
+ print("✓ Survey analysis agent initialized with staged research capability")
206
+
207
+ def _build_graph(self) -> StateGraph:
208
+ """Build the LangGraph workflow with staged research support"""
209
+
210
+ workflow = StateGraph(SurveyAnalysisState)
211
+
212
+ # Add nodes
213
+ workflow.add_node("generate_research_brief", self._generate_research_brief)
214
+ workflow.add_node("execute_stage", self._execute_stage)
215
+ workflow.add_node("extract_stage_context", self._extract_stage_context)
216
+ workflow.add_node("verify_results", self._verify_results)
217
+ workflow.add_node("synthesize_response", self._synthesize_response)
218
+
219
+ # Define edges
220
+ workflow.add_edge(START, "generate_research_brief")
221
+
222
+ # After research brief, route based on action
223
+ workflow.add_conditional_edges(
224
+ "generate_research_brief",
225
+ self._route_after_brief,
226
+ {
227
+ "followup": END,
228
+ "answer": "synthesize_response",
229
+ "execute_stage": "execute_stage"
230
+ }
231
+ )
232
+
233
+ # After stage execution, extract context for next stage
234
+ workflow.add_edge("execute_stage", "extract_stage_context")
235
+
236
+ # After context extraction, decide next step
237
+ workflow.add_conditional_edges(
238
+ "extract_stage_context",
239
+ self._route_after_stage,
240
+ {
241
+ "next_stage": "execute_stage", # More stages to go
242
+ "verify": "verify_results" # All stages done, verify
243
+ }
244
+ )
245
+
246
+ # After verification, decide next step
247
+ workflow.add_conditional_edges(
248
+ "verify_results",
249
+ self._route_after_verification,
250
+ {
251
+ "synthesize": "synthesize_response",
252
+ "retry": "generate_research_brief",
253
+ "give_up": "synthesize_response"
254
+ }
255
+ )
256
+
257
+ # End after synthesis
258
+ workflow.add_edge("synthesize_response", END)
259
+
260
+ # Compile with memory
261
+ memory = MemorySaver()
262
+ return workflow.compile(checkpointer=memory)
263
+
264
+ def _get_available_surveys_description(self) -> str:
265
+ """Get formatted description of available surveys for LLM prompt"""
266
+ survey_names = self.questionnaire_rag.get_available_survey_names()
267
+
268
+ if not survey_names:
269
+ return "No surveys currently loaded."
270
+
271
+ lines = ["Available survey names in the system:"]
272
+ for name in survey_names:
273
+ # Show both the stored name and common variations
274
+ lines.append(f" - Stored as: '{name}'")
275
+ # Parse variations
276
+ variations = []
277
+ # Remove underscores for common term
278
+ clean = name.replace("_", " ")
279
+ if clean != name:
280
+ variations.append(f"'{clean}'")
281
+ # Extract key words
282
+ words = clean.split()
283
+ if len(words) > 1:
284
+ # Last few words might be the short name
285
+ short_name = " ".join(words[-2:]) if len(words) >= 2 else words[-1]
286
+ if short_name != clean:
287
+ variations.append(f"'{short_name}'")
288
+
289
+ if variations:
290
+ lines.append(f" (users might say: {', '.join(variations)})")
291
+
292
+ lines.append("\nIMPORTANT: Use the exact stored name in your filters!")
293
+ return "\n".join(lines)
294
+
295
+ # TODO: REMOVE WHEN PIPELINES READY - START
296
+ def _get_pipeline_status_description(self) -> str:
297
+ """Get description of available vs unavailable pipelines"""
298
+ all_pipelines = {
299
+ "questionnaire": "Survey questions, response options, topics, skip logic, sampling",
300
+ "toplines": "Pre-computed response frequencies for each question",
301
+ "crosstabs": "Pre-computed cross-tabulations by demographics",
302
+ "sql": "Raw survey responses for custom analysis"
303
+ }
304
+
305
+ lines = []
306
+ for pipeline, description in all_pipelines.items():
307
+ status = "✅ AVAILABLE" if pipeline in self.AVAILABLE_PIPELINES else "❌ NOT YET AVAILABLE"
308
+ lines.append(f"{pipeline.capitalize()}: {description} {status}")
309
+
310
+ return "\n".join(lines)
311
+ # TODO: REMOVE WHEN PIPELINES READY - END
312
+
313
+ def _get_full_question_context(self, state: SurveyAnalysisState) -> str:
314
+ """
315
+ Build full question context from conversation history.
316
+
317
+ This handles cases where the user's question is split across multiple turns:
318
+ - Turn 1: "what questions were asked?"
319
+ - Turn 2: "June 2025, unity poll"
320
+
321
+ We need to combine these to understand the full intent.
322
+ """
323
+ messages = state.get("messages", [])
324
+
325
+ # Extract all human messages (excluding system/AI messages)
326
+ human_messages = []
327
+ for msg in messages:
328
+ if isinstance(msg, HumanMessage):
329
+ human_messages.append(msg.content)
330
+
331
+ if not human_messages:
332
+ return state["user_question"]
333
+
334
+ if self.verbose:
335
+ print(f"📝 Conversation history: {len(human_messages)} user message(s)")
336
+ for i, msg in enumerate(human_messages, 1):
337
+ print(f" {i}. {msg[:100]}..." if len(msg) > 100 else f" {i}. {msg}")
338
+
339
+ # If there's only one message, just use it
340
+ if len(human_messages) == 1:
341
+ return human_messages[0]
342
+
343
+ # Multiple messages - combine them intelligently
344
+ # The last message is usually the most specific (e.g., "June 2025, unity poll")
345
+ # Earlier messages provide the intent (e.g., "what questions were asked?")
346
+
347
+ # Check if the first message is a question and the second is a clarification
348
+ first_msg = human_messages[0].lower()
349
+ is_followup_scenario = any(word in first_msg for word in ["what", "which", "how", "show", "list", "tell"])
350
+
351
+ if is_followup_scenario and len(human_messages) == 2:
352
+ # Combine: "what questions were asked? [from] June 2025, unity poll"
353
+ combined = f"{human_messages[0]} (specifically: {human_messages[1]})"
354
+ if self.verbose:
355
+ print(f"🔗 Combined context: {combined}")
356
+ return combined
357
+
358
+ # For other cases, join all messages
359
+ combined = " | ".join(human_messages)
360
+ if self.verbose:
361
+ print(f"🔗 Combined context: {combined}")
362
+ return combined
363
+
364
+
365
+ # ========================================================================
366
+ # NODE FUNCTIONS
367
+ # ========================================================================
368
+
369
+ def _generate_research_brief(self, state: SurveyAnalysisState) -> Dict[str, Any]:
370
+ """Generate research brief - decides single-stage vs multi-stage approach"""
371
+
372
+ if self.verbose:
373
+ print("\n=== GENERATING RESEARCH BRIEF ===")
374
+
375
+ # Get full question context from conversation history
376
+ question = self._get_full_question_context(state)
377
+ original_question = state["user_question"] # Keep original for reference
378
+
379
+ if self.verbose and question != original_question:
380
+ print(f"💬 Using full context from conversation history")
381
+
382
+ retry_count = state.get("retry_count", 0)
383
+
384
+ # Add context from verification if this is a retry
385
+ verification_context = ""
386
+ if state.get("verification") and retry_count > 0:
387
+ verification_context = f"""
388
+ Previous attempt was insufficient:
389
+ - Missing: {state['verification'].missing_info}
390
+ - Suggestion: {state['verification'].improvement_suggestion}
391
+
392
+ Please improve the research plan based on this feedback.
393
+ """
394
+
395
+ system_prompt = f"""You are a research planning expert for survey data analysis.
396
+
397
+ # TODO: REMOVE WHEN PIPELINES READY - Use dynamic status
398
+ Available data sources:
399
+ {self._get_pipeline_status_description()}
400
+
401
+ # TODO: REMOVE WHEN PIPELINES READY - START
402
+ ⚠️ IMPORTANT: Currently ONLY the questionnaire pipeline is available.
403
+ - Do NOT create research plans that require toplines, crosstabs, or SQL
404
+ - If the user asks for results/data/analysis that requires those sources, use action="followup" to inform them
405
+ - Focus on what CAN be answered with questionnaires alone (question text, response options, topics, skip logic)
406
+ # TODO: REMOVE WHEN PIPELINES READY - END
407
+
408
+ {self._get_available_surveys_description()}
409
+
410
+ You have FOUR possible actions:
411
+
412
+ **1. followup** - Ask clarifying question if ambiguous OR if user asks for unavailable data
413
+
414
+ **2. answer** - Answer directly without data (system questions, general knowledge)
415
+
416
+ **3. route_to_sources** - Simple query that can be answered with parallel data retrieval
417
+ Use this for:
418
+ - "What questions were asked in June 2025?"
419
+ - "Show me all healthcare questions"
420
+ - Questions that don't require sequential reasoning
421
+
422
+ **4. execute_stages** - Complex query requiring STAGED research
423
+ Use this for:
424
+ - Queries with "most/least/best/worst" (need stage 1: retrieve, stage 2: analyze)
425
+ - Comparative queries "compare 2024 vs 2025" (need separate stages to maintain context)
426
+ - Queries depending on intermediate results
427
+ - "What demographics differ most?" (stage 1: get questions, stage 2: get crosstabs for those questions)
428
+
429
+ # TODO: REMOVE WHEN PIPELINES READY - START
430
+ NOTE: Since toplines/crosstabs/SQL aren't available, only use execute_stages for comparing questionnaires
431
+ # TODO: REMOVE WHEN PIPELINES READY - END
432
+
433
+ When using stages:
434
+ - Each stage can use results from previous stages via `use_previous_results_for`
435
+ - Later stages can filter by question_ids extracted from earlier stages
436
+ - Each stage can have a `result_label` to maintain separate contexts
437
+
438
+ CRITICAL FILTERING RULES:
439
+ - **Survey Names**: User queries like "Unity Poll" or "Vanderbilt Unity Poll" should map to the exact stored name shown above
440
+ - When you see "Unity Poll" in a query, use the exact stored name in your filter
441
+ - Only specify filters if explicitly mentioned or clearly implied
442
+ - For staged queries, be explicit about how each stage uses previous results
443
+ - Use `question_ids` filter when later stages need specific questions from earlier stages
444
+ - Year and month are usually sufficient - survey_name is optional unless needed for disambiguation
445
+
446
+ {verification_context}
447
+
448
+ Examples:
449
+
450
+ # TODO: REMOVE WHEN PIPELINES READY - START
451
+ User asks for results/analysis → Inform them:
452
+ Q: "What were the topline results for June 2025?"
453
+ Brief:
454
+ action: followup
455
+ followup_question: "I can show you the questions asked in June 2025, but topline results aren't available yet. Would you like to see the questions?"
456
+ # TODO: REMOVE WHEN PIPELINES READY - END
457
+
458
+ User says "Unity Poll" → Use stored name in filter:
459
+ Q: "What questions were asked in June 2025 Unity Poll?"
460
+ Brief:
461
+ action: route_to_sources
462
+ data_sources: [questionnaire with year=2025, month=June, survey_name='Vanderbilt_Unity_Poll']
463
+
464
+ Simple Query → route_to_sources:
465
+ Q: "What questions were asked in June 2025?"
466
+ Brief:
467
+ action: route_to_sources
468
+ data_sources: [questionnaire with June 2025 filters]
469
+
470
+ Complex Query → execute_stages:
471
+ Q: "Compare immigration questions from 2024 vs 2025"
472
+ Brief:
473
+ action: execute_stages
474
+ stages:
475
+ - stage 1: Get 2024 immigration questions (label: "2024_questions")
476
+ - stage 2: Get 2025 immigration questions (label: "2025_questions")
477
+ - stage 3: Compare the two sets in synthesis
478
+ """
479
+
480
+ brief_generator = self.llm.with_structured_output(ResearchBrief)
481
+
482
+ brief = brief_generator.invoke([
483
+ SystemMessage(content=system_prompt),
484
+ HumanMessage(content=f"User question: {question}\n\nGenerate a research brief.")
485
+ ])
486
+
487
+ if self.verbose:
488
+ print(f"Action: {brief.action}")
489
+ print(f"Reasoning: {brief.reasoning}")
490
+
491
+ if brief.followup_question:
492
+ print(f"Follow-up: {brief.followup_question}")
493
+
494
+ if brief.action == "route_to_sources" and brief.data_sources:
495
+ print(f"Simple query - {len(brief.data_sources)} data sources")
496
+ for ds in brief.data_sources:
497
+ filters_dict = {k: v for k, v in ds.filters.model_dump().items() if v is not None}
498
+ print(f" - {ds.source_type}: {ds.query_description}")
499
+ if filters_dict:
500
+ print(f" Filters: {filters_dict}")
501
+
502
+ if brief.action == "execute_stages" and brief.stages:
503
+ print(f"Staged query - {len(brief.stages)} stages")
504
+ for stage in brief.stages:
505
+ print(f"\nStage {stage.stage_number}: {stage.description}")
506
+ if stage.depends_on_stages:
507
+ print(f" Depends on: stages {stage.depends_on_stages}")
508
+ if stage.use_previous_results_for:
509
+ print(f" Uses previous: {stage.use_previous_results_for}")
510
+ for ds in stage.data_sources:
511
+ print(f" - {ds.source_type}: {ds.query_description}")
512
+ if ds.result_label:
513
+ print(f" Label: {ds.result_label}")
514
+
515
+ return {
516
+ "research_brief": brief,
517
+ "current_stage": 0, # Start at stage 0 (will execute stage 1 first)
518
+ "stage_results": [],
519
+ "messages": [AIMessage(content=f"[Research plan: {brief.action}]")]
520
+ }
521
+
522
+ def _route_after_brief(self, state: SurveyAnalysisState) -> str:
523
+ """Route based on research brief action"""
524
+ brief = state["research_brief"]
525
+
526
+ if brief.action == "followup":
527
+ return "followup"
528
+ elif brief.action == "answer":
529
+ return "answer"
530
+ elif brief.action == "execute_stages":
531
+ return "execute_stage"
532
+ else: # route_to_sources
533
+ return "execute_stage" # We'll handle both single and staged in execute_stage
534
+
535
+ def _execute_stage(self, state: SurveyAnalysisState) -> Dict[str, Any]:
536
+ """Execute one stage of research (handles both single-stage and multi-stage)"""
537
+
538
+ brief = state["research_brief"]
539
+ current_stage_idx = state.get("current_stage", 0)
540
+ previous_stage_results = state.get("stage_results", [])
541
+
542
+ # Determine if this is single-stage or multi-stage
543
+ if brief.action == "route_to_sources":
544
+ # Single-stage: use data_sources directly
545
+ if self.verbose:
546
+ print(f"\n=== EXECUTING SINGLE-STAGE RESEARCH ===")
547
+
548
+ stage_data_sources = brief.data_sources
549
+ stage_desc = "Single-stage retrieval"
550
+
551
+ elif brief.action == "execute_stages":
552
+ # Multi-stage: get current stage
553
+ stage = brief.stages[current_stage_idx]
554
+
555
+ if self.verbose:
556
+ print(f"\n=== EXECUTING STAGE {stage.stage_number}/{len(brief.stages)} ===")
557
+ print(f"Description: {stage.description}")
558
+
559
+ stage_data_sources = stage.data_sources
560
+ stage_desc = stage.description
561
+
562
+ # If this stage depends on previous stages, enrich filters with context
563
+ if stage.use_previous_results_for and previous_stage_results:
564
+ stage_data_sources = self._enrich_data_sources_with_context(
565
+ stage_data_sources,
566
+ previous_stage_results,
567
+ stage.use_previous_results_for
568
+ )
569
+ else:
570
+ return {}
571
+
572
+ # Execute pipelines for this stage
573
+ stage_result = StageResult(
574
+ stage_number=current_stage_idx + 1,
575
+ status="success"
576
+ )
577
+
578
+ # TODO: REMOVE WHEN PIPELINES READY - Track what was attempted vs available
579
+ attempted_pipelines = []
580
+ unavailable_pipelines = []
581
+
582
+ # Run each pipeline
583
+ for ds in stage_data_sources:
584
+ filters_dict = {k: v for k, v in ds.filters.model_dump().items() if v is not None}
585
+
586
+ # TODO: REMOVE WHEN PIPELINES READY - START
587
+ attempted_pipelines.append(ds.source_type)
588
+ # TODO: REMOVE WHEN PIPELINES READY - END
589
+
590
+ if ds.source_type == "questionnaire":
591
+ if self.verbose:
592
+ print(f"\nQuerying questionnaire: {ds.query_description}")
593
+ if filters_dict:
594
+ print(f"Filters: {filters_dict}")
595
+
596
+ result = self.questionnaire_rag.query_with_metadata(
597
+ question=ds.query_description,
598
+ filters=filters_dict if filters_dict else None
599
+ )
600
+
601
+ # Store with label if provided
602
+ if ds.result_label:
603
+ result["label"] = ds.result_label
604
+
605
+ stage_result.questionnaire_results = result if stage_result.questionnaire_results is None else {
606
+ "multiple": True,
607
+ "results": [stage_result.questionnaire_results, result]
608
+ }
609
+
610
+ if self.verbose:
611
+ print(f"Retrieved {result['num_sources']} questions")
612
+
613
+ # TODO: REMOVE WHEN PIPELINES READY - START
614
+ elif ds.source_type not in self.AVAILABLE_PIPELINES:
615
+ unavailable_pipelines.append(ds.source_type)
616
+ if self.verbose:
617
+ print(f"\n⚠️ {ds.source_type.upper()} pipeline not yet available - skipping")
618
+ print(f" Requested: {ds.query_description}")
619
+ # TODO: REMOVE WHEN PIPELINES READY - END
620
+
621
+ # TODO: REMOVE WHEN PIPELINES READY - START
622
+ # Add a note about unavailable pipelines to the stage result
623
+ if unavailable_pipelines:
624
+ if self.verbose:
625
+ print(f"\n⚠️ Stage {current_stage_idx + 1} incomplete: {len(unavailable_pipelines)} pipeline(s) unavailable")
626
+ stage_result.status = "partial"
627
+ # Store info about what was unavailable for the synthesizer
628
+ if not stage_result.extracted_context:
629
+ stage_result.extracted_context = {}
630
+ stage_result.extracted_context["unavailable_pipelines"] = unavailable_pipelines
631
+ # TODO: REMOVE WHEN PIPELINES READY - END
632
+
633
+ # Add stage result to list
634
+ updated_stage_results = previous_stage_results + [stage_result]
635
+
636
+ # For single-stage, also populate legacy fields
637
+ if brief.action == "route_to_sources":
638
+ return {
639
+ "stage_results": updated_stage_results,
640
+ "questionnaire_results": stage_result.questionnaire_results,
641
+ "toplines_results": stage_result.toplines_results,
642
+ "crosstabs_results": stage_result.crosstabs_results,
643
+ "sql_results": stage_result.sql_results
644
+ }
645
+
646
+ return {
647
+ "stage_results": updated_stage_results
648
+ }
649
+
650
+ def _enrich_data_sources_with_context(
651
+ self,
652
+ data_sources: List[DataSource],
653
+ previous_results: List[StageResult],
654
+ use_instruction: str
655
+ ) -> List[DataSource]:
656
+ """Enrich data sources with context from previous stages"""
657
+
658
+ if self.verbose:
659
+ print(f" Enriching with context: {use_instruction}")
660
+
661
+ # For now, handle the most common case: extracting question IDs
662
+ if "question" in use_instruction.lower() and "id" in use_instruction.lower():
663
+ # Extract question IDs from previous questionnaire results
664
+ question_ids = []
665
+ for prev_result in previous_results:
666
+ if prev_result.questionnaire_results:
667
+ q_results = prev_result.questionnaire_results
668
+ if "source_questions" in q_results:
669
+ question_ids.extend([q.get("question_id") for q in q_results["source_questions"]])
670
+
671
+ if question_ids and self.verbose:
672
+ print(f" Found {len(question_ids)} question IDs from previous stages")
673
+
674
+ # Add question_ids to filters
675
+ enriched_sources = []
676
+ for ds in data_sources:
677
+ new_filters = ds.filters.model_copy()
678
+ new_filters.question_ids = question_ids if question_ids else None
679
+
680
+ enriched_ds = ds.model_copy()
681
+ enriched_ds.filters = new_filters
682
+ enriched_sources.append(enriched_ds)
683
+
684
+ return enriched_sources
685
+
686
+ return data_sources
687
+
688
+ def _extract_stage_context(self, state: SurveyAnalysisState) -> Dict[str, Any]:
689
+ """Extract key context from completed stage for use in next stages"""
690
+
691
+ stage_results = state.get("stage_results", [])
692
+ if not stage_results:
693
+ return {}
694
+
695
+ current_result = stage_results[-1]
696
+
697
+ # Extract question IDs if questionnaire results exist
698
+ extracted_context = {}
699
+
700
+ if current_result.questionnaire_results:
701
+ q_results = current_result.questionnaire_results
702
+ if "source_questions" in q_results:
703
+ question_ids = [q.get("question_id") for q in q_results["source_questions"]]
704
+ extracted_context["question_ids"] = question_ids
705
+
706
+ if self.verbose:
707
+ print(f"\n=== EXTRACTED CONTEXT FROM STAGE {current_result.stage_number} ===")
708
+ print(f"Question IDs: {len(question_ids)} extracted")
709
+
710
+ # Update the stage result with extracted context
711
+ current_result.extracted_context = extracted_context
712
+
713
+ return {}
714
+
715
+ def _route_after_stage(self, state: SurveyAnalysisState) -> str:
716
+ """Decide if we need to execute another stage or move to verification"""
717
+
718
+ brief = state["research_brief"]
719
+ current_stage_idx = state.get("current_stage", 0)
720
+
721
+ # Single-stage query
722
+ if brief.action == "route_to_sources":
723
+ if self.verbose:
724
+ print("\n=== SINGLE-STAGE COMPLETE → VERIFICATION ===")
725
+ return "verify"
726
+
727
+ # Multi-stage query
728
+ total_stages = len(brief.stages)
729
+ next_stage_idx = current_stage_idx + 1
730
+
731
+ if next_stage_idx < total_stages:
732
+ if self.verbose:
733
+ print(f"\n=== MORE STAGES REMAINING ({next_stage_idx + 1}/{total_stages}) → NEXT STAGE ===")
734
+ return "next_stage"
735
+ else:
736
+ if self.verbose:
737
+ print(f"\n=== ALL {total_stages} STAGES COMPLETE → VERIFICATION ===")
738
+ return "verify"
739
+
740
+ def _verify_results(self, state: SurveyAnalysisState) -> Dict[str, Any]:
741
+ """Verify that retrieved data answers the question"""
742
+
743
+ if self.verbose:
744
+ print("\n=== VERIFYING RESULTS ===")
745
+
746
+ # Build full question context from conversation history
747
+ question = self._get_full_question_context(state)
748
+
749
+ if self.verbose and question != state["user_question"]:
750
+ print(f"💬 Using full context: {question[:150]}...")
751
+
752
+ stage_results = state.get("stage_results", [])
753
+ brief = state["research_brief"]
754
+
755
+ # Build summary of what we retrieved
756
+ retrieval_summary = []
757
+ total_questions = 0
758
+
759
+ # TODO: REMOVE WHEN PIPELINES READY - START
760
+ unavailable_pipelines_found = []
761
+ # TODO: REMOVE WHEN PIPELINES READY - END
762
+
763
+ for stage_result in stage_results:
764
+ if stage_result.questionnaire_results:
765
+ q_res = stage_result.questionnaire_results
766
+ num = q_res.get("num_sources", 0)
767
+ total_questions += num
768
+ retrieval_summary.append(f"Stage {stage_result.stage_number}: Retrieved {num} questions")
769
+
770
+ # TODO: REMOVE WHEN PIPELINES READY - START
771
+ # Check if any pipelines were unavailable
772
+ if stage_result.extracted_context and "unavailable_pipelines" in stage_result.extracted_context:
773
+ unavailable = stage_result.extracted_context["unavailable_pipelines"]
774
+ unavailable_pipelines_found.extend(unavailable)
775
+ retrieval_summary.append(f"Stage {stage_result.stage_number}: ⚠️ {', '.join(unavailable)} not yet available")
776
+ # TODO: REMOVE WHEN PIPELINES READY - END
777
+
778
+ if not retrieval_summary:
779
+ retrieval_summary.append("No data was retrieved")
780
+ # Simple heuristic: if this is a single-stage simple query and we got results, auto-pass
781
+ if brief.action == "route_to_sources" and len(stage_results) == 1 and total_questions > 0:
782
+ # Check if question is a simple "what questions" type query
783
+ question_lower = question.lower()
784
+ simple_patterns = ["what question", "which question", "list question", "show question", "questions asked"]
785
+
786
+ if any(pattern in question_lower for pattern in simple_patterns):
787
+ if self.verbose:
788
+ print(f"✓ Auto-pass: Simple question retrieval with {total_questions} results")
789
+
790
+ return {
791
+ "verification": VerificationResult(
792
+ answers_question=True,
793
+ missing_info=None,
794
+ improvement_suggestion=None
795
+ )
796
+ }
797
+
798
+ # TODO: REMOVE WHEN PIPELINES READY - START
799
+ # If we have unavailable pipelines but got questionnaire data, auto-pass with note
800
+ if unavailable_pipelines_found and total_questions > 0:
801
+ if self.verbose:
802
+ print(f"✓ Auto-pass: Got questionnaire data, {len(unavailable_pipelines_found)} pipeline(s) not yet available")
803
+
804
+ return {
805
+ "verification": VerificationResult(
806
+ answers_question=True,
807
+ missing_info=None,
808
+ improvement_suggestion=None
809
+ )
810
+ }
811
+ # TODO: REMOVE WHEN PIPELINES READY - END
812
+
813
+ # If we got 0 results, auto-fail without calling LLM
814
+ if total_questions == 0:
815
+ if self.verbose:
816
+ print("✗ Auto-fail: No results retrieved")
817
+
818
+ return {
819
+ "verification": VerificationResult(
820
+ answers_question=False,
821
+ missing_info="No data was retrieved",
822
+ improvement_suggestion="Adjust filters or search criteria"
823
+ ),
824
+ "retry_count": state.get("retry_count", 0) + 1
825
+ }
826
+
827
+ # For other cases, use LLM verification
828
+ system_prompt = """You are a verification expert. Your ONLY job is to check if the retrieved data matches what the user asked for.
829
+
830
+ CRITICAL RULES:
831
+ 1. **Match the question literally** - Don't add requirements the user didn't ask for
832
+ - If they asked "what questions were asked?" and we retrieved questions → SUCCESS
833
+ - If they asked "what are the results?" and we only have questions → FAILURE
834
+
835
+ 2. **Don't overthink it** - Keep it simple:
836
+ - Did we retrieve the type of data they asked for? (questions, results, etc.)
837
+ - Is it from the right time period/survey they specified?
838
+ - Is there enough data (at least 1 result)?
839
+
840
+ 3. **Only fail if there's an actual problem**:
841
+ - We retrieved the wrong type of data (e.g., questions when they asked for results)
842
+ - We retrieved from the wrong time period/survey
843
+
844
+ 4. **Do NOT fail if**:
845
+ - User asked for questions and we got questions (even if we don't have "analysis")
846
+ - User asked for data from June 2025 and that's what we got
847
+ - The data seems sufficient to answer their actual question
848
+
849
+ Be practical, not pedantic. If the retrieved data can answer what they asked, approve it.
850
+ """
851
+
852
+ verifier = self.llm.with_structured_output(VerificationResult)
853
+
854
+ verification = verifier.invoke([
855
+ SystemMessage(content=system_prompt),
856
+ HumanMessage(content=f"""
857
+ User question: "{question}"
858
+
859
+ What we retrieved:
860
+ {chr(10).join(retrieval_summary)}
861
+
862
+ Simple question: Can we answer their question with this data? YES or NO.
863
+ """)
864
+ ])
865
+
866
+ if self.verbose:
867
+ print(f"Answers question: {verification.answers_question}")
868
+ if not verification.answers_question:
869
+ print(f"Missing: {verification.missing_info}")
870
+ print(f"Suggestion: {verification.improvement_suggestion}")
871
+
872
+ # ⭐ INCREMENT RETRY COUNT IF VERIFICATION FAILS
873
+ updates = {"verification": verification}
874
+ if not verification.answers_question:
875
+ current_retry = state.get("retry_count", 0)
876
+ updates["retry_count"] = current_retry + 1
877
+
878
+ return updates
879
+
880
+ def _route_after_verification(self, state: SurveyAnalysisState) -> str:
881
+ """Route based on verification result"""
882
+
883
+ verification = state["verification"]
884
+ retry_count = state.get("retry_count", 0)
885
+ max_retries = state.get("max_retries", self.max_retries)
886
+
887
+ if verification.answers_question:
888
+ return "synthesize"
889
+ elif retry_count < max_retries:
890
+ if self.verbose:
891
+ print(f"\n⚠️ Retry {retry_count + 1}/{max_retries}")
892
+ return "retry"
893
+ else:
894
+ if self.verbose:
895
+ print(f"\n⚠️ Max retries reached, proceeding with partial results")
896
+ return "give_up"
897
+
898
+ def _synthesize_response(self, state: SurveyAnalysisState) -> Dict[str, Any]:
899
+ """Synthesize final response from all results"""
900
+
901
+ if self.verbose:
902
+ print("\n=== SYNTHESIZING RESPONSE ===")
903
+
904
+ brief = state["research_brief"]
905
+
906
+ # Get full question context from conversation history
907
+ full_question = self._get_full_question_context(state)
908
+
909
+ if self.verbose and full_question != state["user_question"]:
910
+ print(f"💬 Using full context: {full_question[:150]}...")
911
+
912
+ # Handle followup action
913
+ if brief.action == "followup":
914
+ if self.verbose:
915
+ print("Returning followup question")
916
+ return {
917
+ "final_answer": brief.followup_question,
918
+ "messages": [AIMessage(content=brief.followup_question)]
919
+ }
920
+
921
+ # Handle direct answer (no data retrieval)
922
+ if brief.action == "answer":
923
+ if self.verbose:
924
+ print("Generating direct answer without data")
925
+ answer = self.llm.invoke([
926
+ SystemMessage(content="Answer the user's question directly."),
927
+ HumanMessage(content=full_question)
928
+ ]).content
929
+
930
+ return {
931
+ "final_answer": answer,
932
+ "messages": [AIMessage(content=answer)]
933
+ }
934
+
935
+ # Get stage results
936
+ stage_results = state.get("stage_results", [])
937
+
938
+ if not stage_results:
939
+ if self.verbose:
940
+ print("No stage results available")
941
+ return {
942
+ "final_answer": "I was unable to retrieve any data to answer your question.",
943
+ "messages": [AIMessage(content="I was unable to retrieve any data to answer your question.")]
944
+ }
945
+
946
+ # CASE 1: Single stage with single pipeline → return direct answer
947
+ if len(stage_results) == 1:
948
+ stage_result = stage_results[0]
949
+
950
+ # Check if only one pipeline returned data
951
+ pipelines_with_data = 0
952
+ direct_answer = None
953
+
954
+ if stage_result.questionnaire_results:
955
+ pipelines_with_data += 1
956
+ direct_answer = stage_result.questionnaire_results.get("answer")
957
+
958
+ if pipelines_with_data == 1 and direct_answer:
959
+ if self.verbose:
960
+ print("Single stage, single pipeline - returning direct answer (no synthesis)")
961
+ return {
962
+ "final_answer": direct_answer,
963
+ "messages": [AIMessage(content=direct_answer)]
964
+ }
965
+
966
+ # CASE 2: Multiple stages or multiple pipelines → synthesize
967
+ if self.verbose:
968
+ print(f"Synthesizing from {len(stage_results)} stage(s)")
969
+
970
+ # Build context from all stages
971
+ context_parts = []
972
+
973
+ # TODO: REMOVE WHEN PIPELINES READY - START
974
+ unavailable_pipelines_overall = []
975
+ # TODO: REMOVE WHEN PIPELINES READY - END
976
+
977
+ for i, stage_result in enumerate(stage_results, 1):
978
+ if stage_result.questionnaire_results:
979
+ q_res = stage_result.questionnaire_results
980
+
981
+ # Check if this is a labeled result
982
+ label = q_res.get("label", f"Stage {i}")
983
+
984
+ context_parts.append(f"\n=== {label.upper()} ===")
985
+ context_parts.append(f"Stage {i} results:")
986
+ context_parts.append(q_res.get("answer", "No answer available"))
987
+
988
+ # TODO: REMOVE WHEN PIPELINES READY - START
989
+ # Track unavailable pipelines for note in synthesis
990
+ if stage_result.extracted_context and "unavailable_pipelines" in stage_result.extracted_context:
991
+ unavailable = stage_result.extracted_context["unavailable_pipelines"]
992
+ unavailable_pipelines_overall.extend(unavailable)
993
+ context_parts.append(f"\n⚠️ Note: {', '.join(unavailable)} data was requested but not yet available")
994
+ # TODO: REMOVE WHEN PIPELINES READY - END
995
+
996
+ # TODO: REMOVE WHEN PIPELINES READY - START
997
+ unavailable_note = ""
998
+ if unavailable_pipelines_overall:
999
+ unique_unavailable = list(set(unavailable_pipelines_overall))
1000
+ unavailable_note = f"""
1001
+
1002
+ ⚠️ IMPORTANT: The following data sources were requested but are not yet available:
1003
+ {', '.join(unique_unavailable).upper()}
1004
+
1005
+ Please answer based on the questionnaire data that IS available, and note any limitations.
1006
+ """
1007
+ # TODO: REMOVE WHEN PIPELINES READY - END
1008
+
1009
+ synthesis_prompt = f"""Synthesize results from {'multiple stages' if len(stage_results) > 1 else 'the research'} to answer the user's question.
1010
+
1011
+ User question: {full_question}
1012
+
1013
+ Research plan: {brief.reasoning}
1014
+
1015
+ Retrieved data:
1016
+ {chr(10).join(context_parts)}
1017
+
1018
+ {unavailable_note}
1019
+
1020
+ Instructions:
1021
+ - If this is a comparative query, clearly organize by the comparison dimensions
1022
+ - If this is an analytical query (most/least/best/worst), perform the analysis
1023
+ - Preserve important details from the research
1024
+ - Use natural language, be clear and organized
1025
+ - Cite which poll(s) or stage(s) information comes from
1026
+ - Do NOT make up information not in the retrieved data
1027
+ - TODO: REMOVE WHEN PIPELINES READY - If some data sources weren't available, clearly state this and explain what you CAN provide
1028
+ """
1029
+
1030
+ final_answer = self.llm.invoke([
1031
+ SystemMessage(content="You are a survey data analyst synthesizing research results."),
1032
+ HumanMessage(content=synthesis_prompt)
1033
+ ]).content
1034
+
1035
+ if self.verbose:
1036
+ print("Synthesis complete")
1037
+
1038
+ return {
1039
+ "final_answer": final_answer,
1040
+ "messages": [AIMessage(content=final_answer)]
1041
+ }
1042
+
1043
+ # ========================================================================
1044
+ # PUBLIC API
1045
+ # ========================================================================
1046
+
1047
+ def query(self, question: str, thread_id: str = "default") -> str:
1048
+ """
1049
+ Query the survey analysis system.
1050
+
1051
+ Args:
1052
+ question: User's question
1053
+ thread_id: Conversation thread ID for memory
1054
+
1055
+ Returns:
1056
+ Answer string
1057
+
1058
+ Note: When using the same thread_id across multiple calls, the conversation
1059
+ context is preserved. For example:
1060
+ - Call 1: query("what questions were asked?", thread_id="user_123")
1061
+ - Call 2: query("June 2025, unity poll", thread_id="user_123")
1062
+
1063
+ The second call will understand the full context.
1064
+ """
1065
+
1066
+ # Create initial state for this turn
1067
+ # Note: LangGraph's operator.add annotation will append to existing messages
1068
+ # from the checkpointer, not replace them
1069
+ initial_state = {
1070
+ "messages": [HumanMessage(content=question)],
1071
+ "user_question": question,
1072
+ "research_brief": None,
1073
+ "current_stage": 0,
1074
+ "stage_results": [],
1075
+ "questionnaire_results": None,
1076
+ "toplines_results": None,
1077
+ "crosstabs_results": None,
1078
+ "sql_results": None,
1079
+ "verification": None,
1080
+ "final_answer": None,
1081
+ "retry_count": 0,
1082
+ "max_retries": self.max_retries
1083
+ }
1084
+
1085
+ config = {"configurable": {"thread_id": thread_id}}
1086
+
1087
+ if self.verbose:
1088
+ print(f"\n🧵 Thread ID: {thread_id}")
1089
+
1090
+ final_state = self.graph.invoke(initial_state, config)
1091
+
1092
+ return final_state["final_answer"]
1093
+
1094
+ def stream_query(self, question: str, thread_id: str = "default"):
1095
+ """Stream the query execution for real-time updates"""
1096
+
1097
+ initial_state = {
1098
+ "messages": [HumanMessage(content=question)],
1099
+ "user_question": question,
1100
+ "research_brief": None,
1101
+ "current_stage": 0,
1102
+ "stage_results": [],
1103
+ "questionnaire_results": None,
1104
+ "toplines_results": None,
1105
+ "crosstabs_results": None,
1106
+ "sql_results": None,
1107
+ "verification": None,
1108
+ "final_answer": None,
1109
+ "retry_count": 0,
1110
+ "max_retries": self.max_retries
1111
+ }
1112
+
1113
+ config = {"configurable": {"thread_id": thread_id}}
1114
+
1115
+ for event in self.graph.stream(initial_state, config):
1116
+ yield event
1117
+
1118
+
1119
+ # ============================================================================
1120
+ # CLI INTERFACE
1121
+ # ============================================================================
1122
+
1123
+ def main():
1124
+ """Interactive CLI"""
1125
+ import sys
1126
+
1127
+ openai_api_key = os.getenv("OPENAI_API_KEY")
1128
+ pinecone_api_key = os.getenv("PINECONE_API_KEY")
1129
+
1130
+ if not openai_api_key or not pinecone_api_key:
1131
+ print("Error: Missing API keys")
1132
+ print("Set OPENAI_API_KEY and PINECONE_API_KEY environment variables")
1133
+ sys.exit(1)
1134
+
1135
+ print("Initializing survey analysis agent...")
1136
+ agent = SurveyAnalysisAgent(
1137
+ openai_api_key=openai_api_key,
1138
+ pinecone_api_key=pinecone_api_key,
1139
+ verbose=True
1140
+ )
1141
+
1142
+ print("\n" + "="*80)
1143
+ print("SURVEY ANALYSIS AGENT (WITH STAGED RESEARCH)")
1144
+ print("="*80)
1145
+ print("\nType 'quit' to exit\n")
1146
+
1147
+ thread_id = "cli_session"
1148
+
1149
+ while True:
1150
+ try:
1151
+ question = input("\nYour question: ").strip()
1152
+
1153
+ if not question or question.lower() in ['quit', 'exit', 'q']:
1154
+ print("\nGoodbye!")
1155
+ break
1156
+
1157
+ print("\n" + "-"*80)
1158
+ answer = agent.query(question, thread_id=thread_id)
1159
+ print("\n" + "="*80)
1160
+ print("ANSWER:")
1161
+ print("="*80)
1162
+ print(answer)
1163
+ print("="*80)
1164
+
1165
+ except KeyboardInterrupt:
1166
+ print("\n\nGoodbye!")
1167
+ break
1168
+ except Exception as e:
1169
+ print(f"\nError: {e}")
1170
+ if os.getenv("DEBUG"):
1171
+ raise
1172
+
1173
+
1174
+ if __name__ == "__main__":
1175
+ main()