|
|
import os |
|
|
import tempfile |
|
|
import streamlit as st |
|
|
|
|
|
from langchain_community.document_loaders import PyPDFLoader |
|
|
from langchain_community.vectorstores import FAISS |
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
|
from langchain.chains import RetrievalQA |
|
|
from langchain.prompts import PromptTemplate |
|
|
from langchain_groq import GroqLLM |
|
|
|
|
|
|
|
|
GROQ_API_KEY = os.getenv("GROQ_API_KEY", "your-groq-api-key") |
|
|
HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY", "your-huggingface-api-key") |
|
|
|
|
|
|
|
|
llm = GroqLLM( |
|
|
api_key=GROQ_API_KEY, |
|
|
model="llama3-8b-8192", |
|
|
temperature=0.1 |
|
|
) |
|
|
|
|
|
|
|
|
embedding = HuggingFaceEmbeddings( |
|
|
model_name="sentence-transformers/all-MiniLM-L6-v2", |
|
|
cache_folder="./hf_cache", |
|
|
huggingfacehub_api_token=HUGGINGFACE_API_KEY |
|
|
) |
|
|
|
|
|
|
|
|
st.title("π RAG Chat with Groq + HuggingFace") |
|
|
|
|
|
uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"]) |
|
|
user_query = st.text_input("Ask something about the document") |
|
|
submit_button = st.button("Submit") |
|
|
|
|
|
if uploaded_file and submit_button: |
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file: |
|
|
tmp_file.write(uploaded_file.read()) |
|
|
tmp_path = tmp_file.name |
|
|
|
|
|
|
|
|
loader = PyPDFLoader(tmp_path) |
|
|
pages = loader.load_and_split() |
|
|
|
|
|
|
|
|
vectorstore = FAISS.from_documents(pages, embedding) |
|
|
retriever = vectorstore.as_retriever() |
|
|
|
|
|
|
|
|
prompt_template = PromptTemplate( |
|
|
input_variables=["context", "question"], |
|
|
template=""" |
|
|
You are an intelligent assistant. Use the following context to answer the question accurately. |
|
|
Context: {context} |
|
|
Question: {question} |
|
|
Answer:""" |
|
|
) |
|
|
|
|
|
|
|
|
qa_chain = RetrievalQA.from_chain_type( |
|
|
llm=llm, |
|
|
chain_type="stuff", |
|
|
retriever=retriever, |
|
|
return_source_documents=True, |
|
|
chain_type_kwargs={"prompt": prompt_template} |
|
|
) |
|
|
|
|
|
|
|
|
result = qa_chain({"query": user_query}) |
|
|
st.markdown("### π¬ Answer") |
|
|
st.write(result["result"]) |
|
|
|
|
|
|
|
|
with st.expander("π Sources"): |
|
|
for i, doc in enumerate(result["source_documents"]): |
|
|
st.write(f"**Page {i+1}** β {doc.metadata.get('source', 'Unknown')}") |
|
|
|