manabb commited on
Commit
fc8e85d
Β·
verified Β·
1 Parent(s): 3f43c54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +318 -40
app.py CHANGED
@@ -9,82 +9,360 @@ from langchain.document_loaders import TextLoader
9
  from langchain_text_splitters import RecursiveCharacterTextSplitter
10
  from langchain.chains import RetrievalQA
11
  from langchain.llms import HuggingFacePipeline
12
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
13
- from langchain.document_loaders import PyPDFLoader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # Optional: Set HF Token if needed
16
  # os.environ['HUGGINGFACEHUB_API_TOKEN'] = 'hf_XXXX'
17
 
18
  # Initialize embedding model
19
- embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
20
 
21
  # Load HF model (lightweight for CPU)
22
- model_name = "google/flan-t5-small"
23
- tokenizer = AutoTokenizer.from_pretrained(model_name)
24
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
25
 
26
  # Wrap in pipeline
27
- pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_length=512)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  llm = HuggingFacePipeline(pipeline=pipe)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- def process_file(file_path):
31
- # Load & split document
32
- #loader = TextLoader(file_path)
33
- loader = PyPDFLoader(file_path)
34
- documents = loader.load()
35
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
36
- docs = text_splitter.split_documents(documents)
37
-
38
- # Create vector DB
39
- vector_db = FAISS.from_documents(docs, embedding_model)
40
- retriever = vector_db.as_retriever()
41
-
42
- # Setup RetrievalQA chain
43
- qa_chain = RetrievalQA.from_chain_type(
44
- llm=llm,
45
- chain_type="stuff",
46
- retriever=retriever
47
- )
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  return qa_chain
50
 
51
- # Store the QA chain globally (across UI events)
52
- qa_chain = None
 
 
 
53
 
54
- def upload_and_prepare(file):
 
55
  global qa_chain
56
- # qa_chain = process_file(file)
57
- qa_chain = process_file(file.name)
58
- return "βœ… Document processed. You can now ask questions!"
59
 
60
  def ask_question(query):
61
  if not qa_chain:
62
- return "❌ Please upload a document first."
63
  response = qa_chain.invoke({"query": query})
64
  return response["result"]
65
 
 
66
  # Gradio UI
67
  with gr.Blocks() as demo:
68
- gr.Markdown("## 🧠 Ask Questions About Your Document (LangChain + Hugging Face)")
69
 
70
  with gr.Row():
71
- file_input = gr.File(label="πŸ“„ Upload .pdf File", type="filepath")
72
- upload_btn = gr.Button("πŸ”„ Process Document")
73
-
74
- upload_output = gr.Textbox(label="πŸ“ Status", interactive=False)
75
 
76
  with gr.Row():
77
- query_input = gr.Textbox(label="❓ Your Question")
78
  query_btn = gr.Button("🧠 Get Answer")
79
 
80
- answer_output = gr.Textbox(label="βœ… Answer", lines=4)
 
 
 
 
 
 
 
 
81
 
82
- upload_btn.click(upload_and_prepare, inputs=file_input, outputs=upload_output)
83
  query_btn.click(ask_question, inputs=query_input, outputs=answer_output)
 
 
84
 
85
  # For local dev use: demo.launch()
86
  # For HF Spaces
 
87
  if __name__ == "__main__":
88
  demo.launch()
 
89
 
90
 
 
9
  from langchain_text_splitters import RecursiveCharacterTextSplitter
10
  from langchain.chains import RetrievalQA
11
  from langchain.llms import HuggingFacePipeline
12
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, AutoModelForCausalLM
13
+ from langchain.document_loaders import PyPDFLoader, PyMuPDFLoader
14
+ import pypdf
15
+ from langchain.prompts import PromptTemplate
16
+ from huggingface_hub import upload_folder
17
+ from huggingface_hub import HfApi, upload_file
18
+ from huggingface_hub import hf_hub_download
19
+ from huggingface_hub import (
20
+ file_exists,
21
+ upload_file,
22
+ repo_exists,
23
+ create_repo,
24
+ hf_hub_download
25
+ )
26
+ import shutil
27
+ import torch
28
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
29
+ from langchain_huggingface import HuggingFacePipeline
30
+
31
+ # Optional: Set HF Token if needed-allWrite
32
+ os.environ['HUGGINGFACEHUB_API_TOKEN'] = os.getenv("HF_TOKEN")
33
+
34
+ # Initialize embedding model
35
+ embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
36
+
37
+ #Create pipeline
38
+ pipe = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
39
+
40
+ #Build LLM
41
+ llm = HuggingFacePipeline(pipeline=pipe)
42
+ # Wrap in pipeline
43
+ #pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=512)
44
+ #llm = HuggingFacePipeline(pipeline=pipe)
45
+
46
+ # Store the QA chain globally (across UI events)
47
+ qa_chain = None
48
+
49
+ repo_id="manabb/nrl"
50
 
51
  # Optional: Set HF Token if needed
52
  # os.environ['HUGGINGFACEHUB_API_TOKEN'] = 'hf_XXXX'
53
 
54
  # Initialize embedding model
55
+ #embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
56
 
57
  # Load HF model (lightweight for CPU)
58
+ #model_name = "google/flan-t5-small"
59
+ #tokenizer = AutoTokenizer.from_pretrained(model_name)
60
+ #model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
61
 
62
  # Wrap in pipeline
63
+ #pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_length=512)
64
+ #llm = HuggingFacePipeline(pipeline=pipe)
65
+
66
+ #======
67
+ # Create optimized pipeline for TinyLlama
68
+ pipe = pipeline(
69
+ "text-generation",
70
+ model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
71
+ tokenizer=AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0"),
72
+ device_map="auto" if torch.cuda.is_available() else None,
73
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
74
+ max_new_tokens=512,
75
+ temperature=0.7,
76
+ top_p=0.95,
77
+ do_sample=True,
78
+ repetition_penalty=1.15,
79
+ pad_token_id=tokenizer.eos_token_id if 'tokenizer' in locals() else 128001,
80
+ trust_remote_code=True
81
+ )
82
+
83
+ # Build LangChain LLM wrapper
84
  llm = HuggingFacePipeline(pipeline=pipe)
85
+ #=====
86
+
87
+ def create_faiss_index(repo_id, file, embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
88
+ """Create FAISS index from PDF and upload to HF dataset repo"""
89
+ message = "Index creation started"
90
+
91
+ try:
92
+ # Step 1: Create proper embeddings object (CRITICAL FIX)
93
+ embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
94
+
95
+ # Step 2: Clean temp directory
96
+ if os.path.exists("temp_faiss"):
97
+ shutil.rmtree("temp_faiss")
98
+
99
+ # Step 3: Try PyPDFLoader first
100
+ loader = PyPDFLoader(file)
101
+ documents = loader.load()
102
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
103
+ new_docs = text_splitter.split_documents(documents)
104
+ db = FAISS.from_documents(new_docs, embeddings)
105
+ db.save_local("temp_faiss")
106
+
107
+ # Step 4: Upload to HF Hub
108
+ api = HfApi(token=os.getenv("HF_TOKEN"))
109
+ api.upload_file(path_or_fileobj="temp_faiss/index.faiss", path_in_repo="index.faiss", repo_id=repo_id, repo_type="dataset")
110
+ api.upload_file(path_or_fileobj="temp_faiss/index.pkl", path_in_repo="index.pkl", repo_id=repo_id, repo_type="dataset")
111
+
112
+ message = "βœ… Index created successfully with PyPDFLoader and uploaded to repo"
113
+
114
+ except Exception as e1:
115
+ try:
116
+ print(f"PyPDFLoader failed: {e1}")
117
+
118
+ # Step 5: Fallback to PyMuPDFLoader
119
+ loader = PyMuPDFLoader(file)
120
+ documents = loader.load()
121
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
122
+ new_docs = text_splitter.split_documents(documents)
123
+
124
+ # Use same embeddings instance
125
+ embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
126
+ db = FAISS.from_documents(new_docs, embeddings)
127
+ db.save_local("temp_faiss")
128
+
129
+ # Upload
130
+ api = HfApi(token=os.getenv("HF_TOKEN"))
131
+ api.upload_file(path_or_fileobj="temp_faiss/index.faiss", path_in_repo="index.faiss", repo_id=repo_id, repo_type="dataset")
132
+ api.upload_file(path_or_fileobj="temp_faiss/index.pkl", path_in_repo="index.pkl", repo_id=repo_id, repo_type="dataset")
133
+
134
+ message = f"βœ… PyPDFLoader failed ({e1}), PyMuPDFLoader succeeded and uploaded to repo"
135
+
136
+ except Exception as e2:
137
+ message = f"❌ Both loaders failed. PyPDF: {e1}, PyMuPDF: {e2}"
138
+
139
+ finally:
140
+ # Cleanup
141
+ if os.path.exists("temp_faiss"):
142
+ shutil.rmtree("temp_faiss")
143
+
144
+ return message
145
+
146
+ # Usage
147
+ #result = create_faiss_index("your_username/your-dataset", "path/to/your/file.pdf")
148
+ #print(result)
149
+ #=============
150
+ def update_faiss_from_hf(repo_id, file, embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
151
+ """Load existing FAISS from HF, add new docs, push updated version."""
152
+ message = ""
153
+
154
+ try:
155
+ # Step 1: Create embeddings
156
+ embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
157
+
158
+ # Step 2: Download existing FAISS files
159
+ print("Downloading existing FAISS index...")
160
+ faiss_path = hf_hub_download(repo_id=repo_id, filename="index.faiss", repo_type="dataset")
161
+ pkl_path = hf_hub_download(repo_id=repo_id, filename="index.pkl", repo_type="dataset")
162
+
163
+ # Step 3: Load existing vectorstore
164
+ folder_path = os.path.dirname(faiss_path)
165
+ vectorstore = FAISS.load_local(
166
+ folder_path=folder_path,
167
+ embeddings=embeddings,
168
+ allow_dangerous_deserialization=True
169
+ )
170
+ message += f"βœ… Loaded existing index with {vectorstore.index.ntotal} vectors\n"
171
+
172
+ # Step 4: Load new document with fallback
173
+ documents = None
174
+ loaders = [
175
+ ("PyPDFLoader", PyPDFLoader),
176
+ ("PyMuPDFLoader", PyMuPDFLoader)
177
+ ]
178
+
179
+ for loader_name, LoaderClass in loaders:
180
+ try:
181
+ print(f"Trying {loader_name}...")
182
+ loader = LoaderClass(file)
183
+ documents = loader.load()
184
+ message += f"βœ… Loaded {len(documents)} pages with {loader_name}\n"
185
+ break
186
+ except Exception as e:
187
+ message += f"❌ {loader_name} failed: {str(e)[:100]}...\n"
188
+ continue
189
+
190
+ if documents is None:
191
+ return "❌ All PDF loaders failed"
192
+
193
+ # Step 5: Split documents
194
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
195
+ new_docs = text_splitter.split_documents(documents)
196
+ message += f"βœ… Created {len(new_docs)} chunks from new document\n"
197
+
198
+ # Step 6: Add new documents to existing index
199
+ vectorstore.add_documents(new_docs)
200
+ message += f"βœ… Added to index. New total: {vectorstore.index.ntotal} vectors\n"
201
+
202
+ # Step 7: Save updated index
203
+ temp_dir = "temp_faiss_update"
204
+ if os.path.exists(temp_dir):
205
+ shutil.rmtree(temp_dir)
206
+ vectorstore.save_local(temp_dir)
207
+
208
+ # Step 8: Upload updated files
209
+ api = HfApi(token=os.getenv("HF_TOKEN")) # Replace with your token
210
+ api.upload_file(
211
+ path_or_fileobj=f"{temp_dir}/index.faiss",
212
+ path_in_repo="index.faiss",
213
+ repo_id=repo_id,
214
+ repo_type="dataset"
215
+ )
216
+ api.upload_file(
217
+ path_or_fileobj=f"{temp_dir}/index.pkl",
218
+ path_in_repo="index.pkl",
219
+ repo_id=repo_id,
220
+ repo_type="dataset"
221
+ )
222
+
223
+ message += f"βœ… Successfully updated repo with {len(new_docs)} new chunks!"
224
+
225
+ except Exception as e:
226
+ message += f"❌ Update failed: {str(e)}"
227
+
228
+ finally:
229
+ # Cleanup
230
+ if os.path.exists("temp_faiss_update"):
231
+ shutil.rmtree("temp_faiss_update")
232
+
233
+ return message
234
+
235
+ # Usage
236
+ # result = update_faiss_from_hf("yourusername/my-faiss-store", "new_document.pdf")
237
+ # print(result)
238
+ #====================
239
+ def upload_and_prepare(file,user):
240
+ # Load & split document
241
+ mm=""
242
+ if user == "manab251225":
243
+ if file_exists(repo_id=repo_id, filename="index.faiss", repo_type="dataset"):
244
+ mm=update_faiss_from_hf(repo_id, file)
245
+ #mm="βœ… Document processed. New index added. You can now ask questions!"
246
+ if not file_exists(repo_id=repo_id, filename="index.faiss", repo_type="dataset"):
247
+ mm=create_faiss_index(repo_id, file)
248
+ #mm="βœ… Document processed. New index created. You can now ask questions!"
249
+ else:
250
+ mm="❌ Unauthorized User"
251
+ return mm
252
+ #create_faiss_index(repo_id, file_input)
253
+ #======================================================================
254
+
255
+ def generate_qa_chain(repo_id, embedding_model="sentence-transformers/all-MiniLM-L6-v2", llm=None):
256
+ """
257
+ Generate QA chain from HF dataset repo FAISS index
258
+ """
259
+ try:
260
+ # Step 1: Create embeddings (FIX: was missing)
261
+ embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
262
+
263
+ # Step 2: Download FAISS files from HF Hub
264
+ faiss_path = hf_hub_download(
265
+ repo_id=repo_id,
266
+ filename="index.faiss",
267
+ repo_type="dataset"
268
+ )
269
+ pkl_path = hf_hub_download(
270
+ repo_id=repo_id,
271
+ filename="index.pkl",
272
+ repo_type="dataset"
273
+ )
274
+
275
+ # Step 3: Load FAISS vectorstore (FIX: pass embeddings object, not string)
276
+ folder_path = os.path.dirname(faiss_path)
277
+ vectorstore = FAISS.load_local(
278
+ folder_path=folder_path,
279
+ embeddings=embeddings, # FIXED: was 'embedding_model' string
280
+ allow_dangerous_deserialization=True
281
+ )
282
+
283
+ # Step 4: Create retriever
284
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
285
+
286
+ # Step 5: Custom prompt template
287
+ prompt_template = PromptTemplate(
288
+ input_variables=["context", "question"],
289
+ template="""
290
+ Answer strictly based on the context below.
291
+ Mention rule number / circular reference.
292
+ Add interpretation.
293
+
294
+ If answer is not found, say "Not available in the provided context".
295
+
296
+ Question: {question}
297
 
298
+ Context: {context}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
+ Answer:
301
+ """
302
+ )
303
+
304
+ # Step 6: Setup RetrievalQA chain
305
+ qa_chain = RetrievalQA.from_chain_type(
306
+ llm=llm, # Make sure llm is passed or defined globally
307
+ chain_type="stuff",
308
+ chain_type_kwargs={"prompt": prompt_template},
309
+ retriever=retriever,
310
+ return_source_documents=True
311
+ )
312
+ except Exception as e:
313
+ print(f"Error in generate_qa_chain: {e}")
314
+ return None
315
  return qa_chain
316
 
317
+ # Usage example:
318
+ # llm = HuggingFacePipeline(...) # Your LLM setup
319
+ # qa = generate_qa_chain("your_username/your-dataset", llm=llm)
320
+ # result = qa.invoke({"query": "What is the main rule?"})
321
+ # print(result["result"])
322
 
323
+ #============================
324
+ def bePrepare():
325
  global qa_chain
326
+ qa_chain = generate_qa_chain("manabb/nrl",llm=llm)
 
 
327
 
328
  def ask_question(query):
329
  if not qa_chain:
330
+ return "❌ Please clik the button to get the udated resources first."
331
  response = qa_chain.invoke({"query": query})
332
  return response["result"]
333
 
334
+ #====================
335
  # Gradio UI
336
  with gr.Blocks() as demo:
337
+ gr.Markdown("## 🧠 For use of NRL procurement department Only")
338
 
339
  with gr.Row():
340
+ Index_processing_output=gr.Textbox(label="πŸ“ Status", interactive=False)
341
+ Index_processing_btn = gr.Button("πŸ”„ Clik to get the udated resources")
 
 
342
 
343
  with gr.Row():
344
+ query_input = gr.Textbox(label="❓ This is for NRL commercial procurement deptd. Your Question pls")
345
  query_btn = gr.Button("🧠 Get Answer")
346
 
347
+ answer_output = gr.Textbox(label="βœ… Answer", lines=10)
348
+
349
+ output_msg = gr.Textbox(label="πŸ“ Authorization Message", interactive=False)
350
+ with gr.Row():
351
+ file_input = gr.File(label="πŸ“„ Upload .pdf File by only authorized user", type="filepath")
352
+ upload_btn = gr.Button("πŸ”„ Process Doc")
353
+ manab1="Write the password to upload new Circular Doc."
354
+ authorized_user=gr.Textbox(label=manab1)
355
+ upload_btn.click(upload_and_prepare, inputs=[file_input,authorized_user], outputs=output_msg)
356
 
 
357
  query_btn.click(ask_question, inputs=query_input, outputs=answer_output)
358
+ Index_processing_btn.click(bePrepare, inputs=None, outputs=Index_processing_output)
359
+
360
 
361
  # For local dev use: demo.launch()
362
  # For HF Spaces
363
+
364
  if __name__ == "__main__":
365
  demo.launch()
366
+
367
 
368