|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; |
|
|
import { getOauthClient } from './oauth2.js'; |
|
|
import { OAuth2Client } from 'google-auth-library'; |
|
|
import * as fs from 'fs'; |
|
|
import * as path from 'path'; |
|
|
import http from 'http'; |
|
|
import open from 'open'; |
|
|
import crypto from 'crypto'; |
|
|
import * as os from 'os'; |
|
|
|
|
|
vi.mock('os', async (importOriginal) => { |
|
|
const os = await importOriginal<typeof import('os')>(); |
|
|
return { |
|
|
...os, |
|
|
homedir: vi.fn(), |
|
|
}; |
|
|
}); |
|
|
|
|
|
vi.mock('google-auth-library'); |
|
|
vi.mock('http'); |
|
|
vi.mock('open'); |
|
|
vi.mock('crypto'); |
|
|
|
|
|
describe('oauth2', () => { |
|
|
let tempHomeDir: string; |
|
|
|
|
|
beforeEach(() => { |
|
|
tempHomeDir = fs.mkdtempSync( |
|
|
path.join(os.tmpdir(), 'gemini-cli-test-home-'), |
|
|
); |
|
|
vi.mocked(os.homedir).mockReturnValue(tempHomeDir); |
|
|
}); |
|
|
afterEach(() => { |
|
|
fs.rmSync(tempHomeDir, { recursive: true, force: true }); |
|
|
}); |
|
|
|
|
|
it('should perform a web login', async () => { |
|
|
const mockAuthUrl = 'https://example.com/auth'; |
|
|
const mockCode = 'test-code'; |
|
|
const mockState = 'test-state'; |
|
|
const mockTokens = { |
|
|
access_token: 'test-access-token', |
|
|
refresh_token: 'test-refresh-token', |
|
|
}; |
|
|
|
|
|
const mockGenerateAuthUrl = vi.fn().mockReturnValue(mockAuthUrl); |
|
|
const mockGetToken = vi.fn().mockResolvedValue({ tokens: mockTokens }); |
|
|
const mockSetCredentials = vi.fn(); |
|
|
const mockOAuth2Client = { |
|
|
generateAuthUrl: mockGenerateAuthUrl, |
|
|
getToken: mockGetToken, |
|
|
setCredentials: mockSetCredentials, |
|
|
credentials: mockTokens, |
|
|
} as unknown as OAuth2Client; |
|
|
vi.mocked(OAuth2Client).mockImplementation(() => mockOAuth2Client); |
|
|
|
|
|
vi.spyOn(crypto, 'randomBytes').mockReturnValue(mockState as never); |
|
|
vi.mocked(open).mockImplementation(async () => ({}) as never); |
|
|
|
|
|
let requestCallback!: http.RequestListener< |
|
|
typeof http.IncomingMessage, |
|
|
typeof http.ServerResponse |
|
|
>; |
|
|
|
|
|
let serverListeningCallback: (value: unknown) => void; |
|
|
const serverListeningPromise = new Promise( |
|
|
(resolve) => (serverListeningCallback = resolve), |
|
|
); |
|
|
|
|
|
let capturedPort = 0; |
|
|
const mockHttpServer = { |
|
|
listen: vi.fn((port: number, callback?: () => void) => { |
|
|
capturedPort = port; |
|
|
if (callback) { |
|
|
callback(); |
|
|
} |
|
|
serverListeningCallback(undefined); |
|
|
}), |
|
|
close: vi.fn((callback?: () => void) => { |
|
|
if (callback) { |
|
|
callback(); |
|
|
} |
|
|
}), |
|
|
on: vi.fn(), |
|
|
address: () => ({ port: capturedPort }), |
|
|
}; |
|
|
vi.mocked(http.createServer).mockImplementation((cb) => { |
|
|
requestCallback = cb as http.RequestListener< |
|
|
typeof http.IncomingMessage, |
|
|
typeof http.ServerResponse |
|
|
>; |
|
|
return mockHttpServer as unknown as http.Server; |
|
|
}); |
|
|
|
|
|
const clientPromise = getOauthClient(); |
|
|
|
|
|
|
|
|
await serverListeningPromise; |
|
|
|
|
|
const mockReq = { |
|
|
url: `/oauth2callback?code=${mockCode}&state=${mockState}`, |
|
|
} as http.IncomingMessage; |
|
|
const mockRes = { |
|
|
writeHead: vi.fn(), |
|
|
end: vi.fn(), |
|
|
} as unknown as http.ServerResponse; |
|
|
|
|
|
await requestCallback(mockReq, mockRes); |
|
|
|
|
|
const client = await clientPromise; |
|
|
expect(client).toBe(mockOAuth2Client); |
|
|
|
|
|
expect(open).toHaveBeenCalledWith(mockAuthUrl); |
|
|
expect(mockGetToken).toHaveBeenCalledWith({ |
|
|
code: mockCode, |
|
|
redirect_uri: `http://localhost:${capturedPort}/oauth2callback`, |
|
|
}); |
|
|
expect(mockSetCredentials).toHaveBeenCalledWith(mockTokens); |
|
|
|
|
|
const tokenPath = path.join(tempHomeDir, '.gemini', 'oauth_creds.json'); |
|
|
const tokenData = JSON.parse(fs.readFileSync(tokenPath, 'utf-8')); |
|
|
expect(tokenData).toEqual(mockTokens); |
|
|
}); |
|
|
}); |
|
|
|