import gradio as gr import os import shutil import torch from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline from langchain_community.document_loaders import PyPDFLoader, PyMuPDFLoader from langchain_community.vectorstores import FAISS from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate from langchain_core.documents import Document from huggingface_hub import hf_hub_download, HfApi import tempfile # ======================================== # ENHANCED PDF LOADER WITH METADATA # ======================================== def load_pdf_with_metadata(file_path): """Load PDF with document number and page numbers""" documents = [] try: # PyMuPDF for better metadata extraction import fitz # PyMuPDF doc = fitz.open(file_path) for page_num in range(len(doc)): page = doc.load_page(page_num) text = page.get_text() # Create Document with metadata metadata = { "source": os.path.basename(file_path), "document_number": os.path.splitext(os.path.basename(file_path))[0], # e.g., "DOC001" "page_number": page_num + 1, "total_pages": len(doc) } documents.append(Document(page_content=text, metadata=metadata)) doc.close() return documents except: # Fallback to PyPDFLoader loader = PyPDFLoader(file_path) docs = loader.load() for i, doc in enumerate(docs): doc.metadata.update({ "source": os.path.basename(file_path), "document_number": os.path.splitext(os.path.basename(file_path))[0], "page_number": i + 1, "total_pages": len(docs) }) return docs # ======================================== # UPDATED CREATE INDEX WITH METADATA # ======================================== def create_faiss_index(repo_id, file_path, embedding_model="sentence-transformers/all-MiniLM-L6-v2"): """Create FAISS with document/page metadata""" embeddings = HuggingFaceEmbeddings(model_name=embedding_model) # Load with metadata documents = load_pdf_with_metadata(file_path) text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) split_docs = text_splitter.split_documents(documents) # Save split docs metadata for later with open("temp_metadata.json", "w") as f: import json json.dump([doc.metadata for doc in split_docs], f) db = FAISS.from_documents(split_docs, embeddings) db.save_local("temp_faiss") # Upload api = HfApi(token=os.getenv("HF_token")) api.upload_file("temp_faiss/index.faiss", "index.faiss", repo_id, repo_type="dataset") api.upload_file("temp_faiss/index.pkl", "index.pkl", repo_id, repo_type="dataset") api.upload_file("temp_metadata.json", "metadata.json", repo_id, repo_type="dataset") return f"✅ Created index with metadata for {len(split_docs)} chunks" # ======================================== # ENHANCED QA CHAIN WITH CITATIONS # ======================================== def generate_qa_chain_with_citations(repo_id, llm): embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") # Download files # In generate_qa_chain_with_citations(), replace: faiss_path = hf_hub_download( repo_id=repo_id, filename="index.faiss", repo_type="dataset", cache_dir="/tmp/hf_cache" # Dedicated cache ) #faiss_path = hf_hub_download(repo_id=repo_id, filename="index.faiss", repo_type="dataset") pkl_path = hf_hub_download(repo_id=repo_id, filename="index.pkl", repo_type="dataset") metadata_path = hf_hub_download(repo_id=repo_id, filename="metadata.json", repo_type="dataset") # Load vectorstore vectorstore = FAISS.load_local(os.path.dirname(faiss_path), embeddings, allow_dangerous_deserialization=True) retriever = vectorstore.as_retriever(search_kwargs={"k": 4}) prompt_template = PromptTemplate( input_variables=["context", "question"], template=""" Answer STRICTLY based on context. Include [DOC:docnum, PAGE:pagenum] citations. Question: {question} Context: {context} Answer with citations: """ ) qa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", chain_type_kwargs={"prompt": prompt_template}, retriever=retriever, return_source_documents=True ) return qa_chain, metadata_path # ======================================== # CITATION FORMATTER WITH LINKS # ======================================== def format_citations_with_links(sources, uploaded_files): """Create clickable citations with document links""" citations_html = [] for i, source_doc in enumerate(sources): doc_num = source_doc.metadata.get("document_number", "Unknown") page_num = source_doc.metadata.get("page_number", 1) source_file = source_doc.metadata.get("source", "Unknown") snippet = source_doc.page_content[:200] + "..." if len(source_doc.page_content) > 200 else source_doc.page_content # Find uploaded file path file_path = None for fname, fpath in uploaded_files.items(): if source_file == fname: file_path = fpath break if file_path: # Create clickable link to page (using PDF.js or browser) citation_html = f"""
""" else: citation_html = f"""