import gradio as gr import os import shutil import torch from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline from langchain_community.document_loaders import PyPDFLoader, PyMuPDFLoader from langchain_community.vectorstores import FAISS from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate from langchain_core.documents import Document from huggingface_hub import hf_hub_download, HfApi import tempfile # ======================================== # ENHANCED PDF LOADER WITH METADATA # ======================================== def load_pdf_with_metadata(file_path): """Load PDF with document number and page numbers""" documents = [] try: # PyMuPDF for better metadata extraction import fitz # PyMuPDF doc = fitz.open(file_path) for page_num in range(len(doc)): page = doc.load_page(page_num) text = page.get_text() # Create Document with metadata metadata = { "source": os.path.basename(file_path), "document_number": os.path.splitext(os.path.basename(file_path))[0], # e.g., "DOC001" "page_number": page_num + 1, "total_pages": len(doc) } documents.append(Document(page_content=text, metadata=metadata)) doc.close() return documents except: # Fallback to PyPDFLoader loader = PyPDFLoader(file_path) docs = loader.load() for i, doc in enumerate(docs): doc.metadata.update({ "source": os.path.basename(file_path), "document_number": os.path.splitext(os.path.basename(file_path))[0], "page_number": i + 1, "total_pages": len(docs) }) return docs # ======================================== # UPDATED CREATE INDEX WITH METADATA # ======================================== def create_faiss_index(repo_id, file_path, embedding_model="sentence-transformers/all-MiniLM-L6-v2"): """Create FAISS with document/page metadata""" embeddings = HuggingFaceEmbeddings(model_name=embedding_model) # Load with metadata documents = load_pdf_with_metadata(file_path) text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) split_docs = text_splitter.split_documents(documents) # Save split docs metadata for later with open("temp_metadata.json", "w") as f: import json json.dump([doc.metadata for doc in split_docs], f) db = FAISS.from_documents(split_docs, embeddings) db.save_local("temp_faiss") # Upload api = HfApi(token=os.getenv("HF_token")) api.upload_file("temp_faiss/index.faiss", "index.faiss", repo_id, repo_type="dataset") api.upload_file("temp_faiss/index.pkl", "index.pkl", repo_id, repo_type="dataset") api.upload_file("temp_metadata.json", "metadata.json", repo_id, repo_type="dataset") return f"✅ Created index with metadata for {len(split_docs)} chunks" # ======================================== # ENHANCED QA CHAIN WITH CITATIONS # ======================================== def generate_qa_chain_with_citations(repo_id, llm): embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") # Download files # In generate_qa_chain_with_citations(), replace: faiss_path = hf_hub_download( repo_id=repo_id, filename="index.faiss", repo_type="dataset", cache_dir="/tmp/hf_cache" # Dedicated cache ) #faiss_path = hf_hub_download(repo_id=repo_id, filename="index.faiss", repo_type="dataset") pkl_path = hf_hub_download(repo_id=repo_id, filename="index.pkl", repo_type="dataset") metadata_path = hf_hub_download(repo_id=repo_id, filename="metadata.json", repo_type="dataset") # Load vectorstore vectorstore = FAISS.load_local(os.path.dirname(faiss_path), embeddings, allow_dangerous_deserialization=True) retriever = vectorstore.as_retriever(search_kwargs={"k": 4}) prompt_template = PromptTemplate( input_variables=["context", "question"], template=""" Answer STRICTLY based on context. Include [DOC:docnum, PAGE:pagenum] citations. Question: {question} Context: {context} Answer with citations: """ ) qa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", chain_type_kwargs={"prompt": prompt_template}, retriever=retriever, return_source_documents=True ) return qa_chain, metadata_path # ======================================== # CITATION FORMATTER WITH LINKS # ======================================== def format_citations_with_links(sources, uploaded_files): """Create clickable citations with document links""" citations_html = [] for i, source_doc in enumerate(sources): doc_num = source_doc.metadata.get("document_number", "Unknown") page_num = source_doc.metadata.get("page_number", 1) source_file = source_doc.metadata.get("source", "Unknown") snippet = source_doc.page_content[:200] + "..." if len(source_doc.page_content) > 200 else source_doc.page_content # Find uploaded file path file_path = None for fname, fpath in uploaded_files.items(): if source_file == fname: file_path = fpath break if file_path: # Create clickable link to page (using PDF.js or browser) citation_html = f"""
📄 {doc_num} (Page {page_num})
{snippet}
""" else: citation_html = f"""
📄 {doc_num} (Page {page_num})
{snippet}
""" citations_html.append(citation_html) return "".join(citations_html) #========================================= from langchain_huggingface import HuggingFacePipeline from transformers import pipeline LLM_CACHE = None # Global cache def get_cached_llm(): global LLM_CACHE if LLM_CACHE is None: LLM_CACHE = HuggingFacePipeline.from_model_id( model_id="distilgpt2", # Smallest, fastest task="text-generation", device_map="cpu", pipeline_kwargs={"max_new_tokens": 100} ) return LLM_CACHE # ======================================== # Creating the llm with model # ======================================== def create_llm_pipeline(): """Create LLM pipeline compatible with LangChain""" return HuggingFacePipeline.from_model_id( model_id="microsoft/DialoGPT-medium", task="text-generation", device_map="auto", pipeline_kwargs={ "max_new_tokens": 200, "do_sample": True, "temperature": 0.7, "pad_token_id": 0 # Fix tokenizer warning } ) # ======================================== # MAIN GRADIO QUERY FUNCTION # ======================================== def rag_query_with_citations(question, repo_id, history=[], uploaded_files=[]): try: #llm = create_llm_pipeline() llm = get_cached_llm() # Single creation qa_chain, metadata_path = generate_qa_chain_with_citations(repo_id, llm) result = qa_chain.invoke({"query": question}) answer = result["result"] sources = result["source_documents"] # Format citations citations = format_citations_with_links(sources, uploaded_files) history.append([question, f"{answer}\n\n{citations}"]) return history, "" except Exception as e: return history, f"❌ Error: {str(e)}" # ======================================== # GRADIO INTERFACE - ENHANCED # ======================================== with gr.Blocks(title="NRL Chat for Commercial procurement", theme=gr.themes.Soft()) as demo: gr.Markdown("# 📚 Ask question and get answer from NRL documents") # File storage state uploaded_files = gr.State({}) with gr.Row(): # LEFT COLUMN: Document Management with gr.Column(scale=1): gr.Markdown("## 📁 Document Management") repo_id_input = gr.Textbox( label="HF Dataset Repo", placeholder="manabb/withPDFlink", value="manabb/withPDFlink", interactive=False ) pdf_upload = gr.File( label="Upload PDF Document", file_types=[".pdf"], file_count="multiple" ) with gr.Row(): create_btn = gr.Button("🚀 Create Index", variant="primary") clear_btn = gr.Button("🗑️ Clear Files", variant="secondary") index_status = gr.Markdown("📊 Status: Ready") # Store uploaded files # Store uploaded files - CORRECTED VERSION def store_files(files): file_dict = {} if not files: return {} for file_obj in files: if file_obj and hasattr(file_obj, 'name'): source_path = file_obj.name # This is the string path # Create temp copy with original name temp_suffix = os.path.splitext(file_obj.name)[1] or '.pdf' with tempfile.NamedTemporaryFile(delete=False, suffix=temp_suffix) as tmp: # Read from file path, not file object with open(source_path, 'rb') as source_file: shutil.copyfileobj(source_file, tmp) file_dict[file_obj.name] = tmp.name return file_dict # Update the event handler pdf_upload.change( store_files, inputs=pdf_upload, outputs=uploaded_files ) #def store_files(files): # file_dict = {} # for f in files: # if f: # with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: # tmp.write(f.read()) # tmp.close() # Explicit close # file_dict[f.name] = tmp.name # return file_dict # #pdf_upload.change(store_files, pdf_upload, uploaded_files) # RIGHT COLUMN: QA Interface with gr.Column(scale=2): gr.Markdown("## ❓ Document QA with Citations") chatbot = gr.Chatbot(height=500, show_label=True) with gr.Row(): question_input = gr.Textbox( label="Ask about your documents", placeholder="What does section 3.2 say about compliance?", lines=2 ) repo_id_chat = gr.Textbox( label="Repo ID", value="manabb/withPDFlink", interactive=False ) submit_btn = gr.Button("💬 Answer with Citations", variant="primary") # Event handlers - ADD THESE MISSING ONES create_btn.click( create_faiss_index, inputs=[repo_id_input, pdf_upload], outputs=[index_status] ) clear_btn.click( lambda: {}, outputs=[uploaded_files] ) # Event handlers submit_btn.click( rag_query_with_citations, inputs=[question_input, repo_id_chat, chatbot, uploaded_files], outputs=[chatbot, index_status] ) question_input.submit( rag_query_with_citations, inputs=[question_input, repo_id_chat, chatbot, uploaded_files], outputs=[chatbot, index_status] ) gr.Markdown(""" ### ✨ **Citation Features** - **📄 Document Number**: Extracted from filename (e.g., DOC001) - **📃 Page Number**: Exact page location - **🔗 Clickable Links**: Jump to exact page in PDF - **💬 Source Snippets**: Context preview """) if __name__ == "__main__": demo.launch(share=True, server_port=7860)