PaperChat / app.py
sdmadhav's picture
Update app.py
fddd403 verified
import streamlit as st
import os
import tempfile
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Import required libraries
try:
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from langchain_community.retrievers import BM25Retriever
from smolagents import Tool, CodeAgent, InferenceClientModel
except ImportError as e:
st.error(f"Missing dependency: {e}. Please install all requirements.")
st.stop()
# Custom Retriever Tool
class RetrieverTool(Tool):
name = "retriever"
description = "Uses semantic search to retrieve the parts of the research paper that could be most relevant to answer your query."
inputs = {
"query": {
"type": "string",
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
}
}
output_type = "string"
def __init__(self, docs, **kwargs):
super().__init__(**kwargs)
self.retriever = BM25Retriever.from_documents(docs, k=10)
def forward(self, query: str) -> str:
"""Execute the retrieval based on the provided query."""
assert isinstance(query, str), "Your search query must be a string"
docs = self.retriever.invoke(query)
return "\nRetrieved documents:\n" + "".join([
f"\n\n===== Document {str(i)} (Page {doc.metadata.get('page', 'N/A')}) =====\n" + doc.page_content
for i, doc in enumerate(docs)
])
# Function to load and process PDF
@st.cache_resource
def load_and_process_pdf(pdf_path):
"""Load PDF and split into chunks for retrieval."""
try:
# Load PDF
loader = PyPDFLoader(pdf_path)
pages = loader.load()
# Split into chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
add_start_index=True,
strip_whitespace=True,
separators=["\n\n", "\n", ".", " ", ""],
)
docs_processed = text_splitter.split_documents(pages)
return docs_processed, len(pages)
except Exception as e:
st.error(f"Error processing PDF: {e}")
return None, 0
# Function to create agent
@st.cache_resource
def create_agent(_docs):
"""Create the RAG agent with retriever tool."""
retriever_tool = RetrieverTool(_docs)
# Use FREE Hugging Face model (Qwen 2.5 72B via serverless inference)
agent = CodeAgent(
tools=[retriever_tool],
model=InferenceClientModel(
model_id="Qwen/Qwen2.5-72B-Instruct",
token=os.getenv("HF_TOKEN")
),
max_steps=4,
verbosity_level=0,
)
return agent
# Streamlit UI
def main():
st.set_page_config(
page_title="PaperChat",
page_icon="πŸ“„",
layout="wide"
)
# Header
st.title("πŸ“„ PaperChat - Research Paper Q&A Assistant")
st.markdown("""
Upload any research paper (PDF) and ask questions about it.
Powered by Agentic RAG with retrieval capabilities.
""")
# Sidebar
with st.sidebar:
st.header("πŸ“€ Upload Paper")
uploaded_file = st.file_uploader(
"Choose a PDF file",
type="pdf",
help="Upload a research paper in PDF format"
)
st.markdown("---")
st.subheader("πŸ“š Example Questions")
st.markdown("""
- What is the main contribution of this paper?
- What methodology was used?
- What are the key results?
- What datasets were used?
- What are the limitations mentioned?
""")
st.markdown("---")
st.subheader("ℹ️ How it works")
st.markdown("""
1. Upload your paper
2. The system chunks and indexes it
3. Ask questions naturally
4. Get answers with source citations
""")
# Main content area
if uploaded_file is not None:
# Save uploaded file to temporary location
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
tmp_file.write(uploaded_file.read())
tmp_path = tmp_file.name
# Process PDF
with st.spinner("πŸ”„ Processing your paper... This may take a moment."):
docs, num_pages = load_and_process_pdf(tmp_path)
if docs:
st.success(f"βœ… Paper loaded successfully! ({num_pages} pages, {len(docs)} chunks)")
# Create agent
with st.spinner("πŸ€– Initializing AI agent..."):
agent = create_agent(docs)
st.success("βœ… Agent ready! You can now ask questions.")
# Chat interface
st.markdown("---")
st.subheader("πŸ’¬ Ask Questions")
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Chat input
if question := st.chat_input("Ask a question about the paper..."):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": question})
# Display user message
with st.chat_message("user"):
st.markdown(question)
# Generate response
with st.chat_message("assistant"):
with st.spinner("πŸ€” Thinking..."):
try:
answer = agent.run(question)
st.markdown(answer)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": answer})
except Exception as e:
error_msg = f"Error generating answer: {str(e)}"
st.error(error_msg)
st.session_state.messages.append({"role": "assistant", "content": error_msg})
# Clear chat button
if st.button("πŸ—‘οΈ Clear Chat History"):
st.session_state.messages = []
st.rerun()
# Cleanup temp file
try:
os.unlink(tmp_path)
except:
pass
else:
# Welcome message when no file is uploaded
st.info("πŸ‘ˆ Please upload a research paper PDF from the sidebar to get started.")
st.markdown("### 🎯 What can you do with PaperChat?")
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("""
#### πŸ“– Understand Papers
- Get summaries of complex papers
- Understand methodology
- Learn about key findings
""")
with col2:
st.markdown("""
#### πŸ” Extract Information
- Find specific details
- Locate datasets used
- Identify citations
""")
with col3:
st.markdown("""
#### πŸ’‘ Learn Faster
- Ask follow-up questions
- Clarify concepts
- Compare approaches
""")
if __name__ == "__main__":
main()