Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import os | |
| import tempfile | |
| import pandas as pd | |
| # LangChain Imports | |
| from langchain_community.document_loaders import ( | |
| PyMuPDFLoader, | |
| CSVLoader, | |
| TextLoader, | |
| Docx2txtLoader, | |
| ) | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain.chains import create_retrieval_chain | |
| from langchain_groq import ChatGroq | |
| # --------------------------------------------------- | |
| # PAGE CONFIG | |
| # --------------------------------------------------- | |
| st.set_page_config( | |
| page_title="ππ¬ DocTalk- Chat With Docs", | |
| page_icon="ππ¬", | |
| layout="wide" | |
| ) | |
| st.title("ππ¬ DocTalk - Chat With Your Documents") | |
| # --------------------------------------------------- | |
| # SESSION STATE | |
| # --------------------------------------------------- | |
| if "qa_chain" not in st.session_state: | |
| st.session_state.qa_chain = None | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "processed" not in st.session_state: | |
| st.session_state.processed = False | |
| # --------------------------------------------------- | |
| # LOAD EMBEDDINGS | |
| # --------------------------------------------------- | |
| def load_embeddings(): | |
| return HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2" | |
| ) | |
| # --------------------------------------------------- | |
| # LOAD GROQ LLM | |
| # --------------------------------------------------- | |
| def load_llm(): | |
| return ChatGroq( | |
| groq_api_key=os.getenv("GROQ_API_KEY"), | |
| model_name="llama-3.1-8b-instant", | |
| temperature=0 | |
| ) | |
| # --------------------------------------------------- | |
| # FILE ROUTER | |
| # --------------------------------------------------- | |
| def load_file(path, filename): | |
| ext = os.path.splitext(filename)[1].lower() | |
| if ext == ".pdf": | |
| return PyMuPDFLoader(path).load() | |
| elif ext == ".csv": | |
| return CSVLoader(file_path=path, encoding="utf-8").load() | |
| elif ext == ".docx": | |
| return Docx2txtLoader(path).load() | |
| elif ext in [".txt", ".py", ".json", ".md"]: | |
| return TextLoader(path, encoding="utf-8").load() | |
| elif ext in [".xlsx", ".xls"]: | |
| # safer Excel loader using pandas | |
| df = pd.read_excel(path) | |
| text = df.to_string() | |
| from langchain_core.documents import Document | |
| return [Document(page_content=text)] | |
| else: | |
| raise ValueError(f"Unsupported file format: {ext}") | |
| # --------------------------------------------------- | |
| # PROCESS DOCUMENT | |
| # --------------------------------------------------- | |
| def process_document(uploaded_file): | |
| ext = os.path.splitext(uploaded_file.name)[1] | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp: | |
| tmp.write(uploaded_file.getvalue()) | |
| tmp_path = tmp.name | |
| docs = load_file(tmp_path, uploaded_file.name) | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=800, | |
| chunk_overlap=150 | |
| ) | |
| chunks = splitter.split_documents(docs) | |
| embeddings = load_embeddings() | |
| vector_store = FAISS.from_documents(chunks, embeddings) | |
| llm = load_llm() | |
| system_prompt = """ | |
| You are a professional assistant analyzing a document. | |
| Answer ONLY using the provided context. | |
| If the answer cannot be found say: | |
| "I cannot find the answer in this document." | |
| Context: | |
| {context} | |
| """ | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system_prompt), | |
| ("human", "{input}") | |
| ] | |
| ) | |
| qa_chain = create_retrieval_chain( | |
| vector_store.as_retriever(search_kwargs={"k": 4}), | |
| create_stuff_documents_chain(llm, prompt) | |
| ) | |
| os.remove(tmp_path) | |
| return qa_chain | |
| # --------------------------------------------------- | |
| # SIDEBAR | |
| # --------------------------------------------------- | |
| with st.sidebar: | |
| st.header("βοΈ Settings") | |
| uploaded_file = st.file_uploader( | |
| "Upload Document", | |
| type=["pdf","csv","xlsx","xls","docx","txt","py","json","md"] | |
| ) | |
| if uploaded_file: | |
| if st.button("π Process Document"): | |
| with st.spinner("Processing document..."): | |
| st.session_state.qa_chain = process_document(uploaded_file) | |
| st.session_state.processed = True | |
| st.success("Document indexed successfully!") | |
| if st.session_state.processed: | |
| if st.button("π Clear Chat"): | |
| st.session_state.messages = [] | |
| st.rerun() | |
| # --------------------------------------------------- | |
| # CHAT UI | |
| # --------------------------------------------------- | |
| if not st.session_state.processed: | |
| st.info("Upload a document from the sidebar to start chatting.") | |
| else: | |
| for msg in st.session_state.messages: | |
| with st.chat_message(msg["role"]): | |
| st.markdown(msg["content"]) | |
| user_input = st.chat_input("Ask something about your document") | |
| if user_input: | |
| st.session_state.messages.append( | |
| {"role": "user", "content": user_input} | |
| ) | |
| with st.chat_message("user"): | |
| st.markdown(user_input) | |
| with st.chat_message("assistant"): | |
| with st.spinner("Thinking..."): | |
| response = st.session_state.qa_chain.invoke( | |
| {"input": user_input} | |
| ) | |
| answer = response["answer"] | |
| sources = response["context"] | |
| st.markdown(answer) | |
| if sources: | |
| with st.expander("π Sources"): | |
| for i, s in enumerate(sources): | |
| st.caption( | |
| f"Chunk {i+1}: {s.page_content[:300]}..." | |
| ) | |
| st.session_state.messages.append( | |
| { | |
| "role": "assistant", | |
| "content": answer, | |
| "sources": sources | |
| } | |
| ) | |
| # --------------------------------------------------- | |
| # FOOTER | |
| # --------------------------------------------------- | |
| st.markdown( | |
| """ | |
| --- | |
| Built by **Chirag Kaushik** | |
| """ | |
| ) |