|
|
""" |
|
|
A test script to test the gRPC service and dynamic loader |
|
|
""" |
|
|
import unittest |
|
|
import subprocess |
|
|
import time |
|
|
from unittest.mock import patch, MagicMock |
|
|
|
|
|
|
|
|
import diffusers_dynamic_loader as loader |
|
|
from diffusers import DiffusionPipeline, StableDiffusionPipeline |
|
|
|
|
|
|
|
|
try: |
|
|
import grpc |
|
|
import backend_pb2 |
|
|
import backend_pb2_grpc |
|
|
GRPC_AVAILABLE = True |
|
|
except ImportError: |
|
|
GRPC_AVAILABLE = False |
|
|
|
|
|
|
|
|
@unittest.skipUnless(GRPC_AVAILABLE, "gRPC modules not available") |
|
|
class TestBackendServicer(unittest.TestCase): |
|
|
""" |
|
|
TestBackendServicer is the class that tests the gRPC service |
|
|
""" |
|
|
def setUp(self): |
|
|
""" |
|
|
This method sets up the gRPC service by starting the server |
|
|
""" |
|
|
self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"]) |
|
|
|
|
|
def tearDown(self) -> None: |
|
|
""" |
|
|
This method tears down the gRPC service by terminating the server |
|
|
""" |
|
|
self.service.kill() |
|
|
self.service.wait() |
|
|
|
|
|
def test_server_startup(self): |
|
|
""" |
|
|
This method tests if the server starts up successfully |
|
|
""" |
|
|
time.sleep(20) |
|
|
try: |
|
|
self.setUp() |
|
|
with grpc.insecure_channel("localhost:50051") as channel: |
|
|
stub = backend_pb2_grpc.BackendStub(channel) |
|
|
response = stub.Health(backend_pb2.HealthMessage()) |
|
|
self.assertEqual(response.message, b'OK') |
|
|
except Exception as err: |
|
|
print(err) |
|
|
self.fail("Server failed to start") |
|
|
finally: |
|
|
self.tearDown() |
|
|
|
|
|
def test_load_model(self): |
|
|
""" |
|
|
This method tests if the model is loaded successfully |
|
|
""" |
|
|
time.sleep(20) |
|
|
try: |
|
|
self.setUp() |
|
|
with grpc.insecure_channel("localhost:50051") as channel: |
|
|
stub = backend_pb2_grpc.BackendStub(channel) |
|
|
response = stub.LoadModel(backend_pb2.ModelOptions(Model="Lykon/dreamshaper-8")) |
|
|
self.assertTrue(response.success) |
|
|
self.assertEqual(response.message, "Model loaded successfully") |
|
|
except Exception as err: |
|
|
print(err) |
|
|
self.fail("LoadModel service failed") |
|
|
finally: |
|
|
self.tearDown() |
|
|
|
|
|
def test(self): |
|
|
""" |
|
|
This method tests if the backend can generate images |
|
|
""" |
|
|
time.sleep(20) |
|
|
try: |
|
|
self.setUp() |
|
|
with grpc.insecure_channel("localhost:50051") as channel: |
|
|
stub = backend_pb2_grpc.BackendStub(channel) |
|
|
response = stub.LoadModel(backend_pb2.ModelOptions(Model="Lykon/dreamshaper-8")) |
|
|
print(response.message) |
|
|
self.assertTrue(response.success) |
|
|
image_req = backend_pb2.GenerateImageRequest(positive_prompt="cat", width=16,height=16, dst="test.jpg") |
|
|
re = stub.GenerateImage(image_req) |
|
|
self.assertTrue(re.success) |
|
|
except Exception as err: |
|
|
print(err) |
|
|
self.fail("Image gen service failed") |
|
|
finally: |
|
|
self.tearDown() |
|
|
|
|
|
|
|
|
class TestDiffusersDynamicLoader(unittest.TestCase): |
|
|
"""Test cases for the diffusers dynamic loader functionality.""" |
|
|
|
|
|
@classmethod |
|
|
def setUpClass(cls): |
|
|
"""Set up test fixtures - clear caches to ensure fresh discovery.""" |
|
|
|
|
|
loader._pipeline_registry = None |
|
|
loader._task_aliases = None |
|
|
|
|
|
def test_camel_to_kebab_conversion(self): |
|
|
"""Test CamelCase to kebab-case conversion.""" |
|
|
test_cases = [ |
|
|
("StableDiffusionPipeline", "stable-diffusion-pipeline"), |
|
|
("StableDiffusionXLPipeline", "stable-diffusion-xl-pipeline"), |
|
|
("FluxPipeline", "flux-pipeline"), |
|
|
("DiffusionPipeline", "diffusion-pipeline"), |
|
|
] |
|
|
for input_val, expected in test_cases: |
|
|
with self.subTest(input=input_val): |
|
|
result = loader._camel_to_kebab(input_val) |
|
|
self.assertEqual(result, expected) |
|
|
|
|
|
def test_extract_task_keywords(self): |
|
|
"""Test task keyword extraction from class names.""" |
|
|
|
|
|
aliases = loader._extract_task_keywords("StableDiffusionPipeline") |
|
|
self.assertIn("stable-diffusion", aliases) |
|
|
|
|
|
|
|
|
aliases = loader._extract_task_keywords("StableDiffusionImg2ImgPipeline") |
|
|
self.assertIn("image-to-image", aliases) |
|
|
self.assertIn("img2img", aliases) |
|
|
|
|
|
|
|
|
aliases = loader._extract_task_keywords("StableDiffusionInpaintPipeline") |
|
|
self.assertIn("inpainting", aliases) |
|
|
self.assertIn("inpaint", aliases) |
|
|
|
|
|
|
|
|
aliases = loader._extract_task_keywords("StableDiffusionDepth2ImgPipeline") |
|
|
self.assertIn("depth-to-image", aliases) |
|
|
|
|
|
def test_discover_pipelines_finds_known_classes(self): |
|
|
"""Test that pipeline discovery finds at least one known pipeline class.""" |
|
|
registry = loader.get_pipeline_registry() |
|
|
|
|
|
|
|
|
self.assertGreater(len(registry), 0, "Pipeline registry should not be empty") |
|
|
|
|
|
|
|
|
known_pipelines = [ |
|
|
"StableDiffusionPipeline", |
|
|
"DiffusionPipeline", |
|
|
] |
|
|
|
|
|
for pipeline_name in known_pipelines: |
|
|
with self.subTest(pipeline=pipeline_name): |
|
|
self.assertIn( |
|
|
pipeline_name, |
|
|
registry, |
|
|
f"Expected to find {pipeline_name} in registry" |
|
|
) |
|
|
|
|
|
def test_discover_pipelines_caches_results(self): |
|
|
"""Test that pipeline discovery results are cached.""" |
|
|
|
|
|
registry1 = loader.get_pipeline_registry() |
|
|
registry2 = loader.get_pipeline_registry() |
|
|
|
|
|
|
|
|
self.assertIs(registry1, registry2, "Registry should be cached") |
|
|
|
|
|
def test_get_available_pipelines(self): |
|
|
"""Test getting list of available pipelines.""" |
|
|
available = loader.get_available_pipelines() |
|
|
|
|
|
|
|
|
self.assertIsInstance(available, list) |
|
|
|
|
|
|
|
|
self.assertIn("StableDiffusionPipeline", available) |
|
|
self.assertIn("DiffusionPipeline", available) |
|
|
|
|
|
|
|
|
self.assertEqual(available, sorted(available)) |
|
|
|
|
|
def test_get_available_tasks(self): |
|
|
"""Test getting list of available task aliases.""" |
|
|
tasks = loader.get_available_tasks() |
|
|
|
|
|
|
|
|
self.assertIsInstance(tasks, list) |
|
|
|
|
|
|
|
|
self.assertEqual(tasks, sorted(tasks)) |
|
|
|
|
|
def test_resolve_pipeline_class_by_name(self): |
|
|
"""Test resolving pipeline class by exact name.""" |
|
|
cls = loader.resolve_pipeline_class(class_name="StableDiffusionPipeline") |
|
|
self.assertEqual(cls, StableDiffusionPipeline) |
|
|
|
|
|
def test_resolve_pipeline_class_by_name_case_insensitive(self): |
|
|
"""Test that class name resolution is case-insensitive.""" |
|
|
cls1 = loader.resolve_pipeline_class(class_name="StableDiffusionPipeline") |
|
|
cls2 = loader.resolve_pipeline_class(class_name="stablediffusionpipeline") |
|
|
self.assertEqual(cls1, cls2) |
|
|
|
|
|
def test_resolve_pipeline_class_by_task(self): |
|
|
"""Test resolving pipeline class by task alias.""" |
|
|
|
|
|
aliases = loader.get_task_aliases() |
|
|
|
|
|
|
|
|
if "stable-diffusion" in aliases: |
|
|
cls = loader.resolve_pipeline_class(task="stable-diffusion") |
|
|
self.assertIsNotNone(cls) |
|
|
|
|
|
def test_resolve_pipeline_class_unknown_name_raises(self): |
|
|
"""Test that resolving unknown class name raises ValueError with helpful message.""" |
|
|
with self.assertRaises(ValueError) as ctx: |
|
|
loader.resolve_pipeline_class(class_name="NonExistentPipeline") |
|
|
|
|
|
|
|
|
error_msg = str(ctx.exception) |
|
|
self.assertIn("Unknown pipeline class", error_msg) |
|
|
self.assertIn("Available pipelines", error_msg) |
|
|
|
|
|
def test_resolve_pipeline_class_unknown_task_raises(self): |
|
|
"""Test that resolving unknown task raises ValueError with helpful message.""" |
|
|
with self.assertRaises(ValueError) as ctx: |
|
|
loader.resolve_pipeline_class(task="nonexistent-task-xyz") |
|
|
|
|
|
|
|
|
error_msg = str(ctx.exception) |
|
|
self.assertIn("Unknown task", error_msg) |
|
|
self.assertIn("Available tasks", error_msg) |
|
|
|
|
|
def test_resolve_pipeline_class_no_params_raises(self): |
|
|
"""Test that calling with no parameters raises helpful ValueError.""" |
|
|
with self.assertRaises(ValueError) as ctx: |
|
|
loader.resolve_pipeline_class() |
|
|
|
|
|
error_msg = str(ctx.exception) |
|
|
self.assertIn("Must provide at least one of", error_msg) |
|
|
|
|
|
def test_get_pipeline_info(self): |
|
|
"""Test getting pipeline information.""" |
|
|
info = loader.get_pipeline_info("StableDiffusionPipeline") |
|
|
|
|
|
self.assertEqual(info['name'], "StableDiffusionPipeline") |
|
|
self.assertIsInstance(info['aliases'], list) |
|
|
self.assertIsInstance(info['supports_single_file'], bool) |
|
|
|
|
|
def test_get_pipeline_info_unknown_raises(self): |
|
|
"""Test that getting info for unknown pipeline raises ValueError.""" |
|
|
with self.assertRaises(ValueError) as ctx: |
|
|
loader.get_pipeline_info("NonExistentPipeline") |
|
|
|
|
|
self.assertIn("Unknown pipeline", str(ctx.exception)) |
|
|
|
|
|
def test_discover_diffusers_classes_pipelines(self): |
|
|
"""Test generic class discovery for DiffusionPipeline.""" |
|
|
classes = loader.discover_diffusers_classes("DiffusionPipeline") |
|
|
|
|
|
|
|
|
self.assertIsInstance(classes, dict) |
|
|
|
|
|
|
|
|
self.assertIn("DiffusionPipeline", classes) |
|
|
self.assertIn("StableDiffusionPipeline", classes) |
|
|
|
|
|
def test_discover_diffusers_classes_caches_results(self): |
|
|
"""Test that class discovery results are cached.""" |
|
|
classes1 = loader.discover_diffusers_classes("DiffusionPipeline") |
|
|
classes2 = loader.discover_diffusers_classes("DiffusionPipeline") |
|
|
|
|
|
|
|
|
self.assertIs(classes1, classes2) |
|
|
|
|
|
def test_discover_diffusers_classes_exclude_base(self): |
|
|
"""Test discovering classes without base class.""" |
|
|
classes = loader.discover_diffusers_classes("DiffusionPipeline", include_base=False) |
|
|
|
|
|
|
|
|
self.assertIn("StableDiffusionPipeline", classes) |
|
|
|
|
|
def test_get_available_classes(self): |
|
|
"""Test getting list of available classes for a base class.""" |
|
|
classes = loader.get_available_classes("DiffusionPipeline") |
|
|
|
|
|
|
|
|
self.assertIsInstance(classes, list) |
|
|
self.assertEqual(classes, sorted(classes)) |
|
|
|
|
|
|
|
|
self.assertIn("StableDiffusionPipeline", classes) |
|
|
|
|
|
|
|
|
class TestDiffusersDynamicLoaderWithMocks(unittest.TestCase): |
|
|
"""Test cases using mocks to test edge cases.""" |
|
|
|
|
|
def test_load_pipeline_requires_model_id(self): |
|
|
"""Test that load_diffusers_pipeline requires model_id.""" |
|
|
with self.assertRaises(ValueError) as ctx: |
|
|
loader.load_diffusers_pipeline(class_name="StableDiffusionPipeline") |
|
|
|
|
|
self.assertIn("model_id is required", str(ctx.exception)) |
|
|
|
|
|
def test_resolve_with_model_id_uses_diffusion_pipeline_fallback(self): |
|
|
"""Test that resolving with only model_id falls back to DiffusionPipeline.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cls = loader.resolve_pipeline_class(model_id="some/nonexistent/model") |
|
|
self.assertEqual(cls, DiffusionPipeline) |
|
|
|