Spaces:
Sleeping
Sleeping
| """Core RAG system implementation""" | |
| import os | |
| from typing import List, Tuple, Optional | |
| import PyPDF2 | |
| import faiss | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from huggingface_hub import InferenceClient | |
| import spaces | |
| class RAGSystem: | |
| def __init__(self): | |
| self.chunks = [] | |
| self.embeddings = None | |
| self.index = None | |
| self.embedding_model = None | |
| self.embedding_model_name = None | |
| self.llm_client = None | |
| self.llm_model_name = None | |
| self.ready = False | |
| def is_ready(self) -> bool: | |
| """Check if the system is ready to process queries""" | |
| return self.ready and self.index is not None | |
| def load_default_corpus(self, chunk_size: int = 500, chunk_overlap: int = 50) -> str: | |
| """Load the default corpus""" | |
| default_path = "default_corpus.pdf" | |
| if os.path.exists(default_path): | |
| return self.process_document(default_path, chunk_size, chunk_overlap) | |
| else: | |
| return "Default corpus not found. Please upload a PDF." | |
| def extract_text_from_pdf(self, pdf_path: str) -> str: | |
| """Extract text from PDF file""" | |
| text = "" | |
| with open(pdf_path, 'rb') as file: | |
| pdf_reader = PyPDF2.PdfReader(file) | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() + "\n" | |
| return text | |
| def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]: | |
| """Split text into overlapping chunks""" | |
| chunks = [] | |
| start = 0 | |
| text_length = len(text) | |
| while start < text_length: | |
| end = start + chunk_size | |
| chunk = text[start:end] | |
| # Try to break at sentence boundary | |
| if end < text_length: | |
| # Look for sentence endings | |
| last_period = chunk.rfind('.') | |
| last_newline = chunk.rfind('\n') | |
| break_point = max(last_period, last_newline) | |
| if break_point > chunk_size * 0.5: # Only break if we're past halfway | |
| chunk = chunk[:break_point + 1] | |
| end = start + break_point + 1 | |
| chunks.append(chunk.strip()) | |
| start = end - overlap | |
| return [c for c in chunks if len(c) > 50] # Filter out very small chunks | |
| def create_embeddings(self, texts: List[str]) -> np.ndarray: | |
| """Create embeddings for text chunks""" | |
| if self.embedding_model is None: | |
| self.set_embedding_model("sentence-transformers/all-MiniLM-L6-v2") | |
| embeddings = self.embedding_model.encode( | |
| texts, | |
| show_progress_bar=True, | |
| convert_to_numpy=True | |
| ) | |
| return embeddings | |
| def build_index(self, embeddings: np.ndarray): | |
| """Build FAISS index from embeddings""" | |
| dimension = embeddings.shape[1] | |
| self.index = faiss.IndexFlatIP(dimension) # Inner product for cosine similarity | |
| # Normalize embeddings for cosine similarity | |
| faiss.normalize_L2(embeddings) | |
| self.index.add(embeddings) | |
| def process_document(self, pdf_path: str, chunk_size: int = 500, chunk_overlap: int = 50) -> str: | |
| """Process a PDF document and create searchable index""" | |
| try: | |
| # Extract text | |
| text = self.extract_text_from_pdf(pdf_path) | |
| if not text.strip(): | |
| return "Error: No text could be extracted from the PDF." | |
| # Chunk text | |
| self.chunks = self.chunk_text(text, chunk_size, chunk_overlap) | |
| if not self.chunks: | |
| return "Error: No valid chunks created from the document." | |
| # Create embeddings | |
| self.embeddings = self.create_embeddings(self.chunks) | |
| # Build index | |
| self.build_index(self.embeddings) | |
| self.ready = True | |
| return f"Success! Processed {len(self.chunks)} chunks from the document." | |
| except Exception as e: | |
| self.ready = False | |
| return f"Error processing document: {str(e)}" | |
| def set_embedding_model(self, model_name: str): | |
| """Set or change the embedding model""" | |
| if self.embedding_model_name != model_name: | |
| self.embedding_model_name = model_name | |
| self.embedding_model = SentenceTransformer(model_name) | |
| # If we have chunks, re-create embeddings and index | |
| if self.chunks: | |
| self.embeddings = self.create_embeddings(self.chunks) | |
| self.build_index(self.embeddings) | |
| def set_llm_model(self, model_name: str): | |
| """Set or change the LLM model""" | |
| if self.llm_model_name != model_name: | |
| self.llm_model_name = model_name | |
| self.llm_client = InferenceClient(model_name) | |
| def retrieve( | |
| self, | |
| query: str, | |
| top_k: int = 3, | |
| similarity_threshold: float = 0.0 | |
| ) -> List[Tuple[str, float]]: | |
| """Retrieve relevant chunks for a query""" | |
| if not self.is_ready(): | |
| return [] | |
| # Encode query | |
| query_embedding = self.embedding_model.encode( | |
| [query], | |
| convert_to_numpy=True | |
| ) | |
| # Normalize for cosine similarity | |
| faiss.normalize_L2(query_embedding) | |
| # Search | |
| scores, indices = self.index.search(query_embedding, top_k) | |
| # Filter by threshold and return results | |
| results = [] | |
| for score, idx in zip(scores[0], indices[0]): | |
| if score >= similarity_threshold: | |
| results.append((self.chunks[idx], float(score))) | |
| return results | |
| def generate( | |
| self, | |
| query: str, | |
| retrieved_chunks: List[Tuple[str, float]], | |
| temperature: float = 0.7, | |
| max_tokens: int = 300 | |
| ) -> Tuple[str, str]: | |
| """Generate answer using LLM""" | |
| if self.llm_client is None: | |
| self.set_llm_model("HuggingFaceH4/zephyr-7b-beta") | |
| # Build context from retrieved chunks | |
| context = "\n\n".join([chunk for chunk, _ in retrieved_chunks]) | |
| # Create prompt | |
| prompt = f"""You are a helpful assistant. Use the following context to answer the question. | |
| If you cannot answer based on the context, say so. | |
| Context: | |
| {context} | |
| Question: {query} | |
| Answer:""" | |
| # Generate response | |
| try: | |
| response = self.llm_client.text_generation( | |
| prompt, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| return_full_text=False | |
| ) | |
| return response, prompt | |
| except Exception as e: | |
| return f"Error generating response: {str(e)}", prompt | |