ScholarBot / app /app.py
vinny4's picture
initial commit
9c37331
import os
import tempfile
import streamlit as st
import sys
import os
from dotenv import load_dotenv
load_dotenv()
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) # dirty fix ----> can fix this with pip install -e .
if project_root not in sys.path:
sys.path.insert(0, project_root)
from src.pipeline import ChatPipeline
st.set_page_config(page_title="ScholarBot", layout="wide")
st.title("ScholarBot: Chat with Research Papers")
if "chat_pipeline" not in st.session_state:
st.session_state.chat_pipeline = None
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
st.sidebar.header("Input Paper")
input_method = st.sidebar.radio("Choose input method:", ("Upload PDF", "arXiv ID"))
refine_query = st.sidebar.checkbox("Refine query before answering?", value=True)
if input_method == "Upload PDF":
uploaded_file = st.sidebar.file_uploader("Upload a PDF file", type=["pdf"])
if uploaded_file is not None:
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
tmp_file.write(uploaded_file.read())
pdf_path = tmp_file.name
st.info("Setting up ScholarBot...")
st.session_state.chat_pipeline = ChatPipeline()
st.session_state.chat_pipeline.setup_from_pdf(pdf_path)
st.success("PDF loaded and indexed successfully!")
else:
arxiv_id = st.sidebar.text_input("Enter arXiv ID:")
if st.sidebar.button("Load Paper") and arxiv_id:
st.info("Setting up ScholarBot...")
st.session_state.chat_pipeline = ChatPipeline()
st.session_state.chat_pipeline.setup(arxiv_id=arxiv_id)
st.success(f"arXiv paper {arxiv_id} loaded successfully!")
st.subheader("Chat with the Paper")
user_input = st.text_input("Ask a question:", placeholder="e.g. What is the JointMI acquisition function?")
if st.button("Generate Answer") and user_input:
if st.session_state.chat_pipeline:
answer = st.session_state.chat_pipeline.query(user_input, refine_query=refine_query)
st.session_state.chat_history.append((user_input, answer))
else:
st.warning("Please load a paper first.")
if st.session_state.chat_history:
st.markdown("---")
st.subheader("๐Ÿ“œ Chat History")
for q, a in st.session_state.chat_history[::-1]:
st.markdown(f"**You:** {q}")
st.markdown(f"**ScholarBot:** {a}")
st.markdown("---")