Spaces:
No application file
No application file
| #!/usr/bin/env python3 | |
| """ | |
| Streamlit MCP Monitor & Query Tester | |
| A lightweight monitoring and testing interface for the agentic system. | |
| All database access MUST go through MCP server - no direct connections allowed. | |
| """ | |
| import streamlit as st | |
| import requests | |
| import os | |
| import json | |
| import pandas as pd | |
| from typing import Dict, Any | |
| # --- Configuration --- | |
| AGENT_URL = os.getenv("AGENT_URL", "http://agent:8001/query") | |
| NEO4J_URL = os.getenv("NEO4J_URL", "http://neo4j:7474") | |
| MCP_URL = os.getenv("MCP_URL", "http://mcp:8000/mcp") | |
| MCP_API_KEY = os.getenv("MCP_API_KEY", "dev-key-123") | |
| st.set_page_config( | |
| page_title="GraphRAG Chat", | |
| page_icon="π¬", | |
| layout="wide" | |
| ) | |
| # --- Session State --- | |
| if 'messages' not in st.session_state: | |
| st.session_state.messages = [] | |
| if 'schema_info' not in st.session_state: | |
| st.session_state.schema_info = "" | |
| if 'current_results' not in st.session_state: | |
| st.session_state.current_results = None | |
| # --- Helper Functions --- | |
| def stream_agent_response(question: str): | |
| """Streams the agent's response, yielding JSON objects.""" | |
| try: | |
| with requests.post(AGENT_URL, json={"question": question}, stream=True, timeout=300) as r: | |
| r.raise_for_status() | |
| for line in r.iter_lines(): | |
| if line: | |
| try: | |
| yield json.loads(line.decode('utf-8')) | |
| except json.JSONDecodeError: | |
| # Skip malformed JSON lines | |
| continue | |
| except requests.exceptions.RequestException as e: | |
| yield {"error": f"Failed to connect to agent: {e}"} | |
| def fetch_schema_info() -> str: | |
| """Fetches the database schema from the MCP server for display.""" | |
| try: | |
| response = requests.post( | |
| f"{MCP_URL}/discovery/get_relevant_schemas", | |
| headers={"x-api-key": MCP_API_KEY, "Content-Type": "application/json"}, | |
| json={"query": ""} | |
| ) | |
| response.raise_for_status() | |
| data = response.json() | |
| if data.get("status") == "success": | |
| schemas = data.get("schemas", []) | |
| if not schemas: return "No schema information found." | |
| # Group columns by table | |
| tables = {} | |
| for s in schemas: | |
| table_key = f"{s.get('database', '')}.{s.get('table', '')}" | |
| if table_key not in tables: | |
| tables[table_key] = [] | |
| tables[table_key].append(f"{s.get('name', '')} ({s.get('type', [''])[0]})") | |
| schema_text = "" | |
| for table, columns in tables.items(): | |
| schema_text += f"**{table}**:\n" | |
| for col in columns: | |
| schema_text += f"- {col}\n" | |
| return schema_text | |
| else: | |
| return f"Error from MCP: {data.get('message', 'Unknown error')}" | |
| except requests.exceptions.RequestException as e: | |
| return f"Could not fetch schema: {e}" | |
| def get_cached_schema(): | |
| """Cache the schema info to avoid repeated calls.""" | |
| return fetch_schema_info() | |
| def check_service_health(service_name: str, url: str) -> bool: | |
| """Checks if a service is reachable. Cached for 10 seconds.""" | |
| try: | |
| response = requests.get(url, timeout=2) | |
| return response.status_code in [200, 401] | |
| except Exception: | |
| return False | |
| # --- UI Components --- | |
| def display_sidebar(): | |
| with st.sidebar: | |
| st.title("ποΈ Database Schema") | |
| if st.button("π Refresh Schema"): | |
| st.cache_data.clear() | |
| st.session_state.schema_info = get_cached_schema() | |
| st.markdown(st.session_state.schema_info) | |
| st.markdown("---") | |
| st.title("π Service Status") | |
| try: | |
| neo4j_status = "β Online" if check_service_health("Neo4j", NEO4J_URL) else "β Offline" | |
| mcp_health_url = "http://mcp:8000/health" | |
| mcp_status = "β Online" if check_service_health("MCP", mcp_health_url) else "β Offline" | |
| except Exception as e: | |
| neo4j_status = "β Unknown" | |
| mcp_status = "β Unknown" | |
| st.markdown(f"**Neo4j:** {neo4j_status}") | |
| st.markdown(f"**MCP Server:** {mcp_status}") | |
| st.markdown("---") | |
| if st.button("ποΈ Clear Chat History"): | |
| st.session_state.messages = [] | |
| st.rerun() | |
| def extract_sql_results(observation_content: str) -> pd.DataFrame | None: | |
| """Extract SQL results from execute_query tool observation.""" | |
| try: | |
| if "execute_query" not in observation_content or "returned:" not in observation_content: | |
| return None | |
| # Look for JSON results in the observation | |
| if "Query returned" in observation_content and "rows:" in observation_content: | |
| # Extract the table format from the text | |
| lines = observation_content.split('\n') | |
| table_start = -1 | |
| for i, line in enumerate(lines): | |
| if "Query returned" in line and "rows:" in line: | |
| table_start = i + 1 | |
| break | |
| if table_start >= 0 and table_start < len(lines): | |
| # Find the table data | |
| table_lines = [] | |
| for i in range(table_start, len(lines)): | |
| line = lines[i].strip() | |
| if line and not line.startswith("Final Answer"): | |
| if "|" in line: # Table format | |
| table_lines.append(line) | |
| elif line.startswith("PT") or line.startswith("DIAB") or line.startswith("NEURO"): # Data row | |
| table_lines.append(line) | |
| elif line.startswith("Final Answer"): | |
| break | |
| if len(table_lines) >= 2: # Headers + at least one data row | |
| # Parse headers | |
| headers = [h.strip() for h in table_lines[0].split('|')] | |
| # Parse data rows | |
| data_rows = [] | |
| for line in table_lines[1:]: | |
| if "and" in line and "more rows" in line: | |
| break | |
| row_values = [v.strip() for v in line.split('|')] | |
| if len(row_values) == len(headers): | |
| data_rows.append(row_values) | |
| if data_rows: | |
| return pd.DataFrame(data_rows, columns=headers) | |
| except Exception: | |
| pass | |
| return None | |
| def main(): | |
| display_sidebar() | |
| st.title("π¬ GraphRAG Conversational Agent") | |
| st.markdown("Ask questions about the life sciences dataset. The agent's thought process will be shown below.") | |
| # Display chat history | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| if message.get("dataframe") is not None: | |
| st.dataframe(message["dataframe"], use_container_width=True) | |
| csv = message["dataframe"].to_csv(index=False) | |
| st.download_button( | |
| label="π₯ Download CSV", | |
| data=csv, | |
| file_name="query_results.csv", | |
| mime="text/csv" | |
| ) | |
| if prompt := st.chat_input("Ask your question here..."): | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| with st.chat_message("assistant"): | |
| full_response = "" | |
| response_box = st.empty() | |
| sql_results_df = None | |
| for chunk in stream_agent_response(prompt): | |
| if "error" in chunk: | |
| full_response = chunk["error"] | |
| response_box.error(full_response) | |
| break | |
| content = chunk.get("content", "") | |
| if chunk.get("type") == "thought": | |
| full_response += f"π€ *{content}*\n\n" | |
| elif chunk.get("type") == "observation": | |
| full_response += f"{content}\n\n" | |
| # Try to extract SQL results | |
| df = extract_sql_results(content) | |
| if df is not None: | |
| sql_results_df = df | |
| elif chunk.get("type") == "final_answer": | |
| full_response += f"**Final Answer:**\n{content}" | |
| response_box.markdown(full_response) | |
| # Display DataFrame if SQL results were found | |
| if sql_results_df is not None: | |
| st.markdown("---") | |
| st.markdown("### π Query Results") | |
| st.dataframe(sql_results_df, use_container_width=True) | |
| csv = sql_results_df.to_csv(index=False) | |
| st.download_button( | |
| label="π₯ Download CSV", | |
| data=csv, | |
| file_name="query_results.csv", | |
| mime="text/csv", | |
| key=f"download_{len(st.session_state.messages)}" | |
| ) | |
| st.session_state.messages.append({ | |
| "role": "assistant", | |
| "content": full_response, | |
| "dataframe": sql_results_df | |
| }) | |
| if __name__ == "__main__": | |
| main() | |