Chat_with_PDF / app.py
ChiragKaushikCK's picture
Update app.py
6912357 verified
import streamlit as st
import os
import tempfile
import pandas as pd
# LangChain Imports
from langchain_community.document_loaders import (
PyMuPDFLoader,
CSVLoader,
TextLoader,
Docx2txtLoader,
)
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain_groq import ChatGroq
# ---------------------------------------------------
# PAGE CONFIG
# ---------------------------------------------------
st.set_page_config(
page_title="πŸ“—πŸ’¬ DocTalk- Chat With Docs",
page_icon="πŸ“—πŸ’¬",
layout="wide"
)
st.title("πŸ“—πŸ’¬ DocTalk - Chat With Your Documents")
# ---------------------------------------------------
# SESSION STATE
# ---------------------------------------------------
if "qa_chain" not in st.session_state:
st.session_state.qa_chain = None
if "messages" not in st.session_state:
st.session_state.messages = []
if "processed" not in st.session_state:
st.session_state.processed = False
# ---------------------------------------------------
# LOAD EMBEDDINGS
# ---------------------------------------------------
@st.cache_resource
def load_embeddings():
return HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
# ---------------------------------------------------
# LOAD GROQ LLM
# ---------------------------------------------------
@st.cache_resource
def load_llm():
return ChatGroq(
groq_api_key=os.getenv("GROQ_API_KEY"),
model_name="llama-3.1-8b-instant",
temperature=0
)
# ---------------------------------------------------
# FILE ROUTER
# ---------------------------------------------------
def load_file(path, filename):
ext = os.path.splitext(filename)[1].lower()
if ext == ".pdf":
return PyMuPDFLoader(path).load()
elif ext == ".csv":
return CSVLoader(file_path=path, encoding="utf-8").load()
elif ext == ".docx":
return Docx2txtLoader(path).load()
elif ext in [".txt", ".py", ".json", ".md"]:
return TextLoader(path, encoding="utf-8").load()
elif ext in [".xlsx", ".xls"]:
# safer Excel loader using pandas
df = pd.read_excel(path)
text = df.to_string()
from langchain_core.documents import Document
return [Document(page_content=text)]
else:
raise ValueError(f"Unsupported file format: {ext}")
# ---------------------------------------------------
# PROCESS DOCUMENT
# ---------------------------------------------------
def process_document(uploaded_file):
ext = os.path.splitext(uploaded_file.name)[1]
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp:
tmp.write(uploaded_file.getvalue())
tmp_path = tmp.name
docs = load_file(tmp_path, uploaded_file.name)
splitter = RecursiveCharacterTextSplitter(
chunk_size=800,
chunk_overlap=150
)
chunks = splitter.split_documents(docs)
embeddings = load_embeddings()
vector_store = FAISS.from_documents(chunks, embeddings)
llm = load_llm()
system_prompt = """
You are a professional assistant analyzing a document.
Answer ONLY using the provided context.
If the answer cannot be found say:
"I cannot find the answer in this document."
Context:
{context}
"""
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("human", "{input}")
]
)
qa_chain = create_retrieval_chain(
vector_store.as_retriever(search_kwargs={"k": 4}),
create_stuff_documents_chain(llm, prompt)
)
os.remove(tmp_path)
return qa_chain
# ---------------------------------------------------
# SIDEBAR
# ---------------------------------------------------
with st.sidebar:
st.header("βš™οΈ Settings")
uploaded_file = st.file_uploader(
"Upload Document",
type=["pdf","csv","xlsx","xls","docx","txt","py","json","md"]
)
if uploaded_file:
if st.button("πŸš€ Process Document"):
with st.spinner("Processing document..."):
st.session_state.qa_chain = process_document(uploaded_file)
st.session_state.processed = True
st.success("Document indexed successfully!")
if st.session_state.processed:
if st.button("πŸ—‘ Clear Chat"):
st.session_state.messages = []
st.rerun()
# ---------------------------------------------------
# CHAT UI
# ---------------------------------------------------
if not st.session_state.processed:
st.info("Upload a document from the sidebar to start chatting.")
else:
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
user_input = st.chat_input("Ask something about your document")
if user_input:
st.session_state.messages.append(
{"role": "user", "content": user_input}
)
with st.chat_message("user"):
st.markdown(user_input)
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = st.session_state.qa_chain.invoke(
{"input": user_input}
)
answer = response["answer"]
sources = response["context"]
st.markdown(answer)
if sources:
with st.expander("πŸ”Ž Sources"):
for i, s in enumerate(sources):
st.caption(
f"Chunk {i+1}: {s.page_content[:300]}..."
)
st.session_state.messages.append(
{
"role": "assistant",
"content": answer,
"sources": sources
}
)
# ---------------------------------------------------
# FOOTER
# ---------------------------------------------------
st.markdown(
"""
---
Built by **Chirag Kaushik**
"""
)