Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| import streamlit as st | |
| import sys | |
| import os | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) # dirty fix ----> can fix this with pip install -e . | |
| if project_root not in sys.path: | |
| sys.path.insert(0, project_root) | |
| from src.pipeline import ChatPipeline | |
| st.set_page_config(page_title="ScholarBot", layout="wide") | |
| st.title("ScholarBot: Chat with Research Papers") | |
| if "chat_pipeline" not in st.session_state: | |
| st.session_state.chat_pipeline = None | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [] | |
| st.sidebar.header("Input Paper") | |
| input_method = st.sidebar.radio("Choose input method:", ("Upload PDF", "arXiv ID")) | |
| refine_query = st.sidebar.checkbox("Refine query before answering?", value=True) | |
| if input_method == "Upload PDF": | |
| uploaded_file = st.sidebar.file_uploader("Upload a PDF file", type=["pdf"]) | |
| if uploaded_file is not None: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file: | |
| tmp_file.write(uploaded_file.read()) | |
| pdf_path = tmp_file.name | |
| st.info("Setting up ScholarBot...") | |
| st.session_state.chat_pipeline = ChatPipeline() | |
| st.session_state.chat_pipeline.setup_from_pdf(pdf_path) | |
| st.success("PDF loaded and indexed successfully!") | |
| else: | |
| arxiv_id = st.sidebar.text_input("Enter arXiv ID:") | |
| if st.sidebar.button("Load Paper") and arxiv_id: | |
| st.info("Setting up ScholarBot...") | |
| st.session_state.chat_pipeline = ChatPipeline() | |
| st.session_state.chat_pipeline.setup(arxiv_id=arxiv_id) | |
| st.success(f"arXiv paper {arxiv_id} loaded successfully!") | |
| st.subheader("Chat with the Paper") | |
| user_input = st.text_input("Ask a question:", placeholder="e.g. What is the JointMI acquisition function?") | |
| if st.button("Generate Answer") and user_input: | |
| if st.session_state.chat_pipeline: | |
| answer = st.session_state.chat_pipeline.query(user_input, refine_query=refine_query) | |
| st.session_state.chat_history.append((user_input, answer)) | |
| else: | |
| st.warning("Please load a paper first.") | |
| if st.session_state.chat_history: | |
| st.markdown("---") | |
| st.subheader("๐ Chat History") | |
| for q, a in st.session_state.chat_history[::-1]: | |
| st.markdown(f"**You:** {q}") | |
| st.markdown(f"**ScholarBot:** {a}") | |
| st.markdown("---") |