Spaces:
Sleeping
Sleeping
File size: 6,079 Bytes
c0f74f5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | import streamlit as st
from agents.graph import app
from langchain_core.messages import HumanMessage
import os
import sys
import tempfile
from typing import List
# Ensure you have implemented this function in FinalProject/models/retriever.py
# It should accept a list of PDF file paths and return a LangChain Retriever object.
try:
from models.retriever import get_rag_retriever_from_paths
except ImportError:
st.error("Could not import get_rag_retriever_from_paths. Please check your models/retriever.py file.")
sys.exit()
# --- PATH SETUP ---
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir)
# --- PAGE CONFIGURATION ---
st.set_page_config(
page_title="GraphQuery RAG Agent",
page_icon="π€",
layout="wide"
)
# --- CACHED FUNCTION TO BUILD RAG RETRIEVER ---
# Hashing trick: By passing file_paths (a list of strings), Streamlit hashes the list.
# The expensive function only runs if the list of paths changes (i.e., files are added/removed).
@st.cache_resource
def load_and_index_documents(file_paths: List[str]):
"""Loads documents and creates/returns a RAG retriever."""
if not file_paths:
return None
with st.spinner(f"Indexing {len(file_paths)} PDF file(s)... This may take a moment."):
try:
# Calls the function from your models/retriever.py
retriever = get_rag_retriever_from_paths(file_paths)
st.success(f"Indexed {len(file_paths)} PDF file(s) successfully!")
return retriever
except Exception as e:
st.error(f"Failed to index documents: {e}")
return None
# --- SIDEBAR (Settings, Key, and Upload) ---
with st.sidebar:
st.header("βοΈ Agent Settings")
st.caption("Configure your LLM and Access Key.")
# API Key Input
api_key = st.text_input(
"**Groq API Key (Required):**",
type="password",
help="Paste your private Groq API Key here. It is used only for this session.",
)
st.divider()
# 1. FILE UPLOAD SECTION
st.subheader("π Document Upload")
uploaded_files = st.file_uploader(
"Upload your own PDFs for RAG context:",
type=["pdf"],
accept_multiple_files=True
)
# 2. FILE SAVING & INDEXING LOGIC
file_paths = []
rag_retriever = None
if uploaded_files:
# Streamlit files are in memory; we must write them to a temporary file
# so LangChain's PyPDFLoader (which needs a file path) can read them.
with tempfile.TemporaryDirectory() as temp_dir:
for uploaded_file in uploaded_files:
file_path = os.path.join(temp_dir, uploaded_file.name)
# Write the file bytes to the temporary path
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
file_paths.append(file_path)
# 3. Build the retriever and cache it based on the list of paths
# NOTE: We pass the list of temporary paths to the cached function.
rag_retriever = load_and_index_documents(file_paths)
else:
# Clear the cache if no files are uploaded to ensure a clean state
st.info("No documents uploaded. Only Wikipedia lookup is enabled.")
load_and_index_documents.clear() # Clears the cache for this function
st.divider()
st.subheader("π οΈ Features")
st.info(f"RAG (Document Context) status: {'**ENABLED**' if rag_retriever else 'DISABLED'}")
st.info("Wikipedia Routing is always active.")
st.text("MORE COMING SOON β±οΈ")
# --- MAIN INTERFACE (Header) ---
st.markdown(
"""
# π§ LangGraph Query Model
### Multi-Source RAG Agent
Ask a question related to your uploaded documents or general knowledge.
"""
)
st.divider()
# --- STATE INITIALIZATION ---
initial_state_base = {
"documents": [],
"source": "",
"api_key": api_key,
# Pass the dynamically created retriever to the graph state
"rag_retriever": rag_retriever
}
# --- CHAT INPUT AND LOGIC ---
with st.form(key='query_form', clear_on_submit=True):
user_query = st.text_input(
"**Your Question:**",
placeholder="e.g., What is the significance of the military-industrial complex in Russia?",
label_visibility="collapsed"
)
submit_button = st.form_submit_button(label='Ask the Agent π')
# --- EXECUTION LOGIC ---
if submit_button and user_query:
if not api_key:
st.error("π **Error:** Please enter your Groq API Key in the sidebar to run the query.")
st.stop()
st.info("π **Querying the Agent...** Please wait.")
# Prepare state
initial_state = initial_state_base.copy()
initial_state["messages"] = [HumanMessage(content=user_query)]
with st.spinner('Thinking... Routing and Retrieving Context...'):
try:
response = app.invoke(initial_state)
# --- Output Display ---
final_message = response["messages"][-1].content
st.success("β
**Agent Response:**")
st.markdown(final_message)
st.divider()
# Optional: Show debug info
with st.expander("π **Debug Info (Agent Flow)**"):
st.write(f"**Final Source:** {response.get('source', 'Unknown')}")
if 'documents' in response and response['documents']:
st.write(f"**Retrieved Documents:** {len(response['documents'])} chunks used.")
except Exception as e:
st.error("β **Agent Failure:** An error occurred during execution.")
st.exception(e)
elif not user_query and not api_key:
st.markdown("π Start by entering your **Groq API Key** in the sidebar and asking a question above!") |