Spaces:
Build error
Build error
Upload 6 files
Browse files- app.py +247 -0
- questionnaire_rag.py +592 -0
- questionnaire_vectorstores/poll_catalog.json +74 -0
- questionnaire_vectorstores/questions_index.json +0 -0
- requirements.txt +9 -0
- survey_agent.py +1175 -0
app.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio interface for Survey Analysis Agent
|
| 3 |
+
Host on Hugging Face Spaces
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import gradio as gr
|
| 8 |
+
from survey_agent import SurveyAnalysisAgent
|
| 9 |
+
import uuid
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
|
| 12 |
+
# Initialize agent (will be done once at startup)
|
| 13 |
+
agent = None
|
| 14 |
+
initialization_error = None
|
| 15 |
+
|
| 16 |
+
def initialize_agent():
|
| 17 |
+
"""Initialize the agent with API keys from environment"""
|
| 18 |
+
global agent, initialization_error
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
openai_api_key = os.getenv("OPENAI_API_KEY")
|
| 22 |
+
pinecone_api_key = os.getenv("PINECONE_API_KEY")
|
| 23 |
+
|
| 24 |
+
if not openai_api_key:
|
| 25 |
+
initialization_error = "❌ OPENAI_API_KEY not found. Please set it in Space Settings → Repository Secrets."
|
| 26 |
+
return False
|
| 27 |
+
|
| 28 |
+
if not pinecone_api_key:
|
| 29 |
+
initialization_error = "❌ PINECONE_API_KEY not found. Please set it in Space Settings → Repository Secrets."
|
| 30 |
+
return False
|
| 31 |
+
|
| 32 |
+
# Check if vector store exists
|
| 33 |
+
if not os.path.exists("./questionnaire_vectorstores"):
|
| 34 |
+
initialization_error = "❌ Vector store directory not found. Please upload the questionnaire_vectorstores folder."
|
| 35 |
+
return False
|
| 36 |
+
|
| 37 |
+
agent = SurveyAnalysisAgent(
|
| 38 |
+
openai_api_key=openai_api_key,
|
| 39 |
+
pinecone_api_key=pinecone_api_key,
|
| 40 |
+
verbose=False # Set to False for cleaner UI
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
return True
|
| 44 |
+
|
| 45 |
+
except Exception as e:
|
| 46 |
+
initialization_error = f"❌ Initialization error: {str(e)}"
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def chat(message, history, session_id):
|
| 51 |
+
"""
|
| 52 |
+
Handle chat interaction
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
message: User's message
|
| 56 |
+
history: Chat history (list of [user_msg, bot_msg] pairs)
|
| 57 |
+
session_id: Unique session identifier for conversation memory
|
| 58 |
+
"""
|
| 59 |
+
if initialization_error:
|
| 60 |
+
return initialization_error
|
| 61 |
+
|
| 62 |
+
if not agent:
|
| 63 |
+
return "⚠️ Agent not initialized. Please refresh the page."
|
| 64 |
+
|
| 65 |
+
if not message.strip():
|
| 66 |
+
return "Please enter a question."
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
# Use session_id as thread_id for conversation memory
|
| 70 |
+
answer = agent.query(message, thread_id=session_id)
|
| 71 |
+
return answer
|
| 72 |
+
|
| 73 |
+
except Exception as e:
|
| 74 |
+
error_msg = f"❌ Error processing query: {str(e)}"
|
| 75 |
+
print(f"Error details: {e}") # Log to console
|
| 76 |
+
return error_msg
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def create_new_session():
|
| 80 |
+
"""Create a new session ID"""
|
| 81 |
+
return str(uuid.uuid4())
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_available_surveys():
|
| 85 |
+
"""Get list of available surveys"""
|
| 86 |
+
if initialization_error or not agent:
|
| 87 |
+
return "Agent not initialized"
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
surveys = agent.questionnaire_rag.get_available_survey_names()
|
| 91 |
+
polls = agent.questionnaire_rag.get_available_polls()
|
| 92 |
+
|
| 93 |
+
info = "## Available Surveys\n\n"
|
| 94 |
+
info += f"**Survey Names:** {', '.join(surveys)}\n\n"
|
| 95 |
+
info += "## Available Polls\n\n"
|
| 96 |
+
|
| 97 |
+
for poll in polls:
|
| 98 |
+
info += f"- **{poll['poll_date']}** ({poll['month']} {poll['year']}): {poll['survey_name']} - {poll['num_questions']} questions\n"
|
| 99 |
+
|
| 100 |
+
return info
|
| 101 |
+
except Exception as e:
|
| 102 |
+
return f"Error retrieving survey info: {str(e)}"
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# Initialize agent at startup
|
| 106 |
+
print("🚀 Initializing Survey Analysis Agent...")
|
| 107 |
+
init_success = initialize_agent()
|
| 108 |
+
|
| 109 |
+
if init_success:
|
| 110 |
+
print("✅ Agent initialized successfully!")
|
| 111 |
+
else:
|
| 112 |
+
print(f"⚠️ Agent initialization failed: {initialization_error}")
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# Create Gradio interface
|
| 116 |
+
with gr.Blocks(title="Survey Analysis Agent", theme=gr.themes.Soft()) as demo:
|
| 117 |
+
|
| 118 |
+
# Header
|
| 119 |
+
gr.Markdown("""
|
| 120 |
+
# 📊 Survey Analysis Agent
|
| 121 |
+
|
| 122 |
+
Ask questions about survey data using natural language. The agent can:
|
| 123 |
+
- Find questions from specific surveys and time periods
|
| 124 |
+
- Compare questions across different time periods
|
| 125 |
+
- Analyze question topics and themes
|
| 126 |
+
- Show sampling logic and question flow
|
| 127 |
+
|
| 128 |
+
**Note:** Currently only questionnaire data is available (questions, topics, response options, skip logic).
|
| 129 |
+
""")
|
| 130 |
+
|
| 131 |
+
# Show initialization status
|
| 132 |
+
if initialization_error:
|
| 133 |
+
gr.Markdown(f"## ⚠️ Setup Required\n\n{initialization_error}")
|
| 134 |
+
|
| 135 |
+
# Session state
|
| 136 |
+
session_id_state = gr.State(value=create_new_session())
|
| 137 |
+
|
| 138 |
+
# Main chat interface
|
| 139 |
+
with gr.Row():
|
| 140 |
+
with gr.Column(scale=2):
|
| 141 |
+
chatbot = gr.Chatbot(
|
| 142 |
+
label="Conversation",
|
| 143 |
+
height=500,
|
| 144 |
+
show_label=True,
|
| 145 |
+
type="messages"
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
with gr.Row():
|
| 149 |
+
msg = gr.Textbox(
|
| 150 |
+
label="Your question",
|
| 151 |
+
placeholder="e.g., What questions were asked in the June 2025 Unity Poll?",
|
| 152 |
+
show_label=False,
|
| 153 |
+
scale=4
|
| 154 |
+
)
|
| 155 |
+
submit = gr.Button("Send", scale=1, variant="primary")
|
| 156 |
+
|
| 157 |
+
with gr.Row():
|
| 158 |
+
clear = gr.Button("🔄 New Conversation", scale=1)
|
| 159 |
+
|
| 160 |
+
# Example questions
|
| 161 |
+
gr.Examples(
|
| 162 |
+
examples=[
|
| 163 |
+
"What questions were asked in June 2025?",
|
| 164 |
+
"Show me all healthcare-related questions",
|
| 165 |
+
"What questions were asked in the Unity Poll?",
|
| 166 |
+
"Compare immigration questions from different surveys",
|
| 167 |
+
],
|
| 168 |
+
inputs=msg,
|
| 169 |
+
label="Example Questions"
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Sidebar with info
|
| 173 |
+
with gr.Column(scale=1):
|
| 174 |
+
gr.Markdown("## 📋 Available Data")
|
| 175 |
+
survey_info = gr.Markdown(
|
| 176 |
+
value=get_available_surveys() if init_success else "Agent not initialized",
|
| 177 |
+
label="Surveys"
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
refresh_info = gr.Button("🔄 Refresh Survey List", size="sm")
|
| 181 |
+
|
| 182 |
+
gr.Markdown("""
|
| 183 |
+
## 💡 Tips
|
| 184 |
+
|
| 185 |
+
- Be specific about time periods (e.g., "June 2025")
|
| 186 |
+
- Mention survey names when relevant
|
| 187 |
+
- Follow up with clarifications if needed
|
| 188 |
+
- The agent maintains conversation context
|
| 189 |
+
|
| 190 |
+
## 🔧 Current Capabilities
|
| 191 |
+
|
| 192 |
+
✅ **Available:**
|
| 193 |
+
- Question text and response options
|
| 194 |
+
- Topics and themes
|
| 195 |
+
- Skip logic and sampling
|
| 196 |
+
- Question sequencing
|
| 197 |
+
|
| 198 |
+
⏳ **Coming Soon:**
|
| 199 |
+
- Response frequencies (toplines)
|
| 200 |
+
- Cross-tabulations
|
| 201 |
+
- Statistical analysis
|
| 202 |
+
""")
|
| 203 |
+
|
| 204 |
+
# Event handlers
|
| 205 |
+
def respond(message, chat_history, session_id):
|
| 206 |
+
"""Handle message and update chat history"""
|
| 207 |
+
if not message.strip():
|
| 208 |
+
return chat_history, ""
|
| 209 |
+
|
| 210 |
+
# Add user message
|
| 211 |
+
chat_history.append({"role": "user", "content": message})
|
| 212 |
+
|
| 213 |
+
# Get bot response
|
| 214 |
+
bot_message = chat(message, chat_history, session_id)
|
| 215 |
+
|
| 216 |
+
# Add bot message
|
| 217 |
+
chat_history.append({"role": "assistant", "content": bot_message})
|
| 218 |
+
|
| 219 |
+
return chat_history, ""
|
| 220 |
+
|
| 221 |
+
def clear_chat():
|
| 222 |
+
"""Clear chat and create new session"""
|
| 223 |
+
new_session = create_new_session()
|
| 224 |
+
return [], new_session
|
| 225 |
+
|
| 226 |
+
# Wire up events
|
| 227 |
+
msg.submit(respond, [msg, chatbot, session_id_state], [chatbot, msg])
|
| 228 |
+
submit.click(respond, [msg, chatbot, session_id_state], [chatbot, msg])
|
| 229 |
+
clear.click(clear_chat, None, [chatbot, session_id_state])
|
| 230 |
+
refresh_info.click(get_available_surveys, None, survey_info)
|
| 231 |
+
|
| 232 |
+
# Footer
|
| 233 |
+
gr.Markdown("""
|
| 234 |
+
---
|
| 235 |
+
**Note:** This system uses conversation memory. You can ask follow-up questions like:
|
| 236 |
+
1. "What questions were asked?"
|
| 237 |
+
2. "June 2025, Unity Poll" (it will understand the context)
|
| 238 |
+
""")
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# Launch the app
|
| 242 |
+
if __name__ == "__main__":
|
| 243 |
+
demo.launch(
|
| 244 |
+
server_name="0.0.0.0",
|
| 245 |
+
server_port=7860,
|
| 246 |
+
share=False
|
| 247 |
+
)
|
questionnaire_rag.py
ADDED
|
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Questionnaire RAG with better filtering and anti-hallucination measures.
|
| 3 |
+
|
| 4 |
+
Key improvements:
|
| 5 |
+
1. Correct Pinecone filter syntax
|
| 6 |
+
2. Post-retrieval validation of filters
|
| 7 |
+
3. Stronger anti-hallucination prompts
|
| 8 |
+
4. Explicit checks for data existence
|
| 9 |
+
5. Fuzzy survey name matching
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import json
|
| 14 |
+
from typing import List, Dict, Any, Optional
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
| 18 |
+
from langchain_pinecone import PineconeVectorStore
|
| 19 |
+
from pinecone import Pinecone
|
| 20 |
+
from langchain.prompts import ChatPromptTemplate
|
| 21 |
+
from langchain.schema.output_parser import StrOutputParser
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
from dotenv import load_dotenv
|
| 25 |
+
load_dotenv()
|
| 26 |
+
except ImportError:
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class QuestionnaireRAG:
|
| 31 |
+
"""
|
| 32 |
+
Improved questionnaire RAG with:
|
| 33 |
+
- Better Pinecone filtering
|
| 34 |
+
- Post-retrieval validation
|
| 35 |
+
- Anti-hallucination measures
|
| 36 |
+
- Fuzzy survey name matching
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
openai_api_key: str,
|
| 42 |
+
pinecone_api_key: str,
|
| 43 |
+
persist_directory: str = "./questionnaire_vectorstores",
|
| 44 |
+
verbose: bool = False
|
| 45 |
+
):
|
| 46 |
+
self.openai_api_key = openai_api_key
|
| 47 |
+
self.pinecone_api_key = pinecone_api_key
|
| 48 |
+
self.persist_directory = persist_directory
|
| 49 |
+
self.verbose = verbose
|
| 50 |
+
|
| 51 |
+
# Initialize embeddings
|
| 52 |
+
self.embeddings = OpenAIEmbeddings(
|
| 53 |
+
model=os.getenv("OPENAI_EMBED_MODEL", "text-embedding-3-small")
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Initialize LLM
|
| 57 |
+
chat_model = os.getenv("OPENAI_MODEL", "gpt-4o")
|
| 58 |
+
self.llm = ChatOpenAI(model=chat_model, temperature=0)
|
| 59 |
+
|
| 60 |
+
# Load vector store
|
| 61 |
+
if not os.path.exists(persist_directory):
|
| 62 |
+
raise ValueError(
|
| 63 |
+
f"Vector store not found at {persist_directory}\n"
|
| 64 |
+
"Run create_questionnaire_vectorstores.py first"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Connect to Pinecone
|
| 68 |
+
index_name = os.getenv("PINECONE_INDEX_NAME", "poll-questionnaire-index")
|
| 69 |
+
namespace = os.getenv("PINECONE_NAMESPACE") or None
|
| 70 |
+
|
| 71 |
+
pc = Pinecone(api_key=self.pinecone_api_key)
|
| 72 |
+
self.index = pc.Index(index_name)
|
| 73 |
+
self.vectorstore = PineconeVectorStore(
|
| 74 |
+
index=self.index,
|
| 75 |
+
embedding=self.embeddings,
|
| 76 |
+
namespace=namespace
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Load catalog and questions
|
| 80 |
+
self.poll_catalog = self._load_catalog()
|
| 81 |
+
self.questions_by_id = self._load_questions_index()
|
| 82 |
+
|
| 83 |
+
if self.verbose:
|
| 84 |
+
print(f"✓ Loaded {len(self.questions_by_id)} questions from {len(self.poll_catalog)} polls")
|
| 85 |
+
|
| 86 |
+
def _load_catalog(self) -> Dict[str, Dict]:
|
| 87 |
+
"""Load poll catalog"""
|
| 88 |
+
catalog_path = Path(self.persist_directory) / "poll_catalog.json"
|
| 89 |
+
if catalog_path.exists():
|
| 90 |
+
with open(catalog_path, 'r') as f:
|
| 91 |
+
return json.load(f)
|
| 92 |
+
return {}
|
| 93 |
+
|
| 94 |
+
def _load_questions_index(self) -> Dict[str, Dict]:
|
| 95 |
+
"""Load questions index"""
|
| 96 |
+
questions_path = Path(self.persist_directory) / "questions_index.json"
|
| 97 |
+
if questions_path.exists():
|
| 98 |
+
with open(questions_path, 'r') as f:
|
| 99 |
+
return json.load(f)
|
| 100 |
+
return {}
|
| 101 |
+
|
| 102 |
+
def get_available_survey_names(self) -> List[str]:
|
| 103 |
+
"""Get list of unique survey names from the catalog"""
|
| 104 |
+
survey_names = set()
|
| 105 |
+
for info in self.poll_catalog.values():
|
| 106 |
+
survey_names.add(info["survey_name"])
|
| 107 |
+
return sorted(survey_names)
|
| 108 |
+
|
| 109 |
+
def _fuzzy_match_survey_name(self, requested_name: str) -> Optional[str]:
|
| 110 |
+
"""
|
| 111 |
+
Fuzzy match a requested survey name to an actual stored name.
|
| 112 |
+
|
| 113 |
+
Examples:
|
| 114 |
+
- "Unity Poll" → "Vanderbilt_Unity_Poll"
|
| 115 |
+
- "unity poll" → "Vanderbilt_Unity_Poll"
|
| 116 |
+
- "Vanderbilt Unity" → "Vanderbilt_Unity_Poll"
|
| 117 |
+
"""
|
| 118 |
+
# Get all unique survey names
|
| 119 |
+
available_names = self.get_available_survey_names()
|
| 120 |
+
|
| 121 |
+
# Normalize the requested name
|
| 122 |
+
normalized_requested = requested_name.lower().replace("_", " ").replace("-", " ")
|
| 123 |
+
|
| 124 |
+
# Try exact match first (case-insensitive)
|
| 125 |
+
for stored_name in available_names:
|
| 126 |
+
normalized_stored = stored_name.lower().replace("_", " ").replace("-", " ")
|
| 127 |
+
if normalized_requested == normalized_stored:
|
| 128 |
+
return stored_name
|
| 129 |
+
|
| 130 |
+
# Try substring matching - check if requested is in stored
|
| 131 |
+
for stored_name in available_names:
|
| 132 |
+
normalized_stored = stored_name.lower().replace("_", " ").replace("-", " ")
|
| 133 |
+
if normalized_requested in normalized_stored:
|
| 134 |
+
return stored_name
|
| 135 |
+
|
| 136 |
+
# Try reverse - check if stored is in requested
|
| 137 |
+
for stored_name in available_names:
|
| 138 |
+
normalized_stored = stored_name.lower().replace("_", " ").replace("-", " ")
|
| 139 |
+
if normalized_stored in normalized_requested:
|
| 140 |
+
return stored_name
|
| 141 |
+
|
| 142 |
+
# Try word-level matching - if all words from requested are in stored
|
| 143 |
+
requested_words = set(normalized_requested.split())
|
| 144 |
+
for stored_name in available_names:
|
| 145 |
+
normalized_stored = stored_name.lower().replace("_", " ").replace("-", " ")
|
| 146 |
+
stored_words = set(normalized_stored.split())
|
| 147 |
+
|
| 148 |
+
# Check if requested words are a subset of stored words
|
| 149 |
+
if requested_words.issubset(stored_words):
|
| 150 |
+
return stored_name
|
| 151 |
+
|
| 152 |
+
return None
|
| 153 |
+
|
| 154 |
+
def _build_pinecone_filter(self, filters: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| 155 |
+
"""
|
| 156 |
+
Build proper Pinecone metadata filter with fuzzy survey name matching.
|
| 157 |
+
|
| 158 |
+
Pinecone filter syntax:
|
| 159 |
+
- Simple: {"year": 2025}
|
| 160 |
+
- Multiple: {"$and": [{"year": 2025}, {"month": "February"}]}
|
| 161 |
+
"""
|
| 162 |
+
if not filters:
|
| 163 |
+
return None
|
| 164 |
+
|
| 165 |
+
filter_conditions = []
|
| 166 |
+
|
| 167 |
+
# Handle year filter
|
| 168 |
+
if "year" in filters:
|
| 169 |
+
year = filters["year"]
|
| 170 |
+
if isinstance(year, str):
|
| 171 |
+
year = int(year)
|
| 172 |
+
filter_conditions.append({"year": {"$eq": year}})
|
| 173 |
+
|
| 174 |
+
# Handle month filter
|
| 175 |
+
if "month" in filters:
|
| 176 |
+
month = filters["month"]
|
| 177 |
+
# Ensure proper capitalization
|
| 178 |
+
if isinstance(month, str):
|
| 179 |
+
month = month.capitalize()
|
| 180 |
+
filter_conditions.append({"month": {"$eq": month}})
|
| 181 |
+
|
| 182 |
+
# Handle poll_date filter (exact match)
|
| 183 |
+
if "poll_date" in filters:
|
| 184 |
+
filter_conditions.append({"poll_date": {"$eq": filters["poll_date"]}})
|
| 185 |
+
|
| 186 |
+
# Handle survey_name filter with fuzzy matching
|
| 187 |
+
if "survey_name" in filters:
|
| 188 |
+
requested_name = filters["survey_name"]
|
| 189 |
+
|
| 190 |
+
# Try to fuzzy match the survey name
|
| 191 |
+
matched_name = self._fuzzy_match_survey_name(requested_name)
|
| 192 |
+
|
| 193 |
+
if matched_name:
|
| 194 |
+
if self.verbose and matched_name != requested_name:
|
| 195 |
+
print(f"🔄 Mapped survey name '{requested_name}' → '{matched_name}'")
|
| 196 |
+
filter_conditions.append({"survey_name": {"$eq": matched_name}})
|
| 197 |
+
else:
|
| 198 |
+
if self.verbose:
|
| 199 |
+
print(f"⚠️ Survey name '{requested_name}' not found in catalog")
|
| 200 |
+
print(f" Available: {self.get_available_survey_names()}")
|
| 201 |
+
# Don't add the filter if we can't match it - let other filters work
|
| 202 |
+
|
| 203 |
+
# Handle topics (if a topic is in the comma-separated list)
|
| 204 |
+
if "topic" in filters:
|
| 205 |
+
# This is trickier with comma-separated strings in metadata
|
| 206 |
+
# For now, we'll do post-filtering
|
| 207 |
+
pass
|
| 208 |
+
|
| 209 |
+
# Combine filters
|
| 210 |
+
if len(filter_conditions) == 0:
|
| 211 |
+
return None
|
| 212 |
+
elif len(filter_conditions) == 1:
|
| 213 |
+
return filter_conditions[0]
|
| 214 |
+
else:
|
| 215 |
+
return {"$and": filter_conditions}
|
| 216 |
+
|
| 217 |
+
def _validate_results(
|
| 218 |
+
self,
|
| 219 |
+
docs: List[Any],
|
| 220 |
+
filters: Dict[str, Any]
|
| 221 |
+
) -> List[Any]:
|
| 222 |
+
"""
|
| 223 |
+
Validate that retrieved documents actually match the filters.
|
| 224 |
+
|
| 225 |
+
This catches cases where:
|
| 226 |
+
1. Pinecone filtering didn't work correctly
|
| 227 |
+
2. We need to do additional filtering (like topic matching)
|
| 228 |
+
"""
|
| 229 |
+
if not filters:
|
| 230 |
+
return docs
|
| 231 |
+
|
| 232 |
+
validated_docs = []
|
| 233 |
+
|
| 234 |
+
for doc in docs:
|
| 235 |
+
metadata = doc.metadata
|
| 236 |
+
valid = True
|
| 237 |
+
|
| 238 |
+
# Check year
|
| 239 |
+
if "year" in filters:
|
| 240 |
+
expected_year = int(filters["year"]) if isinstance(filters["year"], str) else filters["year"]
|
| 241 |
+
if metadata.get("year") != expected_year:
|
| 242 |
+
if self.verbose:
|
| 243 |
+
print(f"⚠️ Filtered out: wrong year {metadata.get('year')} != {expected_year}")
|
| 244 |
+
valid = False
|
| 245 |
+
|
| 246 |
+
# Check month
|
| 247 |
+
if "month" in filters and valid:
|
| 248 |
+
expected_month = filters["month"].capitalize() if isinstance(filters["month"], str) else filters["month"]
|
| 249 |
+
if metadata.get("month") != expected_month:
|
| 250 |
+
if self.verbose:
|
| 251 |
+
print(f"⚠️ Filtered out: wrong month {metadata.get('month')} != {expected_month}")
|
| 252 |
+
valid = False
|
| 253 |
+
|
| 254 |
+
# Check poll_date
|
| 255 |
+
if "poll_date" in filters and valid:
|
| 256 |
+
if metadata.get("poll_date") != filters["poll_date"]:
|
| 257 |
+
if self.verbose:
|
| 258 |
+
print(f"⚠️ Filtered out: wrong poll_date {metadata.get('poll_date')} != {filters['poll_date']}")
|
| 259 |
+
valid = False
|
| 260 |
+
|
| 261 |
+
# Check survey_name (with fuzzy matching)
|
| 262 |
+
if "survey_name" in filters and valid:
|
| 263 |
+
requested_name = filters["survey_name"]
|
| 264 |
+
matched_name = self._fuzzy_match_survey_name(requested_name)
|
| 265 |
+
if matched_name and metadata.get("survey_name") != matched_name:
|
| 266 |
+
if self.verbose:
|
| 267 |
+
print(f"⚠️ Filtered out: wrong survey {metadata.get('survey_name')} != {matched_name}")
|
| 268 |
+
valid = False
|
| 269 |
+
|
| 270 |
+
if valid:
|
| 271 |
+
validated_docs.append(doc)
|
| 272 |
+
|
| 273 |
+
return validated_docs
|
| 274 |
+
|
| 275 |
+
def _get_prompt(self) -> ChatPromptTemplate:
|
| 276 |
+
"""Get the improved system prompt with anti-hallucination measures"""
|
| 277 |
+
return ChatPromptTemplate.from_messages([
|
| 278 |
+
("system", """You are an expert assistant for analyzing poll questionnaires.
|
| 279 |
+
|
| 280 |
+
🚨 CRITICAL RULES - NEVER VIOLATE THESE:
|
| 281 |
+
|
| 282 |
+
1. **ONLY use information from the provided context**
|
| 283 |
+
- Do NOT make up questions, polls, or dates
|
| 284 |
+
- Do NOT assume a poll exists if it's not in the context
|
| 285 |
+
- If information is missing, say "I don't have data for [X]" rather than making it up
|
| 286 |
+
|
| 287 |
+
2. **Verify data exists before listing it**
|
| 288 |
+
- Before mentioning any poll, check it's actually in the context
|
| 289 |
+
- Before listing questions, confirm they exist in the retrieved data
|
| 290 |
+
- If asked about multiple time periods, explicitly state which ones have data and which don't
|
| 291 |
+
|
| 292 |
+
3. **Be explicit about what's NOT in the data**
|
| 293 |
+
- If asked about "2024 and 2025" but only 2025 data exists, say: "I have data for 2025, but there is no 2024 data in the retrieved results"
|
| 294 |
+
- Never silently skip missing data - always acknowledge it
|
| 295 |
+
|
| 296 |
+
4. **When listing questions:**
|
| 297 |
+
- List ALL questions from the context in order
|
| 298 |
+
- Include full question text and response options
|
| 299 |
+
- Note sampling inline in clear language:
|
| 300 |
+
* "Asked to all respondents" (not "ASK ALL")
|
| 301 |
+
* "Asked to half the sample" (not "HALFSAMP1=1")
|
| 302 |
+
* "Asked only if [condition]" (not technical codes)
|
| 303 |
+
- If sibling variants exist, note "One of two versions shown to different groups"
|
| 304 |
+
- Always cite which poll(s) you're using
|
| 305 |
+
|
| 306 |
+
5. **Format for scannability:**
|
| 307 |
+
- Use numbered lists for questions
|
| 308 |
+
- Bold question text
|
| 309 |
+
- Include response options as bullet points
|
| 310 |
+
- Put sampling info in parentheses after question
|
| 311 |
+
|
| 312 |
+
Available polls in the system (for reference):
|
| 313 |
+
{catalog}
|
| 314 |
+
|
| 315 |
+
Context (ONLY source of truth):
|
| 316 |
+
{context}
|
| 317 |
+
|
| 318 |
+
Question: {question}
|
| 319 |
+
"""),
|
| 320 |
+
("human", "Answer:")
|
| 321 |
+
])
|
| 322 |
+
|
| 323 |
+
def query(self, question: str, filters: Optional[Dict[str, Any]] = None, k: int = 20) -> str:
|
| 324 |
+
"""
|
| 325 |
+
Query the questionnaire system.
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
question: Natural language question
|
| 329 |
+
filters: Optional filters (year, month, poll_date, survey_name)
|
| 330 |
+
k: Number of results to retrieve
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
Answer string
|
| 334 |
+
"""
|
| 335 |
+
result = self._query_internal(question, filters, k)
|
| 336 |
+
return result['answer']
|
| 337 |
+
|
| 338 |
+
def query_with_metadata(
|
| 339 |
+
self,
|
| 340 |
+
question: str,
|
| 341 |
+
filters: Optional[Dict[str, Any]] = None,
|
| 342 |
+
k: int = 20
|
| 343 |
+
) -> Dict[str, Any]:
|
| 344 |
+
"""
|
| 345 |
+
Query with full metadata about retrieval.
|
| 346 |
+
|
| 347 |
+
Returns:
|
| 348 |
+
Dict with 'answer', 'source_questions', 'num_sources', 'filters_applied'
|
| 349 |
+
"""
|
| 350 |
+
return self._query_internal(question, filters, k)
|
| 351 |
+
|
| 352 |
+
def _query_internal(
|
| 353 |
+
self,
|
| 354 |
+
question: str,
|
| 355 |
+
filters: Optional[Dict[str, Any]] = None,
|
| 356 |
+
k: int = 20
|
| 357 |
+
) -> Dict[str, Any]:
|
| 358 |
+
"""Internal query implementation"""
|
| 359 |
+
|
| 360 |
+
if self.verbose:
|
| 361 |
+
print(f"\n📊 Query: {question}")
|
| 362 |
+
if filters:
|
| 363 |
+
print(f"🔍 Filters: {filters}")
|
| 364 |
+
|
| 365 |
+
# Build Pinecone filter
|
| 366 |
+
pinecone_filter = self._build_pinecone_filter(filters or {})
|
| 367 |
+
|
| 368 |
+
# Retrieve documents
|
| 369 |
+
if pinecone_filter:
|
| 370 |
+
if self.verbose:
|
| 371 |
+
print(f"🔧 Pinecone filter: {pinecone_filter}")
|
| 372 |
+
retriever = self.vectorstore.as_retriever(
|
| 373 |
+
search_kwargs={"k": k, "filter": pinecone_filter}
|
| 374 |
+
)
|
| 375 |
+
else:
|
| 376 |
+
retriever = self.vectorstore.as_retriever(search_kwargs={"k": k})
|
| 377 |
+
|
| 378 |
+
docs = retriever.invoke(question)
|
| 379 |
+
|
| 380 |
+
if self.verbose:
|
| 381 |
+
print(f"📥 Retrieved {len(docs)} documents from Pinecone")
|
| 382 |
+
|
| 383 |
+
# Validate results match filters
|
| 384 |
+
if filters:
|
| 385 |
+
docs = self._validate_results(docs, filters)
|
| 386 |
+
if self.verbose:
|
| 387 |
+
print(f"✅ After validation: {len(docs)} documents")
|
| 388 |
+
|
| 389 |
+
# Check if we have any results
|
| 390 |
+
if not docs:
|
| 391 |
+
no_data_msg = f"No questionnaire data found"
|
| 392 |
+
if filters:
|
| 393 |
+
filter_desc = ", ".join([f"{k}={v}" for k, v in filters.items()])
|
| 394 |
+
no_data_msg += f" matching filters: {filter_desc}"
|
| 395 |
+
|
| 396 |
+
return {
|
| 397 |
+
"answer": no_data_msg,
|
| 398 |
+
"source_questions": [],
|
| 399 |
+
"num_sources": 0,
|
| 400 |
+
"filters_applied": filters or {}
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
# Reconstruct full questions
|
| 404 |
+
full_questions = []
|
| 405 |
+
seen_ids = set()
|
| 406 |
+
|
| 407 |
+
for doc in docs:
|
| 408 |
+
q_id = doc.metadata.get('question_id')
|
| 409 |
+
if q_id and q_id not in seen_ids:
|
| 410 |
+
if q_id in self.questions_by_id:
|
| 411 |
+
full_questions.append(self.questions_by_id[q_id])
|
| 412 |
+
seen_ids.add(q_id)
|
| 413 |
+
|
| 414 |
+
# Sort by position to maintain survey order
|
| 415 |
+
full_questions.sort(key=lambda q: (q.get('poll_date', ''), q.get('position', 0)))
|
| 416 |
+
|
| 417 |
+
# Format context with explicit data availability info
|
| 418 |
+
context = self._format_context(full_questions, filters)
|
| 419 |
+
|
| 420 |
+
# Get prompt
|
| 421 |
+
prompt = self._get_prompt()
|
| 422 |
+
|
| 423 |
+
# Create chain
|
| 424 |
+
chain = (
|
| 425 |
+
{
|
| 426 |
+
"context": lambda x: context,
|
| 427 |
+
"question": lambda x: question,
|
| 428 |
+
"catalog": lambda x: self._get_catalog_summary()
|
| 429 |
+
}
|
| 430 |
+
| prompt
|
| 431 |
+
| self.llm
|
| 432 |
+
| StrOutputParser()
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
# Get answer
|
| 436 |
+
answer = chain.invoke(question)
|
| 437 |
+
|
| 438 |
+
return {
|
| 439 |
+
'answer': answer,
|
| 440 |
+
'source_questions': full_questions,
|
| 441 |
+
'num_sources': len(full_questions),
|
| 442 |
+
'filters_applied': filters or {}
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
def _format_context(
|
| 446 |
+
self,
|
| 447 |
+
questions: List[Dict],
|
| 448 |
+
filters: Optional[Dict[str, Any]] = None
|
| 449 |
+
) -> str:
|
| 450 |
+
"""Format questions as context with explicit data availability"""
|
| 451 |
+
|
| 452 |
+
if not questions:
|
| 453 |
+
filter_desc = ""
|
| 454 |
+
if filters:
|
| 455 |
+
filter_desc = f" matching {filters}"
|
| 456 |
+
return f"⚠️ NO DATA RETRIEVED{filter_desc}\n\nYou must inform the user that no data exists for their query."
|
| 457 |
+
|
| 458 |
+
context_parts = []
|
| 459 |
+
|
| 460 |
+
# Add explicit note about what data we have
|
| 461 |
+
polls_found = sorted(set(q['poll_date'] for q in questions))
|
| 462 |
+
context_parts.append(f"✅ DATA AVAILABLE FOR: {', '.join(polls_found)}")
|
| 463 |
+
|
| 464 |
+
# Add note about what was requested vs what was found
|
| 465 |
+
if filters:
|
| 466 |
+
if 'year' in filters and 'month' in filters:
|
| 467 |
+
requested = f"{filters['month']} {filters['year']}"
|
| 468 |
+
context_parts.append(f"🔍 REQUESTED: {requested}")
|
| 469 |
+
|
| 470 |
+
context_parts.append("") # Blank line
|
| 471 |
+
context_parts.append("=" * 80)
|
| 472 |
+
context_parts.append("")
|
| 473 |
+
|
| 474 |
+
# Format each question
|
| 475 |
+
for i, q in enumerate(questions, 1):
|
| 476 |
+
part = f"""
|
| 477 |
+
--- Question {i} from {q['survey_name']} ({q['poll_date']}) ---
|
| 478 |
+
Variable: {q['variable_name']}
|
| 479 |
+
Question: {q['question_text']}
|
| 480 |
+
Response Options: {' | '.join(q['response_options'])}
|
| 481 |
+
Topics: {', '.join(q['topics'])}
|
| 482 |
+
Question Type: {q['question_type']}
|
| 483 |
+
Administration: {q['ask_condition']}
|
| 484 |
+
"""
|
| 485 |
+
|
| 486 |
+
# Add skip logic/sampling
|
| 487 |
+
if q.get('skip_logic'):
|
| 488 |
+
part += f"Skip Logic: {q['skip_logic']}\n"
|
| 489 |
+
|
| 490 |
+
if q.get('half_sample_group'):
|
| 491 |
+
part += f"Half Sample Group: {q['half_sample_group']}\n"
|
| 492 |
+
|
| 493 |
+
# Add sibling variants
|
| 494 |
+
if q.get('sibling_variants'):
|
| 495 |
+
part += f"\nAlternate Versions (shown to different groups):\n"
|
| 496 |
+
for sib in q['sibling_variants']:
|
| 497 |
+
sib_group = sib.get('half_sample_group', 'other group')
|
| 498 |
+
part += f" - [{sib_group}] {sib['question_text']}\n"
|
| 499 |
+
|
| 500 |
+
# Add sequence context
|
| 501 |
+
if q.get('previous_question'):
|
| 502 |
+
prev_vars = q.get('previous_question_variants', [])
|
| 503 |
+
if len(prev_vars) > 1:
|
| 504 |
+
part += "\nPrevious Question (respondents saw one of these):\n"
|
| 505 |
+
for pv in prev_vars:
|
| 506 |
+
part += f" - {pv['question_text']}\n"
|
| 507 |
+
else:
|
| 508 |
+
part += f"\nPrevious Question: {q['previous_question']['question_text']}\n"
|
| 509 |
+
|
| 510 |
+
if q.get('next_question'):
|
| 511 |
+
next_vars = q.get('next_question_variants', [])
|
| 512 |
+
if len(next_vars) > 1:
|
| 513 |
+
part += "\nNext Question (respondents saw one of these):\n"
|
| 514 |
+
for nv in next_vars:
|
| 515 |
+
part += f" - {nv['question_text']}\n"
|
| 516 |
+
else:
|
| 517 |
+
part += f"\nNext Question: {q['next_question']['question_text']}\n"
|
| 518 |
+
|
| 519 |
+
context_parts.append(part.strip())
|
| 520 |
+
|
| 521 |
+
return "\n\n".join(context_parts)
|
| 522 |
+
|
| 523 |
+
def _get_catalog_summary(self) -> str:
|
| 524 |
+
"""Get summary of available polls"""
|
| 525 |
+
lines = ["Available polls:"]
|
| 526 |
+
for poll_date in sorted(self.poll_catalog.keys()):
|
| 527 |
+
info = self.poll_catalog[poll_date]
|
| 528 |
+
month_str = f" ({info['month']})" if info.get('month') else ""
|
| 529 |
+
lines.append(f"- {poll_date}{month_str}: {info['num_questions']} questions")
|
| 530 |
+
return "\n".join(lines)
|
| 531 |
+
|
| 532 |
+
def get_available_polls(self) -> List[Dict[str, Any]]:
|
| 533 |
+
"""Get list of all available polls"""
|
| 534 |
+
return [
|
| 535 |
+
{
|
| 536 |
+
"poll_date": poll_date,
|
| 537 |
+
"survey_name": info["survey_name"],
|
| 538 |
+
"year": info["year"],
|
| 539 |
+
"month": info.get("month", ""),
|
| 540 |
+
"num_questions": info["num_questions"]
|
| 541 |
+
}
|
| 542 |
+
for poll_date, info in sorted(self.poll_catalog.items())
|
| 543 |
+
]
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def main():
|
| 547 |
+
"""Test CLI"""
|
| 548 |
+
import sys
|
| 549 |
+
|
| 550 |
+
openai_api_key = os.getenv("OPENAI_API_KEY")
|
| 551 |
+
pinecone_api_key = os.getenv("PINECONE_API_KEY")
|
| 552 |
+
|
| 553 |
+
if not openai_api_key or not pinecone_api_key:
|
| 554 |
+
print("Error: Missing API keys")
|
| 555 |
+
sys.exit(1)
|
| 556 |
+
|
| 557 |
+
rag = QuestionnaireRAG(
|
| 558 |
+
openai_api_key=openai_api_key,
|
| 559 |
+
pinecone_api_key=pinecone_api_key,
|
| 560 |
+
verbose=True
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
print("\n" + "="*80)
|
| 564 |
+
print("QUESTIONNAIRE RAG - TEST MODE")
|
| 565 |
+
print("="*80)
|
| 566 |
+
|
| 567 |
+
# Test fuzzy matching
|
| 568 |
+
print("\n🧪 TEST: Fuzzy survey name matching")
|
| 569 |
+
test_names = ["Unity Poll", "unity poll", "Vanderbilt Unity", "UNITY"]
|
| 570 |
+
for name in test_names:
|
| 571 |
+
matched = rag._fuzzy_match_survey_name(name)
|
| 572 |
+
print(f" '{name}' → '{matched}'")
|
| 573 |
+
|
| 574 |
+
# Test with the problematic query
|
| 575 |
+
print("\n🧪 TEST: Query that previously failed")
|
| 576 |
+
print("Query: What questions were asked in the June 2025 Unity Poll?")
|
| 577 |
+
|
| 578 |
+
filters = {"year": 2025, "month": "June", "survey_name": "Unity Poll"}
|
| 579 |
+
result = rag.query_with_metadata(
|
| 580 |
+
"What questions were asked in the June 2025 Unity Poll?",
|
| 581 |
+
filters=filters
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
print(f"\n📊 Results:")
|
| 585 |
+
print(f"Found: {result['num_sources']} questions")
|
| 586 |
+
print(f"\n{result['answer'][:500]}...")
|
| 587 |
+
|
| 588 |
+
print("\n" + "="*80)
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
if __name__ == "__main__":
|
| 592 |
+
main()
|
questionnaire_vectorstores/poll_catalog.json
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"2023-06": {
|
| 3 |
+
"survey_name": "Vanderbilt_Unity_Poll",
|
| 4 |
+
"year": 2023,
|
| 5 |
+
"month": "June",
|
| 6 |
+
"poll_date": "2023-06",
|
| 7 |
+
"num_questions": 15,
|
| 8 |
+
"file": "questionnaire_data/Vanderbilt_Unity_Poll_2023_June_questions.json"
|
| 9 |
+
},
|
| 10 |
+
"2023-03": {
|
| 11 |
+
"survey_name": "Vanderbilt_Unity_Poll",
|
| 12 |
+
"year": 2023,
|
| 13 |
+
"month": "March",
|
| 14 |
+
"poll_date": "2023-03",
|
| 15 |
+
"num_questions": 8,
|
| 16 |
+
"file": "questionnaire_data/Vanderbilt_Unity_Poll_2023_March_questions.json"
|
| 17 |
+
},
|
| 18 |
+
"2023-09": {
|
| 19 |
+
"survey_name": "Vanderbilt_Unity_Poll",
|
| 20 |
+
"year": 2023,
|
| 21 |
+
"month": "September",
|
| 22 |
+
"poll_date": "2023-09",
|
| 23 |
+
"num_questions": 15,
|
| 24 |
+
"file": "questionnaire_data/Vanderbilt_Unity_Poll_2023_September_questions.json"
|
| 25 |
+
},
|
| 26 |
+
"2024-06": {
|
| 27 |
+
"survey_name": "Vanderbilt_Unity_Poll",
|
| 28 |
+
"year": 2024,
|
| 29 |
+
"month": "June",
|
| 30 |
+
"poll_date": "2024-06",
|
| 31 |
+
"num_questions": 5,
|
| 32 |
+
"file": "questionnaire_data/Vanderbilt_Unity_Poll_2024_June_questions.json"
|
| 33 |
+
},
|
| 34 |
+
"2024-03": {
|
| 35 |
+
"survey_name": "Vanderbilt_Unity_Poll",
|
| 36 |
+
"year": 2024,
|
| 37 |
+
"month": "March",
|
| 38 |
+
"poll_date": "2024-03",
|
| 39 |
+
"num_questions": 13,
|
| 40 |
+
"file": "questionnaire_data/Vanderbilt_Unity_Poll_2024_March_questions.json"
|
| 41 |
+
},
|
| 42 |
+
"2024-10": {
|
| 43 |
+
"survey_name": "Vanderbilt_Unity_Poll",
|
| 44 |
+
"year": 2024,
|
| 45 |
+
"month": "October",
|
| 46 |
+
"poll_date": "2024-10",
|
| 47 |
+
"num_questions": 14,
|
| 48 |
+
"file": "questionnaire_data/Vanderbilt_Unity_Poll_2024_October_questions.json"
|
| 49 |
+
},
|
| 50 |
+
"2024-09": {
|
| 51 |
+
"survey_name": "Vanderbilt_Unity_Poll",
|
| 52 |
+
"year": 2024,
|
| 53 |
+
"month": "September",
|
| 54 |
+
"poll_date": "2024-09",
|
| 55 |
+
"num_questions": 15,
|
| 56 |
+
"file": "questionnaire_data/Vanderbilt_Unity_Poll_2024_September_questions.json"
|
| 57 |
+
},
|
| 58 |
+
"2025-02": {
|
| 59 |
+
"survey_name": "Vanderbilt_Unity_Poll",
|
| 60 |
+
"year": 2025,
|
| 61 |
+
"month": "February",
|
| 62 |
+
"poll_date": "2025-02",
|
| 63 |
+
"num_questions": 17,
|
| 64 |
+
"file": "questionnaire_data/Vanderbilt_Unity_Poll_2025_February_questions.json"
|
| 65 |
+
},
|
| 66 |
+
"2025-06": {
|
| 67 |
+
"survey_name": "Vanderbilt_Unity_Poll",
|
| 68 |
+
"year": 2025,
|
| 69 |
+
"month": "June",
|
| 70 |
+
"poll_date": "2025-06",
|
| 71 |
+
"num_questions": 23,
|
| 72 |
+
"file": "questionnaire_data/Vanderbilt_Unity_Poll_2025_June_questions.json"
|
| 73 |
+
}
|
| 74 |
+
}
|
questionnaire_vectorstores/questions_index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.0.0
|
| 2 |
+
langchain>=0.1.0
|
| 3 |
+
langchain-openai>=0.0.5
|
| 4 |
+
langchain-pinecone>=0.0.3
|
| 5 |
+
langgraph>=0.0.20
|
| 6 |
+
openai>=1.0.0
|
| 7 |
+
pinecone
|
| 8 |
+
python-dotenv>=1.0.0
|
| 9 |
+
pydantic>=2.0.0
|
survey_agent.py
ADDED
|
@@ -0,0 +1,1175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multi-agent survey analysis system using LangGraph with Staged Research Briefs.
|
| 3 |
+
|
| 4 |
+
This orchestrates multiple data sources (questionnaires, toplines, crosstabs, SQL)
|
| 5 |
+
to answer complex survey research questions using sequential, adaptive research stages.
|
| 6 |
+
|
| 7 |
+
# TODO: REMOVE WHEN PIPELINES READY
|
| 8 |
+
When new pipelines (toplines, crosstabs, SQL) become available:
|
| 9 |
+
1. Add pipeline name to SurveyAnalysisAgent.AVAILABLE_PIPELINES (line ~105)
|
| 10 |
+
2. Add execution logic in _execute_stage() method (around line ~450)
|
| 11 |
+
3. Search for "TODO: REMOVE WHEN PIPELINES READY" and remove those sections
|
| 12 |
+
4. Update examples to include the new pipeline capabilities
|
| 13 |
+
|
| 14 |
+
Current Status:
|
| 15 |
+
- ✅ Questionnaire pipeline: ACTIVE
|
| 16 |
+
- ⏳ Toplines pipeline: Not yet implemented
|
| 17 |
+
- ⏳ Crosstabs pipeline: Not yet implemented
|
| 18 |
+
- ⏳ SQL pipeline: Not yet implemented
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
import json
|
| 23 |
+
from typing import TypedDict, Literal, Annotated, List, Dict, Any, Optional, Union
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
import operator
|
| 26 |
+
|
| 27 |
+
from langgraph.graph import StateGraph, START, END
|
| 28 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 29 |
+
from langchain_openai import ChatOpenAI
|
| 30 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
| 31 |
+
from pydantic import BaseModel, Field, ConfigDict
|
| 32 |
+
|
| 33 |
+
# Import the questionnaire RAG
|
| 34 |
+
from questionnaire_rag import QuestionnaireRAG
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
from dotenv import load_dotenv
|
| 38 |
+
load_dotenv()
|
| 39 |
+
except ImportError:
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ============================================================================
|
| 44 |
+
# STATE DEFINITIONS (PYDANTIC V2) - WITH STAGED RESEARCH
|
| 45 |
+
# ============================================================================
|
| 46 |
+
|
| 47 |
+
class QueryFilters(BaseModel):
|
| 48 |
+
"""Filters for data source queries - Pydantic v2 with strict schema"""
|
| 49 |
+
model_config = ConfigDict(extra="forbid")
|
| 50 |
+
|
| 51 |
+
year: Optional[int] = Field(default=None, description="Year filter (e.g., 2025)")
|
| 52 |
+
month: Optional[str] = Field(default=None, description="Month filter (e.g., 'February')")
|
| 53 |
+
poll_date: Optional[str] = Field(default=None, description="Specific poll date (e.g., '2025-02-15')")
|
| 54 |
+
survey_name: Optional[str] = Field(default=None, description="Survey name filter (e.g., 'Unity Poll')")
|
| 55 |
+
topic: Optional[str] = Field(default=None, description="Topic filter")
|
| 56 |
+
question_ids: Optional[List[str]] = Field(default=None, description="Specific question IDs from previous stage")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class DataSource(BaseModel):
|
| 60 |
+
"""Represents a data source to query"""
|
| 61 |
+
model_config = ConfigDict(extra="forbid")
|
| 62 |
+
|
| 63 |
+
source_type: Literal["questionnaire", "toplines", "crosstabs", "sql"]
|
| 64 |
+
query_description: str = Field(description="What to retrieve from this source")
|
| 65 |
+
filters: QueryFilters = Field(default_factory=QueryFilters, description="Filters to apply")
|
| 66 |
+
result_label: Optional[str] = Field(default=None, description="Label for these results (e.g., '2024_questions')")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class ResearchStage(BaseModel):
|
| 70 |
+
"""A single stage in a multi-stage research plan"""
|
| 71 |
+
model_config = ConfigDict(extra="forbid")
|
| 72 |
+
|
| 73 |
+
stage_number: int = Field(description="Stage number (1-indexed)")
|
| 74 |
+
description: str = Field(description="What this stage accomplishes")
|
| 75 |
+
data_sources: List[DataSource] = Field(description="Data sources to query in this stage")
|
| 76 |
+
depends_on_stages: List[int] = Field(default_factory=list, description="Which prior stages this depends on")
|
| 77 |
+
use_previous_results_for: Optional[str] = Field(
|
| 78 |
+
default=None,
|
| 79 |
+
description="How to use previous stage results (e.g., 'Extract question IDs from stage 1')"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class ResearchBrief(BaseModel):
|
| 84 |
+
"""Research brief - can be either single-stage or multi-stage"""
|
| 85 |
+
model_config = ConfigDict(extra="forbid")
|
| 86 |
+
|
| 87 |
+
action: Literal["answer", "followup", "route_to_sources", "execute_stages"]
|
| 88 |
+
followup_question: Optional[str] = Field(default=None, description="Follow-up question to ask user")
|
| 89 |
+
reasoning: str = Field(description="Why this approach was chosen")
|
| 90 |
+
|
| 91 |
+
# For simple queries (single-stage)
|
| 92 |
+
data_sources: List[DataSource] = Field(default_factory=list, description="Data sources for simple queries")
|
| 93 |
+
|
| 94 |
+
# For complex queries (multi-stage)
|
| 95 |
+
stages: List[ResearchStage] = Field(default_factory=list, description="Ordered stages of research")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class StageResult(BaseModel):
|
| 99 |
+
"""Results from executing one stage"""
|
| 100 |
+
model_config = ConfigDict(extra="forbid")
|
| 101 |
+
|
| 102 |
+
stage_number: int
|
| 103 |
+
status: Literal["success", "partial", "failed"]
|
| 104 |
+
questionnaire_results: Optional[Dict[str, Any]] = None
|
| 105 |
+
toplines_results: Optional[Dict[str, Any]] = None
|
| 106 |
+
crosstabs_results: Optional[Dict[str, Any]] = None
|
| 107 |
+
sql_results: Optional[Dict[str, Any]] = None
|
| 108 |
+
extracted_context: Optional[Dict[str, Any]] = Field(
|
| 109 |
+
default=None,
|
| 110 |
+
description="Key information extracted for next stages (e.g., question IDs)"
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class VerificationResult(BaseModel):
|
| 115 |
+
"""Result of verifying if data answers the question"""
|
| 116 |
+
model_config = ConfigDict(extra="forbid")
|
| 117 |
+
|
| 118 |
+
answers_question: bool = Field(description="Whether the data fully answers the question")
|
| 119 |
+
missing_info: Optional[str] = Field(default=None, description="What information is missing")
|
| 120 |
+
improvement_suggestion: Optional[str] = Field(default=None, description="How to improve the research brief")
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class SurveyAnalysisState(TypedDict):
|
| 124 |
+
"""State for the survey analysis agent - WITH STAGED RESEARCH"""
|
| 125 |
+
# User interaction
|
| 126 |
+
messages: Annotated[List, operator.add]
|
| 127 |
+
user_question: str
|
| 128 |
+
|
| 129 |
+
# Planning
|
| 130 |
+
research_brief: Optional[ResearchBrief]
|
| 131 |
+
|
| 132 |
+
# Stage execution
|
| 133 |
+
current_stage: int # Which stage we're executing (0-indexed internally, but 1-indexed in models)
|
| 134 |
+
stage_results: List[StageResult] # Results from each completed stage
|
| 135 |
+
|
| 136 |
+
# Legacy single-stage results (for backward compatibility)
|
| 137 |
+
questionnaire_results: Optional[Dict[str, Any]]
|
| 138 |
+
toplines_results: Optional[Dict[str, Any]]
|
| 139 |
+
crosstabs_results: Optional[Dict[str, Any]]
|
| 140 |
+
sql_results: Optional[Dict[str, Any]]
|
| 141 |
+
|
| 142 |
+
# Verification & synthesis
|
| 143 |
+
verification: Optional[VerificationResult]
|
| 144 |
+
final_answer: Optional[str]
|
| 145 |
+
|
| 146 |
+
# Control flow
|
| 147 |
+
retry_count: int
|
| 148 |
+
max_retries: int
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# ============================================================================
|
| 152 |
+
# SURVEY ANALYSIS ORCHESTRATOR - WITH STAGED RESEARCH
|
| 153 |
+
# ============================================================================
|
| 154 |
+
|
| 155 |
+
class SurveyAnalysisAgent:
|
| 156 |
+
"""
|
| 157 |
+
Multi-agent system for analyzing survey data with staged research briefs.
|
| 158 |
+
|
| 159 |
+
Flow:
|
| 160 |
+
1. User asks question
|
| 161 |
+
2. Research brief agent decides: simple (one-shot) or complex (staged)
|
| 162 |
+
3. For simple: run pipelines in parallel → verify → synthesize
|
| 163 |
+
4. For complex: execute stages sequentially, each using previous results
|
| 164 |
+
5. Final synthesis combines all stage results
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
# TODO: REMOVE WHEN PIPELINES READY - START
|
| 168 |
+
# Track which pipelines are currently available
|
| 169 |
+
AVAILABLE_PIPELINES = {"questionnaire"} # Add "toplines", "crosstabs", "sql" as they become ready
|
| 170 |
+
# TODO: REMOVE WHEN PIPELINES READY - END
|
| 171 |
+
|
| 172 |
+
def __init__(
|
| 173 |
+
self,
|
| 174 |
+
openai_api_key: str,
|
| 175 |
+
pinecone_api_key: str,
|
| 176 |
+
questionnaire_persist_dir: str = "./questionnaire_vectorstores",
|
| 177 |
+
max_retries: int = 2,
|
| 178 |
+
verbose: bool = True
|
| 179 |
+
):
|
| 180 |
+
self.openai_api_key = openai_api_key
|
| 181 |
+
self.pinecone_api_key = pinecone_api_key
|
| 182 |
+
self.verbose = verbose
|
| 183 |
+
self.max_retries = max_retries
|
| 184 |
+
|
| 185 |
+
# Initialize LLM
|
| 186 |
+
self.llm = ChatOpenAI(
|
| 187 |
+
model=os.getenv("OPENAI_MODEL", "gpt-4o"),
|
| 188 |
+
temperature=0
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Initialize questionnaire RAG
|
| 192 |
+
if self.verbose:
|
| 193 |
+
print("Initializing questionnaire RAG system...")
|
| 194 |
+
self.questionnaire_rag = QuestionnaireRAG(
|
| 195 |
+
openai_api_key=openai_api_key,
|
| 196 |
+
pinecone_api_key=pinecone_api_key,
|
| 197 |
+
persist_directory=questionnaire_persist_dir,
|
| 198 |
+
verbose=verbose
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# Build the graph
|
| 202 |
+
self.graph = self._build_graph()
|
| 203 |
+
|
| 204 |
+
if self.verbose:
|
| 205 |
+
print("✓ Survey analysis agent initialized with staged research capability")
|
| 206 |
+
|
| 207 |
+
def _build_graph(self) -> StateGraph:
|
| 208 |
+
"""Build the LangGraph workflow with staged research support"""
|
| 209 |
+
|
| 210 |
+
workflow = StateGraph(SurveyAnalysisState)
|
| 211 |
+
|
| 212 |
+
# Add nodes
|
| 213 |
+
workflow.add_node("generate_research_brief", self._generate_research_brief)
|
| 214 |
+
workflow.add_node("execute_stage", self._execute_stage)
|
| 215 |
+
workflow.add_node("extract_stage_context", self._extract_stage_context)
|
| 216 |
+
workflow.add_node("verify_results", self._verify_results)
|
| 217 |
+
workflow.add_node("synthesize_response", self._synthesize_response)
|
| 218 |
+
|
| 219 |
+
# Define edges
|
| 220 |
+
workflow.add_edge(START, "generate_research_brief")
|
| 221 |
+
|
| 222 |
+
# After research brief, route based on action
|
| 223 |
+
workflow.add_conditional_edges(
|
| 224 |
+
"generate_research_brief",
|
| 225 |
+
self._route_after_brief,
|
| 226 |
+
{
|
| 227 |
+
"followup": END,
|
| 228 |
+
"answer": "synthesize_response",
|
| 229 |
+
"execute_stage": "execute_stage"
|
| 230 |
+
}
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# After stage execution, extract context for next stage
|
| 234 |
+
workflow.add_edge("execute_stage", "extract_stage_context")
|
| 235 |
+
|
| 236 |
+
# After context extraction, decide next step
|
| 237 |
+
workflow.add_conditional_edges(
|
| 238 |
+
"extract_stage_context",
|
| 239 |
+
self._route_after_stage,
|
| 240 |
+
{
|
| 241 |
+
"next_stage": "execute_stage", # More stages to go
|
| 242 |
+
"verify": "verify_results" # All stages done, verify
|
| 243 |
+
}
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# After verification, decide next step
|
| 247 |
+
workflow.add_conditional_edges(
|
| 248 |
+
"verify_results",
|
| 249 |
+
self._route_after_verification,
|
| 250 |
+
{
|
| 251 |
+
"synthesize": "synthesize_response",
|
| 252 |
+
"retry": "generate_research_brief",
|
| 253 |
+
"give_up": "synthesize_response"
|
| 254 |
+
}
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# End after synthesis
|
| 258 |
+
workflow.add_edge("synthesize_response", END)
|
| 259 |
+
|
| 260 |
+
# Compile with memory
|
| 261 |
+
memory = MemorySaver()
|
| 262 |
+
return workflow.compile(checkpointer=memory)
|
| 263 |
+
|
| 264 |
+
def _get_available_surveys_description(self) -> str:
|
| 265 |
+
"""Get formatted description of available surveys for LLM prompt"""
|
| 266 |
+
survey_names = self.questionnaire_rag.get_available_survey_names()
|
| 267 |
+
|
| 268 |
+
if not survey_names:
|
| 269 |
+
return "No surveys currently loaded."
|
| 270 |
+
|
| 271 |
+
lines = ["Available survey names in the system:"]
|
| 272 |
+
for name in survey_names:
|
| 273 |
+
# Show both the stored name and common variations
|
| 274 |
+
lines.append(f" - Stored as: '{name}'")
|
| 275 |
+
# Parse variations
|
| 276 |
+
variations = []
|
| 277 |
+
# Remove underscores for common term
|
| 278 |
+
clean = name.replace("_", " ")
|
| 279 |
+
if clean != name:
|
| 280 |
+
variations.append(f"'{clean}'")
|
| 281 |
+
# Extract key words
|
| 282 |
+
words = clean.split()
|
| 283 |
+
if len(words) > 1:
|
| 284 |
+
# Last few words might be the short name
|
| 285 |
+
short_name = " ".join(words[-2:]) if len(words) >= 2 else words[-1]
|
| 286 |
+
if short_name != clean:
|
| 287 |
+
variations.append(f"'{short_name}'")
|
| 288 |
+
|
| 289 |
+
if variations:
|
| 290 |
+
lines.append(f" (users might say: {', '.join(variations)})")
|
| 291 |
+
|
| 292 |
+
lines.append("\nIMPORTANT: Use the exact stored name in your filters!")
|
| 293 |
+
return "\n".join(lines)
|
| 294 |
+
|
| 295 |
+
# TODO: REMOVE WHEN PIPELINES READY - START
|
| 296 |
+
def _get_pipeline_status_description(self) -> str:
|
| 297 |
+
"""Get description of available vs unavailable pipelines"""
|
| 298 |
+
all_pipelines = {
|
| 299 |
+
"questionnaire": "Survey questions, response options, topics, skip logic, sampling",
|
| 300 |
+
"toplines": "Pre-computed response frequencies for each question",
|
| 301 |
+
"crosstabs": "Pre-computed cross-tabulations by demographics",
|
| 302 |
+
"sql": "Raw survey responses for custom analysis"
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
lines = []
|
| 306 |
+
for pipeline, description in all_pipelines.items():
|
| 307 |
+
status = "✅ AVAILABLE" if pipeline in self.AVAILABLE_PIPELINES else "❌ NOT YET AVAILABLE"
|
| 308 |
+
lines.append(f"{pipeline.capitalize()}: {description} {status}")
|
| 309 |
+
|
| 310 |
+
return "\n".join(lines)
|
| 311 |
+
# TODO: REMOVE WHEN PIPELINES READY - END
|
| 312 |
+
|
| 313 |
+
def _get_full_question_context(self, state: SurveyAnalysisState) -> str:
|
| 314 |
+
"""
|
| 315 |
+
Build full question context from conversation history.
|
| 316 |
+
|
| 317 |
+
This handles cases where the user's question is split across multiple turns:
|
| 318 |
+
- Turn 1: "what questions were asked?"
|
| 319 |
+
- Turn 2: "June 2025, unity poll"
|
| 320 |
+
|
| 321 |
+
We need to combine these to understand the full intent.
|
| 322 |
+
"""
|
| 323 |
+
messages = state.get("messages", [])
|
| 324 |
+
|
| 325 |
+
# Extract all human messages (excluding system/AI messages)
|
| 326 |
+
human_messages = []
|
| 327 |
+
for msg in messages:
|
| 328 |
+
if isinstance(msg, HumanMessage):
|
| 329 |
+
human_messages.append(msg.content)
|
| 330 |
+
|
| 331 |
+
if not human_messages:
|
| 332 |
+
return state["user_question"]
|
| 333 |
+
|
| 334 |
+
if self.verbose:
|
| 335 |
+
print(f"📝 Conversation history: {len(human_messages)} user message(s)")
|
| 336 |
+
for i, msg in enumerate(human_messages, 1):
|
| 337 |
+
print(f" {i}. {msg[:100]}..." if len(msg) > 100 else f" {i}. {msg}")
|
| 338 |
+
|
| 339 |
+
# If there's only one message, just use it
|
| 340 |
+
if len(human_messages) == 1:
|
| 341 |
+
return human_messages[0]
|
| 342 |
+
|
| 343 |
+
# Multiple messages - combine them intelligently
|
| 344 |
+
# The last message is usually the most specific (e.g., "June 2025, unity poll")
|
| 345 |
+
# Earlier messages provide the intent (e.g., "what questions were asked?")
|
| 346 |
+
|
| 347 |
+
# Check if the first message is a question and the second is a clarification
|
| 348 |
+
first_msg = human_messages[0].lower()
|
| 349 |
+
is_followup_scenario = any(word in first_msg for word in ["what", "which", "how", "show", "list", "tell"])
|
| 350 |
+
|
| 351 |
+
if is_followup_scenario and len(human_messages) == 2:
|
| 352 |
+
# Combine: "what questions were asked? [from] June 2025, unity poll"
|
| 353 |
+
combined = f"{human_messages[0]} (specifically: {human_messages[1]})"
|
| 354 |
+
if self.verbose:
|
| 355 |
+
print(f"🔗 Combined context: {combined}")
|
| 356 |
+
return combined
|
| 357 |
+
|
| 358 |
+
# For other cases, join all messages
|
| 359 |
+
combined = " | ".join(human_messages)
|
| 360 |
+
if self.verbose:
|
| 361 |
+
print(f"🔗 Combined context: {combined}")
|
| 362 |
+
return combined
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
# ========================================================================
|
| 366 |
+
# NODE FUNCTIONS
|
| 367 |
+
# ========================================================================
|
| 368 |
+
|
| 369 |
+
def _generate_research_brief(self, state: SurveyAnalysisState) -> Dict[str, Any]:
|
| 370 |
+
"""Generate research brief - decides single-stage vs multi-stage approach"""
|
| 371 |
+
|
| 372 |
+
if self.verbose:
|
| 373 |
+
print("\n=== GENERATING RESEARCH BRIEF ===")
|
| 374 |
+
|
| 375 |
+
# Get full question context from conversation history
|
| 376 |
+
question = self._get_full_question_context(state)
|
| 377 |
+
original_question = state["user_question"] # Keep original for reference
|
| 378 |
+
|
| 379 |
+
if self.verbose and question != original_question:
|
| 380 |
+
print(f"💬 Using full context from conversation history")
|
| 381 |
+
|
| 382 |
+
retry_count = state.get("retry_count", 0)
|
| 383 |
+
|
| 384 |
+
# Add context from verification if this is a retry
|
| 385 |
+
verification_context = ""
|
| 386 |
+
if state.get("verification") and retry_count > 0:
|
| 387 |
+
verification_context = f"""
|
| 388 |
+
Previous attempt was insufficient:
|
| 389 |
+
- Missing: {state['verification'].missing_info}
|
| 390 |
+
- Suggestion: {state['verification'].improvement_suggestion}
|
| 391 |
+
|
| 392 |
+
Please improve the research plan based on this feedback.
|
| 393 |
+
"""
|
| 394 |
+
|
| 395 |
+
system_prompt = f"""You are a research planning expert for survey data analysis.
|
| 396 |
+
|
| 397 |
+
# TODO: REMOVE WHEN PIPELINES READY - Use dynamic status
|
| 398 |
+
Available data sources:
|
| 399 |
+
{self._get_pipeline_status_description()}
|
| 400 |
+
|
| 401 |
+
# TODO: REMOVE WHEN PIPELINES READY - START
|
| 402 |
+
⚠️ IMPORTANT: Currently ONLY the questionnaire pipeline is available.
|
| 403 |
+
- Do NOT create research plans that require toplines, crosstabs, or SQL
|
| 404 |
+
- If the user asks for results/data/analysis that requires those sources, use action="followup" to inform them
|
| 405 |
+
- Focus on what CAN be answered with questionnaires alone (question text, response options, topics, skip logic)
|
| 406 |
+
# TODO: REMOVE WHEN PIPELINES READY - END
|
| 407 |
+
|
| 408 |
+
{self._get_available_surveys_description()}
|
| 409 |
+
|
| 410 |
+
You have FOUR possible actions:
|
| 411 |
+
|
| 412 |
+
**1. followup** - Ask clarifying question if ambiguous OR if user asks for unavailable data
|
| 413 |
+
|
| 414 |
+
**2. answer** - Answer directly without data (system questions, general knowledge)
|
| 415 |
+
|
| 416 |
+
**3. route_to_sources** - Simple query that can be answered with parallel data retrieval
|
| 417 |
+
Use this for:
|
| 418 |
+
- "What questions were asked in June 2025?"
|
| 419 |
+
- "Show me all healthcare questions"
|
| 420 |
+
- Questions that don't require sequential reasoning
|
| 421 |
+
|
| 422 |
+
**4. execute_stages** - Complex query requiring STAGED research
|
| 423 |
+
Use this for:
|
| 424 |
+
- Queries with "most/least/best/worst" (need stage 1: retrieve, stage 2: analyze)
|
| 425 |
+
- Comparative queries "compare 2024 vs 2025" (need separate stages to maintain context)
|
| 426 |
+
- Queries depending on intermediate results
|
| 427 |
+
- "What demographics differ most?" (stage 1: get questions, stage 2: get crosstabs for those questions)
|
| 428 |
+
|
| 429 |
+
# TODO: REMOVE WHEN PIPELINES READY - START
|
| 430 |
+
NOTE: Since toplines/crosstabs/SQL aren't available, only use execute_stages for comparing questionnaires
|
| 431 |
+
# TODO: REMOVE WHEN PIPELINES READY - END
|
| 432 |
+
|
| 433 |
+
When using stages:
|
| 434 |
+
- Each stage can use results from previous stages via `use_previous_results_for`
|
| 435 |
+
- Later stages can filter by question_ids extracted from earlier stages
|
| 436 |
+
- Each stage can have a `result_label` to maintain separate contexts
|
| 437 |
+
|
| 438 |
+
CRITICAL FILTERING RULES:
|
| 439 |
+
- **Survey Names**: User queries like "Unity Poll" or "Vanderbilt Unity Poll" should map to the exact stored name shown above
|
| 440 |
+
- When you see "Unity Poll" in a query, use the exact stored name in your filter
|
| 441 |
+
- Only specify filters if explicitly mentioned or clearly implied
|
| 442 |
+
- For staged queries, be explicit about how each stage uses previous results
|
| 443 |
+
- Use `question_ids` filter when later stages need specific questions from earlier stages
|
| 444 |
+
- Year and month are usually sufficient - survey_name is optional unless needed for disambiguation
|
| 445 |
+
|
| 446 |
+
{verification_context}
|
| 447 |
+
|
| 448 |
+
Examples:
|
| 449 |
+
|
| 450 |
+
# TODO: REMOVE WHEN PIPELINES READY - START
|
| 451 |
+
User asks for results/analysis → Inform them:
|
| 452 |
+
Q: "What were the topline results for June 2025?"
|
| 453 |
+
Brief:
|
| 454 |
+
action: followup
|
| 455 |
+
followup_question: "I can show you the questions asked in June 2025, but topline results aren't available yet. Would you like to see the questions?"
|
| 456 |
+
# TODO: REMOVE WHEN PIPELINES READY - END
|
| 457 |
+
|
| 458 |
+
User says "Unity Poll" → Use stored name in filter:
|
| 459 |
+
Q: "What questions were asked in June 2025 Unity Poll?"
|
| 460 |
+
Brief:
|
| 461 |
+
action: route_to_sources
|
| 462 |
+
data_sources: [questionnaire with year=2025, month=June, survey_name='Vanderbilt_Unity_Poll']
|
| 463 |
+
|
| 464 |
+
Simple Query → route_to_sources:
|
| 465 |
+
Q: "What questions were asked in June 2025?"
|
| 466 |
+
Brief:
|
| 467 |
+
action: route_to_sources
|
| 468 |
+
data_sources: [questionnaire with June 2025 filters]
|
| 469 |
+
|
| 470 |
+
Complex Query → execute_stages:
|
| 471 |
+
Q: "Compare immigration questions from 2024 vs 2025"
|
| 472 |
+
Brief:
|
| 473 |
+
action: execute_stages
|
| 474 |
+
stages:
|
| 475 |
+
- stage 1: Get 2024 immigration questions (label: "2024_questions")
|
| 476 |
+
- stage 2: Get 2025 immigration questions (label: "2025_questions")
|
| 477 |
+
- stage 3: Compare the two sets in synthesis
|
| 478 |
+
"""
|
| 479 |
+
|
| 480 |
+
brief_generator = self.llm.with_structured_output(ResearchBrief)
|
| 481 |
+
|
| 482 |
+
brief = brief_generator.invoke([
|
| 483 |
+
SystemMessage(content=system_prompt),
|
| 484 |
+
HumanMessage(content=f"User question: {question}\n\nGenerate a research brief.")
|
| 485 |
+
])
|
| 486 |
+
|
| 487 |
+
if self.verbose:
|
| 488 |
+
print(f"Action: {brief.action}")
|
| 489 |
+
print(f"Reasoning: {brief.reasoning}")
|
| 490 |
+
|
| 491 |
+
if brief.followup_question:
|
| 492 |
+
print(f"Follow-up: {brief.followup_question}")
|
| 493 |
+
|
| 494 |
+
if brief.action == "route_to_sources" and brief.data_sources:
|
| 495 |
+
print(f"Simple query - {len(brief.data_sources)} data sources")
|
| 496 |
+
for ds in brief.data_sources:
|
| 497 |
+
filters_dict = {k: v for k, v in ds.filters.model_dump().items() if v is not None}
|
| 498 |
+
print(f" - {ds.source_type}: {ds.query_description}")
|
| 499 |
+
if filters_dict:
|
| 500 |
+
print(f" Filters: {filters_dict}")
|
| 501 |
+
|
| 502 |
+
if brief.action == "execute_stages" and brief.stages:
|
| 503 |
+
print(f"Staged query - {len(brief.stages)} stages")
|
| 504 |
+
for stage in brief.stages:
|
| 505 |
+
print(f"\nStage {stage.stage_number}: {stage.description}")
|
| 506 |
+
if stage.depends_on_stages:
|
| 507 |
+
print(f" Depends on: stages {stage.depends_on_stages}")
|
| 508 |
+
if stage.use_previous_results_for:
|
| 509 |
+
print(f" Uses previous: {stage.use_previous_results_for}")
|
| 510 |
+
for ds in stage.data_sources:
|
| 511 |
+
print(f" - {ds.source_type}: {ds.query_description}")
|
| 512 |
+
if ds.result_label:
|
| 513 |
+
print(f" Label: {ds.result_label}")
|
| 514 |
+
|
| 515 |
+
return {
|
| 516 |
+
"research_brief": brief,
|
| 517 |
+
"current_stage": 0, # Start at stage 0 (will execute stage 1 first)
|
| 518 |
+
"stage_results": [],
|
| 519 |
+
"messages": [AIMessage(content=f"[Research plan: {brief.action}]")]
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
def _route_after_brief(self, state: SurveyAnalysisState) -> str:
|
| 523 |
+
"""Route based on research brief action"""
|
| 524 |
+
brief = state["research_brief"]
|
| 525 |
+
|
| 526 |
+
if brief.action == "followup":
|
| 527 |
+
return "followup"
|
| 528 |
+
elif brief.action == "answer":
|
| 529 |
+
return "answer"
|
| 530 |
+
elif brief.action == "execute_stages":
|
| 531 |
+
return "execute_stage"
|
| 532 |
+
else: # route_to_sources
|
| 533 |
+
return "execute_stage" # We'll handle both single and staged in execute_stage
|
| 534 |
+
|
| 535 |
+
def _execute_stage(self, state: SurveyAnalysisState) -> Dict[str, Any]:
|
| 536 |
+
"""Execute one stage of research (handles both single-stage and multi-stage)"""
|
| 537 |
+
|
| 538 |
+
brief = state["research_brief"]
|
| 539 |
+
current_stage_idx = state.get("current_stage", 0)
|
| 540 |
+
previous_stage_results = state.get("stage_results", [])
|
| 541 |
+
|
| 542 |
+
# Determine if this is single-stage or multi-stage
|
| 543 |
+
if brief.action == "route_to_sources":
|
| 544 |
+
# Single-stage: use data_sources directly
|
| 545 |
+
if self.verbose:
|
| 546 |
+
print(f"\n=== EXECUTING SINGLE-STAGE RESEARCH ===")
|
| 547 |
+
|
| 548 |
+
stage_data_sources = brief.data_sources
|
| 549 |
+
stage_desc = "Single-stage retrieval"
|
| 550 |
+
|
| 551 |
+
elif brief.action == "execute_stages":
|
| 552 |
+
# Multi-stage: get current stage
|
| 553 |
+
stage = brief.stages[current_stage_idx]
|
| 554 |
+
|
| 555 |
+
if self.verbose:
|
| 556 |
+
print(f"\n=== EXECUTING STAGE {stage.stage_number}/{len(brief.stages)} ===")
|
| 557 |
+
print(f"Description: {stage.description}")
|
| 558 |
+
|
| 559 |
+
stage_data_sources = stage.data_sources
|
| 560 |
+
stage_desc = stage.description
|
| 561 |
+
|
| 562 |
+
# If this stage depends on previous stages, enrich filters with context
|
| 563 |
+
if stage.use_previous_results_for and previous_stage_results:
|
| 564 |
+
stage_data_sources = self._enrich_data_sources_with_context(
|
| 565 |
+
stage_data_sources,
|
| 566 |
+
previous_stage_results,
|
| 567 |
+
stage.use_previous_results_for
|
| 568 |
+
)
|
| 569 |
+
else:
|
| 570 |
+
return {}
|
| 571 |
+
|
| 572 |
+
# Execute pipelines for this stage
|
| 573 |
+
stage_result = StageResult(
|
| 574 |
+
stage_number=current_stage_idx + 1,
|
| 575 |
+
status="success"
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
# TODO: REMOVE WHEN PIPELINES READY - Track what was attempted vs available
|
| 579 |
+
attempted_pipelines = []
|
| 580 |
+
unavailable_pipelines = []
|
| 581 |
+
|
| 582 |
+
# Run each pipeline
|
| 583 |
+
for ds in stage_data_sources:
|
| 584 |
+
filters_dict = {k: v for k, v in ds.filters.model_dump().items() if v is not None}
|
| 585 |
+
|
| 586 |
+
# TODO: REMOVE WHEN PIPELINES READY - START
|
| 587 |
+
attempted_pipelines.append(ds.source_type)
|
| 588 |
+
# TODO: REMOVE WHEN PIPELINES READY - END
|
| 589 |
+
|
| 590 |
+
if ds.source_type == "questionnaire":
|
| 591 |
+
if self.verbose:
|
| 592 |
+
print(f"\nQuerying questionnaire: {ds.query_description}")
|
| 593 |
+
if filters_dict:
|
| 594 |
+
print(f"Filters: {filters_dict}")
|
| 595 |
+
|
| 596 |
+
result = self.questionnaire_rag.query_with_metadata(
|
| 597 |
+
question=ds.query_description,
|
| 598 |
+
filters=filters_dict if filters_dict else None
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
# Store with label if provided
|
| 602 |
+
if ds.result_label:
|
| 603 |
+
result["label"] = ds.result_label
|
| 604 |
+
|
| 605 |
+
stage_result.questionnaire_results = result if stage_result.questionnaire_results is None else {
|
| 606 |
+
"multiple": True,
|
| 607 |
+
"results": [stage_result.questionnaire_results, result]
|
| 608 |
+
}
|
| 609 |
+
|
| 610 |
+
if self.verbose:
|
| 611 |
+
print(f"Retrieved {result['num_sources']} questions")
|
| 612 |
+
|
| 613 |
+
# TODO: REMOVE WHEN PIPELINES READY - START
|
| 614 |
+
elif ds.source_type not in self.AVAILABLE_PIPELINES:
|
| 615 |
+
unavailable_pipelines.append(ds.source_type)
|
| 616 |
+
if self.verbose:
|
| 617 |
+
print(f"\n⚠️ {ds.source_type.upper()} pipeline not yet available - skipping")
|
| 618 |
+
print(f" Requested: {ds.query_description}")
|
| 619 |
+
# TODO: REMOVE WHEN PIPELINES READY - END
|
| 620 |
+
|
| 621 |
+
# TODO: REMOVE WHEN PIPELINES READY - START
|
| 622 |
+
# Add a note about unavailable pipelines to the stage result
|
| 623 |
+
if unavailable_pipelines:
|
| 624 |
+
if self.verbose:
|
| 625 |
+
print(f"\n⚠️ Stage {current_stage_idx + 1} incomplete: {len(unavailable_pipelines)} pipeline(s) unavailable")
|
| 626 |
+
stage_result.status = "partial"
|
| 627 |
+
# Store info about what was unavailable for the synthesizer
|
| 628 |
+
if not stage_result.extracted_context:
|
| 629 |
+
stage_result.extracted_context = {}
|
| 630 |
+
stage_result.extracted_context["unavailable_pipelines"] = unavailable_pipelines
|
| 631 |
+
# TODO: REMOVE WHEN PIPELINES READY - END
|
| 632 |
+
|
| 633 |
+
# Add stage result to list
|
| 634 |
+
updated_stage_results = previous_stage_results + [stage_result]
|
| 635 |
+
|
| 636 |
+
# For single-stage, also populate legacy fields
|
| 637 |
+
if brief.action == "route_to_sources":
|
| 638 |
+
return {
|
| 639 |
+
"stage_results": updated_stage_results,
|
| 640 |
+
"questionnaire_results": stage_result.questionnaire_results,
|
| 641 |
+
"toplines_results": stage_result.toplines_results,
|
| 642 |
+
"crosstabs_results": stage_result.crosstabs_results,
|
| 643 |
+
"sql_results": stage_result.sql_results
|
| 644 |
+
}
|
| 645 |
+
|
| 646 |
+
return {
|
| 647 |
+
"stage_results": updated_stage_results
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
def _enrich_data_sources_with_context(
|
| 651 |
+
self,
|
| 652 |
+
data_sources: List[DataSource],
|
| 653 |
+
previous_results: List[StageResult],
|
| 654 |
+
use_instruction: str
|
| 655 |
+
) -> List[DataSource]:
|
| 656 |
+
"""Enrich data sources with context from previous stages"""
|
| 657 |
+
|
| 658 |
+
if self.verbose:
|
| 659 |
+
print(f" Enriching with context: {use_instruction}")
|
| 660 |
+
|
| 661 |
+
# For now, handle the most common case: extracting question IDs
|
| 662 |
+
if "question" in use_instruction.lower() and "id" in use_instruction.lower():
|
| 663 |
+
# Extract question IDs from previous questionnaire results
|
| 664 |
+
question_ids = []
|
| 665 |
+
for prev_result in previous_results:
|
| 666 |
+
if prev_result.questionnaire_results:
|
| 667 |
+
q_results = prev_result.questionnaire_results
|
| 668 |
+
if "source_questions" in q_results:
|
| 669 |
+
question_ids.extend([q.get("question_id") for q in q_results["source_questions"]])
|
| 670 |
+
|
| 671 |
+
if question_ids and self.verbose:
|
| 672 |
+
print(f" Found {len(question_ids)} question IDs from previous stages")
|
| 673 |
+
|
| 674 |
+
# Add question_ids to filters
|
| 675 |
+
enriched_sources = []
|
| 676 |
+
for ds in data_sources:
|
| 677 |
+
new_filters = ds.filters.model_copy()
|
| 678 |
+
new_filters.question_ids = question_ids if question_ids else None
|
| 679 |
+
|
| 680 |
+
enriched_ds = ds.model_copy()
|
| 681 |
+
enriched_ds.filters = new_filters
|
| 682 |
+
enriched_sources.append(enriched_ds)
|
| 683 |
+
|
| 684 |
+
return enriched_sources
|
| 685 |
+
|
| 686 |
+
return data_sources
|
| 687 |
+
|
| 688 |
+
def _extract_stage_context(self, state: SurveyAnalysisState) -> Dict[str, Any]:
|
| 689 |
+
"""Extract key context from completed stage for use in next stages"""
|
| 690 |
+
|
| 691 |
+
stage_results = state.get("stage_results", [])
|
| 692 |
+
if not stage_results:
|
| 693 |
+
return {}
|
| 694 |
+
|
| 695 |
+
current_result = stage_results[-1]
|
| 696 |
+
|
| 697 |
+
# Extract question IDs if questionnaire results exist
|
| 698 |
+
extracted_context = {}
|
| 699 |
+
|
| 700 |
+
if current_result.questionnaire_results:
|
| 701 |
+
q_results = current_result.questionnaire_results
|
| 702 |
+
if "source_questions" in q_results:
|
| 703 |
+
question_ids = [q.get("question_id") for q in q_results["source_questions"]]
|
| 704 |
+
extracted_context["question_ids"] = question_ids
|
| 705 |
+
|
| 706 |
+
if self.verbose:
|
| 707 |
+
print(f"\n=== EXTRACTED CONTEXT FROM STAGE {current_result.stage_number} ===")
|
| 708 |
+
print(f"Question IDs: {len(question_ids)} extracted")
|
| 709 |
+
|
| 710 |
+
# Update the stage result with extracted context
|
| 711 |
+
current_result.extracted_context = extracted_context
|
| 712 |
+
|
| 713 |
+
return {}
|
| 714 |
+
|
| 715 |
+
def _route_after_stage(self, state: SurveyAnalysisState) -> str:
|
| 716 |
+
"""Decide if we need to execute another stage or move to verification"""
|
| 717 |
+
|
| 718 |
+
brief = state["research_brief"]
|
| 719 |
+
current_stage_idx = state.get("current_stage", 0)
|
| 720 |
+
|
| 721 |
+
# Single-stage query
|
| 722 |
+
if brief.action == "route_to_sources":
|
| 723 |
+
if self.verbose:
|
| 724 |
+
print("\n=== SINGLE-STAGE COMPLETE → VERIFICATION ===")
|
| 725 |
+
return "verify"
|
| 726 |
+
|
| 727 |
+
# Multi-stage query
|
| 728 |
+
total_stages = len(brief.stages)
|
| 729 |
+
next_stage_idx = current_stage_idx + 1
|
| 730 |
+
|
| 731 |
+
if next_stage_idx < total_stages:
|
| 732 |
+
if self.verbose:
|
| 733 |
+
print(f"\n=== MORE STAGES REMAINING ({next_stage_idx + 1}/{total_stages}) → NEXT STAGE ===")
|
| 734 |
+
return "next_stage"
|
| 735 |
+
else:
|
| 736 |
+
if self.verbose:
|
| 737 |
+
print(f"\n=== ALL {total_stages} STAGES COMPLETE → VERIFICATION ===")
|
| 738 |
+
return "verify"
|
| 739 |
+
|
| 740 |
+
def _verify_results(self, state: SurveyAnalysisState) -> Dict[str, Any]:
|
| 741 |
+
"""Verify that retrieved data answers the question"""
|
| 742 |
+
|
| 743 |
+
if self.verbose:
|
| 744 |
+
print("\n=== VERIFYING RESULTS ===")
|
| 745 |
+
|
| 746 |
+
# Build full question context from conversation history
|
| 747 |
+
question = self._get_full_question_context(state)
|
| 748 |
+
|
| 749 |
+
if self.verbose and question != state["user_question"]:
|
| 750 |
+
print(f"💬 Using full context: {question[:150]}...")
|
| 751 |
+
|
| 752 |
+
stage_results = state.get("stage_results", [])
|
| 753 |
+
brief = state["research_brief"]
|
| 754 |
+
|
| 755 |
+
# Build summary of what we retrieved
|
| 756 |
+
retrieval_summary = []
|
| 757 |
+
total_questions = 0
|
| 758 |
+
|
| 759 |
+
# TODO: REMOVE WHEN PIPELINES READY - START
|
| 760 |
+
unavailable_pipelines_found = []
|
| 761 |
+
# TODO: REMOVE WHEN PIPELINES READY - END
|
| 762 |
+
|
| 763 |
+
for stage_result in stage_results:
|
| 764 |
+
if stage_result.questionnaire_results:
|
| 765 |
+
q_res = stage_result.questionnaire_results
|
| 766 |
+
num = q_res.get("num_sources", 0)
|
| 767 |
+
total_questions += num
|
| 768 |
+
retrieval_summary.append(f"Stage {stage_result.stage_number}: Retrieved {num} questions")
|
| 769 |
+
|
| 770 |
+
# TODO: REMOVE WHEN PIPELINES READY - START
|
| 771 |
+
# Check if any pipelines were unavailable
|
| 772 |
+
if stage_result.extracted_context and "unavailable_pipelines" in stage_result.extracted_context:
|
| 773 |
+
unavailable = stage_result.extracted_context["unavailable_pipelines"]
|
| 774 |
+
unavailable_pipelines_found.extend(unavailable)
|
| 775 |
+
retrieval_summary.append(f"Stage {stage_result.stage_number}: ⚠️ {', '.join(unavailable)} not yet available")
|
| 776 |
+
# TODO: REMOVE WHEN PIPELINES READY - END
|
| 777 |
+
|
| 778 |
+
if not retrieval_summary:
|
| 779 |
+
retrieval_summary.append("No data was retrieved")
|
| 780 |
+
# Simple heuristic: if this is a single-stage simple query and we got results, auto-pass
|
| 781 |
+
if brief.action == "route_to_sources" and len(stage_results) == 1 and total_questions > 0:
|
| 782 |
+
# Check if question is a simple "what questions" type query
|
| 783 |
+
question_lower = question.lower()
|
| 784 |
+
simple_patterns = ["what question", "which question", "list question", "show question", "questions asked"]
|
| 785 |
+
|
| 786 |
+
if any(pattern in question_lower for pattern in simple_patterns):
|
| 787 |
+
if self.verbose:
|
| 788 |
+
print(f"✓ Auto-pass: Simple question retrieval with {total_questions} results")
|
| 789 |
+
|
| 790 |
+
return {
|
| 791 |
+
"verification": VerificationResult(
|
| 792 |
+
answers_question=True,
|
| 793 |
+
missing_info=None,
|
| 794 |
+
improvement_suggestion=None
|
| 795 |
+
)
|
| 796 |
+
}
|
| 797 |
+
|
| 798 |
+
# TODO: REMOVE WHEN PIPELINES READY - START
|
| 799 |
+
# If we have unavailable pipelines but got questionnaire data, auto-pass with note
|
| 800 |
+
if unavailable_pipelines_found and total_questions > 0:
|
| 801 |
+
if self.verbose:
|
| 802 |
+
print(f"✓ Auto-pass: Got questionnaire data, {len(unavailable_pipelines_found)} pipeline(s) not yet available")
|
| 803 |
+
|
| 804 |
+
return {
|
| 805 |
+
"verification": VerificationResult(
|
| 806 |
+
answers_question=True,
|
| 807 |
+
missing_info=None,
|
| 808 |
+
improvement_suggestion=None
|
| 809 |
+
)
|
| 810 |
+
}
|
| 811 |
+
# TODO: REMOVE WHEN PIPELINES READY - END
|
| 812 |
+
|
| 813 |
+
# If we got 0 results, auto-fail without calling LLM
|
| 814 |
+
if total_questions == 0:
|
| 815 |
+
if self.verbose:
|
| 816 |
+
print("✗ Auto-fail: No results retrieved")
|
| 817 |
+
|
| 818 |
+
return {
|
| 819 |
+
"verification": VerificationResult(
|
| 820 |
+
answers_question=False,
|
| 821 |
+
missing_info="No data was retrieved",
|
| 822 |
+
improvement_suggestion="Adjust filters or search criteria"
|
| 823 |
+
),
|
| 824 |
+
"retry_count": state.get("retry_count", 0) + 1
|
| 825 |
+
}
|
| 826 |
+
|
| 827 |
+
# For other cases, use LLM verification
|
| 828 |
+
system_prompt = """You are a verification expert. Your ONLY job is to check if the retrieved data matches what the user asked for.
|
| 829 |
+
|
| 830 |
+
CRITICAL RULES:
|
| 831 |
+
1. **Match the question literally** - Don't add requirements the user didn't ask for
|
| 832 |
+
- If they asked "what questions were asked?" and we retrieved questions → SUCCESS
|
| 833 |
+
- If they asked "what are the results?" and we only have questions → FAILURE
|
| 834 |
+
|
| 835 |
+
2. **Don't overthink it** - Keep it simple:
|
| 836 |
+
- Did we retrieve the type of data they asked for? (questions, results, etc.)
|
| 837 |
+
- Is it from the right time period/survey they specified?
|
| 838 |
+
- Is there enough data (at least 1 result)?
|
| 839 |
+
|
| 840 |
+
3. **Only fail if there's an actual problem**:
|
| 841 |
+
- We retrieved the wrong type of data (e.g., questions when they asked for results)
|
| 842 |
+
- We retrieved from the wrong time period/survey
|
| 843 |
+
|
| 844 |
+
4. **Do NOT fail if**:
|
| 845 |
+
- User asked for questions and we got questions (even if we don't have "analysis")
|
| 846 |
+
- User asked for data from June 2025 and that's what we got
|
| 847 |
+
- The data seems sufficient to answer their actual question
|
| 848 |
+
|
| 849 |
+
Be practical, not pedantic. If the retrieved data can answer what they asked, approve it.
|
| 850 |
+
"""
|
| 851 |
+
|
| 852 |
+
verifier = self.llm.with_structured_output(VerificationResult)
|
| 853 |
+
|
| 854 |
+
verification = verifier.invoke([
|
| 855 |
+
SystemMessage(content=system_prompt),
|
| 856 |
+
HumanMessage(content=f"""
|
| 857 |
+
User question: "{question}"
|
| 858 |
+
|
| 859 |
+
What we retrieved:
|
| 860 |
+
{chr(10).join(retrieval_summary)}
|
| 861 |
+
|
| 862 |
+
Simple question: Can we answer their question with this data? YES or NO.
|
| 863 |
+
""")
|
| 864 |
+
])
|
| 865 |
+
|
| 866 |
+
if self.verbose:
|
| 867 |
+
print(f"Answers question: {verification.answers_question}")
|
| 868 |
+
if not verification.answers_question:
|
| 869 |
+
print(f"Missing: {verification.missing_info}")
|
| 870 |
+
print(f"Suggestion: {verification.improvement_suggestion}")
|
| 871 |
+
|
| 872 |
+
# ⭐ INCREMENT RETRY COUNT IF VERIFICATION FAILS
|
| 873 |
+
updates = {"verification": verification}
|
| 874 |
+
if not verification.answers_question:
|
| 875 |
+
current_retry = state.get("retry_count", 0)
|
| 876 |
+
updates["retry_count"] = current_retry + 1
|
| 877 |
+
|
| 878 |
+
return updates
|
| 879 |
+
|
| 880 |
+
def _route_after_verification(self, state: SurveyAnalysisState) -> str:
|
| 881 |
+
"""Route based on verification result"""
|
| 882 |
+
|
| 883 |
+
verification = state["verification"]
|
| 884 |
+
retry_count = state.get("retry_count", 0)
|
| 885 |
+
max_retries = state.get("max_retries", self.max_retries)
|
| 886 |
+
|
| 887 |
+
if verification.answers_question:
|
| 888 |
+
return "synthesize"
|
| 889 |
+
elif retry_count < max_retries:
|
| 890 |
+
if self.verbose:
|
| 891 |
+
print(f"\n⚠️ Retry {retry_count + 1}/{max_retries}")
|
| 892 |
+
return "retry"
|
| 893 |
+
else:
|
| 894 |
+
if self.verbose:
|
| 895 |
+
print(f"\n⚠️ Max retries reached, proceeding with partial results")
|
| 896 |
+
return "give_up"
|
| 897 |
+
|
| 898 |
+
def _synthesize_response(self, state: SurveyAnalysisState) -> Dict[str, Any]:
|
| 899 |
+
"""Synthesize final response from all results"""
|
| 900 |
+
|
| 901 |
+
if self.verbose:
|
| 902 |
+
print("\n=== SYNTHESIZING RESPONSE ===")
|
| 903 |
+
|
| 904 |
+
brief = state["research_brief"]
|
| 905 |
+
|
| 906 |
+
# Get full question context from conversation history
|
| 907 |
+
full_question = self._get_full_question_context(state)
|
| 908 |
+
|
| 909 |
+
if self.verbose and full_question != state["user_question"]:
|
| 910 |
+
print(f"💬 Using full context: {full_question[:150]}...")
|
| 911 |
+
|
| 912 |
+
# Handle followup action
|
| 913 |
+
if brief.action == "followup":
|
| 914 |
+
if self.verbose:
|
| 915 |
+
print("Returning followup question")
|
| 916 |
+
return {
|
| 917 |
+
"final_answer": brief.followup_question,
|
| 918 |
+
"messages": [AIMessage(content=brief.followup_question)]
|
| 919 |
+
}
|
| 920 |
+
|
| 921 |
+
# Handle direct answer (no data retrieval)
|
| 922 |
+
if brief.action == "answer":
|
| 923 |
+
if self.verbose:
|
| 924 |
+
print("Generating direct answer without data")
|
| 925 |
+
answer = self.llm.invoke([
|
| 926 |
+
SystemMessage(content="Answer the user's question directly."),
|
| 927 |
+
HumanMessage(content=full_question)
|
| 928 |
+
]).content
|
| 929 |
+
|
| 930 |
+
return {
|
| 931 |
+
"final_answer": answer,
|
| 932 |
+
"messages": [AIMessage(content=answer)]
|
| 933 |
+
}
|
| 934 |
+
|
| 935 |
+
# Get stage results
|
| 936 |
+
stage_results = state.get("stage_results", [])
|
| 937 |
+
|
| 938 |
+
if not stage_results:
|
| 939 |
+
if self.verbose:
|
| 940 |
+
print("No stage results available")
|
| 941 |
+
return {
|
| 942 |
+
"final_answer": "I was unable to retrieve any data to answer your question.",
|
| 943 |
+
"messages": [AIMessage(content="I was unable to retrieve any data to answer your question.")]
|
| 944 |
+
}
|
| 945 |
+
|
| 946 |
+
# CASE 1: Single stage with single pipeline → return direct answer
|
| 947 |
+
if len(stage_results) == 1:
|
| 948 |
+
stage_result = stage_results[0]
|
| 949 |
+
|
| 950 |
+
# Check if only one pipeline returned data
|
| 951 |
+
pipelines_with_data = 0
|
| 952 |
+
direct_answer = None
|
| 953 |
+
|
| 954 |
+
if stage_result.questionnaire_results:
|
| 955 |
+
pipelines_with_data += 1
|
| 956 |
+
direct_answer = stage_result.questionnaire_results.get("answer")
|
| 957 |
+
|
| 958 |
+
if pipelines_with_data == 1 and direct_answer:
|
| 959 |
+
if self.verbose:
|
| 960 |
+
print("Single stage, single pipeline - returning direct answer (no synthesis)")
|
| 961 |
+
return {
|
| 962 |
+
"final_answer": direct_answer,
|
| 963 |
+
"messages": [AIMessage(content=direct_answer)]
|
| 964 |
+
}
|
| 965 |
+
|
| 966 |
+
# CASE 2: Multiple stages or multiple pipelines → synthesize
|
| 967 |
+
if self.verbose:
|
| 968 |
+
print(f"Synthesizing from {len(stage_results)} stage(s)")
|
| 969 |
+
|
| 970 |
+
# Build context from all stages
|
| 971 |
+
context_parts = []
|
| 972 |
+
|
| 973 |
+
# TODO: REMOVE WHEN PIPELINES READY - START
|
| 974 |
+
unavailable_pipelines_overall = []
|
| 975 |
+
# TODO: REMOVE WHEN PIPELINES READY - END
|
| 976 |
+
|
| 977 |
+
for i, stage_result in enumerate(stage_results, 1):
|
| 978 |
+
if stage_result.questionnaire_results:
|
| 979 |
+
q_res = stage_result.questionnaire_results
|
| 980 |
+
|
| 981 |
+
# Check if this is a labeled result
|
| 982 |
+
label = q_res.get("label", f"Stage {i}")
|
| 983 |
+
|
| 984 |
+
context_parts.append(f"\n=== {label.upper()} ===")
|
| 985 |
+
context_parts.append(f"Stage {i} results:")
|
| 986 |
+
context_parts.append(q_res.get("answer", "No answer available"))
|
| 987 |
+
|
| 988 |
+
# TODO: REMOVE WHEN PIPELINES READY - START
|
| 989 |
+
# Track unavailable pipelines for note in synthesis
|
| 990 |
+
if stage_result.extracted_context and "unavailable_pipelines" in stage_result.extracted_context:
|
| 991 |
+
unavailable = stage_result.extracted_context["unavailable_pipelines"]
|
| 992 |
+
unavailable_pipelines_overall.extend(unavailable)
|
| 993 |
+
context_parts.append(f"\n⚠️ Note: {', '.join(unavailable)} data was requested but not yet available")
|
| 994 |
+
# TODO: REMOVE WHEN PIPELINES READY - END
|
| 995 |
+
|
| 996 |
+
# TODO: REMOVE WHEN PIPELINES READY - START
|
| 997 |
+
unavailable_note = ""
|
| 998 |
+
if unavailable_pipelines_overall:
|
| 999 |
+
unique_unavailable = list(set(unavailable_pipelines_overall))
|
| 1000 |
+
unavailable_note = f"""
|
| 1001 |
+
|
| 1002 |
+
⚠️ IMPORTANT: The following data sources were requested but are not yet available:
|
| 1003 |
+
{', '.join(unique_unavailable).upper()}
|
| 1004 |
+
|
| 1005 |
+
Please answer based on the questionnaire data that IS available, and note any limitations.
|
| 1006 |
+
"""
|
| 1007 |
+
# TODO: REMOVE WHEN PIPELINES READY - END
|
| 1008 |
+
|
| 1009 |
+
synthesis_prompt = f"""Synthesize results from {'multiple stages' if len(stage_results) > 1 else 'the research'} to answer the user's question.
|
| 1010 |
+
|
| 1011 |
+
User question: {full_question}
|
| 1012 |
+
|
| 1013 |
+
Research plan: {brief.reasoning}
|
| 1014 |
+
|
| 1015 |
+
Retrieved data:
|
| 1016 |
+
{chr(10).join(context_parts)}
|
| 1017 |
+
|
| 1018 |
+
{unavailable_note}
|
| 1019 |
+
|
| 1020 |
+
Instructions:
|
| 1021 |
+
- If this is a comparative query, clearly organize by the comparison dimensions
|
| 1022 |
+
- If this is an analytical query (most/least/best/worst), perform the analysis
|
| 1023 |
+
- Preserve important details from the research
|
| 1024 |
+
- Use natural language, be clear and organized
|
| 1025 |
+
- Cite which poll(s) or stage(s) information comes from
|
| 1026 |
+
- Do NOT make up information not in the retrieved data
|
| 1027 |
+
- TODO: REMOVE WHEN PIPELINES READY - If some data sources weren't available, clearly state this and explain what you CAN provide
|
| 1028 |
+
"""
|
| 1029 |
+
|
| 1030 |
+
final_answer = self.llm.invoke([
|
| 1031 |
+
SystemMessage(content="You are a survey data analyst synthesizing research results."),
|
| 1032 |
+
HumanMessage(content=synthesis_prompt)
|
| 1033 |
+
]).content
|
| 1034 |
+
|
| 1035 |
+
if self.verbose:
|
| 1036 |
+
print("Synthesis complete")
|
| 1037 |
+
|
| 1038 |
+
return {
|
| 1039 |
+
"final_answer": final_answer,
|
| 1040 |
+
"messages": [AIMessage(content=final_answer)]
|
| 1041 |
+
}
|
| 1042 |
+
|
| 1043 |
+
# ========================================================================
|
| 1044 |
+
# PUBLIC API
|
| 1045 |
+
# ========================================================================
|
| 1046 |
+
|
| 1047 |
+
def query(self, question: str, thread_id: str = "default") -> str:
|
| 1048 |
+
"""
|
| 1049 |
+
Query the survey analysis system.
|
| 1050 |
+
|
| 1051 |
+
Args:
|
| 1052 |
+
question: User's question
|
| 1053 |
+
thread_id: Conversation thread ID for memory
|
| 1054 |
+
|
| 1055 |
+
Returns:
|
| 1056 |
+
Answer string
|
| 1057 |
+
|
| 1058 |
+
Note: When using the same thread_id across multiple calls, the conversation
|
| 1059 |
+
context is preserved. For example:
|
| 1060 |
+
- Call 1: query("what questions were asked?", thread_id="user_123")
|
| 1061 |
+
- Call 2: query("June 2025, unity poll", thread_id="user_123")
|
| 1062 |
+
|
| 1063 |
+
The second call will understand the full context.
|
| 1064 |
+
"""
|
| 1065 |
+
|
| 1066 |
+
# Create initial state for this turn
|
| 1067 |
+
# Note: LangGraph's operator.add annotation will append to existing messages
|
| 1068 |
+
# from the checkpointer, not replace them
|
| 1069 |
+
initial_state = {
|
| 1070 |
+
"messages": [HumanMessage(content=question)],
|
| 1071 |
+
"user_question": question,
|
| 1072 |
+
"research_brief": None,
|
| 1073 |
+
"current_stage": 0,
|
| 1074 |
+
"stage_results": [],
|
| 1075 |
+
"questionnaire_results": None,
|
| 1076 |
+
"toplines_results": None,
|
| 1077 |
+
"crosstabs_results": None,
|
| 1078 |
+
"sql_results": None,
|
| 1079 |
+
"verification": None,
|
| 1080 |
+
"final_answer": None,
|
| 1081 |
+
"retry_count": 0,
|
| 1082 |
+
"max_retries": self.max_retries
|
| 1083 |
+
}
|
| 1084 |
+
|
| 1085 |
+
config = {"configurable": {"thread_id": thread_id}}
|
| 1086 |
+
|
| 1087 |
+
if self.verbose:
|
| 1088 |
+
print(f"\n🧵 Thread ID: {thread_id}")
|
| 1089 |
+
|
| 1090 |
+
final_state = self.graph.invoke(initial_state, config)
|
| 1091 |
+
|
| 1092 |
+
return final_state["final_answer"]
|
| 1093 |
+
|
| 1094 |
+
def stream_query(self, question: str, thread_id: str = "default"):
|
| 1095 |
+
"""Stream the query execution for real-time updates"""
|
| 1096 |
+
|
| 1097 |
+
initial_state = {
|
| 1098 |
+
"messages": [HumanMessage(content=question)],
|
| 1099 |
+
"user_question": question,
|
| 1100 |
+
"research_brief": None,
|
| 1101 |
+
"current_stage": 0,
|
| 1102 |
+
"stage_results": [],
|
| 1103 |
+
"questionnaire_results": None,
|
| 1104 |
+
"toplines_results": None,
|
| 1105 |
+
"crosstabs_results": None,
|
| 1106 |
+
"sql_results": None,
|
| 1107 |
+
"verification": None,
|
| 1108 |
+
"final_answer": None,
|
| 1109 |
+
"retry_count": 0,
|
| 1110 |
+
"max_retries": self.max_retries
|
| 1111 |
+
}
|
| 1112 |
+
|
| 1113 |
+
config = {"configurable": {"thread_id": thread_id}}
|
| 1114 |
+
|
| 1115 |
+
for event in self.graph.stream(initial_state, config):
|
| 1116 |
+
yield event
|
| 1117 |
+
|
| 1118 |
+
|
| 1119 |
+
# ============================================================================
|
| 1120 |
+
# CLI INTERFACE
|
| 1121 |
+
# ============================================================================
|
| 1122 |
+
|
| 1123 |
+
def main():
|
| 1124 |
+
"""Interactive CLI"""
|
| 1125 |
+
import sys
|
| 1126 |
+
|
| 1127 |
+
openai_api_key = os.getenv("OPENAI_API_KEY")
|
| 1128 |
+
pinecone_api_key = os.getenv("PINECONE_API_KEY")
|
| 1129 |
+
|
| 1130 |
+
if not openai_api_key or not pinecone_api_key:
|
| 1131 |
+
print("Error: Missing API keys")
|
| 1132 |
+
print("Set OPENAI_API_KEY and PINECONE_API_KEY environment variables")
|
| 1133 |
+
sys.exit(1)
|
| 1134 |
+
|
| 1135 |
+
print("Initializing survey analysis agent...")
|
| 1136 |
+
agent = SurveyAnalysisAgent(
|
| 1137 |
+
openai_api_key=openai_api_key,
|
| 1138 |
+
pinecone_api_key=pinecone_api_key,
|
| 1139 |
+
verbose=True
|
| 1140 |
+
)
|
| 1141 |
+
|
| 1142 |
+
print("\n" + "="*80)
|
| 1143 |
+
print("SURVEY ANALYSIS AGENT (WITH STAGED RESEARCH)")
|
| 1144 |
+
print("="*80)
|
| 1145 |
+
print("\nType 'quit' to exit\n")
|
| 1146 |
+
|
| 1147 |
+
thread_id = "cli_session"
|
| 1148 |
+
|
| 1149 |
+
while True:
|
| 1150 |
+
try:
|
| 1151 |
+
question = input("\nYour question: ").strip()
|
| 1152 |
+
|
| 1153 |
+
if not question or question.lower() in ['quit', 'exit', 'q']:
|
| 1154 |
+
print("\nGoodbye!")
|
| 1155 |
+
break
|
| 1156 |
+
|
| 1157 |
+
print("\n" + "-"*80)
|
| 1158 |
+
answer = agent.query(question, thread_id=thread_id)
|
| 1159 |
+
print("\n" + "="*80)
|
| 1160 |
+
print("ANSWER:")
|
| 1161 |
+
print("="*80)
|
| 1162 |
+
print(answer)
|
| 1163 |
+
print("="*80)
|
| 1164 |
+
|
| 1165 |
+
except KeyboardInterrupt:
|
| 1166 |
+
print("\n\nGoodbye!")
|
| 1167 |
+
break
|
| 1168 |
+
except Exception as e:
|
| 1169 |
+
print(f"\nError: {e}")
|
| 1170 |
+
if os.getenv("DEBUG"):
|
| 1171 |
+
raise
|
| 1172 |
+
|
| 1173 |
+
|
| 1174 |
+
if __name__ == "__main__":
|
| 1175 |
+
main()
|