"""Core RAG system implementation"""
import os
import glob
import re
from typing import List, Tuple, Optional
import PyPDF2
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from huggingface_hub import InferenceClient
import spaces
class RAGSystem:
def __init__(self):
self.chunks = []
self.chunk_metadata = [] # Store chunk positions for overlap visualization
self.embeddings = None
self.index = None
self.embedding_model = None
self.embedding_model_name = None
self.llm_client = None
self.llm_model_name = None
self.ready = False
def is_ready(self) -> bool:
"""Check if the system is ready to process queries"""
return self.ready and self.index is not None
def load_default_corpus(self, chunk_size: int = 500, chunk_overlap: int = 50):
"""Load the default corpus from documents folder"""
documents_dir = "documents"
if not os.path.exists(documents_dir):
return "Documents folder not found. Please upload a PDF.", "", ""
# Get all PDFs in documents folder
pdf_files = glob.glob(os.path.join(documents_dir, "*.pdf"))
if not pdf_files:
return "No PDF files found in documents folder. Please upload a PDF.", "", ""
try:
# Extract text from all PDFs
all_text = ""
corpus_summary = f"📚 **Loading {len(pdf_files)} documents:**\n\n"
for pdf_path in pdf_files:
filename = os.path.basename(pdf_path)
corpus_summary += f"- {filename}\n"
text = self.extract_text_from_pdf(pdf_path)
all_text += f"\n\n=== {filename} ===\n\n{text}"
corpus_summary += f"\n**Total text length:** {len(all_text)} characters\n"
# Chunk the combined text
self.chunks = self.chunk_text(all_text, chunk_size, chunk_overlap)
if not self.chunks:
return "Error: No valid chunks created from the documents.", "", ""
# Create embeddings
self.embeddings = self.create_embeddings(self.chunks)
# Build index
self.build_index(self.embeddings)
self.ready = True
# Format chunks for display with overlap highlighting
chunks_display = self._format_chunks_with_overlap()
status = f"✅ Success! Processed {len(pdf_files)} documents into {len(self.chunks)} chunks."
return status, chunks_display, corpus_summary
except Exception as e:
self.ready = False
return f"Error loading default corpus: {str(e)}", "", ""
def extract_text_from_pdf(self, pdf_path: str) -> str:
"""Extract text from PDF file"""
text = ""
with open(pdf_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
for page in pdf_reader.pages:
text += page.extract_text() + "\n"
return text
def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]:
"""Split text into overlapping chunks and store metadata"""
chunks = []
self.chunk_metadata = [] # Reset metadata
start = 0
text_length = len(text)
previous_end = 0
while start < text_length:
end = start + chunk_size
chunk = text[start:end]
original_end = end
# Try to break at sentence boundary
if end < text_length:
# Look for sentence endings
last_period = chunk.rfind('.')
last_newline = chunk.rfind('\n')
break_point = max(last_period, last_newline)
if break_point > chunk_size * 0.5: # Only break if we're past halfway
chunk = chunk[:break_point + 1]
end = start + break_point + 1
original_end = end
# Calculate overlap with previous chunk
overlap_start = max(0, start - previous_end) if previous_end > 0 else 0
overlap_length = min(overlap, previous_end - start) if start < previous_end else 0
chunks.append(chunk.strip())
self.chunk_metadata.append({
'start': start,
'end': original_end,
'overlap_with_previous': overlap_length,
'text': chunk
})
previous_end = original_end
start = end - overlap
# Filter out very small chunks and update metadata accordingly
filtered_chunks = []
filtered_metadata = []
for i, c in enumerate(chunks):
if len(c) > 50:
filtered_chunks.append(c)
filtered_metadata.append(self.chunk_metadata[i])
self.chunk_metadata = filtered_metadata
return filtered_chunks
@spaces.GPU
def create_embeddings(self, texts: List[str]) -> np.ndarray:
"""Create embeddings for text chunks"""
if self.embedding_model is None:
self.set_embedding_model("sentence-transformers/all-MiniLM-L6-v2")
embeddings = self.embedding_model.encode(
texts,
show_progress_bar=True,
convert_to_numpy=True
)
return embeddings
def build_index(self, embeddings: np.ndarray):
"""Build FAISS index from embeddings"""
dimension = embeddings.shape[1]
self.index = faiss.IndexFlatIP(dimension) # Inner product for cosine similarity
# Normalize embeddings for cosine similarity
faiss.normalize_L2(embeddings)
self.index.add(embeddings)
def process_document(self, pdf_path: str, chunk_size: int = 500, chunk_overlap: int = 50):
"""Process a PDF document and create searchable index"""
try:
# Extract text
text = self.extract_text_from_pdf(pdf_path)
if not text.strip():
return "Error: No text could be extracted from the PDF.", "", ""
# Chunk text
self.chunks = self.chunk_text(text, chunk_size, chunk_overlap)
if not self.chunks:
return "Error: No valid chunks created from the document.", "", ""
# Create embeddings
self.embeddings = self.create_embeddings(self.chunks)
# Build index
self.build_index(self.embeddings)
self.ready = True
# Format chunks for display with overlap highlighting
chunks_display = self._format_chunks_with_overlap()
status = f"✅ Success! Processed {len(self.chunks)} chunks from the document."
return status, chunks_display, text[:5000] # Return first 5000 chars of original text
except Exception as e:
self.ready = False
return f"Error processing document: {str(e)}", "", ""
def _format_chunks_with_overlap(self) -> str:
"""Format chunks with overlap highlighting for pedagogical display"""
if not self.chunks or not self.chunk_metadata:
return "No chunks available"
display = "### 📑 Processed Chunks\n\n"
display += "*Overlapping parts are shown separately with a yellow marker (⚠️)*\n\n"
display += "---\n\n"
for i, (chunk, metadata) in enumerate(zip(self.chunks, self.chunk_metadata), 1):
# Calculate which part is overlapping with previous chunk
if i == 1:
# First chunk has no overlap
display += f"#### 📄 Chunk {i}\n"
display += f"**{len(chunk)} characters** | 🆕 No overlap (first chunk)\n\n"
display += f"```text\n{chunk}\n```\n\n"
display += "---\n\n"
else:
# Find overlap with previous chunk
prev_chunk = self.chunks[i-2]
# Find common substring at the beginning of current chunk
overlap_length = 0
for j in range(1, min(len(chunk), len(prev_chunk)) + 1):
if prev_chunk[-j:] == chunk[:j]:
overlap_length = j
if overlap_length > 0:
overlap_text = chunk[:overlap_length]
remaining_text = chunk[overlap_length:]
display += f"#### 📄 Chunk {i}\n"
display += f"**{len(chunk)} characters** | ⚠️ **{overlap_length} characters overlap** with previous chunk\n\n"
# Show overlap
display += f"> **⚠️ OVERLAP ({overlap_length} chars) - Repeated from Chunk {i-1}:**\n"
display += f"> ```text\n"
for line in overlap_text.split('\n'):
display += f"> {line}\n"
display += f"> ```\n\n"
# Show the new content
display += f"**🆕 NEW CONTENT ({len(remaining_text)} chars):**\n"
display += f"```text\n{remaining_text}\n```\n\n"
# Show full chunk for reference
display += f"\n📋 Click to view complete chunk (overlap + new)
\n\n"
display += f"```text\n{chunk}\n```\n\n"
display += f" \n\n"
else:
# No overlap found (shouldn't happen normally)
display += f"#### 📄 Chunk {i}\n"
display += f"**{len(chunk)} characters** | No overlap detected\n\n"
display += f"```text\n{chunk}\n```\n\n"
display += "---\n\n"
return display
def set_embedding_model(self, model_name: str):
"""Set or change the embedding model"""
if self.embedding_model_name != model_name:
self.embedding_model_name = model_name
# Some models require trust_remote_code
try:
self.embedding_model = SentenceTransformer(model_name)
except Exception as e:
if "trust_remote_code" in str(e):
print(f"Model {model_name} requires trust_remote_code=True, loading with trust...")
self.embedding_model = SentenceTransformer(model_name, trust_remote_code=True)
else:
raise e
# If we have chunks, re-create embeddings and index
if self.chunks:
self.embeddings = self.create_embeddings(self.chunks)
self.build_index(self.embeddings)
def set_llm_model(self, model_name: str):
"""Set or change the LLM model"""
if self.llm_model_name != model_name:
self.llm_model_name = model_name
# Use HF_TOKEN from environment if available
hf_token = os.environ.get("HF_TOKEN", None)
self.llm_client = InferenceClient(model_name, token=hf_token)
@spaces.GPU
def retrieve(
self,
query: str,
top_k: int = 3,
similarity_threshold: float = 0.0
) -> List[Tuple[str, float]]:
"""Retrieve relevant chunks for a query"""
if not self.is_ready():
return []
# Encode query
query_embedding = self.embedding_model.encode(
[query],
convert_to_numpy=True
)
# Normalize for cosine similarity
faiss.normalize_L2(query_embedding)
# Search
scores, indices = self.index.search(query_embedding, top_k)
# Filter by threshold and return results
results = []
for score, idx in zip(scores[0], indices[0]):
if score >= similarity_threshold:
results.append((self.chunks[idx], float(score)))
return results
@spaces.GPU
def generate(
self,
query: str,
retrieved_chunks: List[Tuple[str, float]],
temperature: float = 0.7,
max_tokens: int = 300
) -> Tuple[str, str]:
"""Generate answer using LLM"""
if self.llm_client is None:
self.set_llm_model("meta-llama/Llama-3.2-1B-Instruct")
# Build context from retrieved chunks
context = "\n\n".join([chunk for chunk, _ in retrieved_chunks])
# Create prompt
prompt = f"""Use the following context to answer the question. If you cannot answer based on the context, say so.
Context:
{context}
Question: {query}
Answer:"""
# Generate response - try chat_completion first, fallback to text_generation
try:
# Try chat_completion first
try:
messages = [
{
"role": "user",
"content": prompt
}
]
response = self.llm_client.chat_completion(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
# Extract answer from response
if hasattr(response, 'choices') and len(response.choices) > 0:
answer = response.choices[0].message.content.strip()
elif isinstance(response, dict) and 'choices' in response:
answer = response['choices'][0]['message']['content'].strip()
else:
answer = str(response).strip()
except Exception as chat_error:
# Fallback to text_generation
print(f"Chat completion failed, trying text_generation: {chat_error}")
response = self.llm_client.text_generation(
prompt,
max_new_tokens=max_tokens,
temperature=temperature,
return_full_text=False,
)
answer = response.strip() if isinstance(response, str) else str(response).strip()
# Handle reasoning tokens (for models like Qwen)
answer = self._process_reasoning_output(answer)
return answer, prompt
except Exception as e:
import traceback
error_details = traceback.format_exc()
return f"Error generating response: {str(e)}\n\nDetails:\n{error_details}", prompt
def _process_reasoning_output(self, text: str) -> str:
"""Process output from reasoning models to separate thinking from answer"""
# Debug: print first 200 chars to see the format
print(f"[DEBUG] Processing output (first 200 chars): {text[:200]}")
# Common patterns for reasoning models
# Qwen uses ... tags (case-insensitive check)
if '' in text.lower():
# Extract reasoning and answer (case-insensitive)
reasoning_match = re.search(r'(.*?)', text, re.DOTALL | re.IGNORECASE)
if reasoning_match:
reasoning = reasoning_match.group(1).strip()
answer = re.sub(r'.*?', '', text, flags=re.DOTALL | re.IGNORECASE).strip()
print(f"[DEBUG] Found reasoning tokens! Reasoning length: {len(reasoning)}, Answer length: {len(answer)}")
return f"""**Answer:**
{answer}
---
🧠 Model Reasoning (click to expand)
```
{reasoning}
```
"""
# Alternative pattern: Look for common thinking patterns in text
# Some models output their reasoning inline without special tags
thinking_patterns = [
r'(Let me think.*?(?:Answer:|Response:|Conclusion:))',
r'(Okay, let\'s see.*?(?:Answer:|Response:|Conclusion:))',
r'(First,.*?(?:Therefore,|Thus,|So,|In conclusion,))',
]
for pattern in thinking_patterns:
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
if match:
reasoning = match.group(1).strip()
answer = text[match.end():].strip()
if len(reasoning) > 100 and len(answer) > 20: # Substantial reasoning and answer
print(f"[DEBUG] Found inline reasoning! Pattern matched.")
return f"""**Answer:**
{answer}
---
🧠 Model Reasoning (click to expand)
```
{reasoning}
```
"""
# Alternative pattern: text before "Answer:" or similar markers
if re.search(r'(Answer:|Final Answer:|Response:)', text, re.IGNORECASE):
parts = re.split(r'(Answer:|Final Answer:|Response:)', text, re.IGNORECASE)
if len(parts) >= 3:
reasoning = parts[0].strip()
answer = ''.join(parts[2:]).strip()
if reasoning and len(reasoning) > 50: # Only if there's substantial reasoning
print(f"[DEBUG] Found Answer: marker pattern")
return f"""**Answer:**
{answer}
---
🧠 Model Reasoning (click to expand)
```
{reasoning}
```
"""
# No reasoning pattern found, return as is
print(f"[DEBUG] No reasoning pattern found, returning as-is")
return text
def generate_example_questions(self, num_questions: int = 5) -> List[str]:
"""Generate example questions based on the corpus content"""
if not self.is_ready() or not self.chunks:
return [
"What is the main topic of this document?",
"Can you summarize the key points?",
"What are the main concepts discussed?",
]
# Sample some chunks to understand the corpus
sample_size = min(10, len(self.chunks))
import random
sample_chunks = random.sample(self.chunks, sample_size)
sample_text = "\n".join(sample_chunks[:3]) # Use first 3 sampled chunks
# Generate questions using the LLM
try:
if self.llm_client is None:
self.set_llm_model("meta-llama/Llama-3.2-1B-Instruct")
prompt = f"""Based on the following text excerpts, generate {num_questions} diverse and relevant questions that could be answered using this corpus. Make the questions specific and interesting.
Text excerpts:
{sample_text[:2000]}
Generate exactly {num_questions} questions, one per line, without numbering:"""
# Try chat_completion first, fallback to text_generation
try:
messages = [{"role": "user", "content": prompt}]
response = self.llm_client.chat_completion(
messages=messages,
max_tokens=300,
temperature=0.8,
)
# Extract questions
if hasattr(response, 'choices') and len(response.choices) > 0:
questions_text = response.choices[0].message.content.strip()
elif isinstance(response, dict) and 'choices' in response:
questions_text = response['choices'][0]['message']['content'].strip()
else:
questions_text = str(response).strip()
except Exception as chat_error:
print(f"Chat completion failed for questions, trying text_generation: {chat_error}")
response = self.llm_client.text_generation(
prompt,
max_new_tokens=300,
temperature=0.8,
return_full_text=False,
)
questions_text = response.strip() if isinstance(response, str) else str(response).strip()
# Clean up reasoning if present
questions_text = re.sub(r'.*?', '', questions_text, flags=re.DOTALL)
# Parse questions
questions = [q.strip() for q in questions_text.split('\n') if q.strip()]
# Remove numbering if present
questions = [re.sub(r'^\d+[\.\)]\s*', '', q) for q in questions]
# Filter out empty or very short questions
questions = [q for q in questions if len(q) > 10]
return questions[:num_questions] if questions else self._default_questions()
except Exception as e:
import traceback
print(f"Error generating questions: {e}")
print(f"Traceback: {traceback.format_exc()}")
return self._default_questions()
def _default_questions(self) -> List[str]:
"""Return default questions if generation fails"""
return [
"What is the main topic discussed in this corpus?",
"Can you summarize the key concepts?",
"What are the main findings or arguments presented?",
]