Spaces:
Sleeping
Sleeping
| from typing import List, Tuple, Optional | |
| from langchain_core.documents import Document | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| def format_context_with_citations(results: List[Tuple[Document, float]]) -> str: | |
| context_parts = [] | |
| for i, (doc, score) in enumerate(results, 1): | |
| citation = doc.metadata.get("citation", "Unknown Source") | |
| entity = doc.metadata.get("entity", "Unknown") | |
| language = doc.metadata.get("language", "") | |
| status = doc.metadata.get("status", "") | |
| text = doc.page_content | |
| # Build the source block | |
| source_block = [ | |
| f"[Source {i}]", | |
| f"Citation: {citation}", | |
| f"Jurisdiction: {entity}", | |
| ] | |
| if status and status.lower() not in ["published", ""]: | |
| source_block.append(f"Status: {status}") | |
| if language and language.lower() not in ["english", ""]: | |
| source_block.append( | |
| f"Language: {language} translation - interpret with caution" | |
| ) | |
| source_block.append(f"Content: {text}") | |
| context_parts.append("\n".join(source_block)) | |
| return "\n---\n".join(context_parts) | |
| def create_rag_chain(model_name: str = "gpt-4o-mini", temperature: float = 0): | |
| llm = ChatOpenAI(model=model_name, temperature=temperature) | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ( | |
| "system", | |
| """You are an expert AI policy analyst with deep knowledge of AI regulations across multiple jurisdictions. | |
| Your task is to answer questions about AI regulations based ONLY on the provided source documents. | |
| CRITICAL INSTRUCTIONS: | |
| 1. Answer based ONLY on the provided sources - do not use external knowledge | |
| 2. Always cite your sources using the [Source N] format from the context | |
| 3. When comparing jurisdictions, explicitly note which jurisdiction each point comes from | |
| 4. If the sources don't contain enough information to answer fully, say so | |
| 5. Be precise with legal citations (e.g., "According to EU AI Act, Article 5(1)(a)...") | |
| 6. For comparative questions, structure your answer to clearly contrast different jurisdictions | |
| 7. When asked about an EU member state (e.g., Germany, France, Spain) and the sources discuss that country's policies but do not include country-specific primary legislation, note that as an EU member state, the country is also subject to EU regulations like the EU AI Act. If EU sources are available, reference them as applicable law for that member state. | |
| Format your citations like this: | |
| - "According to the EU AI Act, Article 5..." [Source 1] | |
| - "The US approach differs..." [Source 3] | |
| """, | |
| ), | |
| ( | |
| "user", | |
| """Context from relevant AI regulation documents: | |
| {context} | |
| Question: {question} | |
| Please provide a comprehensive answer with proper citations.""", | |
| ), | |
| ] | |
| ) | |
| # Create chain | |
| chain = prompt | llm | StrOutputParser() | |
| return chain | |
| def extract_article_references(question: str) -> List[str]: | |
| """ | |
| Extract Article/Section references from question. | |
| E.g., "Article 5", "Section 3", "Article 5(1)(a)" | |
| Args: | |
| question: Question text | |
| Returns: | |
| List of article/section numbers | |
| """ | |
| import re | |
| # Match patterns like "Article 5", "Section 3", "Article 5(1)", etc. | |
| article_pattern = r"Article\s+(\d+)" | |
| section_pattern = r"Section\s+(\d+)" | |
| articles = re.findall(article_pattern, question, re.IGNORECASE) | |
| sections = re.findall(section_pattern, question, re.IGNORECASE) | |
| return articles + sections | |
| def rerank_by_document_priority( | |
| results: List[Tuple[Document, float]], | |
| boost_factor: float = 0.3, | |
| detected_entity: Optional[str] = None, | |
| ) -> List[Tuple[Document, float]]: | |
| """ | |
| Rerank results to prioritize: | |
| 1. Documents from detected entity (if specified) | |
| 2. Primary legislation (highest priority) | |
| 3. Draft legislation (medium priority) | |
| 4. Articles over preambles | |
| 5. White Papers/Reports (lowest priority) | |
| Args: | |
| results: List of (Document, score) tuples | |
| boost_factor: How much to adjust scores | |
| detected_entity: If specified, heavily penalize non-matching entities | |
| """ | |
| reranked = [] | |
| for doc, score in results: | |
| status_raw = doc.metadata.get("status", "") | |
| status = str(status_raw).lower() | |
| doc_type = doc.metadata.get("document_type", "") | |
| filename = doc.metadata.get("filename", "") | |
| doc_entity = doc.metadata.get("entity", "") | |
| if detected_entity and doc_entity != detected_entity: | |
| boosted_score = score * 2.0 # Push down in ranking | |
| elif ("passed" in status or "enacted" in status) and doc_type in [ | |
| "Article_style", | |
| "US_Congress", | |
| "Special_cases", | |
| ]: | |
| boosted_score = score * (1 - boost_factor * 3) | |
| elif "preamble" in filename.lower(): | |
| boosted_score = score * (1 + boost_factor * 2) | |
| elif "draft" in status or doc_type in [ | |
| "Article_style", | |
| "US_Congress", | |
| "Special_cases", | |
| ]: | |
| boosted_score = score * (1 - boost_factor * 1.5) | |
| elif doc_type == "Paragraph_style": | |
| boosted_score = score * (1 + boost_factor) | |
| else: | |
| boosted_score = score | |
| reranked.append((doc, boosted_score)) | |
| reranked.sort(key=lambda x: x[1]) | |
| return [ | |
| (doc, original_score) | |
| for (doc, _), (_, original_score) in zip(reranked, results) | |
| ] | |
| def extract_entity_with_llm( | |
| question: str, metadata_df, model_name: str = "gpt-4o-mini" | |
| ) -> Optional[str]: | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| # Get list of valid entities | |
| valid_entities = metadata_df["Entity"].unique().tolist() | |
| llm = ChatOpenAI(model=model_name, temperature=0) | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ( | |
| "system", | |
| """You are an expert at identifying countries and jurisdictions mentioned in questions. | |
| Given a question, identify which country or jurisdiction it's asking about. | |
| Valid entities (return EXACTLY one of these if mentioned, or "NONE" if no jurisdiction mentioned): | |
| {entities} | |
| Rules: | |
| - Return ONLY the entity name, nothing else | |
| - Handle variations: "Chinese" → "China", "American" → "USA", "German" → "Germany" | |
| - If multiple entities mentioned and it's a comparison, return "COMPARISON" | |
| - If no specific entity mentioned, return "NONE" | |
| - Match to the closest valid entity from the list above | |
| Examples: | |
| Question: "What are Chinese AI regulations?" | |
| Answer: China | |
| Question: "How does Germany regulate AI?" | |
| Answer: Germany | |
| Question: "Compare US and EU approaches" | |
| Answer: COMPARISON | |
| Question: "What are best practices for AI governance?" | |
| Answer: NONE""", | |
| ), | |
| ("user", "Question: {question}\nAnswer:"), | |
| ] | |
| ) | |
| chain = prompt | llm | StrOutputParser() | |
| try: | |
| result = chain.invoke( | |
| {"question": question, "entities": ", ".join(valid_entities)} | |
| ).strip() | |
| # Validate result | |
| if result == "COMPARISON": | |
| return None # Let comparison detection handle it | |
| elif result == "NONE": | |
| return None | |
| elif result in valid_entities: | |
| return result | |
| else: | |
| # LLM returned something not in our list, ignore | |
| print(f"LLM returned invalid entity: {result}") | |
| return None | |
| except Exception as e: | |
| print(f"LLM entity extraction failed: {e}") | |
| return None | |
| def extract_document_references( | |
| question: str, metadata_df | |
| ) -> Tuple[List[str], Optional[str]]: | |
| """ | |
| Detect if question mentions specific documents by name or entity/country. | |
| Uses manual aliases for major regulations, fuzzy title matching as fallback, | |
| and LLM-based entity extraction as final fallback. | |
| Args: | |
| question: User's question | |
| metadata_df: DataFrame with document metadata (must have 'Title' and 'File_name' columns) | |
| Returns: | |
| Tuple of (list of filenames to boost, suggested entity filter) | |
| """ | |
| import pandas as pd | |
| import re | |
| question_lower = question.lower() | |
| matching_files = [] | |
| suggested_entity = None | |
| # FIRST: Check for entity names with regex (fast) | |
| if metadata_df is not None: | |
| for entity in metadata_df["Entity"].unique(): | |
| entity_lower = entity.lower() | |
| pattern = r"\b" + re.escape(entity_lower) + r"\b" | |
| if re.search(pattern, question_lower): | |
| suggested_entity = entity | |
| print(f"Regex detected entity: {entity}") | |
| break # Stop at first match | |
| document_aliases = { | |
| "eu ai act": ("EU_AI_Act_Articles.pdf", "EU"), | |
| "ai act": ("EU_AI_Act_Articles.pdf", "EU"), | |
| "artificial intelligence act": ("EU_AI_Act_Articles.pdf", "EU"), | |
| "gdpr": ("EU_GDPR_Articles.pdf", "EU"), | |
| "general data protection regulation": ("EU_GDPR_Articles.pdf", "EU"), | |
| "ccpa": ("ccpa_act.pdf", "USA"), | |
| "california consumer privacy act": ("ccpa_act.pdf", "USA"), | |
| "dsa": ("EU_DSA_Articles.pdf", "EU"), | |
| "digital services act": ("EU_DSA_Articles.pdf", "EU"), | |
| "dma": ("EU_DMA_Articles.pdf", "EU"), | |
| "digital markets act": ("EU_DMA_Articles.pdf", "EU"), | |
| "dga": ("EU_DGA_Articles.pdf", "EU"), | |
| "data governance act": ("EU_DGA_Articles.pdf", "EU"), | |
| } | |
| # Check manual aliases | |
| for alias, (filename, entity) in document_aliases.items(): | |
| if alias in question_lower: | |
| matching_files.append(filename) | |
| if suggested_entity is None: # Only override if not already set | |
| suggested_entity = entity | |
| # Fuzzy matching against titles (only if no specific doc found) | |
| if not matching_files: | |
| common_words = { | |
| "the", | |
| "of", | |
| "and", | |
| "a", | |
| "an", | |
| "for", | |
| "in", | |
| "on", | |
| "to", | |
| "act", | |
| "regulation", | |
| "law", | |
| "policy", | |
| "framework", | |
| "strategy", | |
| "guide", | |
| } | |
| question_words = set(question_lower.split()) | |
| question_significant = question_words - common_words | |
| for _, row in metadata_df.iterrows(): | |
| title = str(row["Title"]).lower() | |
| filename = row["File_name"] | |
| entity = row["Entity"] | |
| title_words = set(title.split()) | |
| title_significant = title_words - common_words | |
| matches = title_significant & question_significant | |
| if len(matches) >= 2: | |
| matching_files.append(filename) | |
| if suggested_entity is None: # Only override if not already set | |
| suggested_entity = entity | |
| if suggested_entity is None and metadata_df is not None: | |
| print("Regex entity detection failed, trying LLM extraction...") | |
| suggested_entity = extract_entity_with_llm(question, metadata_df) | |
| if suggested_entity: | |
| print(f"LLM detected entity: {suggested_entity}") | |
| return list(set(matching_files)), suggested_entity | |
| # def extract_document_references( | |
| # question: str, metadata_df | |
| # ) -> Tuple[List[str], Optional[str]]: | |
| # """ | |
| # Detect if question mentions specific documents by name or entity/country. | |
| # """ | |
| # import pandas as pd | |
| # import re | |
| # question_lower = question.lower() | |
| # matching_files = [] | |
| # suggested_entity = None | |
| # if metadata_df is not None: | |
| # for entity in metadata_df["Entity"].unique(): | |
| # entity_lower = entity.lower() | |
| # pattern = r"\b" + re.escape(entity_lower) + r"\b" | |
| # if re.search(pattern, question_lower): | |
| # suggested_entity = entity | |
| # break # Stop at first match | |
| # document_aliases = { | |
| # "eu ai act": ("EU_AI_Act_Articles.pdf", "EU"), | |
| # "ai act": ("EU_AI_Act_Articles.pdf", "EU"), | |
| # "artificial intelligence act": ("EU_AI_Act_Articles.pdf", "EU"), | |
| # "gdpr": ("EU_GDPR_Articles.pdf", "EU"), | |
| # "general data protection regulation": ("EU_GDPR_Articles.pdf", "EU"), | |
| # "ccpa": ("ccpa_act.pdf", "USA"), | |
| # "california consumer privacy act": ("ccpa_act.pdf", "USA"), | |
| # "dsa": ("EU_DSA_Articles.pdf", "EU"), | |
| # "digital services act": ("EU_DSA_Articles.pdf", "EU"), | |
| # "dma": ("EU_DMA_Articles.pdf", "EU"), | |
| # "digital markets act": ("EU_DMA_Articles.pdf", "EU"), | |
| # "dga": ("EU_DGA_Articles.pdf", "EU"), | |
| # "data governance act": ("EU_DGA_Articles.pdf", "EU"), | |
| # } | |
| # # Check manual aliases | |
| # for alias, (filename, entity) in document_aliases.items(): | |
| # if alias in question_lower: | |
| # matching_files.append(filename) | |
| # if suggested_entity is None: # Only override if not already set | |
| # suggested_entity = entity | |
| # if not matching_files: | |
| # common_words = { | |
| # "the", | |
| # "of", | |
| # "and", | |
| # "a", | |
| # "an", | |
| # "for", | |
| # "in", | |
| # "on", | |
| # "to", | |
| # "act", | |
| # "regulation", | |
| # "law", | |
| # "policy", | |
| # "framework", | |
| # "strategy", | |
| # "guide", | |
| # } | |
| # question_words = set(question_lower.split()) | |
| # question_significant = question_words - common_words | |
| # for _, row in metadata_df.iterrows(): | |
| # title = str(row["Title"]).lower() | |
| # filename = row["File_name"] | |
| # entity = row["Entity"] | |
| # title_words = set(title.split()) | |
| # title_significant = title_words - common_words | |
| # matches = title_significant & question_significant | |
| # if len(matches) >= 2: | |
| # matching_files.append(filename) | |
| # if suggested_entity is None: # Only override if not already set | |
| # suggested_entity = entity | |
| # return list(set(matching_files)), suggested_entity | |
| def is_comparison_question(question: str) -> bool: | |
| """Detect if question is comparing multiple jurisdictions""" | |
| question_lower = question.lower() | |
| comparison_patterns = [ | |
| "differ from", | |
| "compared to", | |
| "versus", | |
| "vs", | |
| "vs.", | |
| "difference between", | |
| "differences between", | |
| "compare", | |
| "comparison", | |
| "contrast", | |
| "unlike", | |
| "similar to", | |
| "different from", | |
| "in contrast to", | |
| "as opposed to", | |
| ] | |
| return any(pattern in question_lower for pattern in comparison_patterns) | |
| def ask_question_with_llm( | |
| vectorstore, | |
| question: str, | |
| metadata_df=None, | |
| entity: Optional[str] = None, | |
| k: int = 10, | |
| model_name: str = "gpt-4o-mini", | |
| boost_specific_docs: bool = True, | |
| auto_detect_entity: bool = True, | |
| ) -> dict: | |
| """ | |
| Args: | |
| vectorstore: FAISS vectorstore | |
| question: Question to ask | |
| metadata_df: DataFrame with document metadata | |
| entity: Optional entity filter (if None and auto_detect_entity=True, will auto-detect) | |
| k: Number of chunks to retrieve | |
| model_name: OpenAI model to use | |
| boost_specific_docs: If True, boost documents mentioned in the question | |
| auto_detect_entity: If True, auto-detect entity from document references | |
| Returns: | |
| Dictionary with answer, sources, and metadata | |
| """ | |
| # If it's a comparison question, disable auto entity detection | |
| if is_comparison_question(question): | |
| auto_detect_entity = False | |
| print("Comparison question detected - retrieving from all jurisdictions") | |
| # Check if question references specific documents | |
| referenced_docs = [] | |
| detected_entity = None | |
| if boost_specific_docs and metadata_df is not None: | |
| referenced_docs, detected_entity = extract_document_references( | |
| question, metadata_df | |
| ) | |
| # Use detected entity if no entity specified and auto-detection enabled | |
| if entity is None and auto_detect_entity and detected_entity: | |
| entity = detected_entity | |
| print(f"Auto-detected entity filter: {entity}") | |
| # Store original entity for EU fallback | |
| original_entity = entity | |
| # Also check for article/section references | |
| referenced_articles = extract_article_references(question) | |
| # If specific articles are mentioned AND specific documents are detected, | |
| # search docstore directly by metadata (bypass semantic search entirely) | |
| if referenced_articles and referenced_docs: | |
| print(f"Detected Article/Section references: {referenced_articles}") | |
| print(f"Using direct metadata search (bypassing semantic search)...") | |
| matched_chunks = [] | |
| for doc_id, doc in vectorstore.docstore._dict.items(): | |
| doc_article = doc.metadata.get("article") or doc.metadata.get("section") | |
| filename = doc.metadata.get("filename") | |
| doc_entity = doc.metadata.get("entity") | |
| if ( | |
| doc_article | |
| and str(doc_article) in referenced_articles | |
| and filename in referenced_docs | |
| and (entity is None or doc_entity == entity) | |
| ): | |
| matched_chunks.append((doc, 0.0)) | |
| if matched_chunks: | |
| results = matched_chunks[:k] | |
| print( | |
| f"Found {len(matched_chunks)} chunks for Article {referenced_articles[0]}" | |
| ) | |
| else: | |
| print("No metadata matches found, falling back to semantic search...") | |
| # Fall back to regular semantic search | |
| retrieve_k = k * 5 | |
| results = vectorstore.similarity_search_with_score( | |
| question, k=retrieve_k * 5 | |
| ) | |
| if entity: | |
| results = [ | |
| (doc, score) | |
| for doc, score in results | |
| if doc.metadata.get("entity") == entity | |
| ] | |
| results = results[:k] | |
| else: | |
| results = results[:k] | |
| else: | |
| # Regular retrieval when no specific articles mentioned | |
| # Retrieve more documents for potential boosting | |
| retrieve_k = k * 5 | |
| results = vectorstore.similarity_search_with_score(question, k=retrieve_k * 5) | |
| if entity: | |
| entity_results = [ | |
| (doc, score) | |
| for doc, score in results | |
| if doc.metadata.get("entity") == entity | |
| ] | |
| # EU MEMBER STATE FALLBACK | |
| # If no results and entity is an EU member state, expand to all results | |
| # Let reranking prioritize EU legislation | |
| if not entity_results and original_entity: | |
| eu_members = [ | |
| "Germany", | |
| "France", | |
| "Italy", | |
| "Spain", | |
| "Netherlands", | |
| "Poland", | |
| "Belgium", | |
| "Austria", | |
| "Greece", | |
| "Portugal", | |
| "Czechia", | |
| "Hungary", | |
| "Sweden", | |
| "Denmark", | |
| "Finland", | |
| "Slovakia", | |
| "Ireland", | |
| "Croatia", | |
| "Lithuania", | |
| "Slovenia", | |
| "Latvia", | |
| "Estonia", | |
| "Cyprus", | |
| "Luxembourg", | |
| "Malta", | |
| ] | |
| if original_entity in eu_members: | |
| print( | |
| f"No {original_entity}-specific documents found. Expanding search to include EU regulations as {original_entity} is an EU member state..." | |
| ) | |
| entity_results = results[:retrieve_k] | |
| results = entity_results[:retrieve_k] | |
| if not (referenced_articles and referenced_docs): | |
| if referenced_docs: | |
| boosted_results = [] | |
| other_results = [] | |
| for doc, score in results: | |
| filename = doc.metadata.get("filename") | |
| if filename in referenced_docs: | |
| boosted_results.append((doc, score * 0.3)) | |
| else: | |
| other_results.append((doc, score)) | |
| results = boosted_results + other_results | |
| # RERANK BY DOCUMENT PRIORITY - prioritize primary legislation | |
| results = rerank_by_document_priority(results, boost_factor=0.3) | |
| results = results[:k] | |
| if not results: | |
| return { | |
| "answer": "No relevant documents found for this query.", | |
| "sources": [], | |
| "question": question, | |
| "num_sources": 0, | |
| } | |
| context = format_context_with_citations(results) | |
| chain = create_rag_chain(model_name=model_name) | |
| answer = chain.invoke({"context": context, "question": question}) | |
| sources = [] | |
| for i, (doc, score) in enumerate(results, 1): | |
| sources.append( | |
| { | |
| "number": i, | |
| "citation": doc.metadata.get("citation"), | |
| "entity": doc.metadata.get("entity"), | |
| "score": float(score), | |
| "text_preview": doc.page_content[:200], | |
| } | |
| ) | |
| return { | |
| "answer": answer, | |
| "sources": sources, | |
| "question": question, | |
| "entity_filter": entity, | |
| "num_sources": len(sources), | |
| } | |
| def print_rag_answer(result: dict): | |
| print("\n" + "=" * 80) | |
| print(f"QUESTION: {result['question']}") | |
| if result.get("entity_filter"): | |
| print(f"FILTER: {result['entity_filter']}") | |
| print("=" * 80) | |
| print(f"\nANSWER:\n") | |
| print(result["answer"]) | |
| print(f"\n\n{'='*80}") | |
| print(f"SOURCES ({result['num_sources']} retrieved):") | |
| print("=" * 80) | |
| # Group sources by citation to avoid repetition | |
| seen_citations = {} | |
| citation_order = [] | |
| for source in result["sources"]: | |
| citation = source["citation"] | |
| entity = source["entity"] | |
| key = f"{citation}|{entity}" | |
| if key not in seen_citations: | |
| seen_citations[key] = { | |
| "citation": citation, | |
| "entity": entity, | |
| "source_numbers": [source["number"]], | |
| } | |
| citation_order.append(key) | |
| else: | |
| seen_citations[key]["source_numbers"].append(source["number"]) | |
| # Print deduplicated sources | |
| for key in citation_order: | |
| group = seen_citations[key] | |
| count = len(group["source_numbers"]) | |
| # If same doc appears many times, simplify the display | |
| if count > 3: | |
| print(f"\n{group['citation']}") | |
| print(f" Jurisdiction: {group['entity']}") | |
| print( | |
| f" (Referenced in sources {group['source_numbers'][0]}-{group['source_numbers'][-1]}, {count} chunks total)" | |
| ) | |
| else: | |
| source_nums = ", ".join([f"{n}" for n in group["source_numbers"]]) | |
| print(f"\n[{source_nums}] {group['citation']}") | |
| print(f" Jurisdiction: {group['entity']}") | |
| print() | |