File size: 9,573 Bytes
f831e98
 
 
 
 
 
 
 
 
86cbe3c
f831e98
 
86cbe3c
f831e98
86cbe3c
 
 
f831e98
 
 
 
86cbe3c
 
 
f831e98
 
86cbe3c
 
 
 
 
8595be6
 
86cbe3c
 
 
 
 
 
 
a0eb181
 
86cbe3c
a0eb181
86cbe3c
a0eb181
86cbe3c
 
 
f831e98
86cbe3c
 
f831e98
 
86cbe3c
a0eb181
86cbe3c
f831e98
86cbe3c
 
 
 
 
 
f831e98
86cbe3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f831e98
86cbe3c
 
f831e98
86cbe3c
 
 
 
f831e98
a0eb181
86cbe3c
a0eb181
f831e98
a0eb181
86cbe3c
a0eb181
86cbe3c
 
 
 
 
 
f831e98
86cbe3c
 
f831e98
86cbe3c
 
f831e98
86cbe3c
 
 
a0eb181
 
 
 
 
 
 
f831e98
86cbe3c
 
f831e98
86cbe3c
 
 
 
f831e98
8595be6
 
 
 
 
 
6422ca4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8595be6
 
6422ca4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8595be6
 
 
 
86cbe3c
 
 
 
f831e98
86cbe3c
 
 
 
8595be6
 
 
 
 
 
 
 
 
f831e98
86cbe3c
 
 
 
f831e98
86cbe3c
 
 
8595be6
f831e98
86cbe3c
 
 
 
 
f831e98
86cbe3c
f831e98
86cbe3c
 
 
 
8595be6
 
 
 
86cbe3c
 
 
 
8595be6
 
 
 
 
 
 
 
 
 
 
 
 
 
f831e98
8595be6
 
 
 
 
f831e98
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
#!/usr/bin/env python3
"""
Streamlit MCP Monitor & Query Tester
A lightweight monitoring and testing interface for the agentic system.
All database access MUST go through MCP server - no direct connections allowed.
"""

import streamlit as st
import requests
import os
import json
import pandas as pd
from typing import Dict, Any

# --- Configuration ---
AGENT_URL = os.getenv("AGENT_URL", "http://agent:8001/query")
NEO4J_URL = os.getenv("NEO4J_URL", "http://neo4j:7474")
MCP_URL = os.getenv("MCP_URL", "http://mcp:8000/mcp")
MCP_API_KEY = os.getenv("MCP_API_KEY", "dev-key-123")

st.set_page_config(
    page_title="GraphRAG Chat",
    page_icon="πŸ’¬",
    layout="wide"
)

# --- Session State ---
if 'messages' not in st.session_state:
    st.session_state.messages = []
if 'schema_info' not in st.session_state:
    st.session_state.schema_info = ""
if 'current_results' not in st.session_state:
    st.session_state.current_results = None

# --- Helper Functions ---
def stream_agent_response(question: str):
    """Streams the agent's response, yielding JSON objects."""
    try:
        with requests.post(AGENT_URL, json={"question": question}, stream=True, timeout=300) as r:
            r.raise_for_status()
            for line in r.iter_lines():
                if line:
                    try:
                        yield json.loads(line.decode('utf-8'))
                    except json.JSONDecodeError:
                        # Skip malformed JSON lines
                        continue
    except requests.exceptions.RequestException as e:
        yield {"error": f"Failed to connect to agent: {e}"}

def fetch_schema_info() -> str:
    """Fetches the database schema from the MCP server for display."""
    try:
        response = requests.post(
            f"{MCP_URL}/discovery/get_relevant_schemas",
            headers={"x-api-key": MCP_API_KEY, "Content-Type": "application/json"},
            json={"query": ""}
        )
        response.raise_for_status()
        data = response.json()

        if data.get("status") == "success":
            schemas = data.get("schemas", [])
            if not schemas: return "No schema information found."
            
            # Group columns by table
            tables = {}
            for s in schemas:
                table_key = f"{s.get('database', '')}.{s.get('table', '')}"
                if table_key not in tables:
                    tables[table_key] = []
                tables[table_key].append(f"{s.get('name', '')} ({s.get('type', [''])[0]})")
            
            schema_text = ""
            for table, columns in tables.items():
                schema_text += f"**{table}**:\n"
                for col in columns:
                    schema_text += f"- {col}\n"
            return schema_text
        else:
            return f"Error from MCP: {data.get('message', 'Unknown error')}"

    except requests.exceptions.RequestException as e:
        return f"Could not fetch schema: {e}"

@st.cache_data(ttl=600)
def get_cached_schema():
    """Cache the schema info to avoid repeated calls."""
    return fetch_schema_info()

@st.cache_data(ttl=10)
def check_service_health(service_name: str, url: str) -> bool:
    """Checks if a service is reachable. Cached for 10 seconds."""
    try:
        response = requests.get(url, timeout=2)
        return response.status_code in [200, 401]
    except Exception:
        return False

# --- UI Components ---
def display_sidebar():
    with st.sidebar:
        st.title("πŸ—„οΈ Database Schema")
        
        if st.button("πŸ”„ Refresh Schema"):
            st.cache_data.clear()

        st.session_state.schema_info = get_cached_schema()
        st.markdown(st.session_state.schema_info)

        st.markdown("---")
        st.title("πŸ”Œ Service Status")
        
        try:
            neo4j_status = "βœ… Online" if check_service_health("Neo4j", NEO4J_URL) else "❌ Offline"
            mcp_health_url = "http://mcp:8000/health"
            mcp_status = "βœ… Online" if check_service_health("MCP", mcp_health_url) else "❌ Offline"
        except Exception as e:
            neo4j_status = "❓ Unknown"
            mcp_status = "❓ Unknown"
        
        st.markdown(f"**Neo4j:** {neo4j_status}")
        st.markdown(f"**MCP Server:** {mcp_status}")
        
        st.markdown("---")
        if st.button("πŸ—‘οΈ Clear Chat History"):
            st.session_state.messages = []
            st.rerun()

def extract_sql_results(observation_content: str) -> pd.DataFrame | None:
    """Extract SQL results from execute_query tool observation."""
    try:
        if "execute_query" not in observation_content or "returned:" not in observation_content:
            return None
        
        # Look for JSON results in the observation
        if "Query returned" in observation_content and "rows:" in observation_content:
            # Extract the table format from the text
            lines = observation_content.split('\n')
            table_start = -1
            for i, line in enumerate(lines):
                if "Query returned" in line and "rows:" in line:
                    table_start = i + 1
                    break
            
            if table_start >= 0 and table_start < len(lines):
                # Find the table data
                table_lines = []
                for i in range(table_start, len(lines)):
                    line = lines[i].strip()
                    if line and not line.startswith("Final Answer"):
                        if "|" in line:  # Table format
                            table_lines.append(line)
                        elif line.startswith("PT") or line.startswith("DIAB") or line.startswith("NEURO"):  # Data row
                            table_lines.append(line)
                    elif line.startswith("Final Answer"):
                        break
                
                if len(table_lines) >= 2:  # Headers + at least one data row
                    # Parse headers
                    headers = [h.strip() for h in table_lines[0].split('|')]
                    
                    # Parse data rows
                    data_rows = []
                    for line in table_lines[1:]:
                        if "and" in line and "more rows" in line:
                            break
                        row_values = [v.strip() for v in line.split('|')]
                        if len(row_values) == len(headers):
                            data_rows.append(row_values)
                    
                    if data_rows:
                        return pd.DataFrame(data_rows, columns=headers)
    except Exception:
        pass
    return None

def main():
    display_sidebar()
    st.title("πŸ’¬ GraphRAG Conversational Agent")
    st.markdown("Ask questions about the life sciences dataset. The agent's thought process will be shown below.")

    # Display chat history
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])
            if message.get("dataframe") is not None:
                st.dataframe(message["dataframe"], use_container_width=True)
                csv = message["dataframe"].to_csv(index=False)
                st.download_button(
                    label="πŸ“₯ Download CSV",
                    data=csv,
                    file_name="query_results.csv",
                    mime="text/csv"
                )

    if prompt := st.chat_input("Ask your question here..."):
        st.session_state.messages.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)

        with st.chat_message("assistant"):
            full_response = ""
            response_box = st.empty()
            sql_results_df = None
            
            for chunk in stream_agent_response(prompt):
                if "error" in chunk:
                    full_response = chunk["error"]
                    response_box.error(full_response)
                    break
                
                content = chunk.get("content", "")
                
                if chunk.get("type") == "thought":
                    full_response += f"πŸ€” *{content}*\n\n"
                elif chunk.get("type") == "observation":
                    full_response += f"{content}\n\n"
                    # Try to extract SQL results
                    df = extract_sql_results(content)
                    if df is not None:
                        sql_results_df = df
                elif chunk.get("type") == "final_answer":
                    full_response += f"**Final Answer:**\n{content}"

                response_box.markdown(full_response)
            
            # Display DataFrame if SQL results were found
            if sql_results_df is not None:
                st.markdown("---")
                st.markdown("### πŸ“Š Query Results")
                st.dataframe(sql_results_df, use_container_width=True)
                csv = sql_results_df.to_csv(index=False)
                st.download_button(
                    label="πŸ“₯ Download CSV",
                    data=csv,
                    file_name="query_results.csv",
                    mime="text/csv",
                    key=f"download_{len(st.session_state.messages)}"
                )
        
        st.session_state.messages.append({
            "role": "assistant", 
            "content": full_response,
            "dataframe": sql_results_df
        })

if __name__ == "__main__":
    main()