Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import unicodedata | |
| from pathlib import Path | |
| from typing import List | |
| from dotenv import load_dotenv | |
| import google.generativeai as genai | |
| from huggingface_hub import InferenceClient | |
| load_dotenv() | |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
| if GEMINI_API_KEY: | |
| genai.configure(api_key=GEMINI_API_KEY) | |
| vectordb = None | |
| retriever = None | |
| embeddings = None | |
| rag_initialized = False | |
| uploaded_documents = [] | |
| last_index_mtime = None | |
| RAG_DATA_DIR = Path(__file__).resolve().parent.parent / "rag_data" | |
| FAISS_INDEX_PATH = RAG_DATA_DIR / "faiss_index" | |
| INSUFFICIENT_CONTEXT_MARKER = "i don't have enough information in the documents" | |
| def initialize_embeddings(): | |
| """Initialize the multilingual embedding model.""" | |
| global embeddings | |
| if embeddings is not None: | |
| return embeddings | |
| print("Loading multilingual embedding model...") | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2", | |
| encode_kwargs={"normalize_embeddings": True}, | |
| ) | |
| print("Embedding model loaded.") | |
| return embeddings | |
| def clean_text(text: str) -> str: | |
| """Clean and normalize text for embedding.""" | |
| if not isinstance(text, str) or not text.strip(): | |
| return "" | |
| normalized_text = unicodedata.normalize("NFKC", text) | |
| cleaned_chars = [ | |
| char for char in normalized_text | |
| if unicodedata.category(char) not in ["So", "Cn", "Cc", "Cf", "Cs"] | |
| ] | |
| cleaned_text = "".join(cleaned_chars) | |
| cleaned_text = re.sub(r"\s+", " ", cleaned_text).strip() | |
| return cleaned_text | |
| def load_and_process_pdf(pdf_path: str) -> List[dict]: | |
| """Load a PDF and split it into chunks.""" | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| print(f"Loading PDF: {pdf_path}") | |
| loader = PyPDFLoader(pdf_path) | |
| docs = loader.load() | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=300, | |
| chunk_overlap=80, | |
| ) | |
| chunks = splitter.split_documents(docs) | |
| print(f"Loaded {len(docs)} pages, created {len(chunks)} chunks.") | |
| return chunks | |
| def create_vector_store(chunks: List) -> bool: | |
| """Create or update the FAISS vector store with document chunks.""" | |
| global vectordb, retriever, rag_initialized | |
| from langchain_community.vectorstores import FAISS | |
| initialize_embeddings() | |
| texts = [doc.page_content for doc in chunks] | |
| metadatas = [doc.metadata for doc in chunks] | |
| processed_texts = [] | |
| processed_metadatas = [] | |
| for i, text in enumerate(texts): | |
| cleaned_text = clean_text(text) | |
| if cleaned_text: | |
| processed_texts.append(cleaned_text) | |
| processed_metadatas.append(metadatas[i]) | |
| if not processed_texts: | |
| print("No valid texts after cleaning.") | |
| return False | |
| print(f"Processing {len(processed_texts)} text chunks for embedding...") | |
| if vectordb is None: | |
| vectordb = FAISS.from_texts(processed_texts, embeddings, metadatas=processed_metadatas) | |
| else: | |
| new_vectordb = FAISS.from_texts(processed_texts, embeddings, metadatas=processed_metadatas) | |
| vectordb.merge_from(new_vectordb) | |
| retriever = vectordb.as_retriever(search_kwargs={"k": 4}) | |
| rag_initialized = True | |
| save_vector_store() | |
| _sync_uploaded_documents() | |
| print("Vector store created/updated successfully.") | |
| return True | |
| def save_vector_store(): | |
| """Save the FAISS index to disk.""" | |
| global vectordb, last_index_mtime | |
| if vectordb is None: | |
| return | |
| RAG_DATA_DIR.mkdir(parents=True, exist_ok=True) | |
| vectordb.save_local(str(FAISS_INDEX_PATH)) | |
| last_index_mtime = _get_index_mtime() | |
| print(f"Vector store saved to {FAISS_INDEX_PATH}.") | |
| def load_vector_store() -> bool: | |
| """Load the FAISS index from disk if it exists.""" | |
| global vectordb, retriever, rag_initialized, last_index_mtime | |
| if not FAISS_INDEX_PATH.exists(): | |
| return False | |
| try: | |
| from langchain_community.vectorstores import FAISS | |
| initialize_embeddings() | |
| vectordb = FAISS.load_local( | |
| str(FAISS_INDEX_PATH), | |
| embeddings, | |
| allow_dangerous_deserialization=True, | |
| ) | |
| retriever = vectordb.as_retriever(search_kwargs={"k": 4}) | |
| rag_initialized = True | |
| last_index_mtime = _get_index_mtime() | |
| _sync_uploaded_documents() | |
| print("Loaded existing vector store from disk.") | |
| return True | |
| except Exception as e: | |
| print(f"Failed to load vector store: {e}") | |
| return False | |
| def rag_answer(question: str) -> dict: | |
| """Answer a question using RAG - first check database, then fallback to Gemini/HF.""" | |
| global retriever, vectordb, last_index_mtime | |
| result = { | |
| "answer": "", | |
| "source": "none", | |
| "context_found": False, | |
| "relevance_score": 0.0, | |
| } | |
| if FAISS_INDEX_PATH.exists(): | |
| current_mtime = _get_index_mtime() | |
| if (not rag_initialized or retriever is None) or ( | |
| current_mtime and last_index_mtime and current_mtime > last_index_mtime | |
| ): | |
| load_vector_store() | |
| if not rag_initialized or retriever is None: | |
| result["source"] = "gemini" | |
| result["answer"] = _ask_gemini_directly(question) | |
| return result | |
| docs_with_scores = vectordb.similarity_search_with_score(question, k=4) | |
| if not docs_with_scores: | |
| print(f"No documents found for question: {question}") | |
| result["source"] = "gemini" | |
| result["answer"] = _ask_gemini_directly(question) | |
| return result | |
| best_score = docs_with_scores[0][1] if docs_with_scores else float("inf") | |
| result["relevance_score"] = float(best_score) | |
| print(f"\nQuestion: {question}") | |
| print(f"Retrieved {len(docs_with_scores)} documents:") | |
| for i, (doc, score) in enumerate(docs_with_scores): | |
| preview = doc.page_content[:100].replace("\n", " ") | |
| print(f" [{i + 1}] Score: {score:.3f} - {preview}...") | |
| print(f"Using RAG with relevance score: {best_score}") | |
| docs = [doc for doc, score in docs_with_scores] | |
| context = "\n\n".join([d.page_content for d in docs]) | |
| result["context_found"] = True | |
| prompt = ( | |
| "You are a helpful assistant. Answer the question based ONLY on the following " | |
| "context from the PDF document. If the context doesn't contain enough information " | |
| "to answer the question, say \"I don't have enough information in the documents to " | |
| "answer this question.\"\n\n" | |
| "Context from PDF:\n" | |
| f"{context}\n\n" | |
| f"Question: {question}\n\n" | |
| "Answer (in English):" | |
| ) | |
| try: | |
| gemini_key = os.getenv("GEMINI_API_KEY") | |
| if gemini_key: | |
| try: | |
| model = genai.GenerativeModel("models/gemini-2.5-flash") | |
| response = model.generate_content(prompt) | |
| rag_answer_text = (response.text or "").strip() | |
| if _is_insufficient_context_answer(rag_answer_text): | |
| print("RAG context not sufficient. Falling back to direct AI answer.") | |
| result["answer"] = _ask_gemini_directly(question) | |
| result["source"] = "gemini" | |
| return result | |
| result["answer"] = rag_answer_text | |
| result["source"] = "rag" | |
| return result | |
| except Exception as gemini_error: | |
| error_msg = str(gemini_error) | |
| print(f"Gemini error in RAG: {error_msg[:200]}...") | |
| if "429" in error_msg or "quota" in error_msg.lower(): | |
| print("Gemini quota exceeded. Using Hugging Face for RAG.") | |
| print("Using Hugging Face for RAG answer...") | |
| rag_answer_text = _ask_huggingface_free(prompt).strip() | |
| if _is_insufficient_context_answer(rag_answer_text): | |
| print("RAG context not sufficient. Falling back to direct AI answer.") | |
| result["answer"] = _ask_gemini_directly(question) | |
| result["source"] = "gemini" | |
| return result | |
| result["answer"] = rag_answer_text | |
| result["source"] = "rag" | |
| except Exception as e: | |
| print(f"All RAG generation failed: {e}") | |
| result["answer"] = "Sorry, unable to generate answer. Please try again later." | |
| result["source"] = "error" | |
| return result | |
| def _ask_huggingface_free(prompt: str) -> str: | |
| """Use free Hugging Face Inference API with token if available.""" | |
| hf_token = os.getenv("HF_API_TOKEN") | |
| try: | |
| client = InferenceClient(token=hf_token) | |
| except Exception as e: | |
| raise Exception(f"Failed to create Hugging Face client: {e}") | |
| messages = [{"role": "user", "content": prompt}] | |
| try: | |
| print("Calling Hugging Face API (Qwen2.5-72B-Instruct)...") | |
| response = client.chat_completion( | |
| messages=messages, | |
| model="Qwen/Qwen2.5-72B-Instruct", | |
| max_tokens=500, | |
| temperature=0.7, | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| error_str = str(e) | |
| print(f"Hugging Face primary model error: {e}") | |
| try: | |
| print("Trying backup model (Microsoft Phi-3)...") | |
| response = client.chat_completion( | |
| messages=messages, | |
| model="microsoft/Phi-3-mini-4k-instruct", | |
| max_tokens=500, | |
| temperature=0.7, | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e2: | |
| print(f"Backup model also failed: {e2}") | |
| raise Exception(f"All HF models failed: {error_str}") | |
| def _ask_gemini_directly(question: str) -> str: | |
| """Fallback: Ask Gemini directly without RAG context, with Hugging Face fallback.""" | |
| prompt = ( | |
| "Answer the following question helpfully and accurately:\n\n" | |
| f"Question: {question}\n\n" | |
| "Answer:" | |
| ) | |
| gemini_key = os.getenv("GEMINI_API_KEY") | |
| if gemini_key: | |
| try: | |
| model = genai.GenerativeModel("models/gemini-2.5-flash") | |
| response = model.generate_content(prompt) | |
| return response.text | |
| except Exception as gemini_error: | |
| error_msg = str(gemini_error) | |
| print(f"Gemini API error: {error_msg[:200]}...") | |
| if "429" in error_msg or "quota" in error_msg.lower(): | |
| print("Gemini quota exceeded. Switching to Hugging Face.") | |
| else: | |
| print("Gemini error. Switching to Hugging Face.") | |
| else: | |
| print("No Gemini API key, using Hugging Face.") | |
| try: | |
| print("Using Hugging Face for direct answer...") | |
| return _ask_huggingface_free(prompt) | |
| except Exception as hf_error: | |
| print(f"Hugging Face error: {hf_error}") | |
| return ( | |
| "Sorry, both AI services are unavailable. " | |
| f"Gemini quota exceeded, and Hugging Face error: {str(hf_error)}" | |
| ) | |
| def get_rag_status() -> dict: | |
| """Get the current status of the RAG system.""" | |
| if not rag_initialized and FAISS_INDEX_PATH.exists(): | |
| load_vector_store() | |
| _sync_uploaded_documents() | |
| return { | |
| "initialized": rag_initialized, | |
| "documents_count": len(uploaded_documents), | |
| "documents": uploaded_documents, | |
| "has_embeddings": embeddings is not None, | |
| "has_vector_store": vectordb is not None, | |
| } | |
| def clear_rag_data(): | |
| """Clear all RAG data.""" | |
| global vectordb, retriever, rag_initialized, uploaded_documents, last_index_mtime | |
| vectordb = None | |
| retriever = None | |
| rag_initialized = False | |
| uploaded_documents = [] | |
| last_index_mtime = None | |
| if FAISS_INDEX_PATH.exists(): | |
| import shutil | |
| shutil.rmtree(FAISS_INDEX_PATH) | |
| print("RAG data cleared.") | |
| return True | |
| def _get_index_mtime(): | |
| index_file = FAISS_INDEX_PATH / "index.faiss" | |
| if index_file.exists(): | |
| return index_file.stat().st_mtime | |
| return None | |
| def _is_insufficient_context_answer(answer_text: str) -> bool: | |
| if not answer_text: | |
| return True | |
| normalized = answer_text.strip().lower() | |
| return INSUFFICIENT_CONTEXT_MARKER in normalized | |
| def _sync_uploaded_documents(): | |
| global uploaded_documents | |
| if not RAG_DATA_DIR.exists(): | |
| uploaded_documents = [] | |
| return | |
| uploaded_documents = sorted( | |
| [pdf.name for pdf in RAG_DATA_DIR.glob("*.pdf") if pdf.is_file()] | |
| ) | |
| def rebuild_vector_store_from_pdfs() -> bool: | |
| """Rebuild vector store from all PDFs in rag_data directory.""" | |
| global vectordb, retriever, rag_initialized | |
| _sync_uploaded_documents() | |
| if not uploaded_documents: | |
| print("No PDFs found in rag_data to rebuild vector store.") | |
| return False | |
| initialize_embeddings() | |
| vectordb = None | |
| retriever = None | |
| rag_initialized = False | |
| all_chunks = [] | |
| for filename in uploaded_documents: | |
| pdf_path = RAG_DATA_DIR / filename | |
| try: | |
| chunks = load_and_process_pdf(str(pdf_path)) | |
| all_chunks.extend(chunks) | |
| except Exception as e: | |
| print(f"Skipping PDF '{filename}' due to processing error: {e}") | |
| if not all_chunks: | |
| print("No chunks generated from PDFs. Rebuild aborted.") | |
| return False | |
| success = create_vector_store(all_chunks) | |
| if success: | |
| print(f"Rebuilt vector store from {len(uploaded_documents)} PDF(s).") | |
| return success | |