| | """ |
| | DocMind - Retriever Module |
| | Semantic search over arXiv papers using FAISS and sentence-transformers |
| | """ |
| |
|
| | import numpy as np |
| | import faiss |
| | from sentence_transformers import SentenceTransformer |
| | from typing import List, Dict, Tuple |
| | import pickle |
| | from pathlib import Path |
| |
|
| |
|
| | class PaperRetriever: |
| | def __init__( |
| | self, |
| | model_name: str = "sentence-transformers/all-MiniLM-L6-v2", |
| | index_path: str = "data/faiss_index" |
| | ): |
| | """ |
| | Initialize retriever with embedding model and FAISS index |
| | |
| | Args: |
| | model_name: HuggingFace sentence-transformer model |
| | index_path: Directory to save/load FAISS index |
| | """ |
| | print(f"Loading embedding model: {model_name}") |
| | self.model = SentenceTransformer(model_name) |
| | self.index_path = Path(index_path) |
| | self.index_path.mkdir(parents=True, exist_ok=True) |
| |
|
| | self.index = None |
| | self.papers = [] |
| | self.embeddings = None |
| |
|
| | def build_index(self, papers: List[Dict]): |
| | """ |
| | Build FAISS index from papers |
| | |
| | Args: |
| | papers: List of paper dictionaries with 'title' and 'abstract' |
| | """ |
| | print(f"Building index for {len(papers)} papers...") |
| | self.papers = papers |
| |
|
| | |
| | texts = [ |
| | f"{paper['title']}. {paper['abstract']}" |
| | for paper in papers |
| | ] |
| |
|
| | |
| | print("Generating embeddings...") |
| | self.embeddings = self.model.encode( |
| | texts, |
| | show_progress_bar=True, |
| | convert_to_numpy=True |
| | ) |
| |
|
| | |
| | dimension = self.embeddings.shape[1] |
| | self.index = faiss.IndexFlatIP(dimension) |
| |
|
| | |
| | faiss.normalize_L2(self.embeddings) |
| | self.index.add(self.embeddings) |
| |
|
| | print(f"Index built with {self.index.ntotal} papers") |
| |
|
| | def save_index(self, name: str = "papers"): |
| | """Save FAISS index and metadata""" |
| | faiss.write_index(self.index, str(self.index_path / f"{name}.index")) |
| |
|
| | with open(self.index_path / f"{name}_papers.pkl", 'wb') as f: |
| | pickle.dump(self.papers, f) |
| |
|
| | with open(self.index_path / f"{name}_embeddings.npy", 'wb') as f: |
| | np.save(f, self.embeddings) |
| |
|
| | print(f"Saved index to {self.index_path}/{name}.*") |
| |
|
| | def load_index(self, name: str = "papers"): |
| | """Load FAISS index and metadata""" |
| | index_file = self.index_path / f"{name}.index" |
| | if not index_file.exists(): |
| | print(f"No index found at {index_file}") |
| | return False |
| |
|
| | self.index = faiss.read_index(str(index_file)) |
| |
|
| | with open(self.index_path / f"{name}_papers.pkl", 'rb') as f: |
| | self.papers = pickle.load(f) |
| |
|
| | with open(self.index_path / f"{name}_embeddings.npy", 'rb') as f: |
| | self.embeddings = np.load(f) |
| |
|
| | print(f"Loaded index with {len(self.papers)} papers") |
| | return True |
| |
|
| | def search( |
| | self, |
| | query: str, |
| | top_k: int = 5 |
| | ) -> List[Tuple[Dict, float]]: |
| | """ |
| | Search for relevant papers |
| | |
| | Args: |
| | query: Search query string |
| | top_k: Number of results to return |
| | |
| | Returns: |
| | List of (paper_dict, score) tuples |
| | """ |
| | if self.index is None: |
| | raise ValueError("Index not built or loaded") |
| |
|
| | |
| | query_embedding = self.model.encode([query], convert_to_numpy=True) |
| | faiss.normalize_L2(query_embedding) |
| |
|
| | |
| | scores, indices = self.index.search(query_embedding, top_k) |
| |
|
| | |
| | results = [] |
| | for idx, score in zip(indices[0], scores[0]): |
| | paper = self.papers[idx] |
| | results.append((paper, float(score))) |
| |
|
| | return results |
| |
|
| | def get_retrieval_context( |
| | self, |
| | query: str, |
| | top_k: int = 5 |
| | ) -> str: |
| | """ |
| | Get formatted context string for LLM consumption |
| | |
| | Args: |
| | query: Search query |
| | top_k: Number of papers to retrieve |
| | |
| | Returns: |
| | Formatted context string with paper summaries |
| | """ |
| | results = self.search(query, top_k) |
| |
|
| | context = f"Retrieved {len(results)} relevant papers:\n\n" |
| | for i, (paper, score) in enumerate(results, 1): |
| | context += f"[{i}] {paper['title']}\n" |
| | context += f" Authors: {', '.join(paper['authors'][:3])}" |
| | if len(paper['authors']) > 3: |
| | context += f" et al." |
| | context += f"\n arXiv ID: {paper['arxiv_id']}\n" |
| | context += f" Published: {paper['published']}\n" |
| | context += f" Relevance: {score:.3f}\n" |
| | context += f" Abstract: {paper['abstract']}\n\n" |
| |
|
| | return context |
| |
|
| |
|
| | def main(): |
| | """Example: Build and test retriever""" |
| | from fetch_arxiv_data import ArxivFetcher |
| |
|
| | |
| | fetcher = ArxivFetcher() |
| | papers = fetcher.load_papers("arxiv_papers.json") |
| |
|
| | if not papers: |
| | print("No papers found. Run fetch_arxiv_data.py first") |
| | return |
| |
|
| | |
| | retriever = PaperRetriever() |
| | retriever.build_index(papers) |
| | retriever.save_index() |
| |
|
| | |
| | test_queries = [ |
| | "diffusion models for image generation", |
| | "reinforcement learning from human feedback", |
| | "large language model alignment" |
| | ] |
| |
|
| | for query in test_queries: |
| | print(f"\n{'=' * 60}") |
| | print(f"Query: {query}") |
| | print('=' * 60) |
| |
|
| | results = retriever.search(query, top_k=3) |
| | for i, (paper, score) in enumerate(results, 1): |
| | print(f"\n[{i}] Score: {score:.3f}") |
| | print(f" {paper['title']}") |
| | print(f" {paper['arxiv_id']}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |