aeonshift commited on
Commit
d9de825
·
verified ·
1 Parent(s): 62935c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -23
app.py CHANGED
@@ -111,9 +111,12 @@ mcp_server = Server(name="airtable-mcp")
111
  mcp_server.tool_handlers = tool_handlers # Set as attribute
112
  mcp_server.tools = tools # Set tools as attribute for Deep Agent to discover
113
 
114
- # Store write streams for each session ID
115
  write_streams: Dict[str, anyio.streams.memory.MemoryObjectSendStream] = {}
116
 
 
 
 
117
  # Initialize SseServerTransport
118
  transport = SseServerTransport("/airtable/mcp")
119
 
@@ -133,15 +136,14 @@ async def handle_sse(request: Request):
133
  )
134
  logger.debug(f"Sent endpoint event: {endpoint_data}")
135
  async for session_message in write_stream_reader:
136
- # Since SseServerTransport expects a SessionMessage, we'll handle messages manually
137
- # and send raw SSE events to avoid the error in SseServerTransport's sse_writer
138
  if hasattr(session_message, 'message'):
139
  message_data = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
140
  event_data = json.loads(message_data)
141
- logger.debug(f"Received SessionMessage: {event_data}")
142
  else:
143
  event_data = session_message
144
- logger.debug(f"Received dict event: {event_data}")
145
  # Extract session_id from the endpoint event
146
  if not session_id and event_data.get("event") == "endpoint":
147
  endpoint_url = event_data.get("data", "")
@@ -150,8 +152,9 @@ async def handle_sse(request: Request):
150
  placeholder_id = f"placeholder_{id(write_stream)}"
151
  if placeholder_id in write_streams:
152
  write_streams[session_id] = write_streams.pop(placeholder_id)
 
153
  logger.debug(f"Updated placeholder {placeholder_id} to session_id {session_id}")
154
- # Send the event as a raw SSE event
155
  await sse_stream_writer.send({
156
  "event": event_data.get("event", "message"),
157
  "data": event_data.get("data", json.dumps(event_data))
@@ -170,11 +173,12 @@ async def handle_sse(request: Request):
170
  await mcp_server.run(read_stream, write_stream, mcp_server.create_initialization_options())
171
  except Exception as e:
172
  logger.error(f"Error in handle_sse: {str(e)}")
173
- # Clean up write_streams on error
174
  placeholder_id = f"placeholder_{id(write_stream)}"
175
  write_streams.pop(placeholder_id, None)
176
  if session_id:
177
  write_streams.pop(session_id, None)
 
178
  raise
179
  return EventSourceResponse(sse_stream_reader, data_sender_callable=sse_writer)
180
 
@@ -187,8 +191,10 @@ async def handle_post_message(request: Request):
187
  try:
188
  message = json.loads(body.decode())
189
  session_id = request.query_params.get("session_id")
 
 
190
  write_stream = write_streams.get(session_id) if session_id else None
191
- if message.get("method") == "initialize" and write_stream:
192
  logger.debug("Handling initialize request manually")
193
  response = {
194
  "jsonrpc": "2.0",
@@ -218,10 +224,11 @@ async def handle_post_message(request: Request):
218
  }
219
  logger.debug(f"Manual initialize response: {response}")
220
  response_data = json.dumps(response)
221
- await write_stream.send({
222
  "event": "message",
223
  "data": response_data
224
  })
 
225
  return Response(status_code=202)
226
  if message.get("method") == "tools/list":
227
  logger.debug("Handling tools/list request manually")
@@ -236,27 +243,27 @@ async def handle_post_message(request: Request):
236
  logger.debug(f"Manual tools/list response: {response}")
237
  response_data = json.dumps(response)
238
  sent = False
239
- # First, try the session_id directly
240
- if session_id in write_streams:
241
- write_stream = write_streams[session_id]
242
  try:
243
- await write_stream.send({
244
  "event": "message",
245
  "data": response_data
246
  })
247
- logger.debug(f"Sent tools/list response to session {session_id}")
248
  sent = True
249
  except Exception as e:
250
- logger.error(f"Error sending to session {session_id}: {str(e)}")
251
- write_streams.pop(session_id, None)
252
- # If not found, look for a placeholder ID and update it
253
- if not sent:
254
  for sid, ws in list(write_streams.items()):
255
  if sid.startswith("placeholder_"):
256
  try:
257
  write_streams[session_id] = ws
 
258
  write_streams.pop(sid, None)
259
- await ws.send({
260
  "event": "message",
261
  "data": response_data
262
  })
@@ -266,12 +273,13 @@ async def handle_post_message(request: Request):
266
  except Exception as e:
267
  logger.error(f"Error sending to placeholder {sid}: {str(e)}")
268
  write_streams.pop(sid, None)
 
269
  if not sent:
270
- logger.warning(f"Failed to send tools/list response: no active write_streams found")
271
  return Response(status_code=202)
272
- # If write_stream is None, log and handle gracefully
273
- if not write_stream:
274
- logger.error(f"No write_stream found for session_id: {session_id}")
275
  return Response(status_code=202)
276
  await transport.handle_post_message(request.scope, request.receive, request._send)
277
  logger.debug("POST message handled successfully")
 
111
  mcp_server.tool_handlers = tool_handlers # Set as attribute
112
  mcp_server.tools = tools # Set tools as attribute for Deep Agent to discover
113
 
114
+ # Store write streams for each session ID (for SseServerTransport messages)
115
  write_streams: Dict[str, anyio.streams.memory.MemoryObjectSendStream] = {}
116
 
117
+ # Store SSE stream writers for each session ID (for manual messages)
118
+ sse_stream_writers: Dict[str, anyio.streams.memory.MemoryObjectSendStream] = {}
119
+
120
  # Initialize SseServerTransport
121
  transport = SseServerTransport("/airtable/mcp")
122
 
 
136
  )
137
  logger.debug(f"Sent endpoint event: {endpoint_data}")
138
  async for session_message in write_stream_reader:
139
+ # Handle messages from SseServerTransport
 
140
  if hasattr(session_message, 'message'):
141
  message_data = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
142
  event_data = json.loads(message_data)
143
+ logger.debug(f"Received SessionMessage from SseServerTransport: {event_data}")
144
  else:
145
  event_data = session_message
146
+ logger.debug(f"Received dict event from SseServerTransport: {event_data}")
147
  # Extract session_id from the endpoint event
148
  if not session_id and event_data.get("event") == "endpoint":
149
  endpoint_url = event_data.get("data", "")
 
152
  placeholder_id = f"placeholder_{id(write_stream)}"
153
  if placeholder_id in write_streams:
154
  write_streams[session_id] = write_streams.pop(placeholder_id)
155
+ sse_stream_writers[session_id] = sse_stream_writer
156
  logger.debug(f"Updated placeholder {placeholder_id} to session_id {session_id}")
157
+ # Forward the event to the client
158
  await sse_stream_writer.send({
159
  "event": event_data.get("event", "message"),
160
  "data": event_data.get("data", json.dumps(event_data))
 
173
  await mcp_server.run(read_stream, write_stream, mcp_server.create_initialization_options())
174
  except Exception as e:
175
  logger.error(f"Error in handle_sse: {str(e)}")
176
+ # Clean up write_streams and sse_stream_writers on error
177
  placeholder_id = f"placeholder_{id(write_stream)}"
178
  write_streams.pop(placeholder_id, None)
179
  if session_id:
180
  write_streams.pop(session_id, None)
181
+ sse_stream_writers.pop(session_id, None)
182
  raise
183
  return EventSourceResponse(sse_stream_reader, data_sender_callable=sse_writer)
184
 
 
191
  try:
192
  message = json.loads(body.decode())
193
  session_id = request.query_params.get("session_id")
194
+ # Use sse_stream_writers to send manual responses directly
195
+ sse_writer = sse_stream_writers.get(session_id) if session_id else None
196
  write_stream = write_streams.get(session_id) if session_id else None
197
+ if message.get("method") == "initialize" and sse_writer:
198
  logger.debug("Handling initialize request manually")
199
  response = {
200
  "jsonrpc": "2.0",
 
224
  }
225
  logger.debug(f"Manual initialize response: {response}")
226
  response_data = json.dumps(response)
227
+ await sse_writer.send({
228
  "event": "message",
229
  "data": response_data
230
  })
231
+ logger.debug(f"Sent initialize response directly via SSE for session {session_id}")
232
  return Response(status_code=202)
233
  if message.get("method") == "tools/list":
234
  logger.debug("Handling tools/list request manually")
 
243
  logger.debug(f"Manual tools/list response: {response}")
244
  response_data = json.dumps(response)
245
  sent = False
246
+ # First, try sending directly via sse_writer
247
+ if sse_writer:
 
248
  try:
249
+ await sse_writer.send({
250
  "event": "message",
251
  "data": response_data
252
  })
253
+ logger.debug(f"Sent tools/list response directly via SSE for session {session_id}")
254
  sent = True
255
  except Exception as e:
256
+ logger.error(f"Error sending to session {session_id} via sse_writer: {str(e)}")
257
+ sse_stream_writers.pop(session_id, None)
258
+ # If not found or failed, look for a placeholder ID and update it
259
+ if not sent and write_stream:
260
  for sid, ws in list(write_streams.items()):
261
  if sid.startswith("placeholder_"):
262
  try:
263
  write_streams[session_id] = ws
264
+ sse_stream_writers[session_id] = sse_writer
265
  write_streams.pop(sid, None)
266
+ await sse_writer.send({
267
  "event": "message",
268
  "data": response_data
269
  })
 
273
  except Exception as e:
274
  logger.error(f"Error sending to placeholder {sid}: {str(e)}")
275
  write_streams.pop(sid, None)
276
+ sse_stream_writers.pop(session_id, None)
277
  if not sent:
278
+ logger.warning(f"Failed to send tools/list response: no active write_streams or sse_writer found")
279
  return Response(status_code=202)
280
+ # If neither sse_writer nor write_stream is available, log and handle gracefully
281
+ if not sse_writer and not write_stream:
282
+ logger.error(f"No sse_writer or write_stream found for session_id: {session_id}")
283
  return Response(status_code=202)
284
  await transport.handle_post_message(request.scope, request.receive, request._send)
285
  logger.debug("POST message handled successfully")