| import streamlit as st |
| import os |
| import tempfile |
| from dotenv import load_dotenv |
|
|
| |
| load_dotenv() |
|
|
| |
| try: |
| from langchain_community.document_loaders import PyPDFLoader |
| from langchain_text_splitters import RecursiveCharacterTextSplitter |
| from langchain_core.documents import Document |
| from langchain_community.retrievers import BM25Retriever |
| from smolagents import Tool, CodeAgent, InferenceClientModel |
| except ImportError as e: |
| st.error(f"Missing dependency: {e}. Please install all requirements.") |
| st.stop() |
|
|
| |
| class RetrieverTool(Tool): |
| name = "retriever" |
| description = "Uses semantic search to retrieve the parts of the research paper that could be most relevant to answer your query." |
| inputs = { |
| "query": { |
| "type": "string", |
| "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.", |
| } |
| } |
| output_type = "string" |
|
|
| def __init__(self, docs, **kwargs): |
| super().__init__(**kwargs) |
| self.retriever = BM25Retriever.from_documents(docs, k=10) |
|
|
| def forward(self, query: str) -> str: |
| """Execute the retrieval based on the provided query.""" |
| assert isinstance(query, str), "Your search query must be a string" |
| |
| docs = self.retriever.invoke(query) |
| |
| return "\nRetrieved documents:\n" + "".join([ |
| f"\n\n===== Document {str(i)} (Page {doc.metadata.get('page', 'N/A')}) =====\n" + doc.page_content |
| for i, doc in enumerate(docs) |
| ]) |
|
|
| |
| @st.cache_resource |
| def load_and_process_pdf(pdf_path): |
| """Load PDF and split into chunks for retrieval.""" |
| try: |
| |
| loader = PyPDFLoader(pdf_path) |
| pages = loader.load() |
| |
| |
| text_splitter = RecursiveCharacterTextSplitter( |
| chunk_size=500, |
| chunk_overlap=50, |
| add_start_index=True, |
| strip_whitespace=True, |
| separators=["\n\n", "\n", ".", " ", ""], |
| ) |
| docs_processed = text_splitter.split_documents(pages) |
| |
| return docs_processed, len(pages) |
| except Exception as e: |
| st.error(f"Error processing PDF: {e}") |
| return None, 0 |
|
|
| |
| @st.cache_resource |
| def create_agent(_docs): |
| """Create the RAG agent with retriever tool.""" |
| retriever_tool = RetrieverTool(_docs) |
| |
| |
| agent = CodeAgent( |
| tools=[retriever_tool], |
| model=InferenceClientModel( |
| model_id="Qwen/Qwen2.5-72B-Instruct", |
| token=os.getenv("HF_TOKEN") |
| ), |
| max_steps=4, |
| verbosity_level=0, |
| ) |
| |
| return agent |
|
|
| |
| def main(): |
| st.set_page_config( |
| page_title="PaperChat", |
| page_icon="π", |
| layout="wide" |
| ) |
| |
| |
| st.title("π PaperChat - Research Paper Q&A Assistant") |
| st.markdown(""" |
| Upload any research paper (PDF) and ask questions about it. |
| Powered by Agentic RAG with retrieval capabilities. |
| """) |
| |
| |
| with st.sidebar: |
| st.header("π€ Upload Paper") |
| uploaded_file = st.file_uploader( |
| "Choose a PDF file", |
| type="pdf", |
| help="Upload a research paper in PDF format" |
| ) |
| |
| st.markdown("---") |
| st.subheader("π Example Questions") |
| st.markdown(""" |
| - What is the main contribution of this paper? |
| - What methodology was used? |
| - What are the key results? |
| - What datasets were used? |
| - What are the limitations mentioned? |
| """) |
| |
| st.markdown("---") |
| st.subheader("βΉοΈ How it works") |
| st.markdown(""" |
| 1. Upload your paper |
| 2. The system chunks and indexes it |
| 3. Ask questions naturally |
| 4. Get answers with source citations |
| """) |
| |
| |
| if uploaded_file is not None: |
| |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file: |
| tmp_file.write(uploaded_file.read()) |
| tmp_path = tmp_file.name |
| |
| |
| with st.spinner("π Processing your paper... This may take a moment."): |
| docs, num_pages = load_and_process_pdf(tmp_path) |
| |
| if docs: |
| st.success(f"β
Paper loaded successfully! ({num_pages} pages, {len(docs)} chunks)") |
| |
| |
| with st.spinner("π€ Initializing AI agent..."): |
| agent = create_agent(docs) |
| |
| st.success("β
Agent ready! You can now ask questions.") |
| |
| |
| st.markdown("---") |
| st.subheader("π¬ Ask Questions") |
| |
| |
| if "messages" not in st.session_state: |
| st.session_state.messages = [] |
| |
| |
| for message in st.session_state.messages: |
| with st.chat_message(message["role"]): |
| st.markdown(message["content"]) |
| |
| |
| if question := st.chat_input("Ask a question about the paper..."): |
| |
| st.session_state.messages.append({"role": "user", "content": question}) |
| |
| |
| with st.chat_message("user"): |
| st.markdown(question) |
| |
| |
| with st.chat_message("assistant"): |
| with st.spinner("π€ Thinking..."): |
| try: |
| answer = agent.run(question) |
| st.markdown(answer) |
| |
| |
| st.session_state.messages.append({"role": "assistant", "content": answer}) |
| except Exception as e: |
| error_msg = f"Error generating answer: {str(e)}" |
| st.error(error_msg) |
| st.session_state.messages.append({"role": "assistant", "content": error_msg}) |
| |
| |
| if st.button("ποΈ Clear Chat History"): |
| st.session_state.messages = [] |
| st.rerun() |
| |
| |
| try: |
| os.unlink(tmp_path) |
| except: |
| pass |
| |
| else: |
| |
| st.info("π Please upload a research paper PDF from the sidebar to get started.") |
| |
| st.markdown("### π― What can you do with PaperChat?") |
| col1, col2, col3 = st.columns(3) |
| |
| with col1: |
| st.markdown(""" |
| #### π Understand Papers |
| - Get summaries of complex papers |
| - Understand methodology |
| - Learn about key findings |
| """) |
| |
| with col2: |
| st.markdown(""" |
| #### π Extract Information |
| - Find specific details |
| - Locate datasets used |
| - Identify citations |
| """) |
| |
| with col3: |
| st.markdown(""" |
| #### π‘ Learn Faster |
| - Ask follow-up questions |
| - Clarify concepts |
| - Compare approaches |
| """) |
|
|
| if __name__ == "__main__": |
| main() |