"""Query module for RAG system with background model loading.""" import os import json import time import threading from typing import Dict, Any, List from dotenv import load_dotenv from openai import OpenAI from qdrant_client import QdrantClient from sentence_transformers import SentenceTransformer from citations import parse_llm_response, process_citations, format_citations_display from config import MODEL_NAME, COLLECTION_NAME load_dotenv() # ============================================================================ # Configuration # ============================================================================ LLM_MODEL = "gpt-4o" SOURCE_COUNT = 10 SCORE_THRESHOLD = 0.4 # ============================================================================ # Background Model Loading # ============================================================================ EMBEDDING_MODEL = None _model_loaded = threading.Event() def _load_model_background(): """Load the embedding model in a background thread.""" global EMBEDDING_MODEL print("🔄 Loading embedding model in background...") EMBEDDING_MODEL = SentenceTransformer(MODEL_NAME) _model_loaded.set() print("✅ Embedding model loaded!") def is_model_ready(): """Check if the embedding model is ready to use.""" return _model_loaded.is_set() # Start loading immediately when module is imported _loading_thread = threading.Thread(target=_load_model_background, daemon=True) _loading_thread.start() # ============================================================================ # Context Retrieval # ============================================================================ def retrieve_context(question): """Retrieve relevant chunks from Qdrant.""" start = time.time() client = QdrantClient( url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"), ) # Wait for model to be loaded (if still loading) if not _model_loaded.is_set(): print("⏳ Waiting for embedding model to finish loading...") if not _model_loaded.wait(timeout=120): raise Exception("Model loading timeout - please try again") query_vector = EMBEDDING_MODEL.encode(question).tolist() results = client.query_points( collection_name=COLLECTION_NAME, query=query_vector, limit=SOURCE_COUNT, score_threshold=SCORE_THRESHOLD, ) elapsed = (time.time() - start) * 1000 print(f"[TIMING] Retrieval: {elapsed:.0f}ms") return results.points def format_context(results): """Format retrieved chunks into context string for LLM.""" context_parts = [] for i, hit in enumerate(results, 1): context_parts.append( f"[Source {i}]\n" f"Title: {hit.payload['title']}\n" f"URL: {hit.payload['url']}\n" f"Content: {hit.payload['text']}\n" ) return "\n---\n".join(context_parts) # ============================================================================ # LLM Answer Generation # ============================================================================ def generate_answer_with_citations( question: str, context: str, results: List[Any], llm_model: str, openai_api_key: str ) -> Dict[str, Any]: """Generate answer with structured citations using OpenAI. Args: question: User's question context: Formatted context from source chunks results: Source chunks from Qdrant llm_model: OpenAI model name openai_api_key: OpenAI API key Returns: Dict with answer and validated citations """ client = OpenAI(api_key=openai_api_key) system_prompt = """Answer the user's question using ONLY the provided sources from 80,000 Hours articles. STEP 1: Write your answer - Write a clear, concise answer to the question - Use a natural, conversational tone - After EACH substantive claim, add [1], [2], [3], etc. in order - Example: "Career capital is important [1]. You can build it through work [2]." STEP 2: Provide citations - For each [N] in your answer, provide a citation with: * citation_id: The number from your answer (1 for [1], 2 for [2], etc.) * source_id: Which source it came from (match the [Source N] label exactly) * quote: Copy the EXACT sentences from that source, word-for-word EXAMPLE - If you found text in [Source 3]: - Your answer: "Career capital helps you succeed [1]." - Your citation: {"citation_id": 1, "source_id": 3, "quote": "Career capital includes..."} CRITICAL RULES: 1. Number citations in ORDER: [1] is first, [2] is second, [3] is third, etc. 2. Copy quotes EXACTLY - No changes, NO ellipses, No paraphrasing 3. source_id MUST match the source number: [Source 1] → source_id: 1, [Source 5] → source_id: 5 4. Each quote must be complete sentences from the source OUTPUT FORMAT (valid JSON): { "answer": "Your answer with [1], [2], [3] after each claim.", "citations": [ { "citation_id": 1, "source_id": 2, "quote": "Exact sentence from the source." }, { "citation_id": 2, "source_id": 5, "quote": "Another exact sentence from a different source." } ] }""" user_prompt = f"""Context from 80,000 Hours articles: {context} Question: {question} Provide your answer in JSON format with exact quotes from the sources.""" llm_start = time.time() response = client.chat.completions.create( model=llm_model, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ], response_format={"type": "json_object"} ) llm_time = (time.time() - llm_start) * 1000 print(f"[TIMING] LLM call: {llm_time:.0f}ms") # Parse LLM response parsed = parse_llm_response(response.choices[0].message.content) if "validation_errors" in parsed: return { "answer": parsed["answer"], # raw llm response "citations": [], "validation_errors": parsed["validation_errors"], "total_citations": 0, "valid_citations": 0 } answer = parsed.get("answer", "") citations = parsed.get("citations", []) # Validate citations validation_start = time.time() result = process_citations(citations, results) validation_time = (time.time() - validation_start) * 1000 print(f"[TIMING] Validation: {validation_time:.0f}ms") return { "answer": answer, "citations": result["validated_citations"], "validation_errors": result["validation_errors"], "total_citations": len(citations), "valid_citations": len(result["validated_citations"]) } # ============================================================================ # Results Processing & Display # ============================================================================ def save_validation_results(question: str, result: Dict[str, Any], results: List[Any], _unused_time: float): """Save detailed validation results to JSON file for debugging.""" validation_output = { "question": question, "answer": result["answer"], "citations": result["citations"], "validation_errors": result["validation_errors"], "stats": { "total_citations": result["total_citations"], "valid_citations": result["valid_citations"] }, "sources": [ { "source_id": i, "title": hit.payload['title'], "url": hit.payload['url'], "chunk_id": hit.payload.get('chunk_id'), "cosine_similarity": hit.score, # Vector similarity from Qdrant "text": hit.payload['text'] } for i, hit in enumerate(results, 1) ] } with open("validation_results.json", "w", encoding="utf-8") as f: json.dump(validation_output, f, ensure_ascii=False, indent=2) print("\n[INFO] Validation results saved to validation_results.json") def display_results(question: str, result: Dict[str, Any], context: str = None): """Display query results to console.""" print(f"Question: {question}\n") if context: print("=" * 80) print("RETRIEVED CONTEXT:") print("=" * 80) print(context) print("\n") print("=" * 80) print("ANSWER:") print("=" * 80) print(result["answer"]) print("\n") print("=" * 80) print("CITATIONS (Verified Quotes):") print("=" * 80) print(format_citations_display(result["citations"])) if result["validation_errors"]: print("\n" + "=" * 80) print("VALIDATION WARNINGS:") print("=" * 80) for error in result["validation_errors"]: print(f"⚠ [Citation {error['citation_id']}] {error['reason']}") print("\n" + "=" * 80) print(f"Citation Stats: {result['valid_citations']}/{result['total_citations']} citations validated") print("=" * 80) # ============================================================================ # Main Public API # ============================================================================ def ask(question: str, show_context: bool = False) -> Dict[str, Any]: """Main RAG function: retrieve context and generate answer with validated citations.""" total_start = time.time() results = retrieve_context(question) if not results: print("No relevant sources found above the score threshold.") return { "question": question, "answer": "No relevant information found in the knowledge base.", "citations": [], "sources": [] } context = format_context(results) result = generate_answer_with_citations( question=question, context=context, results=results, llm_model=LLM_MODEL, openai_api_key=os.getenv("OPENAI_API_KEY") ) total_time = (time.time() - total_start) * 1000 print(f"[TIMING] Total: {total_time:.0f}ms") # Display results # display_results(question, result, context if show_context else None) # Save debug output save_validation_results(question, result, results, 0) return { "question": question, "answer": result["answer"], "citations": result["citations"], "validation_errors": result["validation_errors"], "sources": results }