MBilal-72's picture
Update app.py
b7b493d verified
raw
history blame
2.53 kB
import os
import tempfile
import streamlit as st
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_groq import GroqLLM
# --- Environment Variable Setup ---
GROQ_API_KEY = os.getenv("GROQ_API_KEY", "your-groq-api-key")
HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY", "your-huggingface-api-key")
# --- Groq LLM Initialization ---
llm = GroqLLM(
api_key=GROQ_API_KEY,
model="llama3-8b-8192",
temperature=0.1
)
# --- HuggingFace Embeddings (add a default model name if needed) ---
embedding = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
cache_folder="./hf_cache",
huggingfacehub_api_token=HUGGINGFACE_API_KEY
)
# --- Streamlit UI ---
st.title("πŸ“„ RAG Chat with Groq + HuggingFace")
uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"])
user_query = st.text_input("Ask something about the document")
submit_button = st.button("Submit")
if uploaded_file and submit_button:
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
tmp_file.write(uploaded_file.read())
tmp_path = tmp_file.name
# --- Load and Split PDF ---
loader = PyPDFLoader(tmp_path)
pages = loader.load_and_split()
# --- FAISS Vector Store ---
vectorstore = FAISS.from_documents(pages, embedding)
retriever = vectorstore.as_retriever()
# --- Optional Custom Prompt ---
prompt_template = PromptTemplate(
input_variables=["context", "question"],
template="""
You are an intelligent assistant. Use the following context to answer the question accurately.
Context: {context}
Question: {question}
Answer:"""
)
# --- RetrievalQA Chain ---
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
chain_type_kwargs={"prompt": prompt_template}
)
# --- Run the Chain ---
result = qa_chain({"query": user_query})
st.markdown("### πŸ’¬ Answer")
st.write(result["result"])
# --- Optional: Show Source Documents ---
with st.expander("πŸ“„ Sources"):
for i, doc in enumerate(result["source_documents"]):
st.write(f"**Page {i+1}** β€” {doc.metadata.get('source', 'Unknown')}")