manabb commited on
Commit
156c494
Β·
verified Β·
1 Parent(s): 4ad76e2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -0
app.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import shutil
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
+ from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
7
+ from langchain_community.document_loaders import PyPDFLoader, PyMuPDFLoader
8
+ from langchain_community.vectorstores import FAISS
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langchain.chains import RetrievalQA
11
+ from langchain.prompts import PromptTemplate
12
+ from langchain_core.documents import Document
13
+ from huggingface_hub import hf_hub_download, HfApi
14
+ import tempfile
15
+
16
+ # ========================================
17
+ # ENHANCED PDF LOADER WITH METADATA
18
+ # ========================================
19
+ def load_pdf_with_metadata(file_path):
20
+ """Load PDF with document number and page numbers"""
21
+ documents = []
22
+ try:
23
+ # PyMuPDF for better metadata extraction
24
+ import fitz # PyMuPDF
25
+ doc = fitz.open(file_path)
26
+
27
+ for page_num in range(len(doc)):
28
+ page = doc.load_page(page_num)
29
+ text = page.get_text()
30
+
31
+ # Create Document with metadata
32
+ metadata = {
33
+ "source": os.path.basename(file_path),
34
+ "document_number": os.path.splitext(os.path.basename(file_path))[0], # e.g., "DOC001"
35
+ "page_number": page_num + 1,
36
+ "total_pages": len(doc)
37
+ }
38
+
39
+ documents.append(Document(page_content=text, metadata=metadata))
40
+
41
+ doc.close()
42
+ return documents
43
+ except:
44
+ # Fallback to PyPDFLoader
45
+ loader = PyPDFLoader(file_path)
46
+ docs = loader.load()
47
+ for i, doc in enumerate(docs):
48
+ doc.metadata.update({
49
+ "source": os.path.basename(file_path),
50
+ "document_number": os.path.splitext(os.path.basename(file_path))[0],
51
+ "page_number": i + 1,
52
+ "total_pages": len(docs)
53
+ })
54
+ return docs
55
+
56
+ # ========================================
57
+ # UPDATED CREATE INDEX WITH METADATA
58
+ # ========================================
59
+ def create_faiss_index(repo_id, file_path, embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
60
+ """Create FAISS with document/page metadata"""
61
+ embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
62
+
63
+ # Load with metadata
64
+ documents = load_pdf_with_metadata(file_path)
65
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
66
+ split_docs = text_splitter.split_documents(documents)
67
+
68
+ # Save split docs metadata for later
69
+ with open("temp_metadata.json", "w") as f:
70
+ import json
71
+ json.dump([doc.metadata for doc in split_docs], f)
72
+
73
+ db = FAISS.from_documents(split_docs, embeddings)
74
+ db.save_local("temp_faiss")
75
+
76
+ # Upload
77
+ api = HfApi(token=os.getenv("HF_token"))
78
+ api.upload_file("temp_faiss/index.faiss", "index.faiss", repo_id, repo_type="dataset")
79
+ api.upload_file("temp_faiss/index.pkl", "index.pkl", repo_id, repo_type="dataset")
80
+ api.upload_file("temp_metadata.json", "metadata.json", repo_id, repo_type="dataset")
81
+
82
+ return f"βœ… Created index with metadata for {len(split_docs)} chunks"
83
+
84
+ # ========================================
85
+ # ENHANCED QA CHAIN WITH CITATIONS
86
+ # ========================================
87
+ def generate_qa_chain_with_citations(repo_id, llm):
88
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
89
+
90
+ # Download files
91
+ faiss_path = hf_hub_download(repo_id=repo_id, filename="index.faiss", repo_type="dataset")
92
+ pkl_path = hf_hub_download(repo_id=repo_id, filename="index.pkl", repo_type="dataset")
93
+ metadata_path = hf_hub_download(repo_id=repo_id, filename="metadata.json", repo_type="dataset")
94
+
95
+ # Load vectorstore
96
+ vectorstore = FAISS.load_local(os.path.dirname(faiss_path), embeddings, allow_dangerous_deserialization=True)
97
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
98
+
99
+ prompt_template = PromptTemplate(
100
+ input_variables=["context", "question"],
101
+ template="""
102
+ Answer STRICTLY based on context. Include [DOC:docnum, PAGE:pagenum] citations.
103
+
104
+ Question: {question}
105
+ Context: {context}
106
+ Answer with citations:
107
+ """
108
+ )
109
+
110
+ qa_chain = RetrievalQA.from_chain_type(
111
+ llm=llm, chain_type="stuff", chain_type_kwargs={"prompt": prompt_template},
112
+ retriever=retriever, return_source_documents=True
113
+ )
114
+ return qa_chain, metadata_path
115
+
116
+ # ========================================
117
+ # CITATION FORMATTER WITH LINKS
118
+ # ========================================
119
+ def format_citations_with_links(sources, uploaded_files):
120
+ """Create clickable citations with document links"""
121
+ citations_html = []
122
+
123
+ for i, source_doc in enumerate(sources):
124
+ doc_num = source_doc.metadata.get("document_number", "Unknown")
125
+ page_num = source_doc.metadata.get("page_number", 1)
126
+ source_file = source_doc.metadata.get("source", "Unknown")
127
+ snippet = source_doc.page_content[:200] + "..." if len(source_doc.page_content) > 200 else source_doc.page_content
128
+
129
+ # Find uploaded file path
130
+ file_path = None
131
+ for fname, fpath in uploaded_files.items():
132
+ if source_file == fname:
133
+ file_path = fpath
134
+ break
135
+
136
+ if file_path:
137
+ # Create clickable link to page (using PDF.js or browser)
138
+ citation_html = f"""
139
+ <div style="margin: 10px 0; padding: 10px; border-left: 4px solid #007bff; background: #f8f9fa;">
140
+ <strong>πŸ“„ <a href="{file_path}#page={page_num}" target="_blank">{doc_num}</a></strong>
141
+ <span style="color: #666;">(Page {page_num})</span><br>
142
+ <small>{snippet}</small>
143
+ </div>
144
+ """
145
+ else:
146
+ citation_html = f"""
147
+ <div style="margin: 10px 0; padding: 10px; border-left: 4px solid #dc3545; background: #f8d7da;">
148
+ <strong>πŸ“„ {doc_num}</strong>
149
+ <span style="color: #666;">(Page {page_num})</span><br>
150
+ <small>{snippet}</small>
151
+ </div>
152
+ """
153
+
154
+ citations_html.append(citation_html)
155
+
156
+ return "".join(citations_html)
157
+
158
+ # ========================================
159
+ # MAIN GRADIO QUERY FUNCTION
160
+ # ========================================
161
+ def rag_query_with_citations(question, repo_id, history=[], uploaded_files=[]):
162
+ try:
163
+ llm = create_llm_pipeline()
164
+ qa_chain, metadata_path = generate_qa_chain_with_citations(repo_id, llm)
165
+
166
+ result = qa_chain.invoke({"query": question})
167
+ answer = result["result"]
168
+ sources = result["source_documents"]
169
+
170
+ # Format citations
171
+ citations = format_citations_with_links(sources, uploaded_files)
172
+
173
+ history.append([question, f"{answer}\n\n{citations}"])
174
+ return history, ""
175
+ except Exception as e:
176
+ return history, f"❌ Error: {str(e)}"
177
+
178
+ # ========================================
179
+ # GRADIO INTERFACE - ENHANCED
180
+ # ========================================
181
+ with gr.Blocks(title="RAG QA with Citations", theme=gr.themes.Soft()) as demo:
182
+ gr.Markdown("# πŸ“š RAG QA with **Document Citations & Page Links**")
183
+
184
+ # File storage state
185
+ uploaded_files = gr.State({})
186
+
187
+ with gr.Row():
188
+ # LEFT COLUMN: Document Management
189
+ with gr.Column(scale=1):
190
+ gr.Markdown("## πŸ“ Document Management")
191
+
192
+ repo_id_input = gr.Textbox(
193
+ label="HF Dataset Repo",
194
+ placeholder="yourusername/rag-docs",
195
+ value="yourusername/rag-docs"
196
+ )
197
+
198
+ pdf_upload = gr.File(
199
+ label="Upload PDF Document",
200
+ file_types=[".pdf"],
201
+ file_count="multiple"
202
+ )
203
+
204
+ with gr.Row():
205
+ create_btn = gr.Button("πŸš€ Create Index", variant="primary")
206
+ clear_btn = gr.Button("πŸ—‘οΈ Clear Files", variant="secondary")
207
+
208
+ index_status = gr.Markdown("πŸ“Š Status: Ready")
209
+
210
+ # Store uploaded files
211
+ def store_files(files):
212
+ file_dict = {}
213
+ for f in files:
214
+ if f:
215
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
216
+ tmp.write(f.read())
217
+ file_dict[f.name] = tmp.name
218
+ return file_dict
219
+
220
+ pdf_upload.change(store_files, pdf_upload, uploaded_files)
221
+
222
+ # RIGHT COLUMN: QA Interface
223
+ with gr.Column(scale=2):
224
+ gr.Markdown("## ❓ Document QA with Citations")
225
+
226
+ chatbot = gr.Chatbot(height=500, show_label=True)
227
+
228
+ with gr.Row():
229
+ question_input = gr.Textbox(
230
+ label="Ask about your documents",
231
+ placeholder="What does section 3.2 say about compliance?",
232
+ lines=2
233
+ )
234
+ repo_id_chat = gr.Textbox(
235
+ label="Repo ID",
236
+ value="yourusername/rag-docs"
237
+ )
238
+
239
+ submit_btn = gr.Button("πŸ’¬ Answer with Citations", variant="primary")
240
+
241
+ # Event handlers
242
+ submit_btn.click(
243
+ rag_query_with_citations,
244
+ inputs=[question_input, repo_id_chat, chatbot, uploaded_files],
245
+ outputs=[chatbot, index_status]
246
+ )
247
+
248
+ question_input.submit(
249
+ rag_query_with_citations,
250
+ inputs=[question_input, repo_id_chat, chatbot, uploaded_files],
251
+ outputs=[chatbot, index_status]
252
+ )
253
+
254
+ gr.Markdown("""
255
+ ### ✨ **Citation Features**
256
+ - **πŸ“„ Document Number**: Extracted from filename (e.g., DOC001)
257
+ - **πŸ“ƒ Page Number**: Exact page location
258
+ - **πŸ”— Clickable Links**: Jump to exact page in PDF
259
+ - **πŸ’¬ Source Snippets**: Context preview
260
+ """)
261
+
262
+ if __name__ == "__main__":
263
+ demo.launch(share=True, server_port=7860)