Spaces:
Running
Running
| """ | |
| Ask My Research - RAG chatbot over Anthony Maio's AI safety papers. | |
| Runs on HuggingFace Spaces using the Inference API. | |
| """ | |
| import json | |
| import os | |
| import time | |
| from pathlib import Path | |
| from collections import defaultdict | |
| import gradio as gr | |
| import numpy as np | |
| from huggingface_hub import InferenceClient | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| # ============================================================================= | |
| # Configuration | |
| # ============================================================================= | |
| INDEX_DIR = Path("index") | |
| EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
| LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.3" | |
| TOP_K = 5 # Number of chunks to retrieve | |
| MAX_NEW_TOKENS = 512 | |
| # Rate limiting | |
| RATE_LIMIT = 20 # requests per window | |
| RATE_WINDOW = 3600 # 1 hour in seconds | |
| request_log = defaultdict(list) | |
| # Paper metadata for nice display | |
| PAPER_INFO = { | |
| "manifold_model_organisms_arxiv": { | |
| "title": "Model Organisms of Supply-Chain Co-option", | |
| "url": "https://zenodo.org/records/18203353" | |
| }, | |
| "slipstream-paper": { | |
| "title": "Slipstream: Semantic Quantization for Multi-Agent Coordination", | |
| "url": "https://zenodo.org/records/18115418" | |
| }, | |
| "cmed_paper": { | |
| "title": "Cross-Model Epistemic Divergence (CMED)", | |
| "url": "https://making-minds.ai/research/cmed" | |
| }, | |
| "hdcs_paper": { | |
| "title": "Heterogeneous Divergence-Convergence Swarm (HDCS)", | |
| "url": "https://making-minds.ai/research/hdcs" | |
| }, | |
| "synthesis_paper": { | |
| "title": "Synthesis: Test-Driven AI Self-Extension", | |
| "url": "https://making-minds.ai/research/synthesis" | |
| }, | |
| "Coherence-Seeking-Architectures": { | |
| "title": "Coherence-Seeking Architectures for Agentic AI", | |
| "url": "https://zenodo.org/records/18137928" | |
| }, | |
| } | |
| # ============================================================================= | |
| # Rate Limiting | |
| # ============================================================================= | |
| def check_rate_limit(request: gr.Request) -> tuple[bool, str]: | |
| """Check if request is within rate limits.""" | |
| if request is None: | |
| return True, "" | |
| ip = request.client.host if request.client else "unknown" | |
| now = time.time() | |
| # Clean old entries | |
| request_log[ip] = [t for t in request_log[ip] if now - t < RATE_WINDOW] | |
| if len(request_log[ip]) >= RATE_LIMIT: | |
| remaining = int(RATE_WINDOW - (now - request_log[ip][0])) | |
| return False, f"Rate limit exceeded. Please try again in {remaining // 60} minutes." | |
| request_log[ip].append(now) | |
| return True, "" | |
| # ============================================================================= | |
| # Load Index and Models | |
| # ============================================================================= | |
| print("Loading embedding model...") | |
| embed_model = SentenceTransformer(EMBEDDING_MODEL) | |
| print("Loading FAISS index...") | |
| index_path = INDEX_DIR / "faiss.index" | |
| chunks_path = INDEX_DIR / "chunks.json" | |
| if index_path.exists() and chunks_path.exists(): | |
| faiss_index = faiss.read_index(str(index_path)) | |
| with open(chunks_path, "r", encoding="utf-8") as f: | |
| chunks = json.load(f) | |
| print(f"Loaded {faiss_index.ntotal} vectors and {len(chunks)} chunks") | |
| else: | |
| print("WARNING: Index not found. Run embed_papers.py first!") | |
| faiss_index = None | |
| chunks = [] | |
| # Initialize the Inference Client | |
| print("Initializing HF Inference Client...") | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if hf_token: | |
| client = InferenceClient(token=hf_token) | |
| print("Inference client ready with authentication") | |
| else: | |
| client = InferenceClient() | |
| print("WARNING: No HF_TOKEN found - using unauthenticated requests") | |
| # ============================================================================= | |
| # RAG Functions | |
| # ============================================================================= | |
| def retrieve(query: str, top_k: int = TOP_K) -> list[dict]: | |
| """Retrieve relevant chunks for a query.""" | |
| if faiss_index is None or not chunks: | |
| return [] | |
| # Embed query | |
| query_embedding = embed_model.encode([query], convert_to_numpy=True) | |
| faiss.normalize_L2(query_embedding) | |
| # Search | |
| distances, indices = faiss_index.search(query_embedding, top_k) | |
| # Get chunks with scores | |
| results = [] | |
| for dist, idx in zip(distances[0], indices[0]): | |
| if idx < len(chunks): | |
| chunk = chunks[idx].copy() | |
| chunk["score"] = float(dist) | |
| results.append(chunk) | |
| return results | |
| def format_context(retrieved_chunks: list[dict]) -> str: | |
| """Format retrieved chunks as context for the LLM.""" | |
| if not retrieved_chunks: | |
| return "No relevant context found." | |
| context_parts = [] | |
| for i, chunk in enumerate(retrieved_chunks, 1): | |
| source = chunk.get("source", "Unknown") | |
| page = chunk.get("page", "?") | |
| text = chunk.get("text", "") | |
| context_parts.append(f"[Source {i}: {source}, Page {page}]\n{text}") | |
| return "\n\n---\n\n".join(context_parts) | |
| def format_citations(retrieved_chunks: list[dict]) -> str: | |
| """Format citations for display.""" | |
| if not retrieved_chunks: | |
| return "" | |
| seen_sources = set() | |
| citations = [] | |
| for chunk in retrieved_chunks: | |
| source = chunk.get("source", "Unknown") | |
| if source in seen_sources: | |
| continue | |
| seen_sources.add(source) | |
| # Look up paper info | |
| paper = None | |
| for key, info in PAPER_INFO.items(): | |
| if key.lower() in source.lower() or source.lower() in key.lower(): | |
| paper = info | |
| break | |
| if paper: | |
| citations.append(f"- [{paper['title']}]({paper['url']}) (p. {chunk.get('page', '?')})") | |
| else: | |
| citations.append(f"- {source} (p. {chunk.get('page', '?')})") | |
| return "\n".join(citations) | |
| # ============================================================================= | |
| # Generation with Inference API | |
| # ============================================================================= | |
| def generate_response(query: str, context: str) -> str: | |
| """Generate response using the HF Inference API.""" | |
| # Build prompt | |
| system_prompt = """You are a helpful research assistant that answers questions about Anthony Maio's AI safety research papers. | |
| IMPORTANT RULES: | |
| 1. ONLY answer based on the provided context from the papers | |
| 2. If the context doesn't contain relevant information, say "I don't have information about that in the indexed papers" | |
| 3. Be precise and cite which paper/concept you're referencing | |
| 4. Keep responses concise but informative | |
| 5. Use technical terms accurately as defined in the papers""" | |
| user_prompt = f"""Context from research papers: | |
| {context} | |
| Question: {query} | |
| Provide a helpful answer based ONLY on the context above. If the context doesn't contain relevant information, say so.""" | |
| # Format for Mistral instruction format | |
| prompt = f"<s>[INST] {system_prompt}\n\n{user_prompt} [/INST]" | |
| # Call the Inference API | |
| response = client.text_generation( | |
| prompt, | |
| model=LLM_MODEL, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| temperature=0.7, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| ) | |
| return response.strip() | |
| # ============================================================================= | |
| # Chat Function | |
| # ============================================================================= | |
| def chat(message: str, history: list, request: gr.Request) -> str: | |
| """Main chat function.""" | |
| # Rate limit check | |
| allowed, error_msg = check_rate_limit(request) | |
| if not allowed: | |
| return error_msg | |
| if not message.strip(): | |
| return "Please enter a question." | |
| if faiss_index is None: | |
| return "The paper index is not loaded. Please check the Space configuration." | |
| # Retrieve relevant chunks | |
| retrieved = retrieve(message) | |
| if not retrieved: | |
| return "I couldn't find relevant information in the indexed papers." | |
| # Format context | |
| context = format_context(retrieved) | |
| # Generate response | |
| try: | |
| response = generate_response(message, context) | |
| except Exception as e: | |
| return f"Error generating response: {type(e).__name__}: {str(e)}" | |
| # Add citations | |
| citations = format_citations(retrieved) | |
| if citations: | |
| response = f"{response}\n\n**Sources:**\n{citations}" | |
| return response | |
| # ============================================================================= | |
| # Gradio Interface | |
| # ============================================================================= | |
| DESCRIPTION = """ | |
| # 🔬 Ask My Research | |
| Chat with Anthony Maio's AI safety research papers. Ask questions about: | |
| - **CMED** - Cross-Model Epistemic Divergence (weak verifier failures) | |
| - **HDCS** - Heterogeneous Divergence-Convergence Swarm (ensemble oversight) | |
| - **Slipstream** - Semantic quantization for multi-agent coordination | |
| - **Model Organisms** - Living-off-the-land failure modes in RAG agents | |
| - **Coherence** - Architectures for agentic AI continuity | |
| *Responses are grounded in the actual papers with citations.* | |
| """ | |
| EXAMPLES = [ | |
| "What is CMED and why does it matter for AI safety?", | |
| "How does Slipstream achieve 82% token reduction?", | |
| "What are living-off-the-land (LotL) failure modes?", | |
| "Explain the HDCS architecture for scalable oversight", | |
| "What is the Manifold Resonance Architecture (MRA)?", | |
| "How do weak verifiers fail to detect deceptive reasoning?", | |
| ] | |
| demo = gr.ChatInterface( | |
| fn=chat, | |
| type="messages", | |
| title="Ask My Research", | |
| description=DESCRIPTION, | |
| examples=EXAMPLES, | |
| theme=gr.themes.Soft( | |
| primary_hue="orange", | |
| secondary_hue="yellow", | |
| ), | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |