Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import shutil | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline | |
| from langchain_community.document_loaders import PyPDFLoader, PyMuPDFLoader | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.chains import RetrievalQA | |
| from langchain.prompts import PromptTemplate | |
| from langchain_core.documents import Document | |
| from huggingface_hub import hf_hub_download, HfApi | |
| import tempfile | |
| # ======================================== | |
| # ENHANCED PDF LOADER WITH METADATA | |
| # ======================================== | |
| def load_pdf_with_metadata(file_path): | |
| """Load PDF with document number and page numbers""" | |
| documents = [] | |
| try: | |
| # PyMuPDF for better metadata extraction | |
| import fitz # PyMuPDF | |
| doc = fitz.open(file_path) | |
| for page_num in range(len(doc)): | |
| page = doc.load_page(page_num) | |
| text = page.get_text() | |
| # Create Document with metadata | |
| metadata = { | |
| "source": os.path.basename(file_path), | |
| "document_number": os.path.splitext(os.path.basename(file_path))[0], # e.g., "DOC001" | |
| "page_number": page_num + 1, | |
| "total_pages": len(doc) | |
| } | |
| documents.append(Document(page_content=text, metadata=metadata)) | |
| doc.close() | |
| return documents | |
| except: | |
| # Fallback to PyPDFLoader | |
| loader = PyPDFLoader(file_path) | |
| docs = loader.load() | |
| for i, doc in enumerate(docs): | |
| doc.metadata.update({ | |
| "source": os.path.basename(file_path), | |
| "document_number": os.path.splitext(os.path.basename(file_path))[0], | |
| "page_number": i + 1, | |
| "total_pages": len(docs) | |
| }) | |
| return docs | |
| # ======================================== | |
| # UPDATED CREATE INDEX WITH METADATA | |
| # ======================================== | |
| def create_faiss_index(repo_id, file_path, embedding_model="sentence-transformers/all-MiniLM-L6-v2"): | |
| """Create FAISS with document/page metadata""" | |
| embeddings = HuggingFaceEmbeddings(model_name=embedding_model) | |
| # Load with metadata | |
| documents = load_pdf_with_metadata(file_path) | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
| split_docs = text_splitter.split_documents(documents) | |
| # Save split docs metadata for later | |
| with open("temp_metadata.json", "w") as f: | |
| import json | |
| json.dump([doc.metadata for doc in split_docs], f) | |
| db = FAISS.from_documents(split_docs, embeddings) | |
| db.save_local("temp_faiss") | |
| # Upload | |
| api = HfApi(token=os.getenv("HF_token")) | |
| api.upload_file("temp_faiss/index.faiss", "index.faiss", repo_id, repo_type="dataset") | |
| api.upload_file("temp_faiss/index.pkl", "index.pkl", repo_id, repo_type="dataset") | |
| api.upload_file("temp_metadata.json", "metadata.json", repo_id, repo_type="dataset") | |
| return f"β Created index with metadata for {len(split_docs)} chunks" | |
| # ======================================== | |
| # ENHANCED QA CHAIN WITH CITATIONS | |
| # ======================================== | |
| def generate_qa_chain_with_citations(repo_id, llm): | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| # Download files | |
| # In generate_qa_chain_with_citations(), replace: | |
| faiss_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="index.faiss", | |
| repo_type="dataset", | |
| cache_dir="/tmp/hf_cache" # Dedicated cache | |
| ) | |
| #faiss_path = hf_hub_download(repo_id=repo_id, filename="index.faiss", repo_type="dataset") | |
| pkl_path = hf_hub_download(repo_id=repo_id, filename="index.pkl", repo_type="dataset") | |
| metadata_path = hf_hub_download(repo_id=repo_id, filename="metadata.json", repo_type="dataset") | |
| # Load vectorstore | |
| vectorstore = FAISS.load_local(os.path.dirname(faiss_path), embeddings, allow_dangerous_deserialization=True) | |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 4}) | |
| prompt_template = PromptTemplate( | |
| input_variables=["context", "question"], | |
| template=""" | |
| Answer STRICTLY based on context. Include [DOC:docnum, PAGE:pagenum] citations. | |
| Question: {question} | |
| Context: {context} | |
| Answer with citations: | |
| """ | |
| ) | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, chain_type="stuff", chain_type_kwargs={"prompt": prompt_template}, | |
| retriever=retriever, return_source_documents=True | |
| ) | |
| return qa_chain, metadata_path | |
| # ======================================== | |
| # CITATION FORMATTER WITH LINKS | |
| # ======================================== | |
| def format_citations_with_links(sources, uploaded_files): | |
| """Create clickable citations with document links""" | |
| citations_html = [] | |
| for i, source_doc in enumerate(sources): | |
| doc_num = source_doc.metadata.get("document_number", "Unknown") | |
| page_num = source_doc.metadata.get("page_number", 1) | |
| source_file = source_doc.metadata.get("source", "Unknown") | |
| snippet = source_doc.page_content[:200] + "..." if len(source_doc.page_content) > 200 else source_doc.page_content | |
| # Find uploaded file path | |
| file_path = None | |
| for fname, fpath in uploaded_files.items(): | |
| if source_file == fname: | |
| file_path = fpath | |
| break | |
| if file_path: | |
| # Create clickable link to page (using PDF.js or browser) | |
| citation_html = f""" | |
| <div style="margin: 10px 0; padding: 10px; border-left: 4px solid #007bff; background: #f8f9fa;"> | |
| <strong>π <a href="{file_path}#page={page_num}" target="_blank">{doc_num}</a></strong> | |
| <span style="color: #666;">(Page {page_num})</span><br> | |
| <small>{snippet}</small> | |
| </div> | |
| """ | |
| else: | |
| citation_html = f""" | |
| <div style="margin: 10px 0; padding: 10px; border-left: 4px solid #dc3545; background: #f8d7da;"> | |
| <strong>π {doc_num}</strong> | |
| <span style="color: #666;">(Page {page_num})</span><br> | |
| <small>{snippet}</small> | |
| </div> | |
| """ | |
| citations_html.append(citation_html) | |
| return "".join(citations_html) | |
| #========================================= | |
| from langchain_huggingface import HuggingFacePipeline | |
| from transformers import pipeline | |
| LLM_CACHE = None # Global cache | |
| def get_cached_llm(): | |
| global LLM_CACHE | |
| if LLM_CACHE is None: | |
| LLM_CACHE = HuggingFacePipeline.from_model_id( | |
| model_id="distilgpt2", # Smallest, fastest | |
| task="text-generation", | |
| device_map="cpu", | |
| pipeline_kwargs={"max_new_tokens": 100} | |
| ) | |
| return LLM_CACHE | |
| # ======================================== | |
| # Creating the llm with model | |
| # ======================================== | |
| def create_llm_pipeline(): | |
| """Create LLM pipeline compatible with LangChain""" | |
| return HuggingFacePipeline.from_model_id( | |
| model_id="microsoft/DialoGPT-medium", | |
| task="text-generation", | |
| device_map="auto", | |
| pipeline_kwargs={ | |
| "max_new_tokens": 200, | |
| "do_sample": True, | |
| "temperature": 0.7, | |
| "pad_token_id": 0 # Fix tokenizer warning | |
| } | |
| ) | |
| # ======================================== | |
| # MAIN GRADIO QUERY FUNCTION | |
| # ======================================== | |
| def rag_query_with_citations(question, repo_id, history=[], uploaded_files=[]): | |
| try: | |
| #llm = create_llm_pipeline() | |
| llm = get_cached_llm() # Single creation | |
| qa_chain, metadata_path = generate_qa_chain_with_citations(repo_id, llm) | |
| result = qa_chain.invoke({"query": question}) | |
| answer = result["result"] | |
| sources = result["source_documents"] | |
| # Format citations | |
| citations = format_citations_with_links(sources, uploaded_files) | |
| history.append([question, f"{answer}\n\n{citations}"]) | |
| return history, "" | |
| except Exception as e: | |
| return history, f"β Error: {str(e)}" | |
| # ======================================== | |
| # GRADIO INTERFACE - ENHANCED | |
| # ======================================== | |
| with gr.Blocks(title="NRL Chat for Commercial procurement", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π Ask question and get answer from NRL documents") | |
| # File storage state | |
| uploaded_files = gr.State({}) | |
| with gr.Row(): | |
| # LEFT COLUMN: Document Management | |
| with gr.Column(scale=1): | |
| gr.Markdown("## π Document Management") | |
| repo_id_input = gr.Textbox( | |
| label="HF Dataset Repo", | |
| placeholder="manabb/withPDFlink", | |
| value="manabb/withPDFlink", | |
| interactive=False | |
| ) | |
| pdf_upload = gr.File( | |
| label="Upload PDF Document", | |
| file_types=[".pdf"], | |
| file_count="multiple" | |
| ) | |
| with gr.Row(): | |
| create_btn = gr.Button("π Create Index", variant="primary") | |
| clear_btn = gr.Button("ποΈ Clear Files", variant="secondary") | |
| index_status = gr.Markdown("π Status: Ready") | |
| # Store uploaded files | |
| # Store uploaded files - CORRECTED VERSION | |
| def store_files(files): | |
| file_dict = {} | |
| if not files: | |
| return {} | |
| for file_obj in files: | |
| if file_obj and hasattr(file_obj, 'name'): | |
| source_path = file_obj.name # This is the string path | |
| # Create temp copy with original name | |
| temp_suffix = os.path.splitext(file_obj.name)[1] or '.pdf' | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=temp_suffix) as tmp: | |
| # Read from file path, not file object | |
| with open(source_path, 'rb') as source_file: | |
| shutil.copyfileobj(source_file, tmp) | |
| file_dict[file_obj.name] = tmp.name | |
| return file_dict | |
| # Update the event handler | |
| pdf_upload.change( | |
| store_files, | |
| inputs=pdf_upload, | |
| outputs=uploaded_files | |
| ) | |
| #def store_files(files): | |
| # file_dict = {} | |
| # for f in files: | |
| # if f: | |
| # with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: | |
| # tmp.write(f.read()) | |
| # tmp.close() # Explicit close | |
| # file_dict[f.name] = tmp.name | |
| # return file_dict | |
| # | |
| #pdf_upload.change(store_files, pdf_upload, uploaded_files) | |
| # RIGHT COLUMN: QA Interface | |
| with gr.Column(scale=2): | |
| gr.Markdown("## β Document QA with Citations") | |
| chatbot = gr.Chatbot(height=500, show_label=True) | |
| with gr.Row(): | |
| question_input = gr.Textbox( | |
| label="Ask about your documents", | |
| placeholder="What does section 3.2 say about compliance?", | |
| lines=2 | |
| ) | |
| repo_id_chat = gr.Textbox( | |
| label="Repo ID", | |
| value="manabb/withPDFlink", | |
| interactive=False | |
| ) | |
| submit_btn = gr.Button("π¬ Answer with Citations", variant="primary") | |
| # Event handlers - ADD THESE MISSING ONES | |
| create_btn.click( | |
| create_faiss_index, | |
| inputs=[repo_id_input, pdf_upload], | |
| outputs=[index_status] | |
| ) | |
| clear_btn.click( | |
| lambda: {}, | |
| outputs=[uploaded_files] | |
| ) | |
| # Event handlers | |
| submit_btn.click( | |
| rag_query_with_citations, | |
| inputs=[question_input, repo_id_chat, chatbot, uploaded_files], | |
| outputs=[chatbot, index_status] | |
| ) | |
| question_input.submit( | |
| rag_query_with_citations, | |
| inputs=[question_input, repo_id_chat, chatbot, uploaded_files], | |
| outputs=[chatbot, index_status] | |
| ) | |
| gr.Markdown(""" | |
| ### β¨ **Citation Features** | |
| - **π Document Number**: Extracted from filename (e.g., DOC001) | |
| - **π Page Number**: Exact page location | |
| - **π Clickable Links**: Jump to exact page in PDF | |
| - **π¬ Source Snippets**: Context preview | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch(share=True, server_port=7860) | |