| | import { useCallback, useEffect, useRef } from 'react'; |
| | import { useAgentStore, type PlanItem } from '@/store/agentStore'; |
| | import { useSessionStore } from '@/store/sessionStore'; |
| | import { useLayoutStore } from '@/store/layoutStore'; |
| | import { getWebSocketUrl } from '@/utils/api'; |
| | import { logger } from '@/utils/logger'; |
| | import type { AgentEvent } from '@/types/events'; |
| | import type { Message, TraceLog } from '@/types/agent'; |
| |
|
| | const WS_RECONNECT_DELAY = 1000; |
| | const WS_MAX_RECONNECT_DELAY = 30000; |
| | const WS_MAX_RETRIES = 5; |
| |
|
| | interface UseAgentWebSocketOptions { |
| | sessionId: string | null; |
| | onReady?: () => void; |
| | onError?: (error: string) => void; |
| | onSessionDead?: (sessionId: string) => void; |
| | } |
| |
|
| | export function useAgentWebSocket({ |
| | sessionId, |
| | onReady, |
| | onError, |
| | onSessionDead, |
| | }: UseAgentWebSocketOptions) { |
| | const wsRef = useRef<WebSocket | null>(null); |
| | const reconnectTimeoutRef = useRef<number | null>(null); |
| | const reconnectDelayRef = useRef(WS_RECONNECT_DELAY); |
| | const retriesRef = useRef(0); |
| |
|
| | const { |
| | addMessage, |
| | updateMessage, |
| | appendToMessage, |
| | setProcessing, |
| | setConnected, |
| | setError, |
| | addTraceLog, |
| | updateTraceLog, |
| | clearTraceLogs, |
| | setPanelContent, |
| | setPanelTab, |
| | setActivePanelTab, |
| | clearPanelTabs, |
| | setPlan, |
| | setCurrentTurnMessageId, |
| | updateCurrentTurnTrace, |
| | removeLastTurn, |
| | } = useAgentStore(); |
| |
|
| | const { setRightPanelOpen, setLeftSidebarOpen } = useLayoutStore(); |
| |
|
| | const { setSessionActive } = useSessionStore(); |
| |
|
| | const handleEvent = useCallback( |
| | (event: AgentEvent) => { |
| | if (!sessionId) return; |
| |
|
| | switch (event.event_type) { |
| | case 'ready': |
| | setConnected(true); |
| | setProcessing(false); |
| | setSessionActive(sessionId, true); |
| | onReady?.(); |
| | break; |
| |
|
| | case 'processing': |
| | setProcessing(true); |
| | clearTraceLogs(); |
| | |
| | |
| | setCurrentTurnMessageId(null); |
| | break; |
| |
|
| | |
| | case 'assistant_chunk': { |
| | const delta = (event.data?.content as string) || ''; |
| | if (!delta) break; |
| |
|
| | const currentTurnMsgId = useAgentStore.getState().currentTurnMessageId; |
| |
|
| | if (currentTurnMsgId) { |
| | |
| | appendToMessage(sessionId, currentTurnMsgId, delta); |
| | } else { |
| | |
| | const currentTrace = useAgentStore.getState().traceLogs; |
| | const messageId = `msg_${Date.now()}`; |
| | const segments: Array<{ type: 'text' | 'tools'; content?: string; tools?: typeof currentTrace }> = []; |
| |
|
| | if (currentTrace.length > 0) { |
| | segments.push({ type: 'tools', tools: [...currentTrace] }); |
| | clearTraceLogs(); |
| | } |
| | segments.push({ type: 'text', content: delta }); |
| |
|
| | const message: Message = { |
| | id: messageId, |
| | role: 'assistant', |
| | content: delta, |
| | timestamp: new Date().toISOString(), |
| | segments, |
| | }; |
| | addMessage(sessionId, message); |
| | setCurrentTurnMessageId(messageId); |
| | } |
| | break; |
| | } |
| |
|
| | |
| | case 'assistant_stream_end': |
| | |
| | |
| | break; |
| |
|
| | |
| | case 'assistant_message': { |
| | const content = (event.data?.content as string) || ''; |
| | const currentTrace = useAgentStore.getState().traceLogs; |
| | const currentTurnMsgId = useAgentStore.getState().currentTurnMessageId; |
| |
|
| | if (currentTurnMsgId) { |
| | |
| | const messages = useAgentStore.getState().getMessages(sessionId); |
| | const existingMsg = messages.find(m => m.id === currentTurnMsgId); |
| |
|
| | if (existingMsg) { |
| | const segments = existingMsg.segments ? [...existingMsg.segments] : []; |
| |
|
| | |
| | if (currentTrace.length > 0) { |
| | segments.push({ type: 'tools', tools: [...currentTrace] }); |
| | clearTraceLogs(); |
| | } |
| |
|
| | |
| | if (content) { |
| | segments.push({ type: 'text', content }); |
| | } |
| |
|
| | updateMessage(sessionId, currentTurnMsgId, { |
| | content: existingMsg.content + '\n\n' + content, |
| | segments, |
| | }); |
| | } |
| | } else { |
| | |
| | const messageId = `msg_${Date.now()}`; |
| | const segments: Array<{ type: 'text' | 'tools'; content?: string; tools?: typeof currentTrace }> = []; |
| |
|
| | |
| | if (currentTrace.length > 0) { |
| | segments.push({ type: 'tools', tools: [...currentTrace] }); |
| | clearTraceLogs(); |
| | } |
| |
|
| | |
| | if (content) { |
| | segments.push({ type: 'text', content }); |
| | } |
| |
|
| | const message: Message = { |
| | id: messageId, |
| | role: 'assistant', |
| | content, |
| | timestamp: new Date().toISOString(), |
| | segments, |
| | }; |
| | addMessage(sessionId, message); |
| | setCurrentTurnMessageId(messageId); |
| | } |
| | break; |
| | } |
| |
|
| | case 'tool_call': { |
| | const toolName = (event.data?.tool as string) || 'unknown'; |
| | const toolCallId = (event.data?.tool_call_id as string) || ''; |
| | const args = (event.data?.arguments as Record<string, string | undefined>) || {}; |
| |
|
| | |
| | if (toolName !== 'plan_tool') { |
| | const log: TraceLog = { |
| | id: `tool_${Date.now()}_${toolCallId}`, |
| | toolCallId, |
| | type: 'call', |
| | text: `Agent is executing ${toolName}...`, |
| | tool: toolName, |
| | timestamp: new Date().toISOString(), |
| | completed: false, |
| | args, |
| | }; |
| | addTraceLog(log); |
| |
|
| | |
| | |
| | const currentTurnMsgId = useAgentStore.getState().currentTurnMessageId; |
| | if (!currentTurnMsgId) { |
| | const messageId = `msg_${Date.now()}`; |
| | const currentTrace = useAgentStore.getState().traceLogs; |
| | addMessage(sessionId, { |
| | id: messageId, |
| | role: 'assistant', |
| | content: '', |
| | timestamp: new Date().toISOString(), |
| | segments: [{ type: 'tools', tools: [...currentTrace] }], |
| | }); |
| | setCurrentTurnMessageId(messageId); |
| | clearTraceLogs(); |
| | } else { |
| | updateCurrentTurnTrace(sessionId); |
| | } |
| | } |
| |
|
| | |
| | if (toolName === 'hf_jobs' && (args.operation === 'run' || args.operation === 'scheduled run') && args.script) { |
| | |
| | clearPanelTabs(); |
| | |
| | setPanelTab({ |
| | id: 'script', |
| | title: 'Script', |
| | content: args.script, |
| | language: 'python', |
| | parameters: args |
| | }); |
| | setActivePanelTab('script'); |
| | setRightPanelOpen(true); |
| | setLeftSidebarOpen(false); |
| | } else if (toolName === 'hf_repo_files' && args.operation === 'upload' && args.content) { |
| | setPanelContent({ |
| | title: `File Upload: ${args.path || 'unnamed'}`, |
| | content: args.content, |
| | parameters: args, |
| | language: args.path?.endsWith('.py') ? 'python' : undefined |
| | }); |
| | setRightPanelOpen(true); |
| | setLeftSidebarOpen(false); |
| | } |
| |
|
| | logger.log('Tool call:', toolName, args); |
| | break; |
| | } |
| |
|
| | case 'tool_output': { |
| | const toolName = (event.data?.tool as string) || 'unknown'; |
| | const toolCallId = (event.data?.tool_call_id as string) || ''; |
| | const output = (event.data?.output as string) || ''; |
| | const success = event.data?.success as boolean; |
| |
|
| | |
| | |
| | const prevLog = useAgentStore.getState().traceLogs.find( |
| | (l) => l.toolCallId === toolCallId |
| | ); |
| | const wasApproval = prevLog?.approvalStatus === 'pending'; |
| | updateTraceLog(toolCallId, toolName, { |
| | completed: true, |
| | output, |
| | success, |
| | ...(wasApproval ? { approvalStatus: 'approved' as const } : {}), |
| | }); |
| | updateCurrentTurnTrace(sessionId); |
| |
|
| | |
| | if (toolName === 'hf_jobs' && output) { |
| | const updates: Partial<TraceLog> = { approvalStatus: 'approved' as const }; |
| |
|
| | |
| | const urlMatch = output.match(/\*\*View at:\*\*\s*(https:\/\/[^\s\n]+)/); |
| | if (urlMatch) updates.jobUrl = urlMatch[1]; |
| |
|
| | |
| | const statusMatch = output.match(/\*\*Final Status:\*\*\s*([^\n]+)/); |
| | if (statusMatch) updates.jobStatus = statusMatch[1].trim(); |
| |
|
| | |
| | if (output.includes('**Logs:**')) { |
| | const parts = output.split('**Logs:**'); |
| | if (parts.length > 1) { |
| | const codeBlockMatch = parts[1].trim().match(/```([\s\S]*?)```/); |
| | if (codeBlockMatch) updates.jobLogs = codeBlockMatch[1].trim(); |
| | } |
| | } |
| |
|
| | updateTraceLog(toolCallId, toolName, updates); |
| | updateCurrentTurnTrace(sessionId); |
| |
|
| | |
| | setPanelTab({ |
| | id: 'output', |
| | title: 'Output', |
| | content: output, |
| | language: 'markdown', |
| | }); |
| | |
| | if (!success) { |
| | setActivePanelTab('output'); |
| | } |
| | } |
| |
|
| | |
| | logger.log('Tool output:', toolName, success); |
| | break; |
| | } |
| |
|
| | case 'tool_log': { |
| | const toolName = (event.data?.tool as string) || 'unknown'; |
| | const log = (event.data?.log as string) || ''; |
| |
|
| | if (toolName === 'hf_jobs') { |
| | const currentTabs = useAgentStore.getState().panelTabs; |
| | const logsTab = currentTabs.find(t => t.id === 'logs'); |
| |
|
| | |
| | const newContent = logsTab |
| | ? logsTab.content + '\n' + log |
| | : '--- Job execution started ---\n' + log; |
| |
|
| | setPanelTab({ |
| | id: 'logs', |
| | title: 'Logs', |
| | content: newContent, |
| | language: 'text' |
| | }); |
| |
|
| | |
| | setActivePanelTab('logs'); |
| |
|
| | if (!useLayoutStore.getState().isRightPanelOpen) { |
| | setRightPanelOpen(true); |
| | } |
| | } |
| | break; |
| | } |
| |
|
| | case 'plan_update': { |
| | const plan = (event.data?.plan as PlanItem[]) || []; |
| | setPlan(plan); |
| | if (!useLayoutStore.getState().isRightPanelOpen) { |
| | setRightPanelOpen(true); |
| | } |
| | break; |
| | } |
| |
|
| | case 'approval_required': { |
| | const tools = event.data?.tools as Array<{ |
| | tool: string; |
| | arguments: Record<string, unknown>; |
| | tool_call_id: string; |
| | }>; |
| |
|
| | |
| | |
| | |
| | if (tools) { |
| | for (const t of tools) { |
| | |
| | const existing = useAgentStore.getState().traceLogs.find( |
| | (log) => log.toolCallId === t.tool_call_id |
| | ); |
| | if (!existing) { |
| | addTraceLog({ |
| | id: `tool_${Date.now()}_${t.tool_call_id}`, |
| | toolCallId: t.tool_call_id, |
| | type: 'call', |
| | text: `Approval required for ${t.tool}`, |
| | tool: t.tool, |
| | timestamp: new Date().toISOString(), |
| | completed: false, |
| | args: t.arguments as Record<string, unknown>, |
| | approvalStatus: 'pending', |
| | }); |
| | } else { |
| | updateTraceLog(t.tool_call_id, t.tool, { |
| | approvalStatus: 'pending', |
| | args: t.arguments as Record<string, unknown>, |
| | }); |
| | } |
| | } |
| |
|
| | |
| | const currentTurnMsgId = useAgentStore.getState().currentTurnMessageId; |
| | if (!currentTurnMsgId) { |
| | const messageId = `msg_${Date.now()}`; |
| | const currentTrace = useAgentStore.getState().traceLogs; |
| | addMessage(sessionId, { |
| | id: messageId, |
| | role: 'assistant', |
| | content: '', |
| | timestamp: new Date().toISOString(), |
| | segments: [{ type: 'tools', tools: [...currentTrace] }], |
| | }); |
| | setCurrentTurnMessageId(messageId); |
| | clearTraceLogs(); |
| | } else { |
| | updateCurrentTurnTrace(sessionId); |
| | } |
| | } |
| |
|
| | |
| | if (tools && tools.length > 0) { |
| | const firstTool = tools[0]; |
| | const args = firstTool.arguments as Record<string, string | undefined>; |
| |
|
| | clearPanelTabs(); |
| |
|
| | if (firstTool.tool === 'hf_jobs' && args.script) { |
| | setPanelTab({ |
| | id: 'script', |
| | title: 'Script', |
| | content: args.script, |
| | language: 'python', |
| | parameters: args |
| | }); |
| | setActivePanelTab('script'); |
| | } else if (firstTool.tool === 'hf_repo_files' && args.content) { |
| | const filename = args.path || 'file'; |
| | const isPython = filename.endsWith('.py'); |
| | setPanelTab({ |
| | id: 'content', |
| | title: filename.split('/').pop() || 'Content', |
| | content: args.content, |
| | language: isPython ? 'python' : 'text', |
| | parameters: args |
| | }); |
| | setActivePanelTab('content'); |
| | } else { |
| | setPanelTab({ |
| | id: 'args', |
| | title: firstTool.tool, |
| | content: JSON.stringify(args, null, 2), |
| | language: 'json', |
| | parameters: args |
| | }); |
| | setActivePanelTab('args'); |
| | } |
| |
|
| | setRightPanelOpen(true); |
| | setLeftSidebarOpen(false); |
| | } |
| |
|
| | setProcessing(false); |
| | break; |
| | } |
| |
|
| | case 'turn_complete': |
| | setProcessing(false); |
| | setCurrentTurnMessageId(null); |
| | break; |
| |
|
| | case 'compacted': { |
| | const oldTokens = event.data?.old_tokens as number; |
| | const newTokens = event.data?.new_tokens as number; |
| | logger.log(`Context compacted: ${oldTokens} -> ${newTokens} tokens`); |
| | break; |
| | } |
| |
|
| | case 'error': { |
| | const errorMsg = (event.data?.error as string) || 'Unknown error'; |
| | setError(errorMsg); |
| | setProcessing(false); |
| | onError?.(errorMsg); |
| | break; |
| | } |
| |
|
| | case 'shutdown': |
| | setConnected(false); |
| | setProcessing(false); |
| | break; |
| |
|
| | case 'interrupted': |
| | setProcessing(false); |
| | break; |
| |
|
| | case 'undo_complete': |
| | if (sessionId) { |
| | removeLastTurn(sessionId); |
| | } |
| | setProcessing(false); |
| | break; |
| |
|
| | default: |
| | logger.log('Unknown event:', event); |
| | } |
| | }, |
| | |
| | |
| | [sessionId, onReady, onError, onSessionDead] |
| | ); |
| |
|
| | const connect = useCallback(() => { |
| | if (!sessionId) return; |
| | |
| | |
| | if (wsRef.current?.readyState === WebSocket.OPEN || |
| | wsRef.current?.readyState === WebSocket.CONNECTING) { |
| | return; |
| | } |
| |
|
| | |
| | const wsUrl = getWebSocketUrl(sessionId); |
| |
|
| | logger.log('Connecting to WebSocket:', wsUrl); |
| | const ws = new WebSocket(wsUrl); |
| |
|
| | ws.onopen = () => { |
| | logger.log('WebSocket connected'); |
| | setConnected(true); |
| | reconnectDelayRef.current = WS_RECONNECT_DELAY; |
| | retriesRef.current = 0; |
| | }; |
| |
|
| | ws.onmessage = (event) => { |
| | try { |
| | const data = JSON.parse(event.data) as AgentEvent; |
| | handleEvent(data); |
| | } catch (e) { |
| | logger.error('Failed to parse WebSocket message:', e); |
| | } |
| | }; |
| |
|
| | ws.onerror = (error) => { |
| | logger.error('WebSocket error:', error); |
| | }; |
| |
|
| | ws.onclose = (event) => { |
| | logger.log('WebSocket closed', event.code, event.reason); |
| | setConnected(false); |
| |
|
| | |
| | |
| | |
| | |
| | |
| | const noRetryCodes = [1000, 4001, 4003, 4004]; |
| | if (!noRetryCodes.includes(event.code) && sessionId) { |
| | retriesRef.current += 1; |
| | if (retriesRef.current > WS_MAX_RETRIES) { |
| | logger.warn(`WebSocket: max retries (${WS_MAX_RETRIES}) reached, giving up.`); |
| | onSessionDead?.(sessionId); |
| | return; |
| | } |
| | |
| | if (reconnectTimeoutRef.current) { |
| | clearTimeout(reconnectTimeoutRef.current); |
| | } |
| | reconnectTimeoutRef.current = window.setTimeout(() => { |
| | reconnectDelayRef.current = Math.min( |
| | reconnectDelayRef.current * 2, |
| | WS_MAX_RECONNECT_DELAY |
| | ); |
| | connect(); |
| | }, reconnectDelayRef.current); |
| | } else if (event.code === 4004 && sessionId) { |
| | |
| | logger.warn(`Session ${sessionId} no longer exists on backend, removing.`); |
| | onSessionDead?.(sessionId); |
| | } else if (noRetryCodes.includes(event.code) && event.code !== 1000) { |
| | logger.warn(`WebSocket permanently closed: ${event.code} ${event.reason}`); |
| | } |
| | }; |
| |
|
| | wsRef.current = ws; |
| | }, [sessionId, handleEvent]); |
| |
|
| | const disconnect = useCallback(() => { |
| | if (reconnectTimeoutRef.current) { |
| | clearTimeout(reconnectTimeoutRef.current); |
| | reconnectTimeoutRef.current = null; |
| | } |
| | if (wsRef.current) { |
| | wsRef.current.close(); |
| | wsRef.current = null; |
| | } |
| | setConnected(false); |
| | }, []); |
| |
|
| | const sendPing = useCallback(() => { |
| | if (wsRef.current?.readyState === WebSocket.OPEN) { |
| | wsRef.current.send(JSON.stringify({ type: 'ping' })); |
| | } |
| | }, []); |
| |
|
| | |
| | useEffect(() => { |
| | if (!sessionId) { |
| | disconnect(); |
| | return; |
| | } |
| |
|
| | |
| | retriesRef.current = 0; |
| | reconnectDelayRef.current = WS_RECONNECT_DELAY; |
| |
|
| | |
| | const timeoutId = setTimeout(() => { |
| | connect(); |
| | }, 100); |
| |
|
| | return () => { |
| | clearTimeout(timeoutId); |
| | disconnect(); |
| | }; |
| | |
| | }, [sessionId]); |
| |
|
| | |
| | useEffect(() => { |
| | const interval = setInterval(sendPing, 30000); |
| | return () => clearInterval(interval); |
| | }, [sendPing]); |
| |
|
| | return { |
| | isConnected: wsRef.current?.readyState === WebSocket.OPEN, |
| | connect, |
| | disconnect, |
| | }; |
| | } |
| |
|