Spaces:
Sleeping
Sleeping
Teja Chowdary
Fix all problematic imports: fitz, PIL, cv2 - use conditional imports and mock functions for cloud compatibility
11aefa0 | from __future__ import annotations | |
| from typing import List, Dict, Optional, Tuple, Union | |
| import os | |
| import re | |
| from pathlib import Path | |
| from sentence_transformers import SentenceTransformer | |
| try: | |
| from langchain_chroma import Chroma # Updated import | |
| CHROMA_AVAILABLE = True | |
| except ImportError: | |
| from langchain.vectorstores import Chroma # Fallback import | |
| CHROMA_AVAILABLE = False | |
| from langchain.embeddings.base import Embeddings | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.schema import Document | |
| import chromadb | |
| import numpy as np | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| # Mock imports for cloud compatibility | |
| FITZ_AVAILABLE = False | |
| OCR_AVAILABLE = False | |
| print("⚠️ PyMuPDF and OCR libraries not available in cloud environment - using fallbacks") | |
| import json | |
| class STEmbedding(Embeddings): | |
| """LangChain-compatible wrapper for sentence-transformers.""" | |
| def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"): | |
| self.model = SentenceTransformer(model_name) | |
| self.model_name = model_name # Store the model name for access | |
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
| embeddings = self.model.encode(texts, convert_to_numpy=True) | |
| # Convert numpy arrays to lists | |
| if len(embeddings.shape) == 1: | |
| embeddings = embeddings.reshape(1, -1) | |
| return embeddings.tolist() | |
| def embed_query(self, text: str) -> List[float]: | |
| embedding = self.model.encode([text], convert_to_numpy=True) | |
| return embedding[0].tolist() | |
| class AdvancedDocumentProcessor: | |
| """Enhanced document processing with multi-modal support.""" | |
| def load_text_file(file_path: str) -> str: | |
| """Load text from a text file.""" | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| return f.read() | |
| except Exception as e: | |
| print(f"Error loading text file {file_path}: {e}") | |
| return "" | |
| def load_markdown_file(file_path: str) -> str: | |
| """Load text from a markdown file.""" | |
| return AdvancedDocumentProcessor.load_text_file(file_path) | |
| def load_pdf_file(file_path: str) -> str: | |
| """Extract text from PDF files.""" | |
| if not FITZ_AVAILABLE: | |
| print(f"⚠️ PyMuPDF not available - cannot process PDF: {file_path}") | |
| return "" | |
| try: | |
| # Mock PDF processing for cloud compatibility | |
| print(f"⚠️ PDF processing not available in cloud environment: {file_path}") | |
| return f"[PDF content from {file_path} - processing not available in cloud]" | |
| except Exception as e: | |
| print(f"Error loading PDF file {file_path}: {e}") | |
| return "" | |
| def extract_text_from_image(image_path: str) -> str: | |
| """Extract text from images using OCR.""" | |
| if not OCR_AVAILABLE: | |
| print(f"⚠️ OCR libraries not available - cannot extract text from image: {image_path}") | |
| return "" | |
| try: | |
| # Mock OCR processing for cloud compatibility | |
| print(f"⚠️ OCR processing not available in cloud environment: {image_path}") | |
| return f"[Image content from {image_path} - OCR not available in cloud]" | |
| except Exception as e: | |
| print(f"Error extracting text from image {image_path}: {e}") | |
| return "" | |
| def split_text_adaptive(text: str, content_type: str = "general") -> List[str]: | |
| """Adaptive text splitting based on content type.""" | |
| if content_type == "code": | |
| # For code, split by functions/classes | |
| chunk_size = 800 | |
| chunk_overlap = 100 | |
| elif content_type == "academic": | |
| # For academic papers, split by sections | |
| chunk_size = 1200 | |
| chunk_overlap = 200 | |
| elif content_type == "conversation": | |
| # For conversations, split by turns | |
| chunk_size = 600 | |
| chunk_overlap = 100 | |
| else: | |
| # Default settings | |
| chunk_size = 1000 | |
| chunk_overlap = 200 | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=chunk_size, | |
| chunk_overlap=chunk_overlap, | |
| length_function=len, | |
| separators=["\n\n", "\n", " ", ""] | |
| ) | |
| return text_splitter.split_text(text) | |
| class HybridSearchEngine: | |
| """Combines semantic and keyword search for better retrieval.""" | |
| def __init__(self, documents: List[str]): | |
| self.documents = documents | |
| self.tfidf_vectorizer = TfidfVectorizer( | |
| max_features=1000, | |
| stop_words='english', | |
| ngram_range=(1, 2) | |
| ) | |
| self.tfidf_matrix = None | |
| self._build_tfidf_index() | |
| def _build_tfidf_index(self): | |
| """Build TF-IDF index for keyword search.""" | |
| try: | |
| self.tfidf_matrix = self.tfidf_vectorizer.fit_transform(self.documents) | |
| except Exception as e: | |
| print(f"Error building TF-IDF index: {e}") | |
| self.tfidf_matrix = None | |
| def keyword_search(self, query: str, top_k: int = 5) -> List[Tuple[int, float]]: | |
| """Perform keyword-based search using TF-IDF.""" | |
| if self.tfidf_matrix is None: | |
| return [] | |
| try: | |
| query_vector = self.tfidf_vectorizer.transform([query]) | |
| similarities = cosine_similarity(query_vector, self.tfidf_matrix).flatten() | |
| # Get top-k results | |
| top_indices = similarities.argsort()[-top_k:][::-1] | |
| results = [(idx, similarities[idx]) for idx in top_indices if similarities[idx] > 0] | |
| return results | |
| except Exception as e: | |
| print(f"Error in keyword search: {e}") | |
| return [] | |
| def hybrid_search(self, query: str, semantic_scores: List[float], | |
| alpha: float = 0.7, top_k: int = 5) -> List[Tuple[int, float]]: | |
| """Combine semantic and keyword search scores.""" | |
| if not semantic_scores: | |
| return [] | |
| # Get keyword search scores | |
| keyword_results = self.keyword_search(query, top_k=len(semantic_scores)) | |
| keyword_scores = [0.0] * len(semantic_scores) | |
| for idx, score in keyword_results: | |
| if idx < len(keyword_scores): | |
| keyword_scores[idx] = score | |
| # Normalize scores | |
| if max(semantic_scores) > 0: | |
| semantic_scores = [s / max(semantic_scores) for s in semantic_scores] | |
| if max(keyword_scores) > 0: | |
| keyword_scores = [s / max(keyword_scores) for s in keyword_scores] | |
| # Combine scores | |
| hybrid_scores = [] | |
| for i in range(len(semantic_scores)): | |
| hybrid_score = alpha * semantic_scores[i] + (1 - alpha) * keyword_scores[i] | |
| hybrid_scores.append((i, hybrid_score)) | |
| # Sort by hybrid score and return top-k | |
| hybrid_scores.sort(key=lambda x: x[1], reverse=True) | |
| return hybrid_scores[:top_k] | |
| class QueryExpander: | |
| """Expands queries to improve retrieval.""" | |
| def __init__(self): | |
| self.synonyms = { | |
| "math": ["mathematics", "mathematical", "calculation", "computation"], | |
| "programming": ["coding", "software development", "computer programming"], | |
| "physics": ["physical science", "mechanics", "dynamics"], | |
| "chemistry": ["chemical science", "molecular science"], | |
| "biology": ["biological science", "life science"], | |
| "study": ["learn", "research", "investigate", "examine"], | |
| "understand": ["comprehend", "grasp", "fathom", "realize"], | |
| "practice": ["exercise", "drill", "rehearse", "train"] | |
| } | |
| def expand_query(self, query: str) -> List[str]: | |
| """Expand a query with synonyms and related terms.""" | |
| expanded_queries = [query] | |
| # Add synonyms | |
| for word in query.lower().split(): | |
| if word in self.synonyms: | |
| for synonym in self.synonyms[word]: | |
| expanded_queries.append(query.replace(word, synonym)) | |
| # Add related terms | |
| if "derivative" in query.lower(): | |
| expanded_queries.extend([ | |
| query + " calculus", | |
| query + " differentiation", | |
| query + " rate of change" | |
| ]) | |
| if "algorithm" in query.lower(): | |
| expanded_queries.extend([ | |
| query + " computer science", | |
| query + " programming", | |
| query + " data structure" | |
| ]) | |
| return list(set(expanded_queries)) # Remove duplicates | |
| class EnhancedRAG: | |
| """Enhanced RAG system with advanced features.""" | |
| def __init__(self, persist_path: str | None = None, | |
| embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"): | |
| persist_path = persist_path or os.getenv("RAG_DB_PATH", ".smartlearn/chroma_db") | |
| os.makedirs(persist_path, exist_ok=True) | |
| self.emb = STEmbedding(model_name=embedding_model) | |
| self.persist_path = persist_path | |
| self.vs = None | |
| self.processor = AdvancedDocumentProcessor() | |
| self.documents_loaded = False | |
| self.documents = [] | |
| self.hybrid_search_engine = None | |
| self.query_expander = QueryExpander() | |
| # Initialize vector store | |
| self._init_vector_store() | |
| def _init_vector_store(self): | |
| """Initialize the vector store with proper configuration.""" | |
| try: | |
| # Create ChromaDB client | |
| client = chromadb.PersistentClient(path=self.persist_path) | |
| # Initialize Chroma vector store | |
| self.vs = Chroma( | |
| collection_name="smartlearn", | |
| embedding_function=self.emb, | |
| client=client | |
| ) | |
| except Exception as e: | |
| print(f"Error initializing vector store: {e}") | |
| # Fallback to in-memory store | |
| self.vs = Chroma( | |
| collection_name="smartlearn", | |
| embedding_function=self.emb | |
| ) | |
| def load_knowledge_base(self, knowledge_base_path: str) -> bool: | |
| """Load documents from the knowledge base directory with multi-modal support.""" | |
| try: | |
| kb_path = Path(knowledge_base_path) | |
| if not kb_path.exists(): | |
| print(f"Knowledge base path not found: {knowledge_base_path}") | |
| return False | |
| # Clear existing documents - use safer method | |
| if hasattr(self.vs, '_collection') and self.vs._collection: | |
| try: | |
| # Try to get existing documents first | |
| existing = self.vs._collection.get() | |
| if existing and 'ids' in existing and existing['ids']: | |
| # Delete by IDs if they exist | |
| self.vs._collection.delete(ids=existing['ids']) | |
| print("✅ Cleared existing documents") | |
| except Exception as e: | |
| print(f"⚠️ Warning: Could not clear existing documents: {e}") | |
| # Continue anyway | |
| documents = [] | |
| chunk_id = 0 | |
| # Process different file types | |
| for file_path in kb_path.rglob("*"): | |
| if file_path.is_file(): | |
| content = "" | |
| content_type = "general" | |
| # Determine content type and load accordingly | |
| if file_path.suffix.lower() == '.md': | |
| content = self.processor.load_markdown_file(str(file_path)) | |
| content_type = "academic" | |
| elif file_path.suffix.lower() == '.txt': | |
| content = self.processor.load_text_file(str(file_path)) | |
| elif file_path.suffix.lower() == '.pdf': | |
| content = self.processor.load_pdf_file(str(file_path)) | |
| content_type = "academic" | |
| elif file_path.suffix.lower() in ['.jpg', '.jpeg', '.png', '.gif']: | |
| content = self.processor.extract_text_from_image(str(file_path)) | |
| content_type = "image" | |
| elif file_path.suffix.lower() == '.py': | |
| content = self.processor.load_text_file(str(file_path)) | |
| content_type = "code" | |
| if content.strip(): | |
| # Split content adaptively | |
| chunks = self.processor.split_text_adaptive(content, content_type) | |
| for i, chunk in enumerate(chunks): | |
| if chunk.strip(): | |
| # Create document with metadata | |
| doc = Document( | |
| page_content=chunk, | |
| metadata={ | |
| "source": str(file_path), | |
| "chunk_id": chunk_id, | |
| "content_type": content_type, | |
| "chunk_index": i, | |
| "total_chunks": len(chunks) | |
| } | |
| ) | |
| documents.append(doc) | |
| chunk_id += 1 | |
| if documents: | |
| # Add documents to vector store | |
| self.vs.add_documents(documents) | |
| self.documents = [doc.page_content for doc in documents] | |
| # Initialize hybrid search engine | |
| self.hybrid_search_engine = HybridSearchEngine(self.documents) | |
| self.documents_loaded = True | |
| print(f"✅ Loaded {len(documents)} document chunks from knowledge base") | |
| return True | |
| else: | |
| print("❌ No documents found in knowledge base") | |
| return False | |
| except Exception as e: | |
| print(f"❌ Error loading knowledge base: {e}") | |
| return False | |
| def retrieve_relevant_context(self, query: str, k: int = 3, | |
| use_hybrid: bool = True, alpha: float = 0.7) -> Tuple[str, List[Dict]]: | |
| """Retrieve relevant context using enhanced search.""" | |
| try: | |
| if not self.documents_loaded: | |
| return "", [] | |
| # Expand query | |
| expanded_queries = self.query_expander.expand_query(query) | |
| # Get semantic search results | |
| semantic_results = [] | |
| for exp_query in expanded_queries: | |
| results = self.vs.similarity_search_with_score(exp_query, k=k) | |
| semantic_results.extend(results) | |
| # Remove duplicates and get top results | |
| unique_results = {} | |
| for doc, score in semantic_results: | |
| if doc.metadata.get('chunk_id') not in unique_results: | |
| unique_results[doc.metadata.get('chunk_id')] = (doc, score) | |
| # Sort by score and get top-k | |
| sorted_results = sorted(unique_results.values(), key=lambda x: x[1], reverse=True)[:k] | |
| if use_hybrid and self.hybrid_search_engine: | |
| # Use hybrid search | |
| semantic_scores = [score for _, score in sorted_results] | |
| hybrid_results = self.hybrid_search_engine.hybrid_search( | |
| query, semantic_scores, alpha=alpha, top_k=k | |
| ) | |
| # Reorder results based on hybrid scores | |
| final_results = [] | |
| for idx, hybrid_score in hybrid_results: | |
| if idx < len(sorted_results): | |
| doc, _ = sorted_results[idx] | |
| final_results.append((doc, hybrid_score)) | |
| sorted_results = final_results | |
| # Format results | |
| context = "\n\n".join([doc.page_content for doc, _ in sorted_results]) | |
| sources = [] | |
| for doc, score in sorted_results: | |
| sources.append({ | |
| "source": doc.metadata.get("source", "Unknown"), | |
| "chunk_id": doc.metadata.get("chunk_id", "Unknown"), | |
| "content_type": doc.metadata.get("content_type", "general"), | |
| "relevance_score": float(score), | |
| "chunk_index": doc.metadata.get("chunk_index", 0), | |
| "total_chunks": doc.metadata.get("total_chunks", 1) | |
| }) | |
| return context, sources | |
| except Exception as e: | |
| print(f"❌ Error retrieving context: {e}") | |
| return "", [] | |
| def get_knowledge_base_stats(self) -> Dict: | |
| """Get comprehensive statistics about the loaded knowledge base.""" | |
| try: | |
| if not self.documents_loaded: | |
| return {"status": "No documents loaded", "count": 0} | |
| # Get collection info | |
| if hasattr(self.vs, '_collection') and self.vs._collection: | |
| count = self.vs._collection.count() | |
| else: | |
| count = "Unknown" | |
| # Analyze content types | |
| content_types = {} | |
| if hasattr(self.vs, '_collection') and self.vs._collection: | |
| try: | |
| results = self.vs._collection.get() | |
| if results and 'metadatas' in results: | |
| for metadata in results['metadatas']: | |
| if metadata and 'content_type' in metadata: | |
| content_type = metadata['content_type'] | |
| content_types[content_type] = content_types.get(content_type, 0) + 1 | |
| except Exception as e: | |
| print(f"Warning: Could not analyze content types: {e}") | |
| return { | |
| "status": "Documents loaded", | |
| "count": count, | |
| "embedding_model": self.emb.model_name, | |
| "content_types": content_types, | |
| "hybrid_search": self.hybrid_search_engine is not None, | |
| "query_expansion": True | |
| } | |
| except Exception as e: | |
| return {"status": f"Error: {e}", "count": 0} | |
| def search_similar_documents(self, query: str, k: int = 5) -> List[Dict]: | |
| """Search for similar documents with enhanced metadata.""" | |
| try: | |
| if not self.documents_loaded: | |
| return [] | |
| # Use hybrid search if available | |
| if self.hybrid_search_engine: | |
| context, sources = self.retrieve_relevant_context(query, k=k, use_hybrid=True) | |
| else: | |
| context, sources = self.retrieve_relevant_context(query, k=k, use_hybrid=False) | |
| return sources | |
| except Exception as e: | |
| print(f"❌ Error searching documents: {e}") | |
| return [] | |
| def get_document_summary(self, source: str) -> Dict: | |
| """Get summary information about a specific document.""" | |
| try: | |
| if not self.documents_loaded: | |
| return {} | |
| # Find all chunks from this source | |
| source_chunks = [] | |
| if hasattr(self.vs, '_collection') and self.vs._collection: | |
| try: | |
| results = self.vs._collection.get( | |
| where={"source": source} | |
| ) | |
| if results and 'metadatas' in results: | |
| source_chunks = results['metadatas'] | |
| except Exception as e: | |
| print(f"Warning: Could not get source chunks: {e}") | |
| if not source_chunks: | |
| return {} | |
| # Analyze the document | |
| total_chunks = len(source_chunks) | |
| content_types = set() | |
| chunk_sizes = [] | |
| for chunk in source_chunks: | |
| if chunk and 'content_type' in chunk: | |
| content_types.add(chunk['content_type']) | |
| if chunk and 'chunk_index' in chunk: | |
| chunk_sizes.append(chunk.get('chunk_index', 0)) | |
| return { | |
| "source": source, | |
| "total_chunks": total_chunks, | |
| "content_types": list(content_types), | |
| "chunk_range": f"{min(chunk_sizes)} - {max(chunk_sizes)}" if chunk_sizes else "Unknown", | |
| "estimated_pages": max(1, total_chunks // 3) # Rough estimate | |
| } | |
| except Exception as e: | |
| print(f"❌ Error getting document summary: {e}") | |
| return {} | |
| class SimpleRAG: | |
| """Backward compatibility wrapper for EnhancedRAG.""" | |
| def __init__(self, persist_path: str | None = None, embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"): | |
| self.enhanced_rag = EnhancedRAG(persist_path, embedding_model) | |
| def add_documents(self, docs: List[str]): | |
| """Add documents to the knowledge base.""" | |
| # This is a simplified interface for backward compatibility | |
| pass | |
| def retrieve(self, query: str, k: int = 3) -> str: | |
| """Retrieve relevant context.""" | |
| context, _ = self.enhanced_rag.retrieve_relevant_context(query, k) | |
| return context | |