Spaces:
Paused
Paused
| import os | |
| import json | |
| import requests | |
| from fastapi import FastAPI, Request | |
| from fastapi.responses import Response | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from mcp.server.lowlevel import Server, NotificationOptions | |
| from mcp.server.sse import SseServerTransport | |
| from mcp import types as mcp_types | |
| import uvicorn | |
| from sse_starlette import EventSourceResponse | |
| import anyio | |
| import asyncio | |
| import logging | |
| from typing import Dict | |
| # Set up logging | |
| logging.basicConfig(level=logging.DEBUG) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI() | |
| # Add CORS middleware to allow Deep Agent to connect | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Adjust for production | |
| allow_credentials=True, | |
| allow_methods=["GET", "POST", "OPTIONS"], | |
| allow_headers=["*"], | |
| ) | |
| # Load environment variables | |
| AIRTABLE_API_TOKEN = os.getenv("AIRTABLE_API_TOKEN") | |
| AIRTABLE_BASE_ID = os.getenv("AIRTABLE_BASE_ID") | |
| TABLE_ID = "tblQECi5f7m4y2NEV" | |
| AIRTABLE_API_URL = f"https://api.airtable.com/v0/{AIRTABLE_BASE_ID}/{TABLE_ID}" | |
| # Helper function for Airtable API requests | |
| def airtable_request(method, endpoint="", data=None): | |
| headers = { | |
| "Authorization": f"Bearer {AIRTABLE_API_TOKEN}", | |
| "Content-Type": "application/json" | |
| } | |
| url = f"{AIRTABLE_API_URL}/{endpoint}" if endpoint else AIRTABLE_API_URL | |
| response = requests.request(method, url, headers=headers, json=data) | |
| response.raise_for_status() | |
| return response.json() | |
| # Tool to list records | |
| async def list_records_tool(request: mcp_types.CallToolRequest): | |
| logger.debug(f"Received list_records_tool request: {request}") | |
| try: | |
| records = airtable_request("GET") | |
| response = { | |
| "success": True, | |
| "result": json.dumps(records) | |
| } | |
| logger.debug(f"list_records_tool response: {response}") | |
| return response | |
| except Exception as e: | |
| response = { | |
| "success": False, | |
| "error": str(e) | |
| } | |
| logger.error(f"list_records_tool error: {response}") | |
| return response | |
| # Tool to create a record | |
| async def create_record_tool(request: mcp_types.CallToolRequest): | |
| logger.debug(f"Received create_record_tool request: {request}") | |
| try: | |
| record_data = request.input.get("record_data", {}) | |
| data = {"records": [{"fields": record_data}]} | |
| response_data = airtable_request("POST", data=data) | |
| response = { | |
| "success": True, | |
| "result": json.dumps(response_data) | |
| } | |
| logger.debug(f"create_record_tool response: {response}") | |
| return response | |
| except Exception as e: | |
| response = { | |
| "success": False, | |
| "error": str(e) | |
| } | |
| logger.error(f"create_record_tool error: {response}") | |
| return response | |
| # Define tools separately (for Deep Agent to discover them) | |
| tools = [ | |
| mcp_types.Tool( | |
| name="list_airtable_records", | |
| description="Lists all records in the specified Airtable table", | |
| inputSchema={} | |
| ), | |
| mcp_types.Tool( | |
| name="create_airtable_record", | |
| description="Creates a new record in the specified Airtable table", | |
| inputSchema={"record_data": {"type": "object"}} | |
| ) | |
| ] | |
| # Define tool handlers | |
| tool_handlers = { | |
| "list_airtable_records": list_records_tool, | |
| "create_airtable_record": create_record_tool | |
| } | |
| # Create MCP server | |
| mcp_server = Server(name="airtable-mcp") | |
| mcp_server.tool_handlers = tool_handlers # Set as attribute | |
| mcp_server.tools = tools # Set tools as attribute for Deep Agent to discover | |
| # Store write streams for each session ID (for SseServerTransport messages) | |
| write_streams: Dict[str, anyio.streams.memory.MemoryObjectSendStream] = {} | |
| # Store SSE stream writers for each session ID (for manual messages) | |
| sse_stream_writers: Dict[str, anyio.streams.memory.MemoryObjectSendStream] = {} | |
| # Initialize SseServerTransport | |
| transport = SseServerTransport("/airtable/mcp") | |
| # SSE endpoint for GET requests | |
| async def handle_sse(request: Request): | |
| logger.debug("Handling SSE connection request") | |
| session_id = None # We'll extract this later | |
| async def sse_writer(): | |
| nonlocal session_id | |
| logger.debug("Starting SSE writer") | |
| async with sse_stream_writer, write_stream_reader: | |
| # Send the initial endpoint event manually to capture the session_id | |
| endpoint_data = "/airtable/mcp?session_id={session_id}" | |
| await sse_stream_writer.send( | |
| {"event": "endpoint", "data": endpoint_data} | |
| ) | |
| logger.debug(f"Sent endpoint event: {endpoint_data}") | |
| async for session_message in write_stream_reader: | |
| # Handle messages from SseServerTransport | |
| if hasattr(session_message, 'message'): | |
| message_data = session_message.message.model_dump_json(by_alias=True, exclude_none=True) | |
| event_data = json.loads(message_data) | |
| logger.debug(f"Received SessionMessage from SseServerTransport: {event_data}") | |
| else: | |
| event_data = session_message | |
| logger.debug(f"Received dict event from SseServerTransport: {event_data}") | |
| # Extract session_id from the endpoint event | |
| if not session_id and event_data.get("event") == "endpoint": | |
| endpoint_url = event_data.get("data", "") | |
| if "session_id=" in endpoint_url: | |
| session_id = endpoint_url.split("session_id=")[1] | |
| placeholder_id = f"placeholder_{id(write_stream)}" | |
| if placeholder_id in write_streams: | |
| write_streams[session_id] = write_streams.pop(placeholder_id) | |
| sse_stream_writers[session_id] = sse_stream_writer | |
| logger.debug(f"Updated placeholder {placeholder_id} to session_id {session_id}") | |
| # Forward the event to the client | |
| await sse_stream_writer.send({ | |
| "event": event_data.get("event", "message"), | |
| "data": event_data.get("data", json.dumps(event_data)) | |
| }) | |
| sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream(0) | |
| try: | |
| async with transport.connect_sse(request.scope, request.receive, request._send) as streams: | |
| read_stream, write_stream = streams | |
| write_stream_reader = write_stream # Since streams are MemoryObject streams | |
| # Store the write_stream with a placeholder ID | |
| placeholder_id = f"placeholder_{id(write_stream)}" | |
| write_streams[placeholder_id] = write_stream | |
| logger.debug(f"Stored write_stream with placeholder_id: {placeholder_id}") | |
| logger.debug("Running MCP server with streams") | |
| await mcp_server.run(read_stream, write_stream, mcp_server.create_initialization_options()) | |
| except Exception as e: | |
| logger.error(f"Error in handle_sse: {str(e)}") | |
| # Clean up write_streams and sse_stream_writers on error | |
| placeholder_id = f"placeholder_{id(write_stream)}" | |
| write_streams.pop(placeholder_id, None) | |
| if session_id: | |
| write_streams.pop(session_id, None) | |
| sse_stream_writers.pop(session_id, None) | |
| raise | |
| return EventSourceResponse(sse_stream_reader, data_sender_callable=sse_writer) | |
| # Message handling endpoint for POST requests | |
| async def handle_post_message(request: Request): | |
| logger.debug("Handling POST message request") | |
| body = await request.body() | |
| logger.debug(f"Received POST message body: {body}") | |
| try: | |
| message = json.loads(body.decode()) | |
| session_id = request.query_params.get("session_id") | |
| # Use sse_stream_writers to send manual responses directly | |
| sse_writer = sse_stream_writers.get(session_id) if session_id else None | |
| write_stream = write_streams.get(session_id) if session_id else None | |
| if message.get("method") == "initialize" and sse_writer: | |
| logger.debug("Handling initialize request manually") | |
| response = { | |
| "jsonrpc": "2.0", | |
| "id": message.get("id"), | |
| "result": { | |
| "protocolVersion": "2025-03-26", | |
| "capabilities": { | |
| "tools": { | |
| "listChanged": True | |
| }, | |
| "prompts": { | |
| "listChanged": False | |
| }, | |
| "resources": { | |
| "subscribe": False, | |
| "listChanged": False | |
| }, | |
| "logging": {}, | |
| "experimental": {} | |
| }, | |
| "serverInfo": { | |
| "name": "airtable-mcp", | |
| "version": "1.0.0" | |
| }, | |
| "instructions": "Airtable MCP server for listing and creating records." | |
| } | |
| } | |
| logger.debug(f"Manual initialize response: {response}") | |
| response_data = json.dumps(response) | |
| await sse_writer.send({ | |
| "event": "message", | |
| "data": response_data | |
| }) | |
| logger.debug(f"Sent initialize response directly via SSE for session {session_id}") | |
| return Response(status_code=202) | |
| if message.get("method") == "tools/list": | |
| logger.debug("Handling tools/list request manually") | |
| response = { | |
| "jsonrpc": "2.0", | |
| "id": message.get("id"), | |
| "result": { | |
| "tools": [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools], | |
| "nextCursor": None | |
| } | |
| } | |
| logger.debug(f"Manual tools/list response: {response}") | |
| response_data = json.dumps(response) | |
| sent = False | |
| # First, try sending directly via sse_writer | |
| if sse_writer: | |
| try: | |
| await sse_writer.send({ | |
| "event": "message", | |
| "data": response_data | |
| }) | |
| logger.debug(f"Sent tools/list response directly via SSE for session {session_id}") | |
| sent = True | |
| except Exception as e: | |
| logger.error(f"Error sending to session {session_id} via sse_writer: {str(e)}") | |
| sse_stream_writers.pop(session_id, None) | |
| # If not found or failed, look for a placeholder ID and update it | |
| if not sent and write_stream: | |
| for sid, ws in list(write_streams.items()): | |
| if sid.startswith("placeholder_"): | |
| try: | |
| write_streams[session_id] = ws | |
| sse_stream_writers[session_id] = sse_writer | |
| write_streams.pop(sid, None) | |
| await sse_writer.send({ | |
| "event": "message", | |
| "data": response_data | |
| }) | |
| logger.debug(f"Updated placeholder {sid} to session_id {session_id} and sent tools/list response") | |
| sent = True | |
| break | |
| except Exception as e: | |
| logger.error(f"Error sending to placeholder {sid}: {str(e)}") | |
| write_streams.pop(sid, None) | |
| sse_stream_writers.pop(session_id, None) | |
| if not sent: | |
| logger.warning(f"Failed to send tools/list response: no active write_streams or sse_writer found") | |
| return Response(status_code=202) | |
| # If neither sse_writer nor write_stream is available, log and handle gracefully | |
| if not sse_writer and not write_stream: | |
| logger.error(f"No sse_writer or write_stream found for session_id: {session_id}") | |
| return Response(status_code=202) | |
| await transport.handle_post_message(request.scope, request.receive, request._send) | |
| logger.debug("POST message handled successfully") | |
| except Exception as e: | |
| logger.error(f"Error handling POST message: {str(e)}") | |
| return Response(status_code=202) | |
| return Response(status_code=202) | |
| # Health check endpoint | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| # Endpoint to list tools (for debugging) | |
| async def list_tools(): | |
| return {"tools": [tool.model_dump() for tool in tools]} | |
| if __name__ == "__main__": | |
| port = int(os.getenv("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) |