File size: 12,465 Bytes
0f07ba7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
"""
A test script to test the gRPC service and dynamic loader
"""
import unittest
import subprocess
import time
from unittest.mock import patch, MagicMock

# Import dynamic loader for testing (these don't need gRPC)
import diffusers_dynamic_loader as loader
from diffusers import DiffusionPipeline, StableDiffusionPipeline

# Try to import gRPC modules - may not be available during unit testing
try:
    import grpc
    import backend_pb2
    import backend_pb2_grpc
    GRPC_AVAILABLE = True
except ImportError:
    GRPC_AVAILABLE = False


@unittest.skipUnless(GRPC_AVAILABLE, "gRPC modules not available")
class TestBackendServicer(unittest.TestCase):
    """
    TestBackendServicer is the class that tests the gRPC service
    """
    def setUp(self):
        """
        This method sets up the gRPC service by starting the server
        """
        self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])

    def tearDown(self) -> None:
        """
        This method tears down the gRPC service by terminating the server
        """
        self.service.kill()
        self.service.wait()

    def test_server_startup(self):
        """
        This method tests if the server starts up successfully
        """
        time.sleep(20)
        try:
            self.setUp()
            with grpc.insecure_channel("localhost:50051") as channel:
                stub = backend_pb2_grpc.BackendStub(channel)
                response = stub.Health(backend_pb2.HealthMessage())
                self.assertEqual(response.message, b'OK')
        except Exception as err:
            print(err)
            self.fail("Server failed to start")
        finally:
            self.tearDown()

    def test_load_model(self):
        """
        This method tests if the model is loaded successfully
        """
        time.sleep(20)
        try:
            self.setUp()
            with grpc.insecure_channel("localhost:50051") as channel:
                stub = backend_pb2_grpc.BackendStub(channel)
                response = stub.LoadModel(backend_pb2.ModelOptions(Model="Lykon/dreamshaper-8"))
                self.assertTrue(response.success)
                self.assertEqual(response.message, "Model loaded successfully")
        except Exception as err:
            print(err)
            self.fail("LoadModel service failed")
        finally:
            self.tearDown()

    def test(self):
        """
        This method tests if the backend can generate images
        """
        time.sleep(20)
        try:
            self.setUp()
            with grpc.insecure_channel("localhost:50051") as channel:
                stub = backend_pb2_grpc.BackendStub(channel)
                response = stub.LoadModel(backend_pb2.ModelOptions(Model="Lykon/dreamshaper-8"))
                print(response.message)
                self.assertTrue(response.success)
                image_req = backend_pb2.GenerateImageRequest(positive_prompt="cat", width=16,height=16, dst="test.jpg")
                re = stub.GenerateImage(image_req)
                self.assertTrue(re.success)
        except Exception as err:
            print(err)
            self.fail("Image gen service failed")
        finally:
            self.tearDown()


class TestDiffusersDynamicLoader(unittest.TestCase):
    """Test cases for the diffusers dynamic loader functionality."""

    @classmethod
    def setUpClass(cls):
        """Set up test fixtures - clear caches to ensure fresh discovery."""
        # Reset the caches to ensure fresh discovery
        loader._pipeline_registry = None
        loader._task_aliases = None

    def test_camel_to_kebab_conversion(self):
        """Test CamelCase to kebab-case conversion."""
        test_cases = [
            ("StableDiffusionPipeline", "stable-diffusion-pipeline"),
            ("StableDiffusionXLPipeline", "stable-diffusion-xl-pipeline"),
            ("FluxPipeline", "flux-pipeline"),
            ("DiffusionPipeline", "diffusion-pipeline"),
        ]
        for input_val, expected in test_cases:
            with self.subTest(input=input_val):
                result = loader._camel_to_kebab(input_val)
                self.assertEqual(result, expected)

    def test_extract_task_keywords(self):
        """Test task keyword extraction from class names."""
        # Test text-to-image detection
        aliases = loader._extract_task_keywords("StableDiffusionPipeline")
        self.assertIn("stable-diffusion", aliases)

        # Test img2img detection
        aliases = loader._extract_task_keywords("StableDiffusionImg2ImgPipeline")
        self.assertIn("image-to-image", aliases)
        self.assertIn("img2img", aliases)

        # Test inpainting detection
        aliases = loader._extract_task_keywords("StableDiffusionInpaintPipeline")
        self.assertIn("inpainting", aliases)
        self.assertIn("inpaint", aliases)

        # Test depth2img detection
        aliases = loader._extract_task_keywords("StableDiffusionDepth2ImgPipeline")
        self.assertIn("depth-to-image", aliases)

    def test_discover_pipelines_finds_known_classes(self):
        """Test that pipeline discovery finds at least one known pipeline class."""
        registry = loader.get_pipeline_registry()

        # Check that the registry is not empty
        self.assertGreater(len(registry), 0, "Pipeline registry should not be empty")

        # Check for known pipeline classes
        known_pipelines = [
            "StableDiffusionPipeline",
            "DiffusionPipeline",
        ]

        for pipeline_name in known_pipelines:
            with self.subTest(pipeline=pipeline_name):
                self.assertIn(
                    pipeline_name,
                    registry,
                    f"Expected to find {pipeline_name} in registry"
                )

    def test_discover_pipelines_caches_results(self):
        """Test that pipeline discovery results are cached."""
        # Get registry twice
        registry1 = loader.get_pipeline_registry()
        registry2 = loader.get_pipeline_registry()

        # Should be the same object (cached)
        self.assertIs(registry1, registry2, "Registry should be cached")

    def test_get_available_pipelines(self):
        """Test getting list of available pipelines."""
        available = loader.get_available_pipelines()

        # Should return a list
        self.assertIsInstance(available, list)

        # Should contain known pipelines
        self.assertIn("StableDiffusionPipeline", available)
        self.assertIn("DiffusionPipeline", available)

        # Should be sorted
        self.assertEqual(available, sorted(available))

    def test_get_available_tasks(self):
        """Test getting list of available task aliases."""
        tasks = loader.get_available_tasks()

        # Should return a list
        self.assertIsInstance(tasks, list)

        # Should be sorted
        self.assertEqual(tasks, sorted(tasks))

    def test_resolve_pipeline_class_by_name(self):
        """Test resolving pipeline class by exact name."""
        cls = loader.resolve_pipeline_class(class_name="StableDiffusionPipeline")
        self.assertEqual(cls, StableDiffusionPipeline)

    def test_resolve_pipeline_class_by_name_case_insensitive(self):
        """Test that class name resolution is case-insensitive."""
        cls1 = loader.resolve_pipeline_class(class_name="StableDiffusionPipeline")
        cls2 = loader.resolve_pipeline_class(class_name="stablediffusionpipeline")
        self.assertEqual(cls1, cls2)

    def test_resolve_pipeline_class_by_task(self):
        """Test resolving pipeline class by task alias."""
        # Get the registry to find available tasks
        aliases = loader.get_task_aliases()

        # Test with a common task that should be available
        if "stable-diffusion" in aliases:
            cls = loader.resolve_pipeline_class(task="stable-diffusion")
            self.assertIsNotNone(cls)

    def test_resolve_pipeline_class_unknown_name_raises(self):
        """Test that resolving unknown class name raises ValueError with helpful message."""
        with self.assertRaises(ValueError) as ctx:
            loader.resolve_pipeline_class(class_name="NonExistentPipeline")

        # Check that error message includes available pipelines
        error_msg = str(ctx.exception)
        self.assertIn("Unknown pipeline class", error_msg)
        self.assertIn("Available pipelines", error_msg)

    def test_resolve_pipeline_class_unknown_task_raises(self):
        """Test that resolving unknown task raises ValueError with helpful message."""
        with self.assertRaises(ValueError) as ctx:
            loader.resolve_pipeline_class(task="nonexistent-task-xyz")

        # Check that error message includes available tasks
        error_msg = str(ctx.exception)
        self.assertIn("Unknown task", error_msg)
        self.assertIn("Available tasks", error_msg)

    def test_resolve_pipeline_class_no_params_raises(self):
        """Test that calling with no parameters raises helpful ValueError."""
        with self.assertRaises(ValueError) as ctx:
            loader.resolve_pipeline_class()

        error_msg = str(ctx.exception)
        self.assertIn("Must provide at least one of", error_msg)

    def test_get_pipeline_info(self):
        """Test getting pipeline information."""
        info = loader.get_pipeline_info("StableDiffusionPipeline")

        self.assertEqual(info['name'], "StableDiffusionPipeline")
        self.assertIsInstance(info['aliases'], list)
        self.assertIsInstance(info['supports_single_file'], bool)

    def test_get_pipeline_info_unknown_raises(self):
        """Test that getting info for unknown pipeline raises ValueError."""
        with self.assertRaises(ValueError) as ctx:
            loader.get_pipeline_info("NonExistentPipeline")

        self.assertIn("Unknown pipeline", str(ctx.exception))

    def test_discover_diffusers_classes_pipelines(self):
        """Test generic class discovery for DiffusionPipeline."""
        classes = loader.discover_diffusers_classes("DiffusionPipeline")

        # Should return a dict
        self.assertIsInstance(classes, dict)

        # Should contain known pipeline classes
        self.assertIn("DiffusionPipeline", classes)
        self.assertIn("StableDiffusionPipeline", classes)

    def test_discover_diffusers_classes_caches_results(self):
        """Test that class discovery results are cached."""
        classes1 = loader.discover_diffusers_classes("DiffusionPipeline")
        classes2 = loader.discover_diffusers_classes("DiffusionPipeline")

        # Should be the same object (cached)
        self.assertIs(classes1, classes2)

    def test_discover_diffusers_classes_exclude_base(self):
        """Test discovering classes without base class."""
        classes = loader.discover_diffusers_classes("DiffusionPipeline", include_base=False)

        # Should still contain subclasses
        self.assertIn("StableDiffusionPipeline", classes)

    def test_get_available_classes(self):
        """Test getting list of available classes for a base class."""
        classes = loader.get_available_classes("DiffusionPipeline")

        # Should return a sorted list
        self.assertIsInstance(classes, list)
        self.assertEqual(classes, sorted(classes))

        # Should contain known classes
        self.assertIn("StableDiffusionPipeline", classes)


class TestDiffusersDynamicLoaderWithMocks(unittest.TestCase):
    """Test cases using mocks to test edge cases."""

    def test_load_pipeline_requires_model_id(self):
        """Test that load_diffusers_pipeline requires model_id."""
        with self.assertRaises(ValueError) as ctx:
            loader.load_diffusers_pipeline(class_name="StableDiffusionPipeline")

        self.assertIn("model_id is required", str(ctx.exception))

    def test_resolve_with_model_id_uses_diffusion_pipeline_fallback(self):
        """Test that resolving with only model_id falls back to DiffusionPipeline."""
        # When model_id is provided, if hub lookup is not successful,
        # should fall back to DiffusionPipeline.
        # This tests the fallback behavior - the actual hub lookup may succeed
        # or fail depending on network, but the fallback path should work.
        cls = loader.resolve_pipeline_class(model_id="some/nonexistent/model")
        self.assertEqual(cls, DiffusionPipeline)