File size: 6,272 Bytes
f0743f4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 | const OpenAI = require('openai');
const { logger } = require('@librechat/data-schemas');
const DALLE3 = require('../DALLE3');
jest.mock('openai');
jest.mock('@librechat/data-schemas', () => {
return {
logger: {
info: jest.fn(),
warn: jest.fn(),
debug: jest.fn(),
error: jest.fn(),
},
};
});
jest.mock('tiktoken', () => {
return {
encoding_for_model: jest.fn().mockReturnValue({
encode: jest.fn(),
decode: jest.fn(),
}),
};
});
const processFileURL = jest.fn();
const generate = jest.fn();
OpenAI.mockImplementation(() => ({
images: {
generate,
},
}));
jest.mock('fs', () => {
return {
existsSync: jest.fn(),
mkdirSync: jest.fn(),
promises: {
writeFile: jest.fn(),
readFile: jest.fn(),
unlink: jest.fn(),
},
};
});
jest.mock('path', () => {
return {
resolve: jest.fn(),
join: jest.fn(),
relative: jest.fn(),
extname: jest.fn().mockImplementation((filename) => {
return filename.slice(filename.lastIndexOf('.'));
}),
};
});
describe('DALLE3', () => {
let originalEnv;
let dalle; // Keep this declaration if you need to use dalle in other tests
const mockApiKey = 'mock_api_key';
beforeAll(() => {
// Save the original process.env
originalEnv = { ...process.env };
});
beforeEach(() => {
// Reset the process.env before each test
jest.resetModules();
process.env = { ...originalEnv, DALLE_API_KEY: mockApiKey };
// Instantiate DALLE3 for tests that do not depend on DALLE3_SYSTEM_PROMPT
dalle = new DALLE3({ processFileURL });
});
afterEach(() => {
jest.clearAllMocks();
// Restore the original process.env after each test
process.env = originalEnv;
});
it('should throw an error if all potential API keys are missing', () => {
delete process.env.DALLE3_API_KEY;
delete process.env.DALLE_API_KEY;
expect(() => new DALLE3()).toThrow('Missing DALLE_API_KEY environment variable.');
});
it('should replace unwanted characters in input string', () => {
const input = 'This is a test\nstring with "quotes" and new lines.';
const expectedOutput = 'This is a test string with quotes and new lines.';
expect(dalle.replaceUnwantedChars(input)).toBe(expectedOutput);
});
it('should generate markdown image URL correctly', () => {
const imageName = 'test.png';
const markdownImage = dalle.wrapInMarkdown(imageName);
expect(markdownImage).toBe('');
});
it('should call OpenAI API with correct parameters', async () => {
const mockData = {
prompt: 'A test prompt',
quality: 'standard',
size: '1024x1024',
style: 'vivid',
};
const mockResponse = {
data: [
{
url: 'http://example.com/img-test.png',
},
],
};
generate.mockResolvedValue(mockResponse);
processFileURL.mockResolvedValue({
filepath: 'http://example.com/img-test.png',
});
const result = await dalle._call(mockData);
expect(generate).toHaveBeenCalledWith({
model: 'dall-e-3',
quality: mockData.quality,
style: mockData.style,
size: mockData.size,
prompt: mockData.prompt,
n: 1,
});
expect(result).toContain('![generated image]');
});
it('should use the system prompt if provided', () => {
process.env.DALLE3_SYSTEM_PROMPT = 'System prompt for testing';
jest.resetModules(); // This will ensure the module is fresh and will read the new env var
const DALLE3 = require('../DALLE3'); // Re-require after setting the env var
const dalleWithSystemPrompt = new DALLE3();
expect(dalleWithSystemPrompt.description_for_model).toBe('System prompt for testing');
});
it('should not use the system prompt if not provided', async () => {
delete process.env.DALLE3_SYSTEM_PROMPT;
const dalleWithoutSystemPrompt = new DALLE3();
expect(dalleWithoutSystemPrompt.description_for_model).not.toBe('System prompt for testing');
});
it('should throw an error if prompt is missing', async () => {
const mockData = {
quality: 'standard',
size: '1024x1024',
style: 'vivid',
};
await expect(dalle._call(mockData)).rejects.toThrow('Missing required field: prompt');
});
it('should log appropriate debug values', async () => {
const mockData = {
prompt: 'A test prompt',
};
const mockResponse = {
data: [
{
url: 'http://example.com/invalid-url',
},
],
};
generate.mockResolvedValue(mockResponse);
await dalle._call(mockData);
expect(logger.debug).toHaveBeenCalledWith('[DALL-E-3]', {
data: { url: 'http://example.com/invalid-url' },
theImageUrl: 'http://example.com/invalid-url',
extension: expect.any(String),
imageBasename: expect.any(String),
imageExt: expect.any(String),
imageName: expect.any(String),
});
});
it('should log an error and return the image URL if there is an error saving the image', async () => {
const mockData = {
prompt: 'A test prompt',
};
const mockResponse = {
data: [
{
url: 'http://example.com/img-test.png',
},
],
};
const error = new Error('Error while saving the image');
generate.mockResolvedValue(mockResponse);
processFileURL.mockRejectedValue(error);
const result = await dalle._call(mockData);
expect(logger.error).toHaveBeenCalledWith('Error while saving the image:', error);
expect(result).toBe('Failed to save the image locally. Error while saving the image');
});
it('should handle error when saving image to Firebase Storage fails', async () => {
const mockData = {
prompt: 'A test prompt',
};
const mockImageUrl = 'http://example.com/img-test.png';
const mockResponse = { data: [{ url: mockImageUrl }] };
const error = new Error('Error while saving to Firebase');
generate.mockResolvedValue(mockResponse);
processFileURL.mockRejectedValue(error);
const result = await dalle._call(mockData);
expect(logger.error).toHaveBeenCalledWith('Error while saving the image:', error);
expect(result).toContain('Failed to save the image');
});
});
|