File size: 6,079 Bytes
c0f74f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import streamlit as st
from agents.graph import app
from langchain_core.messages import HumanMessage
import os
import sys
import tempfile
from typing import List

# Ensure you have implemented this function in FinalProject/models/retriever.py
# It should accept a list of PDF file paths and return a LangChain Retriever object.
try:
    from models.retriever import get_rag_retriever_from_paths
except ImportError:
    st.error("Could not import get_rag_retriever_from_paths. Please check your models/retriever.py file.")
    sys.exit()


# --- PATH SETUP ---
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir) 

# --- PAGE CONFIGURATION ---
st.set_page_config(
    page_title="GraphQuery RAG Agent",
    page_icon="πŸ€–",
    layout="wide"  
)

# --- CACHED FUNCTION TO BUILD RAG RETRIEVER ---
# Hashing trick: By passing file_paths (a list of strings), Streamlit hashes the list.
# The expensive function only runs if the list of paths changes (i.e., files are added/removed).
@st.cache_resource
def load_and_index_documents(file_paths: List[str]):
    """Loads documents and creates/returns a RAG retriever."""
    if not file_paths:
        return None
    
    with st.spinner(f"Indexing {len(file_paths)} PDF file(s)... This may take a moment."):
        try:
            # Calls the function from your models/retriever.py
            retriever = get_rag_retriever_from_paths(file_paths)
            st.success(f"Indexed {len(file_paths)} PDF file(s) successfully!")
            return retriever
        except Exception as e:
            st.error(f"Failed to index documents: {e}")
            return None

# --- SIDEBAR (Settings, Key, and Upload) ---
with st.sidebar:
    st.header("βš™οΈ Agent Settings")
    st.caption("Configure your LLM and Access Key.")
    
    # API Key Input
    api_key = st.text_input(
        "**Groq API Key (Required):**", 
        type="password",  
        help="Paste your private Groq API Key here. It is used only for this session.",
    )
    
    st.divider()
    
    # 1. FILE UPLOAD SECTION
    st.subheader("πŸ“š Document Upload")
    uploaded_files = st.file_uploader(
        "Upload your own PDFs for RAG context:", 
        type=["pdf"], 
        accept_multiple_files=True
    )
    
    # 2. FILE SAVING & INDEXING LOGIC
    file_paths = []
    rag_retriever = None
    
    if uploaded_files:
        # Streamlit files are in memory; we must write them to a temporary file 
        # so LangChain's PyPDFLoader (which needs a file path) can read them.
        with tempfile.TemporaryDirectory() as temp_dir:
            for uploaded_file in uploaded_files:
                file_path = os.path.join(temp_dir, uploaded_file.name)
                # Write the file bytes to the temporary path
                with open(file_path, "wb") as f:
                    f.write(uploaded_file.getbuffer())
                file_paths.append(file_path)

            # 3. Build the retriever and cache it based on the list of paths
            # NOTE: We pass the list of temporary paths to the cached function.
            rag_retriever = load_and_index_documents(file_paths)

    else:
        # Clear the cache if no files are uploaded to ensure a clean state
        st.info("No documents uploaded. Only Wikipedia lookup is enabled.")
        load_and_index_documents.clear() # Clears the cache for this function
    

    st.divider()
    st.subheader("πŸ› οΈ Features")
    st.info(f"RAG (Document Context) status: {'**ENABLED**' if rag_retriever else 'DISABLED'}")
    st.info("Wikipedia Routing is always active.")
    st.text("MORE COMING SOON ⏱️")

# --- MAIN INTERFACE (Header) ---
st.markdown(
    """

    # 🧠 LangGraph Query Model 

    ### Multi-Source RAG Agent

    Ask a question related to your uploaded documents or general knowledge.

    """
)
st.divider()

# --- STATE INITIALIZATION ---
initial_state_base = {
    "documents": [],
    "source": "",
    "api_key": api_key, 
    # Pass the dynamically created retriever to the graph state
    "rag_retriever": rag_retriever 
}

# --- CHAT INPUT AND LOGIC ---
with st.form(key='query_form', clear_on_submit=True):
    user_query = st.text_input(
        "**Your Question:**", 
        placeholder="e.g., What is the significance of the military-industrial complex in Russia?",
        label_visibility="collapsed"
    )
    submit_button = st.form_submit_button(label='Ask the Agent πŸš€')


# --- EXECUTION LOGIC ---

if submit_button and user_query:
    if not api_key:
        st.error("πŸ”‘ **Error:** Please enter your Groq API Key in the sidebar to run the query.")
        st.stop()
        
    st.info("πŸ”„ **Querying the Agent...** Please wait.")
    
    # Prepare state
    initial_state = initial_state_base.copy()
    initial_state["messages"] = [HumanMessage(content=user_query)]
    
    with st.spinner('Thinking... Routing and Retrieving Context...'):
        try:
            response = app.invoke(initial_state)
            
            # --- Output Display ---
            final_message = response["messages"][-1].content
            
            st.success("βœ… **Agent Response:**")
            st.markdown(final_message)
            st.divider()
            
            # Optional: Show debug info
            with st.expander("πŸ” **Debug Info (Agent Flow)**"):
                st.write(f"**Final Source:** {response.get('source', 'Unknown')}")
                if 'documents' in response and response['documents']:
                    st.write(f"**Retrieved Documents:** {len(response['documents'])} chunks used.")
                
        except Exception as e:
            st.error("❌ **Agent Failure:** An error occurred during execution.")
            st.exception(e)
            
elif not user_query and not api_key:
    st.markdown("πŸ‘‹ Start by entering your **Groq API Key** in the sidebar and asking a question above!")