Spaces:
Running
Running
| # rag_core.py - Chroma Cloud Integration | |
| import os | |
| import sys | |
| import numpy as np | |
| import json | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| import hashlib | |
| import requests | |
| import re | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import nltk | |
| from nltk.corpus import stopwords | |
| from nltk.tokenize import sent_tokenize, word_tokenize | |
| from nltk.stem import PorterStemmer | |
| from typing import List, Dict, Tuple | |
| import time | |
| from dotenv import load_dotenv | |
| import chromadb | |
| # Load environment variables | |
| load_dotenv() | |
| # --- ROBUST NLTK SETUP --- | |
| # Point NLTK to the local 'nltk_data' directory if it exists. | |
| # On Render, this is created during the build step by download_nltk.py | |
| local_nltk_data_path = os.path.join(os.path.dirname(__file__), 'nltk_data') | |
| if os.path.exists(local_nltk_data_path): | |
| nltk.data.path.insert(0, local_nltk_data_path) | |
| # If nltk_data doesn't exist locally, NLTK will use default paths or download on-demand | |
| # --- END SETUP --- | |
| # Model configuration - matching app.py | |
| MODEL_MAP = { | |
| 'gemini': 'google/gemma-3-4b-it:free', | |
| 'deepseek': 'google/gemma-3-27b-it:free', | |
| 'qwen': 'mistralai/mistral-small-3.1-24b-instruct:free', | |
| 'nvidia': 'nvidia/nemotron-nano-12b-v2-vl:free', | |
| 'amazon': 'amazon/nova-2-lite-v1:free' | |
| } | |
| # Best → fallback order (OCR strength) | |
| FALLBACK_ORDER = [ | |
| 'gemini', | |
| 'deepseek', | |
| 'qwen', | |
| 'nvidia', | |
| 'amazon' | |
| ] | |
| # Chroma Cloud configuration | |
| CHROMA_TENANT = os.getenv("CHROMA_TENANT") | |
| CHROMA_DATABASE = os.getenv("CHROMA_DATABASE") | |
| CHROMA_API_KEY = os.getenv("CHROMA_API_KEY") | |
| embedding_model = None | |
| reranker_model = None | |
| chroma_client = None | |
| collections: Dict[str, chromadb.Collection] = {} | |
| keyword_indexes: Dict[str, Dict[str, Dict]] = {} | |
| EMBEDDING_DIM = 768 | |
| CHUNK_SIZE = 300 | |
| CHUNK_OVERLAP = 50 | |
| # Track if RAG system is properly initialized | |
| _rag_system_available = False | |
| # Initialize components | |
| def initialize_rag_system(): | |
| """ | |
| Loads the embedding model, reranker, and connects to Chroma Cloud. | |
| Returns True if successful, False otherwise. | |
| """ | |
| global embedding_model, reranker_model, chroma_client, _rag_system_available | |
| print("RAG Core: Initializing Advanced RAG System with Chroma Cloud...") | |
| # Validate Chroma Cloud credentials - graceful handling | |
| if not all([CHROMA_TENANT, CHROMA_DATABASE, CHROMA_API_KEY]): | |
| print("WARNING: Chroma Cloud credentials not found. RAG system will be disabled.") | |
| print(" Set CHROMA_TENANT, CHROMA_DATABASE, and CHROMA_API_KEY to enable RAG.") | |
| _rag_system_available = False | |
| return False | |
| try: | |
| # Connect to Chroma Cloud | |
| print("RAG Core: Connecting to Chroma Cloud...") | |
| chroma_client = chromadb.CloudClient( | |
| tenant=CHROMA_TENANT, | |
| database=CHROMA_DATABASE, | |
| api_key=CHROMA_API_KEY | |
| ) | |
| print("RAG Core: Successfully connected to Chroma Cloud!") | |
| print("RAG Core: Loading advanced embedding model...") | |
| embedding_model = SentenceTransformer('all-mpnet-base-v2') | |
| print("RAG Core: Loading cross-encoder reranker...") | |
| reranker_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
| print("RAG Core: Advanced models loaded successfully.") | |
| _rag_system_available = True | |
| return True | |
| except Exception as e: | |
| print(f"ERROR: Failed to initialize RAG system: {e}") | |
| print(" RAG system will be disabled. The app will still work for OCR.") | |
| _rag_system_available = False | |
| return False | |
| def is_rag_available(): | |
| """Check if RAG system is available.""" | |
| return _rag_system_available | |
| def _call_openrouter_api_with_fallback(api_key, selected_model_key, prompt): | |
| """ | |
| Calls OpenRouter API with fallback support for text-only requests. | |
| """ | |
| # Start with the selected model, then try others in fallback order | |
| models_to_try = [selected_model_key] | |
| for model in FALLBACK_ORDER: | |
| if model != selected_model_key: | |
| models_to_try.append(model) | |
| last_error = None | |
| for model_key in models_to_try: | |
| model_name = MODEL_MAP.get(model_key) | |
| if not model_name: | |
| continue | |
| print(f"RAG: Attempting API call with model: {model_name}...") | |
| try: | |
| response = requests.post( | |
| url="https://openrouter.ai/api/v1/chat/completions", | |
| headers={ | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json" | |
| }, | |
| json={ | |
| "model": model_name, | |
| "messages": [{"role": "user", "content": prompt}] | |
| }, | |
| timeout=15 # Add timeout for faster failure recovery | |
| ) | |
| response.raise_for_status() | |
| api_response = response.json() | |
| if 'choices' not in api_response or not api_response['choices']: | |
| print(f"RAG: Model {model_name} returned unexpected response format") | |
| last_error = f"Model {model_name} returned unexpected response format" | |
| continue | |
| result = api_response['choices'][0]['message']['content'] | |
| print(f"RAG: Successfully processed with model: {model_name}") | |
| return result | |
| except requests.exceptions.HTTPError as http_err: | |
| error_msg = f"RAG: HTTP error for model {model_name}: {http_err}" | |
| if hasattr(response, 'text'): | |
| error_msg += f"\nResponse: {response.text}" | |
| print(error_msg) | |
| last_error = f"API request failed for {model_name} with status {response.status_code}." | |
| continue | |
| except Exception as e: | |
| print(f"RAG: Error with model {model_name}: {e}") | |
| last_error = f"An unexpected error occurred with model {model_name}." | |
| continue | |
| # If all models failed, return a user-friendly error | |
| return f"I'm having trouble connecting to the AI models right now. Please check your API key and try again. Last error: {last_error}" | |
| def _get_collection_name(user_api_key, mode): | |
| """ | |
| Creates a unique collection name for a user based on a hash of their API key. | |
| """ | |
| user_hash = hashlib.sha256(user_api_key.encode()).hexdigest()[:16] | |
| return f"{user_hash}_{mode}" | |
| def _get_or_create_collection(user_api_key, mode): | |
| """ | |
| Gets or creates a ChromaDB collection for the user/mode combination. | |
| """ | |
| collection_name = _get_collection_name(user_api_key, mode) | |
| if collection_name in collections: | |
| return collections[collection_name] | |
| print(f"RAG Core: Getting/creating collection '{collection_name}' in Chroma Cloud") | |
| collection = chroma_client.get_or_create_collection( | |
| name=collection_name, | |
| metadata={"hnsw:space": "cosine"} # Use cosine similarity | |
| ) | |
| collections[collection_name] = collection | |
| # Load keyword index from collection if exists | |
| _load_keyword_index(user_api_key, mode) | |
| return collection | |
| def _load_keyword_index(user_api_key, mode): | |
| """ | |
| Loads keyword index from Chroma Cloud collection metadata. | |
| """ | |
| collection_name = _get_collection_name(user_api_key, mode) | |
| if mode not in keyword_indexes: | |
| keyword_indexes[mode] = {} | |
| if user_api_key in keyword_indexes[mode]: | |
| return | |
| try: | |
| collection = collections.get(collection_name) | |
| if collection: | |
| # Try to get keyword index document | |
| results = collection.get( | |
| ids=["__keyword_index__"], | |
| include=["documents"] | |
| ) | |
| if results and results['documents'] and results['documents'][0]: | |
| keyword_indexes[mode][user_api_key] = json.loads(results['documents'][0]) | |
| print(f"RAG Core: Loaded keyword index from Chroma Cloud") | |
| else: | |
| keyword_indexes[mode][user_api_key] = {"documents": {}, "vocabulary": {}, "entities": {}} | |
| else: | |
| keyword_indexes[mode][user_api_key] = {"documents": {}, "vocabulary": {}, "entities": {}} | |
| except Exception as e: | |
| print(f"RAG Core: Could not load keyword index: {e}") | |
| keyword_indexes[mode][user_api_key] = {"documents": {}, "vocabulary": {}, "entities": {}} | |
| def _save_keyword_index(user_api_key, mode): | |
| """ | |
| Saves keyword index to Chroma Cloud collection. | |
| """ | |
| collection_name = _get_collection_name(user_api_key, mode) | |
| collection = collections.get(collection_name) | |
| if not collection or mode not in keyword_indexes or user_api_key not in keyword_indexes[mode]: | |
| return | |
| keyword_data = json.dumps(keyword_indexes[mode][user_api_key]) | |
| try: | |
| # Upsert the keyword index document | |
| collection.upsert( | |
| ids=["__keyword_index__"], | |
| documents=[keyword_data], | |
| metadatas=[{"type": "keyword_index"}] | |
| ) | |
| print("RAG Core: Saved keyword index to Chroma Cloud") | |
| except Exception as e: | |
| print(f"RAG Core: Error saving keyword index: {e}") | |
| def _smart_chunking(text, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP): | |
| """ | |
| Intelligent chunking that preserves context and meaning. | |
| """ | |
| if not isinstance(text, str) or not text.strip(): | |
| return [] | |
| paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()] | |
| chunks = [] | |
| current_chunk = "" | |
| for paragraph in paragraphs: | |
| if len(current_chunk) + len(paragraph) <= chunk_size: | |
| if current_chunk: | |
| current_chunk += "\n\n" + paragraph | |
| else: | |
| current_chunk = paragraph | |
| else: | |
| if current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| if len(paragraph) > chunk_size: | |
| sentences = nltk.sent_tokenize(paragraph) | |
| temp_chunk = "" | |
| for sentence in sentences: | |
| if len(temp_chunk) + len(sentence) <= chunk_size: | |
| temp_chunk += " " + sentence if temp_chunk else sentence | |
| else: | |
| if temp_chunk: | |
| chunks.append(temp_chunk.strip()) | |
| temp_chunk = sentence | |
| current_chunk = temp_chunk | |
| else: | |
| current_chunk = paragraph | |
| if current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| final_chunks = [] | |
| for i, chunk in enumerate(chunks): | |
| if i > 0 and chunk_overlap > 0: | |
| prev_words = chunks[i-1].split()[-chunk_overlap:] | |
| if prev_words: | |
| chunk = " ".join(prev_words) + " " + chunk | |
| final_chunks.append(chunk) | |
| return final_chunks | |
| def _enhanced_query_expansion(query: str) -> List[str]: | |
| """ | |
| Advanced query expansion with business context awareness. | |
| """ | |
| query_lower = query.lower() | |
| expanded_queries = {query} | |
| business_expansions = { | |
| r"\bgeneral manager\b": ["GM", "manager", "head", "director", "chief"], | |
| r"\bCEO\b": ["chief executive officer", "president", "director"], | |
| r"\bCFO\b": ["chief financial officer", "finance director"], | |
| r"\blocation\b": ["address", "located", "office", "headquarters", "branch"], | |
| r"\boffice\b": ["location", "branch", "headquarters", "situated"], | |
| r"\bservices\b": ["offerings", "products", "solutions", "business"], | |
| r"\bcompany\b": ["business", "organization", "firm", "corporation", "enterprise"], | |
| r"\bcontact\b": ["reach", "get in touch", "communicate"], | |
| r"\bbranch\b": ["office", "location", "division", "subsidiary"], | |
| r"\bheadquarters\b": ["main office", "head office", "corporate office"], | |
| } | |
| location_patterns = { | |
| r"\bhong\s*kong\b": ["HK", "hongkong"], | |
| r"\bsingapore\b": ["SG", "sing"], | |
| r"\bunited\s*states\b": ["USA", "US", "America"], | |
| r"\bunited\s*kingdom\b": ["UK", "Britain"], | |
| } | |
| for pattern, replacements in business_expansions.items(): | |
| if re.search(pattern, query_lower): | |
| for replacement in replacements: | |
| expanded_query = re.sub(pattern, replacement, query, flags=re.IGNORECASE) | |
| expanded_queries.add(expanded_query) | |
| for pattern, replacements in location_patterns.items(): | |
| if re.search(pattern, query_lower): | |
| for replacement in replacements: | |
| expanded_query = re.sub(pattern, replacement, query, flags=re.IGNORECASE) | |
| expanded_queries.add(expanded_query) | |
| return list(expanded_queries) | |
| def _build_enhanced_keyword_index(text, doc_id, user_api_key, mode): | |
| """ | |
| Build an enhanced keyword index with business context awareness. | |
| """ | |
| if not isinstance(text, str) or not text.strip(): | |
| return | |
| if mode not in keyword_indexes: | |
| keyword_indexes[mode] = {} | |
| if user_api_key not in keyword_indexes[mode]: | |
| keyword_indexes[mode][user_api_key] = {"documents": {}, "vocabulary": {}, "entities": {}} | |
| keyword_index = keyword_indexes[mode][user_api_key] | |
| words = re.findall(r'\b[a-zA-Z]{2,}\b', text.lower()) | |
| stop_words = set(stopwords.words('english')) | |
| ps = PorterStemmer() | |
| business_entities = re.findall(r'\b[A-Z][a-zA-Z&\s]{1,30}(?:Ltd|Inc|Corp|Company|Group|Holdings|Limited|Corporation|Enterprise|Solutions)\b', text) | |
| locations = re.findall(r'\b[A-Z][a-zA-Z\s]{2,20}(?:Street|Road|Avenue|Lane|Drive|Plaza|Square|Center|Centre|Building|Tower|Floor)\b', text) | |
| for word in words: | |
| if word not in stop_words and len(word) > 2: | |
| stemmed = ps.stem(word) | |
| if stemmed not in keyword_index["vocabulary"]: | |
| keyword_index["vocabulary"][stemmed] = [] | |
| if doc_id not in keyword_index["vocabulary"][stemmed]: | |
| keyword_index["vocabulary"][stemmed].append(doc_id) | |
| if "entities" not in keyword_index: | |
| keyword_index["entities"] = {} | |
| for entity in business_entities + locations: | |
| entity_key = entity.lower() | |
| if entity_key not in keyword_index["entities"]: | |
| keyword_index["entities"][entity_key] = [] | |
| if doc_id not in keyword_index["entities"][entity_key]: | |
| keyword_index["entities"][entity_key].append(doc_id) | |
| keyword_index["documents"][doc_id] = { | |
| "text": text, | |
| "length": len(text), | |
| "word_count": len(words), | |
| "entities": business_entities + locations | |
| } | |
| def _enhanced_keyword_search(query, user_api_key, mode, top_k=10): | |
| """ | |
| Enhanced keyword search with business context awareness. | |
| """ | |
| if mode not in keyword_indexes or user_api_key not in keyword_indexes[mode]: | |
| return [] | |
| keyword_index = keyword_indexes[mode][user_api_key] | |
| ps = PorterStemmer() | |
| query_terms = [ps.stem(term) for term in query.lower().split() | |
| if term not in stopwords.words('english') and len(term) > 2] | |
| entity_matches = [] | |
| if "entities" in keyword_index: | |
| for entity, docs in keyword_index["entities"].items(): | |
| if any(term in entity for term in query.lower().split()): | |
| entity_matches.extend(docs) | |
| doc_scores: Dict[str, float] = {} | |
| for term in query_terms: | |
| if term in keyword_index.get("vocabulary", {}): | |
| for doc_id in keyword_index["vocabulary"][term]: | |
| if doc_id not in doc_scores: | |
| doc_scores[doc_id] = 0 | |
| doc_scores[doc_id] += 1.0 | |
| for doc_id in entity_matches: | |
| if doc_id not in doc_scores: | |
| doc_scores[doc_id] = 0 | |
| doc_scores[doc_id] += 2.0 | |
| final_scores = {} | |
| for doc_id, score in doc_scores.items(): | |
| if doc_id in keyword_index.get("documents", {}): | |
| doc_length = keyword_index["documents"][doc_id].get("word_count", 1) | |
| final_scores[doc_id] = score / (1 + np.log(1 + doc_length)) | |
| sorted_docs = sorted(final_scores.items(), key=lambda x: x[1], reverse=True)[:top_k] | |
| return [doc_id for doc_id, score in sorted_docs] | |
| def add_document_to_knowledge_base(user_api_key, document_text, document_id, mode): | |
| """ | |
| Processes a document's text and adds it to the knowledge base with Chroma Cloud. | |
| """ | |
| try: | |
| print(f"\nRAG: Adding document '{document_id}' to Chroma Cloud...") | |
| collection = _get_or_create_collection(user_api_key, mode) | |
| chunks = _smart_chunking(document_text) | |
| print(f"RAG: Created {len(chunks)} intelligent chunks") | |
| _build_enhanced_keyword_index(document_text, document_id, user_api_key, mode) | |
| print("RAG: Built enhanced keyword index") | |
| if not chunks: | |
| print("RAG: No chunks to vectorize, saving keyword index only") | |
| _save_keyword_index(user_api_key, mode) | |
| return | |
| chunk_embeddings = embedding_model.encode(chunks, normalize_embeddings=True) | |
| print("RAG: Generated embeddings") | |
| # Prepare data for Chroma | |
| ids = [f"{document_id}_chunk_{i}" for i in range(len(chunks))] | |
| metadatas = [ | |
| { | |
| "source_doc": document_id, | |
| "chunk_id": i, | |
| "length": len(chunk), | |
| "type": "document_chunk" | |
| } | |
| for i, chunk in enumerate(chunks) | |
| ] | |
| # Add to Chroma Cloud | |
| collection.upsert( | |
| ids=ids, | |
| embeddings=chunk_embeddings.tolist(), | |
| documents=chunks, | |
| metadatas=metadatas | |
| ) | |
| # Save keyword index | |
| _save_keyword_index(user_api_key, mode) | |
| print(f"RAG: Successfully indexed document to Chroma Cloud. Total chunks: {len(chunks)}") | |
| except Exception as e: | |
| print(f"CRITICAL ERROR in add_document_to_knowledge_base: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise e | |
| def remove_document_from_knowledge_base(user_api_key, document_id, mode): | |
| """ | |
| Removes all chunks associated with a document from Chroma Cloud. | |
| """ | |
| try: | |
| collection = _get_or_create_collection(user_api_key, mode) | |
| # Delete all chunks from this document using where filter | |
| collection.delete( | |
| where={"source_doc": document_id} | |
| ) | |
| # Update keyword index | |
| if mode in keyword_indexes and user_api_key in keyword_indexes[mode]: | |
| keyword_index = keyword_indexes[mode][user_api_key] | |
| # Remove document from vocabulary | |
| if "vocabulary" in keyword_index: | |
| for term in list(keyword_index["vocabulary"].keys()): | |
| if document_id in keyword_index["vocabulary"][term]: | |
| keyword_index["vocabulary"][term].remove(document_id) | |
| if not keyword_index["vocabulary"][term]: | |
| del keyword_index["vocabulary"][term] | |
| # Remove document from entities | |
| if "entities" in keyword_index: | |
| for entity in list(keyword_index["entities"].keys()): | |
| if document_id in keyword_index["entities"][entity]: | |
| keyword_index["entities"][entity].remove(document_id) | |
| if not keyword_index["entities"][entity]: | |
| del keyword_index["entities"][entity] | |
| # Remove document metadata | |
| if "documents" in keyword_index and document_id in keyword_index["documents"]: | |
| del keyword_index["documents"][document_id] | |
| _save_keyword_index(user_api_key, mode) | |
| print(f"RAG: Removed document '{document_id}' from Chroma Cloud") | |
| except Exception as e: | |
| print(f"Error removing document: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| def _advanced_hybrid_search(query, user_api_key, mode, top_k=10): | |
| """ | |
| Advanced hybrid search using Chroma Cloud query. | |
| """ | |
| collection = _get_or_create_collection(user_api_key, mode) | |
| # Check if collection has documents | |
| try: | |
| count = collection.count() | |
| if count == 0: | |
| return [] | |
| except: | |
| return [] | |
| # Vector search with Chroma Cloud | |
| expanded_queries = _enhanced_query_expansion(query) | |
| all_results = {} | |
| for q in expanded_queries[:3]: # Limit to avoid too much noise | |
| query_embedding = embedding_model.encode([q], normalize_embeddings=True) | |
| try: | |
| results = collection.query( | |
| query_embeddings=query_embedding.tolist(), | |
| n_results=min(top_k * 2, count), | |
| where={"type": "document_chunk"}, | |
| include=["documents", "metadatas", "distances"] | |
| ) | |
| if results and results['ids'] and results['ids'][0]: | |
| for i, (doc_id, doc, metadata, distance) in enumerate(zip( | |
| results['ids'][0], | |
| results['documents'][0], | |
| results['metadatas'][0], | |
| results['distances'][0] | |
| )): | |
| # Convert distance to similarity score (Chroma returns L2 distance for cosine) | |
| score = 1 - distance if distance else 0 | |
| if doc_id not in all_results or all_results[doc_id]['score'] < score: | |
| all_results[doc_id] = { | |
| 'text': doc, | |
| 'source_doc': metadata.get('source_doc', ''), | |
| 'chunk_id': metadata.get('chunk_id', 0), | |
| 'length': metadata.get('length', 0), | |
| 'score': score | |
| } | |
| except Exception as e: | |
| print(f"RAG: Search error: {e}") | |
| continue | |
| # Enhanced keyword search boost | |
| keyword_doc_ids = set(_enhanced_keyword_search(query, user_api_key, mode, top_k=top_k*2)) | |
| # Add keyword boost to scores | |
| for doc_id, result in all_results.items(): | |
| if result.get('source_doc') in keyword_doc_ids: | |
| result['score'] = result.get('score', 0) + 0.4 | |
| # Sort and return top results | |
| sorted_results = sorted(all_results.items(), key=lambda x: x[1]['score'], reverse=True)[:top_k] | |
| return [result for doc_id, result in sorted_results] | |
| def _intelligent_rerank(query, candidate_chunks, top_k=5): | |
| """ | |
| Intelligent reranking that considers both relevance and context completeness. | |
| """ | |
| if not candidate_chunks or not reranker_model: | |
| return candidate_chunks[:top_k] | |
| # Use cross-encoder for initial scoring | |
| pairs = [(query, chunk["text"]) for chunk in candidate_chunks] | |
| cross_encoder_scores = reranker_model.predict(pairs) | |
| # Additional scoring based on content completeness | |
| enhanced_scores = [] | |
| for i, (chunk, ce_score) in enumerate(zip(candidate_chunks, cross_encoder_scores)): | |
| text = chunk["text"] | |
| # Bonus for chunks that seem to contain complete information | |
| completeness_bonus = 0 | |
| if any(marker in text.lower() for marker in ["located", "address", "office", "branch"]): | |
| completeness_bonus += 0.1 | |
| if any(marker in text.lower() for marker in ["manager", "director", "ceo", "head"]): | |
| completeness_bonus += 0.1 | |
| if any(marker in text.lower() for marker in ["company", "business", "organization"]): | |
| completeness_bonus += 0.05 | |
| final_score = ce_score + completeness_bonus | |
| enhanced_scores.append((chunk, final_score)) | |
| # Sort by enhanced scores and return top results | |
| reranked = sorted(enhanced_scores, key=lambda x: x[1], reverse=True) | |
| return [chunk for chunk, score in reranked[:top_k]] | |
| def query_knowledge_base(user_api_key, query_text, mode, selected_model_key): | |
| """ | |
| Advanced query processing with human-like response generation using selected model with fallback. | |
| """ | |
| collection = _get_or_create_collection(user_api_key, mode) | |
| try: | |
| count = collection.count() | |
| # Exclude keyword index from count | |
| if count <= 1: | |
| return "I don't have any documents in my knowledge base yet. Please upload some brochures or business cards first, and I'll be happy to help you find information from them!" | |
| except: | |
| return "I don't have any documents in my knowledge base yet. Please upload some brochures or business cards first, and I'll be happy to help you find information from them!" | |
| print(f"RAG: Processing query: '{query_text}' with model: {selected_model_key}") | |
| # Optimized search - use only 2 query variations for speed | |
| expanded_queries = _enhanced_query_expansion(query_text) | |
| print(f"RAG: Expanded to {len(expanded_queries)} query variations") | |
| all_candidates = [] | |
| seen_texts = set() | |
| for query in expanded_queries[:2]: # Reduced from 3 to 2 for speed | |
| candidates = _advanced_hybrid_search(query, user_api_key, mode, top_k=5) # Reduced from 8 to 5 | |
| for candidate in candidates: | |
| text = candidate.get('text', '') | |
| if text and text not in seen_texts: | |
| seen_texts.add(text) | |
| all_candidates.append(candidate) | |
| # Intelligent reranking - reduced to 3 chunks for faster LLM response | |
| top_chunks = _intelligent_rerank(query_text, all_candidates, top_k=3) | |
| if not top_chunks: | |
| return f"I couldn't find specific information about '{query_text}' in the uploaded documents. Could you try rephrasing your question or check if the information might be in a document that hasn't been uploaded yet?" | |
| # Prepare context for AI model | |
| context = "\n\n---DOCUMENT SECTION---\n\n".join([chunk["text"] for chunk in top_chunks]) | |
| print(f"RAG: Found {len(top_chunks)} relevant sections. Generating response with {selected_model_key}...") | |
| try: | |
| prompt = f"""You are a world-class AI assistant providing beautifully formatted, accurate answers based on document data. | |
| **FORMATTING RULES (CRITICAL):** | |
| - Use **bold** for names, companies, and important terms | |
| - Use bullet points (•) for lists of items | |
| - Use numbered lists (1. 2. 3.) for steps or rankings | |
| - Keep responses concise but complete - aim for 2-4 sentences unless more detail is needed | |
| - Structure longer responses with clear sections | |
| - For contact info, format cleanly: **Name** - email@example.com - +1234567890 | |
| **ACCURACY RULES:** | |
| - Only state facts found in the documents below | |
| - Be direct and specific - give the exact answer first, then context | |
| - If asked "who is X" or "what is X's role", lead with the answer immediately | |
| **USER QUESTION:** {query_text} | |
| **DOCUMENT DATA:** | |
| {context} | |
| **YOUR RESPONSE (formatted beautifully with markdown):**""" | |
| response = _call_openrouter_api_with_fallback(user_api_key, selected_model_key, prompt) | |
| return response | |
| except Exception as e: | |
| print(f"RAG: An unexpected error occurred during response generation: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return "I found relevant information but ran into an unexpected error while processing it. Please try again." | |
| # ============================================ | |
| # METADATA PERSISTENCE FUNCTIONS | |
| # ============================================ | |
| def save_metadata_to_chroma(user_api_key, mode, document_id, metadata_dict): | |
| """ | |
| Save contact/brochure metadata to ChromaDB for persistence across restarts. | |
| Stores the full metadata as a JSON document with a special ID prefix. | |
| """ | |
| if not _rag_system_available: | |
| print("RAG: System not available, cannot save metadata to ChromaDB") | |
| return False | |
| try: | |
| collection = _get_or_create_collection(user_api_key, mode) | |
| metadata_id = f"__metadata__{document_id}" | |
| # Store metadata as JSON string in document field | |
| metadata_json = json.dumps(metadata_dict, ensure_ascii=False) | |
| collection.upsert( | |
| ids=[metadata_id], | |
| documents=[metadata_json], | |
| embeddings=[[0.0] * EMBEDDING_DIM], # Dummy embedding to match collection dimension | |
| metadatas=[{ | |
| "type": "metadata", | |
| "mode": mode, | |
| "document_id": document_id, | |
| "timestamp": str(time.time()) | |
| }] | |
| ) | |
| print(f"RAG: Saved metadata for {document_id} to ChromaDB") | |
| return True | |
| except Exception as e: | |
| print(f"RAG: Error saving metadata to ChromaDB: {e}") | |
| return False | |
| def load_all_metadata_from_chroma(user_api_key, mode): | |
| """ | |
| Load all saved metadata for a user/mode from ChromaDB. | |
| Returns a list of metadata dictionaries. | |
| """ | |
| if not _rag_system_available: | |
| print("RAG: System not available, cannot load metadata from ChromaDB") | |
| return [] | |
| try: | |
| collection = _get_or_create_collection(user_api_key, mode) | |
| # Query for all metadata documents | |
| results = collection.get( | |
| where={"type": "metadata"}, | |
| include=["documents", "metadatas"] | |
| ) | |
| if not results or not results['documents']: | |
| return [] | |
| metadata_list = [] | |
| for doc, meta in zip(results['documents'], results['metadatas']): | |
| try: | |
| if doc and meta.get('type') == 'metadata': | |
| parsed = json.loads(doc) | |
| metadata_list.append(parsed) | |
| except json.JSONDecodeError: | |
| continue | |
| # Sort by timestamp if available (newest first) | |
| metadata_list.sort(key=lambda x: x.get('_timestamp', 0), reverse=True) | |
| print(f"RAG: Loaded {len(metadata_list)} metadata records from ChromaDB for {mode}") | |
| return metadata_list | |
| except Exception as e: | |
| print(f"RAG: Error loading metadata from ChromaDB: {e}") | |
| return [] | |
| def delete_metadata_from_chroma(user_api_key, mode, document_id): | |
| """ | |
| Delete metadata document from ChromaDB. | |
| """ | |
| if not _rag_system_available: | |
| return False | |
| try: | |
| collection = _get_or_create_collection(user_api_key, mode) | |
| metadata_id = f"__metadata__{document_id}" | |
| collection.delete(ids=[metadata_id]) | |
| print(f"RAG: Deleted metadata for {document_id} from ChromaDB") | |
| return True | |
| except Exception as e: | |
| print(f"RAG: Error deleting metadata from ChromaDB: {e}") | |
| return False | |
| def delete_all_metadata_from_chroma(user_api_key, mode): | |
| """ | |
| Delete ALL metadata documents from ChromaDB for a user/mode. | |
| This is used for the 'delete all' feature. | |
| Returns the count of deleted items. | |
| """ | |
| if not _rag_system_available: | |
| print("RAG: System not available, cannot delete metadata from ChromaDB") | |
| return 0 | |
| try: | |
| collection = _get_or_create_collection(user_api_key, mode) | |
| # Get all metadata document IDs | |
| results = collection.get( | |
| where={"type": "metadata"}, | |
| include=["metadatas"] | |
| ) | |
| if not results or not results['ids']: | |
| print(f"RAG: No metadata to delete for {mode}") | |
| return 0 | |
| deleted_count = len(results['ids']) | |
| # Delete all metadata documents | |
| collection.delete(ids=results['ids']) | |
| print(f"RAG: Deleted {deleted_count} metadata records from ChromaDB for {mode}") | |
| return deleted_count | |
| except Exception as e: | |
| print(f"RAG: Error deleting all metadata from ChromaDB: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return 0 | |
| def delete_all_documents_from_chroma(user_api_key, mode): | |
| """ | |
| Delete ALL document chunks from ChromaDB for a user/mode. | |
| This removes RAG knowledge base entries. | |
| Returns the count of deleted chunks. | |
| """ | |
| if not _rag_system_available: | |
| print("RAG: System not available, cannot delete documents from ChromaDB") | |
| return 0 | |
| try: | |
| collection = _get_or_create_collection(user_api_key, mode) | |
| # Get all document chunk IDs | |
| results = collection.get( | |
| where={"type": "document_chunk"}, | |
| include=["metadatas"] | |
| ) | |
| if not results or not results['ids']: | |
| print(f"RAG: No document chunks to delete for {mode}") | |
| return 0 | |
| deleted_count = len(results['ids']) | |
| # Delete all document chunks | |
| collection.delete(ids=results['ids']) | |
| print(f"RAG: Deleted {deleted_count} document chunks from ChromaDB for {mode}") | |
| # Clear keyword index | |
| if mode in keyword_indexes and user_api_key in keyword_indexes[mode]: | |
| keyword_indexes[mode][user_api_key] = {"documents": {}, "vocabulary": {}, "entities": {}} | |
| _save_keyword_index(user_api_key, mode) | |
| return deleted_count | |
| except Exception as e: | |
| print(f"RAG: Error deleting all documents from ChromaDB: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return 0 | |
| def update_metadata_in_chroma(user_api_key, mode, document_id, field, value, contact_id=None): | |
| """ | |
| Update a specific field in the metadata stored in ChromaDB. | |
| For brochures with contact_id, updates the specific contact. | |
| """ | |
| if not _rag_system_available: | |
| return False | |
| try: | |
| collection = _get_or_create_collection(user_api_key, mode) | |
| metadata_id = f"__metadata__{document_id}" | |
| # Get existing metadata | |
| results = collection.get(ids=[metadata_id], include=["documents"]) | |
| if not results or not results['documents'] or not results['documents'][0]: | |
| print(f"RAG: Metadata not found for {document_id}") | |
| return False | |
| metadata = json.loads(results['documents'][0]) | |
| # Update the field | |
| if mode == 'cards': | |
| metadata[field] = value | |
| elif mode == 'brochures' and contact_id: | |
| # Find and update the contact | |
| for contact in metadata.get('contacts', []): | |
| if contact.get('id') == contact_id: | |
| contact[field] = value | |
| break | |
| # Save updated metadata | |
| return save_metadata_to_chroma(user_api_key, mode, document_id, metadata) | |
| except Exception as e: | |
| print(f"RAG: Error updating metadata in ChromaDB: {e}") | |
| return False | |
| # ============================================ | |
| # CHAT MEMORY FUNCTIONS | |
| # ============================================ | |
| def save_chat_message(user_api_key, mode, role, content): | |
| """ | |
| Save a chat message to ChromaDB for conversation memory. | |
| Role should be 'user' or 'assistant'. | |
| """ | |
| if not _rag_system_available: | |
| return False | |
| try: | |
| collection = _get_or_create_collection(user_api_key, mode) | |
| # Create unique ID for this message | |
| message_id = f"__chat__{mode}_{int(time.time() * 1000)}" | |
| message_data = { | |
| "role": role, | |
| "content": content, | |
| "timestamp": time.time() | |
| } | |
| collection.upsert( | |
| ids=[message_id], | |
| documents=[json.dumps(message_data, ensure_ascii=False)], | |
| embeddings=[[0.0] * EMBEDDING_DIM], # Dummy embedding to match collection dimension | |
| metadatas=[{ | |
| "type": "chat_message", | |
| "mode": mode, | |
| "role": role, | |
| "timestamp": str(time.time()) | |
| }] | |
| ) | |
| return True | |
| except Exception as e: | |
| print(f"RAG: Error saving chat message: {e}") | |
| return False | |
| def get_chat_history(user_api_key, mode, limit=10): | |
| """ | |
| Get recent chat history from ChromaDB. | |
| Returns list of {role, content, timestamp} dictionaries. | |
| """ | |
| if not _rag_system_available: | |
| return [] | |
| try: | |
| collection = _get_or_create_collection(user_api_key, mode) | |
| results = collection.get( | |
| where={"type": "chat_message"}, | |
| include=["documents", "metadatas"] | |
| ) | |
| if not results or not results['documents']: | |
| return [] | |
| messages = [] | |
| for doc, meta in zip(results['documents'], results['metadatas']): | |
| try: | |
| if doc and meta.get('type') == 'chat_message': | |
| parsed = json.loads(doc) | |
| messages.append(parsed) | |
| except json.JSONDecodeError: | |
| continue | |
| # Sort by timestamp (oldest first for conversation flow) | |
| messages.sort(key=lambda x: x.get('timestamp', 0)) | |
| # Return last N messages | |
| return messages[-limit:] | |
| except Exception as e: | |
| print(f"RAG: Error loading chat history: {e}") | |
| return [] | |
| def clear_chat_history(user_api_key, mode): | |
| """ | |
| Clear all chat messages for a user/mode from ChromaDB. | |
| """ | |
| if not _rag_system_available: | |
| return False | |
| try: | |
| collection = _get_or_create_collection(user_api_key, mode) | |
| # Get all chat message IDs | |
| results = collection.get( | |
| where={"type": "chat_message"}, | |
| include=["metadatas"] | |
| ) | |
| if results and results['ids']: | |
| collection.delete(ids=results['ids']) | |
| print(f"RAG: Cleared {len(results['ids'])} chat messages for {mode}") | |
| return True | |
| except Exception as e: | |
| print(f"RAG: Error clearing chat history: {e}") | |
| return False |