AI_tutor / app /rag_system.py
vishalshelke's picture
Upload 10 files
a2438f7 verified
import chromadb
from chromadb.utils import embedding_functions
import openai
import os
import logging
from typing import List, Dict, Any, Optional
import uuid
from datetime import datetime
import numpy as np
logger = logging.getLogger(__name__)
class RAGSystem:
"""Retrieval-Augmented Generation system for chatbot functionality"""
def __init__(self, openai_api_key: str, persist_directory: str = "chroma_db"):
self.client = openai.OpenAI(api_key=openai_api_key)
# Initialize ChromaDB
self.chroma_client = chromadb.PersistentClient(path=persist_directory)
# Create embedding function
self.embedding_function = embedding_functions.DefaultEmbeddingFunction()
# Collections for different document types
self.pdf_collection = self._get_or_create_collection("pdf_documents")
self.lecture_collection = self._get_or_create_collection("lecture_content")
def _get_or_create_collection(self, name: str):
"""Get existing collection or create new one"""
try:
return self.chroma_client.get_collection(
name=name,
embedding_function=self.embedding_function
)
except:
return self.chroma_client.create_collection(
name=name,
embedding_function=self.embedding_function,
metadata={"description": f"Collection for {name}"}
)
def add_pdf_content(self, session_id: str, pdf_content: str, metadata: Dict[str, Any] = None) -> bool:
"""Add PDF content to the vector database"""
try:
# Split content into chunks
chunks = self._split_text(pdf_content, chunk_size=1000, overlap=200)
# Prepare documents for insertion
documents = []
metadatas = []
ids = []
base_metadata = {
"session_id": session_id,
"document_type": "pdf",
"added_at": datetime.now().isoformat(),
**(metadata or {})
}
for i, chunk in enumerate(chunks):
doc_id = f"{session_id}_pdf_{i}_{uuid.uuid4().hex[:8]}"
documents.append(chunk)
metadatas.append({
**base_metadata,
"chunk_index": i,
"chunk_id": doc_id
})
ids.append(doc_id)
# Add to collection
self.pdf_collection.add(
documents=documents,
metadatas=metadatas,
ids=ids
)
logger.info(f"Added {len(chunks)} PDF chunks for session {session_id}")
return True
except Exception as e:
logger.error(f"Failed to add PDF content: {str(e)}")
return False
def add_lecture_content(self, session_id: str, lecture_content: str, metadata: Dict[str, Any] = None) -> bool:
"""Add lecture content to the vector database"""
try:
# Split content into chunks
chunks = self._split_text(lecture_content, chunk_size=1000, overlap=200)
documents = []
metadatas = []
ids = []
base_metadata = {
"session_id": session_id,
"document_type": "lecture",
"added_at": datetime.now().isoformat(),
**(metadata or {})
}
for i, chunk in enumerate(chunks):
doc_id = f"{session_id}_lecture_{i}_{uuid.uuid4().hex[:8]}"
documents.append(chunk)
metadatas.append({
**base_metadata,
"chunk_index": i,
"chunk_id": doc_id
})
ids.append(doc_id)
# Add to collection
self.lecture_collection.add(
documents=documents,
metadatas=metadatas,
ids=ids
)
logger.info(f"Added {len(chunks)} lecture chunks for session {session_id}")
return True
except Exception as e:
logger.error(f"Failed to add lecture content: {str(e)}")
return False
def retrieve_relevant_content(self, session_id: str, query: str, n_results: int = 5) -> Dict[str, Any]:
"""Retrieve relevant content for a query"""
try:
# Search in both collections
pdf_results = self.pdf_collection.query(
query_texts=[query],
n_results=n_results,
where={"session_id": session_id}
)
lecture_results = self.lecture_collection.query(
query_texts=[query],
n_results=n_results,
where={"session_id": session_id}
)
# Combine and rank results
all_results = []
# Process PDF results
if pdf_results['documents'] and pdf_results['documents'][0]:
for i, doc in enumerate(pdf_results['documents'][0]):
all_results.append({
'content': doc,
'metadata': pdf_results['metadatas'][0][i],
'distance': pdf_results['distances'][0][i],
'source': 'pdf'
})
# Process lecture results
if lecture_results['documents'] and lecture_results['documents'][0]:
for i, doc in enumerate(lecture_results['documents'][0]):
all_results.append({
'content': doc,
'metadata': lecture_results['metadatas'][0][i],
'distance': lecture_results['distances'][0][i],
'source': 'lecture'
})
# Sort by relevance (distance)
all_results.sort(key=lambda x: x['distance'])
return {
'success': True,
'results': all_results[:n_results],
'total_found': len(all_results)
}
except Exception as e:
logger.error(f"Content retrieval failed: {str(e)}")
return {
'success': False,
'results': [],
'total_found': 0,
'error': str(e)
}
def _split_text(self, text: str, chunk_size: int = 1000, overlap: int = 200) -> List[str]:
"""Split text into overlapping chunks"""
if len(text) <= chunk_size:
return [text]
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
# Try to end at a sentence boundary
if end < len(text):
# Look for sentence endings within the last 100 characters
search_start = max(end - 100, start)
sentence_ends = []
for punct in ['. ', '! ', '? ', '\n\n']:
pos = text.rfind(punct, search_start, end)
if pos > start:
sentence_ends.append(pos + len(punct))
if sentence_ends:
end = max(sentence_ends)
chunk = text[start:end].strip()
if chunk:
chunks.append(chunk)
# Move start position with overlap
start = end - overlap
if start >= len(text):
break
return chunks
def get_session_stats(self, session_id: str) -> Dict[str, Any]:
"""Get statistics about stored content for a session"""
try:
# Count PDF chunks
pdf_count = len(self.pdf_collection.get(
where={"session_id": session_id}
)['ids'])
# Count lecture chunks
lecture_count = len(self.lecture_collection.get(
where={"session_id": session_id}
)['ids'])
return {
'pdf_chunks': pdf_count,
'lecture_chunks': lecture_count,
'total_chunks': pdf_count + lecture_count
}
except Exception as e:
logger.error(f"Failed to get session stats: {str(e)}")
return {
'pdf_chunks': 0,
'lecture_chunks': 0,
'total_chunks': 0
}
def clear_session_data(self, session_id: str) -> bool:
"""Clear all data for a specific session"""
try:
# Get all document IDs for this session
pdf_ids = self.pdf_collection.get(
where={"session_id": session_id}
)['ids']
lecture_ids = self.lecture_collection.get(
where={"session_id": session_id}
)['ids']
# Delete documents
if pdf_ids:
self.pdf_collection.delete(ids=pdf_ids)
if lecture_ids:
self.lecture_collection.delete(ids=lecture_ids)
logger.info(f"Cleared data for session {session_id}")
return True
except Exception as e:
logger.error(f"Failed to clear session data: {str(e)}")
return False