fcas / synthesis_qa_backend.py
lsempe's picture
clean history - remove exposed credentials
f73929c
import faiss
import pandas as pd
import numpy as np
import google.generativeai as genai
from typing import List, Dict, Optional, Tuple
from collections import defaultdict
import logging
import time
from dataclasses import dataclass
@dataclass
class SynthesisConfig:
"""Configuration class for research synthesis parameters"""
top_k: int = 20
min_relevance_strict: float = 0.7
min_relevance_moderate: float = 0.6
min_relevance_threshold: float = 0.55
max_studies: int = 6
min_studies: int = 4
max_synthesis_tokens: int = 4000
rate_limit_delay: float = 1.0
domain_keywords: List[str] = None
def __post_init__(self):
if self.domain_keywords is None:
self.domain_keywords = [
'development', 'health', 'education', 'governance', 'poverty',
'conflict', 'fragile', 'intervention', 'policy', 'evaluation',
'impact', 'program', 'research', 'study', 'analysis', 'survey'
]
class QueryAnalyzer:
"""Analyzes queries to determine relevance to research domain"""
def __init__(self, config: SynthesisConfig):
self.config = config
def is_domain_relevant(self, query: str) -> Tuple[bool, float, str]:
"""
Check if query is relevant to research domain
Returns: (is_relevant, confidence_score, reason)
"""
query_lower = query.lower()
# Check for obvious non-research queries
non_research_patterns = [
'who won', 'world cup', 'sports', 'entertainment', 'celebrity',
'weather', 'stock price', 'cryptocurrency', 'movie', 'music',
'recipe', 'cooking', 'fashion', 'shopping', 'games', 'gaming'
]
for pattern in non_research_patterns:
if pattern in query_lower:
return False, 0.1, f"Query contains non-research pattern: '{pattern}'"
# Check for domain relevance - multiple approaches
domain_matches = sum(1 for keyword in self.config.domain_keywords
if keyword in query_lower)
# Research question patterns (even without domain keywords)
research_patterns = [
'what methods', 'what approaches', 'how do', 'how to',
'what strategies', 'what techniques', 'how can',
'what are the', 'which methods', 'which approaches'
]
research_pattern_matches = sum(1 for pattern in research_patterns
if pattern in query_lower)
# Methodological terms that indicate research focus
method_terms = [
'method', 'approach', 'strategy', 'technique', 'measure',
'measurement', 'data', 'sample', 'study', 'research',
'analysis', 'evaluation', 'assessment', 'design'
]
method_matches = sum(1 for term in method_terms if term in query_lower)
# Calculate total relevance score
total_score = domain_matches + (research_pattern_matches * 2) + method_matches
if total_score == 0:
return False, 0.3, "No domain-relevant keywords or research patterns found"
# Be more generous for methodological queries
if research_pattern_matches > 0 or method_matches >= 2:
confidence = min(0.9, 0.6 + (total_score * 0.05))
return True, confidence, f"Found research patterns and methodological terms (score: {total_score})"
if domain_matches > 0:
confidence = min(0.9, 0.5 + (domain_matches * 0.1))
return True, confidence, f"Found {domain_matches} domain-relevant keywords"
return False, 0.3, "Insufficient domain relevance"
def analyze_query_type(self, query: str) -> Dict[str, str]:
"""Analyze query to determine focus area and type"""
query_lower = query.lower()
focus_area = "general findings"
query_type = "exploratory"
# Determine focus area
if any(word in query_lower for word in ['method', 'approach', 'methodology', 'technique', 'design']):
focus_area = "methodological approaches"
query_type = "methodological"
elif any(word in query_lower for word in ['result', 'finding', 'outcome', 'impact', 'effect', 'evaluation']):
focus_area = "key findings and outcomes"
query_type = "results-focused"
elif any(word in query_lower for word in ['challenge', 'barrier', 'problem', 'issue', 'difficulty']):
focus_area = "challenges and barriers"
query_type = "problem-identification"
elif any(word in query_lower for word in ['recommendation', 'solution', 'strategy', 'intervention', 'policy']):
focus_area = "strategies and recommendations"
query_type = "solution-oriented"
elif any(word in query_lower for word in ['what', 'how', 'why', 'which', 'where']):
query_type = "analytical"
# Additional FCAS-specific analysis
if any(word in query_lower for word in ['sampling', 'sample', 'recruitment', 'selection']):
focus_area = "sampling and recruitment strategies"
query_type = "methodological"
elif any(word in query_lower for word in ['data quality', 'validation', 'reliability', 'validity']):
focus_area = "data quality and validation"
query_type = "methodological"
elif any(word in query_lower for word in ['ethical', 'ethics', 'consent', 'protection']):
focus_area = "ethical considerations"
query_type = "methodological"
elif any(word in query_lower for word in ['tracking', 'mobile', 'displacement', 'attrition']):
focus_area = "population tracking and attrition"
query_type = "methodological"
elif any(word in query_lower for word in ['proxy', 'indicator', 'measurement', 'counterfactual']):
focus_area = "measurement and identification strategies"
query_type = "methodological"
return {
'focus_area': focus_area,
'query_type': query_type,
'original_query': query
}
class ResearchSynthesizer:
def __init__(self, index_path: str, metadata_path: str, api_key: str,
config: Optional[SynthesisConfig] = None,
log_level: int = logging.INFO):
"""Initialize the research synthesis system"""
# Setup logging
logging.basicConfig(level=log_level, format='%(asctime)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger(__name__)
# Configuration
self.config = config or SynthesisConfig()
self.query_analyzer = QueryAnalyzer(self.config)
# Validate inputs
self._validate_inputs(index_path, metadata_path, api_key)
try:
# Load FAISS index and metadata
self.index = faiss.read_index(index_path)
self.metadata = pd.read_csv(metadata_path)
# Configure Gemini API
genai.configure(api_key=api_key)
self.logger.info(f"Loaded {self.index.ntotal} chunks from {len(self.metadata['record_id'].unique())} documents")
self.logger.info(f"FAISS index dimensions: {self.index.d}")
# Check dimension compatibility
self._check_dimensions()
except Exception as e:
self.logger.error(f"Failed to initialize synthesizer: {e}")
raise
def _validate_inputs(self, index_path: str, metadata_path: str, api_key: str):
"""Validate input parameters"""
if not index_path or not metadata_path:
raise ValueError("Index path and metadata path must be provided")
if not api_key or api_key == "your_api_key_here":
raise ValueError("Valid API key must be provided")
if self.config.min_relevance_strict < self.config.min_relevance_moderate:
raise ValueError("Strict relevance threshold must be >= moderate threshold")
def _check_dimensions(self):
"""Check embedding dimension compatibility"""
test_embedding = self._create_test_embedding()
if test_embedding is not None:
embedding_dim = test_embedding.shape[1]
index_dim = self.index.d
self.logger.info(f"Gemini embedding dimensions: {embedding_dim}")
if embedding_dim != index_dim:
self.logger.warning(f"DIMENSION MISMATCH: Gemini={embedding_dim}, FAISS={index_dim}")
self.logger.info("Will apply dimension adjustment during search")
self.dimension_mismatch = True
self.target_dim = index_dim
self.source_dim = embedding_dim
else:
self.logger.info("Dimensions match perfectly")
self.dimension_mismatch = False
else:
self.logger.warning("Could not test embedding dimensions")
self.dimension_mismatch = False
def _create_test_embedding(self) -> Optional[np.ndarray]:
"""Create a test embedding to check dimensions"""
try:
time.sleep(self.config.rate_limit_delay) # Rate limiting
embed_result = genai.embed_content(
model="models/gemini-embedding-001",
content="test",
task_type="retrieval_query"
)
return np.array([embed_result['embedding']], dtype="float32")
except Exception as e:
self.logger.error(f"Could not create test embedding: {e}")
return None
def _adjust_embedding_dimensions(self, embedding: np.ndarray) -> np.ndarray:
"""Adjust embedding dimensions to match FAISS index"""
if not self.dimension_mismatch:
return embedding
current_dim = embedding.shape[1]
target_dim = self.target_dim
self.logger.debug(f"Adjusting dimensions: {current_dim}{target_dim}")
if current_dim < target_dim:
# Pad with zeros
padding = np.zeros((embedding.shape[0], target_dim - current_dim), dtype="float32")
adjusted = np.concatenate([embedding, padding], axis=1)
elif current_dim > target_dim:
# Truncate (consider PCA for better semantic preservation)
adjusted = embedding[:, :target_dim]
else:
adjusted = embedding
return adjusted
def search_relevant_chunks(self, query: str) -> List[Dict]:
"""Find relevant chunks using FAISS index and Gemini embeddings API"""
self.logger.info(f"Searching for: '{query}'")
try:
time.sleep(self.config.rate_limit_delay) # Rate limiting
embed_result = genai.embed_content(
model="models/gemini-embedding-001",
content=query,
task_type="retrieval_query"
)
query_embedding = np.array([embed_result['embedding']], dtype="float32")
self.logger.debug(f"Embedding created: shape {query_embedding.shape}")
except Exception as e:
self.logger.error(f"Embedding creation failed: {e}")
return []
# Adjust dimensions if needed
query_embedding = self._adjust_embedding_dimensions(query_embedding)
try:
distances, indices = self.index.search(query_embedding, self.config.top_k)
self.logger.info(f"Search completed - found {len(indices[0])} results")
self.logger.debug(f"Distance range: {distances[0].min():.4f} to {distances[0].max():.4f}")
except Exception as e:
self.logger.error(f"FAISS search failed: {e}")
return []
results = []
for distance, idx in zip(distances[0], indices[0]):
if idx == -1 or idx >= len(self.metadata):
continue
try:
chunk_data = self.metadata.iloc[idx].to_dict()
chunk_data['similarity_score'] = float(1 / (1 + distance))
chunk_data['faiss_distance'] = float(distance)
chunk_data['faiss_index'] = int(idx)
results.append(chunk_data)
except (IndexError, KeyError) as e:
self.logger.warning(f"Invalid index {idx}, skipping: {e}")
continue
# Sort by similarity score
results.sort(key=lambda x: x['similarity_score'], reverse=True)
if results:
best_score = results[0]['similarity_score']
worst_score = results[-1]['similarity_score']
self.logger.info(f"Similarity range: {worst_score:.4f} to {best_score:.4f}")
return results
def group_by_studies(self, chunks: List[Dict]) -> Dict[str, List[Dict]]:
"""Group chunks by study/document"""
studies = defaultdict(list)
for chunk in chunks:
studies[chunk['record_id']].append(chunk)
return dict(studies)
def filter_and_rank_studies(self, studies: Dict[str, List[Dict]],
query: str = "") -> Tuple[List[Dict], str]:
"""
Select the most relevant studies using adaptive thresholds
Returns: (selected_studies, quality_message)
"""
study_summaries = []
# Determine threshold based on best available scores
all_best_scores = []
for record_id, chunks in studies.items():
best_chunk = max(chunks, key=lambda x: x['similarity_score'])
all_best_scores.append(best_chunk['similarity_score'])
if not all_best_scores:
return [], "No studies found"
max_score = max(all_best_scores)
avg_score = np.mean(all_best_scores)
# Adaptive threshold selection
if max_score >= self.config.min_relevance_strict:
threshold = self.config.min_relevance_strict
quality = "high"
elif max_score >= self.config.min_relevance_moderate:
threshold = self.config.min_relevance_moderate
quality = "moderate"
elif max_score >= self.config.min_relevance_threshold:
threshold = self.config.min_relevance_threshold
quality = "low"
else:
return [], f"No studies met minimum relevance threshold. Best score: {max_score:.3f}"
self.logger.info(f"Using {quality} quality threshold: {threshold:.3f}")
# Filter studies
for record_id, chunks in studies.items():
best_chunk = max(chunks, key=lambda x: x['similarity_score'])
if best_chunk['similarity_score'] < threshold:
continue
# Get relevant chunks with slightly lower threshold
relevant_chunks = [c for c in chunks
if c['similarity_score'] > threshold * 0.8]
# Limit text to prevent token overflow
combined_texts = [c['text'] for c in relevant_chunks[:3]]
combined_text = "\n\n".join(combined_texts)
# Truncate if too long
if len(combined_text) > 1500:
combined_text = combined_text[:1500] + "..."
study_summary = {
'record_id': record_id,
'combined_text': combined_text,
'max_relevance': best_chunk['similarity_score'],
'chunk_count': len(relevant_chunks)
}
# Copy metadata (excluding internal fields)
excluded_fields = {
'record_id', 'full_text', 'text', 'chunk_id', 'section',
'chunk_type', 'word_count', 'faiss_distance', 'faiss_index'
}
for key, value in best_chunk.items():
if key not in excluded_fields and not key.startswith('similarity'):
study_summary[key] = value
study_summaries.append(study_summary)
# Enhanced scoring with precomputed metadata relevance
def enhanced_score(study):
base_score = study['max_relevance']
# Metadata relevance boost (cached)
metadata_boost = self._calculate_metadata_boost(study, query)
# Quality indicators boost
quality_boost = self._calculate_quality_boost(study)
return base_score + metadata_boost + quality_boost
study_summaries.sort(key=enhanced_score, reverse=True)
selected_studies = study_summaries[:self.config.max_studies]
quality_message = f"Selected {len(selected_studies)} studies with {quality} relevance (threshold: {threshold:.3f})"
self.logger.info(quality_message)
for i, study in enumerate(selected_studies, 1):
title = study.get('title', 'No title')[:50]
score = enhanced_score(study)
self.logger.debug(f" {i}. Score: {score:.4f} - {title}...")
return selected_studies, quality_message
def _calculate_metadata_boost(self, study: Dict, query: str) -> float:
"""Calculate metadata relevance boost for a study"""
query_lower = query.lower()
metadata_boost = 0
boost_fields = [
'world_bank_sector', 'world_bank_subsector', 'study_countries',
'population', 'data_collection_method', 'analysis_type',
'research_design', 'topic_summary', 'countries_list'
]
for field in boost_fields:
if field in study and study[field]:
field_value = str(study[field]).lower()
matches = sum(1 for word in query_lower.split() if word in field_value)
metadata_boost += matches * 0.05 # Smaller, more controlled boost
return min(metadata_boost, 0.2) # Cap the boost
def _calculate_quality_boost(self, study: Dict) -> float:
"""Calculate quality indicator boost for a study"""
quality_boost = 0
# Boolean quality indicators
bool_indicators = {
'has_randomization': 0.08,
'has_validation': 0.05,
'has_advanced_analysis': 0.03,
'has_mixed_methods': 0.03
}
for field, boost in bool_indicators.items():
if study.get(field) == 'true':
quality_boost += boost
# Numeric quality indicators
try:
rigor_score = float(study.get('rigor_score', 0))
quality_boost += min(rigor_score * 0.02, 0.1) # Cap at 0.1
except (ValueError, TypeError):
pass
return quality_boost
def create_synthesis(self, query: str, studies: List[Dict],
query_analysis: Dict) -> str:
"""Create synthesized answer with improved prompt engineering"""
# Build concise context
studies_context = self._build_studies_context(studies)
# Determine synthesis length based on study count
if len(studies) <= 3:
synthesis_style = "concise"
max_length = "2-3 paragraphs"
elif len(studies) <= 6:
synthesis_style = "balanced"
max_length = "3-4 paragraphs with clear sections"
else:
synthesis_style = "comprehensive"
max_length = "4-5 paragraphs with detailed analysis"
synthesis_prompt = f"""You are an expert research synthesizer analyzing studies from fragile and conflict-affected settings (FCAS).
USER QUERY: "{query}"
QUERY TYPE: {query_analysis['query_type']}
FOCUS AREA: {query_analysis['focus_area']}
STUDIES TO SYNTHESIZE ({len(studies)} studies):
{studies_context}
SYNTHESIS INSTRUCTIONS:
1. **Direct Answer First**: Start with a clear, direct answer to the user's question
2. **Evidence-Based**: Ground all claims in the provided studies with citations (Author, Year)
3. **{synthesis_style.title()} Analysis**: Write {max_length}
4. **Key Focus**: Emphasize {query_analysis['focus_area']}
5. **Geographic Context**: Note relevant country/regional patterns
6. **Methodology**: Briefly mention study designs and sample sizes when relevant
FORMAT: Use clear prose without bullet points. Include specific citations and key statistics.
LENGTH: {max_length} maximum.
Write a focused synthesis that directly addresses: "{query}" """
try:
time.sleep(self.config.rate_limit_delay) # Rate limiting
model = genai.GenerativeModel("gemini-1.5-flash")
response = model.generate_content(synthesis_prompt)
return response.text
except Exception as e:
self.logger.error(f"Synthesis generation failed: {e}")
return f"Error creating synthesis: {e}"
def _build_studies_context(self, studies: List[Dict]) -> str:
"""Build concise studies context for synthesis"""
studies_context = ""
for i, study in enumerate(studies, 1):
# Essential metadata
title = study.get('title', 'Unknown Title')[:80]
authors = study.get('authors', 'Unknown Authors')[:50]
year = study.get('publication_year', study.get('research_year', 'Unknown'))
countries = study.get('study_countries', study.get('countries_list', 'Unknown'))[:50]
studies_context += f"\n[{i}] {title}\n"
studies_context += f"Authors: {authors} ({year}) | Countries: {countries}\n"
# Key methodology info
method_info = []
for field, label in [
('research_design', 'Design'),
('sample_size', 'N'),
('rigor_score', 'Rigor')
]:
if field in study and study[field]:
value = str(study[field])
if value.lower() not in ['unknown', 'nan', '']:
method_info.append(f"{label}: {value}")
if method_info:
studies_context += f"Method: {' | '.join(method_info)}\n"
# Truncated content
content = study['combined_text'][:800]
studies_context += f"Content: {content}...\n"
studies_context += "-" * 60 + "\n"
return studies_context
def format_references(self, studies: List[Dict]) -> str:
"""Format academic-style references"""
references = []
for i, study in enumerate(studies, 1):
title = study.get('title', 'Unknown Title')
authors = study.get('authors', 'Unknown Authors')
year = study.get('publication_year', study.get('research_year', 'Unknown'))
countries = study.get('study_countries', '')
ref = f"[{i}] {authors} ({year}). {title}"
if countries:
ref += f" *Countries: {countries}*"
if study.get('max_relevance'):
ref += f" *Relevance: {study['max_relevance']:.3f}*"
references.append(ref)
return "\n\n".join(references)
def answer_research_question(self, query: str) -> Dict[str, any]:
"""Main method to answer research questions with domain checking"""
self.logger.info(f"Processing query: '{query}'")
# Validate query length
if len(query.strip()) < 3:
return {
'answer': "Query too short. Please provide a more detailed research question.",
'references': "",
'study_count': 0,
'quality': "invalid",
'suggestions': []
}
# Check domain relevance
is_relevant, confidence, reason = self.query_analyzer.is_domain_relevant(query)
# Update the suggestions in answer_research_question method
if not is_relevant:
suggestions = [
"What sampling strategies work best in conflict-affected areas?",
"How do researchers ensure data quality during active conflict?",
"What are the ethical considerations for RCTs in fragile states?",
"How do studies handle attrition bias in longitudinal FCAS research?",
"What proxy measures are used when direct measurement is impossible?",
"How do researchers adapt survey instruments for low-literacy populations?",
"What methods are used to track mobile populations in conflict zones?",
"How do studies establish counterfactuals in fragile settings?"
]
return {
'answer': f"This query appears to be outside the scope of development research in fragile and conflict-affected settings.\n\nReason: {reason}\n\nThis database contains research on development, health, education, governance, and policy interventions in FCAS contexts.",
'references': "",
'study_count': 0,
'quality': "out_of_scope",
'suggestions': suggestions
}
# Analyze query type
query_analysis = self.query_analyzer.analyze_query_type(query)
# Search for relevant chunks
relevant_chunks = self.search_relevant_chunks(query)
if not relevant_chunks:
return {
'answer': "No relevant studies found. This might be due to technical issues or very specific query terms.",
'references': "",
'study_count': 0,
'quality': "no_results",
'suggestions': ["Try broader search terms", "Check spelling", "Use more general concepts"]
}
# Group by studies
studies_dict = self.group_by_studies(relevant_chunks)
self.logger.info(f"Found {len(studies_dict)} unique studies")
# Filter and rank studies
top_studies, quality_message = self.filter_and_rank_studies(studies_dict, query)
if len(top_studies) < self.config.min_studies:
return {
'answer': f"Found {len(studies_dict)} studies but only {len(top_studies)} met relevance criteria.\n\n{quality_message}\n\nTry using broader search terms or different keywords.",
'references': "",
'study_count': len(studies_dict),
'quality': "insufficient",
'suggestions': ["Use broader terms", "Try synonyms", "Focus on general concepts"]
}
# Create synthesis
self.logger.info(f"Synthesizing findings from {len(top_studies)} studies")
synthesis = self.create_synthesis(query, top_studies, query_analysis)
references = self.format_references(top_studies)
# Determine overall quality
avg_relevance = np.mean([s['max_relevance'] for s in top_studies])
if avg_relevance >= self.config.min_relevance_strict:
quality = "high"
elif avg_relevance >= self.config.min_relevance_moderate:
quality = "moderate"
else:
quality = "low"
return {
'answer': synthesis,
'references': references,
'study_count': len(top_studies),
'quality': quality,
'quality_message': quality_message,
'query_analysis': query_analysis,
'suggestions': []
}
# Update the test queries section in the main() function of synthesis_qa_backend.py
def main():
"""Test the improved system"""
import os
# Configuration
config = SynthesisConfig(
top_k=25,
min_relevance_strict=0.65,
min_relevance_moderate=0.55,
min_relevance_threshold=0.50,
max_studies=8,
min_studies=3
)
api_key = os.environ.get("GOOGLE_API_KEY", "your_api_key_here")
try:
synthesizer = ResearchSynthesizer(
index_path="research_chunks.faiss",
metadata_path="chunk_metadata.csv",
api_key=api_key,
config=config,
log_level=logging.INFO
)
test_queries = [
"what sampling strategies work best in conflict zones?",
"how do researchers ensure data quality during active conflict?",
"what are ethical considerations for randomized trials in fragile states?",
"how do studies handle attrition bias in FCAS research?",
"what proxy measures are used when direct measurement is impossible?",
"how do researchers adapt survey instruments for low-literacy populations?",
"who won the world cup in 2022?", # Should be rejected (non-research)
]
for query in test_queries:
print("\n" + "="*80)
print(f"QUERY: {query}")
print("="*80)
result = synthesizer.answer_research_question(query)
print(f"Quality: {result['quality']}")
print(f"Studies: {result['study_count']}")
print("\nAnswer:")
print(result['answer'])
if result['references']:
print("\nReferences:")
print(result['references'])
if result['suggestions']:
print("\nSuggestions:")
for suggestion in result['suggestions']:
print(f" • {suggestion}")
except Exception as e:
logging.error(f"Failed to run main: {e}")
raise
if __name__ == "__main__":
main()