Spaces:
Sleeping
Sleeping
| """Core RAG system implementation""" | |
| import os | |
| import glob | |
| import re | |
| from typing import List, Tuple, Optional | |
| import PyPDF2 | |
| import faiss | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from huggingface_hub import InferenceClient | |
| import spaces | |
| class RAGSystem: | |
| def __init__(self): | |
| self.chunks = [] | |
| self.chunk_metadata = [] # Store chunk positions for overlap visualization | |
| self.embeddings = None | |
| self.index = None | |
| self.embedding_model = None | |
| self.embedding_model_name = None | |
| self.llm_client = None | |
| self.llm_model_name = None | |
| self.ready = False | |
| def is_ready(self) -> bool: | |
| """Check if the system is ready to process queries""" | |
| return self.ready and self.index is not None | |
| def load_default_corpus(self, chunk_size: int = 500, chunk_overlap: int = 50): | |
| """Load the default corpus from documents folder""" | |
| documents_dir = "documents" | |
| if not os.path.exists(documents_dir): | |
| return "Documents folder not found. Please upload a PDF.", "", "" | |
| # Get all PDFs in documents folder | |
| pdf_files = glob.glob(os.path.join(documents_dir, "*.pdf")) | |
| if not pdf_files: | |
| return "No PDF files found in documents folder. Please upload a PDF.", "", "" | |
| try: | |
| # Extract text from all PDFs | |
| all_text = "" | |
| corpus_summary = f"📚 **Loading {len(pdf_files)} documents:**\n\n" | |
| for pdf_path in pdf_files: | |
| filename = os.path.basename(pdf_path) | |
| corpus_summary += f"- {filename}\n" | |
| text = self.extract_text_from_pdf(pdf_path) | |
| all_text += f"\n\n=== {filename} ===\n\n{text}" | |
| corpus_summary += f"\n**Total text length:** {len(all_text)} characters\n" | |
| # Chunk the combined text | |
| self.chunks = self.chunk_text(all_text, chunk_size, chunk_overlap) | |
| if not self.chunks: | |
| return "Error: No valid chunks created from the documents.", "", "" | |
| # Create embeddings | |
| self.embeddings = self.create_embeddings(self.chunks) | |
| # Build index | |
| self.build_index(self.embeddings) | |
| self.ready = True | |
| # Format chunks for display with overlap highlighting | |
| chunks_display = self._format_chunks_with_overlap() | |
| status = f"✅ Success! Processed {len(pdf_files)} documents into {len(self.chunks)} chunks." | |
| return status, chunks_display, corpus_summary | |
| except Exception as e: | |
| self.ready = False | |
| return f"Error loading default corpus: {str(e)}", "", "" | |
| def extract_text_from_pdf(self, pdf_path: str) -> str: | |
| """Extract text from PDF file""" | |
| text = "" | |
| with open(pdf_path, 'rb') as file: | |
| pdf_reader = PyPDF2.PdfReader(file) | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() + "\n" | |
| return text | |
| def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]: | |
| """Split text into overlapping chunks and store metadata""" | |
| chunks = [] | |
| self.chunk_metadata = [] # Reset metadata | |
| start = 0 | |
| text_length = len(text) | |
| previous_end = 0 | |
| while start < text_length: | |
| end = start + chunk_size | |
| chunk = text[start:end] | |
| original_end = end | |
| # Try to break at sentence boundary | |
| if end < text_length: | |
| # Look for sentence endings | |
| last_period = chunk.rfind('.') | |
| last_newline = chunk.rfind('\n') | |
| break_point = max(last_period, last_newline) | |
| if break_point > chunk_size * 0.5: # Only break if we're past halfway | |
| chunk = chunk[:break_point + 1] | |
| end = start + break_point + 1 | |
| original_end = end | |
| # Calculate overlap with previous chunk | |
| overlap_start = max(0, start - previous_end) if previous_end > 0 else 0 | |
| overlap_length = min(overlap, previous_end - start) if start < previous_end else 0 | |
| chunks.append(chunk.strip()) | |
| self.chunk_metadata.append({ | |
| 'start': start, | |
| 'end': original_end, | |
| 'overlap_with_previous': overlap_length, | |
| 'text': chunk | |
| }) | |
| previous_end = original_end | |
| start = end - overlap | |
| # Filter out very small chunks and update metadata accordingly | |
| filtered_chunks = [] | |
| filtered_metadata = [] | |
| for i, c in enumerate(chunks): | |
| if len(c) > 50: | |
| filtered_chunks.append(c) | |
| filtered_metadata.append(self.chunk_metadata[i]) | |
| self.chunk_metadata = filtered_metadata | |
| return filtered_chunks | |
| def create_embeddings(self, texts: List[str]) -> np.ndarray: | |
| """Create embeddings for text chunks""" | |
| if self.embedding_model is None: | |
| self.set_embedding_model("sentence-transformers/all-MiniLM-L6-v2") | |
| embeddings = self.embedding_model.encode( | |
| texts, | |
| show_progress_bar=True, | |
| convert_to_numpy=True | |
| ) | |
| return embeddings | |
| def build_index(self, embeddings: np.ndarray): | |
| """Build FAISS index from embeddings""" | |
| dimension = embeddings.shape[1] | |
| self.index = faiss.IndexFlatIP(dimension) # Inner product for cosine similarity | |
| # Normalize embeddings for cosine similarity | |
| faiss.normalize_L2(embeddings) | |
| self.index.add(embeddings) | |
| def process_document(self, pdf_path: str, chunk_size: int = 500, chunk_overlap: int = 50): | |
| """Process a PDF document and create searchable index""" | |
| try: | |
| # Extract text | |
| text = self.extract_text_from_pdf(pdf_path) | |
| if not text.strip(): | |
| return "Error: No text could be extracted from the PDF.", "", "" | |
| # Chunk text | |
| self.chunks = self.chunk_text(text, chunk_size, chunk_overlap) | |
| if not self.chunks: | |
| return "Error: No valid chunks created from the document.", "", "" | |
| # Create embeddings | |
| self.embeddings = self.create_embeddings(self.chunks) | |
| # Build index | |
| self.build_index(self.embeddings) | |
| self.ready = True | |
| # Format chunks for display with overlap highlighting | |
| chunks_display = self._format_chunks_with_overlap() | |
| status = f"✅ Success! Processed {len(self.chunks)} chunks from the document." | |
| return status, chunks_display, text[:5000] # Return first 5000 chars of original text | |
| except Exception as e: | |
| self.ready = False | |
| return f"Error processing document: {str(e)}", "", "" | |
| def _format_chunks_with_overlap(self) -> str: | |
| """Format chunks with overlap highlighting for pedagogical display""" | |
| if not self.chunks or not self.chunk_metadata: | |
| return "No chunks available" | |
| display = "### 📑 Processed Chunks\n\n" | |
| display += "*Overlapping parts are shown separately with a yellow marker (⚠️)*\n\n" | |
| display += "---\n\n" | |
| for i, (chunk, metadata) in enumerate(zip(self.chunks, self.chunk_metadata), 1): | |
| # Calculate which part is overlapping with previous chunk | |
| if i == 1: | |
| # First chunk has no overlap | |
| display += f"#### 📄 Chunk {i}\n" | |
| display += f"**{len(chunk)} characters** | 🆕 No overlap (first chunk)\n\n" | |
| display += f"```text\n{chunk}\n```\n\n" | |
| display += "---\n\n" | |
| else: | |
| # Find overlap with previous chunk | |
| prev_chunk = self.chunks[i-2] | |
| # Find common substring at the beginning of current chunk | |
| overlap_length = 0 | |
| for j in range(1, min(len(chunk), len(prev_chunk)) + 1): | |
| if prev_chunk[-j:] == chunk[:j]: | |
| overlap_length = j | |
| if overlap_length > 0: | |
| overlap_text = chunk[:overlap_length] | |
| remaining_text = chunk[overlap_length:] | |
| display += f"#### 📄 Chunk {i}\n" | |
| display += f"**{len(chunk)} characters** | ⚠️ **{overlap_length} characters overlap** with previous chunk\n\n" | |
| # Show overlap | |
| display += f"> **⚠️ OVERLAP ({overlap_length} chars) - Repeated from Chunk {i-1}:**\n" | |
| display += f"> ```text\n" | |
| for line in overlap_text.split('\n'): | |
| display += f"> {line}\n" | |
| display += f"> ```\n\n" | |
| # Show the new content | |
| display += f"**🆕 NEW CONTENT ({len(remaining_text)} chars):**\n" | |
| display += f"```text\n{remaining_text}\n```\n\n" | |
| # Show full chunk for reference | |
| display += f"<details>\n<summary>📋 Click to view complete chunk (overlap + new)</summary>\n\n" | |
| display += f"```text\n{chunk}\n```\n\n" | |
| display += f"</details>\n\n" | |
| else: | |
| # No overlap found (shouldn't happen normally) | |
| display += f"#### 📄 Chunk {i}\n" | |
| display += f"**{len(chunk)} characters** | No overlap detected\n\n" | |
| display += f"```text\n{chunk}\n```\n\n" | |
| display += "---\n\n" | |
| return display | |
| def set_embedding_model(self, model_name: str): | |
| """Set or change the embedding model""" | |
| if self.embedding_model_name != model_name: | |
| self.embedding_model_name = model_name | |
| # Some models require trust_remote_code | |
| try: | |
| self.embedding_model = SentenceTransformer(model_name) | |
| except Exception as e: | |
| if "trust_remote_code" in str(e): | |
| print(f"Model {model_name} requires trust_remote_code=True, loading with trust...") | |
| self.embedding_model = SentenceTransformer(model_name, trust_remote_code=True) | |
| else: | |
| raise e | |
| # If we have chunks, re-create embeddings and index | |
| if self.chunks: | |
| self.embeddings = self.create_embeddings(self.chunks) | |
| self.build_index(self.embeddings) | |
| def set_llm_model(self, model_name: str): | |
| """Set or change the LLM model""" | |
| if self.llm_model_name != model_name: | |
| self.llm_model_name = model_name | |
| # Use HF_TOKEN from environment if available | |
| hf_token = os.environ.get("HF_TOKEN", None) | |
| self.llm_client = InferenceClient(model_name, token=hf_token) | |
| def retrieve( | |
| self, | |
| query: str, | |
| top_k: int = 3, | |
| similarity_threshold: float = 0.0 | |
| ) -> List[Tuple[str, float]]: | |
| """Retrieve relevant chunks for a query""" | |
| if not self.is_ready(): | |
| return [] | |
| # Encode query | |
| query_embedding = self.embedding_model.encode( | |
| [query], | |
| convert_to_numpy=True | |
| ) | |
| # Normalize for cosine similarity | |
| faiss.normalize_L2(query_embedding) | |
| # Search | |
| scores, indices = self.index.search(query_embedding, top_k) | |
| # Filter by threshold and return results | |
| results = [] | |
| for score, idx in zip(scores[0], indices[0]): | |
| if score >= similarity_threshold: | |
| results.append((self.chunks[idx], float(score))) | |
| return results | |
| def generate( | |
| self, | |
| query: str, | |
| retrieved_chunks: List[Tuple[str, float]], | |
| temperature: float = 0.7, | |
| max_tokens: int = 300 | |
| ) -> Tuple[str, str]: | |
| """Generate answer using LLM""" | |
| if self.llm_client is None: | |
| self.set_llm_model("meta-llama/Llama-3.2-1B-Instruct") | |
| # Build context from retrieved chunks | |
| context = "\n\n".join([chunk for chunk, _ in retrieved_chunks]) | |
| # Create prompt | |
| prompt = f"""Use the following context to answer the question. If you cannot answer based on the context, say so. | |
| Context: | |
| {context} | |
| Question: {query} | |
| Answer:""" | |
| # Generate response - try chat_completion first, fallback to text_generation | |
| try: | |
| # Try chat_completion first | |
| try: | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": prompt | |
| } | |
| ] | |
| response = self.llm_client.chat_completion( | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| ) | |
| # Extract answer from response | |
| if hasattr(response, 'choices') and len(response.choices) > 0: | |
| answer = response.choices[0].message.content.strip() | |
| elif isinstance(response, dict) and 'choices' in response: | |
| answer = response['choices'][0]['message']['content'].strip() | |
| else: | |
| answer = str(response).strip() | |
| except Exception as chat_error: | |
| # Fallback to text_generation | |
| print(f"Chat completion failed, trying text_generation: {chat_error}") | |
| response = self.llm_client.text_generation( | |
| prompt, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| return_full_text=False, | |
| ) | |
| answer = response.strip() if isinstance(response, str) else str(response).strip() | |
| # Handle reasoning tokens (for models like Qwen) | |
| answer = self._process_reasoning_output(answer) | |
| return answer, prompt | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| return f"Error generating response: {str(e)}\n\nDetails:\n{error_details}", prompt | |
| def _process_reasoning_output(self, text: str) -> str: | |
| """Process output from reasoning models to separate thinking from answer""" | |
| # Debug: print first 200 chars to see the format | |
| print(f"[DEBUG] Processing output (first 200 chars): {text[:200]}") | |
| # Common patterns for reasoning models | |
| # Qwen uses <think>...</think> tags (case-insensitive check) | |
| if '<think>' in text.lower(): | |
| # Extract reasoning and answer (case-insensitive) | |
| reasoning_match = re.search(r'<think>(.*?)</think>', text, re.DOTALL | re.IGNORECASE) | |
| if reasoning_match: | |
| reasoning = reasoning_match.group(1).strip() | |
| answer = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL | re.IGNORECASE).strip() | |
| print(f"[DEBUG] Found reasoning tokens! Reasoning length: {len(reasoning)}, Answer length: {len(answer)}") | |
| return f"""**Answer:** | |
| {answer} | |
| --- | |
| <details> | |
| <summary>🧠 Model Reasoning (click to expand)</summary> | |
| ``` | |
| {reasoning} | |
| ``` | |
| </details>""" | |
| # Alternative pattern: Look for common thinking patterns in text | |
| # Some models output their reasoning inline without special tags | |
| thinking_patterns = [ | |
| r'(Let me think.*?(?:Answer:|Response:|Conclusion:))', | |
| r'(Okay, let\'s see.*?(?:Answer:|Response:|Conclusion:))', | |
| r'(First,.*?(?:Therefore,|Thus,|So,|In conclusion,))', | |
| ] | |
| for pattern in thinking_patterns: | |
| match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) | |
| if match: | |
| reasoning = match.group(1).strip() | |
| answer = text[match.end():].strip() | |
| if len(reasoning) > 100 and len(answer) > 20: # Substantial reasoning and answer | |
| print(f"[DEBUG] Found inline reasoning! Pattern matched.") | |
| return f"""**Answer:** | |
| {answer} | |
| --- | |
| <details> | |
| <summary>🧠 Model Reasoning (click to expand)</summary> | |
| ``` | |
| {reasoning} | |
| ``` | |
| </details>""" | |
| # Alternative pattern: text before "Answer:" or similar markers | |
| if re.search(r'(Answer:|Final Answer:|Response:)', text, re.IGNORECASE): | |
| parts = re.split(r'(Answer:|Final Answer:|Response:)', text, re.IGNORECASE) | |
| if len(parts) >= 3: | |
| reasoning = parts[0].strip() | |
| answer = ''.join(parts[2:]).strip() | |
| if reasoning and len(reasoning) > 50: # Only if there's substantial reasoning | |
| print(f"[DEBUG] Found Answer: marker pattern") | |
| return f"""**Answer:** | |
| {answer} | |
| --- | |
| <details> | |
| <summary>🧠 Model Reasoning (click to expand)</summary> | |
| ``` | |
| {reasoning} | |
| ``` | |
| </details>""" | |
| # No reasoning pattern found, return as is | |
| print(f"[DEBUG] No reasoning pattern found, returning as-is") | |
| return text | |
| def generate_example_questions(self, num_questions: int = 5) -> List[str]: | |
| """Generate example questions based on the corpus content""" | |
| if not self.is_ready() or not self.chunks: | |
| return [ | |
| "What is the main topic of this document?", | |
| "Can you summarize the key points?", | |
| "What are the main concepts discussed?", | |
| ] | |
| # Sample some chunks to understand the corpus | |
| sample_size = min(10, len(self.chunks)) | |
| import random | |
| sample_chunks = random.sample(self.chunks, sample_size) | |
| sample_text = "\n".join(sample_chunks[:3]) # Use first 3 sampled chunks | |
| # Generate questions using the LLM | |
| try: | |
| if self.llm_client is None: | |
| self.set_llm_model("meta-llama/Llama-3.2-1B-Instruct") | |
| prompt = f"""Based on the following text excerpts, generate {num_questions} diverse and relevant questions that could be answered using this corpus. Make the questions specific and interesting. | |
| Text excerpts: | |
| {sample_text[:2000]} | |
| Generate exactly {num_questions} questions, one per line, without numbering:""" | |
| # Try chat_completion first, fallback to text_generation | |
| try: | |
| messages = [{"role": "user", "content": prompt}] | |
| response = self.llm_client.chat_completion( | |
| messages=messages, | |
| max_tokens=300, | |
| temperature=0.8, | |
| ) | |
| # Extract questions | |
| if hasattr(response, 'choices') and len(response.choices) > 0: | |
| questions_text = response.choices[0].message.content.strip() | |
| elif isinstance(response, dict) and 'choices' in response: | |
| questions_text = response['choices'][0]['message']['content'].strip() | |
| else: | |
| questions_text = str(response).strip() | |
| except Exception as chat_error: | |
| print(f"Chat completion failed for questions, trying text_generation: {chat_error}") | |
| response = self.llm_client.text_generation( | |
| prompt, | |
| max_new_tokens=300, | |
| temperature=0.8, | |
| return_full_text=False, | |
| ) | |
| questions_text = response.strip() if isinstance(response, str) else str(response).strip() | |
| # Clean up reasoning if present | |
| questions_text = re.sub(r'<think>.*?</think>', '', questions_text, flags=re.DOTALL) | |
| # Parse questions | |
| questions = [q.strip() for q in questions_text.split('\n') if q.strip()] | |
| # Remove numbering if present | |
| questions = [re.sub(r'^\d+[\.\)]\s*', '', q) for q in questions] | |
| # Filter out empty or very short questions | |
| questions = [q for q in questions if len(q) > 10] | |
| return questions[:num_questions] if questions else self._default_questions() | |
| except Exception as e: | |
| import traceback | |
| print(f"Error generating questions: {e}") | |
| print(f"Traceback: {traceback.format_exc()}") | |
| return self._default_questions() | |
| def _default_questions(self) -> List[str]: | |
| """Return default questions if generation fails""" | |
| return [ | |
| "What is the main topic discussed in this corpus?", | |
| "Can you summarize the key concepts?", | |
| "What are the main findings or arguments presented?", | |
| ] | |