""" DocMind - Retriever Module Semantic search over arXiv papers using FAISS and sentence-transformers """ import numpy as np import faiss from sentence_transformers import SentenceTransformer from typing import List, Dict, Tuple import pickle from pathlib import Path class PaperRetriever: def __init__( self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", index_path: str = "data/faiss_index" ): """ Initialize retriever with embedding model and FAISS index Args: model_name: HuggingFace sentence-transformer model index_path: Directory to save/load FAISS index """ print(f"Loading embedding model: {model_name}") self.model = SentenceTransformer(model_name) self.index_path = Path(index_path) self.index_path.mkdir(parents=True, exist_ok=True) self.index = None self.papers = [] self.embeddings = None def build_index(self, papers: List[Dict]): """ Build FAISS index from papers Args: papers: List of paper dictionaries with 'title' and 'abstract' """ print(f"Building index for {len(papers)} papers...") self.papers = papers # Create text to embed: title + abstract texts = [ f"{paper['title']}. {paper['abstract']}" for paper in papers ] # Generate embeddings print("Generating embeddings...") self.embeddings = self.model.encode( texts, show_progress_bar=True, convert_to_numpy=True ) # Build FAISS index dimension = self.embeddings.shape[1] self.index = faiss.IndexFlatIP(dimension) # Inner product (cosine similarity) # Normalize embeddings for cosine similarity faiss.normalize_L2(self.embeddings) self.index.add(self.embeddings) print(f"Index built with {self.index.ntotal} papers") def save_index(self, name: str = "papers"): """Save FAISS index and metadata""" faiss.write_index(self.index, str(self.index_path / f"{name}.index")) with open(self.index_path / f"{name}_papers.pkl", 'wb') as f: pickle.dump(self.papers, f) with open(self.index_path / f"{name}_embeddings.npy", 'wb') as f: np.save(f, self.embeddings) print(f"Saved index to {self.index_path}/{name}.*") def load_index(self, name: str = "papers"): """Load FAISS index and metadata""" index_file = self.index_path / f"{name}.index" if not index_file.exists(): print(f"No index found at {index_file}") return False self.index = faiss.read_index(str(index_file)) with open(self.index_path / f"{name}_papers.pkl", 'rb') as f: self.papers = pickle.load(f) with open(self.index_path / f"{name}_embeddings.npy", 'rb') as f: self.embeddings = np.load(f) print(f"Loaded index with {len(self.papers)} papers") return True def search( self, query: str, top_k: int = 5 ) -> List[Tuple[Dict, float]]: """ Search for relevant papers Args: query: Search query string top_k: Number of results to return Returns: List of (paper_dict, score) tuples """ if self.index is None: raise ValueError("Index not built or loaded") # Embed query query_embedding = self.model.encode([query], convert_to_numpy=True) faiss.normalize_L2(query_embedding) # Search scores, indices = self.index.search(query_embedding, top_k) # Return results results = [] for idx, score in zip(indices[0], scores[0]): paper = self.papers[idx] results.append((paper, float(score))) return results def get_retrieval_context( self, query: str, top_k: int = 5 ) -> str: """ Get formatted context string for LLM consumption Args: query: Search query top_k: Number of papers to retrieve Returns: Formatted context string with paper summaries """ results = self.search(query, top_k) context = f"Retrieved {len(results)} relevant papers:\n\n" for i, (paper, score) in enumerate(results, 1): context += f"[{i}] {paper['title']}\n" context += f" Authors: {', '.join(paper['authors'][:3])}" if len(paper['authors']) > 3: context += f" et al." context += f"\n arXiv ID: {paper['arxiv_id']}\n" context += f" Published: {paper['published']}\n" context += f" Relevance: {score:.3f}\n" context += f" Abstract: {paper['abstract']}\n\n" return context def main(): """Example: Build and test retriever""" from fetch_arxiv_data import ArxivFetcher # Load papers fetcher = ArxivFetcher() papers = fetcher.load_papers("arxiv_papers.json") if not papers: print("No papers found. Run fetch_arxiv_data.py first") return # Build index retriever = PaperRetriever() retriever.build_index(papers) retriever.save_index() # Test search test_queries = [ "diffusion models for image generation", "reinforcement learning from human feedback", "large language model alignment" ] for query in test_queries: print(f"\n{'=' * 60}") print(f"Query: {query}") print('=' * 60) results = retriever.search(query, top_k=3) for i, (paper, score) in enumerate(results, 1): print(f"\n[{i}] Score: {score:.3f}") print(f" {paper['title']}") print(f" {paper['arxiv_id']}") if __name__ == "__main__": main()