NRL-Chat-bot / app.py
manabb's picture
Update app.py
045ccf2 verified
# app.py
import os
import gradio as gr
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.document_loaders import TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, AutoModelForCausalLM
from langchain.document_loaders import PyPDFLoader, PyMuPDFLoader
import pypdf
from langchain.prompts import PromptTemplate
from huggingface_hub import upload_folder
from huggingface_hub import HfApi, upload_file
from huggingface_hub import hf_hub_download
from huggingface_hub import (
file_exists,
upload_file,
repo_exists,
create_repo,
hf_hub_download
)
import shutil
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain_huggingface import HuggingFacePipeline
# Optional: Set HF Token if needed-allWrite
os.environ['HUGGINGFACEHUB_API_TOKEN'] = os.getenv("HF_TOKEN")
# Initialize embedding model
embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
# Store the QA chain globally (across UI events)
qa_chain = None
qa_chain1 = None
llm=None
llm1=None
repo_id=os.getenv("reposit_id")
# Optional: Set HF Token if needed
# os.environ['HUGGINGFACEHUB_API_TOKEN'] = 'hf_XXXX'
# Initialize embedding model
#embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
#=============================================google/flan-t5-small
# Load HF model (lightweight for CPU)
model_name = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Wrap in pipeline
pipe1 = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_length=512)
if llm1 is None:
llm1 = HuggingFacePipeline(pipeline=pipe1)
#=============================================TinyLlama/TinyLlama-1.1B-Chat-v1.0
# Create optimized pipeline for TinyLlama
pipe = pipeline(
"text-generation",
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
tokenizer=AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0"),
device_map="auto" if torch.cuda.is_available() else None,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
max_new_tokens=512,
temperature=0.7,
top_p=0.95,
do_sample=True,
repetition_penalty=1.15,
pad_token_id=tokenizer.eos_token_id if 'tokenizer' in locals() else 128001,
trust_remote_code=True
)
# Build LangChain LLM wrapper
if llm is None:
llm = HuggingFacePipeline(pipeline=pipe)
#=============================================
def create_faiss_index(repo_id, file, embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
"""Create FAISS index from PDF and upload to HF dataset repo"""
message = "Index creation started"
try:
# Step 1: Create proper embeddings object (CRITICAL FIX)
embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
# Step 2: Clean temp directory
if os.path.exists("temp_faiss"):
shutil.rmtree("temp_faiss")
# Step 3: Try PyPDFLoader first
loader = PyPDFLoader(file)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
new_docs = text_splitter.split_documents(documents)
db = FAISS.from_documents(new_docs, embeddings)
db.save_local("temp_faiss")
# Step 4: Upload to HF Hub
api = HfApi(token=os.getenv("HF_TOKEN"))
api.upload_file(path_or_fileobj="temp_faiss/index.faiss", path_in_repo="index.faiss", repo_id=repo_id, repo_type="dataset")
api.upload_file(path_or_fileobj="temp_faiss/index.pkl", path_in_repo="index.pkl", repo_id=repo_id, repo_type="dataset")
message = "βœ… Index created successfully with PyPDFLoader and uploaded to repo"
except Exception as e1:
try:
print(f"PyPDFLoader failed: {e1}")
# Step 5: Fallback to PyMuPDFLoader
loader = PyMuPDFLoader(file)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
new_docs = text_splitter.split_documents(documents)
# Use same embeddings instance
embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
db = FAISS.from_documents(new_docs, embeddings)
db.save_local("temp_faiss")
# Upload
api = HfApi(token=os.getenv("HF_TOKEN"))
api.upload_file(path_or_fileobj="temp_faiss/index.faiss", path_in_repo="index.faiss", repo_id=repo_id, repo_type="dataset")
api.upload_file(path_or_fileobj="temp_faiss/index.pkl", path_in_repo="index.pkl", repo_id=repo_id, repo_type="dataset")
message = f"βœ… PyPDFLoader failed ({e1}), PyMuPDFLoader succeeded and uploaded to repo"
except Exception as e2:
message = f"❌ Both loaders failed. PyPDF: {e1}, PyMuPDF: {e2}"
finally:
# Cleanup
if os.path.exists("temp_faiss"):
shutil.rmtree("temp_faiss")
return message
# Usage
#result = create_faiss_index("your_username/your-dataset", "path/to/your/file.pdf")
#print(result)
#=============
def update_faiss_from_hf(repo_id, file, embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
"""Load existing FAISS from HF, add new docs, push updated version."""
message = ""
try:
# Step 1: Create embeddings
embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
# Step 2: Download existing FAISS files
print("Downloading existing FAISS index...")
faiss_path = hf_hub_download(repo_id=repo_id, filename="index.faiss", repo_type="dataset")
pkl_path = hf_hub_download(repo_id=repo_id, filename="index.pkl", repo_type="dataset")
# Step 3: Load existing vectorstore
folder_path = os.path.dirname(faiss_path)
vectorstore = FAISS.load_local(
folder_path=folder_path,
embeddings=embeddings,
allow_dangerous_deserialization=True
)
message += f"βœ… Loaded existing index with {vectorstore.index.ntotal} vectors\n"
# Step 4: Load new document with fallback
documents = None
loaders = [
("PyPDFLoader", PyPDFLoader),
("PyMuPDFLoader", PyMuPDFLoader)
]
for loader_name, LoaderClass in loaders:
try:
print(f"Trying {loader_name}...")
loader = LoaderClass(file)
documents = loader.load()
message += f"βœ… Loaded {len(documents)} pages with {loader_name}\n"
break
except Exception as e:
message += f"❌ {loader_name} failed: {str(e)[:100]}...\n"
continue
if documents is None:
return "❌ All PDF loaders failed"
# Step 5: Split documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
new_docs = text_splitter.split_documents(documents)
message += f"βœ… Created {len(new_docs)} chunks from new document\n"
# Step 6: Add new documents to existing index
vectorstore.add_documents(new_docs)
message += f"βœ… Added to index. New total: {vectorstore.index.ntotal} vectors\n"
# Step 7: Save updated index
temp_dir = "temp_faiss_update"
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
vectorstore.save_local(temp_dir)
# Step 8: Upload updated files
api = HfApi(token=os.getenv("HF_TOKEN")) # Replace with your token
api.upload_file(
path_or_fileobj=f"{temp_dir}/index.faiss",
path_in_repo="index.faiss",
repo_id=repo_id,
repo_type="dataset"
)
api.upload_file(
path_or_fileobj=f"{temp_dir}/index.pkl",
path_in_repo="index.pkl",
repo_id=repo_id,
repo_type="dataset"
)
message += f"βœ… Successfully updated repo with {len(new_docs)} new chunks!"
except Exception as e:
message += f"❌ Update failed: {str(e)}"
finally:
# Cleanup
if os.path.exists("temp_faiss_update"):
shutil.rmtree("temp_faiss_update")
return message
# Usage
# result = update_faiss_from_hf("yourusername/my-faiss-store", "new_document.pdf")
# print(result)
#====================
def upload_and_prepare(file,user):
# Load & split document
mm=""
if user == os.getenv("uploading_password"):
if file_exists(repo_id=repo_id, filename="index.faiss", repo_type="dataset"):
mm=update_faiss_from_hf(repo_id, file)
#mm="βœ… Document processed. New index added. You can now ask questions!"
if not file_exists(repo_id=repo_id, filename="index.faiss", repo_type="dataset"):
mm=create_faiss_index(repo_id, file)
#mm="βœ… Document processed. New index created. You can now ask questions!"
else:
mm="❌ Unauthorized User"
return mm
#create_faiss_index(repo_id, file_input)
#======================================================================
def generate_qa_chain(repo_id, embedding_model="sentence-transformers/all-MiniLM-L6-v2", llm=None):
"""
Generate QA chain from HF dataset repo FAISS index
"""
try:
# Step 1: Create embeddings (FIX: was missing)
embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
# Step 2: Download FAISS files from HF Hub
faiss_path = hf_hub_download(
repo_id=repo_id,
filename="index.faiss",
repo_type="dataset"
)
pkl_path = hf_hub_download(
repo_id=repo_id,
filename="index.pkl",
repo_type="dataset"
)
# Step 3: Load FAISS vectorstore (FIX: pass embeddings object, not string)
folder_path = os.path.dirname(faiss_path)
vectorstore = FAISS.load_local(
folder_path=folder_path,
embeddings=embeddings, # FIXED: was 'embedding_model' string
allow_dangerous_deserialization=True
)
# Step 4: Create retriever
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
# Step 5: Custom prompt template
prompt_template = PromptTemplate(
input_variables=["context", "question"],
template="""
Answer strictly based on the context below.
Mention rule number / circular reference.
Add interpretation.
If answer is not found, say "Not available in the provided context".
Question: {question}
Context: {context}
Answer:
"""
)
# Step 6: Setup RetrievalQA chain
qa_chain = RetrievalQA.from_chain_type(
llm=llm, # Make sure llm is passed or defined globally
chain_type="stuff",
chain_type_kwargs={"prompt": prompt_template},
retriever=retriever,
return_source_documents=True
)
except Exception as e:
print(f"Error in generate_qa_chain: {e}")
return None
return qa_chain
# Usage example:
# llm = HuggingFacePipeline(...) # Your LLM setup
# qa = generate_qa_chain("your_username/your-dataset", llm=llm)
# result = qa.invoke({"query": "What is the main rule?"})
# print(result["result"])
#============================
def bePrepare():
global qa_chain
qa_chain = generate_qa_chain("manabb/nrl",llm=llm)
return "I am ready, ask me questions with model tiny Lama."
def bePrepare1():
global qa_chain1
qa_chain1 = generate_qa_chain("manabb/nrl",llm=llm1)
return "I am ready, ask me questions with model google flan-t5."
def ask_question(query):
msg="Blank question! "
if not qa_chain:
msg="❌ Please clik the button to get the udated resources with tiny Lama."
if query:
response = qa_chain.invoke({"query": query})
msg= response["result"]
return msg
def ask_question1(query):
msg="Blank question!"
if not qa_chain1:
msg="❌ Please clik the button to get the udated resources google flan-t5."
if query:
response1 = qa_chain1.invoke({"query": query})
msg=response1["result"]
return msg
#====================
# Gradio UI
with gr.Blocks(css="""
#blue-col { background: linear-gradient(135deg, #667eea, #764ba2); padding: 20px; border-radius: 10px; }
#green-col { background: #4ecdc4; padding: 20px; border-radius: 10px; }
""") as demo:
gr.Markdown("## 🧠 For use of NRL procurement department Only")
with gr.Row():
# LEFT COLUMN: Document Management
with gr.Column(elem_id="blue-col",scale=1):
gr.Markdown("## 🧠 Using heavy TinyLama Model")
with gr.Row():
Index_processing_output=gr.Textbox(label="πŸ“ Status for tiny lama", interactive=False)
with gr.Row():
Index_processing_btn = gr.Button("πŸ”„ Clik to get the udated resources with tiny Lama")
Index_processing_btn.click(bePrepare, inputs=None, outputs=Index_processing_output)
with gr.Row():
query_input = gr.Textbox(label="❓ Your Question pls")
with gr.Row():
query_btn = gr.Button("🧠 Get Answer")
with gr.Row():
answer_output = gr.Textbox(label="βœ… Answer", lines=4)
query_btn.click(ask_question, inputs=query_input, outputs=answer_output)
# RIGHT COLUMN: Document Management
with gr.Column(elem_id="green-col",scale=2):
gr.Markdown("## 🧠 Using ligth model - google flan-t5")
Index_processing_output1=gr.Textbox(label="πŸ“ Status for google flan-t5", interactive=False)
Index_processing_btn1 = gr.Button("πŸ”„ Clik to get the udated resources with google flan-t5")
Index_processing_btn1.click(bePrepare1, inputs=None, outputs=Index_processing_output1)
query_input1 = gr.Textbox(label="❓ Your Question pls")
query_btn1 = gr.Button("🧠 Get Answer")
answer_output1 = gr.Textbox(label="βœ… Answer", lines=4)
query_btn1.click(ask_question1, inputs=query_input1, outputs=answer_output1)
with gr.Row():
gr.Markdown("## 🧠 For uploading new PDF documents.")
output_msg = gr.Textbox(label="πŸ“ Authorization Message", interactive=False)
file_input = gr.File(label="πŸ“„ Upload .pdf File by only authorized user", type="filepath")
upload_btn = gr.Button("πŸ”„ Process Doc")
authorized_user=gr.Textbox(label="Write the password to upload new Circular Doc.")
upload_btn.click(upload_and_prepare, inputs=[file_input,authorized_user], outputs=output_msg)
# For local dev use: demo.launch()
# For HF Spaces
if __name__ == "__main__":
demo.launch()