|
|
package mcp |
|
|
|
|
|
import ( |
|
|
"context" |
|
|
"net/http" |
|
|
"os" |
|
|
"os/exec" |
|
|
"sync" |
|
|
"time" |
|
|
|
|
|
"github.com/mudler/LocalAI/core/config" |
|
|
"github.com/mudler/LocalAI/pkg/signals" |
|
|
|
|
|
"github.com/modelcontextprotocol/go-sdk/mcp" |
|
|
"github.com/mudler/xlog" |
|
|
) |
|
|
|
|
|
type sessionCache struct { |
|
|
mu sync.Mutex |
|
|
cache map[string][]*mcp.ClientSession |
|
|
} |
|
|
|
|
|
var ( |
|
|
cache = sessionCache{ |
|
|
cache: make(map[string][]*mcp.ClientSession), |
|
|
} |
|
|
|
|
|
client = mcp.NewClient(&mcp.Implementation{Name: "LocalAI", Version: "v1.0.0"}, nil) |
|
|
) |
|
|
|
|
|
func SessionsFromMCPConfig( |
|
|
name string, |
|
|
remote config.MCPGenericConfig[config.MCPRemoteServers], |
|
|
stdio config.MCPGenericConfig[config.MCPSTDIOServers], |
|
|
) ([]*mcp.ClientSession, error) { |
|
|
cache.mu.Lock() |
|
|
defer cache.mu.Unlock() |
|
|
|
|
|
sessions, exists := cache.cache[name] |
|
|
if exists { |
|
|
return sessions, nil |
|
|
} |
|
|
|
|
|
allSessions := []*mcp.ClientSession{} |
|
|
|
|
|
ctx, cancel := context.WithCancel(context.Background()) |
|
|
|
|
|
|
|
|
for _, server := range remote.Servers { |
|
|
xlog.Debug("[MCP remote server] Configuration", "server", server) |
|
|
|
|
|
httpClient := &http.Client{ |
|
|
Timeout: 360 * time.Second, |
|
|
Transport: newBearerTokenRoundTripper(server.Token, http.DefaultTransport), |
|
|
} |
|
|
|
|
|
transport := &mcp.StreamableClientTransport{Endpoint: server.URL, HTTPClient: httpClient} |
|
|
mcpSession, err := client.Connect(ctx, transport, nil) |
|
|
if err != nil { |
|
|
xlog.Error("Failed to connect to MCP server", "error", err, "url", server.URL) |
|
|
continue |
|
|
} |
|
|
xlog.Debug("[MCP remote server] Connected to MCP server", "url", server.URL) |
|
|
cache.cache[name] = append(cache.cache[name], mcpSession) |
|
|
allSessions = append(allSessions, mcpSession) |
|
|
} |
|
|
|
|
|
for _, server := range stdio.Servers { |
|
|
xlog.Debug("[MCP stdio server] Configuration", "server", server) |
|
|
command := exec.Command(server.Command, server.Args...) |
|
|
command.Env = os.Environ() |
|
|
for key, value := range server.Env { |
|
|
command.Env = append(command.Env, key+"="+value) |
|
|
} |
|
|
transport := &mcp.CommandTransport{Command: command} |
|
|
mcpSession, err := client.Connect(ctx, transport, nil) |
|
|
if err != nil { |
|
|
xlog.Error("Failed to start MCP server", "error", err, "command", command) |
|
|
continue |
|
|
} |
|
|
xlog.Debug("[MCP stdio server] Connected to MCP server", "command", command) |
|
|
cache.cache[name] = append(cache.cache[name], mcpSession) |
|
|
allSessions = append(allSessions, mcpSession) |
|
|
} |
|
|
|
|
|
signals.RegisterGracefulTerminationHandler(func() { |
|
|
for _, session := range allSessions { |
|
|
session.Close() |
|
|
} |
|
|
cancel() |
|
|
}) |
|
|
|
|
|
return allSessions, nil |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
type bearerTokenRoundTripper struct { |
|
|
token string |
|
|
base http.RoundTripper |
|
|
} |
|
|
|
|
|
|
|
|
func (rt *bearerTokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { |
|
|
if rt.token != "" { |
|
|
req.Header.Set("Authorization", "Bearer "+rt.token) |
|
|
} |
|
|
return rt.base.RoundTrip(req) |
|
|
} |
|
|
|
|
|
|
|
|
func newBearerTokenRoundTripper(token string, base http.RoundTripper) http.RoundTripper { |
|
|
if base == nil { |
|
|
base = http.DefaultTransport |
|
|
} |
|
|
return &bearerTokenRoundTripper{ |
|
|
token: token, |
|
|
base: base, |
|
|
} |
|
|
} |
|
|
|