import json import os from pathlib import Path from unittest.mock import MagicMock, mock_open, patch import pytest from storage import ( HuggingFaceStorage, LocalStorage, get_storage_backend, ) def _create_repo_not_found_error(): """Create a RepositoryNotFoundError with required response argument.""" from huggingface_hub.utils import RepositoryNotFoundError mock_response = MagicMock() mock_response.status_code = 404 mock_response.headers = {} return RepositoryNotFoundError("Not found", response=mock_response) class TestLocalStorage: """Tests for LocalStorage backend.""" def test_load_file_exists(self): """Test loading embeddings when file exists.""" test_data = {"alice": [0.1] * 192} storage = LocalStorage() with patch.object(Path, "exists", return_value=True): with patch("builtins.open", mock_open(read_data=json.dumps(test_data))): result = storage.load() assert result == test_data def test_load_file_not_exists(self): """Test loading embeddings when file doesn't exist.""" storage = LocalStorage() with patch.object(Path, "exists", return_value=False): result = storage.load() assert result == {} def test_save(self): """Test saving embeddings to file.""" test_data = {"alice": [0.1] * 192} storage = LocalStorage() with patch.object(Path, "mkdir") as mock_mkdir: with patch("builtins.open", mock_open()) as mock_file: storage.save(test_data) mock_mkdir.assert_called_once_with(exist_ok=True) mock_file.assert_called_once() class TestHuggingFaceStorage: """Tests for HuggingFaceStorage backend.""" def test_init_with_token(self): """Test initialization with explicit token.""" storage = HuggingFaceStorage(repo_id="user/repo", token="test_token") assert storage.repo_id == "user/repo" assert storage.token == "test_token" def test_init_with_env_token(self): """Test initialization with token from environment.""" with patch.dict(os.environ, {"HF_TOKEN": "env_token"}): storage = HuggingFaceStorage(repo_id="user/repo") assert storage.token == "env_token" def test_load_success(self): """Test loading embeddings from HuggingFace.""" test_data = {"alice": [0.1] * 192} storage = HuggingFaceStorage(repo_id="user/repo", token="token") with patch("storage.hf_hub_download") as mock_download: mock_download.return_value = "/tmp/embeddings.json" with patch("builtins.open", mock_open(read_data=json.dumps(test_data))): result = storage.load() assert result == test_data def test_load_entry_not_found(self): """Test loading when embeddings file doesn't exist on HF.""" from huggingface_hub.utils import EntryNotFoundError storage = HuggingFaceStorage(repo_id="user/repo", token="token") with patch("storage.hf_hub_download") as mock_download: mock_download.side_effect = EntryNotFoundError("Not found") result = storage.load() assert result == {} def test_load_repo_not_found(self): """Test loading when repository doesn't exist.""" storage = HuggingFaceStorage(repo_id="user/repo", token="token") with patch("storage.hf_hub_download") as mock_download: mock_download.side_effect = _create_repo_not_found_error() result = storage.load() assert result == {} def test_load_fallback_to_local_cache(self): """Test loading falls back to local cache on error.""" test_data = {"alice": [0.1] * 192} storage = HuggingFaceStorage(repo_id="user/repo", token="token") with patch("storage.hf_hub_download") as mock_download: mock_download.side_effect = Exception("Network error") with patch.object(Path, "exists", return_value=True): with patch("builtins.open", mock_open(read_data=json.dumps(test_data))): result = storage.load() assert result == test_data def test_load_fallback_no_cache(self): """Test loading returns empty when no cache and HF fails.""" storage = HuggingFaceStorage(repo_id="user/repo", token="token") with patch("storage.hf_hub_download") as mock_download: mock_download.side_effect = Exception("Network error") with patch.object(Path, "exists", return_value=False): result = storage.load() assert result == {} def test_save_success(self): """Test saving embeddings to HuggingFace.""" test_data = {"alice": [0.1] * 192} storage = HuggingFaceStorage(repo_id="user/repo", token="token") with patch.object(Path, "mkdir"): with patch("builtins.open", mock_open()): with patch.object(storage.api, "upload_file") as mock_upload: storage.save(test_data) mock_upload.assert_called_once() def test_save_creates_repo_if_not_exists(self): """Test saving creates repo if it doesn't exist.""" test_data = {"alice": [0.1] * 192} storage = HuggingFaceStorage(repo_id="user/repo", token="token") with patch.object(Path, "mkdir"): with patch("builtins.open", mock_open()): with patch.object(storage.api, "upload_file") as mock_upload: mock_upload.side_effect = [ _create_repo_not_found_error(), None, ] with patch.object(storage.api, "create_repo") as mock_create: storage.save(test_data) mock_create.assert_called_once() def test_save_raises_on_error(self): """Test saving raises exception on persistent error.""" test_data = {"alice": [0.1] * 192} storage = HuggingFaceStorage(repo_id="user/repo", token="token") with patch.object(Path, "mkdir"): with patch("builtins.open", mock_open()): with patch.object(storage.api, "upload_file") as mock_upload: mock_upload.side_effect = Exception("Persistent error") with pytest.raises(Exception, match="Persistent error"): storage.save(test_data) class TestGetStorageBackend: """Tests for get_storage_backend factory function.""" def test_returns_local_storage_by_default(self): """Test returns LocalStorage when no env vars set.""" with patch.dict(os.environ, {}, clear=True): # Remove HF_EMBEDDINGS_REPO if it exists os.environ.pop("HF_EMBEDDINGS_REPO", None) backend = get_storage_backend() assert isinstance(backend, LocalStorage) def test_returns_hf_storage_when_configured(self): """Test returns HuggingFaceStorage when HF_EMBEDDINGS_REPO is set.""" with patch.dict(os.environ, {"HF_EMBEDDINGS_REPO": "user/repo"}): backend = get_storage_backend() assert isinstance(backend, HuggingFaceStorage) assert backend.repo_id == "user/repo"