|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; |
|
|
import { retryWithBackoff } from './retry.js'; |
|
|
import { setSimulate429 } from './testUtils.js'; |
|
|
|
|
|
|
|
|
interface HttpError extends Error { |
|
|
status?: number; |
|
|
} |
|
|
|
|
|
|
|
|
const createFailingFunction = ( |
|
|
failures: number, |
|
|
successValue: string = 'success', |
|
|
) => { |
|
|
let attempts = 0; |
|
|
return vi.fn(async () => { |
|
|
attempts++; |
|
|
if (attempts <= failures) { |
|
|
|
|
|
const error: HttpError = new Error(`Simulated error attempt ${attempts}`); |
|
|
error.status = 500; |
|
|
throw error; |
|
|
} |
|
|
return successValue; |
|
|
}); |
|
|
}; |
|
|
|
|
|
|
|
|
class NonRetryableError extends Error { |
|
|
constructor(message: string) { |
|
|
super(message); |
|
|
this.name = 'NonRetryableError'; |
|
|
} |
|
|
} |
|
|
|
|
|
describe('retryWithBackoff', () => { |
|
|
beforeEach(() => { |
|
|
vi.useFakeTimers(); |
|
|
|
|
|
setSimulate429(false); |
|
|
|
|
|
console.warn = vi.fn(); |
|
|
}); |
|
|
|
|
|
afterEach(() => { |
|
|
vi.restoreAllMocks(); |
|
|
vi.useRealTimers(); |
|
|
}); |
|
|
|
|
|
it('should return the result on the first attempt if successful', async () => { |
|
|
const mockFn = createFailingFunction(0); |
|
|
const result = await retryWithBackoff(mockFn); |
|
|
expect(result).toBe('success'); |
|
|
expect(mockFn).toHaveBeenCalledTimes(1); |
|
|
}); |
|
|
|
|
|
it('should retry and succeed if failures are within maxAttempts', async () => { |
|
|
const mockFn = createFailingFunction(2); |
|
|
const promise = retryWithBackoff(mockFn, { |
|
|
maxAttempts: 3, |
|
|
initialDelayMs: 10, |
|
|
}); |
|
|
|
|
|
await vi.runAllTimersAsync(); |
|
|
|
|
|
const result = await promise; |
|
|
expect(result).toBe('success'); |
|
|
expect(mockFn).toHaveBeenCalledTimes(3); |
|
|
}); |
|
|
|
|
|
it('should throw an error if all attempts fail', async () => { |
|
|
const mockFn = createFailingFunction(3); |
|
|
|
|
|
|
|
|
const promise = retryWithBackoff(mockFn, { |
|
|
maxAttempts: 3, |
|
|
initialDelayMs: 10, |
|
|
}); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const assertionPromise = expect(promise).rejects.toThrow( |
|
|
'Simulated error attempt 3', |
|
|
); |
|
|
|
|
|
|
|
|
|
|
|
await vi.runAllTimersAsync(); |
|
|
|
|
|
|
|
|
await assertionPromise; |
|
|
|
|
|
|
|
|
expect(mockFn).toHaveBeenCalledTimes(3); |
|
|
}); |
|
|
|
|
|
it('should not retry if shouldRetry returns false', async () => { |
|
|
const mockFn = vi.fn(async () => { |
|
|
throw new NonRetryableError('Non-retryable error'); |
|
|
}); |
|
|
const shouldRetry = (error: Error) => !(error instanceof NonRetryableError); |
|
|
|
|
|
const promise = retryWithBackoff(mockFn, { |
|
|
shouldRetry, |
|
|
initialDelayMs: 10, |
|
|
}); |
|
|
|
|
|
await expect(promise).rejects.toThrow('Non-retryable error'); |
|
|
expect(mockFn).toHaveBeenCalledTimes(1); |
|
|
}); |
|
|
|
|
|
it('should use default shouldRetry if not provided, retrying on 429', async () => { |
|
|
const mockFn = vi.fn(async () => { |
|
|
const error = new Error('Too Many Requests') as any; |
|
|
error.status = 429; |
|
|
throw error; |
|
|
}); |
|
|
|
|
|
const promise = retryWithBackoff(mockFn, { |
|
|
maxAttempts: 2, |
|
|
initialDelayMs: 10, |
|
|
}); |
|
|
|
|
|
|
|
|
const assertionPromise = |
|
|
expect(promise).rejects.toThrow('Too Many Requests'); |
|
|
|
|
|
|
|
|
await vi.runAllTimersAsync(); |
|
|
|
|
|
|
|
|
await assertionPromise; |
|
|
|
|
|
expect(mockFn).toHaveBeenCalledTimes(2); |
|
|
}); |
|
|
|
|
|
it('should use default shouldRetry if not provided, not retrying on 400', async () => { |
|
|
const mockFn = vi.fn(async () => { |
|
|
const error = new Error('Bad Request') as any; |
|
|
error.status = 400; |
|
|
throw error; |
|
|
}); |
|
|
|
|
|
const promise = retryWithBackoff(mockFn, { |
|
|
maxAttempts: 2, |
|
|
initialDelayMs: 10, |
|
|
}); |
|
|
await expect(promise).rejects.toThrow('Bad Request'); |
|
|
expect(mockFn).toHaveBeenCalledTimes(1); |
|
|
}); |
|
|
|
|
|
it('should respect maxDelayMs', async () => { |
|
|
const mockFn = createFailingFunction(3); |
|
|
const setTimeoutSpy = vi.spyOn(global, 'setTimeout'); |
|
|
|
|
|
const promise = retryWithBackoff(mockFn, { |
|
|
maxAttempts: 4, |
|
|
initialDelayMs: 100, |
|
|
maxDelayMs: 250, |
|
|
}); |
|
|
|
|
|
await vi.advanceTimersByTimeAsync(1000); |
|
|
await promise; |
|
|
|
|
|
const delays = setTimeoutSpy.mock.calls.map((call) => call[1] as number); |
|
|
|
|
|
|
|
|
|
|
|
expect(delays.length).toBe(3); |
|
|
expect(delays[0]).toBeGreaterThanOrEqual(100 * 0.7); |
|
|
expect(delays[0]).toBeLessThanOrEqual(100 * 1.3); |
|
|
expect(delays[1]).toBeGreaterThanOrEqual(200 * 0.7); |
|
|
expect(delays[1]).toBeLessThanOrEqual(200 * 1.3); |
|
|
|
|
|
expect(delays[2]).toBeGreaterThanOrEqual(250 * 0.7); |
|
|
expect(delays[2]).toBeLessThanOrEqual(250 * 1.3); |
|
|
}); |
|
|
|
|
|
it('should handle jitter correctly, ensuring varied delays', async () => { |
|
|
let mockFn = createFailingFunction(5); |
|
|
const setTimeoutSpy = vi.spyOn(global, 'setTimeout'); |
|
|
|
|
|
|
|
|
const runRetry = () => |
|
|
retryWithBackoff(mockFn, { |
|
|
maxAttempts: 2, |
|
|
initialDelayMs: 100, |
|
|
maxDelayMs: 1000, |
|
|
}); |
|
|
|
|
|
|
|
|
const promise1 = runRetry(); |
|
|
|
|
|
const assertionPromise1 = expect(promise1).rejects.toThrow(); |
|
|
await vi.runAllTimersAsync(); |
|
|
await assertionPromise1; |
|
|
|
|
|
const firstDelaySet = setTimeoutSpy.mock.calls.map( |
|
|
(call) => call[1] as number, |
|
|
); |
|
|
setTimeoutSpy.mockClear(); |
|
|
|
|
|
|
|
|
mockFn = createFailingFunction(5); |
|
|
|
|
|
const promise2 = runRetry(); |
|
|
|
|
|
const assertionPromise2 = expect(promise2).rejects.toThrow(); |
|
|
await vi.runAllTimersAsync(); |
|
|
await assertionPromise2; |
|
|
|
|
|
const secondDelaySet = setTimeoutSpy.mock.calls.map( |
|
|
(call) => call[1] as number, |
|
|
); |
|
|
|
|
|
|
|
|
|
|
|
if (firstDelaySet.length > 0 && secondDelaySet.length > 0) { |
|
|
|
|
|
expect(firstDelaySet[0]).not.toBe(secondDelaySet[0]); |
|
|
} else { |
|
|
|
|
|
throw new Error('Delays were not captured for jitter test'); |
|
|
} |
|
|
|
|
|
|
|
|
[...firstDelaySet, ...secondDelaySet].forEach((d) => { |
|
|
expect(d).toBeGreaterThanOrEqual(100 * 0.7); |
|
|
expect(d).toBeLessThanOrEqual(100 * 1.3); |
|
|
}); |
|
|
}); |
|
|
|
|
|
describe('Flash model fallback for OAuth users', () => { |
|
|
it('should trigger fallback for OAuth personal users after persistent 429 errors', async () => { |
|
|
const fallbackCallback = vi.fn().mockResolvedValue('gemini-2.5-flash'); |
|
|
|
|
|
let fallbackOccurred = false; |
|
|
const mockFn = vi.fn().mockImplementation(async () => { |
|
|
if (!fallbackOccurred) { |
|
|
const error: HttpError = new Error('Rate limit exceeded'); |
|
|
error.status = 429; |
|
|
throw error; |
|
|
} |
|
|
return 'success'; |
|
|
}); |
|
|
|
|
|
const promise = retryWithBackoff(mockFn, { |
|
|
maxAttempts: 3, |
|
|
initialDelayMs: 100, |
|
|
onPersistent429: async (authType?: string) => { |
|
|
fallbackOccurred = true; |
|
|
return await fallbackCallback(authType); |
|
|
}, |
|
|
authType: 'oauth-personal', |
|
|
}); |
|
|
|
|
|
|
|
|
await vi.runAllTimersAsync(); |
|
|
|
|
|
|
|
|
await expect(promise).resolves.toBe('success'); |
|
|
|
|
|
|
|
|
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal'); |
|
|
|
|
|
|
|
|
expect(mockFn).toHaveBeenCalledTimes(3); |
|
|
}); |
|
|
|
|
|
it('should NOT trigger fallback for API key users', async () => { |
|
|
const fallbackCallback = vi.fn(); |
|
|
|
|
|
const mockFn = vi.fn(async () => { |
|
|
const error: HttpError = new Error('Rate limit exceeded'); |
|
|
error.status = 429; |
|
|
throw error; |
|
|
}); |
|
|
|
|
|
const promise = retryWithBackoff(mockFn, { |
|
|
maxAttempts: 3, |
|
|
initialDelayMs: 100, |
|
|
onPersistent429: fallbackCallback, |
|
|
authType: 'gemini-api-key', |
|
|
}); |
|
|
|
|
|
|
|
|
const resultPromise = promise.catch((error) => error); |
|
|
await vi.runAllTimersAsync(); |
|
|
const result = await resultPromise; |
|
|
|
|
|
|
|
|
expect(result).toBeInstanceOf(Error); |
|
|
expect(result.message).toBe('Rate limit exceeded'); |
|
|
|
|
|
|
|
|
expect(fallbackCallback).not.toHaveBeenCalled(); |
|
|
}); |
|
|
|
|
|
it('should reset attempt counter and continue after successful fallback', async () => { |
|
|
let fallbackCalled = false; |
|
|
const fallbackCallback = vi.fn().mockImplementation(async () => { |
|
|
fallbackCalled = true; |
|
|
return 'gemini-2.5-flash'; |
|
|
}); |
|
|
|
|
|
const mockFn = vi.fn().mockImplementation(async () => { |
|
|
if (!fallbackCalled) { |
|
|
const error: HttpError = new Error('Rate limit exceeded'); |
|
|
error.status = 429; |
|
|
throw error; |
|
|
} |
|
|
return 'success'; |
|
|
}); |
|
|
|
|
|
const promise = retryWithBackoff(mockFn, { |
|
|
maxAttempts: 3, |
|
|
initialDelayMs: 100, |
|
|
onPersistent429: fallbackCallback, |
|
|
authType: 'oauth-personal', |
|
|
}); |
|
|
|
|
|
await vi.runAllTimersAsync(); |
|
|
|
|
|
await expect(promise).resolves.toBe('success'); |
|
|
expect(fallbackCallback).toHaveBeenCalledOnce(); |
|
|
}); |
|
|
|
|
|
it('should continue with original error if fallback is rejected', async () => { |
|
|
const fallbackCallback = vi.fn().mockResolvedValue(null); |
|
|
|
|
|
const mockFn = vi.fn(async () => { |
|
|
const error: HttpError = new Error('Rate limit exceeded'); |
|
|
error.status = 429; |
|
|
throw error; |
|
|
}); |
|
|
|
|
|
const promise = retryWithBackoff(mockFn, { |
|
|
maxAttempts: 3, |
|
|
initialDelayMs: 100, |
|
|
onPersistent429: fallbackCallback, |
|
|
authType: 'oauth-personal', |
|
|
}); |
|
|
|
|
|
|
|
|
const resultPromise = promise.catch((error) => error); |
|
|
await vi.runAllTimersAsync(); |
|
|
const result = await resultPromise; |
|
|
|
|
|
|
|
|
expect(result).toBeInstanceOf(Error); |
|
|
expect(result.message).toBe('Rate limit exceeded'); |
|
|
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal'); |
|
|
}); |
|
|
|
|
|
it('should handle mixed error types (only count consecutive 429s)', async () => { |
|
|
const fallbackCallback = vi.fn().mockResolvedValue('gemini-2.5-flash'); |
|
|
let attempts = 0; |
|
|
let fallbackOccurred = false; |
|
|
|
|
|
const mockFn = vi.fn().mockImplementation(async () => { |
|
|
attempts++; |
|
|
if (fallbackOccurred) { |
|
|
return 'success'; |
|
|
} |
|
|
if (attempts === 1) { |
|
|
|
|
|
const error: HttpError = new Error('Server error'); |
|
|
error.status = 500; |
|
|
throw error; |
|
|
} else { |
|
|
|
|
|
const error: HttpError = new Error('Rate limit exceeded'); |
|
|
error.status = 429; |
|
|
throw error; |
|
|
} |
|
|
}); |
|
|
|
|
|
const promise = retryWithBackoff(mockFn, { |
|
|
maxAttempts: 5, |
|
|
initialDelayMs: 100, |
|
|
onPersistent429: async (authType?: string) => { |
|
|
fallbackOccurred = true; |
|
|
return await fallbackCallback(authType); |
|
|
}, |
|
|
authType: 'oauth-personal', |
|
|
}); |
|
|
|
|
|
await vi.runAllTimersAsync(); |
|
|
|
|
|
await expect(promise).resolves.toBe('success'); |
|
|
|
|
|
|
|
|
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal'); |
|
|
}); |
|
|
}); |
|
|
}); |
|
|
|