File size: 6,505 Bytes
d4abe4b 3e12ae4 d4abe4b 3e12ae4 d4abe4b 3e12ae4 d4abe4b 3e12ae4 d4abe4b 3e12ae4 d4abe4b 3e12ae4 d4abe4b 3e12ae4 d4abe4b 3e12ae4 d4abe4b 3e12ae4 d4abe4b 3e12ae4 d4abe4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
/**
* LangChain Callback Handler for WebSocket Streaming
* Streams ReAct flow steps to frontend in real-time
*/
import { BaseCallbackHandler } from '@langchain/core/callbacks/base';
import { AgentAction, AgentFinish } from '@langchain/core/agents';
import { wsConnectionManager } from '../services/websocket.service.js';
import { logger } from '../utils/logger.js';
import type {
ThoughtMessage,
ActionStartMessage,
ActionCompleteMessage,
ActionErrorMessage,
ObservationMessage,
FinalAnswerMessage,
} from '../types/websocket.js';
import { TOOL_DISPLAY_NAMES } from '../types/websocket.js';
export class WebSocketStreamHandler extends BaseCallbackHandler {
name = 'WebSocketStreamHandler';
private sessionId: string;
private actionStartTimes: Map<string, number> = new Map();
constructor(sessionId: string) {
super();
this.sessionId = sessionId;
}
/**
* Called when agent starts thinking
*/
async onAgentAction(action: AgentAction): Promise<void> {
try {
// Send thought message (agent's reasoning)
if (action.log) {
const thoughtMsg: ThoughtMessage = {
type: 'thought',
content: action.log,
timestamp: new Date().toISOString(),
variant: 'intermediate',
};
await wsConnectionManager.sendToSession(this.sessionId, thoughtMsg);
logger.debug(`Sent thought for session ${this.sessionId}`);
}
// Send action start
const toolName = action.tool;
const displayName = TOOL_DISPLAY_NAMES[toolName] || toolName;
const actionStartMsg: ActionStartMessage = {
type: 'action_start',
tool_name: toolName,
tool_display_name: displayName,
timestamp: new Date().toISOString(),
};
// Record start time for duration calculation
this.actionStartTimes.set(toolName, Date.now());
await wsConnectionManager.sendToSession(this.sessionId, actionStartMsg);
logger.debug(`Sent action_start for tool ${toolName}, session ${this.sessionId}`);
} catch (error) {
logger.error({ error }, 'Error in onAgentAction');
}
}
/**
* Called when tool execution completes
*/
async onToolEnd(output: string, _runId: string, _parentRunId?: string, _tags?: string[]): Promise<void> {
try {
// Try to parse tool output as JSON (most tools return JSON)
let results: any;
try {
results = JSON.parse(output);
} catch {
results = { output };
}
// Get tool name from tags or use fallback
const toolName = _tags?.find((tag: string) => tag in TOOL_DISPLAY_NAMES) || 'unknown';
// Calculate duration
const startTime = this.actionStartTimes.get(toolName);
const duration = startTime ? Date.now() - startTime : 0;
this.actionStartTimes.delete(toolName);
// Send action complete
const actionCompleteMsg: ActionCompleteMessage = {
type: 'action_complete',
tool_name: toolName,
duration_ms: duration,
results,
timestamp: new Date().toISOString(),
};
await wsConnectionManager.sendToSession(this.sessionId, actionCompleteMsg);
logger.debug(`Sent action_complete for tool ${toolName}, session ${this.sessionId}`);
// Send observation if results contain predictions/confidence
if (results.predictions || results.confidence) {
const observationMsg: ObservationMessage = {
type: 'observation',
tool_name: toolName,
findings: results,
confidence: results.confidence || (results.predictions?.[0]?.confidence ?? 0),
timestamp: new Date().toISOString(),
};
await wsConnectionManager.sendToSession(this.sessionId, observationMsg);
logger.debug(`Sent observation for tool ${toolName}, session ${this.sessionId}`);
}
} catch (error) {
logger.error({ error }, 'Error in onToolEnd');
}
}
/**
* Called when tool execution fails
*/
async onToolError(error: Error, _runId: string, _parentRunId?: string, _tags?: string[]): Promise<void> {
try {
const toolName = _tags?.find((tag: string) => tag in TOOL_DISPLAY_NAMES) || 'unknown';
// Calculate duration
const startTime = this.actionStartTimes.get(toolName);
const duration = startTime ? Date.now() - startTime : 0;
this.actionStartTimes.delete(toolName);
const actionErrorMsg: ActionErrorMessage = {
type: 'action_error',
tool_name: toolName,
error_code: 'TOOL_ERROR',
error_message: error.message || 'Tool execution failed',
duration_ms: duration,
timestamp: new Date().toISOString(),
};
await wsConnectionManager.sendToSession(this.sessionId, actionErrorMsg);
logger.debug(`Sent action_error for tool ${toolName}, session ${this.sessionId}`);
} catch (err) {
logger.error({ error: err }, 'Error in onToolError');
}
}
/**
* Called when agent completes (final answer)
*/
async onAgentFinish(finish: AgentFinish): Promise<void> {
try {
// Send final thought
if (finish.log) {
const thoughtMsg: ThoughtMessage = {
type: 'thought',
content: finish.log,
timestamp: new Date().toISOString(),
variant: 'final',
};
await wsConnectionManager.sendToSession(this.sessionId, thoughtMsg);
}
// Send final answer
const finalAnswerMsg: FinalAnswerMessage = {
type: 'final_answer',
result: finish.returnValues as any,
timestamp: new Date().toISOString(),
};
await wsConnectionManager.sendToSession(this.sessionId, finalAnswerMsg);
logger.debug(`Sent final_answer for session ${this.sessionId}`);
} catch (error) {
logger.error({ error }, 'Error in onAgentFinish');
}
}
/**
* Called when LLM starts
*/
async onLLMStart(): Promise<void> {
// Optional: can send a message indicating LLM is thinking
}
/**
* Called when LLM ends
*/
async onLLMEnd(): Promise<void> {
// Optional: can send a message indicating LLM finished
}
/**
* Called on any error
*/
async onChainError(error: Error): Promise<void> {
try {
wsConnectionManager.sendError(this.sessionId, 'CHAIN_ERROR', error.message);
logger.error({ error, sessionId: this.sessionId }, 'Chain error');
} catch (err) {
logger.error({ error: err }, 'Error in onChainError');
}
}
}
|