MahatirTusher's picture
Update app.py
ec74dc1 verified
import streamlit as st
from dotenv import load_dotenv
from langchain_community.document_loaders import WebBaseLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores.faiss import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
import os
import time
from langchain_groq import ChatGroq
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
# Load environment variables (optional)
load_dotenv()
# Hardcoded Groq API key (NOT RECOMMENDED for production)
GROQ_API_KEY = "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"
# Set Streamlit app title
st.title("News Research Tool πŸ“ˆ")
st.sidebar.title("News Article URLs")
# Initialize session state for FAISS index
if "index_created" not in st.session_state:
st.session_state.index_created = False
# Get URLs from user input
urls = []
for i in range(3):
url = st.sidebar.text_input(f"URL {i+1}")
if url:
urls.append(url)
# Button to process URLs
process_url_clicked = st.sidebar.button("Process URLs")
faiss_index_path = "faiss_index"
# Placeholder for main content
main_placeholder = st.empty()
# Initialize the Groq LLM
llm = ChatGroq(
api_key=GROQ_API_KEY,
model="llama3-70b-8192"
)
def save_faiss_index(vectorstore, path):
vectorstore.save_local(path)
def load_faiss_index(path, embeddings):
return FAISS.load_local(path, embeddings, allow_dangerous_deserialization=True)
if process_url_clicked:
if not urls:
main_placeholder.error("Please provide at least one valid URL.")
else:
try:
main_placeholder.text("Data Loading...Started...βœ…βœ…βœ…")
loader = WebBaseLoader(urls)
data = loader.load()
# Check loaded data
if not data or all(len(doc.page_content.strip()) == 0 for doc in data):
main_placeholder.error("No content loaded from URLs. Try different URLs.")
st.stop()
main_placeholder.text("Text Splitter...Started...βœ…βœ…βœ…")
text_splitter = RecursiveCharacterTextSplitter(
separators=['\n\n', '\n', '.', ','],
chunk_size=1000
)
docs = text_splitter.split_documents(data)
main_placeholder.text(f"Split into {len(docs)} document chunks.")
main_placeholder.text("Embedding Vector Started Building...βœ…βœ…βœ…")
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore_openai = FAISS.from_documents(docs, embeddings)
save_faiss_index(vectorstore_openai, faiss_index_path)
st.session_state.index_created = True
main_placeholder.text("FAISS index saved successfully! βœ…βœ…βœ…")
time.sleep(2)
main_placeholder.empty()
except Exception as e:
main_placeholder.error(f"Error processing URLs: {str(e)}")
query = main_placeholder.text_input("Question: ")
if query:
if not st.session_state.index_created or not os.path.exists(faiss_index_path):
main_placeholder.error("No FAISS index found. Please process URLs first.")
else:
with st.spinner("Processing your question..."):
try:
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = load_faiss_index(faiss_index_path, embeddings)
chain = RetrievalQAWithSourcesChain.from_llm(llm=llm, retriever=vectorstore.as_retriever())
result = chain({"question": query}, return_only_outputs=True)
if not result.get("answer"):
main_placeholder.warning("No answer generated. Try a different question or URLs.")
st.stop()
st.header("Answer")
st.write(result["answer"])
sources = result.get("sources", "")
if sources:
st.subheader("Sources:")
sources_list = sources.split("\n")
for source in sources_list:
st.write(source)
else:
st.write("No sources found.")
except Exception as e:
main_placeholder.error(f"Error answering query: {str(e)}")