Spaces:
Build error
Build error
| import React from "react"; | |
| import { io, Socket } from "socket.io-client"; | |
| import { useQueryClient } from "@tanstack/react-query"; | |
| import EventLogger from "#/utils/event-logger"; | |
| import { handleAssistantMessage } from "#/services/actions"; | |
| import { showChatError, trackError } from "#/utils/error-handler"; | |
| import { useRate } from "#/hooks/use-rate"; | |
| import { OpenHandsParsedEvent } from "#/types/core"; | |
| import { | |
| AssistantMessageAction, | |
| CommandAction, | |
| FileEditAction, | |
| FileWriteAction, | |
| OpenHandsAction, | |
| UserMessageAction, | |
| } from "#/types/core/actions"; | |
| import { Conversation } from "#/api/open-hands.types"; | |
| import { useUserProviders } from "#/hooks/use-user-providers"; | |
| import { useActiveConversation } from "#/hooks/query/use-active-conversation"; | |
| import { OpenHandsObservation } from "#/types/core/observations"; | |
| import { | |
| isAgentStateChangeObservation, | |
| isErrorObservation, | |
| isOpenHandsAction, | |
| isOpenHandsObservation, | |
| isStatusUpdate, | |
| isUserMessage, | |
| } from "#/types/core/guards"; | |
| import { useOptimisticUserMessage } from "#/hooks/use-optimistic-user-message"; | |
| import { useWSErrorMessage } from "#/hooks/use-ws-error-message"; | |
| const hasValidMessageProperty = (obj: unknown): obj is { message: string } => | |
| typeof obj === "object" && | |
| obj !== null && | |
| "message" in obj && | |
| typeof obj.message === "string"; | |
| const isOpenHandsEvent = (event: unknown): event is OpenHandsParsedEvent => | |
| typeof event === "object" && | |
| event !== null && | |
| "id" in event && | |
| "source" in event && | |
| "message" in event && | |
| "timestamp" in event; | |
| const isFileWriteAction = ( | |
| event: OpenHandsParsedEvent, | |
| ): event is FileWriteAction => "action" in event && event.action === "write"; | |
| const isFileEditAction = ( | |
| event: OpenHandsParsedEvent, | |
| ): event is FileEditAction => "action" in event && event.action === "edit"; | |
| const isCommandAction = (event: OpenHandsParsedEvent): event is CommandAction => | |
| "action" in event && event.action === "run"; | |
| const isAssistantMessage = ( | |
| event: OpenHandsParsedEvent, | |
| ): event is AssistantMessageAction => | |
| "source" in event && | |
| "type" in event && | |
| event.source === "agent" && | |
| event.type === "message"; | |
| const isMessageAction = ( | |
| event: OpenHandsParsedEvent, | |
| ): event is UserMessageAction | AssistantMessageAction => | |
| isUserMessage(event) || isAssistantMessage(event); | |
| export enum WsClientProviderStatus { | |
| CONNECTED, | |
| DISCONNECTED, | |
| CONNECTING, | |
| } | |
| interface UseWsClient { | |
| status: WsClientProviderStatus; | |
| isLoadingMessages: boolean; | |
| events: Record<string, unknown>[]; | |
| parsedEvents: (OpenHandsAction | OpenHandsObservation)[]; | |
| send: (event: Record<string, unknown>) => void; | |
| } | |
| const WsClientContext = React.createContext<UseWsClient>({ | |
| status: WsClientProviderStatus.DISCONNECTED, | |
| isLoadingMessages: true, | |
| events: [], | |
| parsedEvents: [], | |
| send: () => { | |
| throw new Error("not connected"); | |
| }, | |
| }); | |
| interface WsClientProviderProps { | |
| conversationId: string; | |
| } | |
| interface ErrorArg { | |
| message?: string; | |
| data?: ErrorArgData | unknown; | |
| } | |
| interface ErrorArgData { | |
| msg_id: string; | |
| } | |
| export function updateStatusWhenErrorMessagePresent(data: ErrorArg | unknown) { | |
| const isObject = (val: unknown): val is object => | |
| !!val && typeof val === "object"; | |
| const isString = (val: unknown): val is string => typeof val === "string"; | |
| if (isObject(data) && "message" in data && isString(data.message)) { | |
| if (data.message === "websocket error" || data.message === "timeout") { | |
| return; | |
| } | |
| let msgId: string | undefined; | |
| let metadata: Record<string, unknown> = {}; | |
| if ("data" in data && isObject(data.data)) { | |
| if ("msg_id" in data.data && isString(data.data.msg_id)) { | |
| msgId = data.data.msg_id; | |
| } | |
| metadata = data.data as Record<string, unknown>; | |
| } | |
| showChatError({ | |
| message: data.message, | |
| source: "websocket", | |
| metadata, | |
| msgId, | |
| }); | |
| } | |
| } | |
| export function WsClientProvider({ | |
| conversationId, | |
| children, | |
| }: React.PropsWithChildren<WsClientProviderProps>) { | |
| const { removeOptimisticUserMessage } = useOptimisticUserMessage(); | |
| const { setErrorMessage, removeErrorMessage } = useWSErrorMessage(); | |
| const queryClient = useQueryClient(); | |
| const sioRef = React.useRef<Socket | null>(null); | |
| const [status, setStatus] = React.useState( | |
| WsClientProviderStatus.DISCONNECTED, | |
| ); | |
| const [events, setEvents] = React.useState<Record<string, unknown>[]>([]); | |
| const [parsedEvents, setParsedEvents] = React.useState< | |
| (OpenHandsAction | OpenHandsObservation)[] | |
| >([]); | |
| const lastEventRef = React.useRef<Record<string, unknown> | null>(null); | |
| const { providers } = useUserProviders(); | |
| const messageRateHandler = useRate({ threshold: 250 }); | |
| const { data: conversation, refetch: refetchConversation } = | |
| useActiveConversation(); | |
| function send(event: Record<string, unknown>) { | |
| if (!sioRef.current) { | |
| EventLogger.error("WebSocket is not connected."); | |
| return; | |
| } | |
| sioRef.current.emit("oh_user_action", event); | |
| } | |
| function handleConnect() { | |
| setStatus(WsClientProviderStatus.CONNECTED); | |
| removeErrorMessage(); | |
| } | |
| function handleMessage(event: Record<string, unknown>) { | |
| handleAssistantMessage(event); | |
| if (isOpenHandsEvent(event)) { | |
| const isStatusUpdateError = | |
| isStatusUpdate(event) && event.type === "error"; | |
| const isAgentStateChangeError = | |
| isAgentStateChangeObservation(event) && | |
| event.extras.agent_state === "error"; | |
| if (isStatusUpdateError || isAgentStateChangeError) { | |
| const errorMessage = isStatusUpdate(event) | |
| ? event.message | |
| : event.extras.reason || "Unknown error"; | |
| trackError({ | |
| message: errorMessage, | |
| source: "chat", | |
| metadata: { msgId: event.id }, | |
| }); | |
| setErrorMessage(errorMessage); | |
| return; | |
| } | |
| if (isOpenHandsAction(event) || isOpenHandsObservation(event)) { | |
| setParsedEvents((prevEvents) => [...prevEvents, event]); | |
| } | |
| if (isErrorObservation(event)) { | |
| trackError({ | |
| message: event.message, | |
| source: "chat", | |
| metadata: { msgId: event.id }, | |
| }); | |
| } else { | |
| removeErrorMessage(); | |
| } | |
| if (isUserMessage(event)) { | |
| removeOptimisticUserMessage(); | |
| } | |
| if (isMessageAction(event)) { | |
| messageRateHandler.record(new Date().getTime()); | |
| } | |
| // Invalidate diffs cache when a file is edited or written | |
| if ( | |
| isFileEditAction(event) || | |
| isFileWriteAction(event) || | |
| isCommandAction(event) | |
| ) { | |
| queryClient.invalidateQueries( | |
| { | |
| queryKey: ["file_changes", conversationId], | |
| }, | |
| // Do not refetch if we are still receiving messages at a high rate (e.g., loading an existing conversation) | |
| // This prevents unnecessary refetches when the user is still receiving messages | |
| { cancelRefetch: false }, | |
| ); | |
| // Invalidate file diff cache when a file is edited or written | |
| if (!isCommandAction(event)) { | |
| const cachedConversaton = queryClient.getQueryData<Conversation>([ | |
| "user", | |
| "conversation", | |
| conversationId, | |
| ]); | |
| const clonedRepositoryDirectory = | |
| cachedConversaton?.selected_repository?.split("/").pop(); | |
| let fileToInvalidate = event.args.path.replace("/workspace/", ""); | |
| if (clonedRepositoryDirectory) { | |
| fileToInvalidate = fileToInvalidate.replace( | |
| `${clonedRepositoryDirectory}/`, | |
| "", | |
| ); | |
| } | |
| queryClient.invalidateQueries({ | |
| queryKey: ["file_diff", conversationId, fileToInvalidate], | |
| }); | |
| } | |
| } | |
| } | |
| setEvents((prevEvents) => [...prevEvents, event]); | |
| if (!Number.isNaN(parseInt(event.id as string, 10))) { | |
| lastEventRef.current = event; | |
| } | |
| } | |
| function handleDisconnect(data: unknown) { | |
| setStatus(WsClientProviderStatus.DISCONNECTED); | |
| const sio = sioRef.current; | |
| if (!sio) { | |
| return; | |
| } | |
| sio.io.opts.query = sio.io.opts.query || {}; | |
| sio.io.opts.query.latest_event_id = lastEventRef.current?.id; | |
| updateStatusWhenErrorMessagePresent(data); | |
| setErrorMessage(hasValidMessageProperty(data) ? data.message : ""); | |
| } | |
| function handleError(data: unknown) { | |
| // set status | |
| setStatus(WsClientProviderStatus.DISCONNECTED); | |
| updateStatusWhenErrorMessagePresent(data); | |
| setErrorMessage( | |
| hasValidMessageProperty(data) | |
| ? data.message | |
| : "An unknown error occurred on the WebSocket connection.", | |
| ); | |
| // check if something went wrong with the conversation. | |
| refetchConversation(); | |
| } | |
| React.useEffect(() => { | |
| lastEventRef.current = null; | |
| // reset events when conversationId changes | |
| setEvents([]); | |
| setParsedEvents([]); | |
| setStatus(WsClientProviderStatus.DISCONNECTED); | |
| }, [conversationId]); | |
| React.useEffect(() => { | |
| if (!conversationId) { | |
| throw new Error("No conversation ID provided"); | |
| } | |
| if ( | |
| !conversation || | |
| ["STOPPED", "STARTING"].includes(conversation.status) | |
| ) { | |
| return () => undefined; // conversation not yet loaded | |
| } | |
| let sio = sioRef.current; | |
| if (sio?.connected) { | |
| sio.disconnect(); | |
| } | |
| const lastEvent = lastEventRef.current; | |
| const query = { | |
| latest_event_id: lastEvent?.id ?? -1, | |
| conversation_id: conversationId, | |
| providers_set: providers, | |
| session_api_key: conversation.session_api_key, // Have to set here because socketio doesn't support custom headers. :( | |
| }; | |
| let baseUrl = null; | |
| if (conversation.url && !conversation.url.startsWith("/")) { | |
| baseUrl = new URL(conversation.url).host; | |
| } else { | |
| baseUrl = import.meta.env.VITE_BACKEND_BASE_URL || window?.location.host; | |
| } | |
| sio = io(baseUrl, { | |
| transports: ["websocket"], | |
| query, | |
| }); | |
| sio.on("connect", handleConnect); | |
| sio.on("oh_event", handleMessage); | |
| sio.on("connect_error", handleError); | |
| sio.on("connect_failed", handleError); | |
| sio.on("disconnect", handleDisconnect); | |
| sioRef.current = sio; | |
| return () => { | |
| sio.off("connect", handleConnect); | |
| sio.off("oh_event", handleMessage); | |
| sio.off("connect_error", handleError); | |
| sio.off("connect_failed", handleError); | |
| sio.off("disconnect", handleDisconnect); | |
| }; | |
| }, [conversationId, conversation?.url, conversation?.status]); | |
| React.useEffect( | |
| () => () => { | |
| const sio = sioRef.current; | |
| if (sio) { | |
| sio.off("disconnect", handleDisconnect); | |
| sio.disconnect(); | |
| } | |
| }, | |
| [], | |
| ); | |
| const value = React.useMemo<UseWsClient>( | |
| () => ({ | |
| status, | |
| isLoadingMessages: messageRateHandler.isUnderThreshold, | |
| events, | |
| parsedEvents, | |
| send, | |
| }), | |
| [status, messageRateHandler.isUnderThreshold, events, parsedEvents], | |
| ); | |
| return <WsClientContext value={value}>{children}</WsClientContext>; | |
| } | |
| export function useWsClient() { | |
| const context = React.useContext(WsClientContext); | |
| return context; | |
| } | |