manabb's picture
Update app.py
77a614a verified
import gradio as gr
import os
import shutil
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain_community.document_loaders import PyPDFLoader, PyMuPDFLoader
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_core.documents import Document
from huggingface_hub import hf_hub_download, HfApi
import tempfile
# ========================================
# ENHANCED PDF LOADER WITH METADATA
# ========================================
def load_pdf_with_metadata(file_path):
"""Load PDF with document number and page numbers"""
documents = []
try:
# PyMuPDF for better metadata extraction
import fitz # PyMuPDF
doc = fitz.open(file_path)
for page_num in range(len(doc)):
page = doc.load_page(page_num)
text = page.get_text()
# Create Document with metadata
metadata = {
"source": os.path.basename(file_path),
"document_number": os.path.splitext(os.path.basename(file_path))[0], # e.g., "DOC001"
"page_number": page_num + 1,
"total_pages": len(doc)
}
documents.append(Document(page_content=text, metadata=metadata))
doc.close()
return documents
except:
# Fallback to PyPDFLoader
loader = PyPDFLoader(file_path)
docs = loader.load()
for i, doc in enumerate(docs):
doc.metadata.update({
"source": os.path.basename(file_path),
"document_number": os.path.splitext(os.path.basename(file_path))[0],
"page_number": i + 1,
"total_pages": len(docs)
})
return docs
# ========================================
# UPDATED CREATE INDEX WITH METADATA
# ========================================
def create_faiss_index(repo_id, file_path, embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
"""Create FAISS with document/page metadata"""
embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
# Load with metadata
documents = load_pdf_with_metadata(file_path)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
split_docs = text_splitter.split_documents(documents)
# Save split docs metadata for later
with open("temp_metadata.json", "w") as f:
import json
json.dump([doc.metadata for doc in split_docs], f)
db = FAISS.from_documents(split_docs, embeddings)
db.save_local("temp_faiss")
# Upload
api = HfApi(token=os.getenv("HF_token"))
api.upload_file("temp_faiss/index.faiss", "index.faiss", repo_id, repo_type="dataset")
api.upload_file("temp_faiss/index.pkl", "index.pkl", repo_id, repo_type="dataset")
api.upload_file("temp_metadata.json", "metadata.json", repo_id, repo_type="dataset")
return f"βœ… Created index with metadata for {len(split_docs)} chunks"
# ========================================
# ENHANCED QA CHAIN WITH CITATIONS
# ========================================
def generate_qa_chain_with_citations(repo_id, llm):
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Download files
# In generate_qa_chain_with_citations(), replace:
faiss_path = hf_hub_download(
repo_id=repo_id,
filename="index.faiss",
repo_type="dataset",
cache_dir="/tmp/hf_cache" # Dedicated cache
)
#faiss_path = hf_hub_download(repo_id=repo_id, filename="index.faiss", repo_type="dataset")
pkl_path = hf_hub_download(repo_id=repo_id, filename="index.pkl", repo_type="dataset")
metadata_path = hf_hub_download(repo_id=repo_id, filename="metadata.json", repo_type="dataset")
# Load vectorstore
vectorstore = FAISS.load_local(os.path.dirname(faiss_path), embeddings, allow_dangerous_deserialization=True)
retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
prompt_template = PromptTemplate(
input_variables=["context", "question"],
template="""
Answer STRICTLY based on context. Include [DOC:docnum, PAGE:pagenum] citations.
Question: {question}
Context: {context}
Answer with citations:
"""
)
qa_chain = RetrievalQA.from_chain_type(
llm=llm, chain_type="stuff", chain_type_kwargs={"prompt": prompt_template},
retriever=retriever, return_source_documents=True
)
return qa_chain, metadata_path
# ========================================
# CITATION FORMATTER WITH LINKS
# ========================================
def format_citations_with_links(sources, uploaded_files):
"""Create clickable citations with document links"""
citations_html = []
for i, source_doc in enumerate(sources):
doc_num = source_doc.metadata.get("document_number", "Unknown")
page_num = source_doc.metadata.get("page_number", 1)
source_file = source_doc.metadata.get("source", "Unknown")
snippet = source_doc.page_content[:200] + "..." if len(source_doc.page_content) > 200 else source_doc.page_content
# Find uploaded file path
file_path = None
for fname, fpath in uploaded_files.items():
if source_file == fname:
file_path = fpath
break
if file_path:
# Create clickable link to page (using PDF.js or browser)
citation_html = f"""
<div style="margin: 10px 0; padding: 10px; border-left: 4px solid #007bff; background: #f8f9fa;">
<strong>πŸ“„ <a href="{file_path}#page={page_num}" target="_blank">{doc_num}</a></strong>
<span style="color: #666;">(Page {page_num})</span><br>
<small>{snippet}</small>
</div>
"""
else:
citation_html = f"""
<div style="margin: 10px 0; padding: 10px; border-left: 4px solid #dc3545; background: #f8d7da;">
<strong>πŸ“„ {doc_num}</strong>
<span style="color: #666;">(Page {page_num})</span><br>
<small>{snippet}</small>
</div>
"""
citations_html.append(citation_html)
return "".join(citations_html)
#=========================================
from langchain_huggingface import HuggingFacePipeline
from transformers import pipeline
LLM_CACHE = None # Global cache
def get_cached_llm():
global LLM_CACHE
if LLM_CACHE is None:
LLM_CACHE = HuggingFacePipeline.from_model_id(
model_id="distilgpt2", # Smallest, fastest
task="text-generation",
device_map="cpu",
pipeline_kwargs={"max_new_tokens": 100}
)
return LLM_CACHE
# ========================================
# Creating the llm with model
# ========================================
def create_llm_pipeline():
"""Create LLM pipeline compatible with LangChain"""
return HuggingFacePipeline.from_model_id(
model_id="microsoft/DialoGPT-medium",
task="text-generation",
device_map="auto",
pipeline_kwargs={
"max_new_tokens": 200,
"do_sample": True,
"temperature": 0.7,
"pad_token_id": 0 # Fix tokenizer warning
}
)
# ========================================
# MAIN GRADIO QUERY FUNCTION
# ========================================
def rag_query_with_citations(question, repo_id, history=[], uploaded_files=[]):
try:
#llm = create_llm_pipeline()
llm = get_cached_llm() # Single creation
qa_chain, metadata_path = generate_qa_chain_with_citations(repo_id, llm)
result = qa_chain.invoke({"query": question})
answer = result["result"]
sources = result["source_documents"]
# Format citations
citations = format_citations_with_links(sources, uploaded_files)
history.append([question, f"{answer}\n\n{citations}"])
return history, ""
except Exception as e:
return history, f"❌ Error: {str(e)}"
# ========================================
# GRADIO INTERFACE - ENHANCED
# ========================================
with gr.Blocks(title="NRL Chat for Commercial procurement", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ“š Ask question and get answer from NRL documents")
# File storage state
uploaded_files = gr.State({})
with gr.Row():
# LEFT COLUMN: Document Management
with gr.Column(scale=1):
gr.Markdown("## πŸ“ Document Management")
repo_id_input = gr.Textbox(
label="HF Dataset Repo",
placeholder="manabb/withPDFlink",
value="manabb/withPDFlink",
interactive=False
)
pdf_upload = gr.File(
label="Upload PDF Document",
file_types=[".pdf"],
file_count="multiple"
)
with gr.Row():
create_btn = gr.Button("πŸš€ Create Index", variant="primary")
clear_btn = gr.Button("πŸ—‘οΈ Clear Files", variant="secondary")
index_status = gr.Markdown("πŸ“Š Status: Ready")
# Store uploaded files
# Store uploaded files - CORRECTED VERSION
def store_files(files):
file_dict = {}
if not files:
return {}
for file_obj in files:
if file_obj and hasattr(file_obj, 'name'):
source_path = file_obj.name # This is the string path
# Create temp copy with original name
temp_suffix = os.path.splitext(file_obj.name)[1] or '.pdf'
with tempfile.NamedTemporaryFile(delete=False, suffix=temp_suffix) as tmp:
# Read from file path, not file object
with open(source_path, 'rb') as source_file:
shutil.copyfileobj(source_file, tmp)
file_dict[file_obj.name] = tmp.name
return file_dict
# Update the event handler
pdf_upload.change(
store_files,
inputs=pdf_upload,
outputs=uploaded_files
)
#def store_files(files):
# file_dict = {}
# for f in files:
# if f:
# with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
# tmp.write(f.read())
# tmp.close() # Explicit close
# file_dict[f.name] = tmp.name
# return file_dict
#
#pdf_upload.change(store_files, pdf_upload, uploaded_files)
# RIGHT COLUMN: QA Interface
with gr.Column(scale=2):
gr.Markdown("## ❓ Document QA with Citations")
chatbot = gr.Chatbot(height=500, show_label=True)
with gr.Row():
question_input = gr.Textbox(
label="Ask about your documents",
placeholder="What does section 3.2 say about compliance?",
lines=2
)
repo_id_chat = gr.Textbox(
label="Repo ID",
value="manabb/withPDFlink",
interactive=False
)
submit_btn = gr.Button("πŸ’¬ Answer with Citations", variant="primary")
# Event handlers - ADD THESE MISSING ONES
create_btn.click(
create_faiss_index,
inputs=[repo_id_input, pdf_upload],
outputs=[index_status]
)
clear_btn.click(
lambda: {},
outputs=[uploaded_files]
)
# Event handlers
submit_btn.click(
rag_query_with_citations,
inputs=[question_input, repo_id_chat, chatbot, uploaded_files],
outputs=[chatbot, index_status]
)
question_input.submit(
rag_query_with_citations,
inputs=[question_input, repo_id_chat, chatbot, uploaded_files],
outputs=[chatbot, index_status]
)
gr.Markdown("""
### ✨ **Citation Features**
- **πŸ“„ Document Number**: Extracted from filename (e.g., DOC001)
- **πŸ“ƒ Page Number**: Exact page location
- **πŸ”— Clickable Links**: Jump to exact page in PDF
- **πŸ’¬ Source Snippets**: Context preview
""")
if __name__ == "__main__":
demo.launch(share=True, server_port=7860)