|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import { type MutableRefObject } from 'react'; |
|
|
import { render } from 'ink-testing-library'; |
|
|
import { act } from 'react-dom/test-utils'; |
|
|
import { SessionStatsProvider, useSessionStats } from './SessionContext.js'; |
|
|
import { describe, it, expect, vi } from 'vitest'; |
|
|
import { GenerateContentResponseUsageMetadata } from '@google/genai'; |
|
|
|
|
|
|
|
|
const mockMetadata1: GenerateContentResponseUsageMetadata = { |
|
|
promptTokenCount: 100, |
|
|
candidatesTokenCount: 200, |
|
|
totalTokenCount: 300, |
|
|
cachedContentTokenCount: 50, |
|
|
toolUsePromptTokenCount: 10, |
|
|
thoughtsTokenCount: 20, |
|
|
}; |
|
|
|
|
|
const mockMetadata2: GenerateContentResponseUsageMetadata = { |
|
|
promptTokenCount: 10, |
|
|
candidatesTokenCount: 20, |
|
|
totalTokenCount: 30, |
|
|
cachedContentTokenCount: 5, |
|
|
toolUsePromptTokenCount: 1, |
|
|
thoughtsTokenCount: 2, |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const TestHarness = ({ |
|
|
contextRef, |
|
|
}: { |
|
|
contextRef: MutableRefObject<ReturnType<typeof useSessionStats> | undefined>; |
|
|
}) => { |
|
|
contextRef.current = useSessionStats(); |
|
|
return null; |
|
|
}; |
|
|
|
|
|
describe('SessionStatsContext', () => { |
|
|
it('should provide the correct initial state', () => { |
|
|
const contextRef: MutableRefObject< |
|
|
ReturnType<typeof useSessionStats> | undefined |
|
|
> = { current: undefined }; |
|
|
|
|
|
render( |
|
|
<SessionStatsProvider> |
|
|
<TestHarness contextRef={contextRef} /> |
|
|
</SessionStatsProvider>, |
|
|
); |
|
|
|
|
|
const stats = contextRef.current?.stats; |
|
|
|
|
|
expect(stats?.sessionStartTime).toBeInstanceOf(Date); |
|
|
expect(stats?.currentTurn).toBeDefined(); |
|
|
expect(stats?.cumulative.turnCount).toBe(0); |
|
|
expect(stats?.cumulative.totalTokenCount).toBe(0); |
|
|
expect(stats?.cumulative.promptTokenCount).toBe(0); |
|
|
}); |
|
|
|
|
|
it('should increment turnCount when startNewTurn is called', () => { |
|
|
const contextRef: MutableRefObject< |
|
|
ReturnType<typeof useSessionStats> | undefined |
|
|
> = { current: undefined }; |
|
|
|
|
|
render( |
|
|
<SessionStatsProvider> |
|
|
<TestHarness contextRef={contextRef} /> |
|
|
</SessionStatsProvider>, |
|
|
); |
|
|
|
|
|
act(() => { |
|
|
contextRef.current?.startNewTurn(); |
|
|
}); |
|
|
|
|
|
const stats = contextRef.current?.stats; |
|
|
expect(stats?.currentTurn.totalTokenCount).toBe(0); |
|
|
expect(stats?.cumulative.turnCount).toBe(1); |
|
|
|
|
|
expect(stats?.cumulative.totalTokenCount).toBe(0); |
|
|
}); |
|
|
|
|
|
it('should aggregate token usage correctly when addUsage is called', () => { |
|
|
const contextRef: MutableRefObject< |
|
|
ReturnType<typeof useSessionStats> | undefined |
|
|
> = { current: undefined }; |
|
|
|
|
|
render( |
|
|
<SessionStatsProvider> |
|
|
<TestHarness contextRef={contextRef} /> |
|
|
</SessionStatsProvider>, |
|
|
); |
|
|
|
|
|
act(() => { |
|
|
contextRef.current?.addUsage({ ...mockMetadata1, apiTimeMs: 123 }); |
|
|
}); |
|
|
|
|
|
const stats = contextRef.current?.stats; |
|
|
|
|
|
|
|
|
expect(stats?.cumulative.totalTokenCount).toBe( |
|
|
mockMetadata1.totalTokenCount ?? 0, |
|
|
); |
|
|
expect(stats?.cumulative.promptTokenCount).toBe( |
|
|
mockMetadata1.promptTokenCount ?? 0, |
|
|
); |
|
|
expect(stats?.cumulative.apiTimeMs).toBe(123); |
|
|
|
|
|
|
|
|
expect(stats?.cumulative.turnCount).toBe(0); |
|
|
|
|
|
|
|
|
expect(stats?.currentTurn?.totalTokenCount).toEqual( |
|
|
mockMetadata1.totalTokenCount, |
|
|
); |
|
|
expect(stats?.currentTurn?.apiTimeMs).toBe(123); |
|
|
}); |
|
|
|
|
|
it('should correctly track a full logical turn with multiple API calls', () => { |
|
|
const contextRef: MutableRefObject< |
|
|
ReturnType<typeof useSessionStats> | undefined |
|
|
> = { current: undefined }; |
|
|
|
|
|
render( |
|
|
<SessionStatsProvider> |
|
|
<TestHarness contextRef={contextRef} /> |
|
|
</SessionStatsProvider>, |
|
|
); |
|
|
|
|
|
|
|
|
act(() => { |
|
|
contextRef.current?.startNewTurn(); |
|
|
}); |
|
|
|
|
|
|
|
|
act(() => { |
|
|
contextRef.current?.addUsage({ ...mockMetadata1, apiTimeMs: 100 }); |
|
|
}); |
|
|
|
|
|
|
|
|
act(() => { |
|
|
contextRef.current?.addUsage({ ...mockMetadata2, apiTimeMs: 50 }); |
|
|
}); |
|
|
|
|
|
const stats = contextRef.current?.stats; |
|
|
|
|
|
|
|
|
expect(stats?.cumulative.turnCount).toBe(1); |
|
|
|
|
|
|
|
|
|
|
|
expect(stats?.cumulative.totalTokenCount).toBe(300 + 30); |
|
|
expect(stats?.cumulative.candidatesTokenCount).toBe(200 + 20); |
|
|
expect(stats?.cumulative.thoughtsTokenCount).toBe(20 + 2); |
|
|
expect(stats?.cumulative.apiTimeMs).toBe(100 + 50); |
|
|
|
|
|
|
|
|
expect(stats?.cumulative.promptTokenCount).toBe(100 + 10); |
|
|
expect(stats?.cumulative.cachedContentTokenCount).toBe(50 + 5); |
|
|
expect(stats?.cumulative.toolUsePromptTokenCount).toBe(10 + 1); |
|
|
|
|
|
|
|
|
|
|
|
expect(stats?.currentTurn.totalTokenCount).toBe(300 + 30); |
|
|
expect(stats?.currentTurn.candidatesTokenCount).toBe(200 + 20); |
|
|
expect(stats?.currentTurn.thoughtsTokenCount).toBe(20 + 2); |
|
|
expect(stats?.currentTurn.promptTokenCount).toBe(100 + 10); |
|
|
expect(stats?.currentTurn.cachedContentTokenCount).toBe(50 + 5); |
|
|
expect(stats?.currentTurn.toolUsePromptTokenCount).toBe(10 + 1); |
|
|
expect(stats?.currentTurn.apiTimeMs).toBe(100 + 50); |
|
|
}); |
|
|
|
|
|
it('should overwrite currentResponse with each API call', () => { |
|
|
const contextRef: MutableRefObject< |
|
|
ReturnType<typeof useSessionStats> | undefined |
|
|
> = { current: undefined }; |
|
|
|
|
|
render( |
|
|
<SessionStatsProvider> |
|
|
<TestHarness contextRef={contextRef} /> |
|
|
</SessionStatsProvider>, |
|
|
); |
|
|
|
|
|
|
|
|
act(() => { |
|
|
contextRef.current?.addUsage({ ...mockMetadata1, apiTimeMs: 100 }); |
|
|
}); |
|
|
|
|
|
let stats = contextRef.current?.stats; |
|
|
|
|
|
|
|
|
expect(stats?.currentResponse.totalTokenCount).toBe(300); |
|
|
expect(stats?.currentResponse.apiTimeMs).toBe(100); |
|
|
|
|
|
|
|
|
act(() => { |
|
|
contextRef.current?.addUsage({ ...mockMetadata2, apiTimeMs: 50 }); |
|
|
}); |
|
|
|
|
|
stats = contextRef.current?.stats; |
|
|
|
|
|
|
|
|
expect(stats?.currentResponse.totalTokenCount).toBe(30); |
|
|
expect(stats?.currentResponse.apiTimeMs).toBe(50); |
|
|
|
|
|
|
|
|
act(() => { |
|
|
contextRef.current?.startNewTurn(); |
|
|
}); |
|
|
|
|
|
stats = contextRef.current?.stats; |
|
|
|
|
|
|
|
|
expect(stats?.currentResponse.totalTokenCount).toBe(0); |
|
|
expect(stats?.currentResponse.apiTimeMs).toBe(0); |
|
|
}); |
|
|
|
|
|
it('should throw an error when useSessionStats is used outside of a provider', () => { |
|
|
|
|
|
const errorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); |
|
|
|
|
|
const contextRef = { current: undefined }; |
|
|
|
|
|
|
|
|
render(<TestHarness contextRef={contextRef} />); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
expect(errorSpy.mock.calls[0][0]).toContain( |
|
|
'useSessionStats must be used within a SessionStatsProvider', |
|
|
); |
|
|
|
|
|
errorSpy.mockRestore(); |
|
|
}); |
|
|
}); |
|
|
|