|
|
import { ZodLiteral, ZodObject, ZodType, z } from "zod"; |
|
|
import { |
|
|
CancelledNotificationSchema, |
|
|
ClientCapabilities, |
|
|
ErrorCode, |
|
|
isJSONRPCError, |
|
|
isJSONRPCRequest, |
|
|
isJSONRPCResponse, |
|
|
isJSONRPCNotification, |
|
|
JSONRPCError, |
|
|
JSONRPCNotification, |
|
|
JSONRPCRequest, |
|
|
JSONRPCResponse, |
|
|
McpError, |
|
|
Notification, |
|
|
PingRequestSchema, |
|
|
Progress, |
|
|
ProgressNotification, |
|
|
ProgressNotificationSchema, |
|
|
Request, |
|
|
RequestId, |
|
|
Result, |
|
|
ServerCapabilities, |
|
|
RequestMeta, |
|
|
} from "../types"; |
|
|
import { Transport, TransportSendOptions } from "./transport"; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export type ProgressCallback = (progress: Progress) => void; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export type ProtocolOptions = { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enforceStrictCapabilities?: boolean; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export const DEFAULT_REQUEST_TIMEOUT_MSEC = 60000; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export type RequestOptions = { |
|
|
|
|
|
|
|
|
|
|
|
onprogress?: ProgressCallback; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
signal?: AbortSignal; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
timeout?: number; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
resetTimeoutOnProgress?: boolean; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
maxTotalTimeout?: number; |
|
|
} & TransportSendOptions; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export type NotificationOptions = { |
|
|
|
|
|
|
|
|
|
|
|
relatedRequestId?: RequestId; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export type RequestHandlerExtra< |
|
|
SendRequestT extends Request, |
|
|
SendNotificationT extends Notification |
|
|
> = { |
|
|
|
|
|
|
|
|
|
|
|
signal: AbortSignal; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sessionId?: string; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_meta?: RequestMeta; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requestId: RequestId; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sendNotification: (notification: SendNotificationT) => Promise<void>; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sendRequest: <U extends ZodType<object>>( |
|
|
request: SendRequestT, |
|
|
resultSchema: U, |
|
|
options?: RequestOptions |
|
|
) => Promise<z.infer<U>>; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
type TimeoutInfo = { |
|
|
timeoutId: ReturnType<typeof setTimeout>; |
|
|
startTime: number; |
|
|
timeout: number; |
|
|
maxTotalTimeout?: number; |
|
|
resetTimeoutOnProgress: boolean; |
|
|
onTimeout: () => void; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export abstract class Protocol< |
|
|
SendRequestT extends Request, |
|
|
SendNotificationT extends Notification, |
|
|
SendResultT extends Result |
|
|
> { |
|
|
private _transport?: Transport; |
|
|
private _requestMessageId = 0; |
|
|
private _requestHandlers: Map< |
|
|
string, |
|
|
( |
|
|
request: JSONRPCRequest, |
|
|
extra: RequestHandlerExtra<SendRequestT, SendNotificationT> |
|
|
) => Promise<SendResultT> |
|
|
> = new Map(); |
|
|
private _requestHandlerAbortControllers: Map<RequestId, AbortController> = |
|
|
new Map(); |
|
|
private _notificationHandlers: Map< |
|
|
string, |
|
|
(notification: JSONRPCNotification) => Promise<void> |
|
|
> = new Map(); |
|
|
private _responseHandlers: Map< |
|
|
number, |
|
|
(response: JSONRPCResponse | Error) => void |
|
|
> = new Map(); |
|
|
private _progressHandlers: Map<number, ProgressCallback> = new Map(); |
|
|
private _timeoutInfo: Map<number, TimeoutInfo> = new Map(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
onclose?: () => void; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
onerror?: (error: Error) => void; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fallbackRequestHandler?: (request: Request) => Promise<SendResultT>; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fallbackNotificationHandler?: (notification: Notification) => Promise<void>; |
|
|
|
|
|
constructor(private _options?: ProtocolOptions) { |
|
|
this.setNotificationHandler(CancelledNotificationSchema, (notification) => { |
|
|
const controller = this._requestHandlerAbortControllers.get( |
|
|
notification.params.requestId |
|
|
); |
|
|
controller?.abort(notification.params.reason); |
|
|
}); |
|
|
|
|
|
this.setNotificationHandler(ProgressNotificationSchema, (notification) => { |
|
|
this._onprogress(notification as unknown as ProgressNotification); |
|
|
}); |
|
|
|
|
|
this.setRequestHandler( |
|
|
PingRequestSchema, |
|
|
|
|
|
() => ({} as SendResultT) |
|
|
); |
|
|
} |
|
|
|
|
|
private _setupTimeout( |
|
|
messageId: number, |
|
|
timeout: number, |
|
|
maxTotalTimeout: number | undefined, |
|
|
onTimeout: () => void, |
|
|
resetTimeoutOnProgress: boolean = false |
|
|
) { |
|
|
this._timeoutInfo.set(messageId, { |
|
|
timeoutId: setTimeout(onTimeout, timeout), |
|
|
startTime: Date.now(), |
|
|
timeout, |
|
|
maxTotalTimeout, |
|
|
resetTimeoutOnProgress, |
|
|
onTimeout, |
|
|
}); |
|
|
} |
|
|
|
|
|
private _resetTimeout(messageId: number): boolean { |
|
|
const info = this._timeoutInfo.get(messageId); |
|
|
if (!info) return false; |
|
|
|
|
|
const totalElapsed = Date.now() - info.startTime; |
|
|
if (info.maxTotalTimeout && totalElapsed >= info.maxTotalTimeout) { |
|
|
this._timeoutInfo.delete(messageId); |
|
|
throw new McpError( |
|
|
ErrorCode.RequestTimeout, |
|
|
"Maximum total timeout exceeded", |
|
|
{ maxTotalTimeout: info.maxTotalTimeout, totalElapsed } |
|
|
); |
|
|
} |
|
|
|
|
|
clearTimeout(info.timeoutId); |
|
|
info.timeoutId = setTimeout(info.onTimeout, info.timeout); |
|
|
return true; |
|
|
} |
|
|
|
|
|
private _cleanupTimeout(messageId: number) { |
|
|
const info = this._timeoutInfo.get(messageId); |
|
|
if (info) { |
|
|
clearTimeout(info.timeoutId); |
|
|
this._timeoutInfo.delete(messageId); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async connect(transport: Transport): Promise<void> { |
|
|
this._transport = transport; |
|
|
this._transport.onclose = () => { |
|
|
this._onclose(); |
|
|
}; |
|
|
|
|
|
this._transport.onerror = (error: Error) => { |
|
|
this._onerror(error); |
|
|
}; |
|
|
|
|
|
this._transport.onmessage = (message) => { |
|
|
if (isJSONRPCResponse(message) || isJSONRPCError(message)) { |
|
|
this._onresponse(message); |
|
|
} else if (isJSONRPCRequest(message)) { |
|
|
this._onrequest(message); |
|
|
} else if (isJSONRPCNotification(message)) { |
|
|
this._onnotification(message); |
|
|
} else { |
|
|
this._onerror( |
|
|
new Error(`Unknown message type: ${JSON.stringify(message)}`) |
|
|
); |
|
|
} |
|
|
}; |
|
|
|
|
|
await this._transport.start(); |
|
|
} |
|
|
|
|
|
private _onclose(): void { |
|
|
const responseHandlers = this._responseHandlers; |
|
|
this._responseHandlers = new Map(); |
|
|
this._progressHandlers.clear(); |
|
|
this._transport = undefined; |
|
|
this.onclose?.(); |
|
|
|
|
|
const error = new McpError(ErrorCode.ConnectionClosed, "Connection closed"); |
|
|
for (const handler of responseHandlers.values()) { |
|
|
handler(error); |
|
|
} |
|
|
} |
|
|
|
|
|
private _onerror(error: Error): void { |
|
|
this.onerror?.(error); |
|
|
} |
|
|
|
|
|
private _onnotification(notification: JSONRPCNotification): void { |
|
|
const handler = |
|
|
this._notificationHandlers.get(notification.method) ?? |
|
|
this.fallbackNotificationHandler; |
|
|
|
|
|
|
|
|
if (handler === undefined) { |
|
|
return; |
|
|
} |
|
|
|
|
|
|
|
|
Promise.resolve() |
|
|
.then(() => handler(notification)) |
|
|
.catch((error) => |
|
|
this._onerror( |
|
|
new Error(`Uncaught error in notification handler: ${error}`) |
|
|
) |
|
|
); |
|
|
} |
|
|
|
|
|
private _onrequest(request: JSONRPCRequest): void { |
|
|
const handler = |
|
|
this._requestHandlers.get(request.method) ?? this.fallbackRequestHandler; |
|
|
|
|
|
if (handler === undefined) { |
|
|
this._transport |
|
|
?.send({ |
|
|
jsonrpc: "2.0", |
|
|
id: request.id, |
|
|
error: { |
|
|
code: ErrorCode.MethodNotFound, |
|
|
message: "Method not found", |
|
|
}, |
|
|
}) |
|
|
.catch((error) => |
|
|
this._onerror(new Error(`Failed to send an error response: ${error}`)) |
|
|
); |
|
|
return; |
|
|
} |
|
|
|
|
|
const abortController = new AbortController(); |
|
|
this._requestHandlerAbortControllers.set(request.id, abortController); |
|
|
|
|
|
const fullExtra: RequestHandlerExtra<SendRequestT, SendNotificationT> = { |
|
|
signal: abortController.signal, |
|
|
sessionId: this._transport?.sessionId, |
|
|
_meta: request.params?._meta, |
|
|
sendNotification: (notification) => |
|
|
this.notification(notification, { relatedRequestId: request.id }), |
|
|
sendRequest: (r, resultSchema, options?) => |
|
|
this.request(r, resultSchema, { |
|
|
...options, |
|
|
relatedRequestId: request.id, |
|
|
}), |
|
|
requestId: request.id, |
|
|
}; |
|
|
|
|
|
|
|
|
Promise.resolve() |
|
|
.then(() => handler(request, fullExtra)) |
|
|
.then( |
|
|
(result) => { |
|
|
if (abortController.signal.aborted) { |
|
|
return; |
|
|
} |
|
|
|
|
|
return this._transport?.send({ |
|
|
result, |
|
|
jsonrpc: "2.0", |
|
|
id: request.id, |
|
|
}); |
|
|
}, |
|
|
(error) => { |
|
|
if (abortController.signal.aborted) { |
|
|
return; |
|
|
} |
|
|
|
|
|
return this._transport?.send({ |
|
|
jsonrpc: "2.0", |
|
|
id: request.id, |
|
|
error: { |
|
|
code: Number.isSafeInteger(error["code"]) |
|
|
? error["code"] |
|
|
: ErrorCode.InternalError, |
|
|
message: error.message ?? "Internal error", |
|
|
}, |
|
|
}); |
|
|
} |
|
|
) |
|
|
.catch((error) => |
|
|
this._onerror(new Error(`Failed to send response: ${error}`)) |
|
|
) |
|
|
.finally(() => { |
|
|
this._requestHandlerAbortControllers.delete(request.id); |
|
|
}); |
|
|
} |
|
|
|
|
|
private _onprogress(notification: ProgressNotification): void { |
|
|
const { progressToken, ...params } = notification.params; |
|
|
const messageId = Number(progressToken); |
|
|
|
|
|
const handler = this._progressHandlers.get(messageId); |
|
|
if (!handler) { |
|
|
this._onerror( |
|
|
new Error( |
|
|
`Received a progress notification for an unknown token: ${JSON.stringify( |
|
|
notification |
|
|
)}` |
|
|
) |
|
|
); |
|
|
return; |
|
|
} |
|
|
|
|
|
const responseHandler = this._responseHandlers.get(messageId); |
|
|
const timeoutInfo = this._timeoutInfo.get(messageId); |
|
|
|
|
|
if (timeoutInfo && responseHandler && timeoutInfo.resetTimeoutOnProgress) { |
|
|
try { |
|
|
this._resetTimeout(messageId); |
|
|
} catch (error) { |
|
|
responseHandler(error as Error); |
|
|
return; |
|
|
} |
|
|
} |
|
|
|
|
|
handler(params); |
|
|
} |
|
|
|
|
|
private _onresponse(response: JSONRPCResponse | JSONRPCError): void { |
|
|
const messageId = Number(response.id); |
|
|
const handler = this._responseHandlers.get(messageId); |
|
|
if (handler === undefined) { |
|
|
this._onerror( |
|
|
new Error( |
|
|
`Received a response for an unknown message ID: ${JSON.stringify( |
|
|
response |
|
|
)}` |
|
|
) |
|
|
); |
|
|
return; |
|
|
} |
|
|
|
|
|
this._responseHandlers.delete(messageId); |
|
|
this._progressHandlers.delete(messageId); |
|
|
this._cleanupTimeout(messageId); |
|
|
|
|
|
if (isJSONRPCResponse(response)) { |
|
|
handler(response); |
|
|
} else { |
|
|
const error = new McpError( |
|
|
response.error.code, |
|
|
response.error.message, |
|
|
response.error.data |
|
|
); |
|
|
handler(error); |
|
|
} |
|
|
} |
|
|
|
|
|
get transport(): Transport | undefined { |
|
|
return this._transport; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async close(): Promise<void> { |
|
|
await this._transport?.close(); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protected abstract assertCapabilityForMethod( |
|
|
method: SendRequestT["method"] |
|
|
): void; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protected abstract assertNotificationCapability( |
|
|
method: SendNotificationT["method"] |
|
|
): void; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protected abstract assertRequestHandlerCapability(method: string): void; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
request<T extends ZodType<object>>( |
|
|
request: SendRequestT, |
|
|
resultSchema: T, |
|
|
options?: RequestOptions |
|
|
): Promise<z.infer<T>> { |
|
|
const { relatedRequestId, resumptionToken, onresumptiontoken } = |
|
|
options ?? {}; |
|
|
|
|
|
return new Promise((resolve, reject) => { |
|
|
if (!this._transport) { |
|
|
reject(new Error("Not connected")); |
|
|
return; |
|
|
} |
|
|
|
|
|
if (this._options?.enforceStrictCapabilities === true) { |
|
|
this.assertCapabilityForMethod(request.method); |
|
|
} |
|
|
|
|
|
options?.signal?.throwIfAborted(); |
|
|
|
|
|
const messageId = this._requestMessageId++; |
|
|
const jsonrpcRequest: JSONRPCRequest = { |
|
|
...request, |
|
|
jsonrpc: "2.0", |
|
|
id: messageId, |
|
|
}; |
|
|
|
|
|
if (options?.onprogress) { |
|
|
this._progressHandlers.set(messageId, options.onprogress); |
|
|
jsonrpcRequest.params = { |
|
|
...request.params, |
|
|
_meta: { progressToken: messageId }, |
|
|
}; |
|
|
} |
|
|
|
|
|
const cancel = (reason: unknown) => { |
|
|
this._responseHandlers.delete(messageId); |
|
|
this._progressHandlers.delete(messageId); |
|
|
this._cleanupTimeout(messageId); |
|
|
|
|
|
this._transport |
|
|
?.send( |
|
|
{ |
|
|
jsonrpc: "2.0", |
|
|
method: "notifications/cancelled", |
|
|
params: { |
|
|
requestId: messageId, |
|
|
reason: String(reason), |
|
|
}, |
|
|
}, |
|
|
{ relatedRequestId, resumptionToken, onresumptiontoken } |
|
|
) |
|
|
.catch((error) => |
|
|
this._onerror(new Error(`Failed to send cancellation: ${error}`)) |
|
|
); |
|
|
|
|
|
reject(reason); |
|
|
}; |
|
|
|
|
|
this._responseHandlers.set(messageId, (response) => { |
|
|
if (options?.signal?.aborted) { |
|
|
return; |
|
|
} |
|
|
|
|
|
if (response instanceof Error) { |
|
|
return reject(response); |
|
|
} |
|
|
|
|
|
try { |
|
|
const result = resultSchema.parse(response.result); |
|
|
resolve(result); |
|
|
} catch (error) { |
|
|
reject(error); |
|
|
} |
|
|
}); |
|
|
|
|
|
options?.signal?.addEventListener("abort", () => { |
|
|
cancel(options?.signal?.reason); |
|
|
}); |
|
|
|
|
|
const timeout = options?.timeout ?? DEFAULT_REQUEST_TIMEOUT_MSEC; |
|
|
const timeoutHandler = () => |
|
|
cancel( |
|
|
new McpError(ErrorCode.RequestTimeout, "Request timed out", { |
|
|
timeout, |
|
|
}) |
|
|
); |
|
|
|
|
|
this._setupTimeout( |
|
|
messageId, |
|
|
timeout, |
|
|
options?.maxTotalTimeout, |
|
|
timeoutHandler, |
|
|
options?.resetTimeoutOnProgress ?? false |
|
|
); |
|
|
|
|
|
this._transport |
|
|
.send(jsonrpcRequest, { |
|
|
relatedRequestId, |
|
|
resumptionToken, |
|
|
onresumptiontoken, |
|
|
}) |
|
|
.catch((error) => { |
|
|
this._cleanupTimeout(messageId); |
|
|
reject(error); |
|
|
}); |
|
|
}); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async notification( |
|
|
notification: SendNotificationT, |
|
|
options?: NotificationOptions |
|
|
): Promise<void> { |
|
|
if (!this._transport) { |
|
|
throw new Error("Not connected"); |
|
|
} |
|
|
|
|
|
this.assertNotificationCapability(notification.method); |
|
|
|
|
|
const jsonrpcNotification: JSONRPCNotification = { |
|
|
...notification, |
|
|
jsonrpc: "2.0", |
|
|
}; |
|
|
|
|
|
await this._transport.send(jsonrpcNotification, options); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
setRequestHandler< |
|
|
T extends ZodObject<{ |
|
|
method: ZodLiteral<string>; |
|
|
}> |
|
|
>( |
|
|
requestSchema: T, |
|
|
handler: ( |
|
|
request: z.infer<T>, |
|
|
extra: RequestHandlerExtra<SendRequestT, SendNotificationT> |
|
|
) => SendResultT | Promise<SendResultT> |
|
|
): void { |
|
|
const method = requestSchema.shape.method.value; |
|
|
this.assertRequestHandlerCapability(method); |
|
|
|
|
|
this._requestHandlers.set(method, (request, extra) => { |
|
|
return Promise.resolve(handler(requestSchema.parse(request), extra)); |
|
|
}); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
removeRequestHandler(method: string): void { |
|
|
this._requestHandlers.delete(method); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assertCanSetRequestHandler(method: string): void { |
|
|
if (this._requestHandlers.has(method)) { |
|
|
throw new Error( |
|
|
`A request handler for ${method} already exists, which would be overridden` |
|
|
); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
setNotificationHandler< |
|
|
T extends ZodObject<{ |
|
|
method: ZodLiteral<string>; |
|
|
}> |
|
|
>( |
|
|
notificationSchema: T, |
|
|
handler: (notification: z.infer<T>) => void | Promise<void> |
|
|
): void { |
|
|
this._notificationHandlers.set( |
|
|
notificationSchema.shape.method.value, |
|
|
(notification) => |
|
|
Promise.resolve(handler(notificationSchema.parse(notification))) |
|
|
); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
removeNotificationHandler(method: string): void { |
|
|
this._notificationHandlers.delete(method); |
|
|
} |
|
|
} |
|
|
|
|
|
export function mergeCapabilities< |
|
|
T extends ServerCapabilities | ClientCapabilities |
|
|
>(base: T, additional: T): T { |
|
|
return Object.entries(additional).reduce( |
|
|
(acc, [key, value]) => { |
|
|
if (value && typeof value === "object") { |
|
|
acc[key] = acc[key] ? { ...acc[key], ...value } : value; |
|
|
} else { |
|
|
acc[key] = value; |
|
|
} |
|
|
return acc; |
|
|
}, |
|
|
{ ...base } |
|
|
); |
|
|
} |
|
|
|