Spaces:
Running on Zero
Running on Zero
| """ | |
| Unit tests for the HistoryManager module. | |
| Tests cover: | |
| - Seed sanitization edge cases | |
| - History entry normalization | |
| - JSON serialization/deserialization | |
| - Deduplication logic | |
| - Backup rotation | |
| - Search/filter functionality | |
| """ | |
| import os | |
| import json | |
| import tempfile | |
| import shutil | |
| import unittest | |
| from unittest.mock import patch, MagicMock, mock_open | |
| from dataclasses import asdict | |
| import sys | |
| # Use direct import to avoid loading heavy src __init__.py dependencies | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) | |
| import importlib.util | |
| spec = importlib.util.spec_from_file_location( | |
| "HistoryManager", | |
| os.path.join(os.path.dirname(__file__), "../../src/FileManaging/HistoryManager.py") | |
| ) | |
| HistoryManagerModule = importlib.util.module_from_spec(spec) | |
| sys.modules["HistoryManager"] = HistoryManagerModule | |
| spec.loader.exec_module(HistoryManagerModule) | |
| sanitize_seed_for_display = HistoryManagerModule.sanitize_seed_for_display | |
| HistoryEntry = HistoryManagerModule.HistoryEntry | |
| HistoryManager = HistoryManagerModule.HistoryManager | |
| _parse_float_safe = HistoryManagerModule._parse_float_safe | |
| _parse_int_safe = HistoryManagerModule._parse_int_safe | |
| class TestSanitizeSeedForDisplay(unittest.TestCase): | |
| """Test cases for seed sanitization function.""" | |
| def test_none_input(self): | |
| """None should return None.""" | |
| self.assertIsNone(sanitize_seed_for_display(None)) | |
| def test_integer_input(self): | |
| """Integers should be converted to strings.""" | |
| self.assertEqual(sanitize_seed_for_display(12345), "12345") | |
| self.assertEqual(sanitize_seed_for_display(0), "0") | |
| self.assertEqual(sanitize_seed_for_display(-123), "-123") | |
| def test_float_input(self): | |
| """Floats should be converted to int strings.""" | |
| self.assertEqual(sanitize_seed_for_display(12345.0), "12345") | |
| self.assertEqual(sanitize_seed_for_display(12345.7), "12345") | |
| def test_simple_string(self): | |
| """Simple numeric strings should be preserved.""" | |
| self.assertEqual(sanitize_seed_for_display("12345"), "12345") | |
| self.assertEqual(sanitize_seed_for_display(" 12345 "), "12345") | |
| def test_tensor_dump_extraction(self): | |
| """Tensor dumps should extract numeric content.""" | |
| result = sanitize_seed_for_display("tensor(123456789)") | |
| self.assertEqual(result, "123456789") | |
| def test_array_dump_extraction(self): | |
| """Array-like content should extract numeric content.""" | |
| result = sanitize_seed_for_display("[1, 2, 3, seed=987654321]") | |
| self.assertEqual(result, "987654321") | |
| def test_very_long_string_extraction(self): | |
| """Very long strings should try to extract a seed.""" | |
| long_string = "x" * 300 + "12345678" + "y" * 100 | |
| result = sanitize_seed_for_display(long_string) | |
| self.assertEqual(result, "12345678") | |
| def test_multiline_string_extraction(self): | |
| """Multiline strings should try to extract a seed.""" | |
| multiline = "line1\nline2\nseed=98765432\nline3" | |
| result = sanitize_seed_for_display(multiline) | |
| self.assertEqual(result, "98765432") | |
| def test_no_numeric_in_garbage(self): | |
| """Garbage without numeric content should return None.""" | |
| self.assertIsNone(sanitize_seed_for_display("tensor(abc, def)")) | |
| def test_short_numeric_not_extracted(self): | |
| """Numeric tokens shorter than 4 digits should not be extracted from dumps.""" | |
| result = sanitize_seed_for_display("tensor(123)") | |
| self.assertIsNone(result) | |
| def test_empty_string(self): | |
| """Empty string should return None.""" | |
| self.assertIsNone(sanitize_seed_for_display("")) | |
| self.assertIsNone(sanitize_seed_for_display(" ")) | |
| class TestParseFunctions(unittest.TestCase): | |
| """Test helper parsing functions.""" | |
| def test_parse_float_safe_valid(self): | |
| """Valid float inputs.""" | |
| self.assertEqual(_parse_float_safe(3.14), 3.14) | |
| self.assertEqual(_parse_float_safe("3.14"), 3.14) | |
| self.assertEqual(_parse_float_safe(42), 42.0) | |
| def test_parse_float_safe_with_suffix(self): | |
| """Float strings with 's' suffix (like '10.5s').""" | |
| self.assertEqual(_parse_float_safe("10.5s"), 10.5) | |
| self.assertEqual(_parse_float_safe("3.2s"), 3.2) | |
| def test_parse_float_safe_invalid(self): | |
| """Invalid float inputs should return None.""" | |
| self.assertIsNone(_parse_float_safe(None)) | |
| self.assertIsNone(_parse_float_safe("abc")) | |
| self.assertIsNone(_parse_float_safe([1, 2, 3])) | |
| def test_parse_int_safe_valid(self): | |
| """Valid int inputs.""" | |
| self.assertEqual(_parse_int_safe(42), 42) | |
| self.assertEqual(_parse_int_safe("42"), 42) | |
| self.assertEqual(_parse_int_safe(3.7), 3) | |
| def test_parse_int_safe_invalid(self): | |
| """Invalid int inputs should return None.""" | |
| self.assertIsNone(_parse_int_safe(None)) | |
| self.assertIsNone(_parse_int_safe("abc")) | |
| class TestHistoryEntry(unittest.TestCase): | |
| """Test HistoryEntry dataclass.""" | |
| def test_to_dict(self): | |
| """Entry should serialize to dictionary.""" | |
| entry = HistoryEntry( | |
| timestamp="2024-01-01 12:00:00", | |
| image_path="/path/to/image.png", | |
| prompt="test prompt", | |
| seed="12345" | |
| ) | |
| d = entry.to_dict() | |
| self.assertEqual(d["timestamp"], "2024-01-01 12:00:00") | |
| self.assertEqual(d["prompt"], "test prompt") | |
| self.assertEqual(d["seed"], "12345") | |
| def test_from_dict(self): | |
| """Entry should deserialize from dictionary.""" | |
| data = { | |
| "timestamp": "2024-01-01 12:00:00", | |
| "image_path": "/path/to/image.png", | |
| "prompt": "test prompt", | |
| "extra_field": "ignored" | |
| } | |
| entry = HistoryEntry.from_dict(data) | |
| self.assertEqual(entry.timestamp, "2024-01-01 12:00:00") | |
| self.assertEqual(entry.prompt, "test prompt") | |
| def test_default_values(self): | |
| """Entry should have sensible defaults.""" | |
| entry = HistoryEntry(timestamp="", image_path="") | |
| self.assertEqual(entry.prompt, "") | |
| self.assertIsNone(entry.width) | |
| self.assertEqual(entry.png_metadata, {}) | |
| class TestHistoryManager(unittest.TestCase): | |
| """Test HistoryManager class.""" | |
| def setUp(self): | |
| """Set up test fixtures.""" | |
| self.temp_dir = tempfile.mkdtemp() | |
| self.history_file = os.path.join(self.temp_dir, "test_history.json") | |
| self.manager = HistoryManager(history_file=self.history_file) | |
| def tearDown(self): | |
| """Clean up test fixtures.""" | |
| shutil.rmtree(self.temp_dir, ignore_errors=True) | |
| def test_load_empty_history(self): | |
| """Loading non-existent file should return empty list.""" | |
| entries = self.manager.load() | |
| self.assertEqual(entries, []) | |
| def test_save_and_load(self): | |
| """Saved entries should be loadable.""" | |
| entry = HistoryEntry( | |
| timestamp="2024-01-01 12:00:00", | |
| image_path="/path/to/image.png", | |
| prompt="test" | |
| ) | |
| self.manager.save([entry]) | |
| # Force cache invalidation | |
| loaded = self.manager.load(use_cache=False) | |
| self.assertEqual(len(loaded), 1) | |
| self.assertEqual(loaded[0].prompt, "test") | |
| def test_deduplication_by_path(self): | |
| """Duplicate image paths should be deduplicated.""" | |
| data = [ | |
| {"timestamp": "2024-01-01", "image_path": "/path/a.png", "prompt": "first"}, | |
| {"timestamp": "2024-01-02", "image_path": "/path/a.png", "prompt": "duplicate"}, | |
| {"timestamp": "2024-01-03", "image_path": "/path/b.png", "prompt": "second"}, | |
| ] | |
| with open(self.history_file, "w") as f: | |
| json.dump(data, f) | |
| entries = self.manager.load(use_cache=False) | |
| self.assertEqual(len(entries), 2) | |
| paths = [e.image_path for e in entries] | |
| self.assertIn("/path/a.png", paths) | |
| self.assertIn("/path/b.png", paths) | |
| def test_max_entries_limit(self): | |
| """History should be limited to MAX_HISTORY_ENTRIES.""" | |
| data = [ | |
| {"timestamp": f"2024-01-{i:02d}", "image_path": f"/path/{i}.png"} | |
| for i in range(150) | |
| ] | |
| with open(self.history_file, "w") as f: | |
| json.dump(data, f) | |
| entries = self.manager.load(use_cache=False) | |
| self.assertEqual(len(entries), 100) | |
| def test_cache_behavior(self): | |
| """Cache should be used when valid.""" | |
| entry = HistoryEntry(timestamp="2024-01-01", image_path="/path/a.png") | |
| self.manager.save([entry]) | |
| # First load populates cache | |
| self.manager.load() | |
| # Modify file time to simulate no change | |
| original_mtime = self.manager._cache_mtime | |
| # Second load should use cache | |
| entries = self.manager.load(use_cache=True) | |
| self.assertEqual(len(entries), 1) | |
| def test_add_entry(self): | |
| """Adding an entry should prepend to history.""" | |
| entry1 = HistoryEntry(timestamp="2024-01-01", image_path="/path/a.png") | |
| entry2 = HistoryEntry(timestamp="2024-01-02", image_path="/path/b.png") | |
| self.manager.save([entry1]) | |
| self.manager.add_entry(entry2) | |
| entries = self.manager.load(use_cache=False) | |
| self.assertEqual(len(entries), 2) | |
| self.assertEqual(entries[0].image_path, "/path/b.png") | |
| def test_delete_entry(self): | |
| """Deleting an entry should remove it from history.""" | |
| entry1 = HistoryEntry(timestamp="2024-01-01", image_path="/path/a.png") | |
| entry2 = HistoryEntry(timestamp="2024-01-02", image_path="/path/b.png") | |
| self.manager.save([entry1, entry2]) | |
| self.manager.delete_entry(0) | |
| entries = self.manager.load(use_cache=False) | |
| self.assertEqual(len(entries), 1) | |
| self.assertEqual(entries[0].image_path, "/path/b.png") | |
| def test_clear_history(self): | |
| """Clearing should remove all entries.""" | |
| entries = [ | |
| HistoryEntry(timestamp=f"2024-01-{i:02d}", image_path=f"/path/{i}.png") | |
| for i in range(5) | |
| ] | |
| self.manager.save(entries) | |
| self.manager.clear(delete_files=False) | |
| loaded = self.manager.load(use_cache=False) | |
| self.assertEqual(len(loaded), 0) | |
| def test_search_by_keyword(self): | |
| """Search should filter by keyword in prompt.""" | |
| entries = [ | |
| HistoryEntry(timestamp="2024-01-01", image_path="/a.png", prompt="cat photo"), | |
| HistoryEntry(timestamp="2024-01-02", image_path="/b.png", prompt="dog photo"), | |
| HistoryEntry(timestamp="2024-01-03", image_path="/c.png", prompt="cat and dog"), | |
| ] | |
| self.manager.save(entries) | |
| results = self.manager.search(keyword="cat") | |
| self.assertEqual(len(results), 2) | |
| results = self.manager.search(keyword="dog") | |
| self.assertEqual(len(results), 2) | |
| results = self.manager.search(keyword="bird") | |
| self.assertEqual(len(results), 0) | |
| def test_search_by_model_type(self): | |
| """Search should filter by model type.""" | |
| entries = [ | |
| HistoryEntry(timestamp="2024-01-01", image_path="/a.png", model_type="SD15"), | |
| HistoryEntry(timestamp="2024-01-02", image_path="/b.png", model_type="SDXL"), | |
| HistoryEntry(timestamp="2024-01-03", image_path="/c.png", model_type="SD15"), | |
| ] | |
| self.manager.save(entries) | |
| results = self.manager.search(model_type="SD15") | |
| self.assertEqual(len(results), 2) | |
| results = self.manager.search(model_type="SDXL") | |
| self.assertEqual(len(results), 1) | |
| def test_search_by_date_range(self): | |
| """Search should filter by date range.""" | |
| entries = [ | |
| HistoryEntry(timestamp="2024-01-15 10:00:00", image_path="/a.png"), | |
| HistoryEntry(timestamp="2024-02-15 10:00:00", image_path="/b.png"), | |
| HistoryEntry(timestamp="2024-03-15 10:00:00", image_path="/c.png"), | |
| ] | |
| self.manager.save(entries) | |
| results = self.manager.search(date_from="2024-02-01") | |
| self.assertEqual(len(results), 2) | |
| results = self.manager.search(date_to="2024-02-01") | |
| self.assertEqual(len(results), 1) | |
| results = self.manager.search(date_from="2024-02-01", date_to="2024-02-28") | |
| self.assertEqual(len(results), 1) | |
| def test_get_model_types(self): | |
| """Should return unique model types.""" | |
| entries = [ | |
| HistoryEntry(timestamp="2024-01-01", image_path="/a.png", model_type="SD15"), | |
| HistoryEntry(timestamp="2024-01-02", image_path="/b.png", model_type="SDXL"), | |
| HistoryEntry(timestamp="2024-01-03", image_path="/c.png", model_type="SD15"), | |
| HistoryEntry(timestamp="2024-01-04", image_path="/d.png"), | |
| ] | |
| self.manager.save(entries) | |
| types = self.manager.get_model_types() | |
| self.assertEqual(sorted(types), ["SD15", "SDXL"]) | |
| class TestBackupRotation(unittest.TestCase): | |
| """Test backup rotation functionality.""" | |
| def setUp(self): | |
| """Set up test fixtures.""" | |
| self.temp_dir = tempfile.mkdtemp() | |
| self.history_file = os.path.join(self.temp_dir, "test_history.json") | |
| self.backup_dir = os.path.join(self.temp_dir, ".history_backups") | |
| # Patch BACKUP_DIR using the imported module reference | |
| self.backup_dir_patcher = patch.object( | |
| HistoryManagerModule, 'BACKUP_DIR', | |
| self.backup_dir | |
| ) | |
| self.backup_dir_patcher.start() | |
| self.manager = HistoryManager(history_file=self.history_file) | |
| def tearDown(self): | |
| """Clean up test fixtures.""" | |
| self.backup_dir_patcher.stop() | |
| shutil.rmtree(self.temp_dir, ignore_errors=True) | |
| def test_backup_created_on_save(self): | |
| """Backup should be created when saving.""" | |
| # First save | |
| entry = HistoryEntry(timestamp="2024-01-01", image_path="/a.png") | |
| self.manager.save([entry]) | |
| # Second save should create backup | |
| entry2 = HistoryEntry(timestamp="2024-01-02", image_path="/b.png") | |
| self.manager.save([entry, entry2]) | |
| backup_file = os.path.join(self.backup_dir, "history_backup_1.json") | |
| self.assertTrue(os.path.exists(backup_file)) | |
| def test_restore_from_backup(self): | |
| """Should restore from backup when main file is corrupt.""" | |
| # Create valid backup | |
| os.makedirs(self.backup_dir, exist_ok=True) | |
| backup_data = [{"timestamp": "2024-01-01", "image_path": "/backup.png"}] | |
| backup_file = os.path.join(self.backup_dir, "history_backup_1.json") | |
| with open(backup_file, "w") as f: | |
| json.dump(backup_data, f) | |
| # Create corrupt main file | |
| with open(self.history_file, "w") as f: | |
| f.write("not valid json {{{") | |
| entries = self.manager.load(use_cache=False) | |
| self.assertEqual(len(entries), 1) | |
| self.assertEqual(entries[0].image_path, "/backup.png") | |
| if __name__ == "__main__": | |
| unittest.main() | |