|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import { describe, it, expect, vi, beforeEach, Mock } from 'vitest'; |
|
|
import { renderHook, act, waitFor } from '@testing-library/react'; |
|
|
import { useGeminiStream, mergePartListUnions } from './useGeminiStream.js'; |
|
|
import { useInput } from 'ink'; |
|
|
import { |
|
|
useReactToolScheduler, |
|
|
TrackedToolCall, |
|
|
TrackedCompletedToolCall, |
|
|
TrackedExecutingToolCall, |
|
|
TrackedCancelledToolCall, |
|
|
} from './useReactToolScheduler.js'; |
|
|
import { Config, EditorType, AuthType } from '@google/gemini-cli-core'; |
|
|
import { Part, PartListUnion } from '@google/genai'; |
|
|
import { UseHistoryManagerReturn } from './useHistoryManager.js'; |
|
|
import { HistoryItem, MessageType, StreamingState } from '../types.js'; |
|
|
import { Dispatch, SetStateAction } from 'react'; |
|
|
import { LoadedSettings } from '../../config/settings.js'; |
|
|
|
|
|
|
|
|
const mockSendMessageStream = vi |
|
|
.fn() |
|
|
.mockReturnValue((async function* () {})()); |
|
|
const mockStartChat = vi.fn(); |
|
|
|
|
|
const MockedGeminiClientClass = vi.hoisted(() => |
|
|
vi.fn().mockImplementation(function (this: any, _config: any) { |
|
|
|
|
|
this.startChat = mockStartChat; |
|
|
this.sendMessageStream = mockSendMessageStream; |
|
|
this.addHistory = vi.fn(); |
|
|
}), |
|
|
); |
|
|
|
|
|
const MockedUserPromptEvent = vi.hoisted(() => |
|
|
vi.fn().mockImplementation(() => {}), |
|
|
); |
|
|
|
|
|
vi.mock('@google/gemini-cli-core', async (importOriginal) => { |
|
|
const actualCoreModule = (await importOriginal()) as any; |
|
|
return { |
|
|
...actualCoreModule, |
|
|
GitService: vi.fn(), |
|
|
GeminiClient: MockedGeminiClientClass, |
|
|
UserPromptEvent: MockedUserPromptEvent, |
|
|
}; |
|
|
}); |
|
|
|
|
|
const mockUseReactToolScheduler = useReactToolScheduler as Mock; |
|
|
vi.mock('./useReactToolScheduler.js', async (importOriginal) => { |
|
|
const actualSchedulerModule = (await importOriginal()) as any; |
|
|
return { |
|
|
...(actualSchedulerModule || {}), |
|
|
useReactToolScheduler: vi.fn(), |
|
|
}; |
|
|
}); |
|
|
|
|
|
vi.mock('ink', async (importOriginal) => { |
|
|
const actualInkModule = (await importOriginal()) as any; |
|
|
return { ...(actualInkModule || {}), useInput: vi.fn() }; |
|
|
}); |
|
|
|
|
|
vi.mock('./shellCommandProcessor.js', () => ({ |
|
|
useShellCommandProcessor: vi.fn().mockReturnValue({ |
|
|
handleShellCommand: vi.fn(), |
|
|
}), |
|
|
})); |
|
|
|
|
|
vi.mock('./atCommandProcessor.js', () => ({ |
|
|
handleAtCommand: vi |
|
|
.fn() |
|
|
.mockResolvedValue({ shouldProceed: true, processedQuery: 'mocked' }), |
|
|
})); |
|
|
|
|
|
vi.mock('../utils/markdownUtilities.js', () => ({ |
|
|
findLastSafeSplitPoint: vi.fn((s: string) => s.length), |
|
|
})); |
|
|
|
|
|
vi.mock('./useStateAndRef.js', () => ({ |
|
|
useStateAndRef: vi.fn((initial) => { |
|
|
let val = initial; |
|
|
const ref = { current: val }; |
|
|
const setVal = vi.fn((updater) => { |
|
|
if (typeof updater === 'function') { |
|
|
val = updater(val); |
|
|
} else { |
|
|
val = updater; |
|
|
} |
|
|
ref.current = val; |
|
|
}); |
|
|
return [ref, setVal]; |
|
|
}), |
|
|
})); |
|
|
|
|
|
vi.mock('./useLogger.js', () => ({ |
|
|
useLogger: vi.fn().mockReturnValue({ |
|
|
logMessage: vi.fn().mockResolvedValue(undefined), |
|
|
}), |
|
|
})); |
|
|
|
|
|
const mockStartNewTurn = vi.fn(); |
|
|
const mockAddUsage = vi.fn(); |
|
|
vi.mock('../contexts/SessionContext.js', () => ({ |
|
|
useSessionStats: vi.fn(() => ({ |
|
|
startNewTurn: mockStartNewTurn, |
|
|
addUsage: mockAddUsage, |
|
|
})), |
|
|
})); |
|
|
|
|
|
vi.mock('./slashCommandProcessor.js', () => ({ |
|
|
handleSlashCommand: vi.fn().mockReturnValue(false), |
|
|
})); |
|
|
|
|
|
const mockParseAndFormatApiError = vi.hoisted(() => vi.fn()); |
|
|
vi.mock('../utils/errorParsing.js', () => ({ |
|
|
parseAndFormatApiError: mockParseAndFormatApiError, |
|
|
})); |
|
|
|
|
|
|
|
|
|
|
|
describe('mergePartListUnions', () => { |
|
|
it('should merge multiple PartListUnion arrays', () => { |
|
|
const list1: PartListUnion = [{ text: 'Hello' }]; |
|
|
const list2: PartListUnion = [ |
|
|
{ inlineData: { mimeType: 'image/png', data: 'abc' } }, |
|
|
]; |
|
|
const list3: PartListUnion = [{ text: 'World' }, { text: '!' }]; |
|
|
const result = mergePartListUnions([list1, list2, list3]); |
|
|
expect(result).toEqual([ |
|
|
{ text: 'Hello' }, |
|
|
{ inlineData: { mimeType: 'image/png', data: 'abc' } }, |
|
|
{ text: 'World' }, |
|
|
{ text: '!' }, |
|
|
]); |
|
|
}); |
|
|
|
|
|
it('should handle empty arrays in the input list', () => { |
|
|
const list1: PartListUnion = [{ text: 'First' }]; |
|
|
const list2: PartListUnion = []; |
|
|
const list3: PartListUnion = [{ text: 'Last' }]; |
|
|
const result = mergePartListUnions([list1, list2, list3]); |
|
|
expect(result).toEqual([{ text: 'First' }, { text: 'Last' }]); |
|
|
}); |
|
|
|
|
|
it('should handle a single PartListUnion array', () => { |
|
|
const list1: PartListUnion = [ |
|
|
{ text: 'One' }, |
|
|
{ inlineData: { mimeType: 'image/jpeg', data: 'xyz' } }, |
|
|
]; |
|
|
const result = mergePartListUnions([list1]); |
|
|
expect(result).toEqual(list1); |
|
|
}); |
|
|
|
|
|
it('should return an empty array if all input arrays are empty', () => { |
|
|
const list1: PartListUnion = []; |
|
|
const list2: PartListUnion = []; |
|
|
const result = mergePartListUnions([list1, list2]); |
|
|
expect(result).toEqual([]); |
|
|
}); |
|
|
|
|
|
it('should handle input list being empty', () => { |
|
|
const result = mergePartListUnions([]); |
|
|
expect(result).toEqual([]); |
|
|
}); |
|
|
|
|
|
it('should correctly merge when PartListUnion items are single Parts not in arrays', () => { |
|
|
const part1: Part = { text: 'Single part 1' }; |
|
|
const part2: Part = { inlineData: { mimeType: 'image/gif', data: 'gif' } }; |
|
|
const listContainingSingleParts: PartListUnion[] = [ |
|
|
part1, |
|
|
[part2], |
|
|
{ text: 'Another single part' }, |
|
|
]; |
|
|
const result = mergePartListUnions(listContainingSingleParts); |
|
|
expect(result).toEqual([ |
|
|
{ text: 'Single part 1' }, |
|
|
{ inlineData: { mimeType: 'image/gif', data: 'gif' } }, |
|
|
{ text: 'Another single part' }, |
|
|
]); |
|
|
}); |
|
|
|
|
|
it('should handle a mix of arrays and single parts, including empty arrays and undefined/null parts if they were possible (though PartListUnion typing restricts this)', () => { |
|
|
const list1: PartListUnion = [{ text: 'A' }]; |
|
|
const list2: PartListUnion = []; |
|
|
const part3: Part = { text: 'B' }; |
|
|
const list4: PartListUnion = [ |
|
|
{ text: 'C' }, |
|
|
{ inlineData: { mimeType: 'text/plain', data: 'D' } }, |
|
|
]; |
|
|
const result = mergePartListUnions([list1, list2, part3, list4]); |
|
|
expect(result).toEqual([ |
|
|
{ text: 'A' }, |
|
|
{ text: 'B' }, |
|
|
{ text: 'C' }, |
|
|
{ inlineData: { mimeType: 'text/plain', data: 'D' } }, |
|
|
]); |
|
|
}); |
|
|
|
|
|
it('should preserve the order of parts from the input arrays', () => { |
|
|
const listA: PartListUnion = [{ text: '1' }, { text: '2' }]; |
|
|
const listB: PartListUnion = [{ text: '3' }]; |
|
|
const listC: PartListUnion = [{ text: '4' }, { text: '5' }]; |
|
|
const result = mergePartListUnions([listA, listB, listC]); |
|
|
expect(result).toEqual([ |
|
|
{ text: '1' }, |
|
|
{ text: '2' }, |
|
|
{ text: '3' }, |
|
|
{ text: '4' }, |
|
|
{ text: '5' }, |
|
|
]); |
|
|
}); |
|
|
|
|
|
it('should handle cases where some PartListUnion items are single Parts and others are arrays of Parts', () => { |
|
|
const singlePart1: Part = { text: 'First single' }; |
|
|
const arrayPart1: Part[] = [ |
|
|
{ text: 'Array item 1' }, |
|
|
{ text: 'Array item 2' }, |
|
|
]; |
|
|
const singlePart2: Part = { |
|
|
inlineData: { mimeType: 'application/json', data: 'e30=' }, |
|
|
}; |
|
|
const arrayPart2: Part[] = [{ text: 'Last array item' }]; |
|
|
|
|
|
const result = mergePartListUnions([ |
|
|
singlePart1, |
|
|
arrayPart1, |
|
|
singlePart2, |
|
|
arrayPart2, |
|
|
]); |
|
|
expect(result).toEqual([ |
|
|
{ text: 'First single' }, |
|
|
{ text: 'Array item 1' }, |
|
|
{ text: 'Array item 2' }, |
|
|
{ inlineData: { mimeType: 'application/json', data: 'e30=' } }, |
|
|
{ text: 'Last array item' }, |
|
|
]); |
|
|
}); |
|
|
}); |
|
|
|
|
|
|
|
|
describe('useGeminiStream', () => { |
|
|
let mockAddItem: Mock; |
|
|
let mockSetShowHelp: Mock; |
|
|
let mockConfig: Config; |
|
|
let mockOnDebugMessage: Mock; |
|
|
let mockHandleSlashCommand: Mock; |
|
|
let mockScheduleToolCalls: Mock; |
|
|
let mockCancelAllToolCalls: Mock; |
|
|
let mockMarkToolsAsSubmitted: Mock; |
|
|
|
|
|
beforeEach(() => { |
|
|
vi.clearAllMocks(); |
|
|
|
|
|
mockAddItem = vi.fn(); |
|
|
mockSetShowHelp = vi.fn(); |
|
|
|
|
|
const mockGetGeminiClient = vi.fn().mockImplementation(() => { |
|
|
|
|
|
|
|
|
const clientInstance = new MockedGeminiClientClass(mockConfig); |
|
|
return clientInstance; |
|
|
}); |
|
|
|
|
|
mockConfig = { |
|
|
apiKey: 'test-api-key', |
|
|
model: 'gemini-pro', |
|
|
sandbox: false, |
|
|
targetDir: '/test/dir', |
|
|
debugMode: false, |
|
|
question: undefined, |
|
|
fullContext: false, |
|
|
coreTools: [], |
|
|
toolDiscoveryCommand: undefined, |
|
|
toolCallCommand: undefined, |
|
|
mcpServerCommand: undefined, |
|
|
mcpServers: undefined, |
|
|
userAgent: 'test-agent', |
|
|
userMemory: '', |
|
|
geminiMdFileCount: 0, |
|
|
alwaysSkipModificationConfirmation: false, |
|
|
vertexai: false, |
|
|
showMemoryUsage: false, |
|
|
contextFileName: undefined, |
|
|
getToolRegistry: vi.fn( |
|
|
() => ({ getToolSchemaList: vi.fn(() => []) }) as any, |
|
|
), |
|
|
getProjectRoot: vi.fn(() => '/test/dir'), |
|
|
getCheckpointingEnabled: vi.fn(() => false), |
|
|
getGeminiClient: mockGetGeminiClient, |
|
|
getUsageStatisticsEnabled: () => true, |
|
|
getDebugMode: () => false, |
|
|
addHistory: vi.fn(), |
|
|
} as unknown as Config; |
|
|
mockOnDebugMessage = vi.fn(); |
|
|
mockHandleSlashCommand = vi.fn().mockResolvedValue(false); |
|
|
|
|
|
|
|
|
mockScheduleToolCalls = vi.fn(); |
|
|
mockCancelAllToolCalls = vi.fn(); |
|
|
mockMarkToolsAsSubmitted = vi.fn(); |
|
|
|
|
|
|
|
|
mockUseReactToolScheduler.mockReturnValue([ |
|
|
[], |
|
|
mockScheduleToolCalls, |
|
|
mockCancelAllToolCalls, |
|
|
mockMarkToolsAsSubmitted, |
|
|
]); |
|
|
|
|
|
|
|
|
|
|
|
mockStartChat.mockClear().mockResolvedValue({ |
|
|
sendMessageStream: mockSendMessageStream, |
|
|
} as unknown as any); |
|
|
mockSendMessageStream |
|
|
.mockClear() |
|
|
.mockReturnValue((async function* () {})()); |
|
|
}); |
|
|
|
|
|
const mockLoadedSettings: LoadedSettings = { |
|
|
merged: { preferredEditor: 'vscode' }, |
|
|
user: { path: '/user/settings.json', settings: {} }, |
|
|
workspace: { path: '/workspace/.gemini/settings.json', settings: {} }, |
|
|
errors: [], |
|
|
forScope: vi.fn(), |
|
|
setValue: vi.fn(), |
|
|
} as unknown as LoadedSettings; |
|
|
|
|
|
const renderTestHook = ( |
|
|
initialToolCalls: TrackedToolCall[] = [], |
|
|
geminiClient?: any, |
|
|
) => { |
|
|
let currentToolCalls = initialToolCalls; |
|
|
const setToolCalls = (newToolCalls: TrackedToolCall[]) => { |
|
|
currentToolCalls = newToolCalls; |
|
|
}; |
|
|
|
|
|
mockUseReactToolScheduler.mockImplementation(() => [ |
|
|
currentToolCalls, |
|
|
mockScheduleToolCalls, |
|
|
mockCancelAllToolCalls, |
|
|
mockMarkToolsAsSubmitted, |
|
|
]); |
|
|
|
|
|
const client = geminiClient || mockConfig.getGeminiClient(); |
|
|
|
|
|
const { result, rerender } = renderHook( |
|
|
(props: { |
|
|
client: any; |
|
|
history: HistoryItem[]; |
|
|
addItem: UseHistoryManagerReturn['addItem']; |
|
|
setShowHelp: Dispatch<SetStateAction<boolean>>; |
|
|
config: Config; |
|
|
onDebugMessage: (message: string) => void; |
|
|
handleSlashCommand: ( |
|
|
cmd: PartListUnion, |
|
|
) => Promise< |
|
|
| import('./slashCommandProcessor.js').SlashCommandActionReturn |
|
|
| boolean |
|
|
>; |
|
|
shellModeActive: boolean; |
|
|
loadedSettings: LoadedSettings; |
|
|
toolCalls?: TrackedToolCall[]; // Allow passing updated toolCalls |
|
|
}) => { |
|
|
|
|
|
if (props.toolCalls) { |
|
|
setToolCalls(props.toolCalls); |
|
|
} |
|
|
return useGeminiStream( |
|
|
props.client, |
|
|
props.history, |
|
|
props.addItem, |
|
|
props.setShowHelp, |
|
|
props.config, |
|
|
props.onDebugMessage, |
|
|
props.handleSlashCommand, |
|
|
props.shellModeActive, |
|
|
() => 'vscode' as EditorType, |
|
|
() => {}, |
|
|
() => Promise.resolve(), |
|
|
); |
|
|
}, |
|
|
{ |
|
|
initialProps: { |
|
|
client, |
|
|
history: [], |
|
|
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'], |
|
|
setShowHelp: mockSetShowHelp, |
|
|
config: mockConfig, |
|
|
onDebugMessage: mockOnDebugMessage, |
|
|
handleSlashCommand: mockHandleSlashCommand as unknown as ( |
|
|
cmd: PartListUnion, |
|
|
) => Promise< |
|
|
| import('./slashCommandProcessor.js').SlashCommandActionReturn |
|
|
| boolean |
|
|
>, |
|
|
shellModeActive: false, |
|
|
loadedSettings: mockLoadedSettings, |
|
|
toolCalls: initialToolCalls, |
|
|
}, |
|
|
}, |
|
|
); |
|
|
return { |
|
|
result, |
|
|
rerender, |
|
|
mockMarkToolsAsSubmitted, |
|
|
mockSendMessageStream, |
|
|
client, |
|
|
}; |
|
|
}; |
|
|
|
|
|
it('should not submit tool responses if not all tool calls are completed', () => { |
|
|
const toolCalls: TrackedToolCall[] = [ |
|
|
{ |
|
|
request: { |
|
|
callId: 'call1', |
|
|
name: 'tool1', |
|
|
args: {}, |
|
|
isClientInitiated: false, |
|
|
}, |
|
|
status: 'success', |
|
|
responseSubmittedToGemini: false, |
|
|
response: { |
|
|
callId: 'call1', |
|
|
responseParts: [{ text: 'tool 1 response' }], |
|
|
error: undefined, |
|
|
resultDisplay: 'Tool 1 success display', |
|
|
}, |
|
|
tool: { |
|
|
name: 'tool1', |
|
|
description: 'desc1', |
|
|
getDescription: vi.fn(), |
|
|
} as any, |
|
|
startTime: Date.now(), |
|
|
endTime: Date.now(), |
|
|
} as TrackedCompletedToolCall, |
|
|
{ |
|
|
request: { callId: 'call2', name: 'tool2', args: {} }, |
|
|
status: 'executing', |
|
|
responseSubmittedToGemini: false, |
|
|
tool: { |
|
|
name: 'tool2', |
|
|
description: 'desc2', |
|
|
getDescription: vi.fn(), |
|
|
} as any, |
|
|
startTime: Date.now(), |
|
|
liveOutput: '...', |
|
|
} as TrackedExecutingToolCall, |
|
|
]; |
|
|
|
|
|
const { mockMarkToolsAsSubmitted, mockSendMessageStream } = |
|
|
renderTestHook(toolCalls); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
expect(mockMarkToolsAsSubmitted).not.toHaveBeenCalled(); |
|
|
expect(mockSendMessageStream).not.toHaveBeenCalled(); |
|
|
}); |
|
|
|
|
|
it('should submit tool responses when all tool calls are completed and ready', async () => { |
|
|
const toolCall1ResponseParts: PartListUnion = [ |
|
|
{ text: 'tool 1 final response' }, |
|
|
]; |
|
|
const toolCall2ResponseParts: PartListUnion = [ |
|
|
{ text: 'tool 2 final response' }, |
|
|
]; |
|
|
const completedToolCalls: TrackedToolCall[] = [ |
|
|
{ |
|
|
request: { |
|
|
callId: 'call1', |
|
|
name: 'tool1', |
|
|
args: {}, |
|
|
isClientInitiated: false, |
|
|
}, |
|
|
status: 'success', |
|
|
responseSubmittedToGemini: false, |
|
|
response: { callId: 'call1', responseParts: toolCall1ResponseParts }, |
|
|
} as TrackedCompletedToolCall, |
|
|
{ |
|
|
request: { |
|
|
callId: 'call2', |
|
|
name: 'tool2', |
|
|
args: {}, |
|
|
isClientInitiated: false, |
|
|
}, |
|
|
status: 'error', |
|
|
responseSubmittedToGemini: false, |
|
|
response: { callId: 'call2', responseParts: toolCall2ResponseParts }, |
|
|
} as TrackedCompletedToolCall, |
|
|
]; |
|
|
|
|
|
|
|
|
mockUseReactToolScheduler.mockReturnValue([ |
|
|
[], |
|
|
mockScheduleToolCalls, |
|
|
mockMarkToolsAsSubmitted, |
|
|
]); |
|
|
const { rerender } = renderHook(() => |
|
|
useGeminiStream( |
|
|
new MockedGeminiClientClass(mockConfig), |
|
|
[], |
|
|
mockAddItem, |
|
|
mockSetShowHelp, |
|
|
mockConfig, |
|
|
mockOnDebugMessage, |
|
|
mockHandleSlashCommand, |
|
|
false, |
|
|
() => 'vscode' as EditorType, |
|
|
() => {}, |
|
|
() => Promise.resolve(), |
|
|
), |
|
|
); |
|
|
|
|
|
|
|
|
mockUseReactToolScheduler.mockReturnValue([ |
|
|
completedToolCalls, |
|
|
mockScheduleToolCalls, |
|
|
mockMarkToolsAsSubmitted, |
|
|
]); |
|
|
|
|
|
|
|
|
act(() => { |
|
|
rerender(); |
|
|
}); |
|
|
|
|
|
await waitFor(() => { |
|
|
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledTimes(1); |
|
|
expect(mockSendMessageStream).toHaveBeenCalledTimes(1); |
|
|
}); |
|
|
|
|
|
const expectedMergedResponse = mergePartListUnions([ |
|
|
toolCall1ResponseParts, |
|
|
toolCall2ResponseParts, |
|
|
]); |
|
|
expect(mockSendMessageStream).toHaveBeenCalledWith( |
|
|
expectedMergedResponse, |
|
|
expect.any(AbortSignal), |
|
|
); |
|
|
}); |
|
|
|
|
|
it('should handle all tool calls being cancelled', async () => { |
|
|
const cancelledToolCalls: TrackedToolCall[] = [ |
|
|
{ |
|
|
request: { |
|
|
callId: '1', |
|
|
name: 'testTool', |
|
|
args: {}, |
|
|
isClientInitiated: false, |
|
|
}, |
|
|
status: 'cancelled', |
|
|
response: { callId: '1', responseParts: [{ text: 'cancelled' }] }, |
|
|
responseSubmittedToGemini: false, |
|
|
} as TrackedCancelledToolCall, |
|
|
]; |
|
|
const client = new MockedGeminiClientClass(mockConfig); |
|
|
|
|
|
|
|
|
mockUseReactToolScheduler.mockReturnValue([ |
|
|
[], |
|
|
mockScheduleToolCalls, |
|
|
mockMarkToolsAsSubmitted, |
|
|
]); |
|
|
const { rerender } = renderHook(() => |
|
|
useGeminiStream( |
|
|
client, |
|
|
[], |
|
|
mockAddItem, |
|
|
mockSetShowHelp, |
|
|
mockConfig, |
|
|
mockOnDebugMessage, |
|
|
mockHandleSlashCommand, |
|
|
false, |
|
|
() => 'vscode' as EditorType, |
|
|
() => {}, |
|
|
() => Promise.resolve(), |
|
|
), |
|
|
); |
|
|
|
|
|
|
|
|
mockUseReactToolScheduler.mockReturnValue([ |
|
|
cancelledToolCalls, |
|
|
mockScheduleToolCalls, |
|
|
mockMarkToolsAsSubmitted, |
|
|
]); |
|
|
|
|
|
|
|
|
act(() => { |
|
|
rerender(); |
|
|
}); |
|
|
|
|
|
await waitFor(() => { |
|
|
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['1']); |
|
|
expect(client.addHistory).toHaveBeenCalledWith({ |
|
|
role: 'user', |
|
|
parts: [{ text: 'cancelled' }], |
|
|
}); |
|
|
|
|
|
expect(mockSendMessageStream).not.toHaveBeenCalled(); |
|
|
}); |
|
|
}); |
|
|
|
|
|
describe('Session Stats Integration', () => { |
|
|
it('should call startNewTurn and addUsage for a simple prompt', async () => { |
|
|
const mockMetadata = { totalTokenCount: 123 }; |
|
|
const mockStream = (async function* () { |
|
|
yield { type: 'content', value: 'Response' }; |
|
|
yield { type: 'usage_metadata', value: mockMetadata }; |
|
|
})(); |
|
|
mockSendMessageStream.mockReturnValue(mockStream); |
|
|
|
|
|
const { result } = renderTestHook(); |
|
|
|
|
|
await act(async () => { |
|
|
await result.current.submitQuery('Hello, world!'); |
|
|
}); |
|
|
|
|
|
expect(mockStartNewTurn).toHaveBeenCalledTimes(1); |
|
|
expect(mockAddUsage).toHaveBeenCalledTimes(1); |
|
|
expect(mockAddUsage).toHaveBeenCalledWith(mockMetadata); |
|
|
}); |
|
|
|
|
|
it('should only call addUsage for a tool continuation prompt', async () => { |
|
|
const mockMetadata = { totalTokenCount: 456 }; |
|
|
const mockStream = (async function* () { |
|
|
yield { type: 'content', value: 'Final Answer' }; |
|
|
yield { type: 'usage_metadata', value: mockMetadata }; |
|
|
})(); |
|
|
mockSendMessageStream.mockReturnValue(mockStream); |
|
|
|
|
|
const { result } = renderTestHook(); |
|
|
|
|
|
await act(async () => { |
|
|
await result.current.submitQuery([{ text: 'tool response' }], { |
|
|
isContinuation: true, |
|
|
}); |
|
|
}); |
|
|
|
|
|
expect(mockStartNewTurn).not.toHaveBeenCalled(); |
|
|
expect(mockAddUsage).toHaveBeenCalledTimes(1); |
|
|
expect(mockAddUsage).toHaveBeenCalledWith(mockMetadata); |
|
|
}); |
|
|
|
|
|
it('should not call addUsage if the stream contains no usage metadata', async () => { |
|
|
|
|
|
const mockStream = (async function* () { |
|
|
yield { type: 'content', value: 'Some response text' }; |
|
|
})(); |
|
|
mockSendMessageStream.mockReturnValue(mockStream); |
|
|
|
|
|
const { result } = renderTestHook(); |
|
|
|
|
|
await act(async () => { |
|
|
await result.current.submitQuery('Query with no usage data'); |
|
|
}); |
|
|
|
|
|
expect(mockStartNewTurn).toHaveBeenCalledTimes(1); |
|
|
expect(mockAddUsage).not.toHaveBeenCalled(); |
|
|
}); |
|
|
|
|
|
it('should not call startNewTurn for a slash command', async () => { |
|
|
mockHandleSlashCommand.mockReturnValue(true); |
|
|
|
|
|
const { result } = renderTestHook(); |
|
|
|
|
|
await act(async () => { |
|
|
await result.current.submitQuery('/stats'); |
|
|
}); |
|
|
|
|
|
expect(mockStartNewTurn).not.toHaveBeenCalled(); |
|
|
expect(mockSendMessageStream).not.toHaveBeenCalled(); |
|
|
}); |
|
|
}); |
|
|
|
|
|
it('should not flicker streaming state to Idle between tool completion and submission', async () => { |
|
|
const toolCallResponseParts: PartListUnion = [ |
|
|
{ text: 'tool 1 final response' }, |
|
|
]; |
|
|
|
|
|
const initialToolCalls: TrackedToolCall[] = [ |
|
|
{ |
|
|
request: { callId: 'call1', name: 'tool1', args: {} }, |
|
|
status: 'executing', |
|
|
responseSubmittedToGemini: false, |
|
|
tool: { |
|
|
name: 'tool1', |
|
|
description: 'desc', |
|
|
getDescription: vi.fn(), |
|
|
} as any, |
|
|
startTime: Date.now(), |
|
|
} as TrackedExecutingToolCall, |
|
|
]; |
|
|
|
|
|
const completedToolCalls: TrackedToolCall[] = [ |
|
|
{ |
|
|
...(initialToolCalls[0] as TrackedExecutingToolCall), |
|
|
status: 'success', |
|
|
response: { |
|
|
callId: 'call1', |
|
|
responseParts: toolCallResponseParts, |
|
|
error: undefined, |
|
|
resultDisplay: 'Tool 1 success display', |
|
|
}, |
|
|
endTime: Date.now(), |
|
|
} as TrackedCompletedToolCall, |
|
|
]; |
|
|
|
|
|
const { result, rerender, client } = renderTestHook(initialToolCalls); |
|
|
|
|
|
|
|
|
expect(result.current.streamingState).toBe(StreamingState.Responding); |
|
|
|
|
|
|
|
|
|
|
|
act(() => { |
|
|
rerender({ |
|
|
client, |
|
|
history: [], |
|
|
addItem: mockAddItem, |
|
|
setShowHelp: mockSetShowHelp, |
|
|
config: mockConfig, |
|
|
onDebugMessage: mockOnDebugMessage, |
|
|
handleSlashCommand: |
|
|
mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand, |
|
|
shellModeActive: false, |
|
|
loadedSettings: mockLoadedSettings, |
|
|
|
|
|
|
|
|
toolCalls: completedToolCalls, |
|
|
}); |
|
|
}); |
|
|
|
|
|
|
|
|
|
|
|
expect(result.current.streamingState).toBe(StreamingState.Responding); |
|
|
|
|
|
|
|
|
await waitFor(() => { |
|
|
expect(mockSendMessageStream).toHaveBeenCalledWith( |
|
|
toolCallResponseParts, |
|
|
expect.any(AbortSignal), |
|
|
); |
|
|
}); |
|
|
|
|
|
|
|
|
expect(result.current.streamingState).toBe(StreamingState.Responding); |
|
|
}); |
|
|
|
|
|
describe('User Cancellation', () => { |
|
|
let useInputCallback: (input: string, key: any) => void; |
|
|
const mockUseInput = useInput as Mock; |
|
|
|
|
|
beforeEach(() => { |
|
|
|
|
|
mockUseInput.mockImplementation((callback) => { |
|
|
useInputCallback = callback; |
|
|
}); |
|
|
}); |
|
|
|
|
|
const simulateEscapeKeyPress = () => { |
|
|
act(() => { |
|
|
useInputCallback('', { escape: true }); |
|
|
}); |
|
|
}; |
|
|
|
|
|
it('should cancel an in-progress stream when escape is pressed', async () => { |
|
|
const mockStream = (async function* () { |
|
|
yield { type: 'content', value: 'Part 1' }; |
|
|
|
|
|
await new Promise(() => {}); |
|
|
})(); |
|
|
mockSendMessageStream.mockReturnValue(mockStream); |
|
|
|
|
|
const { result } = renderTestHook(); |
|
|
|
|
|
|
|
|
await act(async () => { |
|
|
result.current.submitQuery('test query'); |
|
|
}); |
|
|
|
|
|
|
|
|
await waitFor(() => { |
|
|
expect(result.current.streamingState).toBe(StreamingState.Responding); |
|
|
}); |
|
|
|
|
|
|
|
|
simulateEscapeKeyPress(); |
|
|
|
|
|
|
|
|
await waitFor(() => { |
|
|
expect(mockAddItem).toHaveBeenCalledWith( |
|
|
{ |
|
|
type: MessageType.INFO, |
|
|
text: 'Request cancelled.', |
|
|
}, |
|
|
expect.any(Number), |
|
|
); |
|
|
}); |
|
|
|
|
|
|
|
|
expect(result.current.streamingState).toBe(StreamingState.Idle); |
|
|
}); |
|
|
|
|
|
it('should not do anything if escape is pressed when not responding', () => { |
|
|
const { result } = renderTestHook(); |
|
|
|
|
|
expect(result.current.streamingState).toBe(StreamingState.Idle); |
|
|
|
|
|
|
|
|
simulateEscapeKeyPress(); |
|
|
|
|
|
|
|
|
expect(mockAddItem).not.toHaveBeenCalledWith( |
|
|
expect.objectContaining({ |
|
|
text: 'Request cancelled.', |
|
|
}), |
|
|
expect.any(Number), |
|
|
); |
|
|
}); |
|
|
|
|
|
it('should prevent further processing after cancellation', async () => { |
|
|
let continueStream: () => void; |
|
|
const streamPromise = new Promise<void>((resolve) => { |
|
|
continueStream = resolve; |
|
|
}); |
|
|
|
|
|
const mockStream = (async function* () { |
|
|
yield { type: 'content', value: 'Initial' }; |
|
|
await streamPromise; |
|
|
yield { type: 'content', value: ' Canceled' }; |
|
|
})(); |
|
|
mockSendMessageStream.mockReturnValue(mockStream); |
|
|
|
|
|
const { result } = renderTestHook(); |
|
|
|
|
|
await act(async () => { |
|
|
result.current.submitQuery('long running query'); |
|
|
}); |
|
|
|
|
|
await waitFor(() => { |
|
|
expect(result.current.streamingState).toBe(StreamingState.Responding); |
|
|
}); |
|
|
|
|
|
|
|
|
simulateEscapeKeyPress(); |
|
|
|
|
|
|
|
|
act(() => { |
|
|
continueStream(); |
|
|
}); |
|
|
|
|
|
|
|
|
await new Promise((resolve) => setTimeout(resolve, 50)); |
|
|
|
|
|
|
|
|
const lastCall = mockAddItem.mock.calls.find( |
|
|
(call) => call[0].type === 'gemini', |
|
|
); |
|
|
expect(lastCall?.[0].text).toBe('Initial'); |
|
|
|
|
|
|
|
|
expect(result.current.streamingState).toBe(StreamingState.Idle); |
|
|
}); |
|
|
|
|
|
it('should not cancel if a tool call is in progress (not just responding)', async () => { |
|
|
const toolCalls: TrackedToolCall[] = [ |
|
|
{ |
|
|
request: { callId: 'call1', name: 'tool1', args: {} }, |
|
|
status: 'executing', |
|
|
responseSubmittedToGemini: false, |
|
|
tool: { |
|
|
name: 'tool1', |
|
|
description: 'desc1', |
|
|
getDescription: vi.fn(), |
|
|
} as any, |
|
|
startTime: Date.now(), |
|
|
liveOutput: '...', |
|
|
} as TrackedExecutingToolCall, |
|
|
]; |
|
|
|
|
|
const abortSpy = vi.spyOn(AbortController.prototype, 'abort'); |
|
|
const { result } = renderTestHook(toolCalls); |
|
|
|
|
|
|
|
|
expect(result.current.streamingState).toBe(StreamingState.Responding); |
|
|
|
|
|
|
|
|
simulateEscapeKeyPress(); |
|
|
|
|
|
|
|
|
expect(abortSpy).not.toHaveBeenCalled(); |
|
|
}); |
|
|
}); |
|
|
|
|
|
describe('Client-Initiated Tool Calls', () => { |
|
|
it('should execute a client-initiated tool without sending a response to Gemini', async () => { |
|
|
const clientToolRequest = { |
|
|
shouldScheduleTool: true, |
|
|
toolName: 'save_memory', |
|
|
toolArgs: { fact: 'test fact' }, |
|
|
}; |
|
|
mockHandleSlashCommand.mockResolvedValue(clientToolRequest); |
|
|
|
|
|
const completedToolCall: TrackedCompletedToolCall = { |
|
|
request: { |
|
|
callId: 'client-call-1', |
|
|
name: clientToolRequest.toolName, |
|
|
args: clientToolRequest.toolArgs, |
|
|
isClientInitiated: true, |
|
|
}, |
|
|
status: 'success', |
|
|
responseSubmittedToGemini: false, |
|
|
response: { |
|
|
callId: 'client-call-1', |
|
|
responseParts: [{ text: 'Memory saved' }], |
|
|
resultDisplay: 'Success: Memory saved', |
|
|
error: undefined, |
|
|
}, |
|
|
tool: { |
|
|
name: clientToolRequest.toolName, |
|
|
description: 'Saves memory', |
|
|
getDescription: vi.fn(), |
|
|
} as any, |
|
|
}; |
|
|
|
|
|
|
|
|
mockUseReactToolScheduler.mockReturnValue([ |
|
|
[], |
|
|
mockScheduleToolCalls, |
|
|
mockMarkToolsAsSubmitted, |
|
|
]); |
|
|
|
|
|
const { result, rerender } = renderHook(() => |
|
|
useGeminiStream( |
|
|
new MockedGeminiClientClass(mockConfig), |
|
|
[], |
|
|
mockAddItem, |
|
|
mockSetShowHelp, |
|
|
mockConfig, |
|
|
mockOnDebugMessage, |
|
|
mockHandleSlashCommand, |
|
|
false, |
|
|
() => 'vscode' as EditorType, |
|
|
() => {}, |
|
|
() => Promise.resolve(), |
|
|
), |
|
|
); |
|
|
|
|
|
|
|
|
await act(async () => { |
|
|
await result.current.submitQuery('/memory add "test fact"'); |
|
|
}); |
|
|
|
|
|
|
|
|
|
|
|
mockUseReactToolScheduler.mockReturnValue([ |
|
|
[completedToolCall], |
|
|
mockScheduleToolCalls, |
|
|
mockMarkToolsAsSubmitted, |
|
|
]); |
|
|
|
|
|
|
|
|
act(() => { |
|
|
rerender(); |
|
|
}); |
|
|
|
|
|
|
|
|
await waitFor(() => { |
|
|
|
|
|
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith([ |
|
|
'client-call-1', |
|
|
]); |
|
|
|
|
|
expect(mockSendMessageStream).not.toHaveBeenCalled(); |
|
|
}); |
|
|
}); |
|
|
}); |
|
|
|
|
|
describe('Memory Refresh on save_memory', () => { |
|
|
it('should call performMemoryRefresh when a save_memory tool call completes successfully', async () => { |
|
|
const mockPerformMemoryRefresh = vi.fn(); |
|
|
const completedToolCall: TrackedCompletedToolCall = { |
|
|
request: { |
|
|
callId: 'save-mem-call-1', |
|
|
name: 'save_memory', |
|
|
args: { fact: 'test' }, |
|
|
isClientInitiated: true, |
|
|
}, |
|
|
status: 'success', |
|
|
responseSubmittedToGemini: false, |
|
|
response: { |
|
|
callId: 'save-mem-call-1', |
|
|
responseParts: [{ text: 'Memory saved' }], |
|
|
resultDisplay: 'Success: Memory saved', |
|
|
error: undefined, |
|
|
}, |
|
|
tool: { |
|
|
name: 'save_memory', |
|
|
description: 'Saves memory', |
|
|
getDescription: vi.fn(), |
|
|
} as any, |
|
|
}; |
|
|
|
|
|
mockUseReactToolScheduler.mockReturnValue([ |
|
|
[completedToolCall], |
|
|
mockScheduleToolCalls, |
|
|
mockMarkToolsAsSubmitted, |
|
|
]); |
|
|
|
|
|
const { rerender } = renderHook(() => |
|
|
useGeminiStream( |
|
|
new MockedGeminiClientClass(mockConfig), |
|
|
[], |
|
|
mockAddItem, |
|
|
mockSetShowHelp, |
|
|
mockConfig, |
|
|
mockOnDebugMessage, |
|
|
mockHandleSlashCommand, |
|
|
false, |
|
|
() => 'vscode' as EditorType, |
|
|
() => {}, |
|
|
mockPerformMemoryRefresh, |
|
|
), |
|
|
); |
|
|
|
|
|
act(() => { |
|
|
rerender(); |
|
|
}); |
|
|
|
|
|
await waitFor(() => { |
|
|
expect(mockPerformMemoryRefresh).toHaveBeenCalledTimes(1); |
|
|
}); |
|
|
}); |
|
|
}); |
|
|
|
|
|
describe('Error Handling', () => { |
|
|
it('should call parseAndFormatApiError with the correct authType on stream initialization failure', async () => { |
|
|
|
|
|
const mockError = new Error('Rate limit exceeded'); |
|
|
const mockAuthType = AuthType.LOGIN_WITH_GOOGLE_PERSONAL; |
|
|
mockParseAndFormatApiError.mockClear(); |
|
|
mockSendMessageStream.mockReturnValue( |
|
|
(async function* () { |
|
|
yield { type: 'content', value: '' }; |
|
|
throw mockError; |
|
|
})(), |
|
|
); |
|
|
|
|
|
const testConfig = { |
|
|
...mockConfig, |
|
|
getContentGeneratorConfig: vi.fn(() => ({ |
|
|
authType: mockAuthType, |
|
|
})), |
|
|
} as unknown as Config; |
|
|
|
|
|
const { result } = renderHook(() => |
|
|
useGeminiStream( |
|
|
new MockedGeminiClientClass(testConfig), |
|
|
[], |
|
|
mockAddItem, |
|
|
mockSetShowHelp, |
|
|
testConfig, |
|
|
mockOnDebugMessage, |
|
|
mockHandleSlashCommand, |
|
|
false, |
|
|
() => 'vscode' as EditorType, |
|
|
() => {}, |
|
|
() => Promise.resolve(), |
|
|
), |
|
|
); |
|
|
|
|
|
|
|
|
await act(async () => { |
|
|
await result.current.submitQuery('test query'); |
|
|
}); |
|
|
|
|
|
|
|
|
await waitFor(() => { |
|
|
expect(mockParseAndFormatApiError).toHaveBeenCalledWith( |
|
|
'Rate limit exceeded', |
|
|
mockAuthType, |
|
|
); |
|
|
}); |
|
|
}); |
|
|
}); |
|
|
}); |
|
|
|