File size: 6,180 Bytes
8bf4d58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""MCP Server for Snowflake data warehouse."""

import logging
from typing import Any, Dict, List, Optional

try:
    from mcp.types import Tool
    MCP_AVAILABLE = True
except ImportError:
    MCP_AVAILABLE = False
    class Tool:
        def __init__(self, **kwargs):
            pass

from src.mcp.mcp_server import BaseMCPServer

logger = logging.getLogger(__name__)

try:
    import snowflake.connector
    import pandas as pd
    SNOWFLAKE_AVAILABLE = True
except ImportError:
    SNOWFLAKE_AVAILABLE = False
    logger.warning("snowflake-connector-python not installed")


class SnowflakeMCPServer(BaseMCPServer):
    """MCP Server for Snowflake data warehouse operations."""

    def __init__(self, config: Optional[Dict] = None):
        """Initialize Snowflake MCP server."""
        super().__init__("snowflake_server")
        self.config = config or {}
        self.connection = None
        self.cursor = None
        if SNOWFLAKE_AVAILABLE:
            self._register_tools()

    def _register_tools(self):
        """Register Snowflake tools with MCP server."""
        if not SNOWFLAKE_AVAILABLE:
            logger.warning("Snowflake connector not available, skipping tool registration")
            return

        # Query tool
        query_tool = Tool(
            name="snowflake_query",
            description="Execute SQL query on Snowflake data warehouse",
            inputSchema={
                "type": "object",
                "properties": {
                    "sql": {
                        "type": "string",
                        "description": "SQL query to execute",
                    },
                },
                "required": ["sql"],
            },
        )
        self.register_tool(query_tool)

        # List tables tool
        list_tables_tool = Tool(
            name="snowflake_list_tables",
            description="List all tables in the current schema",
            inputSchema={"type": "object", "properties": {}},
        )
        self.register_tool(list_tables_tool)

        # Get table schema tool
        schema_tool = Tool(
            name="snowflake_get_schema",
            description="Get schema information for a table",
            inputSchema={
                "type": "object",
                "properties": {
                    "table_name": {
                        "type": "string",
                        "description": "Name of the table",
                    },
                },
                "required": ["table_name"],
            },
        )
        self.register_tool(schema_tool)

    def connect(self):
        """Establish connection to Snowflake."""
        if not SNOWFLAKE_AVAILABLE:
            return False

        try:
            self.connection = snowflake.connector.connect(
                account=self.config.get('account'),
                user=self.config.get('user'),
                password=self.config.get('password'),
                warehouse=self.config.get('warehouse'),
                database=self.config.get('database'),
                schema=self.config.get('schema'),
                role=self.config.get('role', 'ACCOUNTADMIN'),
            )
            self.cursor = self.connection.cursor()
            logger.info(f"Connected to Snowflake account: {self.config.get('account')}")
            return True
        except Exception as e:
            logger.error(f"Snowflake connection failed: {e}")
            return False

    def query(self, sql_query: str) -> List[Dict]:
        """Execute SQL query on Snowflake."""
        if not SNOWFLAKE_AVAILABLE:
            return [{"error": "Snowflake connector not available"}]

        if not self.connection:
            if not self.connect():
                return [{"error": "Failed to connect to Snowflake"}]

        try:
            self.cursor.execute(sql_query)
            columns = [desc[0] for desc in self.cursor.description]
            results = self.cursor.fetchall()
            return [dict(zip(columns, row)) for row in results]
        except Exception as e:
            logger.error(f"Query error: {e}")
            return [{"error": str(e), "query": sql_query}]

    def get_tables(self) -> List[str]:
        """List all tables in the current schema."""
        if not self.config.get('database') or not self.config.get('schema'):
            return []

        query = f"""
        SELECT TABLE_NAME 
        FROM {self.config['database']}.INFORMATION_SCHEMA.TABLES 
        WHERE TABLE_SCHEMA = '{self.config['schema']}'
        """
        results = self.query(query)
        return [row['TABLE_NAME'] for row in results if 'TABLE_NAME' in row]

    def get_table_schema(self, table_name: str) -> List[Dict]:
        """Get schema information for a table."""
        if not self.config.get('database') or not self.config.get('schema'):
            return []

        query = f"""
        SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE
        FROM {self.config['database']}.INFORMATION_SCHEMA.COLUMNS
        WHERE TABLE_SCHEMA = '{self.config['schema']}'
        AND TABLE_NAME = '{table_name}'
        """
        return self.query(query)

    async def _execute_tool(self, name: str, arguments: Dict[str, Any]) -> Any:
        """Execute a Snowflake tool."""
        if not self.config:
            return {"error": "Snowflake configuration not provided"}

        if name == "snowflake_query":
            sql = arguments.get("sql", "")
            return {"results": self.query(sql)}

        elif name == "snowflake_list_tables":
            return {"tables": self.get_tables()}

        elif name == "snowflake_get_schema":
            table_name = arguments.get("table_name")
            if not table_name:
                return {"error": "table_name is required"}
            return {"schema": self.get_table_schema(table_name)}

        else:
            raise ValueError(f"Unknown tool: {name}")

    def close(self):
        """Close Snowflake connection."""
        if self.cursor:
            self.cursor.close()
        if self.connection:
            self.connection.close()

    def __del__(self):
        """Cleanup on deletion."""
        self.close()