| | |
| | import pytest |
| | from unittest.mock import patch, MagicMock, ANY, AsyncMock |
| | from openai import OpenAIError |
| | import json |
| | import tenacity |
| | import asyncio |
| | from openai.types.chat import ChatCompletion |
| | from openai.types.chat.chat_completion import Choice as ChatCompletionChoice |
| | from openai.types.chat.chat_completion_message import ChatCompletionMessage |
| | from openai import APIConnectionError, APIError, AsyncOpenAI |
| |
|
| | |
| | from ankigen_core.llm_interface import ( |
| | OpenAIClientManager, |
| | structured_output_completion, |
| | process_crawled_page, |
| | process_crawled_pages, |
| | ) |
| | from ankigen_core.utils import ( |
| | ResponseCache, |
| | ) |
| | from ankigen_core.models import CrawledPage, AnkiCardData |
| |
|
| | |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_client_manager_init(): |
| | """Test initial state of the client manager.""" |
| | manager = OpenAIClientManager() |
| | assert manager._client is None |
| | assert manager._api_key is None |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_client_manager_initialize_success(): |
| | """Test successful client initialization.""" |
| | manager = OpenAIClientManager() |
| | valid_key = "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" |
| | |
| | with patch( |
| | "ankigen_core.llm_interface.AsyncOpenAI" |
| | ) as mock_async_openai_constructor: |
| | await manager.initialize_client(valid_key) |
| | mock_async_openai_constructor.assert_called_once_with(api_key=valid_key) |
| | assert manager.get_client() is not None |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_client_manager_initialize_invalid_key_format(): |
| | """Test initialization failure with invalid API key format.""" |
| | manager = OpenAIClientManager() |
| | invalid_key = "invalid-key-format" |
| | with pytest.raises(ValueError, match="Invalid OpenAI API key format."): |
| | await manager.initialize_client(invalid_key) |
| | assert manager._client is None |
| | assert manager._api_key is None |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_client_manager_initialize_openai_error(): |
| | """Test handling of OpenAIError during client initialization.""" |
| | manager = OpenAIClientManager() |
| | valid_key = "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" |
| | error_message = "Test OpenAI Init Error" |
| |
|
| | with patch( |
| | "ankigen_core.llm_interface.AsyncOpenAI", side_effect=OpenAIError(error_message) |
| | ) as mock_async_openai_constructor: |
| | with pytest.raises(OpenAIError, match=error_message): |
| | await manager.initialize_client(valid_key) |
| | mock_async_openai_constructor.assert_called_once_with(api_key=valid_key) |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_client_manager_get_client_success(): |
| | """Test getting the client after successful initialization.""" |
| | manager = OpenAIClientManager() |
| | valid_key = "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" |
| | with patch( |
| | "ankigen_core.llm_interface.AsyncOpenAI" |
| | ) as mock_async_openai_constructor: |
| | mock_instance = mock_async_openai_constructor.return_value |
| | await manager.initialize_client(valid_key) |
| | assert manager.get_client() == mock_instance |
| |
|
| |
|
| | def test_client_manager_get_client_not_initialized(): |
| | """Test getting the client before initialization.""" |
| | manager = OpenAIClientManager() |
| | with pytest.raises(RuntimeError, match="OpenAI client is not initialized."): |
| | manager.get_client() |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| | @pytest.fixture |
| | def mock_openai_client(): |
| | client = MagicMock(spec=AsyncOpenAI) |
| | client.chat = AsyncMock() |
| | client.chat.completions = AsyncMock() |
| | client.chat.completions.create = AsyncMock() |
| | mock_chat_completion_response = create_mock_chat_completion( |
| | json.dumps([{"data": "mocked success"}]) |
| | ) |
| | client.chat.completions.create.return_value = mock_chat_completion_response |
| | return client |
| |
|
| |
|
| | |
| | @pytest.fixture |
| | def mock_response_cache(): |
| | cache = MagicMock(spec=ResponseCache) |
| | return cache |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_structured_output_completion_cache_hit( |
| | mock_openai_client, mock_response_cache |
| | ): |
| | """Test behavior when the response is found in the cache.""" |
| | system_prompt = "System prompt" |
| | user_prompt = "User prompt" |
| | model = "test-model" |
| | cached_result = {"data": "cached result"} |
| |
|
| | |
| | mock_response_cache.get.return_value = cached_result |
| |
|
| | result = await structured_output_completion( |
| | openai_client=mock_openai_client, |
| | model=model, |
| | response_format={"type": "json_object"}, |
| | system_prompt=system_prompt, |
| | user_prompt=user_prompt, |
| | cache=mock_response_cache, |
| | ) |
| |
|
| | |
| | mock_response_cache.get.assert_called_once_with( |
| | f"{system_prompt}:{user_prompt}", model |
| | ) |
| | mock_openai_client.chat.completions.create.assert_not_called() |
| | mock_response_cache.set.assert_not_called() |
| | assert result == cached_result |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_structured_output_completion_cache_miss_success( |
| | mock_openai_client, mock_response_cache |
| | ): |
| | """Test behavior on cache miss with a successful API call.""" |
| | system_prompt = "System prompt for success" |
| | user_prompt = "User prompt for success" |
| | model = "test-model-success" |
| | expected_result = {"data": "successful API result"} |
| |
|
| | |
| | mock_response_cache.get.return_value = None |
| |
|
| | |
| | mock_completion = MagicMock() |
| | mock_message = MagicMock() |
| | mock_message.content = json.dumps(expected_result) |
| | mock_choice = MagicMock() |
| | mock_choice.message = mock_message |
| | mock_completion.choices = [mock_choice] |
| | mock_openai_client.chat.completions.create.return_value = mock_completion |
| |
|
| | result = await structured_output_completion( |
| | openai_client=mock_openai_client, |
| | model=model, |
| | response_format={"type": "json_object"}, |
| | system_prompt=system_prompt, |
| | user_prompt=user_prompt, |
| | cache=mock_response_cache, |
| | ) |
| |
|
| | |
| | mock_response_cache.get.assert_called_once_with( |
| | f"{system_prompt}:{user_prompt}", model |
| | ) |
| | mock_openai_client.chat.completions.create.assert_called_once_with( |
| | model=model, |
| | messages=[ |
| | { |
| | "role": "system", |
| | "content": ANY, |
| | }, |
| | {"role": "user", "content": user_prompt}, |
| | ], |
| | response_format={"type": "json_object"}, |
| | temperature=0.7, |
| | ) |
| | mock_response_cache.set.assert_called_once_with( |
| | f"{system_prompt}:{user_prompt}", model, expected_result |
| | ) |
| | assert result == expected_result |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_structured_output_completion_api_error( |
| | mock_openai_client, mock_response_cache |
| | ): |
| | """Test behavior when the OpenAI API call raises an error.""" |
| | system_prompt = "System prompt for error" |
| | user_prompt = "User prompt for error" |
| | model = "test-model-error" |
| | error_message = "Test API Error" |
| |
|
| | |
| | mock_response_cache.get.return_value = None |
| |
|
| | |
| | |
| | |
| | mock_openai_client.chat.completions.create.side_effect = OpenAIError(error_message) |
| |
|
| | with pytest.raises(tenacity.RetryError): |
| | await structured_output_completion( |
| | openai_client=mock_openai_client, |
| | model=model, |
| | response_format={"type": "json_object"}, |
| | system_prompt=system_prompt, |
| | user_prompt=user_prompt, |
| | cache=mock_response_cache, |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | assert ( |
| | mock_response_cache.get.call_count == 3 |
| | ), f"Expected cache.get to be called 3 times due to retries, but was {mock_response_cache.get.call_count}" |
| | |
| | assert ( |
| | mock_openai_client.chat.completions.create.call_count == 3 |
| | ), f"Expected create to be called 3 times due to retries, but was {mock_openai_client.chat.completions.create.call_count}" |
| | mock_response_cache.set.assert_not_called() |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_structured_output_completion_invalid_json( |
| | mock_openai_client, mock_response_cache |
| | ): |
| | """Test behavior when the API returns invalid JSON.""" |
| | system_prompt = "System prompt for invalid json" |
| | user_prompt = "User prompt for invalid json" |
| | model = "test-model-invalid-json" |
| | invalid_json_content = "this is not json" |
| |
|
| | |
| | mock_response_cache.get.return_value = None |
| |
|
| | |
| | mock_completion = MagicMock() |
| | mock_message = MagicMock() |
| | mock_message.content = invalid_json_content |
| | mock_choice = MagicMock() |
| | mock_choice.message = mock_message |
| | mock_completion.choices = [mock_choice] |
| | mock_openai_client.chat.completions.create.return_value = mock_completion |
| |
|
| | with pytest.raises(tenacity.RetryError): |
| | await structured_output_completion( |
| | openai_client=mock_openai_client, |
| | model=model, |
| | response_format={"type": "json_object"}, |
| | system_prompt=system_prompt, |
| | user_prompt=user_prompt, |
| | cache=mock_response_cache, |
| | ) |
| |
|
| | |
| | |
| | assert ( |
| | mock_response_cache.get.call_count == 3 |
| | ), f"Expected cache.get to be called 3 times due to retries, but was {mock_response_cache.get.call_count}" |
| | |
| | assert ( |
| | mock_openai_client.chat.completions.create.call_count == 3 |
| | ), f"Expected create to be called 3 times due to retries, but was {mock_openai_client.chat.completions.create.call_count}" |
| | mock_response_cache.set.assert_not_called() |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_structured_output_completion_no_choices( |
| | mock_openai_client, mock_response_cache |
| | ): |
| | """Test behavior when API completion has no choices.""" |
| | system_prompt = "System prompt no choices" |
| | user_prompt = "User prompt no choices" |
| | model = "test-model-no-choices" |
| |
|
| | mock_response_cache.get.return_value = None |
| | mock_completion = MagicMock() |
| | mock_completion.choices = [] |
| | mock_openai_client.chat.completions.create.return_value = mock_completion |
| |
|
| | |
| | result = await structured_output_completion( |
| | openai_client=mock_openai_client, |
| | model=model, |
| | response_format={"type": "json_object"}, |
| | system_prompt=system_prompt, |
| | user_prompt=user_prompt, |
| | cache=mock_response_cache, |
| | ) |
| | assert result is None |
| | mock_response_cache.set.assert_not_called() |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_structured_output_completion_no_message_content( |
| | mock_openai_client, mock_response_cache |
| | ): |
| | """Test behavior when API choice has no message content.""" |
| | system_prompt = "System prompt no content" |
| | user_prompt = "User prompt no content" |
| | model = "test-model-no-content" |
| |
|
| | mock_response_cache.get.return_value = None |
| | mock_completion = MagicMock() |
| | mock_message = MagicMock() |
| | mock_message.content = None |
| | mock_choice = MagicMock() |
| | mock_choice.message = mock_message |
| | mock_completion.choices = [mock_choice] |
| | mock_openai_client.chat.completions.create.return_value = mock_completion |
| |
|
| | |
| | result = await structured_output_completion( |
| | openai_client=mock_openai_client, |
| | model=model, |
| | response_format={"type": "json_object"}, |
| | system_prompt=system_prompt, |
| | user_prompt=user_prompt, |
| | cache=mock_response_cache, |
| | ) |
| | assert result is None |
| | mock_response_cache.set.assert_not_called() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| |
|
| | @pytest.fixture |
| | def client_manager(): |
| | """Fixture for the OpenAIClientManager.""" |
| | return OpenAIClientManager() |
| |
|
| |
|
| | @pytest.fixture |
| | def sample_crawled_page(): |
| | """Fixture for a sample CrawledPage object.""" |
| | return CrawledPage( |
| | url="http://example.com", |
| | html_content="<html><body>This is some test content for the page.</body></html>", |
| | text_content="This is some test content for the page.", |
| | title="Test Page", |
| | meta_description="A test page.", |
| | meta_keywords=["test", "page"], |
| | crawl_depth=0, |
| | ) |
| |
|
| |
|
| | |
| |
|
| |
|
| | def create_mock_chat_completion(content: str) -> ChatCompletion: |
| | return ChatCompletion( |
| | id="chatcmpl-test123", |
| | choices=[ |
| | ChatCompletionChoice( |
| | finish_reason="stop", |
| | index=0, |
| | message=ChatCompletionMessage(content=content, role="assistant"), |
| | logprobs=None, |
| | ) |
| | ], |
| | created=1677652288, |
| | model="gpt-4o", |
| | object="chat.completion", |
| | system_fingerprint="fp_test", |
| | usage=None, |
| | ) |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_process_crawled_page_success(mock_openai_client, sample_crawled_page): |
| | |
| | mock_response_content = json.dumps( |
| | [ |
| | {"front": "Q1", "back": "A1", "tags": ["tag1"]}, |
| | {"front": "Q2", "back": "A2", "tags": ["tag2", "python"]}, |
| | ] |
| | ) |
| | mock_openai_client.chat.completions.create.return_value = ( |
| | create_mock_chat_completion(mock_response_content) |
| | ) |
| |
|
| | result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
| |
|
| | assert len(result_cards) == 2 |
| | assert result_cards[0].front == "Q1" |
| | assert result_cards[0].source_url == sample_crawled_page.url |
| | assert result_cards[1].back == "A2" |
| | |
| | |
| | mock_openai_client.chat.completions.create.assert_awaited_once() |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_process_crawled_page_empty_llm_response_content( |
| | mock_openai_client, sample_crawled_page |
| | ): |
| | mock_openai_client.chat.completions.create.return_value = ( |
| | create_mock_chat_completion("") |
| | ) |
| |
|
| | result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
| | assert len(result_cards) == 0 |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_process_crawled_page_llm_returns_not_a_list( |
| | mock_openai_client, sample_crawled_page |
| | ): |
| | mock_response_content = json.dumps( |
| | {"error": "not a list as expected"} |
| | ) |
| | mock_openai_client.chat.completions.create.return_value = ( |
| | create_mock_chat_completion(mock_response_content) |
| | ) |
| |
|
| | result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
| | assert len(result_cards) == 0 |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_process_crawled_page_llm_returns_dict_with_cards_key( |
| | mock_openai_client, sample_crawled_page |
| | ): |
| | mock_response_content = json.dumps( |
| | {"cards": [{"front": "Q1", "back": "A1", "tags": []}]} |
| | ) |
| | mock_openai_client.chat.completions.create.return_value = ( |
| | create_mock_chat_completion(mock_response_content) |
| | ) |
| |
|
| | result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
| |
|
| | |
| | assert len(result_cards) == 1 |
| | assert result_cards[0].front == "Q1" |
| | assert result_cards[0].back == "A1" |
| | assert result_cards[0].source_url == sample_crawled_page.url |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_process_crawled_page_json_decode_error( |
| | mock_openai_client, sample_crawled_page |
| | ): |
| | mock_openai_client.chat.completions.create.return_value = ( |
| | create_mock_chat_completion("this is not valid json") |
| | ) |
| |
|
| | result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
| | assert len(result_cards) == 0 |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_process_crawled_page_empty_text_content(mock_openai_client): |
| | empty_content_page = CrawledPage( |
| | url="http://example.com/empty", |
| | html_content="", |
| | text_content="", |
| | title="Empty", |
| | ) |
| | result_cards = await process_crawled_page(mock_openai_client, empty_content_page) |
| | assert len(result_cards) == 0 |
| | mock_openai_client.chat.completions.create.assert_not_awaited() |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_process_crawled_page_openai_api_error_retry( |
| | mock_openai_client, sample_crawled_page, caplog |
| | ): |
| | |
| | |
| |
|
| | |
| | from ankigen_core.llm_interface import process_crawled_page as original_func |
| |
|
| | |
| | async def mock_implementation(*args, **kwargs): |
| | return await original_func(*args, **kwargs) |
| |
|
| | with patch( |
| | "ankigen_core.llm_interface.process_crawled_page", |
| | side_effect=mock_implementation, |
| | ): |
| | |
| | responses = [ |
| | create_mock_chat_completion( |
| | json.dumps([{"front": "Q1", "back": "A1", "tags": []}]) |
| | ) |
| | ] |
| | mock_openai_client.chat.completions.create.return_value = responses[0] |
| |
|
| | |
| | result_cards = await mock_implementation( |
| | mock_openai_client, sample_crawled_page |
| | ) |
| |
|
| | |
| | assert len(result_cards) == 1 |
| | assert result_cards[0].front == "Q1" |
| | assert result_cards[0].back == "A1" |
| | assert mock_openai_client.chat.completions.create.call_count == 1 |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_process_crawled_page_openai_persistent_api_error( |
| | mock_openai_client, sample_crawled_page, caplog |
| | ): |
| | |
| | mock_openai_client.chat.completions.create.side_effect = APIConnectionError( |
| | request=MagicMock() |
| | ) |
| |
|
| | result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
| |
|
| | assert len(result_cards) == 0 |
| | assert mock_openai_client.chat.completions.create.await_count == 1 |
| | assert "OpenAI API error while processing page" in caplog.text |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_process_crawled_page_tiktoken_truncation( |
| | mock_openai_client, sample_crawled_page, monkeypatch |
| | ): |
| | |
| | long_text = "word " * 8000 |
| | sample_crawled_page.text_content = long_text |
| |
|
| | |
| | mock_response_content = json.dumps( |
| | [{"front": "TruncatedQ", "back": "TruncatedA", "tags": []}] |
| | ) |
| | mock_openai_client.chat.completions.create.return_value = ( |
| | create_mock_chat_completion(mock_response_content) |
| | ) |
| |
|
| | |
| | mock_encoding = MagicMock() |
| |
|
| | |
| | |
| | |
| | |
| | mock_encoding.encode.side_effect = [ |
| | list(range(1000)), |
| | list(range(10000)), |
| | list(range(10000)), |
| | ] |
| |
|
| | |
| | truncated_content = [] |
| |
|
| | def mock_decode(tokens): |
| | truncated_content.append(len(tokens)) |
| | return "Truncated content" |
| |
|
| | mock_encoding.decode = mock_decode |
| |
|
| | mock_get_encoding = MagicMock(return_value=mock_encoding) |
| |
|
| | with patch("tiktoken.get_encoding", mock_get_encoding): |
| | with patch("tiktoken.encoding_for_model", side_effect=KeyError("test")): |
| | result_cards = await process_crawled_page( |
| | mock_openai_client, sample_crawled_page, max_prompt_content_tokens=6000 |
| | ) |
| |
|
| | |
| | assert len(result_cards) == 1 |
| | assert result_cards[0].front == "TruncatedQ" |
| | assert result_cards[0].back == "TruncatedA" |
| |
|
| | |
| | mock_get_encoding.assert_called_with("cl100k_base") |
| | assert mock_encoding.encode.call_count >= 2 |
| |
|
| |
|
| | |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_process_crawled_pages_success(mock_openai_client, sample_crawled_page): |
| | pages_to_process = [ |
| | sample_crawled_page, |
| | CrawledPage( |
| | url="http://example.com/page2", |
| | html_content="", |
| | text_content="Content for page 2", |
| | title="Page 2", |
| | ), |
| | ] |
| |
|
| | |
| | async def mock_single_page_processor(openai_client, page, model="gpt-4o", **kwargs): |
| | if page.url == pages_to_process[0].url: |
| | return [AnkiCardData(front="P1Q1", back="P1A1", source_url=page.url)] |
| | elif page.url == pages_to_process[1].url: |
| | return [ |
| | AnkiCardData(front="P2Q1", back="P2A1", source_url=page.url), |
| | AnkiCardData(front="P2Q2", back="P2A2", source_url=page.url), |
| | ] |
| | return [] |
| |
|
| | with patch( |
| | "ankigen_core.llm_interface.process_crawled_page", |
| | side_effect=mock_single_page_processor, |
| | ) as mock_processor: |
| | result_cards = await process_crawled_pages( |
| | mock_openai_client, pages_to_process, max_concurrent_requests=1 |
| | ) |
| |
|
| | assert len(result_cards) == 3 |
| | assert mock_processor.call_count == 2 |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_process_crawled_pages_partial_failure( |
| | mock_openai_client, sample_crawled_page |
| | ): |
| | pages_to_process = [ |
| | sample_crawled_page, |
| | CrawledPage( |
| | url="http://example.com/page_fail", |
| | html_content="", |
| | text_content="Content for page fail", |
| | title="Page Fail", |
| | ), |
| | CrawledPage( |
| | url="http://example.com/page3", |
| | html_content="", |
| | text_content="Content for page 3", |
| | title="Page 3", |
| | ), |
| | ] |
| |
|
| | async def mock_single_page_processor_with_failure( |
| | openai_client, page, model="gpt-4o", **kwargs |
| | ): |
| | if page.url == pages_to_process[0].url: |
| | return [AnkiCardData(front="P1Q1", back="P1A1", source_url=page.url)] |
| | elif page.url == pages_to_process[1].url: |
| | raise APIConnectionError(request=MagicMock()) |
| | elif page.url == pages_to_process[2].url: |
| | return [AnkiCardData(front="P3Q1", back="P3A1", source_url=page.url)] |
| | return [] |
| |
|
| | with patch( |
| | "ankigen_core.llm_interface.process_crawled_page", |
| | side_effect=mock_single_page_processor_with_failure, |
| | ) as mock_processor: |
| | result_cards = await process_crawled_pages( |
| | mock_openai_client, pages_to_process, max_concurrent_requests=2 |
| | ) |
| |
|
| | assert len(result_cards) == 2 |
| | assert mock_processor.call_count == 3 |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_process_crawled_pages_progress_callback( |
| | mock_openai_client, sample_crawled_page |
| | ): |
| | pages_to_process = [sample_crawled_page] * 3 |
| | progress_log = [] |
| |
|
| | def callback(completed_count, total_count): |
| | progress_log.append((completed_count, total_count)) |
| |
|
| | async def mock_simple_processor(client, page, model, max_tokens): |
| | await asyncio.sleep(0.01) |
| | return [AnkiCardData(front=f"{page.url}-Q", back="A", source_url=page.url)] |
| |
|
| | with patch( |
| | "ankigen_core.llm_interface.process_crawled_page", |
| | side_effect=mock_simple_processor, |
| | ): |
| | await process_crawled_pages( |
| | mock_openai_client, |
| | pages_to_process, |
| | progress_callback=callback, |
| | max_concurrent_requests=1, |
| | ) |
| |
|
| | assert len(progress_log) == 3 |
| | assert progress_log[0] == (1, 3) |
| | assert progress_log[1] == (2, 3) |
| | assert progress_log[2] == (3, 3) |
| |
|
| |
|
| | |
| | TEST_API_KEY = "sk-testkey1234567890abcdefghijklmnopqrstuvwxyz" |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_process_crawled_page_api_error( |
| | client_manager, mock_openai_client, sample_crawled_page |
| | ): |
| | """Test handling of API error during LLM call.""" |
| |
|
| | |
| | |
| | mock_request = MagicMock() |
| | mock_openai_client.chat.completions.create.side_effect = APIError( |
| | message="Test API Error", request=mock_request, body=None |
| | ) |
| |
|
| | with patch.object(client_manager, "get_client", return_value=mock_openai_client): |
| | |
| | mock_openai_client.chat.completions.create.reset_mock() |
| |
|
| | result_cards = await process_crawled_page( |
| | mock_openai_client, |
| | sample_crawled_page, |
| | "gpt-4o", |
| | max_prompt_content_tokens=1000, |
| | ) |
| | assert len(result_cards) == 0 |
| | |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_process_crawled_page_content_truncation( |
| | client_manager, mock_openai_client, sample_crawled_page |
| | ): |
| | """Test content truncation based on max_prompt_content_tokens.""" |
| | long_content_piece = "This is a word. " |
| | repetitions = 10 |
| | sample_crawled_page.text_content = long_content_piece * repetitions |
| |
|
| | with ( |
| | patch.object(client_manager, "get_client", return_value=mock_openai_client), |
| | patch("tiktoken.encoding_for_model", side_effect=KeyError("test")), |
| | patch("tiktoken.get_encoding") as mock_get_encoding, |
| | ): |
| | mock_encoding = MagicMock() |
| |
|
| | |
| | |
| | |
| | system_prompt_tokens = list(range(100)) |
| | mock_encoding.encode.return_value = system_prompt_tokens |
| |
|
| | mock_get_encoding.return_value = mock_encoding |
| |
|
| | |
| | mock_openai_client.chat.completions.create.return_value = ( |
| | create_mock_chat_completion( |
| | json.dumps([{"front": "TestQ", "back": "TestA", "tags": []}]) |
| | ) |
| | ) |
| |
|
| | |
| | result = await process_crawled_page( |
| | mock_openai_client, |
| | sample_crawled_page, |
| | "gpt-4o", |
| | max_prompt_content_tokens=20, |
| | ) |
| |
|
| | |
| | assert result == [] |
| |
|
| | |
| | mock_get_encoding.assert_called_with("cl100k_base") |
| | assert mock_encoding.encode.call_count >= 1 |
| |
|
| | |
| | mock_openai_client.chat.completions.create.assert_not_called() |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_openai_client_manager_get_client( |
| | client_manager, mock_async_openai_client |
| | ): |
| | """Test that get_client returns the AsyncOpenAI client instance and initializes it once.""" |
| | |
| | client_manager._client = None |
| | client_manager._api_key = None |
| |
|
| | with patch( |
| | "ankigen_core.llm_interface.AsyncOpenAI", return_value=mock_async_openai_client |
| | ) as mock_constructor: |
| | |
| | await client_manager.initialize_client( |
| | "sk-testkey1234567890abcdefghijklmnopqrstuvwxyz" |
| | ) |
| |
|
| | client1 = client_manager.get_client() |
| | client2 = ( |
| | client_manager.get_client() |
| | ) |
| |
|
| | assert client1 is mock_async_openai_client |
| | assert client2 is mock_async_openai_client |
| | mock_constructor.assert_called_once_with( |
| | api_key="sk-testkey1234567890abcdefghijklmnopqrstuvwxyz" |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| | @pytest.fixture |
| | def mock_async_openai_client(): |
| | client = MagicMock(spec=AsyncOpenAI) |
| | client.chat = AsyncMock() |
| | client.chat.completions = AsyncMock() |
| | client.chat.completions.create = AsyncMock() |
| | mock_process_page_response = create_mock_chat_completion( |
| | json.dumps([{"front": "Q_Default", "back": "A_Default", "tags": []}]) |
| | ) |
| | client.chat.completions.create.return_value = mock_process_page_response |
| | return client |
| |
|