| import unittest | |
| import importlib | |
| utils = importlib.import_module('extensions.sd-webui-controlnet.tests.utils', 'utils') | |
| utils.setup_test_env() | |
| from scripts import external_code | |
| class TestGetAllUnitsFrom(unittest.TestCase): | |
| def setUp(self): | |
| self.control_unit = { | |
| "module": "none", | |
| "model": utils.get_model(), | |
| "image": utils.readImage("test/test_files/img2img_basic.png"), | |
| "resize_mode": 1, | |
| "low_vram": False, | |
| "processor_res": 64, | |
| "control_mode": external_code.ControlMode.BALANCED.value, | |
| } | |
| self.object_unit = external_code.ControlNetUnit(**self.control_unit) | |
| def test_empty_converts(self): | |
| script_args = [] | |
| units = external_code.get_all_units_from(script_args) | |
| self.assertListEqual(units, []) | |
| def test_object_forwards(self): | |
| script_args = [self.object_unit] | |
| units = external_code.get_all_units_from(script_args) | |
| self.assertListEqual(units, [self.object_unit]) | |
| if __name__ == '__main__': | |
| unittest.main() |