Spaces:
Runtime error
Runtime error
| import chromadb | |
| from chromadb.utils import embedding_functions | |
| import openai | |
| import os | |
| import logging | |
| from typing import List, Dict, Any, Optional | |
| import uuid | |
| from datetime import datetime | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| class RAGSystem: | |
| """Retrieval-Augmented Generation system for chatbot functionality""" | |
| def __init__(self, openai_api_key: str, persist_directory: str = "chroma_db"): | |
| self.client = openai.OpenAI(api_key=openai_api_key) | |
| # Initialize ChromaDB | |
| self.chroma_client = chromadb.PersistentClient(path=persist_directory) | |
| # Create embedding function | |
| self.embedding_function = embedding_functions.DefaultEmbeddingFunction() | |
| # Collections for different document types | |
| self.pdf_collection = self._get_or_create_collection("pdf_documents") | |
| self.lecture_collection = self._get_or_create_collection("lecture_content") | |
| def _get_or_create_collection(self, name: str): | |
| """Get existing collection or create new one""" | |
| try: | |
| return self.chroma_client.get_collection( | |
| name=name, | |
| embedding_function=self.embedding_function | |
| ) | |
| except: | |
| return self.chroma_client.create_collection( | |
| name=name, | |
| embedding_function=self.embedding_function, | |
| metadata={"description": f"Collection for {name}"} | |
| ) | |
| def add_pdf_content(self, session_id: str, pdf_content: str, metadata: Dict[str, Any] = None) -> bool: | |
| """Add PDF content to the vector database""" | |
| try: | |
| # Split content into chunks | |
| chunks = self._split_text(pdf_content, chunk_size=1000, overlap=200) | |
| # Prepare documents for insertion | |
| documents = [] | |
| metadatas = [] | |
| ids = [] | |
| base_metadata = { | |
| "session_id": session_id, | |
| "document_type": "pdf", | |
| "added_at": datetime.now().isoformat(), | |
| **(metadata or {}) | |
| } | |
| for i, chunk in enumerate(chunks): | |
| doc_id = f"{session_id}_pdf_{i}_{uuid.uuid4().hex[:8]}" | |
| documents.append(chunk) | |
| metadatas.append({ | |
| **base_metadata, | |
| "chunk_index": i, | |
| "chunk_id": doc_id | |
| }) | |
| ids.append(doc_id) | |
| # Add to collection | |
| self.pdf_collection.add( | |
| documents=documents, | |
| metadatas=metadatas, | |
| ids=ids | |
| ) | |
| logger.info(f"Added {len(chunks)} PDF chunks for session {session_id}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to add PDF content: {str(e)}") | |
| return False | |
| def add_lecture_content(self, session_id: str, lecture_content: str, metadata: Dict[str, Any] = None) -> bool: | |
| """Add lecture content to the vector database""" | |
| try: | |
| # Split content into chunks | |
| chunks = self._split_text(lecture_content, chunk_size=1000, overlap=200) | |
| documents = [] | |
| metadatas = [] | |
| ids = [] | |
| base_metadata = { | |
| "session_id": session_id, | |
| "document_type": "lecture", | |
| "added_at": datetime.now().isoformat(), | |
| **(metadata or {}) | |
| } | |
| for i, chunk in enumerate(chunks): | |
| doc_id = f"{session_id}_lecture_{i}_{uuid.uuid4().hex[:8]}" | |
| documents.append(chunk) | |
| metadatas.append({ | |
| **base_metadata, | |
| "chunk_index": i, | |
| "chunk_id": doc_id | |
| }) | |
| ids.append(doc_id) | |
| # Add to collection | |
| self.lecture_collection.add( | |
| documents=documents, | |
| metadatas=metadatas, | |
| ids=ids | |
| ) | |
| logger.info(f"Added {len(chunks)} lecture chunks for session {session_id}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to add lecture content: {str(e)}") | |
| return False | |
| def retrieve_relevant_content(self, session_id: str, query: str, n_results: int = 5) -> Dict[str, Any]: | |
| """Retrieve relevant content for a query""" | |
| try: | |
| # Search in both collections | |
| pdf_results = self.pdf_collection.query( | |
| query_texts=[query], | |
| n_results=n_results, | |
| where={"session_id": session_id} | |
| ) | |
| lecture_results = self.lecture_collection.query( | |
| query_texts=[query], | |
| n_results=n_results, | |
| where={"session_id": session_id} | |
| ) | |
| # Combine and rank results | |
| all_results = [] | |
| # Process PDF results | |
| if pdf_results['documents'] and pdf_results['documents'][0]: | |
| for i, doc in enumerate(pdf_results['documents'][0]): | |
| all_results.append({ | |
| 'content': doc, | |
| 'metadata': pdf_results['metadatas'][0][i], | |
| 'distance': pdf_results['distances'][0][i], | |
| 'source': 'pdf' | |
| }) | |
| # Process lecture results | |
| if lecture_results['documents'] and lecture_results['documents'][0]: | |
| for i, doc in enumerate(lecture_results['documents'][0]): | |
| all_results.append({ | |
| 'content': doc, | |
| 'metadata': lecture_results['metadatas'][0][i], | |
| 'distance': lecture_results['distances'][0][i], | |
| 'source': 'lecture' | |
| }) | |
| # Sort by relevance (distance) | |
| all_results.sort(key=lambda x: x['distance']) | |
| return { | |
| 'success': True, | |
| 'results': all_results[:n_results], | |
| 'total_found': len(all_results) | |
| } | |
| except Exception as e: | |
| logger.error(f"Content retrieval failed: {str(e)}") | |
| return { | |
| 'success': False, | |
| 'results': [], | |
| 'total_found': 0, | |
| 'error': str(e) | |
| } | |
| def _split_text(self, text: str, chunk_size: int = 1000, overlap: int = 200) -> List[str]: | |
| """Split text into overlapping chunks""" | |
| if len(text) <= chunk_size: | |
| return [text] | |
| chunks = [] | |
| start = 0 | |
| while start < len(text): | |
| end = start + chunk_size | |
| # Try to end at a sentence boundary | |
| if end < len(text): | |
| # Look for sentence endings within the last 100 characters | |
| search_start = max(end - 100, start) | |
| sentence_ends = [] | |
| for punct in ['. ', '! ', '? ', '\n\n']: | |
| pos = text.rfind(punct, search_start, end) | |
| if pos > start: | |
| sentence_ends.append(pos + len(punct)) | |
| if sentence_ends: | |
| end = max(sentence_ends) | |
| chunk = text[start:end].strip() | |
| if chunk: | |
| chunks.append(chunk) | |
| # Move start position with overlap | |
| start = end - overlap | |
| if start >= len(text): | |
| break | |
| return chunks | |
| def get_session_stats(self, session_id: str) -> Dict[str, Any]: | |
| """Get statistics about stored content for a session""" | |
| try: | |
| # Count PDF chunks | |
| pdf_count = len(self.pdf_collection.get( | |
| where={"session_id": session_id} | |
| )['ids']) | |
| # Count lecture chunks | |
| lecture_count = len(self.lecture_collection.get( | |
| where={"session_id": session_id} | |
| )['ids']) | |
| return { | |
| 'pdf_chunks': pdf_count, | |
| 'lecture_chunks': lecture_count, | |
| 'total_chunks': pdf_count + lecture_count | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to get session stats: {str(e)}") | |
| return { | |
| 'pdf_chunks': 0, | |
| 'lecture_chunks': 0, | |
| 'total_chunks': 0 | |
| } | |
| def clear_session_data(self, session_id: str) -> bool: | |
| """Clear all data for a specific session""" | |
| try: | |
| # Get all document IDs for this session | |
| pdf_ids = self.pdf_collection.get( | |
| where={"session_id": session_id} | |
| )['ids'] | |
| lecture_ids = self.lecture_collection.get( | |
| where={"session_id": session_id} | |
| )['ids'] | |
| # Delete documents | |
| if pdf_ids: | |
| self.pdf_collection.delete(ids=pdf_ids) | |
| if lecture_ids: | |
| self.lecture_collection.delete(ids=lecture_ids) | |
| logger.info(f"Cleared data for session {session_id}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to clear session data: {str(e)}") | |
| return False | |