Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Optional | |
| from langchain_chroma import Chroma | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_groq import ChatGroq | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_classic.chains import create_history_aware_retriever | |
| from langchain_classic.chains.combine_documents import create_stuff_documents_chain | |
| from langchain_community.chat_message_histories import ChatMessageHistory | |
| from langchain_core.chat_history import BaseChatMessageHistory | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain_core.documents import Document | |
| from typing import List, Tuple | |
| from .core import config | |
| def get_embeddings(): | |
| if config.HUGGINGFACE_TOKEN: | |
| os.environ["HUGGINGFACE_TOKEN"] = config.HUGGINGFACE_TOKEN | |
| # Use all-MiniLM-L6-v2: smaller model (~90MB) that works well on free tier | |
| # all-mpnet-base-v2 (~420MB) is too large for Render free tier (512MB RAM) | |
| return HuggingFaceEmbeddings( | |
| model_name="all-MiniLM-L6-v2", | |
| encode_kwargs={"normalize_embeddings": True}, | |
| ) | |
| def get_user_chroma_dir(user_id: str, session_id: str | None = None) -> str: | |
| # Use /tmp for ChromaDB to avoid permission issues in HF Spaces | |
| base = "/tmp/chroma_db" | |
| if session_id: | |
| return os.path.join(base, f"user_{user_id}", f"session_{session_id}") | |
| return os.path.join(base, f"user_{user_id}") | |
| def get_vectorstore_for_user(user_id: str, session_id: str | None = None) -> Chroma: | |
| if not session_id: | |
| # Enforce per-session isolation; caller must provide session_id | |
| raise ValueError("session_id is required for vectorstore access") | |
| # Try to use persistent directory, fall back to in-memory if it fails | |
| try: | |
| persist_dir = get_user_chroma_dir(user_id, session_id) | |
| os.makedirs(persist_dir, exist_ok=True) | |
| embeddings = get_embeddings() | |
| return Chroma(persist_directory=persist_dir, embedding_function=embeddings) | |
| except Exception as e: | |
| print(f"⚠️ Persistent ChromaDB failed ({e}), using in-memory mode") | |
| # Fallback to in-memory ChromaDB (no persistence) | |
| embeddings = get_embeddings() | |
| return Chroma(embedding_function=embeddings) | |
| from .core.ocr import extract_text_from_pdf_with_ocr | |
| def index_pdf_for_user(user_id: str, temp_pdf_path: str, session_id: str | None = None): | |
| if not session_id: | |
| raise ValueError("session_id is required for indexing") | |
| loader = PyPDFLoader(temp_pdf_path) | |
| # Load all docs first | |
| raw_docs = loader.load() | |
| final_docs = [] | |
| pages_needs_ocr = [] | |
| # Analyze each page | |
| # If a page has very little text, it might be an image/scan -> mark for OCR | |
| for i, doc in enumerate(raw_docs): | |
| content = doc.page_content or "" | |
| # Simple heuristic: if less than 50 chars of meaningful text, try OCR | |
| # This covers empty pages or pages with just "Scanned by CamScanner" etc. | |
| if len(content.strip()) < 50: | |
| pages_needs_ocr.append(i) | |
| else: | |
| final_docs.append(doc) | |
| # Run OCR on identified pages | |
| if pages_needs_ocr: | |
| print(f"OCR needed for {len(pages_needs_ocr)} pages: {pages_needs_ocr}") | |
| try: | |
| ocr_docs = extract_text_from_pdf_with_ocr(temp_pdf_path, pages_needs_ocr) | |
| # Re-insert OCR docs in correct order (though order matters less for bag-of-chunks, it helps context) | |
| # Since final_docs is already populated with non-OCR pages, we can just append and sort or just append specific ones. | |
| # Simpler: just extend final_docs with whatever we got. | |
| # Note: The OCR docs metadata 'page' corresponds to 0-indexed page num. | |
| # Filter out OCR failures (empty text) | |
| ocr_docs = [d for d in ocr_docs if d.page_content.strip()] | |
| final_docs.extend(ocr_docs) | |
| except Exception as e: | |
| print(f"Warning: OCR failed ({e}). Proceeding with what we have.") | |
| # If we still have absolutely no text after everything | |
| if not final_docs: | |
| # One last desperate attempt: Force OCR on ALL pages if we gathered nothing so far | |
| # (Only if we haven't already tried OCR on all pages) | |
| if len(pages_needs_ocr) != len(raw_docs): | |
| print("No text found in initial pass. Attempting OCR on ALL pages...") | |
| try: | |
| final_docs = extract_text_from_pdf_with_ocr(temp_pdf_path) | |
| final_docs = [d for d in final_docs if d.page_content.strip()] | |
| except Exception as e: | |
| print(f"Fallback OCR failed: {e}") | |
| if not final_docs: | |
| raise ValueError("No extractable text found in the PDF, even after OCR attempt.") | |
| # Sort solely for debugging sanity (optional) | |
| final_docs.sort(key=lambda x: x.metadata.get("page", 0)) | |
| # Slightly smaller chunks generally improve recall; keep modest overlap for continuity | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=900, chunk_overlap=150) | |
| splits = splitter.split_documents(final_docs) | |
| if not splits: | |
| raise ValueError("No text chunks generated from the PDF.") | |
| vs = get_vectorstore_for_user(user_id, session_id) | |
| vs.add_documents(splits) | |
| def get_llm() -> ChatGroq: | |
| # Deterministic answers; we rely on retrieved context only | |
| # Using Groq's free open-source model: openai/gpt-oss-120b | |
| return ChatGroq(api_key=config.GROQ_API_KEY, model="openai/gpt-oss-120b", temperature=0) | |
| def build_conversational_chain(user_id: str, history: Optional[BaseChatMessageHistory], session_id: str | None = None): | |
| if not session_id: | |
| raise ValueError("session_id is required for chat") | |
| vs = get_vectorstore_for_user(user_id, session_id) | |
| # Embedding retriever (primary). Avoid score_threshold here due to Chroma compatibility. | |
| embedding_retriever = vs.as_retriever(search_kwargs={"k": 20}) | |
| # Build a lightweight BM25 retriever over all docs in the session for hybrid search | |
| bm25 = None | |
| try: | |
| # Try to get all documents from the collection | |
| collection = vs._collection | |
| all_data = collection.get(include=["documents", "metadatas"]) | |
| texts = all_data.get("documents", []) or [] | |
| metas = all_data.get("metadatas", []) or [] | |
| print(f"Chroma collection has {len(texts)} documents") | |
| if texts and len(texts) > 0: | |
| bm25_docs: List[Document] = [Document(page_content=t, metadata=m or {}) for t, m in zip(texts, metas)] | |
| bm25 = BM25Retriever.from_documents(bm25_docs) | |
| bm25.k = 20 | |
| print(f"BM25 initialized with {len(bm25_docs)} documents") | |
| else: | |
| print("WARNING: No documents found in Chroma collection - did you upload a PDF?") | |
| except Exception as e: | |
| print(f"BM25 initialization failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| bm25 = None | |
| llm = get_llm() | |
| contextualize_q_system_prompt = ( | |
| "Given a chat history and the latest user question" | |
| " which might reference context in the chat history, " | |
| "formulate a standalone question which can be understood " | |
| "without the chat history. Do NOT answer the question, " | |
| "just reformulate it if needed and otherwise return it as is." | |
| ) | |
| contextualize_q_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", contextualize_q_system_prompt), | |
| MessagesPlaceholder("chat_history"), | |
| ("human", "{input}"), | |
| ] | |
| ) | |
| history_aware_retriever = create_history_aware_retriever(llm, embedding_retriever, contextualize_q_prompt) | |
| system_prompt = ( | |
| "You are a grounded RAG assistant.\n" | |
| "Use ONLY the information in the retrieved context to answer.\n" | |
| "Do NOT use prior knowledge or invent facts.\n\n" | |
| "When answering from context, be clear and structured (headings, bullet points, numbered lists as needed).\n\n" | |
| "Retrieved context follows.\n{context}" | |
| ) | |
| qa_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system_prompt), | |
| MessagesPlaceholder("chat_history"), | |
| ("human", "{input}"), | |
| ] | |
| ) | |
| question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) | |
| # Compose a custom retrieval function that performs multi-query expansion and RRF fusion | |
| def retrieve(query: str, chat_history) -> List[Document]: | |
| # Multi-query expansion: generate several paraphrases of the user query | |
| # Simplified to avoid breaking on subsequent queries | |
| queries = [query] | |
| try: | |
| mq_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", "Generate 2 alternative search queries to find relevant information. Return ONLY a JSON array of strings, nothing else. Example: [\"query 1\", \"query 2\"]"), | |
| ("human", "{q}") | |
| ]) | |
| mq = llm.invoke(mq_prompt.format_messages(q=query)).content.strip() | |
| import json | |
| # Try to extract JSON array if wrapped in markdown code blocks | |
| if "```" in mq: | |
| # Extract content between ```json and ``` or ``` and ``` | |
| start = mq.find("[") | |
| end = mq.rfind("]") + 1 | |
| if start != -1 and end > start: | |
| mq = mq[start:end] | |
| parsed = json.loads(mq) | |
| if isinstance(parsed, list): | |
| for alt in parsed: | |
| if isinstance(alt, str) and alt.strip() and alt not in queries: | |
| queries.append(alt.strip()) | |
| print(f"Multi-query expansion: Generated {len(queries)-1} additional queries") | |
| except Exception as e: | |
| # Log for debugging but don't fail - single query still works fine | |
| print(f"Multi-query expansion skipped ({e}). Continuing with original query.") | |
| pass | |
| def dedup_by_text(docs: List[Document]) -> List[Document]: | |
| seen = set() | |
| unique = [] | |
| for d in docs: | |
| key = (d.page_content.strip(), str(d.metadata)) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| unique.append(d) | |
| return unique | |
| # Collect candidates per retriever | |
| candidates: List[Tuple[Document, int]] = [] # (doc, rank) | |
| print(f"Retrieve: Processing {len(queries)} queries: {[q[:50] for q in queries]}") | |
| for i, q in enumerate(queries): | |
| # Embedding hits - always retrieve, don't filter by threshold at this stage | |
| try: | |
| docs = embedding_retriever.invoke(q) | |
| print(f" Query {i+1}: Embedding retriever returned {len(docs)} docs for: '{q[:50]}...'") | |
| except Exception as e: | |
| print(f" Query {i+1}: Embedding invoke failed: {e}, trying get_relevant_documents") | |
| try: | |
| docs = embedding_retriever.get_relevant_documents(q) | |
| print(f" Query {i+1}: get_relevant_documents returned {len(docs)} docs") | |
| except Exception as e2: | |
| print(f" Query {i+1}: get_relevant_documents also failed: {e2}") | |
| docs = [] | |
| for rank, d in enumerate(docs): | |
| candidates.append((d, rank)) | |
| # BM25 hits | |
| if bm25 is not None: | |
| try: | |
| # Try invoke first (newer LangChain), fall back to get_relevant_documents | |
| try: | |
| bm25_docs = bm25.invoke(q) | |
| except AttributeError: | |
| bm25_docs = bm25.get_relevant_documents(q) | |
| print(f" BM25 returned {len(bm25_docs)} docs for query: {q[:50]}") | |
| for rank, d in enumerate(bm25_docs): | |
| candidates.append((d, rank)) | |
| except Exception as e: | |
| print(f" BM25 retrieval failed: {e}") | |
| pass | |
| # Reciprocal Rank Fusion | |
| scores = {} | |
| for d, r in candidates: | |
| key = (d.page_content, tuple(sorted(d.metadata.items()))) if isinstance(d.metadata, dict) else (d.page_content, str(d.metadata)) | |
| scores[key] = scores.get(key, 0) + 1.0 / (60 + r) # 60 for stability | |
| # Rebuild documents with aggregated scores | |
| scored_docs = [] | |
| for d, r in candidates: | |
| key = (d.page_content, tuple(sorted(d.metadata.items()))) if isinstance(d.metadata, dict) else (d.page_content, str(d.metadata)) | |
| if key in scores: | |
| d.metadata = dict(d.metadata or {}) | |
| d.metadata["rrf_score"] = scores[key] | |
| scored_docs.append(d) | |
| # Sort by fused score desc, then truncate | |
| scored_docs.sort(key=lambda x: x.metadata.get("rrf_score", 0), reverse=True) | |
| out = dedup_by_text(scored_docs)[:15] | |
| print(f"Retrieve: Final result: {len(out)} documents after deduplication and ranking") | |
| return out | |
| # Return a simple invokable object that mirrors the output shape of create_retrieval_chain | |
| class SimpleRAG: | |
| def invoke(self, inputs): | |
| q = inputs.get("input", "") | |
| chat_history = inputs.get("chat_history", []) | |
| print(f"SimpleRAG: Processing query: '{q[:100]}...'") | |
| docs = retrieve(q, chat_history) | |
| print(f"SimpleRAG: Retrieved {len(docs)} documents") | |
| if not docs: | |
| print("SimpleRAG: No documents retrieved, returning 'I don't know' response") | |
| return {"answer": "I don't know based on the uploaded documents. Please make sure you have uploaded PDF documents to this session.", "context": []} | |
| answer = question_answer_chain.invoke({ | |
| "input": q, | |
| "chat_history": chat_history, | |
| "context": docs, | |
| }) | |
| print(f"SimpleRAG: Generated answer: '{answer[:100]}...'") | |
| # create_stuff_documents_chain returns a string by default | |
| return {"answer": answer, "context": docs} | |
| return SimpleRAG() | |