Spaces:
Sleeping
Sleeping
| # GET /models | |
| # GET /tools | |
| # POST /chat -> Groq -> response | |
| import asyncio | |
| import json | |
| import traceback | |
| from typing import List, Optional | |
| from contextlib import AsyncExitStack | |
| import uuid | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from fastapi import FastAPI, Request, HTTPException, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from mcp import ClientSession, StdioServerParameters | |
| from mcp.client.stdio import stdio_client | |
| from groq import Groq, APIConnectionError | |
| from dotenv import load_dotenv | |
| import os | |
| import httpx | |
| sessions = {} | |
| unique_apikeys = [] | |
| class MCPClient: | |
| def __init__(self): | |
| self.session: Optional[ClientSession] = None | |
| self.exit_stack = AsyncExitStack() | |
| self.current_model = None | |
| self.groq = None | |
| self.api_key = None | |
| self.messages = [{ | |
| "role": "system", | |
| "content": "You are a helpful assistant that have access to different tools via MCP. Make complete answers." | |
| }] | |
| self.tool_use = True | |
| self.models = None | |
| self.tools = [] | |
| async def connect(self, api_key: str): | |
| try: | |
| self.groq = Groq(api_key=api_key, http_client=httpx.Client(verify=False, timeout=30)) | |
| self.api_key = api_key | |
| except APIConnectionError as e: | |
| traceback.print_exception(e) | |
| return False | |
| except Exception as e: | |
| traceback.print_exception(e) | |
| return False | |
| server_params = StdioServerParameters(command="uv", args=["run", "server.py"]) | |
| stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) | |
| self.stdio, self.write = stdio_transport | |
| self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write)) | |
| await self.session.initialize() | |
| response = await self.session.list_tools() | |
| tools = response.tools | |
| print("\nConnected to server with tools:", [tool.name for tool in tools]) | |
| self.tools = [{"type": "function", "function": { | |
| "name": tool.name, | |
| "description": tool.description, | |
| "parameters": tool.inputSchema | |
| }} for tool in tools] | |
| def populate_model(self): | |
| self.models = sorted([m.id for m in self.groq.models.list().data]) | |
| async def process_query(self, query: str) -> str: | |
| """Process a query using Groq and available tools""" | |
| self.messages.extend([ | |
| { | |
| "role": "user", | |
| "content": query | |
| } | |
| ]) | |
| response = self.groq.chat.completions.create( | |
| model=self.current_model, | |
| messages=self.messages, | |
| tools=self.tools, | |
| temperature=0 | |
| ) if self.tool_use else self.groq.chat.completions.create( | |
| model=self.current_model, | |
| messages=self.messages, | |
| temperature=0.7 | |
| ) | |
| # Process response and handle tool calls | |
| final_text = [] | |
| for choice in response.choices: | |
| content = choice.message.content | |
| tool_calls = choice.message.tool_calls | |
| if content: | |
| final_text.append(content) | |
| if tool_calls: | |
| print(tool_calls) | |
| for tool in tool_calls: | |
| tool_name = tool.function.name | |
| tool_args = tool.function.arguments | |
| result = await self.session.call_tool(tool_name, json.loads(tool_args)) | |
| print(f"[Calling tool {tool_name} with args {tool_args}]") | |
| if content is not None: | |
| self.messages.append({ | |
| "role": "assistant", | |
| "content": content | |
| }) | |
| self.messages.append({ | |
| "role": "tool", | |
| "tool_call_id": tool.id, | |
| "content": str(result.content) | |
| }) | |
| print(result.content[0].text) | |
| response = self.groq.chat.completions.create( | |
| model=self.current_model, | |
| messages=self.messages, | |
| temperature=0.7 | |
| ) | |
| final_text.append(response.choices[0].message.content) | |
| return "\n".join(final_text) | |
| app = FastAPI() | |
| app.add_middleware(CORSMiddleware, allow_credentials=True, allow_headers=["*"], allow_methods=["*"], allow_origins=["*"]) | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| mcp = MCPClient() | |
| class InitRequest(BaseModel): | |
| api_key: str | |
| class InitResponse(BaseModel): | |
| success: bool | |
| session_id: str | |
| models: Optional[list] = None | |
| error: Optional[str] = None | |
| class LogoutRequest(BaseModel): | |
| session_id: str | |
| def get_mcp_client(session_id: str) -> MCPClient|None: | |
| """Get the MCPClient for a given session_id, or raise 404.""" | |
| client = sessions.get(session_id) | |
| if client is None: | |
| raise HTTPException(status_code=404, detail="Invalid session_id. Please re-initialize.") | |
| return client | |
| def root(): | |
| return FileResponse("index.html") | |
| async def init_server(req: InitRequest): | |
| """ | |
| Initializes a new MCP client session. Returns a session_id. | |
| """ | |
| api_key = req.api_key | |
| session_id = str(uuid.uuid4()) | |
| mcp = MCPClient() | |
| try: | |
| ok = await mcp.connect(api_key) | |
| if ok is False: | |
| raise RuntimeError("Failed to connect to MCP or Groq with API key.") | |
| mcp.populate_model() | |
| sessions[session_id] = mcp | |
| if api_key not in unique_apikeys: | |
| unique_apikeys.append(api_key) | |
| else: | |
| raise Exception("Session with this API key already exists. We won't re-return you the session ID. Bye-bye Hacker !!") | |
| return InitResponse( | |
| session_id=session_id, | |
| models=mcp.models, | |
| error=None, | |
| success=True | |
| ) | |
| except Exception as e: | |
| traceback.print_exception(e) | |
| return InitResponse( | |
| session_id="", | |
| models=None, | |
| error=str(e), | |
| success=False | |
| ) | |
| class ChatRequest(BaseModel): | |
| session_id: str | |
| query: str | |
| tool_use: Optional[bool] = True | |
| model: Optional[str] = "llama-3.3-70b-versatile" | |
| class ChatResponse(BaseModel): | |
| output: str | |
| error: Optional[str] = None | |
| async def chat(req: ChatRequest): | |
| """ | |
| Handles chat requests for a given session. | |
| """ | |
| try: | |
| mcp = get_mcp_client(req.session_id) | |
| mcp.tool_use = req.tool_use | |
| if req.model in mcp.models: | |
| mcp.current_model = req.model | |
| else: | |
| raise ValueError(f"Model not recognized: Not in the model list: {mcp.models}") | |
| result = await mcp.process_query(req.query) | |
| return ChatResponse(output=result) | |
| except Exception as e: | |
| traceback.print_exception(e) | |
| return ChatResponse(output="", error=str(e)) | |
| async def logout(logout_req: LogoutRequest): | |
| """Clean up session resources.""" | |
| mcp = sessions.pop(logout_req.session_id, None) | |
| unique_apikeys.remove(mcp.api_key) | |
| if mcp and hasattr(mcp.exit_stack, "aclose"): | |
| await mcp.exit_stack.aclose() | |
| return {"success": True} |