LightDiffusion-Next / tests /unit /test_history.py
Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
"""
Unit tests for the HistoryManager module.
Tests cover:
- Seed sanitization edge cases
- History entry normalization
- JSON serialization/deserialization
- Deduplication logic
- Backup rotation
- Search/filter functionality
"""
import os
import json
import tempfile
import shutil
import unittest
from unittest.mock import patch, MagicMock, mock_open
from dataclasses import asdict
import sys
# Use direct import to avoid loading heavy src __init__.py dependencies
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
import importlib.util
spec = importlib.util.spec_from_file_location(
"HistoryManager",
os.path.join(os.path.dirname(__file__), "../../src/FileManaging/HistoryManager.py")
)
HistoryManagerModule = importlib.util.module_from_spec(spec)
sys.modules["HistoryManager"] = HistoryManagerModule
spec.loader.exec_module(HistoryManagerModule)
sanitize_seed_for_display = HistoryManagerModule.sanitize_seed_for_display
HistoryEntry = HistoryManagerModule.HistoryEntry
HistoryManager = HistoryManagerModule.HistoryManager
_parse_float_safe = HistoryManagerModule._parse_float_safe
_parse_int_safe = HistoryManagerModule._parse_int_safe
class TestSanitizeSeedForDisplay(unittest.TestCase):
"""Test cases for seed sanitization function."""
def test_none_input(self):
"""None should return None."""
self.assertIsNone(sanitize_seed_for_display(None))
def test_integer_input(self):
"""Integers should be converted to strings."""
self.assertEqual(sanitize_seed_for_display(12345), "12345")
self.assertEqual(sanitize_seed_for_display(0), "0")
self.assertEqual(sanitize_seed_for_display(-123), "-123")
def test_float_input(self):
"""Floats should be converted to int strings."""
self.assertEqual(sanitize_seed_for_display(12345.0), "12345")
self.assertEqual(sanitize_seed_for_display(12345.7), "12345")
def test_simple_string(self):
"""Simple numeric strings should be preserved."""
self.assertEqual(sanitize_seed_for_display("12345"), "12345")
self.assertEqual(sanitize_seed_for_display(" 12345 "), "12345")
def test_tensor_dump_extraction(self):
"""Tensor dumps should extract numeric content."""
result = sanitize_seed_for_display("tensor(123456789)")
self.assertEqual(result, "123456789")
def test_array_dump_extraction(self):
"""Array-like content should extract numeric content."""
result = sanitize_seed_for_display("[1, 2, 3, seed=987654321]")
self.assertEqual(result, "987654321")
def test_very_long_string_extraction(self):
"""Very long strings should try to extract a seed."""
long_string = "x" * 300 + "12345678" + "y" * 100
result = sanitize_seed_for_display(long_string)
self.assertEqual(result, "12345678")
def test_multiline_string_extraction(self):
"""Multiline strings should try to extract a seed."""
multiline = "line1\nline2\nseed=98765432\nline3"
result = sanitize_seed_for_display(multiline)
self.assertEqual(result, "98765432")
def test_no_numeric_in_garbage(self):
"""Garbage without numeric content should return None."""
self.assertIsNone(sanitize_seed_for_display("tensor(abc, def)"))
def test_short_numeric_not_extracted(self):
"""Numeric tokens shorter than 4 digits should not be extracted from dumps."""
result = sanitize_seed_for_display("tensor(123)")
self.assertIsNone(result)
def test_empty_string(self):
"""Empty string should return None."""
self.assertIsNone(sanitize_seed_for_display(""))
self.assertIsNone(sanitize_seed_for_display(" "))
class TestParseFunctions(unittest.TestCase):
"""Test helper parsing functions."""
def test_parse_float_safe_valid(self):
"""Valid float inputs."""
self.assertEqual(_parse_float_safe(3.14), 3.14)
self.assertEqual(_parse_float_safe("3.14"), 3.14)
self.assertEqual(_parse_float_safe(42), 42.0)
def test_parse_float_safe_with_suffix(self):
"""Float strings with 's' suffix (like '10.5s')."""
self.assertEqual(_parse_float_safe("10.5s"), 10.5)
self.assertEqual(_parse_float_safe("3.2s"), 3.2)
def test_parse_float_safe_invalid(self):
"""Invalid float inputs should return None."""
self.assertIsNone(_parse_float_safe(None))
self.assertIsNone(_parse_float_safe("abc"))
self.assertIsNone(_parse_float_safe([1, 2, 3]))
def test_parse_int_safe_valid(self):
"""Valid int inputs."""
self.assertEqual(_parse_int_safe(42), 42)
self.assertEqual(_parse_int_safe("42"), 42)
self.assertEqual(_parse_int_safe(3.7), 3)
def test_parse_int_safe_invalid(self):
"""Invalid int inputs should return None."""
self.assertIsNone(_parse_int_safe(None))
self.assertIsNone(_parse_int_safe("abc"))
class TestHistoryEntry(unittest.TestCase):
"""Test HistoryEntry dataclass."""
def test_to_dict(self):
"""Entry should serialize to dictionary."""
entry = HistoryEntry(
timestamp="2024-01-01 12:00:00",
image_path="/path/to/image.png",
prompt="test prompt",
seed="12345"
)
d = entry.to_dict()
self.assertEqual(d["timestamp"], "2024-01-01 12:00:00")
self.assertEqual(d["prompt"], "test prompt")
self.assertEqual(d["seed"], "12345")
def test_from_dict(self):
"""Entry should deserialize from dictionary."""
data = {
"timestamp": "2024-01-01 12:00:00",
"image_path": "/path/to/image.png",
"prompt": "test prompt",
"extra_field": "ignored"
}
entry = HistoryEntry.from_dict(data)
self.assertEqual(entry.timestamp, "2024-01-01 12:00:00")
self.assertEqual(entry.prompt, "test prompt")
def test_default_values(self):
"""Entry should have sensible defaults."""
entry = HistoryEntry(timestamp="", image_path="")
self.assertEqual(entry.prompt, "")
self.assertIsNone(entry.width)
self.assertEqual(entry.png_metadata, {})
class TestHistoryManager(unittest.TestCase):
"""Test HistoryManager class."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.history_file = os.path.join(self.temp_dir, "test_history.json")
self.manager = HistoryManager(history_file=self.history_file)
def tearDown(self):
"""Clean up test fixtures."""
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_load_empty_history(self):
"""Loading non-existent file should return empty list."""
entries = self.manager.load()
self.assertEqual(entries, [])
def test_save_and_load(self):
"""Saved entries should be loadable."""
entry = HistoryEntry(
timestamp="2024-01-01 12:00:00",
image_path="/path/to/image.png",
prompt="test"
)
self.manager.save([entry])
# Force cache invalidation
loaded = self.manager.load(use_cache=False)
self.assertEqual(len(loaded), 1)
self.assertEqual(loaded[0].prompt, "test")
def test_deduplication_by_path(self):
"""Duplicate image paths should be deduplicated."""
data = [
{"timestamp": "2024-01-01", "image_path": "/path/a.png", "prompt": "first"},
{"timestamp": "2024-01-02", "image_path": "/path/a.png", "prompt": "duplicate"},
{"timestamp": "2024-01-03", "image_path": "/path/b.png", "prompt": "second"},
]
with open(self.history_file, "w") as f:
json.dump(data, f)
entries = self.manager.load(use_cache=False)
self.assertEqual(len(entries), 2)
paths = [e.image_path for e in entries]
self.assertIn("/path/a.png", paths)
self.assertIn("/path/b.png", paths)
def test_max_entries_limit(self):
"""History should be limited to MAX_HISTORY_ENTRIES."""
data = [
{"timestamp": f"2024-01-{i:02d}", "image_path": f"/path/{i}.png"}
for i in range(150)
]
with open(self.history_file, "w") as f:
json.dump(data, f)
entries = self.manager.load(use_cache=False)
self.assertEqual(len(entries), 100)
def test_cache_behavior(self):
"""Cache should be used when valid."""
entry = HistoryEntry(timestamp="2024-01-01", image_path="/path/a.png")
self.manager.save([entry])
# First load populates cache
self.manager.load()
# Modify file time to simulate no change
original_mtime = self.manager._cache_mtime
# Second load should use cache
entries = self.manager.load(use_cache=True)
self.assertEqual(len(entries), 1)
def test_add_entry(self):
"""Adding an entry should prepend to history."""
entry1 = HistoryEntry(timestamp="2024-01-01", image_path="/path/a.png")
entry2 = HistoryEntry(timestamp="2024-01-02", image_path="/path/b.png")
self.manager.save([entry1])
self.manager.add_entry(entry2)
entries = self.manager.load(use_cache=False)
self.assertEqual(len(entries), 2)
self.assertEqual(entries[0].image_path, "/path/b.png")
def test_delete_entry(self):
"""Deleting an entry should remove it from history."""
entry1 = HistoryEntry(timestamp="2024-01-01", image_path="/path/a.png")
entry2 = HistoryEntry(timestamp="2024-01-02", image_path="/path/b.png")
self.manager.save([entry1, entry2])
self.manager.delete_entry(0)
entries = self.manager.load(use_cache=False)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].image_path, "/path/b.png")
def test_clear_history(self):
"""Clearing should remove all entries."""
entries = [
HistoryEntry(timestamp=f"2024-01-{i:02d}", image_path=f"/path/{i}.png")
for i in range(5)
]
self.manager.save(entries)
self.manager.clear(delete_files=False)
loaded = self.manager.load(use_cache=False)
self.assertEqual(len(loaded), 0)
def test_search_by_keyword(self):
"""Search should filter by keyword in prompt."""
entries = [
HistoryEntry(timestamp="2024-01-01", image_path="/a.png", prompt="cat photo"),
HistoryEntry(timestamp="2024-01-02", image_path="/b.png", prompt="dog photo"),
HistoryEntry(timestamp="2024-01-03", image_path="/c.png", prompt="cat and dog"),
]
self.manager.save(entries)
results = self.manager.search(keyword="cat")
self.assertEqual(len(results), 2)
results = self.manager.search(keyword="dog")
self.assertEqual(len(results), 2)
results = self.manager.search(keyword="bird")
self.assertEqual(len(results), 0)
def test_search_by_model_type(self):
"""Search should filter by model type."""
entries = [
HistoryEntry(timestamp="2024-01-01", image_path="/a.png", model_type="SD15"),
HistoryEntry(timestamp="2024-01-02", image_path="/b.png", model_type="SDXL"),
HistoryEntry(timestamp="2024-01-03", image_path="/c.png", model_type="SD15"),
]
self.manager.save(entries)
results = self.manager.search(model_type="SD15")
self.assertEqual(len(results), 2)
results = self.manager.search(model_type="SDXL")
self.assertEqual(len(results), 1)
def test_search_by_date_range(self):
"""Search should filter by date range."""
entries = [
HistoryEntry(timestamp="2024-01-15 10:00:00", image_path="/a.png"),
HistoryEntry(timestamp="2024-02-15 10:00:00", image_path="/b.png"),
HistoryEntry(timestamp="2024-03-15 10:00:00", image_path="/c.png"),
]
self.manager.save(entries)
results = self.manager.search(date_from="2024-02-01")
self.assertEqual(len(results), 2)
results = self.manager.search(date_to="2024-02-01")
self.assertEqual(len(results), 1)
results = self.manager.search(date_from="2024-02-01", date_to="2024-02-28")
self.assertEqual(len(results), 1)
def test_get_model_types(self):
"""Should return unique model types."""
entries = [
HistoryEntry(timestamp="2024-01-01", image_path="/a.png", model_type="SD15"),
HistoryEntry(timestamp="2024-01-02", image_path="/b.png", model_type="SDXL"),
HistoryEntry(timestamp="2024-01-03", image_path="/c.png", model_type="SD15"),
HistoryEntry(timestamp="2024-01-04", image_path="/d.png"),
]
self.manager.save(entries)
types = self.manager.get_model_types()
self.assertEqual(sorted(types), ["SD15", "SDXL"])
class TestBackupRotation(unittest.TestCase):
"""Test backup rotation functionality."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.history_file = os.path.join(self.temp_dir, "test_history.json")
self.backup_dir = os.path.join(self.temp_dir, ".history_backups")
# Patch BACKUP_DIR using the imported module reference
self.backup_dir_patcher = patch.object(
HistoryManagerModule, 'BACKUP_DIR',
self.backup_dir
)
self.backup_dir_patcher.start()
self.manager = HistoryManager(history_file=self.history_file)
def tearDown(self):
"""Clean up test fixtures."""
self.backup_dir_patcher.stop()
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_backup_created_on_save(self):
"""Backup should be created when saving."""
# First save
entry = HistoryEntry(timestamp="2024-01-01", image_path="/a.png")
self.manager.save([entry])
# Second save should create backup
entry2 = HistoryEntry(timestamp="2024-01-02", image_path="/b.png")
self.manager.save([entry, entry2])
backup_file = os.path.join(self.backup_dir, "history_backup_1.json")
self.assertTrue(os.path.exists(backup_file))
def test_restore_from_backup(self):
"""Should restore from backup when main file is corrupt."""
# Create valid backup
os.makedirs(self.backup_dir, exist_ok=True)
backup_data = [{"timestamp": "2024-01-01", "image_path": "/backup.png"}]
backup_file = os.path.join(self.backup_dir, "history_backup_1.json")
with open(backup_file, "w") as f:
json.dump(backup_data, f)
# Create corrupt main file
with open(self.history_file, "w") as f:
f.write("not valid json {{{")
entries = self.manager.load(use_cache=False)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].image_path, "/backup.png")
if __name__ == "__main__":
unittest.main()