Innovideo / test_app.py
Bossmarc747's picture
oj
32cd713
"""
Tests for the image generation application.
This module contains unit tests for the various components of the application.
"""
import unittest
from unittest.mock import MagicMock, patch
import os
from pathlib import Path
import numpy as np
from PIL import Image
# Import application modules
from config import MODEL_REPO_ID, MAX_SEED
from model import ModelManager
from utils import save_image, format_generation_info, GenerationHistory
class TestConfig(unittest.TestCase):
"""Test the configuration module."""
def test_config_values(self):
"""Test that configuration values are properly set."""
from config import (
MODEL_REPO_ID,
DEFAULT_GUIDANCE_SCALE,
DEFAULT_INFERENCE_STEPS,
DEFAULT_WIDTH,
DEFAULT_HEIGHT,
MAX_IMAGE_SIZE,
EXAMPLE_PROMPTS
)
self.assertEqual(MODEL_REPO_ID, "stabilityai/sdxl-turbo")
self.assertEqual(DEFAULT_GUIDANCE_SCALE, 0.0)
self.assertEqual(DEFAULT_INFERENCE_STEPS, 2)
self.assertEqual(DEFAULT_WIDTH, 1024)
self.assertEqual(DEFAULT_HEIGHT, 1024)
self.assertEqual(MAX_IMAGE_SIZE, 1024)
self.assertIsInstance(EXAMPLE_PROMPTS, list)
self.assertTrue(len(EXAMPLE_PROMPTS) > 0)
class TestModelManager(unittest.TestCase):
"""Test the ModelManager class."""
@patch('model.DiffusionPipeline')
def test_init(self, mock_pipeline):
"""Test ModelManager initialization."""
manager = ModelManager()
self.assertIn(manager.device, ["cuda", "cpu"])
self.assertIsNone(manager.pipe)
@patch('model.DiffusionPipeline.from_pretrained')
def test_load_model(self, mock_from_pretrained):
"""Test model loading."""
# Setup mock
mock_pipe = MagicMock()
mock_from_pretrained.return_value = mock_pipe
mock_pipe.to.return_value = mock_pipe
# Test loading
manager = ModelManager()
manager.load_model()
# Verify calls
mock_from_pretrained.assert_called_once_with(
MODEL_REPO_ID,
torch_dtype=manager.torch_dtype
)
mock_pipe.to.assert_called_once_with(manager.device)
self.assertEqual(manager.pipe, mock_pipe)
@patch('model.DiffusionPipeline')
def test_generate_image_with_randomize(self, mock_pipeline):
"""Test image generation with randomized seed."""
# Setup mock
manager = ModelManager()
manager.pipe = MagicMock()
mock_image = MagicMock()
manager.pipe.return_value = MagicMock(images=[mock_image])
# Test generation with randomized seed
prompt = "test prompt"
image, seed = manager.generate_image(
prompt=prompt,
randomize_seed=True
)
# Verify result
self.assertEqual(image, mock_image)
self.assertGreaterEqual(seed, 0)
self.assertLessEqual(seed, MAX_SEED)
class TestUtils(unittest.TestCase):
"""Test utility functions."""
def setUp(self):
"""Set up test environment."""
# Create a test image
self.test_image = Image.new('RGB', (100, 100), color='red')
# Ensure test output directory exists
from utils import OUTPUTS_DIR
self.test_outputs_dir = OUTPUTS_DIR
self.test_outputs_dir.mkdir(exist_ok=True)
def test_save_image(self):
"""Test image saving functionality."""
prompt = "test image prompt"
filepath = save_image(self.test_image, prompt)
# Check that file was created
self.assertTrue(os.path.exists(filepath))
self.assertTrue(filepath.endswith(".png"))
# Clean up
os.remove(filepath)
def test_format_generation_info(self):
"""Test generation info formatting."""
prompt = "test prompt"
negative_prompt = "test negative"
seed = 42
width = 512
height = 512
guidance_scale = 7.5
steps = 30
info = format_generation_info(
prompt, negative_prompt, seed, width, height, guidance_scale, steps
)
# Check that all parameters are included in the info string
self.assertIn(prompt, info)
self.assertIn(negative_prompt, info)
self.assertIn(str(seed), info)
self.assertIn(str(width), info)
self.assertIn(str(height), info)
self.assertIn(str(guidance_scale), info)
self.assertIn(str(steps), info)
def test_generation_history(self):
"""Test the GenerationHistory class."""
history = GenerationHistory(max_history=3)
# Test empty history
self.assertEqual(len(history.history), 0)
self.assertEqual(history.get_latest(), [])
# Add entries
for i in range(5):
history.add(
self.test_image,
f"prompt {i}",
f"negative {i}",
i,
512,
512,
7.5,
30
)
# Check that history is limited to max_history
self.assertEqual(len(history.history), 3)
# Check that entries are in correct order (newest last)
latest = history.get_latest(1)[0]
self.assertEqual(latest["prompt"], "prompt 4")
# Test clear
history.clear()
self.assertEqual(len(history.history), 0)
if __name__ == '__main__':
unittest.main()