serichard1 commited on
Commit
f91b5b2
Β·
1 Parent(s): 7a52daf

fix async error

Browse files
Files changed (1) hide show
  1. app.py +122 -51
app.py CHANGED
@@ -1,10 +1,12 @@
1
  import asyncio
2
  import os
3
  import json
4
- from typing import List, Dict, Any, Union
5
  from contextlib import AsyncExitStack
6
  import mimetypes
7
  import tempfile
 
 
8
 
9
  import gradio as gr
10
  from gradio.components.chatbot import ChatMessage
@@ -17,15 +19,13 @@ from dotenv import load_dotenv
17
 
18
  load_dotenv()
19
 
20
- loop = asyncio.new_event_loop()
21
- asyncio.set_event_loop(loop)
22
-
23
  class MCPClientWrapper:
24
  def __init__(self):
25
- self.session = None
26
- self.exit_stack = None
27
  self.tools = []
28
  self.connected = False
 
29
 
30
  # Initialize all LLM clients
31
  self.anthropic_client = None
@@ -61,10 +61,9 @@ class MCPClientWrapper:
61
 
62
  try:
63
  if os.getenv("LLAMAINDEX_API_KEY"):
64
- # Using OpenAI-compatible endpoint for Llama
65
  self.llama_client = OpenAI(
66
  api_key=os.getenv("LLAMAINDEX_API_KEY"),
67
- base_url="https://api.llamaindex.ai/v1" # Adjust based on your provider
68
  )
69
  except Exception as e:
70
  print(f"⚠️ Failed to initialize Llama client: {e}")
@@ -118,28 +117,43 @@ class MCPClientWrapper:
118
  self.current_model = model
119
  return f"βœ… Switched to {provider}: {model}"
120
 
121
- def connect(self) -> str:
122
- return loop.run_until_complete(self._connect())
123
-
124
- async def _connect(self) -> str:
125
  if self.exit_stack:
126
- await self.exit_stack.aclose()
127
-
128
- self.exit_stack = AsyncExitStack()
129
-
130
- server_path = "gradio_mcp_server.py"
131
-
132
- server_params = StdioServerParameters(
133
- command="python",
134
- args=[server_path],
135
- env={"PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"}
136
- )
137
-
138
  try:
139
- stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
140
- self.stdio, self.write = stdio_transport
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
- self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
143
  await self.session.initialize()
144
 
145
  response = await self.session.list_tools()
@@ -152,26 +166,56 @@ class MCPClientWrapper:
152
  self.connected = True
153
  tool_names = [tool["name"] for tool in self.tools]
154
  return f"βœ… Connected to MCP Weather Server. Available tools: {', '.join(tool_names)}"
 
155
  except Exception as e:
156
  self.connected = False
 
157
  return f"❌ Failed to connect to MCP server: {str(e)}"
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  def read_uploaded_file(self, file_path: str) -> str:
160
  """Read and process uploaded file content."""
161
  if not file_path or not os.path.exists(file_path):
162
  return ""
163
 
164
  try:
165
- # Get file info
166
  file_size = os.path.getsize(file_path)
167
  file_name = os.path.basename(file_path)
168
  mime_type, _ = mimetypes.guess_type(file_path)
169
 
170
- # Check file size (limit to 10MB)
171
  if file_size > 10 * 1024 * 1024:
172
- return f"\n\nπŸ” **File Upload Error**: {file_name} is too large (>10MB). Please upload a smaller file."
173
 
174
- # Try to read as text
175
  encodings_to_try = ['utf-8', 'utf-16', 'latin-1', 'cp1252']
176
 
177
  for encoding in encodings_to_try:
@@ -179,12 +223,11 @@ class MCPClientWrapper:
179
  with open(file_path, 'r', encoding=encoding) as f:
180
  content = f.read()
181
 
182
- # If content is too long, truncate it
183
- max_chars = 50000 # Roughly 50k characters
184
  if len(content) > max_chars:
185
  content = content[:max_chars] + f"\n\n[Content truncated - showing first {max_chars} characters of {len(content)} total]"
186
 
187
- file_info = f"\n\nπŸ” **Uploaded File**: {file_name}"
188
  if mime_type:
189
  file_info += f" ({mime_type})"
190
  file_info += f" - {file_size:,} bytes\n\n```\n{content}\n```"
@@ -194,18 +237,16 @@ class MCPClientWrapper:
194
  except UnicodeDecodeError:
195
  continue
196
 
197
- # If all text encodings fail, it's likely a binary file
198
- return f"\n\nπŸ” **File Upload**: {file_name} appears to be a binary file and cannot be displayed as text."
199
 
200
  except Exception as e:
201
- return f"\n\nπŸ” **File Upload Error**: Could not read {file_name}: {str(e)}"
202
 
203
  def _convert_tools_for_provider(self, provider: str):
204
  """Convert MCP tools format to provider-specific format."""
205
  if provider == "claude":
206
  return self.tools
207
  elif provider in ["openai", "llama"]:
208
- # Convert to OpenAI tools format
209
  openai_tools = []
210
  for tool in self.tools:
211
  openai_tools.append({
@@ -218,7 +259,6 @@ class MCPClientWrapper:
218
  })
219
  return openai_tools
220
  elif provider == "mistral":
221
- # Convert to Mistral tools format
222
  mistral_tools = []
223
  for tool in self.tools:
224
  mistral_tools.append({
@@ -270,6 +310,7 @@ class MCPClientWrapper:
270
  raise Exception(f"Error calling {provider}: {str(e)}")
271
 
272
  def process_message(self, message: str, history: List[Union[Dict[str, Any], ChatMessage]], uploaded_file) -> tuple:
 
273
  if not self.session or not self.connected:
274
  return history + [
275
  {"role": "user", "content": message},
@@ -284,10 +325,44 @@ class MCPClientWrapper:
284
  # Combine message with file content
285
  full_message = message + file_content
286
 
287
- new_messages = loop.run_until_complete(self._process_query(full_message, history))
288
- return history + [{"role": "user", "content": full_message}] + new_messages, gr.Textbox(value=""), gr.File(value=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
  async def _process_query(self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]):
 
291
  claude_messages = []
292
  for msg in history:
293
  if isinstance(msg, ChatMessage):
@@ -305,8 +380,6 @@ class MCPClientWrapper:
305
  except Exception as e:
306
  return [{"role": "assistant", "content": f"❌ Error with {self.current_provider}: {str(e)}"}]
307
 
308
- result_messages = []
309
-
310
  # Handle different response formats based on provider
311
  if self.current_provider == "claude":
312
  return await self._process_claude_response(response, claude_messages)
@@ -315,7 +388,7 @@ class MCPClientWrapper:
315
  elif self.current_provider == "mistral":
316
  return await self._process_mistral_response(response, claude_messages)
317
 
318
- return result_messages
319
 
320
  async def _process_claude_response(self, response, claude_messages):
321
  """Process Claude API response."""
@@ -348,11 +421,9 @@ class MCPClientWrapper:
348
  if isinstance(result_content, list):
349
  result_content = "\n".join(str(item) for item in result_content)
350
 
351
- # Format the response
352
  formatted_response = self._format_weather_response(result_content, tool_name)
353
  result_messages.append(formatted_response)
354
 
355
- # Let the LLM analyze and respond
356
  claude_messages.append({"role": "user", "content": f"Tool result for {tool_name}: {result_content}"})
357
  next_response = await self._call_llm(claude_messages, self.current_provider, self.current_model)
358
 
@@ -500,6 +571,7 @@ class MCPClientWrapper:
500
  }
501
  }
502
 
 
503
  client = MCPClientWrapper()
504
 
505
  def gradio_interface():
@@ -547,7 +619,7 @@ def gradio_interface():
547
  status = gr.Textbox(
548
  label="πŸ”Œ Connection Status",
549
  interactive=False,
550
- value="πŸ”„ Connecting to weather server..."
551
  )
552
 
553
  # Main chat interface
@@ -556,11 +628,10 @@ def gradio_interface():
556
  height=600,
557
  type="messages",
558
  show_copy_button=True,
559
- avatar_images=("πŸ‘€", "πŸ€–"),
560
- bubble_full_width=False
561
  )
562
 
563
- # File upload component (already exists in your code!)
564
  file_upload = gr.File(
565
  label="πŸ“Ž Upload File (optional)",
566
  file_count="single",
@@ -611,7 +682,7 @@ def gradio_interface():
611
  return f"{provider}: {model}", status_msg
612
  return current_model_display.value, "❌ Please select both provider and model"
613
 
614
- # Auto-connect when the interface loads
615
  def auto_connect():
616
  return client.connect()
617
 
 
1
  import asyncio
2
  import os
3
  import json
4
+ from typing import List, Dict, Any, Union, Optional
5
  from contextlib import AsyncExitStack
6
  import mimetypes
7
  import tempfile
8
+ import threading
9
+ from concurrent.futures import ThreadPoolExecutor
10
 
11
  import gradio as gr
12
  from gradio.components.chatbot import ChatMessage
 
19
 
20
  load_dotenv()
21
 
 
 
 
22
  class MCPClientWrapper:
23
  def __init__(self):
24
+ self.session: Optional[ClientSession] = None
25
+ self.exit_stack: Optional[AsyncExitStack] = None
26
  self.tools = []
27
  self.connected = False
28
+ self._connection_lock = threading.Lock()
29
 
30
  # Initialize all LLM clients
31
  self.anthropic_client = None
 
61
 
62
  try:
63
  if os.getenv("LLAMAINDEX_API_KEY"):
 
64
  self.llama_client = OpenAI(
65
  api_key=os.getenv("LLAMAINDEX_API_KEY"),
66
+ base_url="https://api.llamaindex.ai/v1"
67
  )
68
  except Exception as e:
69
  print(f"⚠️ Failed to initialize Llama client: {e}")
 
117
  self.current_model = model
118
  return f"βœ… Switched to {provider}: {model}"
119
 
120
+ async def _cleanup_connection(self):
121
+ """Safely cleanup existing connection."""
 
 
122
  if self.exit_stack:
123
+ try:
124
+ await self.exit_stack.aclose()
125
+ except Exception as e:
126
+ print(f"Warning: Error during cleanup: {e}")
127
+ finally:
128
+ self.exit_stack = None
129
+ self.session = None
130
+ self.connected = False
131
+
132
+ async def _establish_connection(self) -> str:
133
+ """Establish MCP connection in proper async context."""
 
134
  try:
135
+ # Clean up any existing connection
136
+ await self._cleanup_connection()
137
+
138
+ self.exit_stack = AsyncExitStack()
139
+
140
+ server_path = "gradio_mcp_server.py"
141
+ server_params = StdioServerParameters(
142
+ command="python",
143
+ args=[server_path],
144
+ env={"PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"}
145
+ )
146
+
147
+ # Enter the async context managers
148
+ stdio_transport = await self.exit_stack.enter_async_context(
149
+ stdio_client(server_params)
150
+ )
151
+ stdio, write = stdio_transport
152
+
153
+ self.session = await self.exit_stack.enter_async_context(
154
+ ClientSession(stdio, write)
155
+ )
156
 
 
157
  await self.session.initialize()
158
 
159
  response = await self.session.list_tools()
 
166
  self.connected = True
167
  tool_names = [tool["name"] for tool in self.tools]
168
  return f"βœ… Connected to MCP Weather Server. Available tools: {', '.join(tool_names)}"
169
+
170
  except Exception as e:
171
  self.connected = False
172
+ await self._cleanup_connection()
173
  return f"❌ Failed to connect to MCP server: {str(e)}"
174
 
175
+ def connect(self) -> str:
176
+ """Thread-safe connection method for Gradio."""
177
+ with self._connection_lock:
178
+ try:
179
+ # Create new event loop for this operation
180
+ try:
181
+ loop = asyncio.get_event_loop()
182
+ except RuntimeError:
183
+ loop = asyncio.new_event_loop()
184
+ asyncio.set_event_loop(loop)
185
+
186
+ if loop.is_running():
187
+ # If loop is already running, we need to run in a thread
188
+ import concurrent.futures
189
+ with concurrent.futures.ThreadPoolExecutor() as executor:
190
+ future = executor.submit(self._run_connection_in_new_loop)
191
+ return future.result()
192
+ else:
193
+ return loop.run_until_complete(self._establish_connection())
194
+ except Exception as e:
195
+ return f"❌ Connection error: {str(e)}"
196
+
197
+ def _run_connection_in_new_loop(self) -> str:
198
+ """Run connection in a new event loop (for thread safety)."""
199
+ loop = asyncio.new_event_loop()
200
+ asyncio.set_event_loop(loop)
201
+ try:
202
+ return loop.run_until_complete(self._establish_connection())
203
+ finally:
204
+ loop.close()
205
+
206
  def read_uploaded_file(self, file_path: str) -> str:
207
  """Read and process uploaded file content."""
208
  if not file_path or not os.path.exists(file_path):
209
  return ""
210
 
211
  try:
 
212
  file_size = os.path.getsize(file_path)
213
  file_name = os.path.basename(file_path)
214
  mime_type, _ = mimetypes.guess_type(file_path)
215
 
 
216
  if file_size > 10 * 1024 * 1024:
217
+ return f"\n\nπŸ“„ **File Upload Error**: {file_name} is too large (>10MB). Please upload a smaller file."
218
 
 
219
  encodings_to_try = ['utf-8', 'utf-16', 'latin-1', 'cp1252']
220
 
221
  for encoding in encodings_to_try:
 
223
  with open(file_path, 'r', encoding=encoding) as f:
224
  content = f.read()
225
 
226
+ max_chars = 50000
 
227
  if len(content) > max_chars:
228
  content = content[:max_chars] + f"\n\n[Content truncated - showing first {max_chars} characters of {len(content)} total]"
229
 
230
+ file_info = f"\n\nπŸ“„ **Uploaded File**: {file_name}"
231
  if mime_type:
232
  file_info += f" ({mime_type})"
233
  file_info += f" - {file_size:,} bytes\n\n```\n{content}\n```"
 
237
  except UnicodeDecodeError:
238
  continue
239
 
240
+ return f"\n\nπŸ“„ **File Upload**: {file_name} appears to be a binary file and cannot be displayed as text."
 
241
 
242
  except Exception as e:
243
+ return f"\n\nπŸ“„ **File Upload Error**: Could not read {file_name}: {str(e)}"
244
 
245
  def _convert_tools_for_provider(self, provider: str):
246
  """Convert MCP tools format to provider-specific format."""
247
  if provider == "claude":
248
  return self.tools
249
  elif provider in ["openai", "llama"]:
 
250
  openai_tools = []
251
  for tool in self.tools:
252
  openai_tools.append({
 
259
  })
260
  return openai_tools
261
  elif provider == "mistral":
 
262
  mistral_tools = []
263
  for tool in self.tools:
264
  mistral_tools.append({
 
310
  raise Exception(f"Error calling {provider}: {str(e)}")
311
 
312
  def process_message(self, message: str, history: List[Union[Dict[str, Any], ChatMessage]], uploaded_file) -> tuple:
313
+ """Process message in thread-safe manner."""
314
  if not self.session or not self.connected:
315
  return history + [
316
  {"role": "user", "content": message},
 
325
  # Combine message with file content
326
  full_message = message + file_content
327
 
328
+ try:
329
+ # Run async processing in new event loop
330
+ new_messages = self._run_async_processing(full_message, history)
331
+ return history + [{"role": "user", "content": full_message}] + new_messages, gr.Textbox(value=""), gr.File(value=None)
332
+ except Exception as e:
333
+ return history + [
334
+ {"role": "user", "content": full_message},
335
+ {"role": "assistant", "content": f"❌ Error processing message: {str(e)}"}
336
+ ], gr.Textbox(value=""), gr.File(value=None)
337
+
338
+ def _run_async_processing(self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]):
339
+ """Run async message processing in new event loop."""
340
+ try:
341
+ loop = asyncio.get_event_loop()
342
+ except RuntimeError:
343
+ loop = asyncio.new_event_loop()
344
+ asyncio.set_event_loop(loop)
345
+
346
+ if loop.is_running():
347
+ # Run in thread if event loop is already running
348
+ import concurrent.futures
349
+ with concurrent.futures.ThreadPoolExecutor() as executor:
350
+ future = executor.submit(self._process_in_new_loop, message, history)
351
+ return future.result()
352
+ else:
353
+ return loop.run_until_complete(self._process_query(message, history))
354
+
355
+ def _process_in_new_loop(self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]):
356
+ """Process query in a completely new event loop."""
357
+ loop = asyncio.new_event_loop()
358
+ asyncio.set_event_loop(loop)
359
+ try:
360
+ return loop.run_until_complete(self._process_query(message, history))
361
+ finally:
362
+ loop.close()
363
 
364
  async def _process_query(self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]):
365
+ """Process the actual query with LLM and tools."""
366
  claude_messages = []
367
  for msg in history:
368
  if isinstance(msg, ChatMessage):
 
380
  except Exception as e:
381
  return [{"role": "assistant", "content": f"❌ Error with {self.current_provider}: {str(e)}"}]
382
 
 
 
383
  # Handle different response formats based on provider
384
  if self.current_provider == "claude":
385
  return await self._process_claude_response(response, claude_messages)
 
388
  elif self.current_provider == "mistral":
389
  return await self._process_mistral_response(response, claude_messages)
390
 
391
+ return []
392
 
393
  async def _process_claude_response(self, response, claude_messages):
394
  """Process Claude API response."""
 
421
  if isinstance(result_content, list):
422
  result_content = "\n".join(str(item) for item in result_content)
423
 
 
424
  formatted_response = self._format_weather_response(result_content, tool_name)
425
  result_messages.append(formatted_response)
426
 
 
427
  claude_messages.append({"role": "user", "content": f"Tool result for {tool_name}: {result_content}"})
428
  next_response = await self._call_llm(claude_messages, self.current_provider, self.current_model)
429
 
 
571
  }
572
  }
573
 
574
+ # Initialize client
575
  client = MCPClientWrapper()
576
 
577
  def gradio_interface():
 
619
  status = gr.Textbox(
620
  label="πŸ”Œ Connection Status",
621
  interactive=False,
622
+ value="πŸ”„ Ready to connect..."
623
  )
624
 
625
  # Main chat interface
 
628
  height=600,
629
  type="messages",
630
  show_copy_button=True,
631
+ avatar_images=("πŸ‘€", "πŸ€–")
 
632
  )
633
 
634
+ # File upload component
635
  file_upload = gr.File(
636
  label="πŸ“Ž Upload File (optional)",
637
  file_count="single",
 
682
  return f"{provider}: {model}", status_msg
683
  return current_model_display.value, "❌ Please select both provider and model"
684
 
685
+ # Auto-connect function
686
  def auto_connect():
687
  return client.connect()
688