Spaces:
Sleeping
Sleeping
| """ | |
| Tests for the image generation application. | |
| This module contains unit tests for the various components of the application. | |
| """ | |
| import unittest | |
| from unittest.mock import MagicMock, patch | |
| import os | |
| from pathlib import Path | |
| import numpy as np | |
| from PIL import Image | |
| # Import application modules | |
| from config import MODEL_REPO_ID, MAX_SEED | |
| from model import ModelManager | |
| from utils import save_image, format_generation_info, GenerationHistory | |
| class TestConfig(unittest.TestCase): | |
| """Test the configuration module.""" | |
| def test_config_values(self): | |
| """Test that configuration values are properly set.""" | |
| from config import ( | |
| MODEL_REPO_ID, | |
| DEFAULT_GUIDANCE_SCALE, | |
| DEFAULT_INFERENCE_STEPS, | |
| DEFAULT_WIDTH, | |
| DEFAULT_HEIGHT, | |
| MAX_IMAGE_SIZE, | |
| EXAMPLE_PROMPTS | |
| ) | |
| self.assertEqual(MODEL_REPO_ID, "stabilityai/sdxl-turbo") | |
| self.assertEqual(DEFAULT_GUIDANCE_SCALE, 0.0) | |
| self.assertEqual(DEFAULT_INFERENCE_STEPS, 2) | |
| self.assertEqual(DEFAULT_WIDTH, 1024) | |
| self.assertEqual(DEFAULT_HEIGHT, 1024) | |
| self.assertEqual(MAX_IMAGE_SIZE, 1024) | |
| self.assertIsInstance(EXAMPLE_PROMPTS, list) | |
| self.assertTrue(len(EXAMPLE_PROMPTS) > 0) | |
| class TestModelManager(unittest.TestCase): | |
| """Test the ModelManager class.""" | |
| def test_init(self, mock_pipeline): | |
| """Test ModelManager initialization.""" | |
| manager = ModelManager() | |
| self.assertIn(manager.device, ["cuda", "cpu"]) | |
| self.assertIsNone(manager.pipe) | |
| def test_load_model(self, mock_from_pretrained): | |
| """Test model loading.""" | |
| # Setup mock | |
| mock_pipe = MagicMock() | |
| mock_from_pretrained.return_value = mock_pipe | |
| mock_pipe.to.return_value = mock_pipe | |
| # Test loading | |
| manager = ModelManager() | |
| manager.load_model() | |
| # Verify calls | |
| mock_from_pretrained.assert_called_once_with( | |
| MODEL_REPO_ID, | |
| torch_dtype=manager.torch_dtype | |
| ) | |
| mock_pipe.to.assert_called_once_with(manager.device) | |
| self.assertEqual(manager.pipe, mock_pipe) | |
| def test_generate_image_with_randomize(self, mock_pipeline): | |
| """Test image generation with randomized seed.""" | |
| # Setup mock | |
| manager = ModelManager() | |
| manager.pipe = MagicMock() | |
| mock_image = MagicMock() | |
| manager.pipe.return_value = MagicMock(images=[mock_image]) | |
| # Test generation with randomized seed | |
| prompt = "test prompt" | |
| image, seed = manager.generate_image( | |
| prompt=prompt, | |
| randomize_seed=True | |
| ) | |
| # Verify result | |
| self.assertEqual(image, mock_image) | |
| self.assertGreaterEqual(seed, 0) | |
| self.assertLessEqual(seed, MAX_SEED) | |
| class TestUtils(unittest.TestCase): | |
| """Test utility functions.""" | |
| def setUp(self): | |
| """Set up test environment.""" | |
| # Create a test image | |
| self.test_image = Image.new('RGB', (100, 100), color='red') | |
| # Ensure test output directory exists | |
| from utils import OUTPUTS_DIR | |
| self.test_outputs_dir = OUTPUTS_DIR | |
| self.test_outputs_dir.mkdir(exist_ok=True) | |
| def test_save_image(self): | |
| """Test image saving functionality.""" | |
| prompt = "test image prompt" | |
| filepath = save_image(self.test_image, prompt) | |
| # Check that file was created | |
| self.assertTrue(os.path.exists(filepath)) | |
| self.assertTrue(filepath.endswith(".png")) | |
| # Clean up | |
| os.remove(filepath) | |
| def test_format_generation_info(self): | |
| """Test generation info formatting.""" | |
| prompt = "test prompt" | |
| negative_prompt = "test negative" | |
| seed = 42 | |
| width = 512 | |
| height = 512 | |
| guidance_scale = 7.5 | |
| steps = 30 | |
| info = format_generation_info( | |
| prompt, negative_prompt, seed, width, height, guidance_scale, steps | |
| ) | |
| # Check that all parameters are included in the info string | |
| self.assertIn(prompt, info) | |
| self.assertIn(negative_prompt, info) | |
| self.assertIn(str(seed), info) | |
| self.assertIn(str(width), info) | |
| self.assertIn(str(height), info) | |
| self.assertIn(str(guidance_scale), info) | |
| self.assertIn(str(steps), info) | |
| def test_generation_history(self): | |
| """Test the GenerationHistory class.""" | |
| history = GenerationHistory(max_history=3) | |
| # Test empty history | |
| self.assertEqual(len(history.history), 0) | |
| self.assertEqual(history.get_latest(), []) | |
| # Add entries | |
| for i in range(5): | |
| history.add( | |
| self.test_image, | |
| f"prompt {i}", | |
| f"negative {i}", | |
| i, | |
| 512, | |
| 512, | |
| 7.5, | |
| 30 | |
| ) | |
| # Check that history is limited to max_history | |
| self.assertEqual(len(history.history), 3) | |
| # Check that entries are in correct order (newest last) | |
| latest = history.get_latest(1)[0] | |
| self.assertEqual(latest["prompt"], "prompt 4") | |
| # Test clear | |
| history.clear() | |
| self.assertEqual(len(history.history), 0) | |
| if __name__ == '__main__': | |
| unittest.main() |