Semantic_Search / app.py
Nigz's picture
Create app.py
e7e52d0 verified
raw
history blame
3.4 kB
import os
import pickle
import time
import gradio as gr
from dotenv import load_dotenv
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import UnstructuredURLLoader, PyPDFLoader, TextLoader, Docx2txtLoader, UnstructuredHTMLLoader
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain.vectorstores import FAISS
load_dotenv()
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-001", google_api_key = api_key, temperature=0.5)
file_path = "vector_db.pkl"
def load_any_file(file_path):
ext = os.path.splitext(file_path)[1].lower()
if ext == ".pdf":
loader = PyPDFLoader(file_path)
elif ext == ".txt":
loader = TextLoader(file_path)
elif ext == ".docx":
loader = Docx2txtLoader(file_path)
elif ext in [".html", ".htm"]:
loader = UnstructuredHTMLLoader(file_path)
else:
raise ValueError(f"Unsupported file type: {ext}")
return loader.load()
def process_inputs(url, file):
data = []
if url:
loader = UnstructuredURLLoader(urls=[url])
data.extend(loader.load())
if file:
file_path = file.name
data.extend(load_any_file(file_path))
if not data:
return "Please provide at least a URL or upload a document.", ""
text_splitter = RecursiveCharacterTextSplitter(
separators=['\n\n', '\n', '.', ','],
chunk_size=1000
)
docs = text_splitter.split_documents(data)
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001",google_api_key = api_key)
vectorstore_openai = FAISS.from_documents(docs, embeddings)
with open(file_path, "wb") as f:
pickle.dump(vectorstore_openai, f)
return "Document processing and vector storage completed. You may now ask your question.", ""
def answer_question(query):
if not os.path.exists(file_path):
return "Please process a document or URL first.", ""
with open(file_path, "rb") as f:
vectorstore = pickle.load(f)
chain = RetrievalQAWithSourcesChain.from_llm(llm=llm, retriever=vectorstore.as_retriever())
result = chain({"question": query}, return_only_outputs=True)
answer = result.get("answer", "No answer generated.")
sources = result.get("sources", "No sources provided.")
return answer, sources
with gr.Blocks(title="RockyBot: News Research Tool") as demo:
gr.Markdown("## 📰 RockyBot: Research News Articles via URL or Upload")
with gr.Row():
url_input = gr.Textbox(label="News Article URL", placeholder="Paste a single article URL here")
file_input = gr.File(label="Upload Document", file_types=[".pdf", ".txt", ".docx", ".html", ".htm"])
with gr.Row():
process_btn = gr.Button("Process Document or URL")
process_status = gr.Textbox(label="Status Message", interactive=False)
with gr.Row():
query_input = gr.Textbox(label="Ask a Question")
answer_output = gr.Textbox(label="Answer")
sources_output = gr.Textbox(label="Sources")
process_btn.click(fn=process_inputs, inputs=[url_input, file_input], outputs=[process_status, answer_output])
query_input.submit(fn=answer_question, inputs=query_input, outputs=[answer_output, sources_output])
demo.launch()