Spaces:
Sleeping
Sleeping
| """Core RAG system implementation""" | |
| import os | |
| import glob | |
| import re | |
| from typing import List, Tuple, Optional | |
| import PyPDF2 | |
| import faiss | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from huggingface_hub import InferenceClient | |
| import spaces | |
| class RAGSystem: | |
| def __init__(self): | |
| self.chunks = [] | |
| self.embeddings = None | |
| self.index = None | |
| self.embedding_model = None | |
| self.embedding_model_name = None | |
| self.llm_client = None | |
| self.llm_model_name = None | |
| self.ready = False | |
| def is_ready(self) -> bool: | |
| """Check if the system is ready to process queries""" | |
| return self.ready and self.index is not None | |
| def load_default_corpus(self, chunk_size: int = 500, chunk_overlap: int = 50): | |
| """Load the default corpus from documents folder""" | |
| documents_dir = "documents" | |
| if not os.path.exists(documents_dir): | |
| return "Documents folder not found. Please upload a PDF.", "", "" | |
| # Get all PDFs in documents folder | |
| pdf_files = glob.glob(os.path.join(documents_dir, "*.pdf")) | |
| if not pdf_files: | |
| return "No PDF files found in documents folder. Please upload a PDF.", "", "" | |
| try: | |
| # Extract text from all PDFs | |
| all_text = "" | |
| corpus_summary = f"📚 **Loading {len(pdf_files)} documents:**\n\n" | |
| for pdf_path in pdf_files: | |
| filename = os.path.basename(pdf_path) | |
| corpus_summary += f"- {filename}\n" | |
| text = self.extract_text_from_pdf(pdf_path) | |
| all_text += f"\n\n=== {filename} ===\n\n{text}" | |
| corpus_summary += f"\n**Total text length:** {len(all_text)} characters\n" | |
| # Chunk the combined text | |
| self.chunks = self.chunk_text(all_text, chunk_size, chunk_overlap) | |
| if not self.chunks: | |
| return "Error: No valid chunks created from the documents.", "", "" | |
| # Create embeddings | |
| self.embeddings = self.create_embeddings(self.chunks) | |
| # Build index | |
| self.build_index(self.embeddings) | |
| self.ready = True | |
| # Format chunks for display | |
| chunks_display = "### Processed Chunks\n\n" | |
| for i, chunk in enumerate(self.chunks, 1): | |
| chunks_display += f"**Chunk {i}** ({len(chunk)} chars)\n```\n{chunk[:200]}{'...' if len(chunk) > 200 else ''}\n```\n\n" | |
| status = f"✅ Success! Processed {len(pdf_files)} documents into {len(self.chunks)} chunks." | |
| return status, chunks_display, corpus_summary | |
| except Exception as e: | |
| self.ready = False | |
| return f"Error loading default corpus: {str(e)}", "", "" | |
| def extract_text_from_pdf(self, pdf_path: str) -> str: | |
| """Extract text from PDF file""" | |
| text = "" | |
| with open(pdf_path, 'rb') as file: | |
| pdf_reader = PyPDF2.PdfReader(file) | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() + "\n" | |
| return text | |
| def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]: | |
| """Split text into overlapping chunks""" | |
| chunks = [] | |
| start = 0 | |
| text_length = len(text) | |
| while start < text_length: | |
| end = start + chunk_size | |
| chunk = text[start:end] | |
| # Try to break at sentence boundary | |
| if end < text_length: | |
| # Look for sentence endings | |
| last_period = chunk.rfind('.') | |
| last_newline = chunk.rfind('\n') | |
| break_point = max(last_period, last_newline) | |
| if break_point > chunk_size * 0.5: # Only break if we're past halfway | |
| chunk = chunk[:break_point + 1] | |
| end = start + break_point + 1 | |
| chunks.append(chunk.strip()) | |
| start = end - overlap | |
| return [c for c in chunks if len(c) > 50] # Filter out very small chunks | |
| def create_embeddings(self, texts: List[str]) -> np.ndarray: | |
| """Create embeddings for text chunks""" | |
| if self.embedding_model is None: | |
| self.set_embedding_model("sentence-transformers/all-MiniLM-L6-v2") | |
| embeddings = self.embedding_model.encode( | |
| texts, | |
| show_progress_bar=True, | |
| convert_to_numpy=True | |
| ) | |
| return embeddings | |
| def build_index(self, embeddings: np.ndarray): | |
| """Build FAISS index from embeddings""" | |
| dimension = embeddings.shape[1] | |
| self.index = faiss.IndexFlatIP(dimension) # Inner product for cosine similarity | |
| # Normalize embeddings for cosine similarity | |
| faiss.normalize_L2(embeddings) | |
| self.index.add(embeddings) | |
| def process_document(self, pdf_path: str, chunk_size: int = 500, chunk_overlap: int = 50): | |
| """Process a PDF document and create searchable index""" | |
| try: | |
| # Extract text | |
| text = self.extract_text_from_pdf(pdf_path) | |
| if not text.strip(): | |
| return "Error: No text could be extracted from the PDF.", "", "" | |
| # Chunk text | |
| self.chunks = self.chunk_text(text, chunk_size, chunk_overlap) | |
| if not self.chunks: | |
| return "Error: No valid chunks created from the document.", "", "" | |
| # Create embeddings | |
| self.embeddings = self.create_embeddings(self.chunks) | |
| # Build index | |
| self.build_index(self.embeddings) | |
| self.ready = True | |
| # Format chunks for display | |
| chunks_display = "### Processed Chunks\n\n" | |
| for i, chunk in enumerate(self.chunks, 1): | |
| chunks_display += f"**Chunk {i}** ({len(chunk)} chars)\n```\n{chunk}\n```\n\n" | |
| status = f"✅ Success! Processed {len(self.chunks)} chunks from the document." | |
| return status, chunks_display, text[:5000] # Return first 5000 chars of original text | |
| except Exception as e: | |
| self.ready = False | |
| return f"Error processing document: {str(e)}", "", "" | |
| def set_embedding_model(self, model_name: str): | |
| """Set or change the embedding model""" | |
| if self.embedding_model_name != model_name: | |
| self.embedding_model_name = model_name | |
| self.embedding_model = SentenceTransformer(model_name) | |
| # If we have chunks, re-create embeddings and index | |
| if self.chunks: | |
| self.embeddings = self.create_embeddings(self.chunks) | |
| self.build_index(self.embeddings) | |
| def set_llm_model(self, model_name: str): | |
| """Set or change the LLM model""" | |
| if self.llm_model_name != model_name: | |
| self.llm_model_name = model_name | |
| # Use HF_TOKEN from environment if available | |
| hf_token = os.environ.get("HF_TOKEN", None) | |
| self.llm_client = InferenceClient(model_name, token=hf_token) | |
| def retrieve( | |
| self, | |
| query: str, | |
| top_k: int = 3, | |
| similarity_threshold: float = 0.0 | |
| ) -> List[Tuple[str, float]]: | |
| """Retrieve relevant chunks for a query""" | |
| if not self.is_ready(): | |
| return [] | |
| # Encode query | |
| query_embedding = self.embedding_model.encode( | |
| [query], | |
| convert_to_numpy=True | |
| ) | |
| # Normalize for cosine similarity | |
| faiss.normalize_L2(query_embedding) | |
| # Search | |
| scores, indices = self.index.search(query_embedding, top_k) | |
| # Filter by threshold and return results | |
| results = [] | |
| for score, idx in zip(scores[0], indices[0]): | |
| if score >= similarity_threshold: | |
| results.append((self.chunks[idx], float(score))) | |
| return results | |
| def generate( | |
| self, | |
| query: str, | |
| retrieved_chunks: List[Tuple[str, float]], | |
| temperature: float = 0.7, | |
| max_tokens: int = 300 | |
| ) -> Tuple[str, str]: | |
| """Generate answer using LLM""" | |
| if self.llm_client is None: | |
| self.set_llm_model("meta-llama/Llama-3.2-1B-Instruct") | |
| # Build context from retrieved chunks | |
| context = "\n\n".join([chunk for chunk, _ in retrieved_chunks]) | |
| # Create prompt | |
| prompt = f"""Use the following context to answer the question. If you cannot answer based on the context, say so. | |
| Context: | |
| {context} | |
| Question: {query} | |
| Answer:""" | |
| # Generate response using chat completion | |
| try: | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": prompt | |
| } | |
| ] | |
| response = self.llm_client.chat_completion( | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| ) | |
| # Extract answer from response | |
| if hasattr(response, 'choices') and len(response.choices) > 0: | |
| answer = response.choices[0].message.content.strip() | |
| elif isinstance(response, dict) and 'choices' in response: | |
| answer = response['choices'][0]['message']['content'].strip() | |
| else: | |
| answer = str(response).strip() | |
| # Handle reasoning tokens (for models like Qwen) | |
| answer = self._process_reasoning_output(answer) | |
| return answer, prompt | |
| except Exception as e: | |
| return f"Error generating response: {str(e)}", prompt | |
| def _process_reasoning_output(self, text: str) -> str: | |
| """Process output from reasoning models to separate thinking from answer""" | |
| # Common patterns for reasoning models | |
| # Qwen uses <think>...</think> tags | |
| if '<think>' in text and '</think>' in text: | |
| # Extract reasoning and answer | |
| reasoning_match = re.search(r'<think>(.*?)</think>', text, re.DOTALL) | |
| if reasoning_match: | |
| reasoning = reasoning_match.group(1).strip() | |
| answer = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip() | |
| return f"""**Answer:** | |
| {answer} | |
| --- | |
| <details> | |
| <summary>🧠 Model Reasoning (click to expand)</summary> | |
| ``` | |
| {reasoning} | |
| ``` | |
| </details>""" | |
| # Alternative pattern: text before "Answer:" or similar markers | |
| if re.search(r'(Answer:|Final Answer:|Response:)', text, re.IGNORECASE): | |
| parts = re.split(r'(Answer:|Final Answer:|Response:)', text, re.IGNORECASE) | |
| if len(parts) >= 3: | |
| reasoning = parts[0].strip() | |
| answer = ''.join(parts[2:]).strip() | |
| if reasoning and len(reasoning) > 50: # Only if there's substantial reasoning | |
| return f"""**Answer:** | |
| {answer} | |
| --- | |
| <details> | |
| <summary>🧠 Model Reasoning (click to expand)</summary> | |
| ``` | |
| {reasoning} | |
| ``` | |
| </details>""" | |
| # No reasoning pattern found, return as is | |
| return text | |
| def generate_example_questions(self, num_questions: int = 5) -> List[str]: | |
| """Generate example questions based on the corpus content""" | |
| if not self.is_ready() or not self.chunks: | |
| return [ | |
| "What is the main topic of this document?", | |
| "Can you summarize the key points?", | |
| "What are the main concepts discussed?", | |
| ] | |
| # Sample some chunks to understand the corpus | |
| sample_size = min(10, len(self.chunks)) | |
| import random | |
| sample_chunks = random.sample(self.chunks, sample_size) | |
| sample_text = "\n".join(sample_chunks[:3]) # Use first 3 sampled chunks | |
| # Generate questions using the LLM | |
| try: | |
| if self.llm_client is None: | |
| self.set_llm_model("meta-llama/Llama-3.2-1B-Instruct") | |
| prompt = f"""Based on the following text excerpts, generate {num_questions} diverse and relevant questions that could be answered using this corpus. Make the questions specific and interesting. | |
| Text excerpts: | |
| {sample_text[:2000]} | |
| Generate exactly {num_questions} questions, one per line, without numbering:""" | |
| messages = [{"role": "user", "content": prompt}] | |
| response = self.llm_client.chat_completion( | |
| messages=messages, | |
| max_tokens=300, | |
| temperature=0.8, | |
| ) | |
| # Extract questions | |
| if hasattr(response, 'choices') and len(response.choices) > 0: | |
| questions_text = response.choices[0].message.content.strip() | |
| elif isinstance(response, dict) and 'choices' in response: | |
| questions_text = response['choices'][0]['message']['content'].strip() | |
| else: | |
| questions_text = str(response).strip() | |
| # Clean up reasoning if present | |
| questions_text = re.sub(r'<think>.*?</think>', '', questions_text, flags=re.DOTALL) | |
| # Parse questions | |
| questions = [q.strip() for q in questions_text.split('\n') if q.strip()] | |
| # Remove numbering if present | |
| questions = [re.sub(r'^\d+[\.\)]\s*', '', q) for q in questions] | |
| # Filter out empty or very short questions | |
| questions = [q for q in questions if len(q) > 10] | |
| return questions[:num_questions] if questions else self._default_questions() | |
| except Exception as e: | |
| print(f"Error generating questions: {e}") | |
| return self._default_questions() | |
| def _default_questions(self) -> List[str]: | |
| """Return default questions if generation fails""" | |
| return [ | |
| "What is the main topic discussed in this corpus?", | |
| "Can you summarize the key concepts?", | |
| "What are the main findings or arguments presented?", | |
| ] | |