DIVERSIFAIR / rag_query.py
Courtney Ford
rag_query update and fallback
6a7581e
from typing import List, Tuple, Optional
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
def format_context_with_citations(results: List[Tuple[Document, float]]) -> str:
context_parts = []
for i, (doc, score) in enumerate(results, 1):
citation = doc.metadata.get("citation", "Unknown Source")
entity = doc.metadata.get("entity", "Unknown")
language = doc.metadata.get("language", "")
status = doc.metadata.get("status", "")
text = doc.page_content
# Build the source block
source_block = [
f"[Source {i}]",
f"Citation: {citation}",
f"Jurisdiction: {entity}",
]
if status and status.lower() not in ["published", ""]:
source_block.append(f"Status: {status}")
if language and language.lower() not in ["english", ""]:
source_block.append(
f"Language: {language} translation - interpret with caution"
)
source_block.append(f"Content: {text}")
context_parts.append("\n".join(source_block))
return "\n---\n".join(context_parts)
def create_rag_chain(model_name: str = "gpt-4o-mini", temperature: float = 0):
llm = ChatOpenAI(model=model_name, temperature=temperature)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are an expert AI policy analyst with deep knowledge of AI regulations across multiple jurisdictions.
Your task is to answer questions about AI regulations based ONLY on the provided source documents.
CRITICAL INSTRUCTIONS:
1. Answer based ONLY on the provided sources - do not use external knowledge
2. Always cite your sources using the [Source N] format from the context
3. When comparing jurisdictions, explicitly note which jurisdiction each point comes from
4. If the sources don't contain enough information to answer fully, say so
5. Be precise with legal citations (e.g., "According to EU AI Act, Article 5(1)(a)...")
6. For comparative questions, structure your answer to clearly contrast different jurisdictions
7. When asked about an EU member state (e.g., Germany, France, Spain) and the sources discuss that country's policies but do not include country-specific primary legislation, note that as an EU member state, the country is also subject to EU regulations like the EU AI Act. If EU sources are available, reference them as applicable law for that member state.
Format your citations like this:
- "According to the EU AI Act, Article 5..." [Source 1]
- "The US approach differs..." [Source 3]
""",
),
(
"user",
"""Context from relevant AI regulation documents:
{context}
Question: {question}
Please provide a comprehensive answer with proper citations.""",
),
]
)
# Create chain
chain = prompt | llm | StrOutputParser()
return chain
def extract_article_references(question: str) -> List[str]:
"""
Extract Article/Section references from question.
E.g., "Article 5", "Section 3", "Article 5(1)(a)"
Args:
question: Question text
Returns:
List of article/section numbers
"""
import re
# Match patterns like "Article 5", "Section 3", "Article 5(1)", etc.
article_pattern = r"Article\s+(\d+)"
section_pattern = r"Section\s+(\d+)"
articles = re.findall(article_pattern, question, re.IGNORECASE)
sections = re.findall(section_pattern, question, re.IGNORECASE)
return articles + sections
def rerank_by_document_priority(
results: List[Tuple[Document, float]],
boost_factor: float = 0.3,
detected_entity: Optional[str] = None,
) -> List[Tuple[Document, float]]:
"""
Rerank results to prioritize:
1. Documents from detected entity (if specified)
2. Primary legislation (highest priority)
3. Draft legislation (medium priority)
4. Articles over preambles
5. White Papers/Reports (lowest priority)
Args:
results: List of (Document, score) tuples
boost_factor: How much to adjust scores
detected_entity: If specified, heavily penalize non-matching entities
"""
reranked = []
for doc, score in results:
status_raw = doc.metadata.get("status", "")
status = str(status_raw).lower()
doc_type = doc.metadata.get("document_type", "")
filename = doc.metadata.get("filename", "")
doc_entity = doc.metadata.get("entity", "")
if detected_entity and doc_entity != detected_entity:
boosted_score = score * 2.0 # Push down in ranking
elif ("passed" in status or "enacted" in status) and doc_type in [
"Article_style",
"US_Congress",
"Special_cases",
]:
boosted_score = score * (1 - boost_factor * 3)
elif "preamble" in filename.lower():
boosted_score = score * (1 + boost_factor * 2)
elif "draft" in status or doc_type in [
"Article_style",
"US_Congress",
"Special_cases",
]:
boosted_score = score * (1 - boost_factor * 1.5)
elif doc_type == "Paragraph_style":
boosted_score = score * (1 + boost_factor)
else:
boosted_score = score
reranked.append((doc, boosted_score))
reranked.sort(key=lambda x: x[1])
return [
(doc, original_score)
for (doc, _), (_, original_score) in zip(reranked, results)
]
def extract_entity_with_llm(
question: str, metadata_df, model_name: str = "gpt-4o-mini"
) -> Optional[str]:
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
# Get list of valid entities
valid_entities = metadata_df["Entity"].unique().tolist()
llm = ChatOpenAI(model=model_name, temperature=0)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are an expert at identifying countries and jurisdictions mentioned in questions.
Given a question, identify which country or jurisdiction it's asking about.
Valid entities (return EXACTLY one of these if mentioned, or "NONE" if no jurisdiction mentioned):
{entities}
Rules:
- Return ONLY the entity name, nothing else
- Handle variations: "Chinese" → "China", "American" → "USA", "German" → "Germany"
- If multiple entities mentioned and it's a comparison, return "COMPARISON"
- If no specific entity mentioned, return "NONE"
- Match to the closest valid entity from the list above
Examples:
Question: "What are Chinese AI regulations?"
Answer: China
Question: "How does Germany regulate AI?"
Answer: Germany
Question: "Compare US and EU approaches"
Answer: COMPARISON
Question: "What are best practices for AI governance?"
Answer: NONE""",
),
("user", "Question: {question}\nAnswer:"),
]
)
chain = prompt | llm | StrOutputParser()
try:
result = chain.invoke(
{"question": question, "entities": ", ".join(valid_entities)}
).strip()
# Validate result
if result == "COMPARISON":
return None # Let comparison detection handle it
elif result == "NONE":
return None
elif result in valid_entities:
return result
else:
# LLM returned something not in our list, ignore
print(f"LLM returned invalid entity: {result}")
return None
except Exception as e:
print(f"LLM entity extraction failed: {e}")
return None
def extract_document_references(
question: str, metadata_df
) -> Tuple[List[str], Optional[str]]:
"""
Detect if question mentions specific documents by name or entity/country.
Uses manual aliases for major regulations, fuzzy title matching as fallback,
and LLM-based entity extraction as final fallback.
Args:
question: User's question
metadata_df: DataFrame with document metadata (must have 'Title' and 'File_name' columns)
Returns:
Tuple of (list of filenames to boost, suggested entity filter)
"""
import pandas as pd
import re
question_lower = question.lower()
matching_files = []
suggested_entity = None
# FIRST: Check for entity names with regex (fast)
if metadata_df is not None:
for entity in metadata_df["Entity"].unique():
entity_lower = entity.lower()
pattern = r"\b" + re.escape(entity_lower) + r"\b"
if re.search(pattern, question_lower):
suggested_entity = entity
print(f"Regex detected entity: {entity}")
break # Stop at first match
document_aliases = {
"eu ai act": ("EU_AI_Act_Articles.pdf", "EU"),
"ai act": ("EU_AI_Act_Articles.pdf", "EU"),
"artificial intelligence act": ("EU_AI_Act_Articles.pdf", "EU"),
"gdpr": ("EU_GDPR_Articles.pdf", "EU"),
"general data protection regulation": ("EU_GDPR_Articles.pdf", "EU"),
"ccpa": ("ccpa_act.pdf", "USA"),
"california consumer privacy act": ("ccpa_act.pdf", "USA"),
"dsa": ("EU_DSA_Articles.pdf", "EU"),
"digital services act": ("EU_DSA_Articles.pdf", "EU"),
"dma": ("EU_DMA_Articles.pdf", "EU"),
"digital markets act": ("EU_DMA_Articles.pdf", "EU"),
"dga": ("EU_DGA_Articles.pdf", "EU"),
"data governance act": ("EU_DGA_Articles.pdf", "EU"),
}
# Check manual aliases
for alias, (filename, entity) in document_aliases.items():
if alias in question_lower:
matching_files.append(filename)
if suggested_entity is None: # Only override if not already set
suggested_entity = entity
# Fuzzy matching against titles (only if no specific doc found)
if not matching_files:
common_words = {
"the",
"of",
"and",
"a",
"an",
"for",
"in",
"on",
"to",
"act",
"regulation",
"law",
"policy",
"framework",
"strategy",
"guide",
}
question_words = set(question_lower.split())
question_significant = question_words - common_words
for _, row in metadata_df.iterrows():
title = str(row["Title"]).lower()
filename = row["File_name"]
entity = row["Entity"]
title_words = set(title.split())
title_significant = title_words - common_words
matches = title_significant & question_significant
if len(matches) >= 2:
matching_files.append(filename)
if suggested_entity is None: # Only override if not already set
suggested_entity = entity
if suggested_entity is None and metadata_df is not None:
print("Regex entity detection failed, trying LLM extraction...")
suggested_entity = extract_entity_with_llm(question, metadata_df)
if suggested_entity:
print(f"LLM detected entity: {suggested_entity}")
return list(set(matching_files)), suggested_entity
# def extract_document_references(
# question: str, metadata_df
# ) -> Tuple[List[str], Optional[str]]:
# """
# Detect if question mentions specific documents by name or entity/country.
# """
# import pandas as pd
# import re
# question_lower = question.lower()
# matching_files = []
# suggested_entity = None
# if metadata_df is not None:
# for entity in metadata_df["Entity"].unique():
# entity_lower = entity.lower()
# pattern = r"\b" + re.escape(entity_lower) + r"\b"
# if re.search(pattern, question_lower):
# suggested_entity = entity
# break # Stop at first match
# document_aliases = {
# "eu ai act": ("EU_AI_Act_Articles.pdf", "EU"),
# "ai act": ("EU_AI_Act_Articles.pdf", "EU"),
# "artificial intelligence act": ("EU_AI_Act_Articles.pdf", "EU"),
# "gdpr": ("EU_GDPR_Articles.pdf", "EU"),
# "general data protection regulation": ("EU_GDPR_Articles.pdf", "EU"),
# "ccpa": ("ccpa_act.pdf", "USA"),
# "california consumer privacy act": ("ccpa_act.pdf", "USA"),
# "dsa": ("EU_DSA_Articles.pdf", "EU"),
# "digital services act": ("EU_DSA_Articles.pdf", "EU"),
# "dma": ("EU_DMA_Articles.pdf", "EU"),
# "digital markets act": ("EU_DMA_Articles.pdf", "EU"),
# "dga": ("EU_DGA_Articles.pdf", "EU"),
# "data governance act": ("EU_DGA_Articles.pdf", "EU"),
# }
# # Check manual aliases
# for alias, (filename, entity) in document_aliases.items():
# if alias in question_lower:
# matching_files.append(filename)
# if suggested_entity is None: # Only override if not already set
# suggested_entity = entity
# if not matching_files:
# common_words = {
# "the",
# "of",
# "and",
# "a",
# "an",
# "for",
# "in",
# "on",
# "to",
# "act",
# "regulation",
# "law",
# "policy",
# "framework",
# "strategy",
# "guide",
# }
# question_words = set(question_lower.split())
# question_significant = question_words - common_words
# for _, row in metadata_df.iterrows():
# title = str(row["Title"]).lower()
# filename = row["File_name"]
# entity = row["Entity"]
# title_words = set(title.split())
# title_significant = title_words - common_words
# matches = title_significant & question_significant
# if len(matches) >= 2:
# matching_files.append(filename)
# if suggested_entity is None: # Only override if not already set
# suggested_entity = entity
# return list(set(matching_files)), suggested_entity
def is_comparison_question(question: str) -> bool:
"""Detect if question is comparing multiple jurisdictions"""
question_lower = question.lower()
comparison_patterns = [
"differ from",
"compared to",
"versus",
"vs",
"vs.",
"difference between",
"differences between",
"compare",
"comparison",
"contrast",
"unlike",
"similar to",
"different from",
"in contrast to",
"as opposed to",
]
return any(pattern in question_lower for pattern in comparison_patterns)
def ask_question_with_llm(
vectorstore,
question: str,
metadata_df=None,
entity: Optional[str] = None,
k: int = 10,
model_name: str = "gpt-4o-mini",
boost_specific_docs: bool = True,
auto_detect_entity: bool = True,
) -> dict:
"""
Args:
vectorstore: FAISS vectorstore
question: Question to ask
metadata_df: DataFrame with document metadata
entity: Optional entity filter (if None and auto_detect_entity=True, will auto-detect)
k: Number of chunks to retrieve
model_name: OpenAI model to use
boost_specific_docs: If True, boost documents mentioned in the question
auto_detect_entity: If True, auto-detect entity from document references
Returns:
Dictionary with answer, sources, and metadata
"""
# If it's a comparison question, disable auto entity detection
if is_comparison_question(question):
auto_detect_entity = False
print("Comparison question detected - retrieving from all jurisdictions")
# Check if question references specific documents
referenced_docs = []
detected_entity = None
if boost_specific_docs and metadata_df is not None:
referenced_docs, detected_entity = extract_document_references(
question, metadata_df
)
# Use detected entity if no entity specified and auto-detection enabled
if entity is None and auto_detect_entity and detected_entity:
entity = detected_entity
print(f"Auto-detected entity filter: {entity}")
# Store original entity for EU fallback
original_entity = entity
# Also check for article/section references
referenced_articles = extract_article_references(question)
# If specific articles are mentioned AND specific documents are detected,
# search docstore directly by metadata (bypass semantic search entirely)
if referenced_articles and referenced_docs:
print(f"Detected Article/Section references: {referenced_articles}")
print(f"Using direct metadata search (bypassing semantic search)...")
matched_chunks = []
for doc_id, doc in vectorstore.docstore._dict.items():
doc_article = doc.metadata.get("article") or doc.metadata.get("section")
filename = doc.metadata.get("filename")
doc_entity = doc.metadata.get("entity")
if (
doc_article
and str(doc_article) in referenced_articles
and filename in referenced_docs
and (entity is None or doc_entity == entity)
):
matched_chunks.append((doc, 0.0))
if matched_chunks:
results = matched_chunks[:k]
print(
f"Found {len(matched_chunks)} chunks for Article {referenced_articles[0]}"
)
else:
print("No metadata matches found, falling back to semantic search...")
# Fall back to regular semantic search
retrieve_k = k * 5
results = vectorstore.similarity_search_with_score(
question, k=retrieve_k * 5
)
if entity:
results = [
(doc, score)
for doc, score in results
if doc.metadata.get("entity") == entity
]
results = results[:k]
else:
results = results[:k]
else:
# Regular retrieval when no specific articles mentioned
# Retrieve more documents for potential boosting
retrieve_k = k * 5
results = vectorstore.similarity_search_with_score(question, k=retrieve_k * 5)
if entity:
entity_results = [
(doc, score)
for doc, score in results
if doc.metadata.get("entity") == entity
]
# EU MEMBER STATE FALLBACK
# If no results and entity is an EU member state, expand to all results
# Let reranking prioritize EU legislation
if not entity_results and original_entity:
eu_members = [
"Germany",
"France",
"Italy",
"Spain",
"Netherlands",
"Poland",
"Belgium",
"Austria",
"Greece",
"Portugal",
"Czechia",
"Hungary",
"Sweden",
"Denmark",
"Finland",
"Slovakia",
"Ireland",
"Croatia",
"Lithuania",
"Slovenia",
"Latvia",
"Estonia",
"Cyprus",
"Luxembourg",
"Malta",
]
if original_entity in eu_members:
print(
f"No {original_entity}-specific documents found. Expanding search to include EU regulations as {original_entity} is an EU member state..."
)
entity_results = results[:retrieve_k]
results = entity_results[:retrieve_k]
if not (referenced_articles and referenced_docs):
if referenced_docs:
boosted_results = []
other_results = []
for doc, score in results:
filename = doc.metadata.get("filename")
if filename in referenced_docs:
boosted_results.append((doc, score * 0.3))
else:
other_results.append((doc, score))
results = boosted_results + other_results
# RERANK BY DOCUMENT PRIORITY - prioritize primary legislation
results = rerank_by_document_priority(results, boost_factor=0.3)
results = results[:k]
if not results:
return {
"answer": "No relevant documents found for this query.",
"sources": [],
"question": question,
"num_sources": 0,
}
context = format_context_with_citations(results)
chain = create_rag_chain(model_name=model_name)
answer = chain.invoke({"context": context, "question": question})
sources = []
for i, (doc, score) in enumerate(results, 1):
sources.append(
{
"number": i,
"citation": doc.metadata.get("citation"),
"entity": doc.metadata.get("entity"),
"score": float(score),
"text_preview": doc.page_content[:200],
}
)
return {
"answer": answer,
"sources": sources,
"question": question,
"entity_filter": entity,
"num_sources": len(sources),
}
def print_rag_answer(result: dict):
print("\n" + "=" * 80)
print(f"QUESTION: {result['question']}")
if result.get("entity_filter"):
print(f"FILTER: {result['entity_filter']}")
print("=" * 80)
print(f"\nANSWER:\n")
print(result["answer"])
print(f"\n\n{'='*80}")
print(f"SOURCES ({result['num_sources']} retrieved):")
print("=" * 80)
# Group sources by citation to avoid repetition
seen_citations = {}
citation_order = []
for source in result["sources"]:
citation = source["citation"]
entity = source["entity"]
key = f"{citation}|{entity}"
if key not in seen_citations:
seen_citations[key] = {
"citation": citation,
"entity": entity,
"source_numbers": [source["number"]],
}
citation_order.append(key)
else:
seen_citations[key]["source_numbers"].append(source["number"])
# Print deduplicated sources
for key in citation_order:
group = seen_citations[key]
count = len(group["source_numbers"])
# If same doc appears many times, simplify the display
if count > 3:
print(f"\n{group['citation']}")
print(f" Jurisdiction: {group['entity']}")
print(
f" (Referenced in sources {group['source_numbers'][0]}-{group['source_numbers'][-1]}, {count} chunks total)"
)
else:
source_nums = ", ".join([f"{n}" for n in group["source_numbers"]])
print(f"\n[{source_nums}] {group['citation']}")
print(f" Jurisdiction: {group['entity']}")
print()